|
import re |
|
import yaml |
|
from dataclasses import dataclass, field |
|
from typing import Dict, List, Set, Optional |
|
|
|
@dataclass |
|
class PromptTemplate: |
|
""" |
|
A template class for managing and validating LLM prompts. |
|
|
|
This class handles: |
|
- Storing system and user prompts |
|
- Validating required template variables |
|
- Formatting prompts with provided variables |
|
|
|
Attributes: |
|
system_prompt (str): The system-level instructions for the LLM |
|
user_template (str): Template string with variables in {variable} format |
|
""" |
|
system_prompt: str |
|
user_template: str |
|
|
|
def __post_init__(self): |
|
"""Initialize the set of required variables from the template.""" |
|
self.required_variables: Set[str] = self._get_required_variables() |
|
|
|
def _get_required_variables(self) -> set: |
|
""" |
|
Extract required variables from the template using regex. |
|
|
|
Returns: |
|
set: Set of variable names found in the template |
|
|
|
Example: |
|
Template "Write about {topic} in {style}" returns {'topic', 'style'} |
|
""" |
|
return set(re.findall(r'\{(\w+)\}', self.user_template)) |
|
|
|
def _validate_variables(self, provided_vars: Dict): |
|
""" |
|
Ensure all required template variables are provided. |
|
|
|
Args: |
|
provided_vars: Dictionary of variable names and values |
|
|
|
Raises: |
|
ValueError: If any required variables are missing |
|
""" |
|
provided_keys = set(provided_vars.keys()) |
|
missing_vars = self.required_variables - provided_keys |
|
if missing_vars: |
|
error_msg = ( |
|
f"\nPrompt Template Error:\n" |
|
f"Missing required variables: {', '.join(missing_vars)}\n" |
|
f"Template requires: {', '.join(self.required_variables)}\n" |
|
f"You provided: {', '.join(provided_keys)}\n" |
|
f"Template string: '{self.user_template}'" |
|
) |
|
raise ValueError(error_msg) |
|
|
|
def format(self, **kwargs) -> List[Dict[str, str]]: |
|
""" |
|
Format the prompt template with provided variables. |
|
|
|
Args: |
|
**kwargs: Key-value pairs for template variables |
|
|
|
Returns: |
|
List[Dict[str, str]]: Formatted messages ready for LLM API |
|
|
|
Example: |
|
template.format(topic="AI", style="academic") |
|
""" |
|
self._validate_variables(kwargs) |
|
|
|
try: |
|
formatted_user_message = self.user_template.format(**kwargs) |
|
except Exception as e: |
|
raise ValueError(f"Error formatting template: {str(e)}") |
|
|
|
return [ |
|
{"role": "system", "content": self.system_prompt}, |
|
{"role": "user", "content": formatted_user_message} |
|
] |
|
|
|
|
|
def load_prompt(yaml_path: str, version: str = None) -> tuple[PromptTemplate, dict]: |
|
""" |
|
Load prompt configuration from YAML file. |
|
|
|
Args: |
|
yaml_path: Path to YAML configuration file |
|
version: Specific version to load (defaults to 'current_version') |
|
|
|
Returns: |
|
tuple: (PromptTemplate instance, generation parameters dictionary) |
|
|
|
Example: |
|
prompt, params = load_prompt('prompts.yaml', version='v2') |
|
""" |
|
with open(yaml_path, 'r') as f: |
|
data = yaml.safe_load(f) |
|
|
|
|
|
version_to_use = version or data.get('current_version') |
|
if version_to_use not in data: |
|
raise KeyError(f"Version '{version_to_use}' not found in {yaml_path}") |
|
|
|
version_data = data[version_to_use] |
|
|
|
prompt = PromptTemplate( |
|
system_prompt=version_data['system_prompt'], |
|
user_template=version_data['user_template'] |
|
) |
|
|
|
generation_params = version_data.get('generation_params', {}) |
|
|
|
return prompt, generation_params |