|
import gradio as gr |
|
from threading import Thread |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
|
import torch |
|
|
|
|
|
model_id = "Qwen/Qwen2-1.5B-Instruct" |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto") |
|
|
|
def chat(message, history): |
|
|
|
messages = [ |
|
{"role": "system", "content": "You are a helpful and friendly assistant named Vidyut, an Indian AI created by Rapnss Production Studio India, running on the Rapnss-vertex LLM model."} |
|
] + history + [{"role": "user", "content": message}] |
|
|
|
|
|
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device) |
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
generation_kwargs = { |
|
"inputs": inputs, |
|
"streamer": streamer, |
|
"max_new_tokens": 256, |
|
"do_sample": True, |
|
"top_p": 0.95, |
|
"temperature": 0.7, |
|
} |
|
|
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
|
|
generated_text = "" |
|
chunk_buffer = "" |
|
min_chunk_length = 10 |
|
punctuation_marks = [".", ",", "!", "?", ";", ":"] |
|
|
|
for new_text in streamer: |
|
chunk_buffer += new_text |
|
|
|
if any(p in chunk_buffer for p in punctuation_marks) or len(chunk_buffer) >= min_chunk_length: |
|
generated_text += chunk_buffer |
|
yield generated_text |
|
chunk_buffer = "" |
|
|
|
|
|
if chunk_buffer: |
|
generated_text += chunk_buffer |
|
yield generated_text |
|
|
|
thread.join() |
|
|
|
|
|
demo = gr.ChatInterface( |
|
fn=chat, |
|
type="messages", |
|
title="Vidyut Omega with VIA-1 Chatbot", |
|
description="Chat with Vidyut, powered by VIA-1.", |
|
examples=[["Tell me a fun fact."], ["Explain neural networks in simple terms."]], |
|
) |
|
|
|
demo.launch() |