Spaces:
Sleeping
Sleeping
File size: 2,289 Bytes
106f99c a6b8ede 106f99c a6b8ede 106f99c a6b8ede 106f99c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import torch
from PIL import Image
from fastapi import FastAPI, HTTPException,Request
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoProcessor, AutoModelForVision2Seq
from pydantic import BaseModel
import base64
import io
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if DEVICE == "cuda" and torch.cuda.is_bf16_supported() else torch.float32
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-500M-Instruct")
model = AutoModelForVision2Seq.from_pretrained(
"HuggingFaceTB/SmolVLM-500M-Instruct",
torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32,
_attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager",
).to(DEVICE)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class PredictRequest(BaseModel):
instruction: str
imageBase64URL: str
@app.post("/predict")
async def predict(request: PredictRequest):
try:
header, base64_string = request.imageBase64URL.split(',', 1)
image_bytes = base64.b64decode(base64_string)
image = Image.open(io.BytesIO(image_bytes))
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": request.instruction}
]
},
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[image], return_tensors="pt").to(DEVICE)
generated_ids = model.generate(**inputs, max_new_tokens=500)
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
response_text = generated_texts[0]
return {"response": response_text}
except Exception as e:
print(f"Error durante la predicción: {e}")
raise HTTPException(status_code=500, detail=f"Internal Server Error: {e}")
@app.get("/")
async def read_root(request: Request):
current_path = request.url.path
print(f"Received GET request at path: {current_path}")
return {"message": "SmolVLM-500M API is running!"} |