import os from collections.abc import Iterator from threading import Thread import gradio as gr import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import logging import time logger = logging.getLogger("gradio_chat_001") logger.setLevel(logging.INFO) logging.debug("Starting logging for gradio_chat_001.") categories = [ "Legal", "Specification", "Facts and Figures", "Publication", "Payment Scheme", "Alternative Payment Systems", "Crypto Payments", "Card Payments", "Banking", "Regulations", "Account Payments" ] logging.debug("Categories to classify: " + repr(categories)) # DESCRIPTION = """\ # # Llama 3.2 3B Instruct # Llama 3.2 3B is Meta's latest iteration of open LLMs. # This is a demo of [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct), fine-tuned for instruction following. # For more details, please check [our post](https://huggingface.co/blog/llama32). # """ MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) USE_CUDA = torch.cuda.is_available() if USE_CUDA: logger.warn("Wants to use CUDA, stop it!") USE_CUDA = False device = torch.device("cuda:0" if USE_CUDA else "cpu") # model_id = "meta-llama/Llama-3.2-3B-Instruct" model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, ) logger.info("Created model: " + model_id) logger.info("Model repr: " + repr(model)) logger.info("Tokenizer repr: " + repr(tokenizer)) model.eval() # Example: # from transformers import AutoTokenizer, DeepseekV3ForCausalLM # model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf") # tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf") # prompt = "Hey, are you conscious? Can you talk to me?" # inputs = tokenizer(prompt, return_tensors="pt") # # Generate # generate_ids = model.generate(inputs.input_ids, max_length=30) # tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] # @spaces.GPU(duration=90) def generate( message: str, chat_history: list[dict], max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, ) -> Iterator[str]: conversation = [*chat_history, {"role": "user", "content": message}] input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") logger.warn(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(model.device) attention_mask = torch.ones_like(input_ids) streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( {"input_ids": input_ids, "attention_mask": attention_mask}, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, num_beams=1, repetition_penalty=repetition_penalty, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) def analyse_time_array(arr, extended=False): length = len(arr) if length == 0: return "Empty" if length == 1: return "Start" start = arr[0] end = arr[-1] diff = end - start msg = f"{length-1} Tokens in {diff}s | {diff/length} Tokens/s" if extended: diffs = sorted([arr[i+1]-arr[i] for i in range(0, length-1)]) # msg += "\nDiffs between tokens:" msg += "\nBest/shortest: " + ", ".join(f"{x:.02f}s" for x in diffs[:5]) msg += "\nWorst/longest: " + ", ".join(f"{x:.02f}s" for x in diffs[-5:]) return msg SPACER = "\n\n" + "-"*80 + "\n\n" def try_generate( message: str, chat_history: list[dict], max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, ): try: logger.info("Create input") yield "" conversation = [*chat_history, {"role": "user", "content": message}] input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") logger.warn(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") yield f"{repr(input_ids)}" input_ids = input_ids.to(model.device) attention_mask = torch.ones_like(input_ids) except Exception as e: logger.warn("Failed to create input parameters: " + repr(e)) yield "Failed to create input parameters: " + repr(e) return try: streamer = TextIteratorStreamer(tokenizer, timeout=120.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( {"input_ids": input_ids, "attention_mask": attention_mask}, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, num_beams=1, repetition_penalty=repetition_penalty, ) except Exception as e: msg ="Failed to create streamer: " + repr(e) logger.warning(msg) yield msg return try: yield "" t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() except Exception as e: msg = "Failed to create thread: " + repr(e) logger.warning(msg) yield msg return outputs = [] times = [time.time()] try: yield "" for text in streamer: outputs.append(text) times.append(time.time()) # yield "".join(outputs) msg = "".join(outputs) info = analyse_time_array(times, True) yield msg+SPACER+info except Exception as e: n = len(outputs) exp = repr(e) error = f"Failed creating output @ position {n}: {exp}" logger.warning(error) msg = "".join(outputs) info = analyse_time_array(times, True) yield msg+SPACER+info+SPACER+error # yield f"{output}\n--------------------\n{msg}" msg = "".join(outputs) info = analyse_time_array(times, True) yield msg+SPACER+info+"\n--- DONE ---" demo = gr.ChatInterface( fn=try_generate, additional_inputs=[ gr.Slider( label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS, ), gr.Slider( label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6, ), gr.Slider( label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9, ), gr.Slider( label="Top-k", minimum=1, maximum=1000, step=1, value=50, ), gr.Slider( label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2, ), ], stop_btn=None, examples=[ ["Hello there! How are you doing?"], ["Can you explain briefly to me what is the Python programming language?"], ["Explain the plot of Cinderella in a sentence."], ["How many hours does it take a man to eat a Helicopter?"], ["Write a 100-word article on 'Benefits of Open-Source in AI research'"], ], cache_examples=False, type="messages", # description=DESCRIPTION, # css_paths="style.css", fill_height=True, ) if __name__ == "__main__": demo.queue(max_size=20).launch()