import gradio as gr from threading import Thread from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer import torch # Load non-gated model and tokenizer model_id = "Qwen/Qwen2-1.5B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto") def chat(message, history): # Build message history messages = [ {"role": "system", "content": "You are a helpful and friendly assistant named Vidyut, an Indian AI created by Rapnss Production Studio India, running on the Rapnss-vertex LLM model."} ] + history + [{"role": "user", "content": message}] # Prepare inputs inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device) # Set up streamer for live typing streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) # Generation kwargs generation_kwargs = { "inputs": inputs, "streamer": streamer, "max_new_tokens": 256, "do_sample": True, "top_p": 0.95, "temperature": 0.7, } # Run generation in a separate thread thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Buffer tokens into chunks generated_text = "" chunk_buffer = "" min_chunk_length = 10 # Minimum characters before yielding punctuation_marks = [".", ",", "!", "?", ";", ":"] for new_text in streamer: chunk_buffer += new_text # Check if buffer has a punctuation mark or is long enough if any(p in chunk_buffer for p in punctuation_marks) or len(chunk_buffer) >= min_chunk_length: generated_text += chunk_buffer yield generated_text chunk_buffer = "" # Reset buffer after yielding # Yield any remaining text in the buffer if chunk_buffer: generated_text += chunk_buffer yield generated_text thread.join() # Create Gradio chat interface demo = gr.ChatInterface( fn=chat, type="messages", title="Vidyut Omega with VIA-1 Chatbot", description="Chat with Vidyut, powered by VIA-1.", examples=[["Tell me a fun fact."], ["Explain neural networks in simple terms."]], ) demo.launch()