|
|
|
from models.tinygpt2_model import TinyGPT2Model |
|
from pydantic import PrivateAttr |
|
from crewai.tools import BaseTool |
|
|
|
|
|
|
|
class MistralChatTool(BaseTool): |
|
name: str = "mistral_chat" |
|
description: str = "Generate an empathetic AI chat response." |
|
model_config = {"arbitrary_types_allowed": True} |
|
_model: TinyGPT2Model = PrivateAttr() |
|
def __init__(self, config=None): |
|
super().__init__() |
|
self._model = TinyGPT2Model() |
|
def _run(self, prompt: str, context: dict = None): |
|
msg = f"Context: {context}\nUser: {prompt}" if context else prompt |
|
return self.model.generate(msg) |
|
|
|
class GenerateAdviceTool(BaseTool): |
|
name: str = "generate_advice" |
|
description: str = "Generate personalized advice." |
|
model_config = {"arbitrary_types_allowed": True} |
|
_model: TinyGPT2Model = PrivateAttr() |
|
def __init__(self, config=None): |
|
super().__init__() |
|
self._model = TinyGPT2Model() |
|
def _run(self, user_analysis: dict, wisdom_quotes: list): |
|
prompt = f"Advice for: {user_analysis}, with wisdom: {wisdom_quotes}" |
|
return self.model.generate(prompt, max_length=300) |
|
|
|
class SummarizeConversationTool(BaseTool): |
|
name: str = "summarize_conversation" |
|
description: str = "Summarize chat with insights and next steps." |
|
model_config = {"arbitrary_types_allowed": True} |
|
_model: TinyGPT2Model = PrivateAttr() |
|
def __init__(self, config=None): |
|
super().__init__() |
|
self._model = TinyGPT2Model() |
|
def _run(self, conversation: list): |
|
prompt = f"Summarize: {conversation}" |
|
return self.model.generate(prompt, max_length=200) |
|
|
|
class LLMTools: |
|
def __init__(self, config=None): |
|
self.mistral_chat = MistralChatTool(config) |
|
self.generate_advice = GenerateAdviceTool(config) |
|
self.summarize_conversation = SummarizeConversationTool(config) |