
import boto3
import argparse
import time

def automate_rds_snapshot_dr(
    source_db_identifier,
    target_account_id=None,
    target_region=None,
    kms_key_id=None,
    region_name='us-east-1'
):
    """
    Automates the process of copying or sharing the latest RDS snapshot
    to another AWS account or region for disaster recovery purposes.

    Args:
        source_db_identifier (str): The identifier of the source RDS DB instance.
        target_account_id (str, optional): The AWS account ID to share/copy the snapshot to.
                                           Required if copying to another account.
        target_region (str, optional): The AWS region to copy the snapshot to.
                                       Required if copying to another region.
        kms_key_id (str, optional): The ARN or ID of the KMS key to encrypt the snapshot in the target.
                                    Required if copying an unencrypted snapshot or to a different account/region.
        region_name (str): The AWS region where the source RDS instance is located.
    """
    rds_client = boto3.client('rds', region_name=region_name)

    print(f"Starting RDS snapshot DR automation for '{source_db_identifier}' in region {region_name}...")

    # 1. Find the latest automated snapshot of the source DB instance
    print("\n>>> Step 1: Finding the latest automated snapshot...")
    try:
        snapshots = rds_client.describe_db_snapshots(
            DBInstanceIdentifier=source_db_identifier,
            SnapshotType='automated',
            IncludeShared=False,
            IncludePublic=False
        )['DBSnapshots']

        if not snapshots:
            print(f"Error: No automated snapshots found for '{source_db_identifier}'. Exiting.")
            return

        latest_snapshot = sorted(snapshots, key=lambda x: x['SnapshotCreateTime'], reverse=True)[0]
        source_snapshot_id = latest_snapshot['DBSnapshotIdentifier']
        source_snapshot_arn = latest_snapshot['DBSnapshotArn']
        source_snapshot_encrypted = latest_snapshot['Encrypted']

        print(f"   Found latest snapshot: '{source_snapshot_id}' (Encrypted: {source_snapshot_encrypted}) created at {latest_snapshot['SnapshotCreateTime']}")

    except Exception as e:
        print(f"Error finding latest snapshot: {e}")
        return

    # Determine action: Copy to another region, share with another account, or both
    if target_region and target_account_id:
        print("\n>>> Step 2: Copying snapshot to another region AND sharing with another account...")
        # This scenario requires copying to target region first, then sharing from there.
        # Or, sharing from source region, then target account copies to its region.
        # For simplicity, we'll focus on copying to target region, then assume sharing from there.
        # A direct cross-account, cross-region copy is done in one step.
        print(f"   Copying snapshot '{source_snapshot_id}' to region {target_region} for account {target_account_id}...")
        target_snapshot_id = f"copy-{source_snapshot_id}-{int(time.time())}"
        try:
            copy_response = rds_client.copy_db_snapshot(
                SourceDBSnapshotIdentifier=source_snapshot_arn,
                TargetDBSnapshotIdentifier=target_snapshot_id,
                SourceRegion=region_name,
                KmsKeyId=kms_key_id, # Required for cross-account/cross-region copy of encrypted snapshots
                CopyTags=True
            )
            print(f"   Copy initiated. New snapshot ID in {target_region}: {copy_response['DBSnapshot']['DBSnapshotIdentifier']}")
            # Wait for copy to complete (in target region)
            target_rds_client = boto3.client('rds', region_name=target_region)
            waiter = target_rds_client.get_waiter('db_snapshot_available')
            waiter.wait(DBSnapshotIdentifier=target_snapshot_id)
            print(f"   Snapshot '{target_snapshot_id}' available in {target_region}.")

            # Share the copied snapshot with the target account in the target region
            print(f"   Sharing snapshot '{target_snapshot_id}' in {target_region} with account {target_account_id}...")
            target_rds_client.modify_db_snapshot_attribute(
                DBSnapshotIdentifier=target_snapshot_id,
                AttributeName='restore',
                ValuesToAdd=[target_account_id]
            )
            print(f"   Snapshot '{target_snapshot_id}' shared with account {target_account_id} in {target_region}.")

        except Exception as e:
            print(f"Error copying/sharing snapshot cross-region/cross-account: {e}")
            return

    elif target_region:
        print("\n>>> Step 2: Copying snapshot to another region...")
        target_snapshot_id = f"copy-{source_snapshot_id}-{int(time.time())}"
        try:
            copy_response = rds_client.copy_db_snapshot(
                SourceDBSnapshotIdentifier=source_snapshot_arn,
                TargetDBSnapshotIdentifier=target_snapshot_id,
                SourceRegion=region_name,
                KmsKeyId=kms_key_id, # Required if source is encrypted or target needs encryption
                CopyTags=True
            )
            print(f"   Copy initiated. New snapshot ID in {target_region}: {copy_response['DBSnapshot']['DBSnapshotIdentifier']}")
            # Wait for copy to complete (in target region)
            target_rds_client = boto3.client('rds', region_name=target_region)
            waiter = target_rds_client.get_waiter('db_snapshot_available')
            waiter.wait(DBSnapshotIdentifier=target_snapshot_id)
            print(f"   Snapshot '{target_snapshot_id}' available in {target_region}.")

        except Exception as e:
            print(f"Error copying snapshot to region {target_region}: {e}")
            return

    elif target_account_id:
        print("\n>>> Step 2: Sharing snapshot with another account...")
        try:
            rds_client.modify_db_snapshot_attribute(
                DBSnapshotIdentifier=source_snapshot_id,
                AttributeName='restore',
                ValuesToAdd=[target_account_id]
            )
            print(f"   Snapshot '{source_snapshot_id}' shared with account {target_account_id} in {region_name}.")
        except Exception as e:
            print(f"Error sharing snapshot with account {target_account_id}: {e}")
            return

    else:
        print("No target account ID or target region specified. No action taken.")

    print("\nRDS snapshot DR automation completed.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Automate RDS snapshot sharing/copying for DR.")
    parser.add_argument("--source-db-identifier", required=True, help="Identifier of the source RDS DB instance.")
    parser.add_argument("--target-account-id", help="Optional. AWS account ID to share/copy the snapshot to.")
    parser.add_argument("--target-region", help="Optional. AWS region to copy the snapshot to.")
    parser.add_argument("--kms-key-id", help="Optional. KMS key ID for encryption in target (required for cross-account/region copy of encrypted snapshots).")
    parser.add_argument("--region", default="us-east-1", help="AWS region of the source RDS instance (default: us-east-1).")

    args = parser.parse_args()

    if not args.target_account_id and not args.target_region:
        print("Error: Please specify either --target-account-id or --target-region (or both). Exiting.")
    else:
        automate_rds_snapshot_dr(
            source_db_identifier=args.source_db_identifier,
            target_account_id=args.target_account_id,
            target_region=args.target_region,
            kms_key_id=args.kms_key_id,
            region_name=args.region
        )
