from .base_agent import BaseAgent
from ..aws_connector import AWSConnector
import time
from datetime import datetime, timedelta

class RDSAgent(BaseAgent):
    """
    An agent specialized in handling AWS RDS tasks.
    """

    def execute(self, command: str, **kwargs):
        """
        Executes a given command related to RDS.
        Includes a confirmation step for creation actions.
        """
        if command == 'create_db_instance':
            db_instance_identifier = kwargs.get('db_instance_identifier', '')
            engine = kwargs.get('engine', '')
            db_instance_class = kwargs.get('db_instance_class', '')
            allocated_storage = kwargs.get('allocated_storage', '')
            master_username = kwargs.get('master_username', '')

            print(f"You are about to create an RDS instance '{db_instance_identifier}' ({engine}, {db_instance_class}, {allocated_storage}GB) with master user '{master_username}'.")
            confirm = input("Are you sure you want to proceed? (yes/no): ")
            if confirm.lower() == 'yes':
                return self._create_db_instance(**kwargs)
            else:
                return {"status": "cancelled", "message": "Create DB instance command cancelled by user."}
        elif command == 'smart_create_db_instance':
            db_instance_identifier = kwargs.get('db_instance_identifier', '')
            engine = kwargs.get('engine', '')
            print(f"You are about to smart-create an RDS instance '{db_instance_identifier}' ({engine}).")
            confirm = input("This may create a new VPC, DB Subnet Group, and Security Group. Are you sure? (yes/no): ")
            if confirm.lower() == 'yes':
                return self._smart_create_db_instance(**kwargs)
            else:
                return {"status": "cancelled", "message": "Smart Create DB instance command cancelled by user."}
        elif command == 'troubleshoot_db_instance':
            return self._troubleshoot_db_instance(**kwargs)
        else:
            raise NotImplementedError(f"Command '{command}' is not supported by RDSAgent.")

    def _create_db_instance(self, region: str, db_instance_identifier: str, engine: str, db_instance_class: str, allocated_storage: int, master_username: str, master_password: str, vpc_security_group_ids: list = None, db_subnet_group_name: str = None):
        """Creates a new RDS DB instance."""
        print(f"RDSAgent: Creating DB instance '{db_instance_identifier}' in region {region}...")
        try:
            rds_client = AWSConnector.get_client('rds', region_name=region)
            
            params = {
                'DBInstanceIdentifier': db_instance_identifier,
                'Engine': engine,
                'DBInstanceClass': db_instance_class,
                'AllocatedStorage': allocated_storage,
                'MasterUsername': master_username,
                'MasterUserPassword': master_password,
                'PubliclyAccessible': False, # Default to private for security
                'MultiAZ': False, # Default to single AZ
                'BackupRetentionPeriod': 7, # Default backup retention
                'Port': 5432 if engine == 'postgres' else 3306, # Default ports
            }

            if vpc_security_group_ids:
                params['VpcSecurityGroupIds'] = vpc_security_group_ids
            if db_subnet_group_name:
                params['DBSubnetGroupName'] = db_subnet_group_name

            response = rds_client.create_db_instance(**params)
            db_instance_arn = response['DBInstance']['DBInstanceArn']
            return {"status": "success", "message": f"DB instance '{db_instance_identifier}' created successfully.", "db_instance_arn": db_instance_arn}
        except Exception as e:
            print(f"Error creating DB instance: {e}")
            return {"status": "error", "message": str(e)}

    def _smart_create_db_instance(self, region: str, db_instance_identifier: str, engine: str, master_username: str, master_password: str, db_instance_class: str = 'db.t3.micro', allocated_storage: int = 20, vpc_id: str = None):
        """
        Smartly creates an RDS DB instance, handling VPC, DB Subnet Group, and Security Group creation.
        """
        print(f"RDSAgent: Beginning Smart Create for DB instance '{db_instance_identifier}' in region {region}...")
        try:
            ec2_client = AWSConnector.get_client('ec2', region_name=region)
            rds_client = AWSConnector.get_client('rds', region_name=region)
            
            # 1. Get/Create VPC and Subnets
            if vpc_id is None:
                print("No VPC ID provided. Attempting to find default VPC...")
                vpcs = ec2_client.describe_vpcs(Filters=[{'Name': 'isDefault', 'Values': ['true']}]).get('Vpcs', [])
                if not vpcs:
                    return {"status": "error", "message": "No default VPC found. Please create one or specify a VPC ID."}
                vpc_id = vpcs[0]['VpcId']
                print(f"Using default VPC: {vpc_id}")
            
            subnets_in_vpc = ec2_client.describe_subnets(Filters=[{'Name': 'vpc-id', 'Values': [vpc_id]}]).get('Subnets', [])
            if len(subnets_in_vpc) < 2:
                return {"status": "error", "message": f"VPC '{vpc_id}' does not have at least 2 subnets for a DB Subnet Group."}
            subnet_ids = [s['SubnetId'] for s in subnets_in_vpc[:2]] # Use first two subnets

            # 2. Create DB Subnet Group
            db_subnet_group_name = f"{db_instance_identifier}-sng"
            try:
                rds_client.create_db_subnet_group(
                    DBSubnetGroupName=db_subnet_group_name,
                    DBSubnetGroupDescription=f"DB Subnet Group for {db_instance_identifier}",
                    SubnetIds=subnet_ids,
                    Tags=[{'Key': 'Name', 'Value': db_subnet_group_name}]
                )
                print(f"Created DB Subnet Group: {db_subnet_group_name}")
            except rds_client.exceptions.DBSubnetGroupAlreadyExistsFault:
                print(f"DB Subnet Group '{db_subnet_group_name}' already exists. Using existing.")
            
            # 3. Create Security Group for RDS
            sg_name = f"{db_instance_identifier}-sg"
            sg_description = f"Security group for RDS instance {db_instance_identifier}"
            try:
                sg_response = ec2_client.create_security_group(GroupName=sg_name, Description=sg_description, VpcId=vpc_id)
                sg_id = sg_response['GroupId']
                # Authorize ingress for the DB port from within the VPC (e.g., 10.0.0.0/16)
                db_port = 5432 if engine == 'postgres' else 3306
                ec2_client.authorize_security_group_ingress(
                    GroupId=sg_id,
                    IpPermissions=[
                        {'IpProtocol': 'tcp', 'FromPort': db_port, 'ToPort': db_port, 'IpRanges': [{'CidrIp': '10.0.0.0/16'}]}
                    ]
                )
                print(f"Created Security Group: {sg_id}")
            except ec2_client.exceptions.ClientError as e:
                if 'already exists' in str(e):
                    print(f"Security Group '{sg_name}' already exists. Using existing.")
                    sg_id = ec2_client.describe_security_groups(Filters=[{'Name': 'group-name', 'Values': [sg_name]}, {'Name': 'vpc-id', 'Values': [vpc_id]}]).get('SecurityGroups', [])[0]['GroupId']
                else:
                    raise

            # 4. Create RDS Instance
            params = {
                'DBInstanceIdentifier': db_instance_identifier,
                'Engine': engine,
                'DBInstanceClass': db_instance_class,
                'AllocatedStorage': allocated_storage,
                'MasterUsername': master_username,
                'MasterUserPassword': master_password,
                'PubliclyAccessible': False,
                'MultiAZ': False,
                'BackupRetentionPeriod': 7,
                'Port': db_port,
                'VpcSecurityGroupIds': [sg_id],
                'DBSubnetGroupName': db_subnet_group_name,
            }

            response = rds_client.create_db_instance(**params)
            db_instance_arn = response['DBInstance']['DBInstanceArn']
            return {"status": "success", "message": f"DB instance '{db_instance_identifier}' creation initiated.", "db_instance_arn": db_instance_arn, "vpc_id": vpc_id, "db_subnet_group_name": db_subnet_group_name, "security_group_id": sg_id}
        except Exception as e:
            print(f"Error during smart RDS instance creation: {e}")
            # In a real scenario, add cleanup logic for created resources
            return {"status": "error", "message": str(e)}

    def _troubleshoot_db_instance(self, region: str, db_instance_identifier: str):
        """Runs a series of diagnostic checks on an RDS DB instance."""
        print(f"RDSAgent: Troubleshooting DB instance '{db_instance_identifier}' in region {region}...")
        report = {
            'db_instance_identifier': db_instance_identifier,
            'region': region,
            'findings': [],
            'recommendations': []
        }
        try:
            rds_client = AWSConnector.get_client('rds', region_name=region)
            cloudwatch_client = AWSConnector.get_client('cloudwatch', region_name=region)
            ec2_client = AWSConnector.get_client('ec2', region_name=region)

            # 1. Describe DB Instance Status
            db_instance_desc = rds_client.describe_db_instances(DBInstanceIdentifier=db_instance_identifier)['DBInstances'][0]
            status = db_instance_desc['DBInstanceStatus']
            report['findings'].append(f"DB Instance status: '{status}'.")
            if status != 'available':
                report['recommendations'].append(f"DB instance is not 'available'. Check RDS console for recent events.")

            # 2. Check CloudWatch Metrics (CPUUtilization, DatabaseConnections)
            end_time = datetime.utcnow()
            start_time = end_time - timedelta(minutes=30)
            
            metrics = [
                {'Name': 'CPUUtilization', 'Statistic': 'Average', 'Unit': 'Percent'},
                {'Name': 'DatabaseConnections', 'Statistic': 'Average', 'Unit': 'Count'}
            ]

            for metric in metrics:
                metric_data = cloudwatch_client.get_metric_statistics(
                    Namespace='AWS/RDS',
                    MetricName=metric['Name'],
                    Dimensions=[
                        {'Name': 'DBInstanceIdentifier', 'Value': db_instance_identifier},
                    ],
                    StartTime=start_time,
                    EndTime=end_time,
                    Period=300, # 5 minutes
                    Statistics=[metric['Statistic']]
                )
                datapoints = metric_data['Datapoints']
                if datapoints:
                    latest_value = datapoints[-1][metric['Statistic']]
                    report['findings'].append(f"{metric['Name']} (last 5 min avg): {latest_value:.2f} {metric['Unit']}.")
                    # Simple thresholding for recommendations
                    if metric['Name'] == 'CPUUtilization' and latest_value > 80:
                        report['recommendations'].append("High CPU utilization detected. Consider scaling up instance type or optimizing queries.")
                    if metric['Name'] == 'DatabaseConnections' and latest_value > 100: # Example threshold
                        report['recommendations'].append("High database connections detected. Review application connection pooling or scale up.")
                else:
                    report['findings'].append(f"No CloudWatch data available for {metric['Name']}.")

            # 3. Check Security Groups
            vpc_security_group_ids = [sg['VpcSecurityGroupId'] for sg in db_instance_desc['VpcSecurityGroups']]
            for sg_id in vpc_security_group_ids:
                sg_details = ec2_client.describe_security_groups(GroupIds=[sg_id])['SecurityGroups'][0]
                db_port = db_instance_desc['Endpoint']['Port']
                
                port_open_to_vpc = False
                for ip_permission in sg_details.get('IpPermissions', []):
                    if ip_permission.get('FromPort') == db_port and ip_permission.get('ToPort') == db_port:
                        for ip_range in ip_permission.get('IpRanges', []):
                            if ip_range['CidrIp'].startswith('10.0.'): # Assuming default VPC CIDR
                                port_open_to_vpc = True
                                break
                if not port_open_to_vpc:
                    report['findings'].append(f"Security Group '{sg_details['GroupName']}' ({sg_id}) does not appear to allow ingress on DB port {db_port} from within the VPC.")
                    report['recommendations'].append(f"Ensure Security Group '{sg_id}' allows ingress on port {db_port} from your application's subnet/security group.")

            print(f"--- Workflow: Troubleshoot DB Instance Finished ---")
            return {"status": "success", "report": report}

        except Exception as e:
            print(f"--- Workflow Failed: {e} ---")
            return {"status": "error", "message": f"Troubleshoot DB Instance workflow failed: {e}"}

    def _troubleshoot_db_instance(self, region: str, db_instance_identifier: str):
        """Runs a series of diagnostic checks on an RDS DB instance."""
        print(f"RDSAgent: Troubleshooting DB instance '{db_instance_identifier}' in region {region}...")
        report = {
            'db_instance_identifier': db_instance_identifier,
            'region': region,
            'findings': [],
            'recommendations': []
        }
        try:
            rds_client = AWSConnector.get_client('rds', region_name=region)
            cloudwatch_client = AWSConnector.get_client('cloudwatch', region_name=region)
            ec2_client = AWSConnector.get_client('ec2', region_name=region)

            # 1. Describe DB Instance Status
            db_instance_desc = rds_client.describe_db_instances(DBInstanceIdentifier=db_instance_identifier)['DBInstances'][0]
            status = db_instance_desc['DBInstanceStatus']
            report['findings'].append(f"DB Instance status: '{status}'.")
            if status != 'available':
                report['recommendations'].append(f"DB instance is not 'available'. Check RDS console for recent events.")

            # 2. Check CloudWatch Metrics (CPUUtilization, DatabaseConnections)
            end_time = datetime.utcnow()
            start_time = end_time - timedelta(minutes=30)
            
            metrics = [
                {'Name': 'CPUUtilization', 'Statistic': 'Average', 'Unit': 'Percent'},
                {'Name': 'DatabaseConnections', 'Statistic': 'Average', 'Unit': 'Count'}
            ]

            for metric in metrics:
                metric_data = cloudwatch_client.get_metric_statistics(
                    Namespace='AWS/RDS',
                    MetricName=metric['Name'],
                    Dimensions=[
                        {'Name': 'DBInstanceIdentifier', 'Value': db_instance_identifier},
                    ],
                    StartTime=start_time,
                    EndTime=end_time,
                    Period=300, # 5 minutes
                    Statistics=[metric['Statistic']]
                )
                datapoints = metric_data['Datapoints']
                if datapoints:
                    latest_value = datapoints[-1][metric['Statistic']]
                    report['findings'].append(f"{metric['Name']} (last 5 min avg): {latest_value:.2f} {metric['Unit']}.")
                    # Simple thresholding for recommendations
                    if metric['Name'] == 'CPUUtilization' and latest_value > 80:
                        report['recommendations'].append("High CPU utilization detected. Consider scaling up instance type or optimizing queries.")
                    if metric['Name'] == 'DatabaseConnections' and latest_value > 100: # Example threshold
                        report['recommendations'].append("High database connections detected. Review application connection pooling or scale up.")
                else:
                    report['findings'].append(f"No CloudWatch data available for {metric['Name']}.")

            # 3. Check Security Groups
            vpc_security_group_ids = [sg['VpcSecurityGroupId'] for sg in db_instance_desc['VpcSecurityGroups']]
            for sg_id in vpc_security_group_ids:
                sg_details = ec2_client.describe_security_groups(GroupIds=[sg_id])['SecurityGroups'][0]
                db_port = db_instance_desc['Endpoint']['Port']
                
                port_open_to_vpc = False
                for ip_permission in sg_details.get('IpPermissions', []):
                    if ip_permission.get('FromPort') == db_port and ip_permission.get('ToPort') == db_port:
                        for ip_range in ip_permission.get('IpRanges', []):
                            if ip_range['CidrIp'].startswith('10.0.'): # Assuming default VPC CIDR
                                port_open_to_vpc = True
                                break
                if not port_open_to_vpc:
                    report['findings'].append(f"Security Group '{sg_details['GroupName']}' ({sg_id}) does not appear to allow ingress on DB port {db_port} from within the VPC.")
                    report['recommendations'].append(f"Ensure Security Group '{sg_id}' allows ingress on port {db_port} from your application's subnet/security group.")

            print(f"--- Workflow: Troubleshoot DB Instance Finished ---")
            return {"status": "success", "report": report}

        except Exception as e:
            print(f"--- Workflow Failed: {e} ---")
            return {"status": "error", "message": f"Troubleshoot DB Instance workflow failed: {e}"}

