import boto3
from botocore.exceptions import ClientError
import time

# A script to delete the multi-tier VPC resources created by create_multi_tier_vpc.py.

# --- Configuration ---
REGION = "us-east-1"
VPC_NAME = "MyWebAppVPC-Boto3" # Must match the name used in create_multi_tier_vpc.py

ec2_client = boto3.client('ec2', region_name=REGION)
ec2_resource = boto3.resource('ec2', region_name=REGION)

def get_resource_id_by_tag(resource_type, tag_value, vpc_id=None):
    """Helper function to get resource ID by tag."""
    filters = [{'Name': 'tag:Name', 'Values': [tag_value]}]
    if vpc_id:
        filters.append({'Name': 'vpc-id', 'Values': [vpc_id]})

    if resource_type == 'vpcs':
        response = ec2_client.describe_vpcs(Filters=filters)
        return response['Vpcs'][0]['VpcId'] if response['Vpcs'] else None
    elif resource_type == 'subnets':
        response = ec2_client.describe_subnets(Filters=filters)
        return response['Subnets'][0]['SubnetId'] if response['Subnets'] else None
    elif resource_type == 'internet-gateways':
        response = ec2_client.describe_internet_gateways(Filters=filters)
        return response['InternetGateways'][0]['InternetGatewayId'] if response['InternetGateways'] else None
    elif resource_type == 'nat-gateways':
        response = ec2_client.describe_nat_gateways(Filters=filters)
        return response['NatGateways'][0]['NatGatewayId'] if response['NatGateways'] else None
    elif resource_type == 'route-tables':
        response = ec2_client.describe_route_tables(Filters=filters)
        return response['RouteTables'][0]['RouteTableId'] if response['RouteTables'] else None
    elif resource_type == 'security-groups':
        response = ec2_client.describe_security_groups(Filters=filters)
        return response['SecurityGroups'][0]['GroupId'] if response['SecurityGroups'] else None
    return None

def get_nat_gateway_eip_allocation_id(nat_gw_id):
    """Gets the EIP allocation ID associated with a NAT Gateway."""
    try:
        response = ec2_client.describe_nat_gateways(NatGatewayIds=[nat_gw_id])
        return response['NatGateways'][0]['NatGatewayAddresses'][0]['AllocationId']
    except (ClientError, IndexError):
        return None

def main():
    print("--- Retrieving IDs of resources to delete ---")
    vpc_id = get_resource_id_by_tag('vpcs', 'Name', VPC_NAME)
    if not vpc_id:
        print(f"VPC '{VPC_NAME}' not found. Exiting cleanup.")
        return

    print(f"Found VPC ID: {vpc_id}")

    public_subnet_id = get_resource_id_by_tag('subnets', 'Name', f"{VPC_NAME}-Public-Subnet", vpc_id)
    private_subnet_app_id = get_resource_id_by_tag('subnets', 'Name', f"{VPC_NAME}-Private-App-Subnet", vpc_id)
    private_subnet_db_id = get_resource_id_by_tag('subnets', 'Name', f"{VPC_NAME}-Private-DB-Subnet", vpc_id)

    igw_id = get_resource_id_by_tag('internet-gateways', 'Name', f"{VPC_NAME}-IGW")
    nat_gw_id = get_resource_id_by_tag('nat-gateways', 'Name', f"{VPC_NAME}-NAT-GW")
    eip_alloc_id = get_nat_gateway_eip_allocation_id(nat_gw_id) if nat_gw_id else None

    public_rt_id = get_resource_id_by_tag('route-tables', 'Name', f"{VPC_NAME}-Public-RT", vpc_id)
    private_rt_id = get_resource_id_by_tag('route-tables', 'Name', f"{VPC_NAME}-Private-RT", vpc_id)

    web_sg_id = get_resource_id_by_tag('security-groups', 'Name', f"{VPC_NAME}-Web-SG", vpc_id)
    app_sg_id = get_resource_id_by_tag('security-groups', 'Name', f"{VPC_NAME}-App-SG", vpc_id)
    db_sg_id = get_resource_id_by_tag('security-groups', 'Name', f"{VPC_NAME}-DB-SG", vpc_id)

    print("--- Starting cleanup of VPC resources ---")

    try:
        # --- Delete Security Groups ---
        print("Deleting Security Groups...")
        for sg_id in [web_sg_id, app_sg_id, db_sg_id]:
            if sg_id:
                try:
                    ec2_client.delete_security_group(GroupId=sg_id)
                    print(f"Deleted SG: {sg_id}")
                except ClientError as e:
                    if e.response['Error']['Code'] == 'InvalidGroup.NotFound':
                        print(f"SG {sg_id} not found.")
                    elif e.response['Error']['Code'] == 'DependencyViolation':
                        print(f"SG {sg_id} still in use, retrying deletion after delay.")
                        time.sleep(10)
                        ec2_client.delete_security_group(GroupId=sg_id)
                    else:
                        print(f"Error deleting SG {sg_id}: {e}")
        time.sleep(5) # Give time for SGs to detach

        # --- Delete Route Table Associations and Routes ---
        print("Deleting Route Table Associations and Routes...")
        # Disassociate public route table from public subnet
        if public_rt_id and public_subnet_id:
            try:
                response = ec2_client.describe_route_tables(RouteTableIds=[public_rt_id])
                for assoc in response['RouteTables'][0]['Associations']:
                    if 'SubnetId' in assoc and assoc['SubnetId'] == public_subnet_id:
                        ec2_client.disassociate_route_table(AssociationId=assoc['RouteTableAssociationId'])
                        print(f"Disassociated public RT from {public_subnet_id}")
            except ClientError as e:
                print(f"Error disassociating public RT: {e}")
        
        # Delete routes from public route table
        if public_rt_id:
            try:
                ec2_client.delete_route(RouteTableId=public_rt_id, DestinationCidrBlock='0.0.0.0/0')
                print(f"Deleted route from public RT {public_rt_id}")
            except ClientError as e:
                if e.response['Error']['Code'] != 'InvalidRoute.NotFound':
                    print(f"Error deleting route from public RT {public_rt_id}: {e}")

        # Disassociate private route table from private subnets
        if private_rt_id:
            for subnet_id in [private_subnet_app_id, private_subnet_db_id]:
                if subnet_id:
                    try:
                        response = ec2_client.describe_route_tables(RouteTableIds=[private_rt_id])
                        for assoc in response['RouteTables'][0]['Associations']:
                            if 'SubnetId' in assoc and assoc['SubnetId'] == subnet_id:
                                ec2_client.disassociate_route_table(AssociationId=assoc['RouteTableAssociationId'])
                                print(f"Disassociated private RT from {subnet_id}")
                    except ClientError as e:
                        print(f"Error disassociating private RT from {subnet_id}: {e}")

        # Delete routes from private route table
        if private_rt_id:
            try:
                ec2_client.delete_route(RouteTableId=private_rt_id, DestinationCidrBlock='0.0.0.0/0')
                print(f"Deleted route from private RT {private_rt_id}")
            except ClientError as e:
                if e.response['Error']['Code'] != 'InvalidRoute.NotFound':
                    print(f"Error deleting route from private RT {private_rt_id}: {e}")

        # Delete Route Tables
        for rt_id in [public_rt_id, private_rt_id]:
            if rt_id:
                try:
                    ec2_client.delete_route_table(RouteTableId=rt_id)
                    print(f"Deleted RT: {rt_id}")
                except ClientError as e:
                    if e.response['Error']['Code'] != 'InvalidRouteTableID.NotFound':
                        print(f"Error deleting RT {rt_id}: {e}")
        time.sleep(5)

        # --- Delete NAT Gateway ---
        print("Deleting NAT Gateway...")
        if nat_gw_id:
            try:
                ec2_client.delete_nat_gateway(NatGatewayId=nat_gw_id)
                print(f"Deleted NAT GW: {nat_gw_id}")
                ec2_client.get_waiter('nat_gateway_deleted').wait(NatGatewayIds=[nat_gw_id])
            except ClientError as e:
                if e.response['Error']['Code'] != 'NatGatewayNotFound':
                    print(f"Error deleting NAT GW {nat_gw_id}: {e}")
        if eip_alloc_id:
            try:
                ec2_client.release_address(AllocationId=eip_alloc_id)
                print(f"Released EIP: {eip_alloc_id}")
            except ClientError as e:
                if e.response['Error']['Code'] != 'InvalidAllocationID.NotFound':
                    print(f"Error releasing EIP {eip_alloc_id}: {e}")
        time.sleep(5)

        # --- Detach and Delete Internet Gateway ---
        print("Detaching and deleting Internet Gateway...")
        if igw_id:
            try:
                ec2_client.detach_internet_gateway(InternetGatewayId=igw_id, VpcId=vpc_id)
                print(f"Detached IGW: {igw_id}")
            except ClientError as e:
                if e.response['Error']['Code'] != 'Gateway.NotAttached':
                    print(f"Error detaching IGW {igw_id}: {e}")
            try:
                ec2_client.delete_internet_gateway(InternetGatewayId=igw_id)
                print(f"Deleted IGW: {igw_id}")
            except ClientError as e:
                if e.response['Error']['Code'] != 'InvalidInternetGatewayID.NotFound':
                    print(f"Error deleting IGW {igw_id}: {e}")
        time.sleep(5)

        # --- Delete Subnets ---
        print("Deleting Subnets...")
        for subnet_id in [public_subnet_id, private_subnet_app_id, private_subnet_db_id]:
            if subnet_id:
                try:
                    ec2_client.delete_subnet(SubnetId=subnet_id)
                    print(f"Deleted Subnet: {subnet_id}")
                except ClientError as e:
                    if e.response['Error']['Code'] != 'InvalidSubnetID.NotFound':
                        print(f"Error deleting Subnet {subnet_id}: {e}")
        time.sleep(5)

        # --- Delete VPC ---
        print("Deleting VPC...")
        if vpc_id:
            try:
                ec2_client.delete_vpc(VpcId=vpc_id)
                print(f"Deleted VPC: {vpc_id}")
            except ClientError as e:
                if e.response['Error']['Code'] != 'InvalidVpcID.NotFound':
                    print(f"Error deleting VPC {vpc_id}: {e}")

        print("\n--- VPC cleanup complete! ---")

    except ClientError as e:
        print(f"An AWS client error occurred during cleanup: {e}")
    except Exception as e:
        print(f"An unexpected error occurred during cleanup: {e}")

if __name__ == "__main__":
    main()
