
import boto3
import argparse

def enforce_cloudwatch_log_retention(
    retention_days,
    region_name='us-east-1',
    dry_run=True
):
    """
    Scans all CloudWatch Log Groups in a region and enforces a specified retention policy.

    Args:
        retention_days (int): The desired retention period in days (e.g., 7, 30, 90, 180, 365, 400, 545, 731, 1827, 3653).
        region_name (str): The AWS region to scan.
        dry_run (bool): If True, only reports changes without applying them. If False, applies changes.
    """
    logs_client = boto3.client('logs', region_name=region_name)

    print(f"Starting CloudWatch Log Group retention policy enforcement in region {region_name}...")
    if dry_run:
        print("*** DRY RUN MODE: No changes will be applied. ***")

    # Valid retention periods for CloudWatch Logs
    VALID_RETENTION_DAYS = [0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1827, 3653]
    if retention_days not in VALID_RETENTION_DAYS:
        print(f"Error: Invalid retention_days value '{retention_days}'. Must be one of {VALID_RETENTION_DAYS}. Exiting.")
        return

    try:
        # 1. List all CloudWatch Log Groups
        print("\n>>> Step 1: Listing all CloudWatch Log Groups...")
        log_groups_to_update = []
        paginator = logs_client.get_paginator('describe_log_groups')
        for page in paginator.paginate():
            for log_group in page['logGroups']:
                log_group_name = log_group['logGroupName']
                current_retention = log_group.get('retentionInDays')

                if current_retention != retention_days:
                    log_groups_to_update.append({
                        'name': log_group_name,
                        'current_retention': current_retention
                    })
                    print(f"   Log Group '{log_group_name}': Current retention {current_retention} days, desired {retention_days} days.")
                else:
                    print(f"   Log Group '{log_group_name}': Already set to {retention_days} days. Skipping.")

        if not log_groups_to_update:
            print("   All log groups already have the desired retention policy. No updates needed.")
            return

        # 2. Apply the new retention policy
        print(f"\n>>> Step 2: Applying retention policy of {retention_days} days...")
        if dry_run:
            print("   (Dry run: No changes applied. The following log groups would be updated:)")
        
        updated_count = 0
        for lg in log_groups_to_update:
            log_group_name = lg['name']
            if not dry_run:
                print(f"   Updating retention for log group '{log_group_name}' from {lg['current_retention']} to {retention_days} days.")
                logs_client.put_retention_policy(
                    logGroupName=log_group_name,
                    retentionInDays=retention_days
                )
                updated_count += 1
            else:
                print(f"   (Dry run) Would update log group '{log_group_name}' from {lg['current_retention']} to {retention_days} days.")
        
        if not dry_run:
            print(f"   Successfully updated retention policy for {updated_count} log groups.")
        else:
            print(f"   (Dry run) {len(log_groups_to_update)} log groups would have been updated.")

    except Exception as e:
        print(f"Error enforcing CloudWatch Log retention: {e}")
        return

    print("\nCloudWatch Log Group retention policy enforcement completed.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Enforce CloudWatch Log Group retention policy.")
    parser.add_argument("--retention-days", type=int, required=True, help="Desired retention period in days (e.g., 7, 30, 90, 365). Valid values: 0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1827, 3653.")
    parser.add_argument("--region", default="us-east-1", help="AWS region to scan (default: us-east-1).")
    parser.add_argument("--dry-run", action="store_true", help="If set, only reports changes without applying them.")

    args = parser.parse_args()

    enforce_cloudwatch_log_retention(
        retention_days=args.retention_days,
        region_name=args.region,
        dry_run=args.dry_run
    )
