#!/usr/bin/env python3
"""
RunPod Worker Health Monitor

This script monitors the health of RunPod serverless workers by:
1. Fetching worker information from RunPod API
2. Extracting ping URLs for each worker
3. Pinging each worker directly to check health
4. Running continuously at specified intervals

Usage:
    python monitor.py --endpoint-id YOUR_ENDPOINT_ID --api-key YOUR_API_KEY --interval 30
"""

import os
import sys
import time
import json
import requests
import argparse
from datetime import datetime
from typing import List, Dict, Optional
from dataclasses import dataclass
from urllib.parse import urlparse


@dataclass
class Worker:
    """Represents a RunPod worker with health status"""
    id: str
    machine_id: str
    desired_status: str
    ping_url: str
    cost_per_hr: float
    memory_gb: int
    vcpu_count: int
    last_started_at: str
    ping_status: Optional[str] = None
    ping_response_time: Optional[float] = None
    ping_error: Optional[str] = None


class RunPodHealthMonitor:
    """Monitor RunPod serverless worker health"""
    
    def __init__(self, endpoint_id: str, api_key: str, timeout: int = 10):
        self.endpoint_id = endpoint_id
        self.api_key = api_key
        self.timeout = timeout
        self.base_url = "https://rest.runpod.io/v1"
        
    def get_workers(self) -> List[Worker]:
        """Fetch worker information from RunPod API"""
        url = f"{self.base_url}/endpoints/{self.endpoint_id}?includeWorkers=true"
        headers = {"Authorization": f"Bearer {self.api_key}"}
        
        try:
            response = requests.get(url, headers=headers, timeout=self.timeout)
            response.raise_for_status()
            data = response.json()
            
            workers = []
            for worker_data in data.get('workers', []):
                # Extract ping URL from environment variables
                env = worker_data.get('env', {})
                ping_url_template = env.get('RUNPOD_WEBHOOK_PING', '')
                
                # Replace template variables with actual values
                ping_url = ping_url_template.replace(
                    '$RUNPOD_POD_ID', worker_data.get('id', '')
                ).replace(
                    '$RUNPOD_GPU_TYPE_ID', 'unknown'  # GPU type not available in worker data
                )
                
                worker = Worker(
                    id=worker_data.get('id', ''),
                    machine_id=worker_data.get('machineId', ''),
                    desired_status=worker_data.get('desiredStatus', ''),
                    ping_url=ping_url,
                    cost_per_hr=worker_data.get('costPerHr', 0.0),
                    memory_gb=worker_data.get('memoryInGb', 0),
                    vcpu_count=worker_data.get('vcpuCount', 0),
                    last_started_at=worker_data.get('lastStartedAt', '')
                )
                workers.append(worker)
                
            return workers
            
        except requests.RequestException as e:
            print(f"Error fetching workers: {e}")
            return []
    
    def ping_worker(self, worker: Worker) -> None:
        """Ping a specific worker and update its health status"""
        if not worker.ping_url or '$' in worker.ping_url:
            worker.ping_error = "Invalid ping URL (contains template variables)"
            worker.ping_status = "ERROR"
            return
            
        try:
            start_time = time.time()
            response = requests.get(worker.ping_url, timeout=self.timeout)
            response_time = time.time() - start_time
            
            worker.ping_response_time = response_time
            
            if response.status_code == 200:
                worker.ping_status = "HEALTHY"
            else:
                worker.ping_status = "UNHEALTHY"
                worker.ping_error = f"HTTP {response.status_code}"
                
        except requests.exceptions.Timeout:
            worker.ping_status = "TIMEOUT"
            worker.ping_error = f"Timeout after {self.timeout}s"
        except requests.exceptions.ConnectionError as e:
            worker.ping_status = "CONNECTION_ERROR"
            worker.ping_error = str(e)
        except requests.RequestException as e:
            worker.ping_status = "ERROR"
            worker.ping_error = str(e)
    
    def monitor_workers(self) -> List[Worker]:
        """Fetch workers and ping each one"""
        workers = self.get_workers()
        
        print(f"\n{'='*80}")
        print(f"Health Check - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        print(f"Found {len(workers)} workers")
        print(f"{'='*80}")
        
        if not workers:
            print("No workers found!")
            return []
        
        # Ping each worker
        for worker in workers:
            if worker.desired_status == "RUNNING":
                self.ping_worker(worker)
        
        # Display results
        self.display_results(workers)
        return workers
    
    def display_results(self, workers: List[Worker]) -> None:
        """Display worker health check results in a formatted table"""
        running_workers = [w for w in workers if w.desired_status == "RUNNING"]
        exited_workers = [w for w in workers if w.desired_status == "EXITED"]
        
        if running_workers:
            print(f"\n🟢 RUNNING WORKERS ({len(running_workers)}):")
            print(f"{'ID':<15} {'Status':<15} {'Response Time':<15} {'Memory':<10} {'vCPU':<6} {'Cost/hr':<8} {'Error'}")
            print("-" * 90)
            
            for worker in running_workers:
                response_time_str = f"{worker.ping_response_time:.3f}s" if worker.ping_response_time else "N/A"
                status_emoji = {
                    "HEALTHY": "✅",
                    "UNHEALTHY": "❌", 
                    "TIMEOUT": "⏰",
                    "CONNECTION_ERROR": "🔌",
                    "ERROR": "⚠️"
                }.get(worker.ping_status, "❓")
                
                error_str = worker.ping_error[:30] + "..." if worker.ping_error and len(worker.ping_error) > 30 else (worker.ping_error or "")
                
                print(f"{worker.id:<15} {status_emoji} {worker.ping_status:<13} {response_time_str:<15} "
                      f"{worker.memory_gb}GB{'':<6} {worker.vcpu_count:<6} ${worker.cost_per_hr:<7} {error_str}")
        
        if exited_workers:
            print(f"\n🔴 EXITED WORKERS ({len(exited_workers)}):")
            for worker in exited_workers:
                print(f"  {worker.id} (Last started: {worker.last_started_at})")
        
        # Summary
        if running_workers:
            healthy_count = len([w for w in running_workers if w.ping_status == "HEALTHY"])
            print(f"\n📊 SUMMARY: {healthy_count}/{len(running_workers)} workers healthy")
            
            if healthy_count < len(running_workers):
                unhealthy = [w for w in running_workers if w.ping_status != "HEALTHY"]
                print(f"⚠️  UNHEALTHY WORKERS: {[w.id for w in unhealthy]}")
    
    def run_continuous(self, interval: int) -> None:
        """Run health checks continuously at specified interval"""
        print(f"🚀 Starting RunPod Health Monitor")
        print(f"📡 Endpoint ID: {self.endpoint_id}")
        print(f"⏱️  Check interval: {interval} seconds")
        print(f"⏰ Timeout: {self.timeout} seconds")
        
        try:
            while True:
                self.monitor_workers()
                print(f"\n💤 Sleeping for {interval} seconds...")
                time.sleep(interval)
                
        except KeyboardInterrupt:
            print(f"\n🛑 Monitoring stopped by user")
        except Exception as e:
            print(f"\n❌ Error during monitoring: {e}")


def main():
    parser = argparse.ArgumentParser(description="Monitor RunPod serverless worker health")
    parser.add_argument("--endpoint-id", required=True, help="RunPod endpoint ID")
    parser.add_argument("--api-key", help="RunPod API key (or use RUNPOD_API_KEY env var)")
    parser.add_argument("--interval", type=int, default=30, help="Check interval in seconds (default: 30)")
    parser.add_argument("--timeout", type=int, default=10, help="Request timeout in seconds (default: 10)")
    parser.add_argument("--once", action="store_true", help="Run once instead of continuously")
    
    args = parser.parse_args()
    
    # Get API key from argument or environment
    api_key = args.api_key or os.getenv("RUNPOD_API_KEY")
    if not api_key:
        print("❌ Error: API key required. Use --api-key or set RUNPOD_API_KEY environment variable")
        sys.exit(1)
    
    monitor = RunPodHealthMonitor(
        endpoint_id=args.endpoint_id,
        api_key=api_key,
        timeout=args.timeout
    )
    
    if args.once:
        monitor.monitor_workers()
    else:
        monitor.run_continuous(args.interval)


if __name__ == "__main__":
    main()
