import boto3
from botocore.exceptions import ClientError
import time
import json

# A script to run a simple ECS task on AWS Fargate using Boto3.

# --- Configuration ---
REGION = "us-east-1"
ECS_CLUSTER_NAME = "MyBoto3FargateCluster"
TASK_DEFINITION_FAMILY = "MyBoto3FargateTask"
TASK_EXECUTION_ROLE_NAME = "MyBoto3FargateTaskExecutionRole"
CONTAINER_NAME = "nginx-container"
CONTAINER_IMAGE = "nginx:latest"
CONTAINER_PORT = 80

ec2_client = boto3.client('ec2', region_name=REGION)
ecs_client = boto3.client('ecs', region_name=REGION)
iam_client = boto3.client('iam', region_name=REGION)

def get_default_vpc_and_subnet():
    """Gets the default VPC ID and a default subnet ID."""
    print("--- Getting Default VPC and Subnet ID ---")
    try:
        vpcs = ec2_client.describe_vpcs(Filters=[{'Name': 'is-default', 'Values': ['true']}])['Vpcs']
        if not vpcs:
            raise Exception("No default VPC found.")
        vpc_id = vpcs[0]['VpcId']
        print(f"Default VPC ID: {vpc_id}")

        subnets = ec2_client.describe_subnets(Filters=[{'Name': 'vpc-id', 'Values': [vpc_id]}, {'Name': 'default-for-az', 'Values': ['true']}])['Subnets']
        if not subnets:
            raise Exception("No default subnet found in the default VPC.")
        subnet_id = subnets[0]['SubnetId']
        print(f"Default Subnet ID: {subnet_id}")
        return vpc_id, subnet_id
    except ClientError as e:
        print(f"Error getting VPC/Subnet info: {e}")
        raise

def create_security_group(vpc_id):
    """Creates a Security Group for the Fargate task."""
    print("Creating Security Group for Fargate task...")
    try:
        sg_response = ec2_client.create_security_group(
            GroupName="MyBoto3FargateSG",
            Description="Allow HTTP for Fargate task",
            VpcId=vpc_id
        )
        sg_id = sg_response['GroupId']
        print(f"Security Group created with ID: {sg_id}")

        ec2_client.authorize_security_group_ingress(
            GroupId=sg_id,
            IpPermissions=[
                {'IpProtocol': 'tcp', 'FromPort': CONTAINER_PORT, 'ToPort': CONTAINER_PORT, 'IpRanges': [{'CidrIp': '0.0.0.0/0'}]}
            ]
        )
        print("Inbound HTTP rule added to Security Group.")
        return sg_id
    except ClientError as e:
        if e.response['Error']['Code'] == 'InvalidGroup.Duplicate':
            print(f"Security Group 'MyBoto3FargateSG' already exists. Fetching ID.")
            response = ec2_client.describe_security_groups(GroupNames=["MyBoto3FargateSG"])
            return response['SecurityGroups'][0]['GroupId']
        else:
            print(f"Error creating security group: {e}")
            raise

def create_ecs_cluster():
    """Creates an ECS cluster."""
    print(f"\n--- Creating ECS Cluster: {ECS_CLUSTER_NAME} ---")
    try:
        ecs_client.create_cluster(clusterName=ECS_CLUSTER_NAME)
        print("ECS Cluster created.")
    except ClientError as e:
        if e.response['Error']['Code'] == 'ClusterAlreadyExistsException':
            print(f"ECS Cluster '{ECS_CLUSTER_NAME}' already exists. Skipping creation.")
        else:
            print(f"Error creating ECS cluster: {e}")
            raise

def create_task_execution_role():
    """Creates the IAM role for the ECS task execution."""
    print(f"\n--- Creating Task Execution Role: {TASK_EXECUTION_ROLE_NAME} ---")
    trust_policy = {
      "Version": "2012-10-17",
      "Statement": [{"Effect": "Allow", "Principal": {"Service": "ecs-tasks.amazonaws.com"}, "Action": "sts:AssumeRole"}]
    }
    try:
        role_response = iam_client.create_role(
            RoleName=TASK_EXECUTION_ROLE_NAME,
            AssumeRolePolicyDocument=json.dumps(trust_policy),
            Description="Role for ECS Fargate task execution"
        )
        role_arn = role_response['Role']['Arn']
        iam_client.attach_role_policy(
            RoleName=TASK_EXECUTION_ROLE_NAME,
            PolicyArn='arn:aws:iam::aws:policy/service-role/AmazonECSTaskExecutionRolePolicy'
        )
        print("Task Execution Role created and policy attached.")
        print("Waiting for IAM role to propagate...")
        time.sleep(10)
        return role_arn
    except ClientError as e:
        if e.response['Error']['Code'] == 'EntityAlreadyExists':
            print(f"IAM role '{TASK_EXECUTION_ROLE_NAME}' already exists. Fetching ARN.")
            return iam_client.get_role(RoleName=TASK_EXECUTION_ROLE_NAME)['Role']['Arn']
        else:
            print(f"Error creating task execution role: {e}")
            raise

def register_task_definition(task_execution_role_arn):
    """Registers a task definition for a simple Nginx container."""
    print(f"\n--- Registering Task Definition: {TASK_DEFINITION_FAMILY} ---")
    try:
        response = ecs_client.register_task_definition(
            family=TASK_DEFINITION_FAMILY,
            networkMode='awsvpc',
            cpu='256',
            memory='512',
            executionRoleArn=task_execution_role_arn,
            containerDefinitions=[
                {
                    'name': CONTAINER_NAME,
                    'image': CONTAINER_IMAGE,
                    'portMappings': [{'containerPort': CONTAINER_PORT, 'protocol': 'tcp'}]
                }
            ],
            requiresCompatibilities=['FARGATE']
        )
        task_definition_arn = response['taskDefinition']['taskDefinitionArn']
        print(f"Task Definition registered: {task_definition_arn}")
        return task_definition_arn
    except ClientError as e:
        print(f"Error registering task definition: {e}")
        raise

def run_ecs_task(subnet_id, sg_id):
    """Runs a single instance of the task on Fargate."""
    print(f"\n--- Running ECS Task on Fargate ---")
    try:
        response = ecs_client.run_task(
            cluster=ECS_CLUSTER_NAME,
            taskDefinition=TASK_DEFINITION_FAMILY,
            launchType='FARGATE',
            networkConfiguration={
                'awsvpcConfiguration': {
                    'subnets': [subnet_id],
                    'securityGroups': [sg_id],
                    'assignPublicIp': 'ENABLED'
                }
            },
            count=1,
            platformVersion='LATEST'
        )
        task_arn = response['tasks'][0]['taskArn']
        print(f"Fargate task started with ARN: {task_arn}")
        print("Waiting for task to be running...")
        ecs_client.get_waiter('tasks_running').wait(cluster=ECS_CLUSTER_NAME, tasks=[task_arn])
        print("Fargate task is running.")
        return task_arn
    except ClientError as e:
        print(f"Error running ECS task: {e}")
        raise

def get_task_public_ip(task_arn):
    """Gets the public IP address of the running Fargate task."""
    print("\n--- Getting Fargate task public IP ---")
    try:
        response = ecs_client.describe_tasks(cluster=ECS_CLUSTER_NAME, tasks=[task_arn])
        attachments = response['tasks'][0]['attachments']
        for attachment in attachments:
            if attachment['type'] == 'ElasticNetworkInterface':
                for detail in attachment['details']:
                    if detail['name'] == 'networkInterfaceId':
                        eni_id = detail['value']
                        eni_response = ec2_client.describe_network_interfaces(NetworkInterfaceIds=[eni_id])
                        public_ip = eni_response['NetworkInterfaces'][0]['Association']['PublicIp']
                        print(f"Public IP Address: {public_ip}")
                        return public_ip
        return None
    except ClientError as e:
        print(f"Error getting task public IP: {e}")
        raise

def cleanup_resources(task_arn, task_definition_arn, task_execution_role_arn, sg_id):
    """Cleans up all created resources."""
    print(f"\n--- Cleaning up resources ---")

    # Stop Task
    if task_arn:
        print(f"Stopping Fargate task '{task_arn}'...")
        try:
            ecs_client.stop_task(cluster=ECS_CLUSTER_NAME, task=task_arn)
            ecs_client.get_waiter('tasks_stopped').wait(cluster=ECS_CLUSTER_NAME, tasks=[task_arn])
            print("Fargate task stopped.")
        except ClientError as e:
            if e.response['Error']['Code'] == 'ClusterNotFoundException' or e.response['Error']['Code'] == 'TaskNotFoundException':
                print(f"Task '{task_arn}' not found, skipping stop.")
            else:
                print(f"Error stopping task: {e}")

    # Deregister Task Definition
    if task_definition_arn:
        print(f"Deregistering Task Definition '{task_definition_arn}'...")
        try:
            ecs_client.deregister_task_definition(taskDefinition=task_definition_arn)
            print("Task Definition deregistered.")
        except ClientError as e:
            if e.response['Error']['Code'] == 'ClientException':
                print(f"Task Definition '{task_definition_arn}' not found, skipping deregister.")
            else:
                print(f"Error deregistering task definition: {e}")

    # Delete ECS Cluster
    print(f"Deleting ECS Cluster '{ECS_CLUSTER_NAME}'...")
    try:
        ecs_client.delete_cluster(cluster=ECS_CLUSTER_NAME)
        print("ECS Cluster deleted.")
    except ClientError as e:
        if e.response['Error']['Code'] == 'ClusterNotFoundException':
            print(f"ECS Cluster '{ECS_CLUSTER_NAME}' not found, skipping deletion.")
        else:
            print(f"Error deleting ECS cluster: {e}")

    # Detach and Delete IAM Role
    if task_execution_role_arn:
        print(f"Detaching policy from IAM Role '{TASK_EXECUTION_ROLE_NAME}'...")
        try:
            iam_client.detach_role_policy(
                RoleName=TASK_EXECUTION_ROLE_NAME,
                PolicyArn='arn:aws:iam::aws:policy/service-role/AmazonECSTaskExecutionRolePolicy'
            )
            print(f"Deleting IAM Role '{TASK_EXECUTION_ROLE_NAME}'...")
            iam_client.delete_role(RoleName=TASK_EXECUTION_ROLE_NAME)
            print("IAM Role deleted.")
        except ClientError as e:
            if e.response['Error']['Code'] == 'NoSuchEntity':
                print(f"IAM Role '{TASK_EXECUTION_ROLE_NAME}' not found, skipping deletion.")
            else:
                print(f"Error deleting IAM role: {e}")

    # Delete Security Group
    if sg_id:
        print(f"Deleting Security Group '{sg_id}'...")
        try:
            ec2_client.delete_security_group(GroupId=sg_id)
            print("Security Group deleted.")
        except ClientError as e:
            if e.response['Error']['Code'] == 'InvalidGroup.NotFound':
                print(f"Security Group '{sg_id}' not found, skipping deletion.")
            elif e.response['Error']['Code'] == 'DependencyViolation':
                print(f"Security Group '{sg_id}' is still in use. Retrying deletion after a short delay.")
                time.sleep(10)
                ec2_client.delete_security_group(GroupId=sg_id)
                print("Security Group deleted.")
            else:
                print(f"Error deleting security group: {e}")

def main():
    vpc_id = None
    subnet_id = None
    sg_id = None
    task_execution_role_arn = None
    task_definition_arn = None
    task_arn = None

    try:
        vpc_id, subnet_id = get_default_vpc_and_subnet()
        sg_id = create_security_group(vpc_id)
        create_ecs_cluster()
        task_execution_role_arn = create_task_execution_role()
        task_definition_arn = register_task_definition(task_execution_role_arn)
        task_arn = run_ecs_task(subnet_id, sg_id)
        
        public_ip = get_task_public_ip(task_arn)
        if public_ip:
            print(f"\n--- Fargate Task Running Successfully! ---")
            print(f"Task ARN: {task_arn}")
            print(f"Public IP Address: {public_ip}")
            print(f"You can access the Nginx server at: http://{public_ip}")
        else:
            print("Could not retrieve public IP for the task.")

        input("Press Enter to stop the task and clean up resources...")

    except ClientError as e:
        print(f"An AWS client error occurred: {e}")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
    finally:
        cleanup_resources(task_arn, task_definition_arn, task_execution_role_arn, sg_id)
        print("\n--- Fargate task demonstration and cleanup complete ---")

if __name__ == "__main__":
    main()
