#!/usr/bin/env python3
"""
Smart GPU selector for cost-effective RunPod batch processing
Finds the cheapest available GPU that meets Stable Diffusion XL requirements
"""

import runpod
import config
import logging
from typing import List, Dict, Optional

logger = logging.getLogger(__name__)

class SmartGPUSelector:
    """Selects the most cost-effective available GPU for batch processing."""
    
    def __init__(self):
        """Initialize the GPU selector."""
        # Get API key
        self.api_key = config.IMAGE_GENERATION.get('runpod', {}).get('api_key', '')
        if not self.api_key:
            raise ValueError("RUNPOD_API_KEY not configured")
        
        runpod.api_key = "REDACTED_API_KEY"
        
        # GPU preferences optimized for EU-RO-1 availability and current pricing
        # Format: (gpu_id, min_memory_gb, price_tier)
        # Based on current EU-RO-1 availability data as of 2025-01-21
        self.gpu_preferences = [
            # PRIMARY CHOICE - Best value with high availability
            ('NVIDIA GeForce RTX 4090', 24, 'primary'),     # 24GB - $0.69/hr, HIGH availability (10 max)
            
            # SECONDARY CHOICES - Good alternatives with availability
            ('NVIDIA GeForce RTX 5090', 32, 'secondary'),   # 32GB - $0.94/hr, MEDIUM availability (2 max)
            ('NVIDIA RTX 6000 Ada', 48, 'secondary'),       # 48GB - $0.77/hr, currently unavailable but good value
            ('NVIDIA L40S', 48, 'secondary'),               # 48GB - $0.86/hr, currently unavailable but excellent
            
            # BUDGET ALTERNATIVES - Lower cost options
            ('NVIDIA RTX 4000 Ada', 20, 'budget'),          # 20GB - $0.26/hr, LOW availability (3 max)
            ('NVIDIA RTX 2000 Ada', 16, 'budget'),          # 16GB - $0.23/hr, LOW availability (4 max)
            
            # WORKSTATION ALTERNATIVES - Higher cost but more VRAM
            ('NVIDIA RTX PRO 6000', 96, 'workstation'),     # 96GB - $1.79/hr, LOW availability (1 max)
            ('NVIDIA L40', 48, 'workstation'),              # 48GB - $0.99/hr, currently unavailable
            
            # HIGH-END OPTIONS - Premium pricing but maximum performance
            ('NVIDIA H100 PCIe', 80, 'premium'),           # 80GB - $2.39/hr, currently unavailable
            ('NVIDIA H100 SXM', 80, 'premium'),            # 80GB - $2.69/hr, currently unavailable
            ('NVIDIA H100 NVL', 94, 'premium'),            # 94GB - $2.79/hr, currently unavailable
            ('NVIDIA H200 SXM', 141, 'premium'),           # 141GB - $3.99/hr, currently unavailable
            ('NVIDIA B200', 180, 'premium'),               # 180GB - $5.99/hr, LOW availability (7 max)
        ]
    
    def find_best_available_gpu(self, min_memory_gb: int = 8, preferred_region: str = "EU-RO-1") -> Optional[str]:
        """
        Find the best available GPU that meets requirements.
        
        Args:
            min_memory_gb: Minimum memory required (default: 8GB for SD 1.5)
            preferred_region: Preferred region for GPU availability (default: EU-RO-1)
    
        Returns:
            GPU ID of the best available option, or None if none available
        """
        logger.info(f"Searching for available GPUs with {min_memory_gb}GB+ memory, prioritizing region {preferred_region}...")
        
        try:
            # Get all available GPUs from RunPod
            all_gpus = runpod.get_gpus()
            gpu_info = {gpu.get('id'): gpu for gpu in all_gpus}
            
            # Test each GPU preference in order
            for gpu_id, memory_gb, price_tier in self.gpu_preferences:
                # Skip if doesn't meet memory requirements
                if memory_gb < min_memory_gb:
                    continue
                
                # Check if GPU exists in RunPod
                if gpu_id not in gpu_info:
                    logger.debug(f"GPU {gpu_id} not found in RunPod catalog")
                    continue
                
                # Test availability by attempting to create a pod
                if self._test_gpu_availability(gpu_id, preferred_region):
                    logger.info(f"Selected GPU: {gpu_id} ({memory_gb}GB, {price_tier} tier) in region {preferred_region}")
                    return gpu_id
                else:
                    logger.debug(f"GPU {gpu_id} is not available in {preferred_region}")
            
            logger.error("No suitable GPUs are currently available in the preferred region")
            return None
            
        except Exception as e:
            logger.error(f"Error finding available GPU: {str(e)}")
            return None
    
    def _test_gpu_availability(self, gpu_id: str, preferred_region: str = "EU-RO-1") -> bool:
        """
        Test if a GPU is available for pod creation in the specified region.
        
        Args:
            gpu_id: GPU ID to test
            preferred_region: Preferred region for pod creation (default: EU-RO-1)
    
        Returns:
            True if available, False otherwise
        """
        try:
            # Try to create a minimal pod configuration
            test_config = {
                "name": f"test-{gpu_id.replace(' ', '-').lower()}",
                "image_name": "runpod/stable-diffusion:web-ui-10.2.1",
                "gpu_type_id": gpu_id,
                "cloud_type": "SECURE",
                "volume_mount_path": "/workspace",
                "data_center_id": preferred_region
            }
            
            # Attempt to create pod (this will fail if GPU unavailable)
            pod = runpod.create_pod(**test_config)
            
            # If successful, immediately terminate to avoid charges
            if pod and 'id' in pod:
                runpod.terminate_pod(pod['id'])
                logger.debug(f"Test pod for {gpu_id} created and terminated successfully in {preferred_region}")
                return True
            
            return False
            
        except Exception as e:
            error_msg = str(e).lower()
            if "no longer any instances available" in error_msg:
                logger.debug(f"GPU {gpu_id} is out of stock in {preferred_region}")
            elif "no gpu found" in error_msg:
                logger.debug(f"GPU {gpu_id} has invalid ID")
            else:
                logger.debug(f"GPU {gpu_id} test failed in {preferred_region}: {str(e)}")
            return False
    
    def get_gpu_recommendations(self) -> Dict[str, List[str]]:
        """
        Get GPU recommendations by price tier.
        
        Returns:
            Dictionary with GPU recommendations by tier
        """
        recommendations = {
            'budget': [],
            'recommended': [],
            'high-end': []
        }
        
        for gpu_id, memory_gb, price_tier in self.gpu_preferences:
            recommendations[price_tier].append(f"{gpu_id} ({memory_gb}GB)")
        
        return recommendations

def test_smart_gpu_selector():
    """Test the smart GPU selector."""
    print("Testing Smart GPU Selector...")
    print()
    
    try:
        selector = SmartGPUSelector()
        
        # Test different memory requirements
        test_cases = [
            (10, "Minimum SDXL (10GB)"),
            (12, "Basic SDXL (12GB)"),
            (16, "Recommended SDXL (16GB)"),
            (20, "Optimal SDXL (20GB+)")
        ]
        
        for min_memory, description in test_cases:
            print(f"Testing {description}...")
            best_gpu = selector.find_best_available_gpu(min_memory)
            
            if best_gpu:
                print(f"Best available: {best_gpu}")
            else:
                print(f"No suitable GPUs available")
            print()
        
        # Show recommendations
        print("GPU Recommendations by Price Tier:")
        recommendations = selector.get_gpu_recommendations()
        
        for tier, gpus in recommendations.items():
            print(f"\n{tier.upper()} TIER:")
            for gpu in gpus:
                print(f"  - {gpu}")
        
    except Exception as e:
        print(f"Error testing GPU selector: {str(e)}")

if __name__ == "__main__":
    test_smart_gpu_selector()
