igmeMarcial commited on
Commit
106f99c
·
1 Parent(s): a16b16a

Add SmolVLM API

Browse files
Files changed (2) hide show
  1. app.py +66 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from fastapi import FastAPI, HTTPException
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from transformers import AutoProcessor, AutoModelForVision2Seq
6
+ from pydantic import BaseModel
7
+ import base64
8
+ import io
9
+
10
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ torch_dtype = torch.bfloat16 if DEVICE == "cuda" and torch.cuda.is_bf16_supported() else torch.float32
13
+
14
+ processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-500M-Instruct")
15
+ model = AutoModelForVision2Seq.from_pretrained(
16
+ "HuggingFaceTB/SmolVLM-500M-Instruct",
17
+ torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32,
18
+ _attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager",
19
+ ).to(DEVICE)
20
+
21
+ app = FastAPI()
22
+
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["*"],
26
+ allow_credentials=True,
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
+ class PredictRequest(BaseModel):
32
+ instruction: str
33
+ imageBase64URL: str
34
+
35
+
36
+ @app.post("/v1/chat/completions")
37
+ async def predict(request: PredictRequest):
38
+ try:
39
+ header, base64_string = request.imageBase64URL.split(',', 1)
40
+ image_bytes = base64.b64decode(base64_string)
41
+ image = Image.open(io.BytesIO(image_bytes))
42
+
43
+ messages = [
44
+ {
45
+ "role": "user",
46
+ "content": [
47
+ {"type": "image"},
48
+ {"type": "text", "text": request.instruction}
49
+ ]
50
+ },
51
+ ]
52
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
53
+ inputs = processor(text=prompt, images=[image], return_tensors="pt").to(DEVICE)
54
+ generated_ids = model.generate(**inputs, max_new_tokens=500)
55
+ generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
56
+ response_text = generated_texts[0]
57
+ return {"response": response_text}
58
+
59
+ except Exception as e:
60
+ print(f"Error durante la predicción: {e}")
61
+ raise HTTPException(status_code=500, detail=f"Internal Server Error: {e}")
62
+
63
+
64
+ @app.get("/")
65
+ async def read_root():
66
+ return {"message": "SmolVLM-500M API is running!"}
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ Pillow
4
+ fastapi
5
+ uvicorn[standard]