Spaces:
Running
Running
import gradio as gr | |
import torch | |
from dataclasses import dataclass | |
from transformers import AutoTokenizer, PretrainedConfig, GenerationConfig, TextIteratorStreamer | |
from optimum.onnxruntime import ORTModelForCausalLM | |
import onnx | |
import logging | |
from threading import Thread | |
logging.basicConfig(level=logging.INFO) | |
# ----------------------------------------------------------------------------- | |
# Configuration and Special Tokens | |
# ----------------------------------------------------------------------------- | |
SPECIAL_TOKENS = { | |
"bos": "<|bos|>", | |
"eot": "<|eot|>", | |
"user": "<|user|>", | |
"assistant": "<|assistant|>", | |
"system": "<|system|>", | |
"think": "<|think|>", | |
} | |
tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.add_special_tokens({"additional_special_tokens": list(SPECIAL_TOKENS.values())}) | |
SPECIAL_IDS = {k: tokenizer.convert_tokens_to_ids(v) for k, v in SPECIAL_TOKENS.items()} | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# ----------------------------------------------------------------------------- | |
# Custom Model Configuration | |
# ----------------------------------------------------------------------------- | |
class Sam3Config(PretrainedConfig): | |
vocab_size: int = 50257 | |
d_model: int = 384 | |
n_layers: int = 10 | |
n_heads: int = 6 | |
ff_mult: float = 4.0 | |
dropout: float = 0.1 | |
input_modality: str = "text" | |
head_type: str = "causal_lm" | |
version: str = "0.1" | |
_attn_implementation_internal: str = "eager" | |
is_encoder_decoder: bool = False | |
hidden_size: int = 384 | |
num_attention_heads: int = 6 | |
def __init__(self, vocab_size=50257, d_model=384, n_layers=10, n_heads=6, ff_mult=4.0, dropout=0.1, input_modality="text", head_type="causal_lm", version="0.1", **kwargs): | |
super().__init__(**kwargs) | |
self.vocab_size = vocab_size | |
self.d_model = d_model | |
self.n_layers = n_layers | |
self.n_heads = n_heads | |
self.ff_mult = ff_mult | |
self.dropout = dropout | |
self.input_modality = input_modality | |
self.head_type = head_type | |
self.version = version | |
self.hidden_size = self.d_model | |
self.num_attention_heads = self.n_heads | |
# Instantiate the custom configuration | |
model_config = Sam3Config() | |
# Load the ONNX model by providing the configuration | |
try: | |
model = ORTModelForCausalLM.from_pretrained( | |
"Smilyai-labs/Sam-3.0-2-onnx", | |
config=model_config, | |
trust_remote_code=True, | |
) | |
logging.info("ONNX model loaded successfully.") | |
except Exception as e: | |
logging.error(f"Failed to load ONNX model: {e}") | |
raise e | |
# ----------------------------------------------------------------------------- | |
# Streaming Generation Function | |
# ----------------------------------------------------------------------------- | |
def generate_text_stream(prompt, max_length, temperature, top_k, top_p): | |
""" | |
This function acts as a generator to stream text. | |
It yields each new token as it's generated by the model. | |
""" | |
# Create a streamer to iterate over the generated tokens | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
# Prepare the generation inputs | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
# Set generation parameters within a GenerationConfig object | |
# We explicitly set use_cache=False to avoid the ONNX export bug | |
gen_config = GenerationConfig( | |
max_length=max_length, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
do_sample=True, | |
use_cache=False, | |
) | |
# Create a thread to run the generation in the background | |
generation_kwargs = dict( | |
input_ids=input_ids, | |
streamer=streamer, | |
generation_config=gen_config, | |
) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Yield each token from the streamer as it is generated | |
for new_text in streamer: | |
yield new_text | |
# ----------------------------------------------------------------------------- | |
# Gradio Interface | |
# ----------------------------------------------------------------------------- | |
demo = gr.Interface( | |
fn=generate_text_stream, | |
inputs=[ | |
gr.Textbox(label="Prompt", lines=2), | |
gr.Slider(minimum=10, maximum=512, value=128, label="Max Length"), | |
gr.Slider(minimum=0.1, maximum=2.0, value=0.8, label="Temperature"), | |
gr.Slider(minimum=1, maximum=100, value=60, label="Top K"), | |
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P"), | |
], | |
outputs="text", | |
title="SmilyAI Sam 3.0-2 ONNX Text Generation (Streaming)", | |
description="A simple API and UI for text generation using the ONNX version of Sam 3.0-2, with streaming output.", | |
) | |
demo.launch() |