sam-api / app.py
Keeby-smilyai's picture
Update app.py
01d2db3 verified
import gradio as gr
import torch
from dataclasses import dataclass
from transformers import AutoTokenizer, PretrainedConfig, GenerationConfig, TextIteratorStreamer
from optimum.onnxruntime import ORTModelForCausalLM
import onnx
import logging
from threading import Thread
logging.basicConfig(level=logging.INFO)
# -----------------------------------------------------------------------------
# Configuration and Special Tokens
# -----------------------------------------------------------------------------
SPECIAL_TOKENS = {
"bos": "<|bos|>",
"eot": "<|eot|>",
"user": "<|user|>",
"assistant": "<|assistant|>",
"system": "<|system|>",
"think": "<|think|>",
}
tokenizer = AutoTokenizer.from_pretrained("gpt2")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens({"additional_special_tokens": list(SPECIAL_TOKENS.values())})
SPECIAL_IDS = {k: tokenizer.convert_tokens_to_ids(v) for k, v in SPECIAL_TOKENS.items()}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -----------------------------------------------------------------------------
# Custom Model Configuration
# -----------------------------------------------------------------------------
@dataclass
class Sam3Config(PretrainedConfig):
vocab_size: int = 50257
d_model: int = 384
n_layers: int = 10
n_heads: int = 6
ff_mult: float = 4.0
dropout: float = 0.1
input_modality: str = "text"
head_type: str = "causal_lm"
version: str = "0.1"
_attn_implementation_internal: str = "eager"
is_encoder_decoder: bool = False
hidden_size: int = 384
num_attention_heads: int = 6
def __init__(self, vocab_size=50257, d_model=384, n_layers=10, n_heads=6, ff_mult=4.0, dropout=0.1, input_modality="text", head_type="causal_lm", version="0.1", **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.n_heads = n_heads
self.ff_mult = ff_mult
self.dropout = dropout
self.input_modality = input_modality
self.head_type = head_type
self.version = version
self.hidden_size = self.d_model
self.num_attention_heads = self.n_heads
# Instantiate the custom configuration
model_config = Sam3Config()
# Load the ONNX model by providing the configuration
try:
model = ORTModelForCausalLM.from_pretrained(
"Smilyai-labs/Sam-3.0-2-onnx",
config=model_config,
trust_remote_code=True,
)
logging.info("ONNX model loaded successfully.")
except Exception as e:
logging.error(f"Failed to load ONNX model: {e}")
raise e
# -----------------------------------------------------------------------------
# Streaming Generation Function
# -----------------------------------------------------------------------------
def generate_text_stream(prompt, max_length, temperature, top_k, top_p):
"""
This function acts as a generator to stream text.
It yields each new token as it's generated by the model.
"""
# Create a streamer to iterate over the generated tokens
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Prepare the generation inputs
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# Set generation parameters within a GenerationConfig object
# We explicitly set use_cache=False to avoid the ONNX export bug
gen_config = GenerationConfig(
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
use_cache=False,
)
# Create a thread to run the generation in the background
generation_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
generation_config=gen_config,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Yield each token from the streamer as it is generated
for new_text in streamer:
yield new_text
# -----------------------------------------------------------------------------
# Gradio Interface
# -----------------------------------------------------------------------------
demo = gr.Interface(
fn=generate_text_stream,
inputs=[
gr.Textbox(label="Prompt", lines=2),
gr.Slider(minimum=10, maximum=512, value=128, label="Max Length"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.8, label="Temperature"),
gr.Slider(minimum=1, maximum=100, value=60, label="Top K"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P"),
],
outputs="text",
title="SmilyAI Sam 3.0-2 ONNX Text Generation (Streaming)",
description="A simple API and UI for text generation using the ONNX version of Sam 3.0-2, with streaming output.",
)
demo.launch()