File size: 4,913 Bytes
2ddcd9e
 
6dcd36b
01d2db3
131258a
 
1a351ce
01d2db3
1a351ce
 
2ddcd9e
0ec6577
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131258a
0ec6577
6dcd36b
 
 
 
 
 
 
 
 
 
 
 
131258a
f5f8831
 
 
6dcd36b
0ec6577
 
 
 
 
 
 
 
 
 
 
f5f8831
 
6dcd36b
131258a
 
6dcd36b
131258a
1a351ce
 
 
 
 
 
 
 
 
 
 
2ddcd9e
01d2db3
 
 
 
 
 
 
 
 
 
 
 
 
 
eb48590
01d2db3
1a351ce
 
 
 
 
 
01d2db3
1a351ce
 
01d2db3
 
 
 
 
2ddcd9e
01d2db3
 
 
 
 
 
2ddcd9e
01d2db3
 
 
2ddcd9e
01d2db3
2ddcd9e
 
 
 
 
 
 
 
01d2db3
 
2ddcd9e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import gradio as gr
import torch
from dataclasses import dataclass
from transformers import AutoTokenizer, PretrainedConfig, GenerationConfig, TextIteratorStreamer
from optimum.onnxruntime import ORTModelForCausalLM
import onnx
import logging
from threading import Thread

logging.basicConfig(level=logging.INFO)

# -----------------------------------------------------------------------------
# Configuration and Special Tokens
# -----------------------------------------------------------------------------
SPECIAL_TOKENS = {
    "bos": "<|bos|>",
    "eot": "<|eot|>",
    "user": "<|user|>",
    "assistant": "<|assistant|>",
    "system": "<|system|>",
    "think": "<|think|>",
}
tokenizer = AutoTokenizer.from_pretrained("gpt2")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens({"additional_special_tokens": list(SPECIAL_TOKENS.values())})
SPECIAL_IDS = {k: tokenizer.convert_tokens_to_ids(v) for k, v in SPECIAL_TOKENS.items()}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------------------------------------------------------
# Custom Model Configuration
# -----------------------------------------------------------------------------
@dataclass
class Sam3Config(PretrainedConfig):
    vocab_size: int = 50257
    d_model: int = 384
    n_layers: int = 10
    n_heads: int = 6
    ff_mult: float = 4.0
    dropout: float = 0.1
    input_modality: str = "text"
    head_type: str = "causal_lm"
    version: str = "0.1"
    _attn_implementation_internal: str = "eager"
    is_encoder_decoder: bool = False
    
    hidden_size: int = 384
    num_attention_heads: int = 6

    def __init__(self, vocab_size=50257, d_model=384, n_layers=10, n_heads=6, ff_mult=4.0, dropout=0.1, input_modality="text", head_type="causal_lm", version="0.1", **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.ff_mult = ff_mult
        self.dropout = dropout
        self.input_modality = input_modality
        self.head_type = head_type
        self.version = version
        self.hidden_size = self.d_model
        self.num_attention_heads = self.n_heads

# Instantiate the custom configuration
model_config = Sam3Config()

# Load the ONNX model by providing the configuration
try:
    model = ORTModelForCausalLM.from_pretrained(
        "Smilyai-labs/Sam-3.0-2-onnx",
        config=model_config,
        trust_remote_code=True,
    )
    logging.info("ONNX model loaded successfully.")
    
except Exception as e:
    logging.error(f"Failed to load ONNX model: {e}")
    raise e

# -----------------------------------------------------------------------------
# Streaming Generation Function
# -----------------------------------------------------------------------------
def generate_text_stream(prompt, max_length, temperature, top_k, top_p):
    """
    This function acts as a generator to stream text.
    It yields each new token as it's generated by the model.
    """
    # Create a streamer to iterate over the generated tokens
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    # Prepare the generation inputs
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

    # Set generation parameters within a GenerationConfig object
    # We explicitly set use_cache=False to avoid the ONNX export bug
    gen_config = GenerationConfig(
        max_length=max_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=True,
        use_cache=False,
    )

    # Create a thread to run the generation in the background
    generation_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        generation_config=gen_config,
    )
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    # Yield each token from the streamer as it is generated
    for new_text in streamer:
        yield new_text

# -----------------------------------------------------------------------------
# Gradio Interface
# -----------------------------------------------------------------------------
demo = gr.Interface(
    fn=generate_text_stream,
    inputs=[
        gr.Textbox(label="Prompt", lines=2),
        gr.Slider(minimum=10, maximum=512, value=128, label="Max Length"),
        gr.Slider(minimum=0.1, maximum=2.0, value=0.8, label="Temperature"),
        gr.Slider(minimum=1, maximum=100, value=60, label="Top K"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P"),
    ],
    outputs="text",
    title="SmilyAI Sam 3.0-2 ONNX Text Generation (Streaming)",
    description="A simple API and UI for text generation using the ONNX version of Sam 3.0-2, with streaming output.",
)

demo.launch()