File size: 2,370 Bytes
a541a16 5e56809 3a5aa1c 9fe06a2 5e73874 5e56809 a541a16 9fe06a2 5e56809 587b592 5e56809 587b592 5e56809 fe60645 587b592 5e56809 9fe06a2 5e56809 9fe06a2 5e56809 9fe06a2 5e56809 9fe06a2 5e56809 fe60645 5e56809 fe60645 9fe06a2 fe60645 5e56809 fe60645 5e56809 fe60645 5e56809 fe60645 5e56809 9fe06a2 587b592 5e56809 7c28246 5e56809 61d92e6 5e56809 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import gradio as gr
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
# Load non-gated model and tokenizer
model_id = "Qwen/Qwen2-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
def chat(message, history):
# Build message history
messages = [
{"role": "system", "content": "You are a helpful and friendly assistant named Vidyut, an Indian AI created by Rapnss Production Studio India, running on the Rapnss-vertex LLM model."}
] + history + [{"role": "user", "content": message}]
# Prepare inputs
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
# Set up streamer for live typing
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Generation kwargs
generation_kwargs = {
"inputs": inputs,
"streamer": streamer,
"max_new_tokens": 256,
"do_sample": True,
"top_p": 0.95,
"temperature": 0.7,
}
# Run generation in a separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Buffer tokens into chunks
generated_text = ""
chunk_buffer = ""
min_chunk_length = 10 # Minimum characters before yielding
punctuation_marks = [".", ",", "!", "?", ";", ":"]
for new_text in streamer:
chunk_buffer += new_text
# Check if buffer has a punctuation mark or is long enough
if any(p in chunk_buffer for p in punctuation_marks) or len(chunk_buffer) >= min_chunk_length:
generated_text += chunk_buffer
yield generated_text
chunk_buffer = "" # Reset buffer after yielding
# Yield any remaining text in the buffer
if chunk_buffer:
generated_text += chunk_buffer
yield generated_text
thread.join()
# Create Gradio chat interface
demo = gr.ChatInterface(
fn=chat,
type="messages",
title="Vidyut Omega with VIA-1 Chatbot",
description="Chat with Vidyut, powered by VIA-1.",
examples=[["Tell me a fun fact."], ["Explain neural networks in simple terms."]],
)
demo.launch() |