from .base_agent import BaseAgent
from ..aws_connector import AWSConnector
import time

class WAFAgent(BaseAgent):
    """
    An agent specialized in handling AWS WAF (Web Application Firewall) tasks.
    """

    def execute(self, command: str, **kwargs):
        """
        Executes a given command related to WAF.
        """
        if command == 'smart_create_web_acl':
            name = kwargs.get('name', '')
            region = kwargs.get('region', '')
            print(f"You are about to smart-create a WAF Web ACL '{name}' in {region}.")
            confirm = input("This will create a new WAF Web ACL with a common managed rule set. Are you sure you want to proceed? (yes/no): ")
            if confirm.lower() == 'yes':
                return self._smart_create_web_acl(**kwargs)
            else:
                return {"status": "cancelled", "message": "Smart Create WAF Web ACL command cancelled by user."}
        elif command == 'troubleshoot_web_acl':
            name = kwargs.get('name', '')
            region = kwargs.get('region', '')
            print(f"You are about to troubleshoot WAF Web ACL '{name}' in {region}.")
            confirm = input("This will retrieve details about the Web ACL. Are you sure you want to proceed? (yes/no): ")
            if confirm.lower() == 'yes':
                return self._troubleshoot_web_acl(**kwargs)
            else:
                return {"status": "cancelled", "message": "Troubleshoot WAF Web ACL command cancelled by user."}
        else:
            raise NotImplementedError(f"Command '{command}' is not supported by WAFAgent.")

    def _smart_create_web_acl(self, region: str, name: str, scope: str = 'REGIONAL', default_action: str = 'allow'):
        """
        Creates a WAF Web ACL with a common managed rule set.
        """
        print(f"WAFAgent: Smart creating Web ACL '{name}' in region {region}...")
        try:
            waf_client = AWSConnector.get_client('wafv2', region_name=region)

            # 1. Create Web ACL
            response = waf_client.create_web_acl(
                Name=name,
                Scope=scope,
                DefaultAction={
                    'Allow': {} if default_action == 'allow' else None,
                    'Block': {} if default_action == 'block' else None,
                },
                VisibilityConfig={
                    'SampledRequestsEnabled': True,
                    'CloudWatchMetricsEnabled': True,
                    'MetricName': f'{name}-metric'
                },
                Description=f'Web ACL for {name} with common managed rules'
            )
            web_acl_arn = response['Summary']['ARN']
            web_acl_id = response['Summary']['Id']
            print(f"Web ACL '{name}' ({web_acl_id}) created.")

            # 2. Add AWSManagedRulesCommonRuleSet
            waf_client.update_web_acl(
                Name=name,
                Scope=scope,
                Id=web_acl_id,
                DefaultAction={
                    'Allow': {} if default_action == 'allow' else None,
                    'Block': {} if default_action == 'block' else None,
                },
                Rules=[
                    {
                        'Name': 'AWS-Managed-CommonRuleSet',
                        'Priority': 1,
                        'Statement': {
                            'ManagedRuleGroupStatement': {
                                'VendorName': 'AWS',
                                'Name': 'AWSManagedRulesCommonRuleSet'
                            }
                        },
                        'OverrideAction': {'None': {}},
                        'VisibilityConfig': {
                            'SampledRequestsEnabled': True,
                            'CloudWatchMetricsEnabled': True,
                            'MetricName': 'AWS-Managed-CommonRuleSet-Metric'
                        }
                    },
                ],
                VisibilityConfig={
                    'SampledRequestsEnabled': True,
                    'CloudWatchMetricsEnabled': True,
                    'MetricName': f'{name}-metric'
                },
                LockToken=response['Summary']['LockToken']
            )
            print("AWSManagedRulesCommonRuleSet added to Web ACL.")

            return {"status": "success", "message": f"WAF Web ACL '{name}' created with common managed rules.", "web_acl_arn": web_acl_arn, "web_acl_id": web_acl_id}
        except Exception as e:
            print(f"Error during smart WAF Web ACL creation: {e}")
            return {"status": "error", "message": str(e)}

    def _troubleshoot_web_acl(self, region: str, name: str, scope: str = 'REGIONAL'):
        """
        Retrieves and displays details about a WAF Web ACL, including its rules and default action.
        """
        print(f"WAFAgent: Troubleshooting Web ACL '{name}' in region {region}...")
        try:
            waf_client = AWSConnector.get_client('wafv2', region_name=region)

            # Find the Web ACL by name and scope
            list_response = waf_client.list_web_acls(Scope=scope, Limit=100)
            web_acl_summary = next((acl for acl in list_response['WebACLs'] if acl['Name'] == name), None)

            if not web_acl_summary:
                return {"status": "error", "message": f"Web ACL '{name}' not found in region {region} with scope {scope}."}

            web_acl_id = web_acl_summary['Id']

            # Get detailed Web ACL information
            get_response = waf_client.get_web_acl(Name=name, Scope=scope, Id=web_acl_id)
            web_acl = get_response['WebACL']

            details = {
                "Name": web_acl['Name'],
                "ARN": web_acl['ARN'],
                "Id": web_acl['Id'],
                "Description": web_acl.get('Description', 'N/A'),
                "DefaultAction": web_acl['DefaultAction'],
                "Rules": [],
                "Capacity": web_acl['Capacity'],
                "VisibilityConfig": web_acl['VisibilityConfig'],
            }

            for rule in web_acl['Rules']:
                details['Rules'].append({
                    "Name": rule['Name'],
                    "Priority": rule['Priority'],
                    "Action": rule.get('Action', 'N/A'),
                    "StatementType": list(rule['Statement'].keys())[0] if rule['Statement'] else 'N/A',
                    "VisibilityConfig": rule['VisibilityConfig'],
                })

            return {"status": "success", "message": f"Details for Web ACL '{name}'.", "details": details}
        except Exception as e:
            print(f"Error troubleshooting WAF Web ACL: {e}")
            return {"status": "error", "message": str(e)}
