File size: 2,060 Bytes
fd7d432
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9ca3d0
 
fd7d432
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from transformers import AutoProcessor, AutoModelForVision2Seq, Idefics3ForConditionalGeneration
from PIL import Image, ImageOps
import torch
from peft import PeftModel
from huggingface_hub import snapshot_download

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

base_model_name = "HuggingFaceTB/SmolVLM-256M-Instruct"

model = Idefics3ForConditionalGeneration.from_pretrained(
    base_model_name,
    torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
).to(device)

processor = AutoProcessor.from_pretrained(base_model_name)

repo_local_path = snapshot_download(
    repo_id="Irina1402/smolvlm-painting-description",
    cache_dir="/tmp"
)

model = PeftModel.from_pretrained(model, model_id=repo_local_path)
model.eval()



def process_chat(text: str = None, image: Image.Image = None):
    """Process the input and generate a response using SmolVLM."""

    image_data = None  # Initialize the image_data variable

    inputs = []
    if image:
        image_data = Image.open(image.file).convert("RGB")
        image_data = ImageOps.exif_transpose(image_data) 
        inputs.append({"type": "image"})

    if text:
        inputs.append({"type": "text", "text": text})

    message = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": text}]}]

    prompt = processor.apply_chat_template(message, add_generation_prompt=True)

    print(f"Prepared prompt:\n{prompt}") 

    processed_inputs = processor(
        text=prompt,
        images=[image_data] if image_data else None,  
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        generated_ids = model.generate(**processed_inputs, max_new_tokens=350,repetition_penalty=1.1)

    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    assistant_text = generated_text.split("Assistant:", 1)[-1].strip()

    if "." in assistant_text:
        last_period_idx = assistant_text.rfind(".")
        assistant_text = assistant_text[:last_period_idx + 1].strip()

    return assistant_text