from .base_agent import BaseAgent
from ..aws_connector import AWSConnector

class CloudWatchAgent(BaseAgent):
    """
    An agent specialized in handling AWS CloudWatch tasks.
    """

    def execute(self, command: str, **kwargs):
        """
        Executes a given command related to CloudWatch.
        """
        if command == 'list_alarms':
            return self._list_alarms(**kwargs)
        elif command == 'list_dashboards':
            return self._list_dashboards(**kwargs)
        elif command == 'smart_create_cpu_alarm':
            instance_id = kwargs.get('instance_id', '')
            region = kwargs.get('region', '')
            print(f"You are about to smart-create a CPU utilization alarm for instance '{instance_id}' in {region}.")
            confirm = input("This will create a new CloudWatch alarm. Are you sure you want to proceed? (yes/no): ")
            if confirm.lower() == 'yes':
                return self._smart_create_cpu_alarm(**kwargs)
            else:
                return {"status": "cancelled", "message": "Smart Create CPU alarm command cancelled by user."}
        else:
            raise NotImplementedError(f"Command '{command}' is not supported by CloudWatchAgent.")

    def _smart_create_cpu_alarm(self, region: str, instance_id: str, threshold: float = 80.0, period: int = 300, evaluation_periods: int = 2, sns_topic_arn: str = None):
        """
        Creates a CloudWatch alarm for high CPU utilization on a specified EC2 instance.
        """
        print(f"CloudWatchAgent: Smart creating CPU alarm for instance '{instance_id}' in region {region}...")
        try:
            cloudwatch_client = AWSConnector.get_client('cloudwatch', region_name=region)
            alarm_name = f"High-CPU-{instance_id}"

            alarm_params = {
                'AlarmName': alarm_name,
                'ComparisonOperator': 'GreaterThanThreshold',
                'EvaluationPeriods': evaluation_periods,
                'MetricName': 'CPUUtilization',
                'Namespace': 'AWS/EC2',
                'Period': period,
                'Statistic': 'Average',
                'Threshold': threshold,
                'ActionsEnabled': True,
                'AlarmDescription': f"Alarm when CPU utilization exceeds {threshold}% for {evaluation_periods} consecutive periods of {period/60} minutes on instance {instance_id}.",
                'Dimensions': [
                    {
                        'Name': 'InstanceId',
                        'Value': instance_id
                    },
                ]
            }

            if sns_topic_arn:
                alarm_params['AlarmActions'] = [sns_topic_arn]

            cloudwatch_client.put_metric_alarm(**alarm_params)

            return {"status": "success", "message": f"CloudWatch CPU alarm '{alarm_name}' created successfully for instance '{instance_id}'."}
        except Exception as e:
            print(f"Error during smart CloudWatch CPU alarm creation: {e}")
            return {"status": "error", "message": str(e)}

    def _list_alarms(self, region: str):
        """Lists all CloudWatch alarms in a specified region."""
        print(f"CloudWatchAgent: Listing alarms in region {region}...")
        try:
            cloudwatch_client = AWSConnector.get_client('cloudwatch', region_name=region)
            response = cloudwatch_client.describe_alarms()
            alarm_names = [alarm['AlarmName'] for alarm in response['MetricAlarms']]
            return {"status": "success", "alarms": alarm_names}
        except Exception as e:
            print(f"Error listing CloudWatch alarms: {e}")
            return {"status": "error", "message": str(e)}

    def _list_dashboards(self, region: str):
        """Lists all CloudWatch dashboards in a specified region."""
        print(f"CloudWatchAgent: Listing dashboards in region {region}...")
        try:
            cloudwatch_client = AWSConnector.get_client('cloudwatch', region_name=region)
            response = cloudwatch_client.list_dashboards()
            dashboard_names = [dashboard['DashboardName'] for dashboard in response['DashboardEntries']]
            return {"status": "success", "dashboards": dashboard_names}
        except Exception as e:
            print(f"Error listing CloudWatch dashboards: {e}")
            return {"status": "error", "message": str(e)}
