from fix_int8 import fix_pytorch_int8 fix_pytorch_int8() # Credit: # https://huggingface.co/spaces/ljsabc/Fujisaki/blob/main/app.py import torch import gradio as gr from threading import Thread from model import model, tokenizer from session import db, logger, log_sys_info from transformers import AutoTokenizer, GenerationConfig, AutoModel max_length = 224 default_start = ["你是Kuma,请和我聊天,每句话以两个竖杠分隔。", "好的,你想聊什么?"] gr_title = """

KumaGLM

这是一个 AI Kuma,你可以与他聊天,或者直接在文本框按下Enter

采样范围 2020/06/13 - 2023/04/15

GitHub Repo: KumaTea/ChatGLM

""" gr_footer = """

本项目基于 ljsabc/Fujisaki ,模型采用 THUDM/chatglm-6b

每天起床第一句!

""" def evaluate(context, temperature, top_p): generation_config = GenerationConfig( temperature=temperature, top_p=top_p, # top_k=top_k, #repetition_penalty=1.1, num_beams=1, do_sample=True, ) with torch.no_grad(): # input_text = f"Context: {context}Answer: " # input_text = '||'.join(default_start) + '||' # No need for starting prompt in API if not context.endswith('||'): context += '||' # logger.info('[API] Request: ' + context) ids = tokenizer([context], return_tensors="pt") inputs = ids.to("cpu") out = model.generate( **inputs, max_length=max_length, generation_config=generation_config ) out = out.tolist()[0] decoder_output = tokenizer.decode(out) # out_text = decoder_output.split("Answer: ")[1] out_text = decoder_output logger.info('[API] Results: ' + out_text.replace('\n', '
')) return out_text def evaluate_wrapper(context, temperature, top_p): db.lock() index = db.index db.set(index, prompt=context) result = evaluate(context, temperature, top_p) db.set(index, result=result) db.unlock() return result def api_wrapper(context='', temperature=0.5, top_p=0.8, query=0): query = int(query) assert context or query return_json = { 'status': '', 'code': 0, 'message': '', 'index': 0, 'result': '' } if context: if db.islocked(): logger.info(f'[API] Request: {context}, Status: busy') return_json['status'] = 'busy' return_json['code'] = 503 return_json['message'] = '[context] Server is busy, please try again later.' return return_json else: for index in db.prompts: if db.prompts[index] == context: return_json['status'] = 'done' return_json['code'] = 200 return_json['message'] = '[context] Request cached.' return_json['index'] = index return_json['result'] = db.results[index] return return_json # new index = db.index t = Thread(target=evaluate_wrapper, args=(context, temperature, top_p)) t.start() logger.info(f'[API] Request: {context}, Status: processing, Index: {index}') return_json['status'] = 'processing' return_json['code'] = 202 return_json['message'] = '[context] Request accepted, please check back later.' return_json['index'] = index return return_json else: # query if query in db.prompts and query in db.results: logger.info(f'[API] Query: {query}, Status: hit') return_json['status'] = 'done' return_json['code'] = 200 return_json['message'] = '[query] Request processed.' return_json['index'] = query return_json['result'] = db.results[query] return return_json else: if db.islocked(): logger.info(f'[API] Query: {query}, Status: processing') return_json['status'] = 'processing' return_json['code'] = 202 return_json['message'] = '[query] Request in processing, please check back later.' return_json['index'] = query return return_json else: logger.info(f'[API] Query: {query}, Status: error') return_json['status'] = 'error' return_json['code'] = 404 return_json['message'] = '[query] Index not found.' return_json['index'] = query return return_json def evaluate_stream(msg, history, temperature, top_p): generation_config = GenerationConfig( temperature=temperature, top_p=top_p, #repetition_penalty=1.1, num_beams=1, do_sample=True, ) if not msg: msg = '……' history.append([msg, ""]) context = '||'.join(default_start) + '||' if len(history) > 4: history.pop(0) for j in range(len(history)): history[j][0] = history[j][0].replace("
", "") # concatenate context for h in history[:-1]: context += h[0] + "||" + h[1] + "||" context += history[-1][0] + "||" context = context.replace(r'
', '') # TODO: Avoid the tokens are too long. # CUTOFF = 224 while len(tokenizer.encode(context)) > max_length: # save 15 token size for the answer context = context[15:] h = [] logger.info('[UI] Request: ' + context) for response, h in model.stream_chat(tokenizer, context, h, max_length=max_length, top_p=top_p, temperature=temperature): history[-1][1] = response yield history, "" logger.info('[UI] Results: ' + response.replace('\n', '
')) with gr.Blocks() as demo: gr.HTML(gr_title) # state = gr.State() with gr.Row(): with gr.Column(scale=2): temp = gr.components.Slider(minimum=0, maximum=1.1, value=0.5, label="Temperature", info="温度参数,越高的温度生成的内容越丰富,但是有可能出现语法问题。小的温度也能帮助生成更相关的回答。") top_p = gr.components.Slider(minimum=0.5, maximum=1.0, value=0.8, label="Top-p", info="top-p参数,只输出前p>top-p的文字,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。") #code = gr.Textbox(label="temp_output", info="解码器输出") #top_k = gr.components.Slider(minimum=1, maximum=200, step=1, value=25, label="Top k", # info="top-k参数,下一个输出的文字会从top-k个文字中进行选择,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。") with gr.Column(scale=3): chatbot = gr.Chatbot(label="聊天框", info="") msg = gr.Textbox(label="输入框", placeholder="最近过得怎么样?", info="输入你的内容,按 [Enter] 发送。什么都不填经常会出错。") clear = gr.Button("清除聊天") api_handler = gr.Button("API", visible=False) api_index = gr.Number(visible=False) api_result = gr.JSON(visible=False) info_handler = gr.Button("Info", visible=False) info_text = gr.Textbox('System info logged. Check it in the log viewer.', visible=False) msg.submit(evaluate_stream, [msg, chatbot, temp, top_p], [chatbot, msg]) clear.click(lambda: None, None, chatbot, queue=False) api_handler.click(api_wrapper, [msg, temp, top_p, api_index], api_result, api_name='chat') info_handler.click(log_sys_info, None, info_text, api_name='info') gr.HTML(gr_footer) demo.queue() demo.launch(debug=False)