VI1 / app.py
Invescoz's picture
Update app.py
7c28246 verified
import gradio as gr
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
# Load non-gated model and tokenizer
model_id = "Qwen/Qwen2-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
def chat(message, history):
# Build message history
messages = [
{"role": "system", "content": "You are a helpful and friendly assistant named Vidyut, an Indian AI created by Rapnss Production Studio India, running on the Rapnss-vertex LLM model."}
] + history + [{"role": "user", "content": message}]
# Prepare inputs
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
# Set up streamer for live typing
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Generation kwargs
generation_kwargs = {
"inputs": inputs,
"streamer": streamer,
"max_new_tokens": 256,
"do_sample": True,
"top_p": 0.95,
"temperature": 0.7,
}
# Run generation in a separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Buffer tokens into chunks
generated_text = ""
chunk_buffer = ""
min_chunk_length = 10 # Minimum characters before yielding
punctuation_marks = [".", ",", "!", "?", ";", ":"]
for new_text in streamer:
chunk_buffer += new_text
# Check if buffer has a punctuation mark or is long enough
if any(p in chunk_buffer for p in punctuation_marks) or len(chunk_buffer) >= min_chunk_length:
generated_text += chunk_buffer
yield generated_text
chunk_buffer = "" # Reset buffer after yielding
# Yield any remaining text in the buffer
if chunk_buffer:
generated_text += chunk_buffer
yield generated_text
thread.join()
# Create Gradio chat interface
demo = gr.ChatInterface(
fn=chat,
type="messages",
title="Vidyut Omega with VIA-1 Chatbot",
description="Chat with Vidyut, powered by VIA-1.",
examples=[["Tell me a fun fact."], ["Explain neural networks in simple terms."]],
)
demo.launch()