ThinkSquare / src /llm /base_llm_wrapper.py
Falguni's picture
Fix imports
2158804
from abc import ABC, abstractmethod
from typing import List
from src.llm.data_models.typeddict_data_models import MultiCommentModel
from src.llm.prompts.prompt_head import (
prompt_head_expert,
prompt_head_jarvis,
prompt_head_novice,
prompt_head_natural,
prompt_head_generic,
)
from src.llm.prompts.prompt_core import prompt_core
class BaseLLMWrapper(ABC):
def comment(
self,
character: str,
game,
comment_refs: List[int],
move_nums: List,
played_moves: List,
played_by: List,
comments: List,
move_suggestions: List,
pre_eval_scores: List,
post_eval_scores: List,
) -> MultiCommentModel:
"""
Rewrite the comment in a specific tone or style.
Args:
character (str): The character or style in which to rewrite the comment.
move_nums (List): List of move numbers corresponding to the comments.
comments (List): List of comments to be rewritten.
move_suggestions (List): List of suggested moves, if any.
Returns:
MultiCommentModel: A model containing the rewritten comments.
"""
character = character.lower().strip()
if character == "natural":
prompt_head = prompt_head_natural
elif character == "jarvis":
prompt_head = prompt_head_jarvis
elif character == "novice":
prompt_head = prompt_head_novice
elif character == "expert":
prompt_head = prompt_head_expert
else:
prompt_head = prompt_head_generic.format(character=character)
comments_info = []
for (
comment_ref,
move_num,
played_move,
played_by_player,
comment,
suggestion,
pre_eval_score,
post_eval_score,
) in zip(
comment_refs,
move_nums,
played_moves,
played_by,
comments,
move_suggestions,
pre_eval_scores,
post_eval_scores,
):
comments_info.append(
{
"comment_ref": comment_ref,
"move_num": move_num,
"move": played_move,
"played_by": played_by_player,
"comment": comment,
"better_variation": suggestion,
"score_before_move": pre_eval_score,
"score_after_move": post_eval_score,
}
)
prompt = (
prompt_head + "\n" + prompt_core.format(pgn=game, comments=comments_info)
)
response = self.generate_response(prompt)
return response
@abstractmethod
def generate_response(self, prompt: str) -> MultiCommentModel:
"""
Generate a response based on the provided prompt.
Args:
prompt (str): The input prompt to generate a response for.
Returns:
str: The generated response.
"""
pass