Spaces:
Runtime error
Runtime error
File size: 5,188 Bytes
e3f06ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import asyncio
import websockets
import logging
import json
from convaimodel_extended import ConvAIModelExtended
logging.basicConfig()
STATE = {"value": 0}
USERS = set()
train_args = {}
model = ConvAIModelExtended("gpt2", "cahya/gpt2-small-indonesian-personachat",
args=train_args, use_cuda=False)
def connection_event():
return json.dumps({"type": "connection", "value": True})
def state_event():
return json.dumps({"type": "state", **STATE})
def dialog_event(message):
return json.dumps({"type": "dialog", "message": message})
def personality_event(message):
return json.dumps({"type": "personality", "message": message})
def persona_list_event(message):
return json.dumps({"type": "persona_list", "message": message})
def personality_reply_event(message):
return json.dumps({"type": "personality_reply", "message": message})
def persona_greeting_event(message):
return json.dumps({"type": "persona_greeting", "message": message})
def talk_event(message):
return json.dumps({"type": "talk", "message": message})
def users_event():
return json.dumps({"type": "users", "count": len(USERS)})
async def chatbot(websocket, path):
dialog_id = 0
try:
# Register user
USERS.add(websocket)
await websocket.send(connection_event())
websockets.broadcast(USERS, users_event())
# Send current state to user
await websocket.send(state_event())
# Manage state changes
async for message in websocket:
message = message.strip()
if message == "":
continue
try:
data = json.loads(message)
if data["action"] == "minus":
STATE["value"] -= 1
websockets.broadcast(USERS, state_event())
elif data["action"] == "plus":
STATE["value"] += 1
websockets.broadcast(USERS, state_event())
elif data["action"] == "get_users":
await websocket.send(users_event())
elif data["action"] == "dialog":
if dialog_id == 0:
dialog_id = model.new_dialog()
if dialog_id != 0:
await websocket.send(dialog_event("New dialog is created"))
persona_list = model.get_persona_list(dialog_id)
await websocket.send(persona_list_event(persona_list))
else:
await websocket.send(dialog_event("Dialog is not created"))
elif data["action"] == "talk":
if dialog_id != 0:
do_sample = bool(data["do_sample"]) if "do_sample" in data else True
min_length = int(data["min_length"]) if "min_length" in data else 1
max_length = int(data["max_length"]) if "max_length" in data else 20
temperature = float(data["temperature"]) if "temperature" in data else 0.7
top_k = int(data["top_k"]) if "top_k" in data else 0
top_p = float(data["top_p"]) if "top_p" in data else 0.9
reply = model.talk(dialog_id,
persona_id=data["persona_id"],
utterance=data["utterance"],
do_sample=do_sample,
min_length=min_length,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p)
await websocket.send(talk_event(reply))
elif data["action"] == "personality":
if dialog_id != 0:
model.set_personality(dialog_id, persona_id=data["persona_id"], personality=data["message"])
await websocket.send(personality_reply_event("Personality has been updated"))
elif data["action"] == "persona_chosen":
if dialog_id != 0:
name = ConvAIModelExtended.get_persona_name(dialog_id, data["persona_id"])
greeting = f"Hi, I am {name}. Nice to meet you. Feel free too talk in English or Indonesian."
await websocket.send(persona_greeting_event(greeting))
else:
logging.error("unsupported event: %s", data)
except json.decoder.JSONDecodeError as error:
print(error)
finally:
# Unregister user
ConvAIModelExtended.delete_dialog(dialog_id)
USERS.remove(websocket)
websockets.broadcast(USERS, users_event())
async def main():
async with websockets.serve(chatbot, "0.0.0.0", 8502):
print("Websocket is running")
await asyncio.Future() # run forever
if __name__ == "__main__":
asyncio.run(main())
|