import boto3
from botocore.exceptions import ClientError
import time

# A script to create a Site-to-Site VPN connection using Boto3.

# --- Configuration ---
REGION = "us-east-1"
CUSTOMER_GATEWAY_IP = "198.51.100.1" # !!! IMPORTANT: Replace with your actual on-premises public IP address !!!
BGP_ASN = 65000 # Your on-premises BGP ASN
CGW_NAME = "MyBoto3CustomerGateway"
VGW_NAME = "MyBoto3VirtualPrivateGateway"
VPN_CONN_NAME = "MyBoto3VPNConnection"

ec2_client = boto3.client('ec2', region_name=REGION)

def get_default_vpc_id():
    """Gets the default VPC ID."""
    print("--- Getting Default VPC ID ---")
    try:
        vpcs = ec2_client.describe_vpcs(Filters=[{'Name': 'is-default', 'Values': ['true']}])['Vpcs']
        if not vpcs:
            raise Exception("No default VPC found.")
        vpc_id = vpcs[0]['VpcId']
        print(f"Default VPC ID: {vpc_id}")
        return vpc_id
    except ClientError as e:
        print(f"Error getting VPC info: {e}")
        raise

def create_customer_gateway():
    """Creates a Customer Gateway."""
    print(f"\n--- Creating Customer Gateway: {CGW_NAME} ---")
    try:
        cgw_response = ec2_client.create_customer_gateway(
            BgpAsn=BGP_ASN,
            PublicIp=CUSTOMER_GATEWAY_IP,
            Type='ipsec.1',
            TagSpecifications=[{'ResourceType': 'customer-gateway', 'Tags': [{'Key': 'Name', 'Value': CGW_NAME}}])
        )
        cgw_id = cgw_response['CustomerGateway']['CustomerGatewayId']
        print(f"Customer Gateway created with ID: {cgw_id}")
        return cgw_id
    except ClientError as e:
        print(f"Error creating Customer Gateway: {e}")
        raise

def create_virtual_private_gateway(vpc_id):
    """Creates a Virtual Private Gateway and attaches it to the VPC."""
    print(f"\n--- Creating Virtual Private Gateway: {VGW_NAME} ---")
    try:
        vgw_response = ec2_client.create_vpn_gateway(
            Type='ipsec.1',
            AmazonSideAsn=64512, # Example private ASN
            TagSpecifications=[{'ResourceType': 'vpn-gateway', 'Tags': [{'Key': 'Name', 'Value': VGW_NAME}}])
        )
        vgw_id = vgw_response['VpnGateway']['VpnGatewayId']
        print(f"Virtual Private Gateway created with ID: {vgw_id}. Attaching to VPC '{vpc_id}'...")
        ec2_client.attach_vpn_gateway(VpcId=vpc_id, VpnGatewayId=vgw_id)
        print("Waiting for VGW to attach...")
        time.sleep(10) # Give it some time to attach
        return vgw_id
    except ClientError as e:
        print(f"Error creating/attaching VGW: {e}")
        raise

def create_vpn_connection(cgw_id, vgw_id):
    """Creates the Site-to-Site VPN connection."""
    print(f"\n--- Creating VPN Connection: {VPN_CONN_NAME} ---")
    try:
        vpn_conn_response = ec2_client.create_vpn_connection(
            Type='ipsec.1',
            CustomerGatewayId=cgw_id,
            VpnGatewayId=vgw_id,
            Options={'StaticRoutesOnly': True},
            TagSpecifications=[{'ResourceType': 'vpn-connection', 'Tags': [{'Key': 'Name', 'Value': VPN_CONN_NAME}}])
        )
        vpn_conn_id = vpn_conn_response['VpnConnection']['VpnConnectionId']
        print(f"VPN Connection created with ID: {vpn_conn_id}. Waiting for it to be available...")
        ec2_client.get_waiter('vpn_connection_available').wait(VpnConnectionIds=[vpn_conn_id])
        print("VPN Connection is available.")
        return vpn_conn_id
    except ClientError as e:
        print(f"Error creating VPN Connection: {e}")
        raise

def cleanup_resources(vpc_id, cgw_id, vgw_id, vpn_conn_id):
    """Cleans up all created resources."""
    print(f"\n--- Cleaning up resources ---")

    # Delete VPN Connection
    if vpn_conn_id:
        print(f"Deleting VPN Connection '{vpn_conn_id}'...")
        try:
            ec2_client.delete_vpn_connection(VpnConnectionId=vpn_conn_id)
            ec2_client.get_waiter('vpn_connection_deleted').wait(VpnConnectionIds=[vpn_conn_id])
            print("VPN Connection deleted.")
        except ClientError as e:
            if e.response['Error']['Code'] == 'InvalidVpnConnectionID.NotFound':
                print(f"VPN Connection '{vpn_conn_id}' not found, skipping deletion.")
            else:
                print(f"Error deleting VPN Connection: {e}")

    # Detach and Delete Virtual Private Gateway
    if vgw_id:
        print(f"Detaching Virtual Private Gateway '{vgw_id}' from VPC '{vpc_id}'...")
        try:
            ec2_client.detach_vpn_gateway(VpcId=vpc_id, VpnGatewayId=vgw_id)
            print("Waiting for VGW to detach...")
            time.sleep(10) # Give it some time to detach
            print(f"Deleting Virtual Private Gateway '{vgw_id}'...")
            ec2_client.delete_vpn_gateway(VpnGatewayId=vgw_id)
            print("Virtual Private Gateway deleted.")
        except ClientError as e:
            if e.response['Error']['Code'] == 'InvalidVpnGatewayID.NotFound':
                print(f"VGW '{vgw_id}' not found, skipping deletion.")
            elif e.response['Error']['Code'] == 'VpnGateway.NotAttached':
                print(f"VGW '{vgw_id}' not attached, skipping detach.")
                ec2_client.delete_vpn_gateway(VpnGatewayId=vgw_id)
                print("Virtual Private Gateway deleted.")
            else:
                print(f"Error deleting VGW: {e}")

    # Delete Customer Gateway
    if cgw_id:
        print(f"Deleting Customer Gateway '{cgw_id}'...")
        try:
            ec2_client.delete_customer_gateway(CustomerGatewayId=cgw_id)
            print("Customer Gateway deleted.")
        except ClientError as e:
            if e.response['Error']['Code'] == 'InvalidCustomerGatewayID.NotFound':
                print(f"CGW '{cgw_id}' not found, skipping deletion.")
            else:
                print(f"Error deleting Customer Gateway: {e}")

def main():
    vpc_id = None
    cgw_id = None
    vgw_id = None
    vpn_conn_id = None
    try:
        vpc_id = get_default_vpc_id()
        cgw_id = create_customer_gateway()
        vgw_id = create_virtual_private_gateway(vpc_id)
        vpn_conn_id = create_vpn_connection(cgw_id, vgw_id)

        print("\n--- VPN Connection Setup Complete! ---")
        print(f"VPN Connection ID: {vpn_conn_id}")
        print("You can download the configuration file for your on-premises device using AWS CLI or Console.")

        input("Press Enter to delete the VPN connection and clean up resources...")

    except ClientError as e:
        print(f"An AWS client error occurred: {e}")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
    finally:
        cleanup_resources(vpc_id, cgw_id, vgw_id, vpn_conn_id)
        print("\n--- Site-to-Site VPN demonstration and cleanup complete ---")

if __name__ == "__main__":
    main()
