import boto3
from botocore.exceptions import ClientError
import time

# A script to create a Redshift cluster, including necessary prerequisites,
# and then clean up all resources using Boto3.

# --- Configuration ---
REGION = "us-east-1"
CLUSTER_IDENTIFIER = "my-boto3-redshift-cluster"
MASTER_USERNAME = "admin"
MASTER_USER_PASSWORD = "MySecurePassword123!" # !!! IMPORTANT: Use a strong password in production !!!
NODE_TYPE = "dc2.large"
NUMBER_OF_NODES = 1 # Single node for demo
DB_NAME = "dev"
CLUSTER_SUBNET_GROUP_NAME = "my-boto3-redshift-subnet-group"
SECURITY_GROUP_NAME = "my-boto3-redshift-sg"

ec2_client = boto3.client('ec2', region_name=REGION)
redshift_client = boto3.client('redshift', region_name=REGION)

def get_default_vpc_and_subnets():
    """Gets the default VPC ID and two default subnet IDs."""
    print("--- Getting Default VPC and Subnet IDs ---")
    try:
        vpcs = ec2_client.describe_vpcs(Filters=[{'Name': 'is-default', 'Values': ['true']}])['Vpcs']
        if not vpcs:
            raise Exception("No default VPC found.")
        vpc_id = vpcs[0]['VpcId']
        print(f"Default VPC ID: {vpc_id}")

        subnets = ec2_client.describe_subnets(Filters=[{'Name': 'vpc-id', 'Values': [vpc_id]}, {'Name': 'default-for-az', 'Values': ['true']}])['Subnets']
        if len(subnets) < 2:
            raise Exception("Not enough default subnets found in the default VPC.")
        subnet_ids = [s['SubnetId'] for s in subnets[:2]] # Take first two
        print(f"Default Subnet IDs: {subnet_ids}")
        return vpc_id, subnet_ids
    except ClientError as e:
        print(f"Error getting VPC/Subnet info: {e}")
        raise

def create_cluster_subnet_group(subnet_ids):
    """Creates a Redshift Cluster Subnet Group."""
    print(f"\n--- Creating Cluster Subnet Group: {CLUSTER_SUBNET_GROUP_NAME} ---")
    try:
        redshift_client.create_cluster_subnet_group(
            ClusterSubnetGroupName=CLUSTER_SUBNET_GROUP_NAME,
            Description="Subnet group for Boto3 Redshift cluster",
            SubnetIds=subnet_ids
        )
        print("Cluster Subnet Group created.")
    except ClientError as e:
        if e.response['Error']['Code'] == 'ClusterSubnetGroupAlreadyExists':
            print(f"Cluster Subnet Group '{CLUSTER_SUBNET_GROUP_NAME}' already exists. Skipping creation.")
        else:
            print(f"Error creating Cluster Subnet Group: {e}")
            raise

def create_security_group(vpc_id):
    """Creates a Security Group for Redshift."""
    print(f"\n--- Creating Security Group: {SECURITY_GROUP_NAME} ---")
    try:
        sg_response = ec2_client.create_security_group(
            GroupName=SECURITY_GROUP_NAME,
            Description="Allow Redshift traffic",
            VpcId=vpc_id
        )
        sg_id = sg_response['GroupId']
        print(f"Security Group '{SECURITY_GROUP_NAME}' created with ID: {sg_id}")

        ec2_client.authorize_security_group_ingress(
            GroupId=sg_id,
            IpPermissions=[
                {'IpProtocol': 'tcp', 'FromPort': 5439, 'ToPort': 5439, 'IpRanges': [{'CidrIp': '0.0.0.0/0'}]}
            ]
        )
        print("Inbound Redshift rule added to Security Group.")
        return sg_id
    except ClientError as e:
        if e.response['Error']['Code'] == 'InvalidGroup.Duplicate':
            print(f"Security Group '{SECURITY_GROUP_NAME}' already exists. Fetching ID.")
            response = ec2_client.describe_security_groups(GroupNames=[SECURITY_GROUP_NAME])
            return response['SecurityGroups'][0]['GroupId']
        else:
            print(f"Error creating security group: {e}")
            raise

def create_redshift_cluster(sg_id):
    """Creates a Redshift cluster."""
    print(f"\n--- Creating Redshift Cluster: {CLUSTER_IDENTIFIER} ---")
    try:
        redshift_client.create_cluster(
            ClusterIdentifier=CLUSTER_IDENTIFIER,
            NodeType=NODE_TYPE,
            NumberOfNodes=NUMBER_OF_NODES,
            MasterUsername=MASTER_USERNAME,
            MasterUserPassword=MASTER_USER_PASSWORD,
            DBName=DB_NAME,
            ClusterSubnetGroupName=CLUSTER_SUBNET_GROUP_NAME,
            VpcSecurityGroupIds=[sg_id],
            PubliclyAccessible=True, # For demo purposes
            Tags=[{'Key': 'Name', 'Value': CLUSTER_IDENTIFIER}]
        )
        print(f"Redshift cluster '{CLUSTER_IDENTIFIER}' created. Waiting for it to be available (this can take 10-15 minutes)...")
        redshift_client.get_waiter('cluster_available').wait(ClusterIdentifier=CLUSTER_IDENTIFIER)
        print("Redshift cluster is available.")
        
        response = redshift_client.describe_clusters(ClusterIdentifier=CLUSTER_IDENTIFIER)
        endpoint = response['Clusters'][0]['Endpoint']['Address']
        return endpoint
    except ClientError as e:
        if e.response['Error']['Code'] == 'ClusterAlreadyExists':
            print(f"Redshift cluster '{CLUSTER_IDENTIFIER}' already exists. Skipping creation.")
            response = redshift_client.describe_clusters(ClusterIdentifier=CLUSTER_IDENTIFIER)
            return response['Clusters'][0]['Endpoint']['Address']
        else:
            print(f"Error creating Redshift cluster: {e}")
            raise

def cleanup_resources(sg_id):
    """Cleans up all created resources."""
    print(f"\n--- Cleaning up resources ---")

    # Delete Redshift Cluster
    print(f"Deleting Redshift cluster '{CLUSTER_IDENTIFIER}'...")
    try:
        redshift_client.delete_cluster(
            ClusterIdentifier=CLUSTER_IDENTIFIER,
            SkipFinalClusterSnapshot=True
        )
        redshift_client.get_waiter('cluster_deleted').wait(ClusterIdentifier=CLUSTER_IDENTIFIER)
        print("Redshift cluster deleted.")
    except ClientError as e:
        if e.response['Error']['Code'] == 'ClusterNotFound':
            print(f"Redshift cluster '{CLUSTER_IDENTIFIER}' not found, skipping deletion.")
        else:
            print(f"Error deleting Redshift cluster: {e}")

    # Delete Cluster Subnet Group
    print(f"Deleting Cluster Subnet Group '{CLUSTER_SUBNET_GROUP_NAME}'...")
    try:
        redshift_client.delete_cluster_subnet_group(ClusterSubnetGroupName=CLUSTER_SUBNET_GROUP_NAME)
        print("Cluster Subnet Group deleted.")
    except ClientError as e:
        if e.response['Error']['Code'] == 'ClusterSubnetGroupNotFoundFault':
            print(f"Cluster Subnet Group '{CLUSTER_SUBNET_GROUP_NAME}' not found, skipping deletion.")
        else:
            print(f"Error deleting Cluster Subnet Group: {e}")

    # Delete Security Group
    print(f"Deleting Security Group '{SECURITY_GROUP_NAME}'...")
    try:
        ec2_client.delete_security_group(GroupId=sg_id)
        print("Security Group deleted.")
    except ClientError as e:
        if e.response['Error']['Code'] == 'InvalidGroup.NotFound':
            print(f"Security Group '{SECURITY_GROUP_NAME}' not found, skipping deletion.")
        elif e.response['Error']['Code'] == 'DependencyViolation':
            print(f"Security Group '{SECURITY_GROUP_NAME}' is still in use. Retrying deletion after a short delay.")
            time.sleep(10)
            ec2_client.delete_security_group(GroupId=sg_id)
            print("Security Group deleted.")
        else:
            print(f"Error deleting security group: {e}")

def main():
    sg_id = None
    try:
        vpc_id, subnet_ids = get_default_vpc_and_subnets()
        create_cluster_subnet_group(subnet_ids)
        sg_id = create_security_group(vpc_id)
        cluster_endpoint = create_redshift_cluster(sg_id)

        print("\n--- Redshift Cluster Created Successfully! ---")
        print(f"Cluster Identifier: {CLUSTER_IDENTIFIER}")
        print(f"Endpoint: {cluster_endpoint}")
        print(f"Database Name: {DB_NAME}")
        print(f"Username: {MASTER_USERNAME}")
        print(f"Password: {MASTER_USER_PASSWORD}")

        input("Press Enter to delete the Redshift cluster and clean up resources...")

    except ClientError as e:
        print(f"An AWS client error occurred: {e}")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
    finally:
        cleanup_resources(sg_id)
        print("\n--- Redshift cluster demonstration and cleanup complete ---")

if __name__ == "__main__":
    main()
