import os
import runpod
import torch
from pathlib import Path
import io
import base64
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
import gc
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def health_check():
    """Verify CUDA and models are available"""
    try:
        if not torch.cuda.is_available():
            logger.error("CUDA not available")
            return False
        logger.info(f"CUDA available: {torch.cuda.get_device_name(0)}")
        return True
    except Exception as e:
        logger.error(f"Health check failed: {e}")
        return False

MODEL_DIR = Path("/workspace")
MODELS = {
    "ponyrealismv23": {
        "path": MODEL_DIR / "ponyrealismv23.safetensors",
        "type": "sd15"
    },
    "juggxl_ragnarok": {
        "path": MODEL_DIR / "juggxl_ragnarok.safetensors",
        "type": "sdxl"
    }
}

logger.info(f"Files in /workspace: {os.listdir('/workspace')}")

if not health_check():
    raise RuntimeError("Health check failed")

def verify_models():
    """Verify all models exist before starting the server"""
    missing_models = []
    for model_name, model_info in MODELS.items():
        if not model_info["path"].exists():
            missing_models.append(model_name)
    if missing_models:
        raise RuntimeError(f"Missing models: {', '.join(missing_models)}")
    logger.info("All models verified successfully")

def verify_cuda(): 
    """Verify CUDA setup and available memory"""
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available")
    device = torch.cuda.current_device()
    logger.info(f"Using GPU: {torch.cuda.get_device_name(device)}")
    logger.info(f"CUDA Version: {torch.version.cuda}")
    memory_allocated = torch.cuda.memory_allocated(device) / (1024**3)
    memory_reserved = torch.cuda.memory_reserved(device) / (1024**3)
    logger.info(f"GPU Memory: Allocated={memory_allocated:.2f}GB, Reserved={memory_reserved:.2f}GB")

PIPELINES = {}

def load_pipeline(model_name):
    if model_name in PIPELINES:
        return PIPELINES[model_name]
    if model_name not in MODELS:
        raise ValueError(f"Model {model_name} not found")
    
    torch.cuda.empty_cache()
    gc.collect()
    
    logger.info(f"Loading model: {model_name}")
    model_path = str(MODELS[model_name]["path"])
    model_type = MODELS[model_name]["type"]

    if model_type == "sd15":
        pipe = StableDiffusionPipeline.from_single_file(model_path, torch_dtype=torch.float16).to("cuda")
    elif model_type == "sdxl":
        pipe = StableDiffusionXLPipeline.from_single_file(model_path, torch_dtype=torch.float16).to("cuda")
    else:
        raise ValueError(f"Unknown model type: {model_type}")

    PIPELINES[model_name] = pipe
    return pipe

def cleanup_pipeline(model_name):
    if model_name in PIPELINES:
        logger.info(f"Cleaning up model: {model_name}")
        del PIPELINES[model_name]
        torch.cuda.empty_cache()
        gc.collect()

def handler(event):
    """
    RunPod handler function that processes image generation requests
    """
    try:
        input_data = event["input"]
        logger.info(f"Received generation request for model: {input_data.get('model', 'ponyrealismv23')}")

        # Validate required fields
        if "prompt" not in input_data:
            return {"error": "prompt is required"}

        # Load model and generate image
        model_name = input_data.get("model", "ponyrealismv23")
        pipe = load_pipeline(model_name)
        
        # Setup generation parameters
        generator = torch.manual_seed(input_data.get("seed")) if "seed" in input_data else None
        model_type = MODELS[model_name]["type"]

        generation_params = {
            "prompt": input_data["prompt"],
            "negative_prompt": input_data.get("negative_prompt", ""),
            "num_inference_steps": input_data.get("steps", 30),
            "width": input_data.get("width", 512),
            "height": input_data.get("height", 768),
            "guidance_scale": input_data.get("guidance_scale", 7.5),
            "generator": generator
        }

        # Generate image
        result = pipe(**generation_params)
        img = result.images[0]

        # Convert image to base64
        buffer = io.BytesIO()
        img.save(buffer, format="PNG")
        img_str = base64.b64encode(buffer.getvalue()).decode()
            
        return {
            "image": img_str,
            "type": "base64_png"
        }

    except Exception as e:
        logger.error(f"Error during generation: {str(e)}")
        return {"error": str(e)}
    
    finally:
        cleanup_pipeline(model_name)

if __name__ == "__main__":
    try:
        logger.info("==========================================")
        logger.info("Initializing RunPod Serverless Backend...")
        logger.info("==========================================")
        
        verify_models()
        verify_cuda()
        
        logger.info("Initialization complete. Starting server...")
        runpod.serverless.start({"handler": handler})
    except Exception as e:
        logger.error(f"Server initialization failed: {str(e)}")
        raise