# handler.py (synchronous, with log streaming and liveness checks)
import os, io, re, time, base64, requests, mimetypes
from pathlib import Path
from urllib.parse import urlparse

import certifi, urllib3
from minio import Minio
from minio.error import S3Error
import runpod

print("--- CUSTOM handler.py SCRIPT IS BEING EXECUTED ---")
print("--- If you see this, start.sh successfully called this script. ---")

# --- ComfyUI settings ---
COMFY_API = os.getenv("COMFY_API", "http://127.0.0.1:8188")

# Work out the filesystem base for Comfy
COMFY_BASE = os.getenv("COMFY_BASE_DIR") or (
    "/runpod-volume/ComfyUI" if Path("/runpod-volume/ComfyUI").exists()
    else "/workspace/ComfyUI"
)
OUT_DIR = Path(os.getenv("COMFY_OUTPUT_DIR", str(Path(COMFY_BASE) / "output")))
IN_DIR  = Path(os.getenv("COMFY_INPUT_DIR",  str(Path(COMFY_BASE) / "input")))
COMFY_LOG = Path(os.getenv("COMFY_LOG_PATH", str(Path(COMFY_BASE) / "user" / "comfyui.log")))

# File types we’ll upload
IMG = {".png", ".jpg", ".jpeg", ".webp"}
VID = {".mp4", ".webm", ".gif", ".mov", ".mkv"}
ALLOWED = IMG | VID

# Timeouts / behavior
HTTP_TIMEOUT     = int(os.getenv("HTTP_TIMEOUT", "15"))
READY_TIMEOUT    = int(os.getenv("COMFY_READY_TIMEOUT", "180"))
READY_INTERVAL   = float(os.getenv("COMFY_READY_INTERVAL", "2"))
HIST_TIMEOUT     = int(os.getenv("COMFY_HISTORY_TIMEOUT", "900")) # <<< MODIFIED: Increased for video
SCAN_TIMEOUT     = int(os.getenv("OUTPUT_SCAN_TIMEOUT", "300"))
SCAN_INTERVAL    = float(os.getenv("OUTPUT_SCAN_INTERVAL", "2"))
FLATTEN          = os.getenv("BUCKET_FLATTEN", "").strip() == "1"
DO_R2_HEALTH     = os.getenv("R2_HEALTHCHECK", "0") == "1"

# Pull latest NetDist session id from comfyui.log
SESSION_RE = re.compile(r"NetDist: Set session ID to '([a-z0-9]+)'", re.I)

# ---------- R2 / MinIO ----------
def minio_client():
    ep_raw = os.environ.get("BUCKET_ENDPOINT_URL", "").strip()
    host = urlparse(ep_raw).netloc or urlparse("https://"+ep_raw).netloc
    if not host:
        raise RuntimeError("BUCKET_ENDPOINT_URL is missing or invalid (no host)")
    # No need to print, it's in the final metadata
    httpc = urllib3.PoolManager(
        cert_reqs="CERT_REQUIRED",
        ca_certs=certifi.where(),
        retries=False,
    )
    return Minio(
        host,
        access_key=os.environ["BUCKET_ACCESS_KEY_ID"].strip(),
        secret_key=os.environ["BUCKET_SECRET_ACCESS_KEY"].strip(),
        secure=True,
        region="auto",
        http_client=httpc,
    )

def upload_object(obj_key: str, f: Path, client: Minio):
    bucket = os.environ["BUCKET_NAME"].strip()
    public = os.environ["BUCKET_PUBLIC_DOMAIN"].rstrip("/")
    ctype = mimetypes.guess_type(str(f))[0] or "application/octet-stream"

    client.fput_object(
        bucket, obj_key, str(f),
        content_type=ctype,
        metadata={"Cache-Control": os.getenv("BUCKET_CACHE_CONTROL", "public, max-age=31536000, immutable")}
    )
    client.stat_object(bucket, obj_key)
    return f"{public}/{obj_key}"

# ---------- ComfyUI helpers ----------
def comfy_ready():
    try:
        r = requests.get(f"{COMFY_API}/queue", timeout=HTTP_TIMEOUT)
        return r.ok and r.headers.get("content-type","").startswith("application/json")
    except Exception:
        return False

def wait_ready():
    start = time.time()
    while time.time() - start < READY_TIMEOUT:
        if comfy_ready():
            return True, time.time() - start
        time.sleep(READY_INTERVAL)
    return False, time.time() - start

def history(prompt_id):
    try:
        r = requests.get(f"{COMFY_API}/history/{prompt_id}", timeout=HTTP_TIMEOUT)
        if not r.ok:
            return None
        data = r.json()
        return data.get(prompt_id) if isinstance(data, dict) else data
    except Exception:
        return None

# <<< NEW: Generator to stream comfyui.log file
def stream_log(interval=1):
    if not COMFY_LOG.exists():
        yield "[Livelog] Log file not found at start."
        return

    with open(COMFY_LOG, 'r', encoding='utf-8', errors='ignore') as f:
        f.seek(0, 2) # Go to the end of the file
        log_pos = f.tell()

        while True:
            f.seek(log_pos)
            line = f.readline()
            if line:
                yield f"[Livelog] {line.strip()}"
                log_pos = f.tell()
            time.sleep(interval)

# <<< MODIFIED: Now takes a log_streamer to yield from
def wait_done(prompt_id, log_streamer):
    deadline = time.time() + HIST_TIMEOUT

    # Interleave log streaming and history polling
    for log_line in log_streamer:
        yield log_line

        if time.time() >= deadline:
            return "timeout", {"message": "History polling timed out"}

        # <<< NEW: Liveness check
        if not comfy_ready():
            return "error", {"message": "ComfyUI became unresponsive during the job."}

        h = history(prompt_id)
        if h:
            st = h.get("status", {})
            exec_info = (st.get("exec_info") or {})
            if exec_info.get("error"):
                err = exec_info["error"]
                msg = err if isinstance(err, str) else getattr(err, "message", str(err))
                return "error", {"history": h, "message": msg}
            if st.get("completed") in (True, 1) or isinstance(st.get("completed"), int):
                return "completed", {"history": h}

def iter_new(since_ts: float, root: Path):
    if not root.exists():
        return
    for dp, dn, fn in os.walk(root):
        dn[:] = [d for d in dn if not d.startswith(".")]
        for name in fn:
            p = Path(dp) / name
            if p.suffix.lower() in ALLOWED:
                try:
                    if p.is_file() and p.stat().st_mtime >= since_ts:
                        yield p
                except FileNotFoundError:
                    pass

def read_netdist_id() -> str | None:
    try:
        if COMFY_LOG.exists():
            data = COMFY_LOG.read_bytes()
            tail = data[-131072:] if len(data) > 131072 else data
            text = tail.decode("utf-8", "ignore")
            last = None
            for m in SESSION_RE.finditer(text):
                last = m.group(1)
            return last
    except Exception:
        pass
    return None

def make_object_key(idx: int, total: int, src: Path, prompt_id: str, netdist_id: str | None):
    ext = src.suffix.lower()
    base = prompt_id + (f"_{netdist_id}" if netdist_id else "")
    return f"{base}{ext}" if total == 1 else f"{base}_{idx:05d}{ext}"

# ---------- Runpod handler ----------
# <<< MODIFIED: to be a generator function
def handler(job):
    # Optional R2 preflight
    if DO_R2_HEALTH:
        try:
            c = minio_client()
            c.put_object(
                os.environ["BUCKET_NAME"], f"sanity/health-{int(time.time())}.txt",
                io.BytesIO(b"1"), length=1, content_type="text/plain",
                metadata={"Cache-Control":"public, max-age=60"}
            )
        except Exception as e:
            yield {"error": "R2 health check failed", "details": str(e)}
            return

    wf = job.get("prompt") or job.get("input", {}).get("workflow")
    if not wf:
        yield {"error": "No workflow provided"}
        return

    IN_DIR.mkdir(parents=True, exist_ok=True)
    for img in job.get("input", {}).get("images", []):
        name = Path(img.get("name") or "").name
        data = img.get("image")
        if not name or not data:
            continue
        if data.startswith("data:"):
            data = data.split("base64,", 1)[-1]
        try:
            (IN_DIR / name).write_bytes(base64.b64decode(data))
        except Exception as e:
            yield {"error": "Invalid base64 image", "details": str(e)}
            return

    ok, waited_ready = wait_ready()
    if not ok:
        yield {"error": "ComfyUI not ready", "details": {"waited_seconds": round(waited_ready, 2)}}
        return

    started = time.time()
    try:
        r = requests.post(f"{COMFY_API}/prompt", json=wf, timeout=HTTP_TIMEOUT)
    except Exception as e:
        yield {"error": "Failed to reach ComfyUI /prompt", "details": str(e)}
        return
    if not r.ok:
        yield {"error": "Failed to submit workflow", "details": r.text}
        return
    prompt_id = (r.json() or {}).get("prompt_id") or "unknown"

    # <<< MODIFIED: Main loop now yields log updates before getting the final status
    log_streamer = stream_log()
    status, details = yield from wait_done(prompt_id, log_streamer)

    if status == "error":
        yield {"error": "ComfyUI reported an error", "details": {"prompt_id": prompt_id, **details}}
        return

    waited_scan = 0
    found = []
    while waited_scan < SCAN_TIMEOUT:
        found = list(iter_new(started, OUT_DIR))
        if found:
            break
        time.sleep(SCAN_INTERVAL)
        waited_scan += SCAN_INTERVAL

    if not found:
        yield {
            "error": "No media produced",
            "details": {"waited_seconds": waited_scan, "prompt_id": prompt_id, "history_status": status},
        }
        return

    netdist_id = read_netdist_id()
    results = []
    client = minio_client()
    total = len(found)
    for idx, f in enumerate(sorted(found)):
        try:
            obj_key = make_object_key(idx, total, f, prompt_id, netdist_id)
            url = upload_object(obj_key, f, client)
            kind = "video" if f.suffix.lower() in VID else "image"
            results.append({
                "filename": f.name,
                "object_key": obj_key,
                "kind": kind,
                "type": "s3_url",
                "data": url
            })
        except S3Error as e:
            results.append({
                "filename": f.name,
                "type": "error",
                "error": f"S3Error {e.code}: {getattr(e,'message',str(e))}"
            })
        except Exception as e:
            results.append({"filename": f.name, "type": "error", "error": str(e)})
        finally:
            try:
                f.unlink()
            except:
                pass

    ep = os.environ.get("BUCKET_ENDPOINT_URL", "")
    yield {
        "images": results,
        "meta": {
            "prompt_id": prompt_id,
            "files_found": total,
            "scanned_dir": str(OUT_DIR),
            "input_dir": str(IN_DIR),
            "ready_wait_seconds": round(waited_ready, 2),
            "history_status": status,
            "flatten": FLATTEN,
            "endpoint_host": urlparse(ep).netloc or urlparse("https://"+ep).netloc,
            "bucket": os.environ.get("BUCKET_NAME", "").strip(),
            "public_domain": os.environ.get("BUCKET_PUBLIC_DOMAIN", "").strip(),
            "netdist_id": netdist_id,
        },
    }

if __name__ == "__main__":
    runpod.serverless.start({"handler": handler})
