import asyncio import websockets import logging import json from convaimodel_extended import ConvAIModelExtended logging.basicConfig() STATE = {"value": 0} USERS = set() train_args = {} model = ConvAIModelExtended("gpt2", "cahya/gpt2-small-indonesian-personachat", args=train_args, use_cuda=False) def connection_event(): return json.dumps({"type": "connection", "value": True}) def state_event(): return json.dumps({"type": "state", **STATE}) def dialog_event(message): return json.dumps({"type": "dialog", "message": message}) def personality_event(message): return json.dumps({"type": "personality", "message": message}) def persona_list_event(message): return json.dumps({"type": "persona_list", "message": message}) def personality_reply_event(message): return json.dumps({"type": "personality_reply", "message": message}) def persona_greeting_event(message): return json.dumps({"type": "persona_greeting", "message": message}) def talk_event(message): return json.dumps({"type": "talk", "message": message}) def users_event(): return json.dumps({"type": "users", "count": len(USERS)}) async def chatbot(websocket, path): dialog_id = 0 try: # Register user USERS.add(websocket) await websocket.send(connection_event()) websockets.broadcast(USERS, users_event()) # Send current state to user await websocket.send(state_event()) # Manage state changes async for message in websocket: message = message.strip() if message == "": continue try: data = json.loads(message) if data["action"] == "minus": STATE["value"] -= 1 websockets.broadcast(USERS, state_event()) elif data["action"] == "plus": STATE["value"] += 1 websockets.broadcast(USERS, state_event()) elif data["action"] == "get_users": await websocket.send(users_event()) elif data["action"] == "dialog": if dialog_id == 0: dialog_id = model.new_dialog() if dialog_id != 0: await websocket.send(dialog_event("New dialog is created")) persona_list = model.get_persona_list(dialog_id) await websocket.send(persona_list_event(persona_list)) else: await websocket.send(dialog_event("Dialog is not created")) elif data["action"] == "talk": if dialog_id != 0: do_sample = bool(data["do_sample"]) if "do_sample" in data else True min_length = int(data["min_length"]) if "min_length" in data else 1 max_length = int(data["max_length"]) if "max_length" in data else 20 temperature = float(data["temperature"]) if "temperature" in data else 0.7 top_k = int(data["top_k"]) if "top_k" in data else 0 top_p = float(data["top_p"]) if "top_p" in data else 0.9 reply = model.talk(dialog_id, persona_id=data["persona_id"], utterance=data["utterance"], do_sample=do_sample, min_length=min_length, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p) await websocket.send(talk_event(reply)) elif data["action"] == "personality": if dialog_id != 0: model.set_personality(dialog_id, persona_id=data["persona_id"], personality=data["message"]) await websocket.send(personality_reply_event("Personality has been updated")) elif data["action"] == "persona_chosen": if dialog_id != 0: name = ConvAIModelExtended.get_persona_name(dialog_id, data["persona_id"]) greeting = f"Hi, I am {name}. Nice to meet you. Feel free too talk in English or Indonesian." await websocket.send(persona_greeting_event(greeting)) else: logging.error("unsupported event: %s", data) except json.decoder.JSONDecodeError as error: print(error) finally: # Unregister user ConvAIModelExtended.delete_dialog(dialog_id) USERS.remove(websocket) websockets.broadcast(USERS, users_event()) async def main(): async with websockets.serve(chatbot, "0.0.0.0", 8502): print("Websocket is running") await asyncio.Future() # run forever if __name__ == "__main__": asyncio.run(main())