hyrinmansoor's picture
Update handler.py
ae4c9ee verified
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
class EndpointHandler:
def __init__(self, path=""):
# Load tokenizer and model
self.tokenizer = T5Tokenizer.from_pretrained(path)
self.model = T5ForConditionalGeneration.from_pretrained(path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.model.eval()
def __call__(self, data):
"""
Expected input JSON format:
{
"inputs": "Instruction: Generate the correct Frappe query for the given question...",
"parameters": {
"max_new_tokens": 128,
"temperature": 0.3,
"do_sample": false
}
}
"""
# Extract text and optional parameters
inputs = data.get("inputs", "")
params = data.get("parameters", {})
# Tokenize the prompt
encoded_input = self.tokenizer(
inputs,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(self.device)
# Generate with optional parameters
gen_kwargs = {
"max_new_tokens": params.get("max_new_tokens", 128),
"temperature": params.get("temperature", 0.3),
"do_sample": params.get("do_sample", False),
"num_beams": params.get("num_beams", 1),
}
# Generate output
with torch.no_grad():
generated_ids = self.model.generate(**encoded_input, **gen_kwargs)
# Decode output
output = self.tokenizer.decode(generated_ids[0], skip_special_tokens=False)
output=output.replace("<pad>","")
output=output.replace("</s>","")
output_text = output.replace("[BT]","`")
# Return as Hugging Face API expects
return [{"generated_text": output_text}]