|
from transformers import AutoModelForCausalLM, AutoProcessor |
|
from PIL import Image |
|
import requests |
|
import torch |
|
import io |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True).to(device) |
|
self.processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True) |
|
self.device = device |
|
|
|
def __call__(self, data): |
|
try: |
|
url = data.get("inputs", {}).get("url") |
|
if not url: |
|
return {"error": "Missing URL"} |
|
|
|
headers = { |
|
"User-Agent": "Mozilla/5.0", |
|
"Accept": "image/*" |
|
} |
|
response = requests.get(url, headers=headers, verify=False) |
|
response.raise_for_status() |
|
|
|
image_data = io.BytesIO(response.content) |
|
image = Image.open(image_data).convert("RGB") |
|
|
|
inputs = self.processor( |
|
text="<MORE_DETAILED_CAPTION>", |
|
images=image, |
|
return_tensors="pt" |
|
).to(self.device) |
|
|
|
with torch.inference_mode(): |
|
output = self.model.generate( |
|
**inputs, |
|
max_new_tokens=512, |
|
num_beams=3 |
|
) |
|
|
|
text = self.processor.batch_decode(output, skip_special_tokens=True)[0] |
|
return {"caption": text} |
|
|
|
except Exception as e: |
|
return {"error": str(e)} |