skellychat / utilities /run_chain.py
jonmatthis's picture
push for posterity
dc2726e
raw
history blame
2.74 kB
import datetime
import re
import sys
from io import StringIO
from openai import InvalidRequestError
from openai.error import RateLimitError, AuthenticationError
from config.config import BUG_FOUND_MSG, AUTHORIZATION_ERROR_MESSAGE
def run_chain(chain, inp, capture_hidden_text):
output = ""
hidden_text = None
if capture_hidden_text:
error_msg = None
tmp = sys.stdout
hidden_text_io = StringIO()
sys.stdout = hidden_text_io
try:
output = chain.run(input=inp)
except AuthenticationError as ae:
error_msg = AUTHORIZATION_ERROR_MESSAGE + str(datetime.datetime.now()) + ". " + str(ae)
print("error_msg", error_msg)
except RateLimitError as rle:
error_msg = "\n\nRateLimitError: " + str(rle)
except ValueError as ve:
error_msg = "\n\nValueError: " + str(ve)
except InvalidRequestError as ire:
error_msg = "\n\nInvalidRequestError: " + str(ire)
except Exception as e:
error_msg = "\n\n" + BUG_FOUND_MSG + ":\n\n" + str(e)
sys.stdout = tmp
hidden_text = hidden_text_io.getvalue()
# remove escape characters from hidden_text
hidden_text = re.sub(r'\x1b[^m]*m', '', hidden_text)
# remove "Entering new AgentExecutor chain..." from hidden_text
hidden_text = re.sub(r"Entering new AgentExecutor chain...\n", "", hidden_text)
# remove "Finished chain." from hidden_text
hidden_text = re.sub(r"Finished chain.", "", hidden_text)
# Add newline after "Thought:" "Action:" "Observation:" "Input:" and "AI:"
hidden_text = re.sub(r"Thought:", "\n\nThought:", hidden_text)
hidden_text = re.sub(r"Action:", "\n\nAction:", hidden_text)
hidden_text = re.sub(r"Observation:", "\n\nObservation:", hidden_text)
hidden_text = re.sub(r"Input:", "\n\nInput:", hidden_text)
hidden_text = re.sub(r"AI:", "\n\nAI:", hidden_text)
if error_msg:
hidden_text += error_msg
print("hidden_text: ", hidden_text)
else:
try:
output = chain.run(input=inp)
except AuthenticationError as ae:
output = AUTHORIZATION_ERROR_MESSAGE + str(datetime.datetime.now()) + ". " + str(ae)
print("output", output)
except RateLimitError as rle:
output = "\n\nRateLimitError: " + str(rle)
except ValueError as ve:
output = "\n\nValueError: " + str(ve)
except InvalidRequestError as ire:
output = "\n\nInvalidRequestError: " + str(ire)
except Exception as e:
output = "\n\n" + BUG_FOUND_MSG + ":\n\n" + str(e)
return output, hidden_text