File size: 2,370 Bytes
a541a16
5e56809
3a5aa1c
9fe06a2
5e73874
5e56809
a541a16
9fe06a2
5e56809
587b592
5e56809
 
587b592
5e56809
fe60645
587b592
5e56809
9fe06a2
5e56809
 
9fe06a2
 
5e56809
9fe06a2
 
 
 
 
 
 
 
 
5e56809
9fe06a2
 
 
5e56809
 
fe60645
5e56809
fe60645
 
9fe06a2
fe60645
5e56809
fe60645
5e56809
 
 
fe60645
5e56809
fe60645
5e56809
 
9fe06a2
 
587b592
5e56809
 
 
 
7c28246
 
5e56809
 
61d92e6
5e56809
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch

# Load non-gated model and tokenizer
model_id = "Qwen/Qwen2-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")

def chat(message, history):
    # Build message history
    messages = [
        {"role": "system", "content": "You are a helpful and friendly assistant named Vidyut, an Indian AI created by Rapnss Production Studio India, running on the Rapnss-vertex LLM model."}
    ] + history + [{"role": "user", "content": message}]
    
    # Prepare inputs
    inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
    
    # Set up streamer for live typing
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    
    # Generation kwargs
    generation_kwargs = {
        "inputs": inputs,
        "streamer": streamer,
        "max_new_tokens": 256,
        "do_sample": True,
        "top_p": 0.95,
        "temperature": 0.7,
    }
    
    # Run generation in a separate thread
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    # Buffer tokens into chunks
    generated_text = ""
    chunk_buffer = ""
    min_chunk_length = 10  # Minimum characters before yielding
    punctuation_marks = [".", ",", "!", "?", ";", ":"]
    
    for new_text in streamer:
        chunk_buffer += new_text
        # Check if buffer has a punctuation mark or is long enough
        if any(p in chunk_buffer for p in punctuation_marks) or len(chunk_buffer) >= min_chunk_length:
            generated_text += chunk_buffer
            yield generated_text
            chunk_buffer = ""  # Reset buffer after yielding
    
    # Yield any remaining text in the buffer
    if chunk_buffer:
        generated_text += chunk_buffer
        yield generated_text
    
    thread.join()

# Create Gradio chat interface
demo = gr.ChatInterface(
    fn=chat,
    type="messages",
    title="Vidyut Omega with VIA-1 Chatbot",
    description="Chat with Vidyut, powered by VIA-1.",
    examples=[["Tell me a fun fact."], ["Explain neural networks in simple terms."]],
)

demo.launch()