AG4DP-Example-Chatbot / prompt_template.py
harpreetsahota's picture
Update prompt_template.py
41f63c2 verified
raw
history blame
3.8 kB
import re
import yaml
from dataclasses import dataclass, field
from typing import Dict, List, Set, Optional
@dataclass
class PromptTemplate:
"""
A template class for managing and validating LLM prompts.
This class handles:
- Storing system and user prompts
- Validating required template variables
- Formatting prompts with provided variables
Attributes:
system_prompt (str): The system-level instructions for the LLM
user_template (str): Template string with variables in {variable} format
"""
system_prompt: str
user_template: str
def __post_init__(self):
"""Initialize the set of required variables from the template."""
self.required_variables: Set[str] = self._get_required_variables()
def _get_required_variables(self) -> set:
"""
Extract required variables from the template using regex.
Returns:
set: Set of variable names found in the template
Example:
Template "Write about {topic} in {style}" returns {'topic', 'style'}
"""
return set(re.findall(r'\{(\w+)\}', self.user_template))
def _validate_variables(self, provided_vars: Dict):
"""
Ensure all required template variables are provided.
Args:
provided_vars: Dictionary of variable names and values
Raises:
ValueError: If any required variables are missing
"""
provided_keys = set(provided_vars.keys())
missing_vars = self.required_variables - provided_keys
if missing_vars:
error_msg = (
f"\nPrompt Template Error:\n"
f"Missing required variables: {', '.join(missing_vars)}\n"
f"Template requires: {', '.join(self.required_variables)}\n"
f"You provided: {', '.join(provided_keys)}\n"
f"Template string: '{self.user_template}'"
)
raise ValueError(error_msg)
def format(self, **kwargs) -> List[Dict[str, str]]:
"""
Format the prompt template with provided variables.
Args:
**kwargs: Key-value pairs for template variables
Returns:
List[Dict[str, str]]: Formatted messages ready for LLM API
Example:
template.format(topic="AI", style="academic")
"""
self._validate_variables(kwargs)
try:
formatted_user_message = self.user_template.format(**kwargs)
except Exception as e:
raise ValueError(f"Error formatting template: {str(e)}")
return [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": formatted_user_message}
]
def load_prompt(yaml_path: str, version: str = None) -> tuple[PromptTemplate, dict]:
"""
Load prompt configuration from YAML file.
Args:
yaml_path: Path to YAML configuration file
version: Specific version to load (defaults to 'current_version')
Returns:
tuple: (PromptTemplate instance, generation parameters dictionary)
Example:
prompt, params = load_prompt('prompts.yaml', version='v2')
"""
with open(yaml_path, 'r') as f:
data = yaml.safe_load(f)
# Use specified version or fall back to current_version
version_to_use = version or data.get('current_version')
if version_to_use not in data:
raise KeyError(f"Version '{version_to_use}' not found in {yaml_path}")
version_data = data[version_to_use]
prompt = PromptTemplate(
system_prompt=version_data['system_prompt'],
user_template=version_data['user_template']
)
generation_params = version_data.get('generation_params', {})
return prompt, generation_params