Spaces:
Runtime error
Runtime error
File size: 10,654 Bytes
82a7a28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
import re
import json
import os
import chevron
from typing import Collection
import copy
import functools
import inspect
from tinytroupe.openai_utils import LLMRequest
from tinytroupe.utils import logger
from tinytroupe.utils.rendering import break_text_at_length
################################################################################
# Model input utilities
################################################################################
def compose_initial_LLM_messages_with_templates(system_template_name:str, user_template_name:str=None,
base_module_folder:str=None,
rendering_configs:dict={}) -> list:
"""
Composes the initial messages for the LLM model call, under the assumption that it always involves
a system (overall task description) and an optional user message (specific task description).
These messages are composed using the specified templates and rendering configurations.
"""
# ../ to go to the base library folder, because that's the most natural reference point for the user
if base_module_folder is None:
sub_folder = "../prompts/"
else:
sub_folder = f"../{base_module_folder}/prompts/"
base_template_folder = os.path.join(os.path.dirname(__file__), sub_folder)
system_prompt_template_path = os.path.join(base_template_folder, f'{system_template_name}')
user_prompt_template_path = os.path.join(base_template_folder, f'{user_template_name}')
messages = []
messages.append({"role": "system",
"content": chevron.render(
open(system_prompt_template_path).read(),
rendering_configs)})
# optionally add a user message
if user_template_name is not None:
messages.append({"role": "user",
"content": chevron.render(
open(user_prompt_template_path).read(),
rendering_configs)})
return messages
def llm(**model_overrides):
"""
Decorator that turns the decorated function into an LLM-based function.
The decorated function must either return a string (the instruction to the LLM),
or the parameters of the function will be used instead as the instruction to the LLM.
The LLM response is coerced to the function's annotated return type, if present.
Usage example:
@llm(model="gpt-4-0613", temperature=0.5, max_tokens=100)
def joke():
return "Tell me a joke."
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
sig = inspect.signature(func)
return_type = sig.return_annotation if sig.return_annotation != inspect.Signature.empty else str
system_prompt = func.__doc__.strip() if func.__doc__ else "You are an AI system that executes a computation as requested."
if isinstance(result, str):
user_prompt = "EXECUTE THE INSTRUCTIONS BELOW:\n\n " + result
else:
user_prompt = f"Execute your function as best as you can using the following parameters: {kwargs}"
llm_req = LLMRequest(system_prompt=system_prompt,
user_prompt=user_prompt,
output_type=return_type,
**model_overrides)
return llm_req.call()
return wrapper
return decorator
################################################################################
# Model output utilities
################################################################################
def extract_json(text: str) -> dict:
"""
Extracts a JSON object from a string, ignoring: any text before the first
opening curly brace; and any Markdown opening (```json) or closing(```) tags.
"""
try:
# remove any text before the first opening curly or square braces, using regex. Leave the braces.
text = re.sub(r'^.*?({|\[)', r'\1', text, flags=re.DOTALL)
# remove any trailing text after the LAST closing curly or square braces, using regex. Leave the braces.
text = re.sub(r'(}|\])(?!.*(\]|\})).*$', r'\1', text, flags=re.DOTALL)
# remove invalid escape sequences, which show up sometimes
text = re.sub("\\'", "'", text) # replace \' with just '
text = re.sub("\\,", ",", text)
# use strict=False to correctly parse new lines, tabs, etc.
parsed = json.loads(text, strict=False)
# return the parsed JSON object
return parsed
except Exception as e:
logger.error(f"Error occurred while extracting JSON: {e}")
return {}
def extract_code_block(text: str) -> str:
"""
Extracts a code block from a string, ignoring any text before the first
opening triple backticks and any text after the closing triple backticks.
"""
try:
# remove any text before the first opening triple backticks, using regex. Leave the backticks.
text = re.sub(r'^.*?(```)', r'\1', text, flags=re.DOTALL)
# remove any trailing text after the LAST closing triple backticks, using regex. Leave the backticks.
text = re.sub(r'(```)(?!.*```).*$', r'\1', text, flags=re.DOTALL)
return text
except Exception:
return ""
################################################################################
# Model control utilities
################################################################################
def repeat_on_error(retries:int, exceptions:list):
"""
Decorator that repeats the specified function call if an exception among those specified occurs,
up to the specified number of retries. If that number of retries is exceeded, the
exception is raised. If no exception occurs, the function returns normally.
Args:
retries (int): The number of retries to attempt.
exceptions (list): The list of exception classes to catch.
"""
def decorator(func):
def wrapper(*args, **kwargs):
for i in range(retries):
try:
return func(*args, **kwargs)
except tuple(exceptions) as e:
logger.debug(f"Exception occurred: {e}")
if i == retries - 1:
raise e
else:
logger.debug(f"Retrying ({i+1}/{retries})...")
continue
return wrapper
return decorator
################################################################################
# Prompt engineering
################################################################################
def add_rai_template_variables_if_enabled(template_variables: dict) -> dict:
"""
Adds the RAI template variables to the specified dictionary, if the RAI disclaimers are enabled.
These can be configured in the config.ini file. If enabled, the variables will then load the RAI disclaimers from the
appropriate files in the prompts directory. Otherwise, the variables will be set to None.
Args:
template_variables (dict): The dictionary of template variables to add the RAI variables to.
Returns:
dict: The updated dictionary of template variables.
"""
from tinytroupe import config # avoids circular import
rai_harmful_content_prevention = config["Simulation"].getboolean(
"RAI_HARMFUL_CONTENT_PREVENTION", True
)
rai_copyright_infringement_prevention = config["Simulation"].getboolean(
"RAI_COPYRIGHT_INFRINGEMENT_PREVENTION", True
)
# Harmful content
with open(os.path.join(os.path.dirname(__file__), "prompts/rai_harmful_content_prevention.md"), "r") as f:
rai_harmful_content_prevention_content = f.read()
template_variables['rai_harmful_content_prevention'] = rai_harmful_content_prevention_content if rai_harmful_content_prevention else None
# Copyright infringement
with open(os.path.join(os.path.dirname(__file__), "prompts/rai_copyright_infringement_prevention.md"), "r") as f:
rai_copyright_infringement_prevention_content = f.read()
template_variables['rai_copyright_infringement_prevention'] = rai_copyright_infringement_prevention_content if rai_copyright_infringement_prevention else None
return template_variables
################################################################################
# Truncation
################################################################################
def truncate_actions_or_stimuli(list_of_actions_or_stimuli: Collection[dict], max_content_length: int) -> Collection[str]:
"""
Truncates the content of actions or stimuli at the specified maximum length. Does not modify the original list.
Args:
list_of_actions_or_stimuli (Collection[dict]): The list of actions or stimuli to truncate.
max_content_length (int): The maximum length of the content.
Returns:
Collection[str]: The truncated list of actions or stimuli. It is a new list, not a reference to the original list,
to avoid unexpected side effects.
"""
cloned_list = copy.deepcopy(list_of_actions_or_stimuli)
for element in cloned_list:
# the external wrapper of the LLM message: {'role': ..., 'content': ...}
if "content" in element:
msg_content = element["content"]
# now the actual action or stimulus content
# has action, stimuli or stimulus as key?
if "action" in msg_content:
# is content there?
if "content" in msg_content["action"]:
msg_content["action"]["content"] = break_text_at_length(msg_content["action"]["content"], max_content_length)
elif "stimulus" in msg_content:
# is content there?
if "content" in msg_content["stimulus"]:
msg_content["stimulus"]["content"] = break_text_at_length(msg_content["stimulus"]["content"], max_content_length)
elif "stimuli" in msg_content:
# for each element in the list
for stimulus in msg_content["stimuli"]:
# is content there?
if "content" in stimulus:
stimulus["content"] = break_text_at_length(stimulus["content"], max_content_length)
return cloned_list |