import runpod
from runpod import RunPodLogger
from google.cloud import storage
import shutil
import os
import json
import uuid
import asyncio
from pathlib import Path
from datetime import timedelta
import psutil

# ----------------- Setup Runpod Logger -----------------
log = RunPodLogger()

# ----------------- GCS Setup -----------------
gcs_credentials_json = os.getenv("GCS_SERVICE_ACCOUNT_JSON")
if not gcs_credentials_json:
    raise ValueError("Missing GCS_SERVICE_ACCOUNT_JSON environment variable")

credentials_dict = json.loads(gcs_credentials_json)
client = storage.Client.from_service_account_info(credentials_dict)

bucket_name = os.getenv("GCS_BUCKET_NAME")
if not bucket_name:
    raise ValueError("Missing GCS_BUCKET_NAME environment variable")

bucket = client.bucket(bucket_name)

# ----------------- Helper Functions -----------------
def log_disk_usage():
    usage = psutil.disk_usage('/')
    log.info(f"Disk Usage: {usage.percent}% used. Free: {usage.free // (1024 * 1024)} MB")

def cleanup_temp_files(files):
    for file in files:
        try:
            file.unlink()
            log.info(f"Deleted {file}")
        except Exception as e:
            log.warning(f"Failed to delete {file}: {e}")

async def download_blob_to_file(job, blob_name, destination_path):
    tries = 3
    delay = 2
    for attempt in range(tries):
        try:
            log.info(f"Downloading {blob_name} to {destination_path}")
            blob = bucket.blob(blob_name)
            blob.download_to_filename(destination_path)
            log.info(f"Downloaded {destination_path} ({destination_path.stat().st_size // 1024} KB)")
            return
        except Exception as e:
            if attempt < tries - 1:
                msg = f"Download {blob_name} failed: {e}. Retrying..."
                log.warning(msg)
                runpod.serverless.progress_update(job, msg)
                await asyncio.sleep(delay)
                delay *= 2
            else:
                raise e

async def upload_file_to_gcs(job, file_path, destination_blob_name):
    tries = 3
    delay = 2
    for attempt in range(tries):
        try:
            log.info(f"Uploading {file_path} to {destination_blob_name}")
            blob = bucket.blob(destination_blob_name)
            blob.upload_from_filename(file_path)
            log.info(f"Uploaded to {destination_blob_name}")
            return blob.public_url
        except Exception as e:
            if attempt < tries - 1:
                msg = f"Upload {destination_blob_name} failed: {e}. Retrying..."
                log.warning(msg)
                runpod.serverless.progress_update(job, msg)
                await asyncio.sleep(delay)
                delay *= 2
            else:
                raise e

def generate_signed_url(blob_name, expiration_minutes=60):
    log.info(f"Generating signed URL for {blob_name}")
    blob = bucket.blob(blob_name)
    url = blob.generate_signed_url(expiration=timedelta(minutes=expiration_minutes))
    log.info(f"Signed URL: {url}")
    return url

# ----------------- Concurrency Modifier -----------------
def adjust_concurrency(current_concurrency):
    max_concurrency = 5
    min_concurrency = 1
    log.info(f"Adjusting concurrency. Current: {current_concurrency}")
    if current_concurrency < max_concurrency:
        return current_concurrency + 1
    elif current_concurrency > min_concurrency:
        return current_concurrency - 1
    return current_concurrency

# ----------------- Main Handler -----------------
async def handler(job):
    try:
        log.info(f"Job started with input: {job.get('input')}")
        input_data = job['input']
        video_blob_name = input_data.get('video_path')
        audio_blob_name = input_data.get('audio_path')

        if not video_blob_name or not audio_blob_name:
            raise ValueError("Both video_path and audio_path must be provided.")

        workspace = Path("/tmp")
        workspace.mkdir(parents=True, exist_ok=True)

        video_local = workspace / "video.mp4"
        audio_local = workspace / "audio.m4a"
        output_filename = f"output_{uuid.uuid4().hex}.mp4"
        output_path = workspace / output_filename

        log_disk_usage()

        runpod.serverless.progress_update(job, "Downloading video and audio in parallel...")
        await asyncio.gather(
            download_blob_to_file(job, video_blob_name, video_local),
            download_blob_to_file(job, audio_blob_name, audio_local)
        )

        log_disk_usage()

        runpod.serverless.progress_update(job, "Processing video (copying)...")
        log.info(f"Copying video file to {output_path}")
        shutil.copy(video_local, output_path)
        log.info(f"Copied output file size: {output_path.stat().st_size // 1024} KB")

        log_disk_usage()

        runpod.serverless.progress_update(job, "Uploading output video...")
        output_blob_name = f"light-test/{output_filename}"
        public_url = await upload_file_to_gcs(job, str(output_path), output_blob_name)

        log_disk_usage()

        runpod.serverless.progress_update(job, "Generating signed URL...")
        signed_url = generate_signed_url(output_blob_name)

        # Cleanup
        log.info("Cleaning up temporary files...")
        cleanup_temp_files([video_local, audio_local, output_path])

        log.info("Job completed successfully.")
        return {
            "message": "Files processed, uploaded, and cleaned up successfully",
            "signed_url": signed_url
        }

    except Exception as e:
        log.error("Handler error", exc_info=True)
        return {"error": str(e)}

# ----------------- Start Serverless Handler -----------------
runpod.serverless.start({
    "handler": handler,
    "concurrency_modifier": adjust_concurrency,
})
