Spaces:
Running
Running
File size: 4,913 Bytes
2ddcd9e 6dcd36b 01d2db3 131258a 1a351ce 01d2db3 1a351ce 2ddcd9e 0ec6577 131258a 0ec6577 6dcd36b 131258a f5f8831 6dcd36b 0ec6577 f5f8831 6dcd36b 131258a 6dcd36b 131258a 1a351ce 2ddcd9e 01d2db3 eb48590 01d2db3 1a351ce 01d2db3 1a351ce 01d2db3 2ddcd9e 01d2db3 2ddcd9e 01d2db3 2ddcd9e 01d2db3 2ddcd9e 01d2db3 2ddcd9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import gradio as gr
import torch
from dataclasses import dataclass
from transformers import AutoTokenizer, PretrainedConfig, GenerationConfig, TextIteratorStreamer
from optimum.onnxruntime import ORTModelForCausalLM
import onnx
import logging
from threading import Thread
logging.basicConfig(level=logging.INFO)
# -----------------------------------------------------------------------------
# Configuration and Special Tokens
# -----------------------------------------------------------------------------
SPECIAL_TOKENS = {
"bos": "<|bos|>",
"eot": "<|eot|>",
"user": "<|user|>",
"assistant": "<|assistant|>",
"system": "<|system|>",
"think": "<|think|>",
}
tokenizer = AutoTokenizer.from_pretrained("gpt2")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens({"additional_special_tokens": list(SPECIAL_TOKENS.values())})
SPECIAL_IDS = {k: tokenizer.convert_tokens_to_ids(v) for k, v in SPECIAL_TOKENS.items()}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -----------------------------------------------------------------------------
# Custom Model Configuration
# -----------------------------------------------------------------------------
@dataclass
class Sam3Config(PretrainedConfig):
vocab_size: int = 50257
d_model: int = 384
n_layers: int = 10
n_heads: int = 6
ff_mult: float = 4.0
dropout: float = 0.1
input_modality: str = "text"
head_type: str = "causal_lm"
version: str = "0.1"
_attn_implementation_internal: str = "eager"
is_encoder_decoder: bool = False
hidden_size: int = 384
num_attention_heads: int = 6
def __init__(self, vocab_size=50257, d_model=384, n_layers=10, n_heads=6, ff_mult=4.0, dropout=0.1, input_modality="text", head_type="causal_lm", version="0.1", **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.n_heads = n_heads
self.ff_mult = ff_mult
self.dropout = dropout
self.input_modality = input_modality
self.head_type = head_type
self.version = version
self.hidden_size = self.d_model
self.num_attention_heads = self.n_heads
# Instantiate the custom configuration
model_config = Sam3Config()
# Load the ONNX model by providing the configuration
try:
model = ORTModelForCausalLM.from_pretrained(
"Smilyai-labs/Sam-3.0-2-onnx",
config=model_config,
trust_remote_code=True,
)
logging.info("ONNX model loaded successfully.")
except Exception as e:
logging.error(f"Failed to load ONNX model: {e}")
raise e
# -----------------------------------------------------------------------------
# Streaming Generation Function
# -----------------------------------------------------------------------------
def generate_text_stream(prompt, max_length, temperature, top_k, top_p):
"""
This function acts as a generator to stream text.
It yields each new token as it's generated by the model.
"""
# Create a streamer to iterate over the generated tokens
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Prepare the generation inputs
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# Set generation parameters within a GenerationConfig object
# We explicitly set use_cache=False to avoid the ONNX export bug
gen_config = GenerationConfig(
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
use_cache=False,
)
# Create a thread to run the generation in the background
generation_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
generation_config=gen_config,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Yield each token from the streamer as it is generated
for new_text in streamer:
yield new_text
# -----------------------------------------------------------------------------
# Gradio Interface
# -----------------------------------------------------------------------------
demo = gr.Interface(
fn=generate_text_stream,
inputs=[
gr.Textbox(label="Prompt", lines=2),
gr.Slider(minimum=10, maximum=512, value=128, label="Max Length"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.8, label="Temperature"),
gr.Slider(minimum=1, maximum=100, value=60, label="Top K"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P"),
],
outputs="text",
title="SmilyAI Sam 3.0-2 ONNX Text Generation (Streaming)",
description="A simple API and UI for text generation using the ONNX version of Sam 3.0-2, with streaming output.",
)
demo.launch() |