Spaces:
Sleeping
Sleeping
import torch | |
from PIL import Image | |
from fastapi import FastAPI, HTTPException,Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from transformers import AutoProcessor, AutoModelForVision2Seq | |
from pydantic import BaseModel | |
import base64 | |
import io | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.bfloat16 if DEVICE == "cuda" and torch.cuda.is_bf16_supported() else torch.float32 | |
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-500M-Instruct") | |
model = AutoModelForVision2Seq.from_pretrained( | |
"HuggingFaceTB/SmolVLM-500M-Instruct", | |
torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32, | |
_attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager", | |
).to(DEVICE) | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class PredictRequest(BaseModel): | |
instruction: str | |
imageBase64URL: str | |
async def predict(request: PredictRequest): | |
try: | |
header, base64_string = request.imageBase64URL.split(',', 1) | |
image_bytes = base64.b64decode(base64_string) | |
image = Image.open(io.BytesIO(image_bytes)) | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image"}, | |
{"type": "text", "text": request.instruction} | |
] | |
}, | |
] | |
prompt = processor.apply_chat_template(messages, add_generation_prompt=True) | |
inputs = processor(text=prompt, images=[image], return_tensors="pt").to(DEVICE) | |
generated_ids = model.generate(**inputs, max_new_tokens=500) | |
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
response_text = generated_texts[0] | |
return {"response": response_text} | |
except Exception as e: | |
print(f"Error durante la predicción: {e}") | |
raise HTTPException(status_code=500, detail=f"Internal Server Error: {e}") | |
async def read_root(request: Request): | |
current_path = request.url.path | |
print(f"Received GET request at path: {current_path}") | |
return {"message": "SmolVLM-500M API is running!"} |