
#!/usr/bin/env python3

import boto3
import time
import argparse

# ==========================================================================
# SCRIPT: db_refresh_rds.py
# DESCRIPTION: Refreshes an Amazon RDS DB instance (typically non-production)
#              from the latest automated snapshot of a source DB instance
#               (typically production). This script automates the process of
#              finding the latest snapshot, optionally deleting an old target
#              instance, and then restoring a new instance with the specified
#              configuration. This is a common practice for maintaining up-to-date
#              development or staging environments with production-like data.
#
# USE CASE SCENARIO:
# A development team needs to periodically refresh their development database
# with the latest data from the production database to test new features or fix
# bugs. Manually performing this process is time-consuming and error-prone.
# This script automates that refresh process.
#
# PREREQUISITES:
# 1.  **AWS Credentials:** The AWS CLI or SDK must be configured with credentials
#     that have the necessary permissions. This can be done via `~/.aws/credentials`,
#     environment variables, or an instance profile/IAM role for EC2 instances.
# 2.  **boto3 Library:** The AWS SDK for Python (`boto3`) must be installed.
#     You can install it using `pip install boto3`.
# 3.  **IAM Permissions:** The principal executing this script (user, role, etc.)
#     must have comprehensive permissions for RDS operations, including:
#     - `rds:DescribeDBSnapshots`
#     - `rds:DeleteDBInstance` (if `--delete-old-instance` is used. Use with caution!)
#     - `rds:RestoreDBInstanceFromDBSnapshot`
#     - `rds:DescribeDBInstances`
# 4.  **Existing Resources:**
#     - A source RDS DB instance (e.g., your production database) with automated snapshots enabled.
#     - A pre-existing RDS DB Subnet Group where the new instance will be launched.
#     - One or more pre-existing VPC Security Group(s) for the new instance, allowing
#       necessary inbound/outbound traffic.
#
# HOW TO USE:
# 1.  **Save the script:** Save this content as `db_refresh_rds.py`.
# 2.  **Make it executable (Linux/macOS):** `chmod +x db_refresh_rds.py`
# 3.  **Run from your terminal:**
#     python db_refresh_rds.py \
#       --source-db-identifier "my-prod-db-instance" \
#       --target-db-identifier "my-dev-db-refreshed" \
#       --db-instance-class "db.t3.medium" \
#       --db-subnet-group-name "my-dev-db-subnet-group" \
#       --vpc-security-group-ids "sg-0abcdef1234567890" "sg-0fedcba9876543210" \
#       --region "us-east-1" \
#       --delete-old-instance
#
#     **Arguments:**
#     - `--source-db-identifier`: The identifier of the production/source RDS DB instance.
#     - `--target-db-identifier`: The desired identifier for the new development/target RDS DB instance.
#     - `--db-instance-class`: The EC2 instance type for the new RDS DB instance (e.g., `db.t3.medium`).
#     - `--db-subnet-group-name`: The DB Subnet Group where the new instance will be deployed.
#     - `--vpc-security-group-ids`: Space-separated list of Security Group IDs for the new instance.
#     - `--region`: (Optional) The AWS region. Defaults to `us-east-1`.
#     - `--delete-old-instance`: (Optional flag) If present, deletes the existing instance with `--target-db-identifier` before restoring. USE WITH EXTREME CAUTION IN PRODUCTION!
#
# IMPORTANT CONSIDERATIONS:
# - This script is designed for non-production refreshes. Deleting instances with data loss is intended.
# - The new instance will inherit the database engine, version, and master credentials from the snapshot.
# - If the source DB is encrypted, the target DB will also be encrypted (using the same KMS key).
# - The script waits for AWS operations to complete, which can take time depending on instance size.
# ==========================================================================

def refresh_rds_instance(
    source_db_identifier: str,
    target_db_identifier: str,
    db_instance_class: str,
    db_subnet_group_name: str,
    vpc_security_group_ids: list,
    region_name: str = 'us-east-1',
    delete_old_instance: bool = False
):
    """
    Refreshes an RDS instance from the latest snapshot of a source DB instance.

    This function automates the creation of a refreshed RDS instance in a target
    environment (e.g., development or staging). It first identifies the most recent
    automated snapshot of the source database, then optionally deletes an existing
    target database, and finally restores a new database from the identified snapshot.

    Args:
        source_db_identifier (str): The identifier of the source (e.g., production) RDS DB instance.
                                    This is the database from which the latest snapshot will be retrieved.
        target_db_identifier (str): The desired identifier for the new/refreshed RDS DB instance.
                                    If `delete_old_instance` is True, this also specifies the instance to be deleted.
        db_instance_class (str): The desired DB instance class for the newly restored instance.
                                 Examples: 'db.t3.medium', 'db.m5.large'. This determines compute and memory.
        db_subnet_group_name (str): The DB Subnet Group Name that the new instance will use.
                                    This dictates the VPC and subnets where the database instance will reside.
        vpc_security_group_ids (list): A list of VPC Security Group IDs to associate with the new instance.
                                     These security groups control network inbound and outbound access to the database.
        region_name (str): The AWS region where the RDS instances are located and operations will be performed.
                           Defaults to 'us-east-1'.
        delete_old_instance (bool): If True, attempts to delete an existing RDS instance with `target_db_identifier`
                                    before restoring the new one. **Use with extreme caution!**
    """
    # Initialize the boto3 RDS client for programmatic access to AWS RDS service.
    rds_client = boto3.client('rds', region_name=region_name)

    print(f"Starting RDS refresh process for '{target_db_identifier}' from '{source_db_identifier}' in region {region_name}...")

    # ==========================================================================
    # STEP 1: Identify the latest automated snapshot of the source DB instance.
    # Automated snapshots are regularly created by RDS for backup purposes.
    # ==========================================================================
    print(f"\n>>> Step 1: Searching for the latest automated snapshot of '{source_db_identifier}'...")
    try:
        # Call the describe_db_snapshots API to get a list of snapshots.
        # We filter by DBInstanceIdentifier and SnapshotType to narrow down results.
        snapshots_response = rds_client.describe_db_snapshots(
            DBInstanceIdentifier=source_db_identifier,
            SnapshotType='automated', # Only consider automated snapshots.
            IncludeShared=False,      # Do not include snapshots shared by other accounts.
            IncludePublic=False       # Do not include publicly available snapshots.
        )
        snapshots = snapshots_response['DBSnapshots']

        # If no snapshots are found, print an error and exit.
        if not snapshots:
            print(f"Error: No automated snapshots found for DB instance '{source_db_identifier}'. Please ensure automated backups are enabled. Exiting.")
            return

        # Sort the list of snapshots by their creation time in descending order (latest first).
        latest_snapshot = sorted(snapshots, key=lambda x: x['SnapshotCreateTime'], reverse=True)[0]
        latest_snapshot_arn = latest_snapshot['DBSnapshotArn'] # ARN of the latest snapshot.
        latest_snapshot_id = latest_snapshot['DBSnapshotIdentifier'] # Identifier of the latest snapshot.

        print(f"   Found latest snapshot: '{latest_snapshot_id}' created at {latest_snapshot['SnapshotCreateTime']}")

    except Exception as e:
        # Handle any exceptions that occur during the API call or data processing.
        print(f"Error finding snapshots for '{source_db_identifier}': {e}")
        return # Exit the function as we cannot proceed without a snapshot.

    # ==========================================================================
    # STEP 2: (Optional) Delete the old target DB instance if it exists.
    # This step is conditional based on the `delete_old_instance` flag.
    # It's crucial for refreshing an existing environment.
    # ==========================================================================
    if delete_old_instance:
        print(f"\n>>> Step 2: Checking for and deleting existing target DB instance '{target_db_identifier}' (as requested)...")
        try:
            # Attempt to describe the target instance. If it doesn't exist, an exception will be raised.
            rds_client.describe_db_instances(DBInstanceIdentifier=target_db_identifier)
            print(f"   Existing instance '{target_db_identifier}' found with current status. Initiating deletion...")

            # Delete the DB instance. `SkipFinalSnapshot=True` is used to prevent creating
            # a final snapshot before deletion, which is usually desired for non-production refreshes.
            # `DeleteAutomatedBackups=True` cleans up all associated automated backups.
            rds_client.delete_db_instance(
                DBInstanceIdentifier=target_db_identifier,
                SkipFinalSnapshot=True,       # Do NOT create a final snapshot.
                DeleteAutomatedBackups=True # Delete all automated backups for this instance.
            )
            print(f"   Deletion initiated for '{target_db_identifier}'. Waiting for instance to be fully deleted...")

            # Use a boto3 waiter to block script execution until the instance is confirmed deleted.
            # This prevents the script from trying to restore with a conflicting identifier.
            waiter = rds_client.get_waiter('db_instance_deleted')
            waiter.wait(DBInstanceIdentifier=target_db_identifier)
            print(f"   Instance '{target_db_identifier}' deleted successfully.")

        except rds_client.exceptions.DBInstanceNotFoundFault:
            # This exception means the instance was not found, so no deletion is necessary.
            print(f"   No existing instance '{target_db_identifier}' found. Skipping deletion.")
        except Exception as e:
            # Handle any other errors during the deletion process.
            print(f"Error deleting old instance '{target_db_identifier}': {e}")
            return # Exit the function if deletion fails mysteriously.
    else:
        print(f"\n>>> Step 2: Skipping deletion of old instance as '--delete-old-instance' flag was not set. Ensure no existing instance with '{target_db_identifier}' conflicts.")

    # ==========================================================================
    # STEP 3: Restore a new RDS instance from the latest identified snapshot.
    # This is the core operation of refreshing the database.
    # ==========================================================================
    print(f"\n>>> Step 3: Restoring new DB instance '{target_db_identifier}' from snapshot '{latest_snapshot_id}'...")
    try:
        # Call the restore_db_instance_from_db_snapshot API to create the new instance.
        # We pass in the desired configuration for the new instance.
        rds_client.restore_db_instance_from_db_snapshot(
            DBInstanceIdentifier=target_db_identifier,    # The identifier for the new instance.
            DBSnapshotIdentifier=latest_snapshot_id,     # The identifier of the snapshot to use.
            DBInstanceClass=db_instance_class,           # The compute and memory capacity for the new instance.
            DBSubnetGroupName=db_subnet_group_name,      # The VPC subnet group for network placement.
            VpcSecurityGroupIds=vpc_security_group_ids,  # The network security rules.
            PubliclyAccessible=False,                    # Set to False by default for enhanced security in non-prod environments.
            Tags=[{'Key': 'Name', 'Value': target_db_identifier}] # Tagging for identification and cost tracking.
        )
        print(f"   Restore initiated for '{target_db_identifier}'. Waiting for it to become available...")

        # Use a boto3 waiter to block script execution until the new instance is fully available.
        # This ensures subsequent operations (like getting endpoint) will succeed.
        waiter = rds_client.get_waiter('db_instance_available')
        waiter.wait(DBInstanceIdentifier=target_db_identifier)
        print(f"   New DB instance '{target_db_identifier}' restored and is now AVAILABLE.\n")

        # ==========================================================================
        # STEP 4: Retrieve and print the endpoint of the newly restored instance.
        # This is the connection string applications will use.
        # ==========================================================================
        print(f"\n>>> Step 4: Retrieving endpoint for new instance '{target_db_identifier}'...")
        new_instance_details = rds_client.describe_db_instances(DBInstanceIdentifier=target_db_identifier)['DBInstances'][0]
        endpoint = new_instance_details['Endpoint']['Address']
        port = new_instance_details['Endpoint']['Port']
        print(f"   New instance endpoint: {endpoint}:{port}")
        print(f"   You can now connect to your refreshed database using this endpoint.")

    except Exception as e:
        # Catch any exceptions during the restore process.
        print(f"Error restoring instance '{target_db_identifier}': {e}")
        return # Exit the function on error.

    print(f"\n=== RDS refresh process for '{target_db_identifier}' completed successfully. ===")

# ==========================================================================
# Main execution block to parse command-line arguments and call the function.
# This allows the script to be run from the command line with various options.
# ==========================================================================
if __name__ == "__main__":
    # Create an argument parser object.
    parser = argparse.ArgumentParser(
        description="""
        Automates the refresh of an Amazon RDS instance from the latest automated snapshot
        of a source database. Ideal for development/staging database refreshes.
        """
    )
    # Define the command-line arguments the script expects.
    parser.add_argument(
        "--source-db-identifier", 
        required=True, 
        help="Identifier of the source RDS DB instance (e.g., 'my-prod-db')."
    )
    parser.add_argument(
        "--target-db-identifier", 
        required=True, 
        help="Identifier for the new RDS DB instance to be created (e.g., 'my-dev-db-new')."
    )
    parser.add_argument(
        "--db-instance-class", 
        required=True, 
        help="DB instance class for the new instance (e.g., 'db.t3.medium')."
    )
    parser.add_argument(
        "--db-subnet-group-name", 
        required=True, 
        help="DB Subnet Group Name for the new instance."
    )
    parser.add_argument(
        "--vpc-security-group-ids", 
        nargs='+', # '+' means one or more arguments
        required=True, 
        help="List of VPC Security Group IDs for the new instance. Provide as space-separated values (e.g., sg-xxxxxxxxxxxx sg-yyyyyyyyyyyy)."
    )
    parser.add_argument(
        "--region", 
        default="us-east-1", 
        help="AWS region where the RDS instances are located (default: us-east-1)."
    )
    parser.add_argument(
        "--delete-old-instance", 
        action="store_true", # This argument is a boolean flag. It will be True if present.
        help="If set, the existing instance with target_db_identifier will be deleted before restore. USE WITH CAUTION."
    )

    # Parse the arguments provided by the user when running the script.
    args = parser.parse_args()

    # Call the main function with the parsed arguments.
    refresh_rds_instance(
        source_db_identifier=args.source_db_identifier,
        target_db_identifier=args.target_db_identifier,
        db_instance_class=args.db_instance_class,
        db_subnet_group_name=args.db_subnet_group_name,
        vpc_security_group_ids=args.vpc_security_group_ids,
        region_name=args.region,
        delete_old_instance=args.delete_old_instance
    )

