from .base_agent import BaseAgent
from ..aws_connector import AWSConnector
import time
import pandas as pd

class AthenaAgent(BaseAgent):
    """
    An agent specialized in handling AWS Athena tasks.
    """

    def execute(self, command: str, **kwargs):
        """
        Executes a given command related to Athena.
        """
        if command == 'run_query':
            return self._run_query(**kwargs)
        else:
            raise NotImplementedError(f"Command '{command}' is not supported by AthenaAgent.")

    def _run_query(self, region: str, database: str, query_string: str, output_location: str):
        """
        Runs a SQL query on Athena and retrieves the results.
        """
        print(f"AthenaAgent: Running query on database '{database}' in region {region}...")
        try:
            athena_client = AWSConnector.get_client('athena', region_name=region)
            s3_client = AWSConnector.get_client('s3', region_name=region)

            # 1. Start Query Execution
            response = athena_client.start_query_execution(
                QueryString=query_string,
                QueryExecutionContext={
                    'Database': database
                },
                ResultConfiguration={
                    'OutputLocation': output_location,
                }
            )
            query_execution_id = response['QueryExecutionId']
            print(f"Query execution started with ID: {query_execution_id}")

            # 2. Wait for Query to Complete
            state = 'QUEUED'
            while state in ['QUEUED', 'RUNNING']:
                time.sleep(5) # Wait for 5 seconds
                response = athena_client.get_query_execution(QueryExecutionId=query_execution_id)
                state = response['QueryExecution']['Status']['State']
                print(f"Query state: {state}")

            if state == 'FAILED':
                failure_reason = response['QueryExecution']['Status']['StateChangeReason']
                return {"status": "error", "message": f"Query failed: {failure_reason}"}
            elif state == 'CANCELLED':
                return {"status": "error", "message": "Query cancelled."}

            # 3. Get Query Results
            # The results are stored in S3 at the output_location
            # We need to parse the S3 path to get bucket and key
            output_bucket = output_location.replace('s3://', '').split('/')[0]
            output_key = '/'.join(output_location.replace('s3://', '').split('/')[1:]) + f'/{query_execution_id}.csv'

            # Download the result file from S3
            local_file_name = f'athena_results_{query_execution_id}.csv'
            s3_client.download_file(output_bucket, output_key, local_file_name)

            # Read results using pandas (requires pandas to be installed)
            df = pd.read_csv(local_file_name)
            os.remove(local_file_name) # Clean up local file

            return {"status": "success", "message": f"Query '{query_execution_id}' completed successfully.", "results": df.to_dict(orient='records')}
        except Exception as e:
            print(f"Error running Athena query: {e}")
            return {"status": "error", "message": str(e)}
