
import boto3
import argparse
import json
import os
from datetime import datetime

def rotate_iam_access_key(
    user_name,
    output_file_path,
    region_name='us-east-1'
):
    """
    Rotates AWS IAM access keys for a specified user.
    It creates a new access key, deactivates the old one, deletes the old one,
    and saves the new access key to a specified JSON file.

    Args:
        user_name (str): The IAM user name whose access key will be rotated.
        output_file_path (str): The path to the JSON file where the new access key will be saved.
        region_name (str): The AWS region (used for boto3 client initialization).
    """
    iam_client = boto3.client('iam', region_name=region_name)

    print(f"Starting IAM access key rotation for user '{user_name}'...")

    try:
        # 1. List existing access keys for the user
        print("\n>>> Step 1: Listing existing access keys...")
        existing_keys = iam_client.list_access_keys(UserName=user_name)['AccessKeyMetadata']
        print(f"   Found {len(existing_keys)} existing access key(s).")

        # 2. Create a new access key
        print("\n>>> Step 2: Creating a new access key...")
        new_key_response = iam_client.create_access_key(UserName=user_name)
        new_access_key = new_key_response['AccessKey']
        new_access_key_id = new_access_key['AccessKeyId']
        new_secret_access_key = new_access_key['SecretAccessKey']
        print(f"   New Access Key ID created: {new_access_key_id}")

        # 3. Save the new access key to a file
        print(f"\n>>> Step 3: Saving new access key to '{output_file_path}'...")
        key_data = {
            "AccessKeyId": new_access_key_id,
            "SecretAccessKey": new_secret_access_key,
            "UserName": user_name,
            "CreationDate": datetime.now().isoformat()
        }
        with open(output_file_path, 'w') as f:
            json.dump(key_data, f, indent=4)
        print(f"   New access key saved successfully to '{output_file_path}'.")
        print("   IMPORTANT: Securely store this file and restrict its access.")

        # 4. Deactivate and delete old access keys
        if existing_keys:
            print("\n>>> Step 4: Deactivating and deleting old access keys...")
            for old_key in existing_keys:
                old_access_key_id = old_key['AccessKeyId']
                if old_access_key_id != new_access_key_id: # Don't delete the one we just created
                    print(f"   Deactivating old Access Key ID: {old_access_key_id}")
                    iam_client.update_access_key(
                        AccessKeyId=old_access_key_id,
                        Status='Inactive',
                        UserName=user_name
                    )
                    print(f"   Deleting old Access Key ID: {old_access_key_id}")
                    iam_client.delete_access_key(
                        AccessKeyId=old_access_key_id,
                        UserName=user_name
                    )
                    print(f"   Old Access Key ID {old_access_key_id} deleted.")
                else:
                    print(f"   Skipping deletion of newly created key: {old_access_key_id}")
        else:
            print("   No old access keys to deactivate/delete.")

    except Exception as e:
        print(f"Error during IAM access key rotation: {e}")
        return

    print(f"\nIAM access key rotation for user '{user_name}' completed successfully.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Rotate AWS IAM access keys for a specified user.")
    parser.add_argument("--user-name", required=True, help="The IAM user name whose access key will be rotated.")
    parser.add_argument("--output-file-path", required=True, help="The path to the JSON file where the new access key will be saved.")
    parser.add_argument("--region", default="us-east-1", help="AWS region (default: us-east-1).")

    args = parser.parse_args()

    rotate_iam_access_key(
        user_name=args.user_name,
        output_file_path=args.output_file_path,
        region_name=args.region
    )
