File size: 2,363 Bytes
afd7003
89f24bb
afd7003
 
 
89f24bb
d7b3955
afd7003
d7b3955
bb4ad11
afd7003
 
 
d7b3955
 
afd7003
d7b3955
afd7003
d7b3955
afd7003
 
 
89f24bb
afd7003
 
 
89f24bb
afd7003
 
 
d7b3955
afd7003
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89f24bb
afd7003
 
 
 
bb4ad11
afd7003
bb4ad11
afd7003
 
d7b3955
 
 
89f24bb
afd7003
d7b3955
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import spaces
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

# Remove GPU decorator since we are CPU-only
def predict(message, history):
    # Load model and tokenizer on CPU
    model_id = "kurakurai/Luth-LFM2-350M"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="cpu",       # CPU only
        torch_dtype=torch.float16,
        trust_remote_code=True,
        load_in_4bit=False      # 4-bit quantization not supported on CPU
    )

    # Format conversation history for chat template
    messages = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} 
                for conv in history for i, msg in enumerate(conv) if msg]
    messages.append({"role": "user", "content": message})
    
    # Apply chat template
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt",
        tokenize=True
    ).to('cpu')  # CPU device
    
    # Setup streamer for real-time output
    streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
    
    # Generation parameters
    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=256,
        do_sample=True,
        temperature=0.3,
        min_p=0.15,
        repetition_penalty=1.05,
        pad_token_id=tokenizer.eos_token_id
    )
    
    # Start generation in separate thread
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    
    # Stream tokens
    partial_message = ""
    for new_token in streamer:
        partial_message += new_token
        yield partial_message

# Setup Gradio interface
gr.ChatInterface(
    predict,
    description="""
    <center><h2>Kurakura AI Luth-LFM2-350M Chat</h2></center>
    
    Chat with [Luth-LFM2-350M](https://huggingface.co/kurakurai/Luth-LFM2-350M), a French-tuned version of LFM2-350M.
    """,
    examples=[
        "Peux-tu résoudre l'équation 3x - 7 = 11 pour x ?",
        "Explique la photosynthèse en termes simples.",
        "Écris un petit poème sur l'intelligence artificielle."
    ],
    theme=gr.themes.Soft(primary_hue="blue"),
).launch()