Spaces:
Runtime error
Runtime error
tech-envision
commited on
Commit
·
bf45c7d
1
Parent(s):
4dbdfda
Add user-based session tracking
Browse files
README.md
CHANGED
|
@@ -8,4 +8,4 @@ This project provides a simple async interface to interact with an Ollama model
|
|
| 8 |
python run.py
|
| 9 |
```
|
| 10 |
|
| 11 |
-
The script will ask the model to compute an arithmetic expression and print the answer. Conversations are automatically persisted to `chat.db
|
|
|
|
| 8 |
python run.py
|
| 9 |
```
|
| 10 |
|
| 11 |
+
The script will ask the model to compute an arithmetic expression and print the answer. Conversations are automatically persisted to `chat.db` and are now associated with a user and session.
|
run.py
CHANGED
|
@@ -6,7 +6,7 @@ from src.chat import ChatSession
|
|
| 6 |
|
| 7 |
|
| 8 |
async def _main() -> None:
|
| 9 |
-
async with ChatSession() as chat:
|
| 10 |
answer = await chat.chat("What is 10 + 23?")
|
| 11 |
print("\n>>>", answer)
|
| 12 |
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
async def _main() -> None:
|
| 9 |
+
async with ChatSession(user="demo_user") as chat:
|
| 10 |
answer = await chat.chat("What is 10 + 23?")
|
| 11 |
print("\n>>>", answer)
|
| 12 |
|
src/chat.py
CHANGED
|
@@ -6,7 +6,7 @@ import json
|
|
| 6 |
from ollama import AsyncClient, ChatResponse
|
| 7 |
|
| 8 |
from .config import MAX_TOOL_CALL_DEPTH, MODEL_NAME, OLLAMA_HOST
|
| 9 |
-
from .db import Conversation, Message, _db, init_db
|
| 10 |
from .log import get_logger
|
| 11 |
from .schema import Msg
|
| 12 |
from .tools import add_two_numbers
|
|
@@ -15,10 +15,11 @@ _LOG = get_logger(__name__)
|
|
| 15 |
|
| 16 |
|
| 17 |
class ChatSession:
|
| 18 |
-
def __init__(self, host: str = OLLAMA_HOST, model: str = MODEL_NAME) -> None:
|
| 19 |
init_db()
|
| 20 |
self._client = AsyncClient(host=host)
|
| 21 |
self._model = model
|
|
|
|
| 22 |
|
| 23 |
async def __aenter__(self) -> "ChatSession":
|
| 24 |
return self
|
|
@@ -82,7 +83,7 @@ class ChatSession:
|
|
| 82 |
return response
|
| 83 |
|
| 84 |
async def chat(self, prompt: str) -> str:
|
| 85 |
-
conversation = Conversation.create()
|
| 86 |
Message.create(conversation=conversation, role="user", content=prompt)
|
| 87 |
messages: List[Msg] = [{"role": "user", "content": prompt}]
|
| 88 |
response = await self.ask(messages)
|
|
|
|
| 6 |
from ollama import AsyncClient, ChatResponse
|
| 7 |
|
| 8 |
from .config import MAX_TOOL_CALL_DEPTH, MODEL_NAME, OLLAMA_HOST
|
| 9 |
+
from .db import Conversation, Message, User, _db, init_db
|
| 10 |
from .log import get_logger
|
| 11 |
from .schema import Msg
|
| 12 |
from .tools import add_two_numbers
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
class ChatSession:
|
| 18 |
+
def __init__(self, user: str = "default", host: str = OLLAMA_HOST, model: str = MODEL_NAME) -> None:
|
| 19 |
init_db()
|
| 20 |
self._client = AsyncClient(host=host)
|
| 21 |
self._model = model
|
| 22 |
+
self._user, _ = User.get_or_create(username=user)
|
| 23 |
|
| 24 |
async def __aenter__(self) -> "ChatSession":
|
| 25 |
return self
|
|
|
|
| 83 |
return response
|
| 84 |
|
| 85 |
async def chat(self, prompt: str) -> str:
|
| 86 |
+
conversation = Conversation.create(user=self._user)
|
| 87 |
Message.create(conversation=conversation, role="user", content=prompt)
|
| 88 |
messages: List[Msg] = [{"role": "user", "content": prompt}]
|
| 89 |
response = await self.ask(messages)
|
src/db.py
CHANGED
|
@@ -23,8 +23,14 @@ class BaseModel(Model):
|
|
| 23 |
database = _db
|
| 24 |
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
class Conversation(BaseModel):
|
| 27 |
id = AutoField()
|
|
|
|
| 28 |
started_at = DateTimeField(default=datetime.utcnow)
|
| 29 |
|
| 30 |
|
|
@@ -36,11 +42,11 @@ class Message(BaseModel):
|
|
| 36 |
created_at = DateTimeField(default=datetime.utcnow)
|
| 37 |
|
| 38 |
|
| 39 |
-
__all__ = ["_db", "Conversation", "Message"]
|
| 40 |
|
| 41 |
|
| 42 |
def init_db() -> None:
|
| 43 |
"""Initialise the database and create tables if they do not exist."""
|
| 44 |
if _db.is_closed():
|
| 45 |
_db.connect()
|
| 46 |
-
_db.create_tables([Conversation, Message])
|
|
|
|
| 23 |
database = _db
|
| 24 |
|
| 25 |
|
| 26 |
+
class User(BaseModel):
|
| 27 |
+
id = AutoField()
|
| 28 |
+
username = CharField(unique=True)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
class Conversation(BaseModel):
|
| 32 |
id = AutoField()
|
| 33 |
+
user = ForeignKeyField(User, backref="conversations")
|
| 34 |
started_at = DateTimeField(default=datetime.utcnow)
|
| 35 |
|
| 36 |
|
|
|
|
| 42 |
created_at = DateTimeField(default=datetime.utcnow)
|
| 43 |
|
| 44 |
|
| 45 |
+
__all__ = ["_db", "User", "Conversation", "Message"]
|
| 46 |
|
| 47 |
|
| 48 |
def init_db() -> None:
|
| 49 |
"""Initialise the database and create tables if they do not exist."""
|
| 50 |
if _db.is_closed():
|
| 51 |
_db.connect()
|
| 52 |
+
_db.create_tables([User, Conversation, Message])
|