
import boto3
import argparse

def scan_security_groups_for_public_access(
    region_name='us-east-1',
    report_file_path=None
):
    """
    Scans AWS Security Groups in a given region to identify rules that allow
    public access (0.0.0.0/0 or ::/0) to common sensitive ports.

    Args:
        region_name (str): The AWS region to scan.
        report_file_path (str, optional): Path to a file to save the report. If None, prints to console.
    """
    ec2_client = boto3.client('ec2', region_name=region_name)

    # Define common sensitive ports that should generally not be open to the public
    SENSITIVE_PORTS = {
        '20': 'FTP Data',
        '21': 'FTP Control',
        '22': 'SSH',
        '23': 'Telnet',
        '25': 'SMTP',
        '3306': 'MySQL',
        '3389': 'RDP',
        '5432': 'PostgreSQL',
        '1433': 'SQL Server',
        '1521': 'Oracle',
        '27017': 'MongoDB',
        '27018': 'MongoDB',
        '27019': 'MongoDB',
        '6379': 'Redis',
        '9200': 'Elasticsearch/OpenSearch',
        '9300': 'Elasticsearch/OpenSearch',
    }

    public_access_sgs = []

    print(f"Starting Security Group scan for public access in region {region_name}...")

    try:
        # 1. Describe all security groups in the region
        print("\n>>> Step 1: Describing all security groups...")
        response = ec2_client.describe_security_groups()
        security_groups = response['SecurityGroups']
        print(f"   Found {len(security_groups)} security groups.")

        # 2. Iterate through each security group and its inbound rules
        print("\n>>> Step 2: Analyzing inbound rules for public access...")
        for sg in security_groups:
            sg_id = sg['GroupId']
            sg_name = sg['GroupName']
            vpc_id = sg.get('VpcId', 'N/A')

            for ip_permission in sg['IpPermissions']:
                # Check for IPv4 public access (0.0.0.0/0)
                for ip_range in ip_permission.get('IpRanges', []):
                    if ip_range['CidrIp'] == '0.0.0.0/0':
                        port_range = f"{ip_permission.get('FromPort', 'All')}-{ip_permission.get('ToPort', 'All')}"
                        protocol = ip_permission.get('IpProtocol', 'All')
                        
                        # Check if it's a sensitive port
                        is_sensitive = False
                        if protocol.lower() in ['tcp', 'all']:
                            from_port = ip_permission.get('FromPort')
                            to_port = ip_permission.get('ToPort')
                            if from_port and to_port:
                                for p in range(from_port, to_port + 1):
                                    if str(p) in SENSITIVE_PORTS:
                                        is_sensitive = True
                                        break
                            elif str(from_port) in SENSITIVE_PORTS:
                                is_sensitive = True

                        public_access_sgs.append({
                            'GroupId': sg_id,
                            'GroupName': sg_name,
                            'VpcId': vpc_id,
                            'Protocol': protocol,
                            'PortRange': port_range,
                            'Source': ip_range['CidrIp'],
                            'Description': ip_permission.get('Description', ''),
                            'IsSensitivePort': is_sensitive
                        })

                # Check for IPv6 public access (::/0)
                for ipv6_range in ip_permission.get('Ipv6Ranges', []):
                    if ipv6_range['CidrIpv6'] == '::/0':
                        port_range = f"{ip_permission.get('FromPort', 'All')}-{ip_permission.get('ToPort', 'All')}"
                        protocol = ip_permission.get('IpProtocol', 'All')

                        is_sensitive = False
                        if protocol.lower() in ['tcp', 'all']:
                            from_port = ip_permission.get('FromPort')
                            to_port = ip_permission.get('ToPort')
                            if from_port and to_port:
                                for p in range(from_port, to_port + 1):
                                    if str(p) in SENSITIVE_PORTS:
                                        is_sensitive = True
                                        break
                            elif str(from_port) in SENSITIVE_PORTS:
                                is_sensitive = True

                        public_access_sgs.append({
                            'GroupId': sg_id,
                            'GroupName': sg_name,
                            'VpcId': vpc_id,
                            'Protocol': protocol,
                            'PortRange': port_range,
                            'Source': ipv6_range['CidrIpv6'],
                            'Description': ip_permission.get('Description', ''),
                            'IsSensitivePort': is_sensitive
                        })

    except Exception as e:
        print(f"Error describing security groups: {e}")
        return

    # 3. Generate Report
    print("\n>>> Step 3: Generating report...")
    report_output = []
    if public_access_sgs:
        report_output.append("--- Security Groups with Public Access ---")
        for sg_info in public_access_sgs:
            sensitive_warning = " (SENSITIVE PORT!)" if sg_info['IsSensitivePort'] else ""
            report_output.append(f"  SG ID: {sg_info['GroupId']}")
            report_output.append(f"  SG Name: {sg_info['GroupName']}")
            report_output.append(f"  VPC ID: {sg_info['VpcId']}")
            report_output.append(f"  Rule: Protocol={sg_info['Protocol']}, Port={sg_info['PortRange']}, Source={sg_info['Source']}{sensitive_warning}")
            if sg_info['Description']:
                report_output.append(f"  Description: {sg_info['Description']}")
            report_output.append("----------------------------------------")
    else:
        report_output.append("No security groups found with public access to sensitive ports.")

    if report_file_path:
        with open(report_file_path, 'w') as f:
            for line in report_output:
                f.write(line + '\n')
        print(f"Report saved to '{report_file_path}'.")
    else:
        for line in report_output:
            print(line)

    print("\nSecurity Group scan completed.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Scan AWS Security Groups for public access to sensitive ports.")
    parser.add_argument("--region", default="us-east-1", help="AWS region to scan (default: us-east-1).")
    parser.add_argument("--report-file-path", help="Optional. Path to a file to save the report. If not provided, prints to console.")

    args = parser.parse_args()

    scan_security_groups_for_public_access(
        region_name=args.region,
        report_file_path=args.report_file_path
    )
