
import runpod
from transformers import AutoModel, AutoTokenizer
import torch

# --- Global Setup ---
print("Worker starting, loading model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("./model", trust_remote_code=True)
model = AutoModel.from_pretrained("./model", trust_remote_code=True).to(device)
model.eval()
print(f"Model and tokenizer loaded successfully onto device: {device}")

def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'Represent this sentence for searching relevant passages: {query}'

# --- Handler Function ---
def handler(job):
    job_input = job.get('input', {})
    sentences = job_input.get('sentences', [])
    if not sentences or not isinstance(sentences, list):
        return {"error": "Input 'sentences' must be a non-empty list of strings."}
    instructed_sentences = [get_detailed_instruct('retrieval', s) for s in sentences]
    with torch.no_grad():
        inputs = tokenizer(instructed_sentences, padding=True, truncation=True, return_tensors="pt").to(device)
        outputs = model(**inputs, output_hidden_states=True)
        # --- THIS IS THE CORRECTED LINE ---
        embeddings = outputs.hidden_states[-1][:, 0]
        normalized_embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
    return {"embeddings": normalized_embeddings.cpu().tolist()}

runpod.serverless.start({ "handler": handler })
