
import boto3
import argparse
import datetime

def cleanup_ebs_resources(
    region_name='us-east-1',
    delete_unattached_volumes=False,
    delete_old_snapshots=False,
    snapshot_age_days=30
):
    """
    Cleans up unused EBS volumes and old snapshots to optimize costs.

    Args:
        region_name (str): The AWS region.
        delete_unattached_volumes (bool): If True, deletes unattached EBS volumes.
        delete_old_snapshots (bool): If True, deletes snapshots older than snapshot_age_days.
        snapshot_age_days (int): Snapshots older than this many days will be deleted.
    """
    ec2_client = boto3.client('ec2', region_name=region_name)

    print(f"Starting EBS cleanup process in region {region_name}...")

    # 1. Cleanup Unattached EBS Volumes
    if delete_unattached_volumes:
        print("\n>>> Step 1: Cleaning up unattached EBS volumes...")
        try:
            volumes = ec2_client.describe_volumes(
                Filters=[
                    {'Name': 'status', 'Values': ['available']}
                ]
            )['Volumes']

            if not volumes:
                print("   No unattached EBS volumes found.")
            else:
                for volume in volumes:
                    volume_id = volume['VolumeId']
                    print(f"   Deleting unattached volume: {volume_id}")
                    ec2_client.delete_volume(VolumeId=volume_id)
                print(f"   Successfully initiated deletion for {len(volumes)} unattached EBS volumes.")

        except Exception as e:
            print(f"Error cleaning up unattached volumes: {e}")

    # 2. Cleanup Old Snapshots
    if delete_old_snapshots:
        print(f"\n>>> Step 2: Cleaning up snapshots older than {snapshot_age_days} days...")
        try:
            snapshots = ec2_client.describe_snapshots(
                OwnerIds=['self'] # Only consider snapshots owned by this account
            )['Snapshots']

            if not snapshots:
                print("   No EBS snapshots found.")
            else:
                now = datetime.datetime.now(datetime.timezone.utc) # Ensure timezone awareness
                deleted_count = 0
                for snapshot in snapshots:
                    snapshot_id = snapshot['SnapshotId']
                    start_time = snapshot['StartTime']
                    age = now - start_time

                    if age.days > snapshot_age_days:
                        # Before deleting, ensure it's not attached to any AMI
                        # This is a crucial check to prevent deleting snapshots that are still in use.
                        images = ec2_client.describe_images(
                            Filters=[
                                {'Name': 'block-device-mapping.snapshot-id', 'Values': [snapshot_id]}
                            ]
                        )['Images']

                        if not images:
                            print(f"   Deleting old snapshot: {snapshot_id} (Age: {age.days} days)")
                            ec2_client.delete_snapshot(SnapshotId=snapshot_id)
                            deleted_count += 1
                        else:
                            print(f"   Skipping snapshot {snapshot_id}: Attached to AMI(s) {', '.join([img['ImageId'] for img in images])}")

                print(f"   Successfully initiated deletion for {deleted_count} old EBS snapshots.")

        except Exception as e:
            print(f"Error cleaning up old snapshots: {e}")

    print("\nEBS cleanup process completed.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Clean up unused EBS volumes and old snapshots.")
    parser.add_argument("--region", default="us-east-1", help="AWS region (default: us-east-1).")
    parser.add_argument("--delete-unattached-volumes", action="store_true", help="If set, deletes unattached EBS volumes.")
    parser.add_argument("--delete-old-snapshots", action="store_true", help="If set, deletes snapshots older than --snapshot-age-days.")
    parser.add_argument("--snapshot-age-days", type=int, default=30, help="Snapshots older than this many days will be deleted (default: 30).")

    args = parser.parse_args()

    if not args.delete_unattached_volumes and not args.delete_old_snapshots:
        print("Error: Please specify at least one cleanup action (--delete-unattached-volumes or --delete-old-snapshots). Exiting.")
    else:
        cleanup_ebs_resources(
            region_name=args.region,
            delete_unattached_volumes=args.delete_unattached_volumes,
            delete_old_snapshots=args.delete_old_snapshots,
            snapshot_age_days=args.snapshot_age_days
        )
