import gradio as gr import torch from dataclasses import dataclass from transformers import AutoTokenizer, PretrainedConfig, GenerationConfig, TextIteratorStreamer from optimum.onnxruntime import ORTModelForCausalLM import onnx import logging from threading import Thread logging.basicConfig(level=logging.INFO) # ----------------------------------------------------------------------------- # Configuration and Special Tokens # ----------------------------------------------------------------------------- SPECIAL_TOKENS = { "bos": "<|bos|>", "eot": "<|eot|>", "user": "<|user|>", "assistant": "<|assistant|>", "system": "<|system|>", "think": "<|think|>", } tokenizer = AutoTokenizer.from_pretrained("gpt2") if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.add_special_tokens({"additional_special_tokens": list(SPECIAL_TOKENS.values())}) SPECIAL_IDS = {k: tokenizer.convert_tokens_to_ids(v) for k, v in SPECIAL_TOKENS.items()} device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ----------------------------------------------------------------------------- # Custom Model Configuration # ----------------------------------------------------------------------------- @dataclass class Sam3Config(PretrainedConfig): vocab_size: int = 50257 d_model: int = 384 n_layers: int = 10 n_heads: int = 6 ff_mult: float = 4.0 dropout: float = 0.1 input_modality: str = "text" head_type: str = "causal_lm" version: str = "0.1" _attn_implementation_internal: str = "eager" is_encoder_decoder: bool = False hidden_size: int = 384 num_attention_heads: int = 6 def __init__(self, vocab_size=50257, d_model=384, n_layers=10, n_heads=6, ff_mult=4.0, dropout=0.1, input_modality="text", head_type="causal_lm", version="0.1", **kwargs): super().__init__(**kwargs) self.vocab_size = vocab_size self.d_model = d_model self.n_layers = n_layers self.n_heads = n_heads self.ff_mult = ff_mult self.dropout = dropout self.input_modality = input_modality self.head_type = head_type self.version = version self.hidden_size = self.d_model self.num_attention_heads = self.n_heads # Instantiate the custom configuration model_config = Sam3Config() # Load the ONNX model by providing the configuration try: model = ORTModelForCausalLM.from_pretrained( "Smilyai-labs/Sam-3.0-2-onnx", config=model_config, trust_remote_code=True, ) logging.info("ONNX model loaded successfully.") except Exception as e: logging.error(f"Failed to load ONNX model: {e}") raise e # ----------------------------------------------------------------------------- # Streaming Generation Function # ----------------------------------------------------------------------------- def generate_text_stream(prompt, max_length, temperature, top_k, top_p): """ This function acts as a generator to stream text. It yields each new token as it's generated by the model. """ # Create a streamer to iterate over the generated tokens streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) # Prepare the generation inputs input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) # Set generation parameters within a GenerationConfig object # We explicitly set use_cache=False to avoid the ONNX export bug gen_config = GenerationConfig( max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, do_sample=True, use_cache=False, ) # Create a thread to run the generation in the background generation_kwargs = dict( input_ids=input_ids, streamer=streamer, generation_config=gen_config, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Yield each token from the streamer as it is generated for new_text in streamer: yield new_text # ----------------------------------------------------------------------------- # Gradio Interface # ----------------------------------------------------------------------------- demo = gr.Interface( fn=generate_text_stream, inputs=[ gr.Textbox(label="Prompt", lines=2), gr.Slider(minimum=10, maximum=512, value=128, label="Max Length"), gr.Slider(minimum=0.1, maximum=2.0, value=0.8, label="Temperature"), gr.Slider(minimum=1, maximum=100, value=60, label="Top K"), gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P"), ], outputs="text", title="SmilyAI Sam 3.0-2 ONNX Text Generation (Streaming)", description="A simple API and UI for text generation using the ONNX version of Sam 3.0-2, with streaming output.", ) demo.launch()