from .base_agent import BaseAgent
from ..aws_connector import AWSConnector
import time

class SageMakerAgent(BaseAgent):
    """
    An agent specialized in handling AWS SageMaker tasks.
    """

    def execute(self, command: str, **kwargs):
        """
        Executes a given command related to SageMaker.
        """
        if command == 'smart_create_notebook_instance':
            instance_name = kwargs.get('instance_name', '')
            region = kwargs.get('region', '')
            print(f"You are about to smart-create a SageMaker notebook instance '{instance_name}' in {region}.")
            confirm = input("This will create a new SageMaker notebook instance. Are you sure you want to proceed? (yes/no): ")
            if confirm.lower() == 'yes':
                return self._smart_create_notebook_instance(**kwargs)
            else:
                return {"status": "cancelled", "message": "Smart Create SageMaker notebook instance command cancelled by user."}
        else:
            raise NotImplementedError(f"Command '{command}' is not supported by SageMakerAgent.")

    def _smart_create_notebook_instance(self, region: str, instance_name: str, role_arn: str, instance_type: str = 'ml.t3.medium', volume_size_in_gb: int = 20):
        """
        Creates a SageMaker notebook instance.
        """
        print(f"SageMakerAgent: Smart creating notebook instance '{instance_name}' in region {region}...")
        try:
            sm_client = AWSConnector.get_client('sagemaker', region_name=region)

            # Create notebook instance
            response = sm_client.create_notebook_instance(
                NotebookInstanceName=instance_name,
                InstanceType=instance_type,
                RoleArn=role_arn,
                VolumeSizeInGB=volume_size_in_gb,
                Tags=[
                    {'Key': 'Name', 'Value': instance_name},
                ]
            )
            notebook_instance_arn = response['NotebookInstanceArn']

            # Wait for notebook instance to be InService
            print("Waiting for notebook instance to be InService... (this may take several minutes)")
            waiter = sm_client.get_waiter('notebook_instance_in_service')
            waiter.wait(NotebookInstanceName=instance_name)
            print(f"Notebook instance '{instance_name}' is now InService.")

            return {"status": "success", "message": f"SageMaker notebook instance '{instance_name}' created successfully.", "notebook_instance_arn": notebook_instance_arn}
        except Exception as e:
            print(f"Error during smart SageMaker notebook instance creation: {e}")
            return {"status": "error", "message": str(e)}
