import base64 import os import time import runpod import torch from pruna import PrunaModel from runpod.serverless.utils import rp_cleanup, rp_upload from runpod.serverless.utils.rp_validator import validate from schemas import INPUT_SCHEMA import os import stat import shutil # python def restructure_to_hf_cache(source_base="/runpod/model-store/huggingface", cache_dir="/root/.cache/huggingface"): """ Restructure models from custom directory to HF cache format. If HF_MODEL is set (e.g. PrunaAI/FLUX.1-dev-smashed) we only attempt to restructure that specific model and preserve its exact capitalization in the HF cache name (models--PrunaAI--FLUX.1-dev-smashed). """ DEBUG_RESTRUCTURE = os.environ.get("DEBUG_RESTRUCTURE", "true").lower() in ("true", "1", "yes") hf_model_env = os.environ.get("HF_MODEL") # e.g. PrunaAI/FLUX.1-dev-smashed def debug_print(msg): if DEBUG_RESTRUCTURE: print(msg, flush=True) def find_dir_case_insensitive(parent, target_name): """Return the matching directory name in parent that matches target_name (case-insensitive), or None.""" if not os.path.exists(parent): return None # exact match first candidate = os.path.join(parent, target_name) if os.path.isdir(candidate): return target_name target_lower = target_name.lower() for entry in os.listdir(parent): if os.path.isdir(os.path.join(parent, entry)) and entry.lower() == target_lower: return entry return None debug_print(f"[DEBUG] Restructuring models from {source_base} to HF cache format...") debug_print(f"[DEBUG] Target cache directory: {cache_dir}") debug_print(f"[DEBUG] Source base exists: {os.path.exists(source_base)}") if not os.path.exists(source_base): debug_print("[DEBUG] Source base missing, nothing to restructure.") return # If HF_MODEL provided, scope to that org/model only target_pairs = [] if hf_model_env: try: target_org, target_model = hf_model_env.split("/", 1) except Exception: debug_print(f"[DEBUG] HF_MODEL value malformed: {hf_model_env}") return debug_print(f"[DEBUG] HF_MODEL detected: org={target_org}, model={target_model}") # Find actual directory names on disk (case-insensitive) found_org = find_dir_case_insensitive(source_base, target_org) if not found_org: debug_print(f"[DEBUG] Organization directory not found for '{target_org}' in {source_base}") return org_path = os.path.join(source_base, found_org) found_model = find_dir_case_insensitive(org_path, target_model) if not found_model: # Sometimes model folder may be lowercase or slightly different; try any candidate that contains model lower substring candidates = [d for d in os.listdir(org_path) if os.path.isdir(os.path.join(org_path, d))] found_model = None for c in candidates: if target_model.lower() in c.lower(): found_model = c break if not found_model: debug_print(f"[DEBUG] Model directory not found for '{target_model}' under {org_path}") return debug_print(f"[DEBUG] Found on-disk org: {found_org}, model: {found_model}") target_pairs.append((found_org, found_model, target_org, target_model)) else: # No HF_MODEL specified: process all orgs/models for org_name in os.listdir(source_base): org_path = os.path.join(source_base, org_name) if not os.path.isdir(org_path): continue for model_name in os.listdir(org_path): model_path = os.path.join(org_path, model_name) if not os.path.isdir(model_path): continue target_pairs.append((org_name, model_name, org_name, model_name)) for on_disk_org, on_disk_model, hf_org_exact, hf_model_exact in target_pairs: org_path = os.path.join(source_base, on_disk_org) model_path = os.path.join(org_path, on_disk_model) debug_print(f"[DEBUG] Processing organization: {on_disk_org}") debug_print(f"[DEBUG] Organization path: {org_path}") debug_print(f"[DEBUG] Processing model: {on_disk_model}") debug_print(f"[DEBUG] Model path: {model_path}") try: commit_dirs = [d for d in os.listdir(model_path) if os.path.isdir(os.path.join(model_path, d))] debug_print(f"[DEBUG] Found commit directories: {commit_dirs}") if not commit_dirs: debug_print(f"[DEBUG] No commit directories found in {model_path}") continue for commit_hash in commit_dirs: source_model_path = os.path.join(model_path, commit_hash) hf_name = f"{hf_org_exact}--{hf_model_exact}" dest_model_dir = os.path.join(cache_dir, "hub", f"models--{hf_name}") dest_snapshot_dir = os.path.join(dest_model_dir, "snapshots", commit_hash) debug_print(f"[DEBUG] Destination model directory: {dest_model_dir}") debug_print(f"[DEBUG] Destination snapshot directory: {dest_snapshot_dir}") os.makedirs(os.path.dirname(dest_snapshot_dir), exist_ok=True) if os.path.exists(dest_snapshot_dir): debug_print(f"[DEBUG] Destination already exists: {dest_snapshot_dir}") continue debug_print(f"[DEBUG] Linking {source_model_path} -> {dest_snapshot_dir}") try: os.symlink(source_model_path, dest_snapshot_dir) except Exception as e_symlink: debug_print(f"[DEBUG] Symlink failed ({e_symlink}), falling back to copytree") try: shutil.copytree(source_model_path, dest_snapshot_dir) except Exception as e_copy: debug_print(f"[DEBUG] copytree also failed: {e_copy}") raise refs_dir = os.path.join(dest_model_dir, "refs") os.makedirs(refs_dir, exist_ok=True) refs_main_path = os.path.join(refs_dir, "main") debug_print(f"[DEBUG] Creating refs/main at: {refs_main_path}") with open(refs_main_path, "w") as f: f.write(commit_hash) debug_print(f"[DEBUG] ✓ Successfully linked model: {hf_name}") except Exception as e: debug_print(f"[DEBUG] Error processing {on_disk_org}/{on_disk_model}: {e}") if DEBUG_RESTRUCTURE: import traceback traceback.print_exc() continue debug_print("[DEBUG] Restructuring complete!") debug_print(f"[DEBUG] Final cache directory exists: {os.path.exists(cache_dir)}") hub_dir = os.path.join(cache_dir, "hub") if os.path.exists(hub_dir): debug_print(f"[DEBUG] Hub directory contents: {os.listdir(hub_dir)}") # --- Update: default HF model used when env is not set -- # In ModelHandler.load_models replace the default string with the new default (optional) # Example change inside ModelHandler.load_models: # os.environ.get("HF_MODEL", "PrunaAI/FLUX.1-dev-smashed") # (No other changes required if HF_MODEL will be set externally.) restructure_to_hf_cache() torch.cuda.empty_cache() class ModelHandler: def __init__(self): self.pipe = None self.load_models() def load_models(self): # Load FLUX.1-dev pipeline from cache using identifier try: self.pipe = PrunaModel.from_hub( os.environ.get("HF_MODEL", "PrunaAI/FLUX.1-schnell-smashed"), local_files_only=True, ) self.pipe.move_to_device("cuda") except: print("[ERROR] Failed to load model from local cache.", flush=True) time.sleep(500) # Sleep to allow log inspection raise MODELS = ModelHandler() def _save_and_upload_images(images, job_id): os.makedirs(f"/{job_id}", exist_ok=True) image_urls = [] for index, image in enumerate(images): image_path = os.path.join(f"/{job_id}", f"{index}.png") image.save(image_path) if os.environ.get("BUCKET_ENDPOINT_URL", False): image_url = rp_upload.upload_image(job_id, image_path) image_urls.append(image_url) else: with open(image_path, "rb") as image_file: image_data = base64.b64encode(image_file.read()).decode("utf-8") image_urls.append(f"data:image/png;base64,{image_data}") rp_cleanup.clean([f"/{job_id}"]) return image_urls @torch.inference_mode() def generate_image(job): """ Generate an image from text using FLUX.1-dev Model """ # ------------------------------------------------------------------------- # 🐞 DEBUG LOGGING # ------------------------------------------------------------------------- import json import pprint # Log the exact structure RunPod delivers so we can see every nesting level. print("[generate_image] RAW job dict:") try: print(json.dumps(job, indent=2, default=str), flush=True) except Exception: pprint.pprint(job, depth=4, compact=False) # ------------------------------------------------------------------------- # Original (strict) behaviour – assume the expected single wrapper exists. # ------------------------------------------------------------------------- job_input = job["input"] print("[generate_image] job['input'] payload:") try: print(json.dumps(job_input, indent=2, default=str), flush=True) except Exception: pprint.pprint(job_input, depth=4, compact=False) # Input validation try: validated_input = validate(job_input, INPUT_SCHEMA) except Exception as err: import traceback print("[generate_image] validate(...) raised an exception:", err, flush=True) traceback.print_exc() # Re-raise so RunPod registers the failure (but logs are now visible). raise print("[generate_image] validate(...) returned:") try: print(json.dumps(validated_input, indent=2, default=str), flush=True) except Exception: pprint.pprint(validated_input, depth=4, compact=False) if "errors" in validated_input: return {"error": validated_input["errors"]} job_input = validated_input["validated_input"] if job_input["seed"] is None: job_input["seed"] = int.from_bytes(os.urandom(2), "big") # Create generator with proper device handling device = torch.device("cuda" if torch.cuda.is_available() else "cpu") generator = torch.Generator(device).manual_seed(job_input["seed"]) try: # Generate image using FLUX.1-dev pipeline with torch.inference_mode(): result = MODELS.pipe( prompt=job_input["prompt"], negative_prompt=job_input["negative_prompt"], height=job_input["height"], width=job_input["width"], num_inference_steps=job_input["num_inference_steps"], guidance_scale=job_input["guidance_scale"], num_images_per_prompt=job_input["num_images"], generator=generator, ) output = result.images except RuntimeError as err: print(f"[ERROR] RuntimeError in generation pipeline: {err}", flush=True) return { "error": f"RuntimeError: {err}, Stack Trace: {err.__traceback__}", "refresh_worker": True, } except Exception as err: print(f"[ERROR] Unexpected error in generation pipeline: {err}", flush=True) return { "error": f"Unexpected error: {err}", "refresh_worker": True, } image_urls = _save_and_upload_images(output, job["id"]) results = { "images": image_urls, "image_url": image_urls[0], "seed": job_input["seed"], } return results runpod.serverless.start({"handler": generate_image})