chat-server / app /chat_server.py
cahya's picture
first commit
e3f06ac
import asyncio
import websockets
import logging
import json
from convaimodel_extended import ConvAIModelExtended
logging.basicConfig()
STATE = {"value": 0}
USERS = set()
train_args = {}
model = ConvAIModelExtended("gpt2", "cahya/gpt2-small-indonesian-personachat",
args=train_args, use_cuda=False)
def connection_event():
return json.dumps({"type": "connection", "value": True})
def state_event():
return json.dumps({"type": "state", **STATE})
def dialog_event(message):
return json.dumps({"type": "dialog", "message": message})
def personality_event(message):
return json.dumps({"type": "personality", "message": message})
def persona_list_event(message):
return json.dumps({"type": "persona_list", "message": message})
def personality_reply_event(message):
return json.dumps({"type": "personality_reply", "message": message})
def persona_greeting_event(message):
return json.dumps({"type": "persona_greeting", "message": message})
def talk_event(message):
return json.dumps({"type": "talk", "message": message})
def users_event():
return json.dumps({"type": "users", "count": len(USERS)})
async def chatbot(websocket, path):
dialog_id = 0
try:
# Register user
USERS.add(websocket)
await websocket.send(connection_event())
websockets.broadcast(USERS, users_event())
# Send current state to user
await websocket.send(state_event())
# Manage state changes
async for message in websocket:
message = message.strip()
if message == "":
continue
try:
data = json.loads(message)
if data["action"] == "minus":
STATE["value"] -= 1
websockets.broadcast(USERS, state_event())
elif data["action"] == "plus":
STATE["value"] += 1
websockets.broadcast(USERS, state_event())
elif data["action"] == "get_users":
await websocket.send(users_event())
elif data["action"] == "dialog":
if dialog_id == 0:
dialog_id = model.new_dialog()
if dialog_id != 0:
await websocket.send(dialog_event("New dialog is created"))
persona_list = model.get_persona_list(dialog_id)
await websocket.send(persona_list_event(persona_list))
else:
await websocket.send(dialog_event("Dialog is not created"))
elif data["action"] == "talk":
if dialog_id != 0:
do_sample = bool(data["do_sample"]) if "do_sample" in data else True
min_length = int(data["min_length"]) if "min_length" in data else 1
max_length = int(data["max_length"]) if "max_length" in data else 20
temperature = float(data["temperature"]) if "temperature" in data else 0.7
top_k = int(data["top_k"]) if "top_k" in data else 0
top_p = float(data["top_p"]) if "top_p" in data else 0.9
reply = model.talk(dialog_id,
persona_id=data["persona_id"],
utterance=data["utterance"],
do_sample=do_sample,
min_length=min_length,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p)
await websocket.send(talk_event(reply))
elif data["action"] == "personality":
if dialog_id != 0:
model.set_personality(dialog_id, persona_id=data["persona_id"], personality=data["message"])
await websocket.send(personality_reply_event("Personality has been updated"))
elif data["action"] == "persona_chosen":
if dialog_id != 0:
name = ConvAIModelExtended.get_persona_name(dialog_id, data["persona_id"])
greeting = f"Hi, I am {name}. Nice to meet you. Feel free too talk in English or Indonesian."
await websocket.send(persona_greeting_event(greeting))
else:
logging.error("unsupported event: %s", data)
except json.decoder.JSONDecodeError as error:
print(error)
finally:
# Unregister user
ConvAIModelExtended.delete_dialog(dialog_id)
USERS.remove(websocket)
websockets.broadcast(USERS, users_event())
async def main():
async with websockets.serve(chatbot, "0.0.0.0", 8502):
print("Websocket is running")
await asyncio.Future() # run forever
if __name__ == "__main__":
asyncio.run(main())