File size: 3,675 Bytes
86e3b11
6d40614
 
c081597
 
 
 
 
 
 
 
 
 
 
 
 
6d40614
9a1f1ed
 
 
 
 
 
 
8e0463b
9a1f1ed
 
6d40614
8e0463b
 
 
 
6d40614
 
 
 
 
 
71f62cd
9a63c43
d5a5044
8e0463b
 
 
 
 
6d40614
 
 
 
 
 
 
 
 
 
f9ce403
6d40614
8e0463b
6d40614
c389077
86e3b11
71f62cd
623da4d
d5a5044
 
6d40614
2524cd0
d5a5044
623da4d
d5a5044
 
 
 
2524cd0
6d40614
 
d5a5044
2524cd0
 
 
 
 
 
 
 
 
 
6d40614
f9ce403
9a63c43
8e0463b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
from gradio_client import utils as client_utils

# پچ کردن تابع get_type
def patched_get_type(schema):
    if isinstance(schema, bool):
        return "any" if schema else "never"
    if "const" in schema:
        return repr(schema["const"])
    if "type" in schema:
        return schema["type"]
    return "any"

client_utils.get_type = patched_get_type

# چاپ نسخه‌های پکیج‌ها برای عیب‌یابی
print(f"Gradio version: {gr.__version__}")
print(f"Transformers version: {transformers.__version__}")
print(f"Torch version: {torch.__version__}")
print(f"Huggingface_hub version: {huggingface_hub.__version__}")

# Load model and tokenizer with force_download
model_name = "arshiaafshani/Arsh-llm"
tokenizer = AutoTokenizer.from_pretrained(model_name, force_download=True, resume_download=False)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, force_download=True, resume_download=False)

# تنظیم توکن‌های خاص
tokenizer.bos_token = "<sos>"
tokenizer.eos_token = "<|endoftext|>"

# Create pipeline
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device=0 if torch.cuda.is_available() else -1
)

def respond(message, chat_history, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty):
    chat_history = chat_history or []
    messages = [{"role": "system", "content": system_message}] + \
               [{"role": "user", "content": msg} for msg, _ in chat_history] + \
               [{"role": "user", "content": message}, {"role": "assistant", "content": ""}]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False)
    
    output = pipe(
        prompt,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repeat_penalty,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )
    
    response = output[0]['generated_text'][len(prompt):].strip()
    chat_history.append((message, response))
    return chat_history

with gr.Blocks() as demo:
    gr.Markdown("# Arsh-LLM Demo")
    with gr.Row():
        with gr.Column():
            system_msg = gr.Textbox("You are Arsh, a helpful assistant by Arshia Afshani. You should answer the user carefully.", 
                                  label="System Message")
            max_tokens = gr.Slider(1, 4096, value=2048, step=1, label="Max Tokens")
            temperature = gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature")
            top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
            top_k = gr.Slider(0, 100, value=40, step=1, label="Top-k")
            repeat_penalty = gr.Slider(0.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty")

    chatbot = gr.Chatbot(height=500)
    msg = gr.Textbox(label="Your Message")
    clear = gr.Button("Clear")

    def submit_message(message, chat_history, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty):
        chat_history = chat_history or []
        response = respond(message, chat_history, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty)
        return "", response

    msg.submit(
        submit_message,
        [msg, chatbot, system_msg, max_tokens, temperature, top_p, top_k, repeat_penalty],
        [msg, chatbot]
    )
    clear.click(lambda: None, None, chatbot, queue=False)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)