File size: 4,786 Bytes
baa67ca
 
7e43cba
baa67ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e43cba
 
baa67ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import spaces
import gradio as gr
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM, TextIteratorStreamer
import torch
import torch.amp.autocast_mode
from PIL import Image
import torchvision.transforms.functional as TVF
from threading import Thread
from typing import Generator


MODEL_PATH = "ytu-ce-cosmos/Turkish-LLaVA-v0.1"
TITLE = "<h1><center>Turkish LLaVA - Görsel Soru Cevap Sistemi</center></h1>"
DESCRIPTION = "Bu model görsel içerikleri analiz ederek Türkçe sorularınızı yanıtlar. Bir resim yükleyip hakkında soru sorabilirsiniz."

PLACEHOLDER = "Merhaba! Size nasıl yardımcı olabilirim? Bir resim yükleyip hakkında soru sorabilirsiniz."


tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=True)
assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Expected PreTrainedTokenizer, got {type(tokenizer)}"

model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype="bfloat16", device_map=0)
# assert isinstance(model, LlavaLlamaForCausalLM), f"Expected LlavaLlamaForCausalLM, got {type(model)}"


@spaces.GPU()
@torch.no_grad()
def chat_turkish_llava(message: dict, history, temperature: float, top_p: float, max_new_tokens: int) -> Generator[str, None, None]:
	torch.cuda.empty_cache()

	prompt = message['text'].strip()

	if "files" not in message or len(message["files"]) != 1:
		yield "HATA: Bu model tam olarak bir resim girişi gerektirir."
		return
	
	image = Image.open(message["files"][0])
	
	print(f"Prompt: {prompt}")

	if image.size != (384, 384):
		image = image.resize((384, 384), Image.LANCZOS)
	image = image.convert("RGB")
	pixel_values = TVF.pil_to_tensor(image)

	convo = [
		{
			"role": "system",
			"content": "Sen bir yapay zeka asistanısın. Kullanıcı sana bir görev verecek. Amacın görevi olabildiğince sadık bir şekilde tamamlamak. Görevi yerine getirirken adım adım düşün ve adımlarını gerekçelendir.",
		},
		{
			"role": "user",
			"content": prompt,
		},
	]

	convo_string = tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = True)
	assert isinstance(convo_string, str)

	convo_tokens = tokenizer.encode(convo_string, add_special_tokens=False, truncation=False)

	input_tokens = []
	for token in convo_tokens:
		if token == model.config.image_token_index:
			input_tokens.extend([model.config.image_token_index] * model.config.image_seq_length)
		else:
			input_tokens.append(token)
	
	input_ids = torch.tensor(input_tokens, dtype=torch.long)
	attention_mask = torch.ones_like(input_ids)

	input_ids = input_ids.unsqueeze(0).to("cuda")
	attention_mask = attention_mask.unsqueeze(0).to("cuda")
	pixel_values = pixel_values.unsqueeze(0).to("cuda")

	pixel_values = pixel_values / 255.0
	pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
	pixel_values = pixel_values.to(torch.bfloat16)

	streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

	generate_kwargs = dict(
		input_ids=input_ids,
		pixel_values=pixel_values,
		attention_mask=attention_mask,
		max_new_tokens=max_new_tokens,
		do_sample=True,
		suppress_tokens=None,
		use_cache=True,
		temperature=temperature,
		top_k=None,
		top_p=top_p,
		streamer=streamer,
	)

	if temperature == 0:
		generate_kwargs["do_sample"] = False
	
	t = Thread(target=model.generate, kwargs=generate_kwargs)
	t.start()

	outputs = []
	for text in streamer:
		outputs.append(text)
		yield "".join(outputs)


chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Turkish LLaVA Sohbet', type="messages")
textbox = gr.MultimodalTextbox(file_types=["image"], file_count="single", placeholder="Mesajınızı yazın ve resim yükleyin...")

with gr.Blocks() as demo:
	gr.HTML(TITLE)
	chat_interface = gr.ChatInterface(
		fn=chat_turkish_llava,
		chatbot=chatbot,
		type="messages",
		fill_height=True,
		multimodal=True,
		textbox=textbox,
		examples=[{"text": "Bu kitabın adı ne?", "files": ["./book.jpg"]},
                  {"text": "Çiçeğin üzerinde ne var?", "files": ["./bee.jpg"]},
                  {"text": "Bu tatlı nasıl yapılır?", "files": ["./baklava.png"]}],
		additional_inputs_accordion=gr.Accordion(label="⚙️ Parametreler", open=True, render=False),
		additional_inputs=[
			gr.Slider(minimum=0,
						maximum=1, 
						step=0.1,
						value=0.6, 
						label="Temperature", 
						render=False),
			gr.Slider(minimum=0,
			 			maximum=1,
						step=0.05,
						value=0.9,
						label="Top P",
						render=False),
			gr.Slider(minimum=8, 
						maximum=4096,
						step=1,
						value=1024, 
						label="Max New Tokens", 
						render=False ),
		],
    )
	gr.Markdown(DESCRIPTION)


if __name__ == "__main__":
    demo.launch()