"""
RunPod Batch Processor

This module handles batch processing of queued images using RunPod
for cost-efficient image generation.
"""

import logging
import asyncio
import aiohttp
import os
import json
import base64
import time
import requests
from datetime import datetime, date
from typing import Dict, Any, List, Optional
from io import BytesIO
from PIL import Image
import runpod
from smart_gpu_selector import SmartGPUSelector
from services.article_pipeline.image_queue_manager import ImageQueueManager
from services.article_pipeline.batch_post_processor import BatchPostProcessor
from services.database import get_db_connection
import config
from services.system_settings import SystemSettings

logger = logging.getLogger(__name__)


class RunPodBatchProcessor:
    """Handles batch processing of images using RunPod."""
    
    def __init__(self):
        """Initialize the batch processor."""
        self.queue_manager = ImageQueueManager()
        self.api_key = config.IMAGE_GENERATION.get('runpod', {}).get('api_key', '')
        
        if not self.api_key:
            raise ValueError("RUNPOD_API_KEY not configured")
        
        # Initialize RunPod
        runpod.api_key = self.api_key
        
        # Get RunPod config
        self.runpod_config = config.IMAGE_GENERATION.get('runpod', {})
        self.model_id = self.runpod_config.get('model_id', 'stabilityai/stable-diffusion-xl-base-1.0')
        self.volume_size = self.runpod_config.get('volume_size', 50)
        
        # Initialize smart GPU selector for cost-effective GPU selection
        self.gpu_selector = SmartGPUSelector()
        
        # Get preferred GPU from config, but allow dynamic selection
        self.preferred_gpu = self.runpod_config.get('gpu_type', 'NVIDIA RTX A4000')
        self.gpu_type = None  # Will be set dynamically
        
        # Template-based pod creation configuration
        self.template_id = self.runpod_config.get('template_id', 'iyz74sp6xl')  # Default template ID
        self.use_template = self.runpod_config.get('use_template', True)  # Enable template-based creation by default
        
        # GraphQL headers for template-based API calls
        self.graphql_headers = {
            "Authorization": "Bearer [REDACTED_API_KEY]",
            "Content-Type": "application/json"
        }
    
    async def process_daily_batch(self, scheduled_date: Optional[date] = None):
        """
        Process a daily batch of queued images.
        
        Args:
            scheduled_date: Date to process (defaults to today)
        """
        # Check if batch processing is enabled in database
        settings = SystemSettings()
        batch_enabled = await settings.get_setting('runpod_batch_enabled', True)
        
        if not batch_enabled:
            logger.info("RunPod batch processing is disabled via system settings")
            return
            
        # Also check environment variable as override
        if not config.RUNPOD_BATCH_ENABLED:
            logger.info("RunPod batch processing is disabled via RUNPOD_BATCH_ENABLED environment variable")
            return
            
        if scheduled_date is None:
            scheduled_date = date.today()
        
        logger.info(f"Starting daily batch processing for {scheduled_date}")
        
        try:
            # Get pending images
            pending_images = await self.queue_manager.get_pending_images(scheduled_date)
            
            if not pending_images:
                logger.info("No pending images to process")
                return
            
            logger.info(f"Found {len(pending_images)} pending images")
            
            # DEBUG: Limit to 5 images for testing (remove this section for production)
            DEBUG_LIMIT = 5
            if len(pending_images) > DEBUG_LIMIT:
                logger.info(f"Limiting to {DEBUG_LIMIT} images for testing")
                pending_images = pending_images[:DEBUG_LIMIT]
            
            # Create batch ID
            batch_id = await self.queue_manager.create_batch_id()
            
            # Process ALL images with a single pod to avoid restart delays
            logger.info(f"Processing all {len(pending_images)} images with a single pod")
            await self._process_all_images_single_pod(pending_images, batch_id)
            
            # Post-process completed images
            logger.info(f"Running post-processing for batch {batch_id}")
            post_processor = BatchPostProcessor()
            await post_processor.process_completed_images(batch_id)
            await post_processor.check_and_notify_failures(batch_id)
            
            logger.info(f"Batch {batch_id} processing complete")
                
        except Exception as e:
            logger.error(f"Batch processing failed: {str(e)}")
            raise
    
    async def _process_all_images_single_pod(self, images: List[Dict[str, Any]], batch_id: str):
        """
        Process ALL images using a single pod to avoid restart delays.
        
        Args:
            images: List of image requests to process
            batch_id: Batch identifier
        """
        pod = None
        try:
            logger.info(f"Creating single pod for {len(images)} images (batch {batch_id})")
            
            # Create one pod for all images
            pod = await self._create_pod()
            
            # Check if pod creation failed
            if not pod or not isinstance(pod, dict) or 'id' not in pod:
                logger.error(f"Pod creation failed - pod is None or invalid: {pod}")
                # Mark all images as failed
                for img in images:
                    await self.queue_manager.update_status(
                        img['id'],
                        'failed',
                        error_message="Pod creation failed"
                    )
                return
            
            logger.info(f"Pod created successfully: {pod['id']}")
            
            # Mark all images as processing
            for img in images:
                await self.queue_manager.update_status(
                    img['id'], 
                    'processing', 
                    batch_id=batch_id
                )
            
            # Wait for Stable Diffusion service to be ready before processing images
            logger.info(f"Waiting for Stable Diffusion service to be ready on pod {pod['id']}...")
            try:
                await self._wait_for_sd_service_ready_simple(pod)
                logger.info(f"Stable Diffusion service is ready on pod {pod['id']}")
            except Exception as e:
                logger.error(f"Stable Diffusion service failed to become ready: {str(e)}")
                # Mark all images as failed
                for img in images:
                    await self.queue_manager.update_status(
                        img['id'],
                        'failed',
                        error_message=f"SD service not ready: {str(e)}"
                    )
                return
            
            # Process each image using the same pod
            for i, img in enumerate(images):
                try:
                    logger.info(f"Processing image {i+1}/{len(images)}: {img['id']}")
                    await self._process_single_image(pod, img)
                    logger.info(f"Completed image {i+1}/{len(images)}: {img['id']}")
                except Exception as e:
                    logger.error(f"Failed to process image {img['id']}: {str(e)}")
                    await self.queue_manager.update_status(
                        img['id'],
                        'failed',
                        error_message=str(e)
                    )
            
            logger.info(f"Completed processing all {len(images)} images with single pod (batch {batch_id})")
            
        except Exception as e:
            logger.error(f"Single pod batch processing failed: {str(e)}")
            raise
        finally:
            # Always terminate the pod to avoid charges
            if pod:
                logger.info(f"Terminating single pod for batch {batch_id}")
                await self._terminate_pod(pod)
    
    async def _process_batch_chunk(self, images: List[Dict[str, Any]], batch_id: str):
        """
        Process a chunk of images using pod-based RunPod processing.
        
        Args:
            images: List of image requests to process
            batch_id: Unique identifier for this batch
        """
        pod = None
        try:
            # Create RunPod instance for batch processing
            logger.info(f"Creating RunPod instance for batch {batch_id} with {len(images)} images")
            pod = await self._create_pod()
            
            # Mark all images as processing
            for img in images:
                await self.queue_manager.update_status(
                    img['id'], 
                    'processing', 
                    batch_id=batch_id
                )
            
            # Process each image using the pod
            for i, img in enumerate(images):
                try:
                    logger.info(f"Processing image {i+1}/{len(images)}: {img['id']}")
                    await self._process_single_image(pod, img)
                    logger.info(f"Completed image {i+1}/{len(images)}: {img['id']}")
                except Exception as e:
                    logger.error(f"Failed to process image {img['id']}: {str(e)}")
                    await self.queue_manager.update_status(
                        img['id'],
                        'failed',
                        error_message=str(e)
                    )
            
            logger.info(f"Completed batch {batch_id}")
            
        except Exception as e:
            logger.error(f"Batch processing failed: {str(e)}")
            raise
        finally:
            # Always terminate the pod to avoid charges
            if pod:
                logger.info(f"Terminating RunPod instance for batch {batch_id}")
                await self._terminate_pod(pod)
    
    async def _terminate_pod(self, pod_info: Dict[str, Any]):
        """
        Terminate a RunPod instance to prevent ongoing charges.
        
        Args:
            pod_info: Pod information containing pod ID
        """
        try:
            pod_id = pod_info.get('id')
            if pod_id:
                logger.info(f"Terminating pod {pod_id}...")
                result = runpod.terminate_pod(pod_id)
                logger.info(f"Pod {pod_id} terminated successfully: {result}")
            else:
                logger.warning("No pod ID found in pod_info for termination")
        except Exception as e:
            logger.error(f"Failed to terminate pod {pod_id}: {str(e)}")
            # Don't raise - termination failure shouldn't break the batch process
    
    async def _process_single_image(self, pod_info: Dict[str, Any], image_request: Dict[str, Any]):
        """
        Process a single image using the RunPod instance.
        
        Args:
            pod_info: Pod information with connection details
            image_request: Image request data from queue
        """
        logger.info(f"ENTERING _process_single_image method")
        logger.info(f"pod_info keys: {list(pod_info.keys()) if pod_info else 'None'}")
        logger.info(f"image_request keys: {list(image_request.keys()) if image_request else 'None'}")
        
        image_id = image_request.get('id', 'unknown')
        logger.info(f"Extracted image_id: {image_id}")
        
        try:
            logger.info(f"Starting image processing for {image_id}")
            
            public_ip = pod_info.get('public_ip')
            exposed_port = pod_info.get('exposed_port')
            
            if not public_ip or not exposed_port:
                raise Exception(f"Pod missing connection info - IP: {public_ip}, Port: {exposed_port}")
            
            # Extract image generation parameters
            prompt = image_request.get('prompt', 'A professional image')
            logger.info(f"Prompt for {image_id}: {prompt[:50]}...")
            
            # First, test basic connectivity with multiple endpoints
            test_endpoints = [
                f"http://{public_ip}:{exposed_port}/sdapi/v1/options",
                f"http://{public_ip}:{exposed_port}/",
                f"http://{public_ip}:{exposed_port}/docs"
            ]
            
            timeout = aiohttp.ClientTimeout(total=30)  # 30 second timeout for test
            api_available = False
            
            for test_url in test_endpoints:
                logger.info(f"Testing connectivity to: {test_url}")
                try:
                    async with aiohttp.ClientSession(timeout=timeout) as session:
                        async with session.get(test_url) as response:
                            logger.info(f"Connectivity test result for {test_url}: {response.status}")
                            if response.status == 200:
                                if 'sdapi/v1/options' in test_url:
                                    api_available = True
                                    logger.info(f"API endpoints are available!")
                                    break
                                else:
                                    logger.info(f"Service is running but API may not be enabled")
                except Exception as e:
                    logger.warning(f"Connectivity test failed for {test_url}: {str(e)}")
            
            if not api_available:
                logger.warning(f"API endpoints not available, but proceeding with image generation attempt")
            
            # Prepare payload for Stable Diffusion API (using default SD 1.5 model)
            payload = {
                "prompt": prompt,
                "steps": 20,
                "width": 512,
                "height": 512,
                "cfg_scale": 7.0,
                "sampler_name": "Euler a"
            }
            
            # Make API call to generate image
            generation_url = f"http://{public_ip}:{exposed_port}/sdapi/v1/txt2img"
            logger.info(f"Making API call to: {generation_url}")
            logger.info(f"Payload: {payload}")
            
            # Use longer timeout for SDXL generation
            timeout = aiohttp.ClientTimeout(total=180)  # 180 second timeout for SDXL
            
            async with aiohttp.ClientSession(timeout=timeout) as session:
                logger.info(f"Sending POST request for image {image_id}...")
                logger.info(f"Request started at: {datetime.now()}")
                
                try:
                    async with session.post(generation_url, json=payload) as response:
                        logger.info(f"Received response for {image_id}: status {response.status}")
                        logger.info(f"Response received at: {datetime.now()}")
                        
                        if response.status == 200:
                            result = await response.json()
                            if "images" in result and result["images"]:
                                # Save the generated image
                                image_data = result["images"][0]
                                await self._save_generated_image(image_id, image_data, prompt)
                                
                                # Update status to completed
                                await self.queue_manager.update_status(image_id, 'completed')
                                logger.info(f"Successfully generated image {image_id}")
                            else:
                                raise Exception("No images in API response")
                        else:
                            error_text = await response.text()
                            raise Exception(f"API returned status {response.status}: {error_text}")
                            
                except asyncio.TimeoutError:
                    raise Exception(f"Request timeout after 60 seconds for {generation_url}")
                except aiohttp.ClientError as e:
                    raise Exception(f"HTTP client error: {str(e)}")
                        
        except Exception as e:
            logger.error(f"Failed to process image {image_id}: {str(e)}")
            await self.queue_manager.update_status(
                image_id,
                'failed',
                error_message=str(e)
            )
            # Don't re-raise - continue with next image
            return
    
    async def _save_generated_image(self, image_id: str, image_data: str, prompt: str):
        """
        Save the generated image data to file and update database.
        
        Args:
            image_id: Image queue ID
            image_data: Base64 encoded image data
            prompt: Original prompt used for generation
        """
        try:
            import base64
            import os
            from datetime import datetime
            
            # Decode base64 image data
            image_bytes = base64.b64decode(image_data)
            
            # Generate filename
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"generated_{image_id}_{timestamp}.jpg"
            
            # Ensure directory exists
            output_dir = "/images/generated"
            os.makedirs(output_dir, exist_ok=True)
            
            # Save image file
            file_path = os.path.join(output_dir, filename)
            with open(file_path, "wb") as f:
                f.write(image_bytes)
            
            # Update database with file path
            await self.queue_manager.update_image_path(image_id, f"/images/generated/{filename}")
            
            logger.info(f"Saved generated image to: {file_path}")
            
        except Exception as e:
            logger.error(f"Failed to save generated image {image_id}: {str(e)}")
            raise
    
    async def _create_pod(self):
        """
        Create a RunPod instance for batch processing using template-based approach.
        
        Returns:
            Pod information with connection details
        """
        try:
            logger.info("Creating new RunPod instance for batch processing")
            
            # Use template-based pod creation with explicit region and volume settings
            pod = await self._create_pod_from_template('iyz74sp6xl')
            
            if not pod or 'id' not in pod:
                logger.error("Pod creation failed, no pod ID returned")
                return None
            
            logger.info(f"Created pod with ID: {pod['id']}")
            
            # Wait for pod to be ready and extract connection information
            logger.info(f"Waiting for pod {pod['id']} to be ready...")
            pod_with_connection = await self._wait_for_pod_ready(pod['id'])
            
            if not pod_with_connection:
                logger.error(f"Pod {pod['id']} failed to become ready")
                return None
                
            logger.info(f"Pod {pod['id']} is ready with connection info")
            return pod_with_connection
            
        except Exception as e:
            logger.error(f"Failed to create pod: {str(e)}")
            return None
    
    async def _create_pod_from_template(self, template_id: str):
        """
        Create a pod using template with fallback to SDK if GraphQL fails.
        
        Args:
            template_id: Template ID to use (iyz74sp6xl for sdxl-batch-run-v2)
        
        Returns:
            Pod data
        """
        # First try GraphQL approach
        try:
            data_center_id = "EU-RO-1"
            network_volume_id = "sftk35ogwh"  # sdxl-model-store volume ID
            
            logger.info(f"Attempting GraphQL pod creation in region {data_center_id} with network volume {network_volume_id}")
            
            mutation = """
            mutation podFindAndDeployOnDemand($input: PodFindAndDeployOnDemandInput!) {
                podFindAndDeployOnDemand(input: $input) {
                    id
                    desiredStatus
                }
            }
            """
            
            variables = {
                "input": {
                    "name": f"batch-pod-{int(time.time())}",
                    "imageName": "runpod/stable-diffusion:web-ui-10.2.1",
                    "gpuTypeId": "NVIDIA GeForce RTX 4090",
                    "containerDiskInGb": 50,
                    "dataCenterId": data_center_id,
                    "networkVolumeId": network_volume_id,
                    "volumeMountPath": "/workspace",
                    "cloudType": "SECURE"
                }
            }
            
            payload = {"query": mutation, "variables": variables}
            response = requests.post("https://api.runpod.io/graphql", headers=self.graphql_headers, json=payload)
            response_data = response.json()
            
            if response.status_code == 200 and 'errors' not in response_data:
                pod_data = response_data['data']['podFindAndDeployOnDemand']
                logger.info(f"Pod created via GraphQL: {pod_data['id']}")
                return pod_data
            else:
                error_msg = response_data.get('errors', [{}])[0].get('message', 'Unknown error')
                logger.warning(f"GraphQL pod creation failed: {error_msg}")
                raise Exception(f"GraphQL failed: {error_msg}")
                
        except Exception as graphql_error:
            logger.warning(f"GraphQL approach failed: {graphql_error}")
            logger.info("Falling back to RunPod SDK with template...")
            
            # Fallback to RunPod SDK approach using your working template
            try:
                # Use the GPU selector to get the best available GPU
                gpu_type = self.gpu_selector.find_best_available_gpu(min_memory_gb=24, preferred_region="EU-RO-1")
                if not gpu_type:
                    gpu_type = "NVIDIA GeForce RTX 4090"  # Fallback to known available GPU
                
                logger.info(f"Using template {template_id} with GPU {gpu_type}")
                
                # First try with network volume
                pod_config = {
                    "name": f"batch-pod-{int(time.time())}",
                    "image_name": "runpod/stable-diffusion:web-ui-10.2.1",
                    "gpu_type_id": gpu_type,
                    "cloud_type": "SECURE",
                    "support_public_ip": True,
                    "ports": "7860/http",
                    "volume_in_gb": 50,
                    "volume_mount_path": "/workspace",
                    "data_center_id": "EU-RO-1"  # Enforce EU-RO-1 region for network volume compatibility
                }
                
                # Try to add network volume - different SDKs may use different parameter names
                try:
                    pod_config["network_volume_id"] = "sftk35ogwh"
                    logger.info(f"Attempting SDK pod creation with network volume: {pod_config}")
                    pod = runpod.create_pod(**pod_config)
                    logger.info(f"Pod created via SDK with network volume: {pod['id']}")
                    return pod
                except Exception as volume_error:
                    logger.warning(f"SDK pod creation with network volume failed: {volume_error}")
                    logger.info("Retrying SDK pod creation without network volume...")
                    
                    # Fallback: Create pod without network volume (will download models)
                    pod_config_no_volume = {
                        "name": f"batch-pod-{int(time.time())}",
                        "image_name": "runpod/stable-diffusion:web-ui-10.2.1",
                        "gpu_type_id": gpu_type,
                        "cloud_type": "SECURE",
                        "support_public_ip": True,
                        "ports": "7860/http",
                        "volume_in_gb": 50,
                        "volume_mount_path": "/workspace",
                        "data_center_id": "EU-RO-1"
                    }
                    
                    pod = runpod.create_pod(**pod_config_no_volume)
                    logger.info(f"Pod created via SDK without network volume: {pod['id']}")
                    logger.warning("Pod created without network volume - models will be downloaded (slower startup)")
                    return pod
                
            except Exception as sdk_error:
                logger.error(f"SDK pod creation also failed: {sdk_error}")
                return None
    
    async def _wait_for_pod_ready(self, pod_id: str, timeout: int = 600) -> Dict[str, Any]:
        """
        Wait for a RunPod instance to be ready and return pod info with public IP/port.
        
        Args:
            pod_id: Pod ID to wait for
            timeout: Timeout in seconds
            
        Returns:
            Pod information including public IP and port
        """
        logger.info(f"Waiting for pod {pod_id} to be ready...")
        
        start_time = datetime.utcnow()
        
        while (datetime.utcnow() - start_time).seconds < timeout:
            try:
                # Use RunPod API to get pod status
                pod = runpod.get_pod(pod_id)
                # Only log full pod response on first few attempts to reduce noise
                elapsed_seconds = (datetime.utcnow() - start_time).seconds
                if elapsed_seconds < 60:  # First minute only
                    logger.info(f"Pod status response: {pod}")
                else:
                    logger.info(f"Pod {pod_id} status check (elapsed: {elapsed_seconds}s)")
                
                # Check if pod exists and has the expected structure
                if pod and isinstance(pod, dict):
                    status = pod.get('status', 'UNKNOWN')
                    logger.info(f"Pod {pod_id} status: {status} (raw response keys: {list(pod.keys())})")
                    
                    # Check for alternative status fields
                    alt_status = pod.get('desiredStatus') or pod.get('runtime', {}).get('status')
                    if alt_status:
                        logger.info(f"Pod {pod_id} alternative status: {alt_status}")
                        # Use alternative status if main status is UNKNOWN
                        if status == 'UNKNOWN' and alt_status != 'UNKNOWN':
                            status = alt_status
                            logger.info(f"Using alternative status: {status}")
                    
                    if status == 'RUNNING':
                        logger.info(f"Pod {pod_id} is ready and running")
                        
                        # PRIORITY 1: Check for endpoint field (working example approach)
                        endpoint = pod.get('endpoint')
                        if endpoint:
                            logger.info(f"Found endpoint field: {endpoint}")
                            # Parse endpoint to extract IP and port
                            if endpoint.startswith('http://'):
                                endpoint_clean = endpoint.replace('http://', '')
                                if ':' in endpoint_clean:
                                    public_ip, port_str = endpoint_clean.split(':', 1)
                                    exposed_port = int(port_str)
                                else:
                                    public_ip = endpoint_clean
                                    exposed_port = 80  # Default HTTP port
                                
                                logger.info(f"  - Using ENDPOINT connection: {public_ip}:{exposed_port}")
                                
                                pod_info = {
                                    'id': pod_id,
                                    'status': status,
                                    'public_ip': public_ip,
                                    'exposed_port': exposed_port,
                                    'endpoint': endpoint,
                                    'raw_pod_data': pod,
                                    'connection_failed': False
                                }
                                logger.info(f"Pod {pod_id} ready - Endpoint: {endpoint}")
                                return pod_info
                        
                        # PRIORITY 2: Extract public IP and port information from runtime.ports
                        public_ip = None
                        exposed_port = None
                        legacy_ports = []  # Initialize to avoid scope issues
                        
                        # Check runtime.ports for actual network information
                        runtime_data = pod.get('runtime')
                        runtime_ports = []
                        
                        if runtime_data is not None:
                            runtime_ports = runtime_data.get('ports', [])
                            logger.info(f"Pod {pod_id} runtime ports: {runtime_ports}")
                        else:
                            logger.info(f"Pod {pod_id} runtime data not available yet, checking for direct IP in other fields...")
                            # Sometimes RunPod provides IP info in other fields when runtime isn't ready
                            machine_data = pod.get('machine', {})
                            if machine_data.get('publicIp'):
                                logger.info(f"Found direct public IP in machine data: {machine_data.get('publicIp')}")
                                public_ip = machine_data.get('publicIp')
                                exposed_port = 7860  # SD Web UI standard port
                                logger.info(f"  - Using DIRECT IP from machine: {public_ip}:{exposed_port}")
                                
                                pod_info = {
                                    'id': pod_id,
                                    'status': status,
                                    'public_ip': public_ip,
                                    'exposed_port': exposed_port,
                                    'ports': [],
                                    'raw_pod_data': pod,
                                    'connection_failed': False
                                }
                                logger.info(f"Pod {pod_id} ready - Direct IP: {public_ip}, Port: {exposed_port}")
                                return pod_info
                            
                            # FIXED: Find the actual SD Web UI port - prioritize 3001 over 3000 and 7860
                            # Pod logs show SD Web UI runs on port 3001 in some configurations
                            sd_port_info = None
                            
                            # First, look for port 3001 (where SD Web UI actually runs in some configs)
                            for port_info in runtime_ports:
                                private_port = port_info.get('privatePort')
                                if private_port == 3001:
                                    sd_port_info = port_info
                                    logger.info(f"Found SD Web UI on port {private_port}: {port_info}")
                                    break
                            
                            # If port 3001 not found, try port 3000
                            if not sd_port_info:
                                for port_info in runtime_ports:
                                    private_port = port_info.get('privatePort')
                                    if private_port == 3000:
                                        sd_port_info = port_info
                                        logger.info(f"Found SD Web UI on port {private_port}: {port_info}")
                                        break
                            
                            # If port 3000 not found, fall back to 7860
                            if not sd_port_info:
                                for port_info in runtime_ports:
                                    private_port = port_info.get('privatePort')
                                    if private_port == 7860:
                                        sd_port_info = port_info
                                        logger.info(f"Found SD Web UI on fallback port {private_port}: {port_info}")
                                        break
                            
                            if sd_port_info:
                                runtime_ip = sd_port_info.get('ip')
                                exposed_port = sd_port_info.get('publicPort') or sd_port_info.get('privatePort')
                                is_public = sd_port_info.get('isIpPublic', False)
                                
                                logger.info(f"  - Port info: IP={runtime_ip}, Port={exposed_port}, isPublic={is_public}")
                                
                                # CRITICAL FIX: Always prefer direct IP connection to avoid proxy issues
                                if runtime_ip:
                                    # Use direct IP connection (works for both public and private IPs)
                                    public_ip = runtime_ip
                                    # Use the ACTUAL exposed port from RunPod (not forced 3000)
                                    exposed_port = sd_port_info.get('publicPort') or sd_port_info.get('privatePort') or 3000
                                    logger.info(f"  - Using DIRECT connection: {public_ip}:{exposed_port} (isPublic: {is_public}) [ACTUAL EXPOSED PORT]")
                                    logger.info(f"  - Port mapping: publicPort={sd_port_info.get('publicPort')} -> privatePort={sd_port_info.get('privatePort')}")
                                    logger.info(f"  - Runtime ports available: {[p.get('publicPort') for p in runtime_ports]}")
                                else:
                                    # Only use proxy as absolute last resort
                                    logger.warning(f"  - No runtime IP available, falling back to proxy")
                                    pod_host_id = pod.get('machine', {}).get('podHostId')
                                    if pod_host_id:
                                        public_ip = f"{pod_host_id}.proxy.runpod.net"
                                        exposed_port = 3000  # RunPod service runs on port 3000 internally
                                        logger.info(f"  - Using PROXY connection: {public_ip}:{exposed_port}")
                                    else:
                                        logger.error(f"  - No podHostId available for proxy connection!")
                                        public_ip = None
                                    
                                if public_ip:
                                    logger.info(f"  - Connected to SD Web UI: {public_ip}:{exposed_port}")
                                    
                                    # Create pod_info immediately after successful connection
                                    pod_info = {
                                        'id': pod_id,
                                        'status': status,
                                        'public_ip': public_ip,
                                        'exposed_port': exposed_port,
                                        'ports': runtime_ports,
                                        'raw_pod_data': pod,
                                        'connection_failed': False
                                    }
                        
                        # PRIORITY 3: Fallback to proxy connection if runtime ports not available
                        if not public_ip or not exposed_port:
                            logger.info(f"Pod {pod_id} runtime ports not available, using proxy connection")
                            legacy_ports = pod.get('ports', [])
                            logger.info(f"  - Legacy ports: {legacy_ports} (type: {type(legacy_ports)})")
                            
                            # Try multiple ways to extract podHostId for RunPod proxy connection
                            pod_host_id = None
                            
                            # Method 1: Check machine.podHostId (most common)
                            if pod.get('machine', {}).get('podHostId'):
                                pod_host_id = pod.get('machine', {}).get('podHostId')
                                logger.info(f"  - Found podHostId in machine: {pod_host_id}")
                            
                            # Method 2: Try to construct from pod ID and machine ID
                            elif pod.get('id') and pod.get('machineId'):
                                # Sometimes podHostId follows pattern: {pod_id}-{machine_suffix}
                                pod_host_id = f"{pod['id']}-{pod['machineId'][:8]}"
                                logger.info(f"  - Constructed podHostId: {pod_host_id}")
                            
                            # Method 3: Use pod ID as fallback
                            elif pod.get('id'):
                                pod_host_id = pod['id']
                                logger.info(f"  - Using pod ID as podHostId: {pod_host_id}")
                            
                            if pod_host_id:
                                public_ip = f"{pod_host_id}.proxy.runpod.net"
                                exposed_port = 7860  # SD Web UI standard port
                                logger.info(f"  - Using proxy connection: {public_ip}:{exposed_port}")
                            else:
                                logger.error(f"  - No podHostId available for proxy connection!")
                                logger.error(f"  - Pod keys available: {list(pod.keys())}")
                                logger.error(f"  - Machine data: {pod.get('machine', {})}")
                            
                            if public_ip:
                                logger.info(f"  - Connected to SD Web UI: {public_ip}:{exposed_port}")
                                logger.info(f"DEBUG: About to create pod_info with pod_host_id: {pod_host_id}")
                                pod_info = {
                                    'id': pod_id,
                                    'status': status,
                                    'public_ip': public_ip,
                                    'exposed_port': exposed_port,
                                    'ports': runtime_ports,
                                    'raw_pod_data': pod,
                                    'connection_failed': False,
                                    'pod_host_id': pod_host_id  # Include podHostId for proxy connection
                                }
                            else:
                                logger.warning(f"Runtime port missing IP: IP={public_ip}")
                                pod_info = {
                                    'id': pod_id,
                                    'status': status,
                                    'public_ip': None,
                                    'exposed_port': None,
                                    'ports': runtime_ports,
                                    'raw_pod_data': pod,
                                    'connection_failed': True,
                                    'pod_host_id': pod_host_id  # Include podHostId for proxy connection
                                }
                        
                        logger.info(f"Pod {pod_id} ready - IP: {public_ip}, Port: {exposed_port}")
                        return pod_info
                        
                    elif status in ['FAILED', 'TERMINATED', 'EXITED']:
                        raise RuntimeError(f"Pod {pod_id} failed with status: {status}")
                    else:
                        logger.info(f"Pod {pod_id} status: {status}, waiting...")
                elif pod is None:
                    logger.warning(f"Pod {pod_id} not found in API response - may still be initializing")
                else:
                    logger.warning(f"Unexpected pod response format: {pod}")
                    
            except Exception as e:
                logger.warning(f"Error checking pod status: {str(e)}")
                # Continue trying - API errors are often temporary
                
            await asyncio.sleep(10)  # Wait 10 seconds between checks
        
        raise TimeoutError(f"Pod {pod_id} did not become ready within {timeout} seconds")
    
    async def _wait_for_sd_service_ready_simple(self, pod_info: Dict[str, Any]):
        """
        Improved service readiness check with endpoint priority and active polling.
        Based on working RunPod patterns.
        
        Args:
            pod_info: Pod information with connection details
        """
        pod_id = pod_info.get('id')
        raw_pod_data = pod_info.get('raw_pod_data', {})
        
        # Priority 1: Use official endpoint if available
        pod_url = raw_pod_data.get("endpoint")
        
        # Priority 2: Use proxy connection
        if not pod_url:
            # Use the podHostId from pod_info
            pod_host_id = pod_info.get('pod_host_id')
            
            logger.info(f"DEBUG: Using podHostId: {pod_host_id}")
            
            if pod_host_id:
                # Try HTTPS first (RunPod proxy often requires HTTPS)
                pod_urls = [
                    f"https://{pod_host_id}.proxy.runpod.net:7860",
                    f"http://{pod_host_id}.proxy.runpod.net:7860"
                ]
                logger.info(f"Will try proxy connections: {pod_urls}")
            else:
                # Priority 3: Use direct IP connection
                public_ip = pod_info.get('public_ip')
                exposed_port = pod_info.get('exposed_port')
                if public_ip and exposed_port:
                    pod_urls = [f"http://{public_ip}:{exposed_port}"]
                    logger.info(f"Using direct IP connection: {pod_urls[0]}")
                else:
                    raise Exception(f"No valid connection method available for pod {pod_id}")
        else:
            pod_urls = [pod_url]
            logger.info(f"Using official endpoint: {pod_url}")
        
        # Active polling for SD WebUI readiness with protocol fallback
        logger.info(f"Actively polling for SD WebUI readiness...")
        
        for attempt in range(60):  # 60 attempts * 10s = 10 minutes max
            # Try each URL (HTTPS first, then HTTP)
            for url_index, base_url in enumerate(pod_urls):
                try:
                    test_url = f"{base_url}/sdapi/v1/options"
                    timeout = aiohttp.ClientTimeout(total=10, connect=5)
                    async with aiohttp.ClientSession(timeout=timeout) as session:
                        async with session.get(test_url) as response:
                            if response.status == 200:
                                logger.info(f"SD WebUI is online after {(attempt + 1) * 10} seconds using {base_url}")
                                return
                            else:
                                logger.info(f"Attempt {attempt + 1}/60: URL {url_index + 1}/{len(pod_urls)} Status {response.status}")
                except Exception as e:
                    logger.info(f"Attempt {attempt + 1}/60: URL {url_index + 1}/{len(pod_urls)} {type(e).__name__}: {str(e)}")
                    # Continue to next URL
                    continue
            
            # If we get here, all URLs failed for this attempt
            
            if attempt < 59:  # Don't sleep after the last attempt
                await asyncio.sleep(10)
        
        # If we get here, WebUI never responded after 10 minutes
        logger.error(f"WebUI never responded after 10 minutes. Terminating pod {pod_id}")
        try:
            runpod.terminate_pod(pod_id)
            logger.info(f"Successfully terminated unresponsive pod {pod_id}")
        except Exception as cleanup_error:
            logger.error(f"Failed to terminate unresponsive pod {pod_id}: {cleanup_error}")
        
        raise Exception(f"SD WebUI failed to respond after 3 minutes on pod {pod_id}")
    
    async def _wait_for_sd_service_ready(self, pod_info: Dict[str, Any], timeout: int = 900):
        """
        Wait for Stable Diffusion service to be ready inside the pod.
        
        Args:
            pod_info: Pod information with connection details
            timeout: Timeout in seconds (default 15 minutes for SDXL model loading)
        """
        public_ip = pod_info.get('public_ip')
        exposed_port = pod_info.get('exposed_port')
        
        if not public_ip or not exposed_port:
            raise Exception(f"Pod missing connection info for SD service check - IP: {public_ip}, Port: {exposed_port}")
        
        # Try multiple health check endpoints
        health_endpoints = [
            "/sdapi/v1/options",
            "/docs",
            "/",
            "/sdapi/v1/progress"
        ]
        
        logger.info(f"Waiting for Stable Diffusion service to be ready at {public_ip}:{exposed_port}...")
        
        # First, test basic connectivity
        logger.info(f"Testing basic connectivity to {public_ip}:{exposed_port}...")
        
        start_time = datetime.utcnow()
        
        while (datetime.utcnow() - start_time).seconds < timeout:
            for endpoint in health_endpoints:
                health_url = f"http://{public_ip}:{exposed_port}{endpoint}"
                try:
                    async with aiohttp.ClientSession() as session:
                        async with session.get(health_url, timeout=30) as response:
                            logger.info(f"Endpoint {endpoint} returned status: {response.status}")
                            if response.status == 200:
                                logger.info(f"Stable Diffusion service is ready! Endpoint: {endpoint}")
                                return
                            elif response.status in [404, 405]:  # Expected for some endpoints
                                continue
                            else:
                                logger.info(f"Endpoint {endpoint} returned unexpected status: {response.status}")
                except (aiohttp.ClientError, asyncio.TimeoutError) as e:
                    logger.info(f"Endpoint {endpoint} error: {type(e).__name__}: {str(e)}")
                    continue
            
            logger.info(f"All endpoints failed, waiting 30 seconds before retry...")
            await asyncio.sleep(30)
        
        raise TimeoutError(f"Stable Diffusion service did not become ready within {timeout} seconds")
    
    async def _terminate_pod(self, pod: Dict[str, Any]):
        """
        Terminate a RunPod instance.
        
        Args:
            pod: Pod information
        """
        try:
            runpod.terminate_pod(pod['id'])
            logger.info(f"Terminated pod {pod['id']}")
        except Exception as e:
            logger.error(f"Failed to terminate pod: {str(e)}")
    
    async def _process_single_image(self, pod: Dict[str, Any], image_request: Dict[str, Any]):
        """
        Process a single image using the pod-based approach.
        
        Args:
            pod: Pod information with connection details
            image_request: Image request from queue
        """
        try:
            # Check if connection info extraction failed
            if pod.get('connection_failed', False):
                raise Exception(f"Pod connection info extraction failed - IP: {pod.get('public_ip')}, Port: {pod.get('exposed_port')}")
            
            # Get connection info from pod
            public_ip = pod.get('public_ip')
            exposed_port = pod.get('exposed_port')
            
            if not public_ip or not exposed_port:
                raise Exception(f"Pod missing connection info - IP: {public_ip}, Port: {exposed_port}")
            
            # Get dimensions (handle both dict and JSON string formats)
            dimensions_data = image_request.get('dimensions', {})
            if isinstance(dimensions_data, str):
                dimensions = json.loads(dimensions_data) if dimensions_data else {}
            elif isinstance(dimensions_data, dict):
                dimensions = dimensions_data
            else:
                dimensions = {}
            
            width = dimensions.get('width', 1024)
            height = dimensions.get('height', 1024)
            
            # Prepare the payload optimized for OFA-Sys/small-stable-diffusion-v0 (lightweight model)
            payload = {
                "prompt": image_request['prompt'],
                "negative_prompt": "blurry, low quality, distorted, deformed, worst quality",
                "width": width,
                "height": height,
                "steps": 20,  # Reduced steps for faster generation with lightweight model
                "cfg_scale": 7.0,  # Slightly lower for better performance
                "sampler_name": "Euler a",  # Faster sampler for lightweight model
                "batch_size": 1,
                "n_iter": 1,
                "seed": -1,  # Random seed
                "restore_faces": False,  # Disable for speed
                "tiling": False,  # Disable for speed
                "do_not_save_samples": True,  # Don't save to disk on pod
                "do_not_save_grid": True  # Don't save grid on pod
            }
            
            # Make the request to the pod's Stable Diffusion Web UI API with retry logic
            api_url = f"http://{public_ip}:{exposed_port}/sdapi/v1/txt2img"
            logger.info(f"Making request to pod API: {api_url}")
            
            # Retry logic for when SD service is starting up
            max_retries = 3
            retry_delay = 30  # seconds
            
            for attempt in range(max_retries):
                try:
                    async with aiohttp.ClientSession() as session:
                        headers = {"Content-Type": "application/json"}
                        
                        async with session.post(api_url, json=payload, headers=headers, timeout=300) as response:
                            if response.status != 200:
                                error_text = await response.text()
                                raise Exception(f"API returned status {response.status}: {error_text}")
                            
                            result = await response.json()
                            break  # Success, exit retry loop
                            
                except (aiohttp.ClientError, asyncio.TimeoutError) as e:
                    if attempt < max_retries - 1:
                        logger.warning(f"Attempt {attempt + 1} failed: {str(e)}. Retrying in {retry_delay} seconds...")
                        await asyncio.sleep(retry_delay)
                    else:
                        logger.error(f"All {max_retries} attempts failed. Final error: {str(e)}")
                        raise Exception(f"Failed to connect to pod after {max_retries} attempts: {str(e)}")
                    
                    # Handle Stable Diffusion Web UI response format
                    if 'images' not in result or not result['images']:
                        raise Exception(f"No images in response: {result}")
                    
                    # Get the first image (base64 encoded)
                    image_b64 = result['images'][0]
                    
                    # Decode base64 image
                    image_bytes = base64.b64decode(image_b64)
                    
                    # Generate filename
                    timestamp = int(time.time())
                    filename = f"generated_{image_request['id']}_{timestamp}.jpg"
                    
                    # Save image
                    await self._save_image(image_bytes, filename)
                    
                    # Update database with image URL
                    image_url = f"/images/generated/{filename}"
                    await self.queue_manager.update_status(
                        image_request['id'],
                        'completed',
                        image_url=image_url
                    )
                    
                    logger.info(f"Successfully processed image {image_request['id']} -> {image_url}")
                    
        except Exception as e:
            logger.error(f"Failed to process image {image_request['id']}: {str(e)}")
            await self.queue_manager.update_status(
                image_request['id'],
                'failed',
                error_message=str(e)
            )
            raise
    
    async def _process_single_image_serverless(self, image_request: Dict[str, Any]):
        """
        Process a single image using RunPod serverless endpoint.
        
        Args:
            image_request: Image request from queue
        """
        try:
            # Get dimensions (handle both dict and JSON string formats)
            dimensions_data = image_request.get('dimensions', {})
            if isinstance(dimensions_data, str):
                # If it's a JSON string, parse it
                dimensions = json.loads(dimensions_data) if dimensions_data else {}
            elif isinstance(dimensions_data, dict):
                # If it's already a dict, use it directly
                dimensions = dimensions_data
            else:
                # Fallback to empty dict
                dimensions = {}
            
            width = dimensions.get('width', 1024)
            height = dimensions.get('height', 1024)
            
            # Get endpoint URL from config
            endpoint_url = self.runpod_config.get('endpoint_id', '')
            if not endpoint_url:
                raise Exception("RUNPOD_ENDPOINT_ID not configured")
            
            # Prepare the payload for RunPod serverless endpoint
            payload = {
                "input": {
                    "prompt": image_request['prompt'],
                    "negative_prompt": "blurry, low quality, distorted, deformed",
                    "width": width,
                    "height": height,
                    "num_inference_steps": 30,
                    "guidance_scale": 7.5,
                    "num_outputs": 1,
                    "scheduler": "DPMSolverMultistep"
                }
            }
            
            # Make the request to RunPod serverless endpoint
            async with aiohttp.ClientSession() as session:
                headers = {
                    "Content-Type": "application/json",
                    "Authorization": "Bearer [REDACTED_API_KEY]"
                }
                
                async with session.post(endpoint_url, json=payload, headers=headers) as response:
                    if response.status != 200:
                        error_text = await response.text()
                        raise Exception(f"API returned status {response.status}: {error_text}")
                    
                    result = await response.json()
                    
                    # Handle RunPod serverless response format
                    if 'output' not in result:
                        raise Exception(f"No output in response: {result}")
                    
                    output = result['output']
                    if isinstance(output, list) and len(output) > 0:
                        # Get the first image URL or base64 data
                        image_data = output[0]
                    elif isinstance(output, dict) and 'images' in output:
                        # Alternative format
                        image_data = output['images'][0]
                    else:
                        raise Exception(f"Unexpected output format: {output}")
                    
                    # Handle image data (URL or base64)
                    if isinstance(image_data, str):
                        if image_data.startswith('http'):
                            # It's a URL, download the image
                            async with session.get(image_data) as img_response:
                                if img_response.status != 200:
                                    raise Exception(f"Failed to download image from {image_data}")
                                image_bytes = await img_response.read()
                        else:
                            # It's base64 data
                            image_bytes = base64.b64decode(image_data)
                    else:
                        raise Exception(f"Unexpected image data format: {type(image_data)}")
                    
                    # Generate filename
                    filename = self.queue_manager.generate_filename(
                        image_request['image_type'],
                        image_request['article_id']
                    )
                    
                    # Save the image
                    await self._save_image(image_bytes, filename)
                    
                    # Update queue status
                    await self.queue_manager.update_status(
                        image_request['id'],
                        'completed',
                        output_filename=filename
                    )
                    
                    logger.info(f"Successfully processed image {image_request['id']} -> {filename}")
                    
        except Exception as e:
            logger.error(f"Failed to process image {image_request['id']}: {str(e)}")
            raise
    
    async def _process_single_image(self, pod: Dict[str, Any], image_request: Dict[str, Any]):
        """
        Process a single image on the pod.
        
        Args:
            pod: Pod information
            image_request: Image request from queue
        """
        try:
            # Get dimensions (handle both dict and JSON string formats)
            dimensions_data = image_request.get('dimensions', {})
            if isinstance(dimensions_data, str):
                # If it's a JSON string, parse it
                dimensions = json.loads(dimensions_data) if dimensions_data else {}
            elif isinstance(dimensions_data, dict):
                # If it's already a dict, use it directly
                dimensions = dimensions_data
            else:
                # Fallback to empty dict
                dimensions = {}
            
            width = dimensions.get('width', 1024)
            height = dimensions.get('height', 1024)
            
            # Get pod connection info
            public_ip = pod.get('public_ip')
            exposed_port = pod.get('exposed_port')
            
            if not public_ip or not exposed_port:
                raise Exception(f"Pod {pod['id']} missing connection info - IP: {public_ip}, Port: {exposed_port}")
            
            # Prepare the request - Direct connection to pod's Stable Diffusion Web UI
            endpoint_url = f"http://{public_ip}:{exposed_port}/sdapi/v1/txt2img"
            
            payload = {
                "prompt": image_request['prompt'],
                "negative_prompt": "blurry, low quality, distorted, deformed",
                "width": width,
                "height": height,
                "steps": 30,
                "cfg_scale": 7.5,
                "sampler_name": "DPM++ 2M Karras",
                "enable_hr": False,
                "denoising_strength": 0.7,
                "batch_size": 1,
                "n_iter": 1
            }
            
            # Make the request
            async with aiohttp.ClientSession() as session:
                headers = {
                    "Content-Type": "application/json",
                    "Authorization": "Bearer [REDACTED_API_KEY]"
                }
                
                async with session.post(endpoint_url, json=payload, headers=headers) as response:
                    if response.status != 200:
                        raise Exception(f"API returned status {response.status}")
                    
                    result = await response.json()
                    
                    # Get the generated image
                    if 'images' not in result or not result['images']:
                        raise Exception("No images in response")
                    
                    image_base64 = result['images'][0]
                    image_bytes = base64.b64decode(image_base64)
                    
                    # Generate filename
                    filename = self.queue_manager.generate_filename(
                        image_request['image_type'],
                        image_request['article_id']
                    )
                    
                    # Save the image
                    await self._save_image(image_bytes, filename)
                    
                    # Update queue status
                    await self.queue_manager.update_status(
                        image_request['id'],
                        'completed',
                        output_filename=filename
                    )
                    
                    logger.info(f"Successfully processed image {image_request['id']} -> {filename}")
                    
        except Exception as e:
            logger.error(f"Failed to process image {image_request['id']}: {str(e)}")
            await self.queue_manager.update_status(
                image_request['id'],
                'failed',
                error_message=str(e)
            )
            raise
    
    async def _save_image(self, image_bytes: bytes, filename: str):
        """
        Save image bytes to file.
        
        Args:
            image_bytes: Image data
            filename: Filename to save as
        """
        try:
            # Ensure directory exists
            base_dir = os.path.join(os.getcwd(), "public", "images", "generated")
            os.makedirs(base_dir, exist_ok=True)
            
            filepath = os.path.join(base_dir, filename)
            
            # Convert to JPG if needed
            image = Image.open(BytesIO(image_bytes))
            image = image.convert('RGB')
            image.save(filepath, 'JPEG', quality=90)
            
            logger.info(f"Saved image to {filepath}")
            
        except Exception as e:
            logger.error(f"Failed to save image: {str(e)}")
            raise
    
    async def check_and_trigger_batch(self, threshold: int = 500):
        """
        Check queue size and trigger batch if threshold is met.
        
        Args:
            threshold: Number of pending images to trigger batch
            
        Returns:
            True if batch was triggered, False otherwise
        """
        try:
            count = await self.queue_manager.get_queue_count()
            
            if count >= threshold:
                logger.info(f"Queue threshold met ({count} >= {threshold}), triggering batch")
                await self.process_daily_batch()
                return True
            else:
                logger.info(f"Queue count ({count}) below threshold ({threshold})")
                return False
                
        except Exception as e:
            logger.error(f"Failed to check queue threshold: {str(e)}")
            return False
