Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import gradio as gr | |
| log_level = os.environ.get("LOG_LEVEL", "WARNING") | |
| logging.basicConfig(encoding='utf-8', level=log_level) | |
| logging.info("Loading Model") | |
| tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True) | |
| def format_prompt(message, history): | |
| """Formats the prompt for the AI""" | |
| logging.info("Formatting Prompt") | |
| logging.debug("Input Message: %s", message) | |
| logging.debug("Input History: %s", history) | |
| prompt = f"Instruct: {message}\n" | |
| prompt += "Output: " | |
| return prompt | |
| def generate( | |
| prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, | |
| ): | |
| logging.info("Generating Response") | |
| logging.debug("Input Prompt: %s", prompt) | |
| logging.debug("Input History: %s", history) | |
| logging.debug("Input System Prompt: %s", system_prompt) | |
| logging.debug("Input Temperature: %s", temperature) | |
| logging.debug("Input Max New Tokens: %s", max_new_tokens) | |
| logging.debug("Input Top P: %s", top_p) | |
| logging.debug("Input Repetition Penalty: %s", repetition_penalty) | |
| logging.info("Converting Parameters to Correct Type") | |
| temperature = float(temperature) | |
| if temperature < 1e-2: | |
| temperature = 1e-2 | |
| top_p = float(top_p) | |
| logging.debug("Temperature: %s", temperature) | |
| logging.debug("Top P: %s", top_p) | |
| logging.info("Creating Generate kwargs") | |
| generate_kwargs = dict( | |
| temperature=temperature, | |
| max_new_tokens=max_new_tokens, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=True | |
| ) | |
| logging.debug("Generate Args: %s", generate_kwargs) | |
| formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history) | |
| logging.debug("Prompt: %s", formatted_prompt) | |
| logging.info("Generating Text") | |
| stream = model.generate(tokenizer(prompt, return_tensors="pt").input_ids, **generate_kwargs) | |
| logging.info("Creating Output") | |
| output = "" | |
| for response in stream: | |
| output += response.token.text | |
| yield output | |
| logging.debug("Output: %s", output) | |
| return output | |
| additional_inputs = [ | |
| gr.Textbox( | |
| label="System Prompt", | |
| max_lines=1, | |
| interactive=True, | |
| ), | |
| gr.Slider( | |
| label="Temperature", | |
| value=0.9, | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| interactive=True, | |
| info="Higher values produce more diverse outputs", | |
| ), | |
| gr.Slider( | |
| label="Max new tokens", | |
| value=256, | |
| minimum=0, | |
| maximum=1048, | |
| step=64, | |
| interactive=True, | |
| info="The maximum numbers of new tokens", | |
| ), | |
| gr.Slider( | |
| label="Top-p (nucleus sampling)", | |
| value=0.90, | |
| minimum=0.0, | |
| maximum=1, | |
| step=0.05, | |
| interactive=True, | |
| info="Higher values sample more low-probability tokens", | |
| ), | |
| gr.Slider( | |
| label="Repetition penalty", | |
| value=1.2, | |
| minimum=1.0, | |
| maximum=2.0, | |
| step=0.05, | |
| interactive=True, | |
| info="Penalize repeated tokens", | |
| ) | |
| ] | |
| examples = [] | |
| logging.info("Creating Chat Interface") | |
| gr.ChatInterface( | |
| fn=generate, | |
| chatbot=gr.Chatbot(show_label=False, show_share_button=False, | |
| show_copy_button=True, likeable=True, layout="panel"), | |
| additional_inputs=additional_inputs, | |
| title="Mixtral Instruct", | |
| examples=examples, | |
| concurrency_limit=20, | |
| ).launch(show_api=False) |