# src.kg.save_triples.py from pathlib import Path import json import argparse import os from pysbd import Segmenter from tiktoken import Encoding from .knowledge_graph import PROMPT_FILE_PATH from .openai_api import (RESPONSES_DIRECTORY_PATH, get_max_chapter_segment_token_count, get_openai_model_encoding, save_openai_api_response) from .utils import (execute_function_in_parallel, set_up_logging, strip_and_remove_empty_strings) logger = set_up_logging('openai-api-scripts.log') def get_paragraphs(text): """Split a text into paragraphs.""" paragraphs = strip_and_remove_empty_strings(text.split('\n\n')) # Convert all whitespace into single spaces. paragraphs = [' '.join(paragraph.split()) for paragraph in paragraphs] return paragraphs def combine_text_subunits_into_segments(subunits, join_string, encoding: Encoding, max_token_count): """ Combine subunits of text into segments that do not exceed a maximum number of tokens. """ # `encode_ordinary_batch()` ignores special tokens and is slightly faster # than `encode_batch()`. subunit_token_counts = [len(tokens) for tokens in encoding.encode_ordinary_batch(subunits)] join_string_token_count = len(encoding.encode_ordinary(join_string)) total_token_count = (sum(subunit_token_counts) + join_string_token_count * (len(subunits) - 1)) if total_token_count <= max_token_count: return [join_string.join(subunits)] # Calculate the approximate number of segments and the approximate number # of tokens per segment, in order to keep the segment lengths roughly # equal. approximate_segment_count = total_token_count // max_token_count + 1 approximate_segment_token_count = round(total_token_count / approximate_segment_count) segments = [] current_segment_subunits = [] current_segment_token_count = 0 for i, (subunit, subunit_token_count) in enumerate( zip(subunits, subunit_token_counts)): # The token count if the current subunit is added to the current # segment. extended_segment_token_count = (current_segment_token_count + join_string_token_count + subunit_token_count) # Add the current subunit to the current segment if it results in a # token count that is closer to the approximate segment token count # than the current segment token count. if (extended_segment_token_count <= max_token_count and abs(extended_segment_token_count - approximate_segment_token_count) <= abs(current_segment_token_count - approximate_segment_token_count)): current_segment_subunits.append(subunit) current_segment_token_count = extended_segment_token_count else: segment = join_string.join(current_segment_subunits) segments.append(segment) # If it is possible to join the remaining subunits into a single # segment, do so. Additionally, add the current subunit as a # segment if it is the last subunit. if (sum(subunit_token_counts[i:]) + join_string_token_count * (len(subunits) - i - 1) <= max_token_count or i == len(subunits) - 1): segment = join_string.join(subunits[i:]) segments.append(segment) break current_segment_subunits = [subunit] current_segment_token_count = subunit_token_count return segments def split_long_sentences(sentences, encoding: Encoding, max_token_count): """ Given a list of sentences, split sentences that exceed a maximum number of tokens into multiple segments. """ token_counts = [len(tokens) for tokens in encoding.encode_ordinary_batch(sentences)] split_sentences = [] for sentence, token_count in zip(sentences, token_counts): if token_count > max_token_count: words = sentence.split() segments = combine_text_subunits_into_segments( words, ' ', encoding, max_token_count) split_sentences.extend(segments) else: split_sentences.append(sentence) return split_sentences def split_long_paragraphs(paragraphs, encoding: Encoding, max_token_count): """ Given a list of paragraphs, split paragraphs that exceed a maximum number of tokens into multiple segments. """ token_counts = [len(tokens) for tokens in encoding.encode_ordinary_batch(paragraphs)] split_paragraphs = [] for paragraph, token_count in zip(paragraphs, token_counts): if token_count > max_token_count: sentences = Segmenter().segment(paragraph) sentences = split_long_sentences(sentences, encoding, max_token_count) segments = combine_text_subunits_into_segments( sentences, ' ', encoding, max_token_count) split_paragraphs.extend(segments) else: split_paragraphs.append(paragraph) return split_paragraphs def get_chapter_segments(chapter_text, encoding: Encoding, max_token_count): """ Split a chapter text into segments that do not exceed a maximum number of tokens. """ paragraphs = get_paragraphs(chapter_text) paragraphs = split_long_paragraphs(paragraphs, encoding, max_token_count) chapter_segments = combine_text_subunits_into_segments( paragraphs, '\n', encoding, max_token_count) return chapter_segments def get_response_save_path(idx, save_path, project_gutenberg_id, chapter_index = None, chapter_segment_index = None, chapter_segment_count = None): """ Get the path to the JSON file(s) containing response data from the OpenAI API. """ save_path = Path(save_path) os.makedirs(save_path, exist_ok=True) if chapter_index is not None: save_path /= str(chapter_index) if chapter_segment_index is not None: save_path /= (f'{chapter_segment_index + 1}-of-' f'{chapter_segment_count}.json') return save_path def save_openai_api_responses_for_script(script, prompt, encoding, max_chapter_segment_token_count, idx, api_key, model_id): """ Call the OpenAI API for each chapter segment in a script and save the responses to a list. """ project_gutenberg_id = script['id'] chapter_count = len(script['chapters']) logger.info(f'Starting to call OpenAI API and process responses for script ' f'{project_gutenberg_id} ({chapter_count} chapters).') prompt_message_lists = [] response_list = [] for chapter in script['chapters']: chapter_index = chapter['index'] chapter_segments = chapter['text'] chapter_segment_count = len(chapter_segments) for chapter_segment_index, chapter_segment in enumerate(chapter_segments): prompt_with_story = prompt.replace('{STORY}', chapter_segment) prompt_message_lists.append([{ 'role': 'user', 'content': prompt_with_story, 'api_key': api_key, 'model_id': model_id }]) responses = execute_function_in_parallel(save_openai_api_response, prompt_message_lists) for response in responses: response_list.append(response) logger.info(f'Finished processing responses for script {project_gutenberg_id}.') return response_list def save_triples_for_scripts(input_data, idx, api_key, model_id): """ Call the OpenAI API to generate knowledge graph nodes and edges, and store the responses in a list. """ # 1) load data script = input_data # 2) call OpenAI API prompt = PROMPT_FILE_PATH.read_text() # load prompt max_chapter_segment_token_count = get_max_chapter_segment_token_count(prompt, model_id) encoding = get_openai_model_encoding(model_id) responses = save_openai_api_responses_for_script( script, prompt, encoding, max_chapter_segment_token_count, idx, api_key, model_id ) return responses