import boto3
from botocore.exceptions import ClientError
import time

# A script to create an AWS CloudHSM cluster using Boto3.

# --- Configuration ---
REGION = "us-east-1"
CLUSTER_NAME = "MyBoto3HSMCluster"
HSM_TYPE = "hsm1.medium" # Specify a valid HSM type
SG_NAME = "MyBoto3HSMSG"

ec2_client = boto3.client('ec2', region_name=REGION)
cloudhsmv2_client = boto3.client('cloudhsmv2', region_name=REGION)

def get_default_vpc_and_subnet():
    """Gets the default VPC ID and a default subnet ID."""
    print("--- Getting Default VPC and Subnet ID ---")
    try:
        vpcs = ec2_client.describe_vpcs(Filters=[{'Name': 'is-default', 'Values': ['true']}])['Vpcs']
        if not vpcs:
            raise Exception("No default VPC found.")
        vpc_id = vpcs[0]['VpcId']
        print(f"Default VPC ID: {vpc_id}")

        subnets = ec2_client.describe_subnets(Filters=[{'Name': 'vpc-id', 'Values': [vpc_id]}, {'Name': 'default-for-az', 'Values': ['true']}])['Subnets']
        if not subnets:
            raise Exception("No default subnet found in the default VPC.")
        subnet_id = subnets[0]['SubnetId']
        print(f"Default Subnet ID: {subnet_id}")
        return vpc_id, subnet_id
    except ClientError as e:
        print(f"Error getting VPC/Subnet info: {e}")
        raise

def create_security_group(vpc_id):
    """Creates a Security Group for the CloudHSM cluster."""
    print(f"\n--- Creating Security Group: {SG_NAME} ---")
    try:
        sg_response = ec2_client.create_security_group(
            GroupName=SG_NAME,
            Description="Security Group for CloudHSM cluster",
            VpcId=vpc_id
        )
        sg_id = sg_response['GroupId']
        print(f"Security Group '{SG_NAME}' created with ID: {sg_id}")

        # Authorize inbound traffic from itself (for cluster communication)
        ec2_client.authorize_security_group_ingress(
            GroupId=sg_id,
            IpPermissions=[
                {'IpProtocol': '-1', 'FromPort': -1, 'ToPort': -1, 'UserIdGroupPairs': [{'GroupId': sg_id}]}
            ]
        )
        print("Inbound rule for self-communication added to Security Group.")
        return sg_id
    except ClientError as e:
        if e.response['Error']['Code'] == 'InvalidGroup.Duplicate':
            print(f"Security Group '{SG_NAME}' already exists. Fetching ID.")
            response = ec2_client.describe_security_groups(GroupNames=[SG_NAME])
            return response['SecurityGroups'][0]['GroupId']
        else:
            print(f"Error creating security group: {e}")
            raise

def create_cloudhsm_cluster(subnet_id, sg_id):
    """Creates a CloudHSM cluster."""
    print(f"\n--- Creating CloudHSM Cluster: {CLUSTER_NAME} ---")
    try:
        response = cloudhsmv2_client.create_cluster(
            HsmType=HSM_TYPE,
            SubnetIds=[subnet_id],
            TagList=[{'Key': 'Name', 'Value': CLUSTER_NAME}]
        )
        cluster_id = response['Cluster']['ClusterId']
        print(f"CloudHSM Cluster created with ID: {cluster_id}. Waiting for it to be in 'UNINITIALIZED' state...")
        
        # CloudHSM doesn't have a 'cluster_uninitialized' waiter, so we poll
        while True:
            status_response = cloudhsmv2_client.describe_clusters(Filters={'clusterIds': [cluster_id]})
            state = status_response['Clusters'][0]['State']
            if state == 'UNINITIALIZED':
                print("CloudHSM Cluster is in 'UNINITIALIZED' state.")
                break
            print(f"Cluster status: {state}, waiting...")
            time.sleep(30)
        return cluster_id
    except ClientError as e:
        print(f"Error creating CloudHSM cluster: {e}")
        raise

def cleanup_resources(cluster_id, sg_id):
    """Cleans up all created resources."""
    print(f"\n--- Cleaning up resources ---")

    # Delete CloudHSM Cluster
    if cluster_id:
        print(f"Deleting CloudHSM Cluster '{cluster_id}'...")
        try:
            cloudhsmv2_client.delete_cluster(ClusterId=cluster_id)
            # CloudHSM doesn't have a 'cluster_deleted' waiter, so we poll
            while True:
                try:
                    cloudhsmv2_client.describe_clusters(Filters={'clusterIds': [cluster_id]})
                    print("Cluster still exists, waiting...")
                    time.sleep(30)
                except ClientError as e:
                    if e.response['Error']['Code'] == 'CloudHsmClusterNotFoundException':
                        print("CloudHSM Cluster deleted.")
                        break
                    else:
                        raise e
        except ClientError as e:
            print(f"Error deleting CloudHSM cluster: {e}")

    # Delete Security Group
    if sg_id:
        print(f"Deleting Security Group '{SG_NAME}'...")
        try:
            ec2_client.delete_security_group(GroupId=sg_id)
            print("Security Group deleted.")
        except ClientError as e:
            if e.response['Error']['Code'] == 'InvalidGroup.NotFound':
                print(f"Security Group '{SG_NAME}' not found, skipping deletion.")
            elif e.response['Error']['Code'] == 'DependencyViolation':
                print(f"Security Group '{SG_ID}' is still in use. Retrying deletion after a short delay.")
                time.sleep(10)
                ec2_client.delete_security_group(GroupId=sg_id)
                print("Security Group deleted.")
            else:
                print(f"Error deleting security group: {e}")

def main():
    vpc_id = None
    subnet_id = None
    sg_id = None
    cluster_id = None
    try:
        vpc_id, subnet_id = get_default_vpc_and_subnet()
        sg_id = create_security_group(vpc_id)
        cluster_id = create_cloudhsm_cluster(subnet_id, sg_id)

        print("\n--- CloudHSM Cluster Setup Complete! ---")
        print(f"Cluster ID: {cluster_id}")
        print("Next steps: Initialize the cluster and create HSM users.")

        input("Press Enter to delete the CloudHSM cluster and clean up resources...")

    except ClientError as e:
        print(f"An AWS client error occurred: {e}")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
    finally:
        cleanup_resources(cluster_id, sg_id)
        print("\n--- CloudHSM cluster demonstration and cleanup complete ---")

if __name__ == "__main__":
    main()
