import boto3
import argparse
import datetime

def audit_iam_access_key_age(
    max_age_days=90,
    region_name='us-east-1',
    report_file_path=None
):
    """
    Audits all IAM users for access keys older than a specified maximum age.

    Args:
        max_age_days (int): The maximum age in days an access key should be.
                            Keys older than this will be flagged.
        region_name (str): The AWS region (used for boto3 client initialization).
        report_file_path (str, optional): Path to a file to save the report. If None, prints to console.
    """
    iam_client = boto3.client('iam', region_name=region_name)

    print(f"Starting IAM access key age audit in region {region_name}...")

    old_keys_found = []
    now = datetime.datetime.now(datetime.timezone.utc)

    try:
        # 1. List all IAM users
        print("\n>>> Step 1: Listing all IAM users...")
        users = []
        paginator = iam_client.get_paginator('list_users')
        for page in paginator.paginate():
            users.extend(page['Users'])
        print(f"   Found {len(users)} IAM users.")

        # 2. Check access keys for each user
        print("\n>>> Step 2: Checking access keys for each user...")
        for user in users:
            user_name = user['UserName']
            print(f"   Checking user: {user_name}")
            access_keys_response = iam_client.list_access_keys(UserName=user_name)
            for access_key in access_keys_response['AccessKeyMetadata']:
                access_key_id = access_key['AccessKeyId']
                create_date = access_key['CreateDate']
                
                age_days = (now - create_date).days

                if age_days > max_age_days:
                    old_keys_found.append({
                        'UserName': user_name,
                        'AccessKeyId': access_key_id,
                        'CreateDate': create_date.isoformat(),
                        'AgeDays': age_days
                    })
                    print(f"      WARNING: Access Key {access_key_id} for user {user_name} is {age_days} days old (older than {max_age_days} days).")
                else:
                    print(f"      Access Key {access_key_id} for user {user_name} is {age_days} days old (within {max_age_days} days).")

    except Exception as e:
        print(f"Error during IAM access key audit: {e}")
        return

    # 3. Generate Report
    print("\n>>> Step 3: Generating report...")
    report_output = []
    if old_keys_found:
        report_output.append(f"--- IAM Access Keys Older Than {max_age_days} Days ---")
        for key_info in old_keys_found:
            report_output.append(f"  User Name: {key_info['UserName']}")
            report_output.append(f"  Access Key ID: {key_info['AccessKeyId']}")
            report_output.append(f"  Created On: {key_info['CreateDate']}")
            report_output.append(f"  Age (Days): {key_info['AgeDays']}")
            report_output.append("----------------------------------------")
    else:
        report_output.append(f"No IAM access keys found older than {max_age_days} days.")

    if report_file_path:
        with open(report_file_path, 'w') as f:
            for line in report_output:
                f.write(line + '\n')
        print(f"Report saved to '{report_file_path}'.")
    else:
        for line in report_output:
            print(line)

    print("\nIAM access key age audit completed.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Audit AWS IAM access keys for age.")
    parser.add_argument("--max-age-days", type=int, default=90, help="Maximum age in days for an access key (default: 90).")
    parser.add_argument("--region", default="us-east-1", help="AWS region (default: us-east-1).")
    parser.add_argument("--report-file-path", help="Optional. Path to a file to save the report. If not provided, prints to console.")

    args = parser.parse_args()

    audit_iam_access_key_age(
        max_age_days=args.max_age_days,
        region_name=args.region,
        report_file_path=args.report_file_path
    )
