Commit
·
c94b54b
1
Parent(s):
4e646fa
first commit
Browse files- .DS_Store +0 -0
- .gitignore +20 -0
- app.py +543 -0
- requirements.txt +200 -0
- src/fact/narrativefactscore.py +230 -0
- src/fact/openai_api.py +52 -0
- src/fact/prompt.py +28 -0
- src/fact/utils.py +47 -0
- src/kg/__init__.py +0 -0
- src/kg/generate_kg.py +253 -0
- src/kg/knowledge_graph.py +301 -0
- src/kg/main.py +23 -0
- src/kg/openai_api.py +101 -0
- src/kg/preprocess.py +20 -0
- src/kg/save_triples.py +215 -0
- src/kg/utils.py +57 -0
- src/summary/prompt.py +33 -0
- src/summary/summarizer.py +65 -0
- src/summary/utils.py +118 -0
- templates/atomic_fact.txt +5 -0
- templates/external_summary.txt +8 -0
- templates/fact_score.txt +8 -0
- templates/fact_score_kg.txt +9 -0
- templates/self_correction.txt +8 -0
- templates/story-prompt.txt +43 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitignore
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*.pyo
|
5 |
+
*.egg
|
6 |
+
*.egg-info/
|
7 |
+
dist/
|
8 |
+
build/
|
9 |
+
*.log
|
10 |
+
|
11 |
+
.env
|
12 |
+
*.env.*
|
13 |
+
|
14 |
+
*.pem
|
15 |
+
*.key
|
16 |
+
|
17 |
+
.cache/
|
18 |
+
*.pytest_cache/
|
19 |
+
|
20 |
+
.gradio/
|
app.py
ADDED
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
from tqdm import tqdm
|
4 |
+
import numpy as np
|
5 |
+
import random
|
6 |
+
import torch
|
7 |
+
import ast
|
8 |
+
from difflib import HtmlDiff
|
9 |
+
|
10 |
+
from src.kg.main import script2kg
|
11 |
+
from src.summary.summarizer import Summarizer
|
12 |
+
from src.summary.utils import preprocess_script, chunk_script_gpt
|
13 |
+
from src.summary.prompt import build_summarizer_prompt
|
14 |
+
from src.fact.narrativefactscore import NarrativeFactScore
|
15 |
+
|
16 |
+
def _set_seed(seed):
|
17 |
+
np.random.seed(seed)
|
18 |
+
random.seed(seed)
|
19 |
+
torch.manual_seed(seed)
|
20 |
+
if torch.cuda.is_available():
|
21 |
+
torch.cuda.manual_seed_all(seed)
|
22 |
+
torch.backends.cudnn.deterministic = True
|
23 |
+
torch.backends.cudnn.benchmark = False
|
24 |
+
|
25 |
+
def parse_scenes(scene_text):
|
26 |
+
try:
|
27 |
+
return json.loads(scene_text)
|
28 |
+
except json.JSONDecodeError:
|
29 |
+
return ast.literal_eval(scene_text)
|
30 |
+
|
31 |
+
def set_name_list(dataset, data_type):
|
32 |
+
if dataset == "MovieSum":
|
33 |
+
if data_type == "train":
|
34 |
+
return ['8MM_1999', 'The Iron Lady_2011', 'Adventureland_2009', 'Napoleon_2023',
|
35 |
+
'Kubo and the Two Strings_2016', 'The Woman King_2022', 'What They Had_2018',
|
36 |
+
'Synecdoche, New York_2008', 'Black Christmas_2006', 'Superbad_2007']
|
37 |
+
elif data_type == "validation":
|
38 |
+
return ['The Boondock Saints_1999', 'The House with a Clock in Its Walls_2018',
|
39 |
+
'The Unbelievable Truth_1989', 'Insidious_2010', 'If Beale Street Could Talk_2018',
|
40 |
+
'The Battle of Shaker Heights_2003', '20th Century Women_2016',
|
41 |
+
'Captain Phillips_2013', 'Conspiracy Theory_1997', 'Domino_2005']
|
42 |
+
elif data_type == "test":
|
43 |
+
# Return test dataset names (shortened for brevity)
|
44 |
+
return ['A Nightmare on Elm Street 3: Dream Warriors_1987', 'Van Helsing_2004',
|
45 |
+
'Oppenheimer_2023', 'Armored_2009', 'The Martian_2015']
|
46 |
+
elif dataset == "MENSA":
|
47 |
+
if data_type == "train":
|
48 |
+
return ['The_Ides_of_March_(film)', 'An_American_Werewolf_in_Paris',
|
49 |
+
'Batman_&_Robin_(film)', 'Airplane_II:_The_Sequel', 'Krull_(film)']
|
50 |
+
elif data_type == "validation":
|
51 |
+
return ['Pleasantville_(film)', 'V_for_Vendetta_(film)',
|
52 |
+
'Mary_Shelleys_Frankenstein_(film)', 'Rapture_(1965_film)', 'Get_Out']
|
53 |
+
elif data_type == "test":
|
54 |
+
return ['Knives_Out', 'Black_Panther', 'Pet_Sematary_(film)',
|
55 |
+
'Panic_Room', 'The_Village_(2004_film)']
|
56 |
+
return []
|
57 |
+
|
58 |
+
def update_name_list_interface(dataset, data_type):
|
59 |
+
if dataset in ["MovieSum", "MENSA"]:
|
60 |
+
return (
|
61 |
+
gr.update(choices=set_name_list(dataset, data_type), value=None, visible=True),
|
62 |
+
gr.update(visible=False),
|
63 |
+
gr.update(value="")
|
64 |
+
)
|
65 |
+
else:
|
66 |
+
return (
|
67 |
+
gr.update(visible=False),
|
68 |
+
gr.update(visible=True),
|
69 |
+
gr.update(value="Click next 'Knowledge Graph' to continue")
|
70 |
+
)
|
71 |
+
|
72 |
+
def read_data(dataset, data_type):
|
73 |
+
file_path = f"dataset/{dataset}/{data_type}.jsonl"
|
74 |
+
try:
|
75 |
+
with open(file_path, 'r', encoding='utf8') as f:
|
76 |
+
data = [json.loads(line) for line in f]
|
77 |
+
return data
|
78 |
+
except FileNotFoundError:
|
79 |
+
return []
|
80 |
+
|
81 |
+
def find_work_index(data, work_name):
|
82 |
+
for idx, entry in enumerate(data):
|
83 |
+
if entry.get("name") == work_name:
|
84 |
+
return idx, entry
|
85 |
+
return None, "Work not found in the selected dataset."
|
86 |
+
|
87 |
+
def get_narrative_content(dataset, data_type, work):
|
88 |
+
data = read_data(dataset, data_type)
|
89 |
+
for entry in data:
|
90 |
+
if entry.get("name") == work:
|
91 |
+
return entry['scenes']
|
92 |
+
return "Work not found in the selected dataset."
|
93 |
+
|
94 |
+
def get_narrative_content_with_index(dataset, data_type, work):
|
95 |
+
data = read_data(dataset, data_type)
|
96 |
+
for idx, entry in enumerate(data):
|
97 |
+
if entry.get("name") == work:
|
98 |
+
# For MovieSum and MENSA datasets, only return scenes
|
99 |
+
if dataset in ["MovieSum", "MENSA"]:
|
100 |
+
return "\n".join(entry['scenes']), idx, data
|
101 |
+
# For other datasets or custom input, return full content
|
102 |
+
return entry, idx, data
|
103 |
+
return "Work not found in the selected dataset.", None, None
|
104 |
+
|
105 |
+
def show_diff(original, revised):
|
106 |
+
d = HtmlDiff()
|
107 |
+
original_lines = original.splitlines(keepends=True)
|
108 |
+
revised_lines = revised.splitlines(keepends=True)
|
109 |
+
diff_table = d.make_table(original_lines, revised_lines, fromdesc='Original Summary', todesc='Refined Summary', context=True, numlines=2)
|
110 |
+
return diff_table
|
111 |
+
|
112 |
+
def extract_initial_summary(summary_result):
|
113 |
+
return summary_result['summary_agg']['summaries']
|
114 |
+
|
115 |
+
def extract_factuality_score_and_details(fact_score_result):
|
116 |
+
factuality_score = fact_score_result['fact_score']
|
117 |
+
feedback_list = []
|
118 |
+
for i, feedback_data in enumerate(fact_score_result['summary_feedback_pairs']):
|
119 |
+
feedbacks = [fb for fb in feedback_data['feedbacks'] if fb.strip()]
|
120 |
+
if feedbacks:
|
121 |
+
feedback_list.append(f"In chunk {i + 1}: {'; '.join(feedbacks)}")
|
122 |
+
incorrect_details = "\n".join(feedback_list)
|
123 |
+
return factuality_score, incorrect_details
|
124 |
+
|
125 |
+
def build_kg(script, idx, api_key, model_id):
|
126 |
+
kg = script2kg(script['scenes'], idx, script['name'], api_key, model_id)
|
127 |
+
return kg
|
128 |
+
|
129 |
+
def build_kg_custom(scenes, idx, api_key, model_id):
|
130 |
+
kg = script2kg(scenes, idx, "custom", api_key, model_id)
|
131 |
+
return kg
|
132 |
+
|
133 |
+
def build_kg_with_data(data, work_index, custom_scenes, api_key, model_id):
|
134 |
+
if data and work_index is not None: # Dataset mode
|
135 |
+
script = data[int(work_index)]
|
136 |
+
try:
|
137 |
+
kg = script2kg(script['scenes'], int(work_index), script['name'], api_key, model_id)
|
138 |
+
return kg, "Knowledge Graph built successfully!"
|
139 |
+
except Exception as e:
|
140 |
+
return None, f"Error building knowledge graph: {str(e)}"
|
141 |
+
elif custom_scenes: # Custom script mode
|
142 |
+
try:
|
143 |
+
scenes = parse_scenes(custom_scenes)
|
144 |
+
if not isinstance(scenes, list):
|
145 |
+
return None, "Invalid format. Please provide scenes as a list."
|
146 |
+
kg = build_kg_custom(scenes, 0, api_key, model_id)
|
147 |
+
return kg, "Knowledge Graph built successfully!"
|
148 |
+
except (json.JSONDecodeError, SyntaxError, ValueError) as e:
|
149 |
+
return None, f"Invalid format. Error: {str(e)}"
|
150 |
+
except Exception as e:
|
151 |
+
return None, f"Error building knowledge graph: {str(e)}"
|
152 |
+
return None, "Please select a work or input custom scenes."
|
153 |
+
|
154 |
+
def generate_summary(script, idx, api_key, model_id):
|
155 |
+
_set_seed(42)
|
156 |
+
scripty_summarizer = Summarizer(
|
157 |
+
inference_mode="org",
|
158 |
+
model_id=model_id,
|
159 |
+
api_key=api_key,
|
160 |
+
dtype="float16",
|
161 |
+
seed=42
|
162 |
+
)
|
163 |
+
scenes = [f"s#{i}\n{s}" for i, s in enumerate(script['scenes'])]
|
164 |
+
script = "\n\n".join(scenes)
|
165 |
+
script_chunks = chunk_script_gpt(script=script, model_id=model_id, chunk_size=2048)
|
166 |
+
|
167 |
+
script_summaries = []
|
168 |
+
for chunk in tqdm(script_chunks):
|
169 |
+
chunk = preprocess_script(chunk)
|
170 |
+
prompt = build_summarizer_prompt(
|
171 |
+
prompt_template="./templates/external_summary.txt",
|
172 |
+
input_text_list=[chunk]
|
173 |
+
)
|
174 |
+
script_summ = scripty_summarizer.inference_with_gpt(prompt=prompt)
|
175 |
+
script_summaries.append(script_summ.strip())
|
176 |
+
|
177 |
+
elem_dict_list = []
|
178 |
+
agg_dict = {
|
179 |
+
'script': ' '.join(script_chunks),
|
180 |
+
'summaries': ' '.join(script_summaries)
|
181 |
+
}
|
182 |
+
|
183 |
+
for i, (chunk, summary) in enumerate(zip(script_chunks, script_summaries)):
|
184 |
+
elem_dict = {
|
185 |
+
"chunk_index": i,
|
186 |
+
"chunk": chunk.strip(),
|
187 |
+
"summary": summary.strip()
|
188 |
+
}
|
189 |
+
elem_dict_list.append(elem_dict)
|
190 |
+
|
191 |
+
processed_dataset = {
|
192 |
+
"script": script,
|
193 |
+
"scenes": scenes,
|
194 |
+
"script_chunks": script_chunks,
|
195 |
+
"script_summaries": script_summaries,
|
196 |
+
}
|
197 |
+
|
198 |
+
return {"summary_sep": elem_dict_list, "summary_agg": agg_dict, "processed_dataset": processed_dataset}
|
199 |
+
|
200 |
+
def generate_summary_with_data(data, work_index, custom_scenes, api_key, model_id):
|
201 |
+
if data and work_index is not None: # Dataset mode
|
202 |
+
script = data[int(work_index)]
|
203 |
+
try:
|
204 |
+
summary = generate_summary(script, int(work_index), api_key, model_id)
|
205 |
+
return summary, extract_initial_summary(summary)
|
206 |
+
except Exception as e:
|
207 |
+
return None, f"Error generating summary: {str(e)}"
|
208 |
+
elif custom_scenes: # Custom script mode
|
209 |
+
try:
|
210 |
+
scenes = parse_scenes(custom_scenes)
|
211 |
+
if not isinstance(scenes, list):
|
212 |
+
return None, "Invalid format. Please provide scenes as a list."
|
213 |
+
script = {"name": "custom", "scenes": scenes}
|
214 |
+
summary = generate_summary(script, 0, api_key, model_id)
|
215 |
+
return summary, extract_initial_summary(summary)
|
216 |
+
except (json.JSONDecodeError, SyntaxError, ValueError) as e:
|
217 |
+
return None, f"Invalid format. Error: {str(e)}"
|
218 |
+
except Exception as e:
|
219 |
+
return None, f"Error generating summary: {str(e)}"
|
220 |
+
return None, "Please select a work or input custom scenes."
|
221 |
+
|
222 |
+
def calculate_narrative_fact_score(summary, kg_raw, api_key, model_id):
|
223 |
+
_set_seed(42)
|
224 |
+
factscorer = NarrativeFactScore(split_type='gpt', model='gptscore', api_key=api_key, model_id=model_id)
|
225 |
+
|
226 |
+
summary = summary['processed_dataset']
|
227 |
+
chunks, summaries = summary['script_chunks'], summary['script_summaries']
|
228 |
+
total_output = {'fact_score': 0, 'summary_feedback_pairs': []}
|
229 |
+
partial_output = {'fact_score': 0, 'summary_feedback_pairs': []}
|
230 |
+
total_score = 0
|
231 |
+
kg = []
|
232 |
+
for elem in kg_raw:
|
233 |
+
if elem['subject'] == elem['object']:
|
234 |
+
kg.append(f"{elem['subject']} {elem['predicate']}")
|
235 |
+
else:
|
236 |
+
kg.append(f"{elem['subject']} {elem['predicate']} {elem['object']}")
|
237 |
+
|
238 |
+
scores, scores_per_sent, relevant_scenes, summary_chunks, feedbacks = factscorer.score_src_hyp_long(chunks, summaries, kg)
|
239 |
+
for i, score in enumerate(scores):
|
240 |
+
output_elem = {
|
241 |
+
'src': chunks[i],
|
242 |
+
'summary': summaries[i],
|
243 |
+
'score': score,
|
244 |
+
'scores_per_sent': scores_per_sent[i],
|
245 |
+
'relevant_scenes': relevant_scenes[i],
|
246 |
+
'summary_chunks': summary_chunks[i],
|
247 |
+
'feedbacks': feedbacks[i],
|
248 |
+
}
|
249 |
+
output_elem_part = {
|
250 |
+
'scores_per_sent': scores_per_sent[i],
|
251 |
+
'summary_chunks': summary_chunks[i],
|
252 |
+
'feedbacks': feedbacks[i],
|
253 |
+
}
|
254 |
+
total_output['summary_feedback_pairs'].append(output_elem)
|
255 |
+
partial_output['summary_feedback_pairs'].append(output_elem_part)
|
256 |
+
total_score += score
|
257 |
+
|
258 |
+
total_output['fact_score'] = float(total_score / len(scores))
|
259 |
+
partial_output['fact_score'] = float(total_score / len(scores))
|
260 |
+
return total_output, partial_output
|
261 |
+
|
262 |
+
def refine_summary(summary, fact_score, api_key, model_id):
|
263 |
+
_set_seed(42)
|
264 |
+
threshold = 0.9
|
265 |
+
summarizer = Summarizer(
|
266 |
+
inference_mode="org",
|
267 |
+
model_id=model_id,
|
268 |
+
api_key=api_key,
|
269 |
+
dtype="float16",
|
270 |
+
seed=42
|
271 |
+
)
|
272 |
+
|
273 |
+
processed_dataset = {
|
274 |
+
"script": summary["script"],
|
275 |
+
"scenes": summary["scenes"],
|
276 |
+
"script_chunks": [],
|
277 |
+
"script_summaries": []
|
278 |
+
}
|
279 |
+
elem_dict_list = []
|
280 |
+
agg_dict = {}
|
281 |
+
|
282 |
+
for factscore_chunk in tqdm(fact_score['summary_feedback_pairs']):
|
283 |
+
src_chunk = factscore_chunk['src']
|
284 |
+
original_summary = factscore_chunk['summary']
|
285 |
+
|
286 |
+
if factscore_chunk['score'] >= threshold:
|
287 |
+
processed_dataset["script_chunks"].append(src_chunk)
|
288 |
+
processed_dataset["script_summaries"].append(original_summary.strip())
|
289 |
+
continue
|
290 |
+
|
291 |
+
hallu_idxs = np.where(np.array(factscore_chunk['scores_per_sent']) == 0)[0]
|
292 |
+
hallu_summary_parts = np.array(factscore_chunk['summary_chunks'])[hallu_idxs]
|
293 |
+
feedbacks = np.array(factscore_chunk['feedbacks'])[hallu_idxs]
|
294 |
+
|
295 |
+
prompt = build_summarizer_prompt(
|
296 |
+
prompt_template="./templates/self_correction.txt",
|
297 |
+
input_text_list=[src_chunk, original_summary]
|
298 |
+
)
|
299 |
+
|
300 |
+
for j, (hallu_summ, feedback) in enumerate(zip(hallu_summary_parts, feedbacks)):
|
301 |
+
prompt += f"\n- Statement to Revise {j + 1}: {hallu_summ} (Reason for Revision: {feedback})"
|
302 |
+
prompt += "\n- Revised Summary: "
|
303 |
+
|
304 |
+
revised_summary = summarizer.inference_with_gpt(prompt=prompt)
|
305 |
+
|
306 |
+
if len(revised_summary.strip()) == 0:
|
307 |
+
revised_summary = original_summary
|
308 |
+
|
309 |
+
processed_dataset["script_chunks"].append(src_chunk)
|
310 |
+
processed_dataset["script_summaries"].append(revised_summary)
|
311 |
+
|
312 |
+
elem_dict = {
|
313 |
+
"chunk_index": len(processed_dataset["script_chunks"]) - 1,
|
314 |
+
"chunk": src_chunk.strip(),
|
315 |
+
"summary": revised_summary.strip(),
|
316 |
+
"org_summary": original_summary.strip(),
|
317 |
+
"hallu_in_summary": list(hallu_summary_parts),
|
318 |
+
"feedbacks": list(feedbacks),
|
319 |
+
}
|
320 |
+
elem_dict_list.append(elem_dict)
|
321 |
+
|
322 |
+
agg_dict['script'] = summary['script']
|
323 |
+
agg_dict['summaries'] = ' '.join(processed_dataset["script_summaries"])
|
324 |
+
|
325 |
+
return {
|
326 |
+
"summary_sep": elem_dict_list,
|
327 |
+
"summary_agg": agg_dict,
|
328 |
+
"processed_dataset": processed_dataset
|
329 |
+
}
|
330 |
+
|
331 |
+
def refine_summary_and_return_diff(summary, fact_score, api_key, model_id):
|
332 |
+
refined_summary = refine_summary(summary['processed_dataset'], fact_score, api_key, model_id)
|
333 |
+
diff = HtmlDiff().make_file(
|
334 |
+
summary['summary_agg']['summaries'].splitlines(),
|
335 |
+
refined_summary['summary_agg']['summaries'].splitlines(),
|
336 |
+
context=True
|
337 |
+
)
|
338 |
+
return diff
|
339 |
+
|
340 |
+
def open_kg(kg_data):
|
341 |
+
if kg_data is None:
|
342 |
+
return "Please build the knowledge graph first."
|
343 |
+
try:
|
344 |
+
with open('refined_kg.html', 'r', encoding='utf-8') as f:
|
345 |
+
html_content = f.read()
|
346 |
+
return f'''
|
347 |
+
<iframe
|
348 |
+
srcdoc="{html_content.replace('"', '"')}"
|
349 |
+
style="width: 100%; height: 500px; border: none;"
|
350 |
+
></iframe>
|
351 |
+
'''
|
352 |
+
except Exception as e:
|
353 |
+
return f'<div style="color: red;">Error reading KG file: {str(e)}</div>'
|
354 |
+
|
355 |
+
def format_fact_score_output(fact_score_result):
|
356 |
+
if not fact_score_result:
|
357 |
+
return "No factuality analysis available"
|
358 |
+
|
359 |
+
formatted_output = []
|
360 |
+
|
361 |
+
# Overall score
|
362 |
+
formatted_output.append(f"Overall Factuality Score: {fact_score_result['fact_score']*100:.1f}%\n")
|
363 |
+
|
364 |
+
# Individual chunk analysis
|
365 |
+
for i, chunk in enumerate(fact_score_result['summary_feedback_pairs'], 1):
|
366 |
+
formatted_output.append(f"\nChunk {i} Analysis:")
|
367 |
+
formatted_output.append("Original Text:")
|
368 |
+
formatted_output.append(f"{' '.join(chunk['summary_chunks'])}\n")
|
369 |
+
|
370 |
+
if chunk['feedbacks']:
|
371 |
+
formatted_output.append("Feedback:")
|
372 |
+
feedbacks = [f"• {feedback}" for feedback in chunk['feedbacks'] if feedback.strip()]
|
373 |
+
formatted_output.extend(feedbacks)
|
374 |
+
|
375 |
+
formatted_output.append("-" * 80)
|
376 |
+
|
377 |
+
return "\n".join(formatted_output)
|
378 |
+
|
379 |
+
|
380 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
381 |
+
gr.Markdown(
|
382 |
+
"""
|
383 |
+
# NarrativeFactScore: Script Factuality Evaluation
|
384 |
+
Evaluate and refine script summaries using narrative factuality scoring.
|
385 |
+
"""
|
386 |
+
)
|
387 |
+
|
388 |
+
with gr.Accordion("Model Settings", open=True):
|
389 |
+
with gr.Row():
|
390 |
+
api_key_input = gr.Textbox(
|
391 |
+
label="GPT API Key",
|
392 |
+
placeholder="Enter your GPT API key",
|
393 |
+
type="password",
|
394 |
+
scale=2
|
395 |
+
)
|
396 |
+
model_selector = gr.Dropdown(
|
397 |
+
choices=[
|
398 |
+
"gpt-4o-mini",
|
399 |
+
"gpt-4o",
|
400 |
+
"gpt-4-turbo",
|
401 |
+
"gpt-3.5-turbo-0125"
|
402 |
+
],
|
403 |
+
value="gpt-4o",
|
404 |
+
label="Model Selection",
|
405 |
+
scale=1
|
406 |
+
)
|
407 |
+
|
408 |
+
with gr.Tabs():
|
409 |
+
with gr.TabItem("Dataset Selection"):
|
410 |
+
with gr.Row():
|
411 |
+
dataset_selector = gr.Radio(
|
412 |
+
choices=["MovieSum", "MENSA", "Custom"],
|
413 |
+
label="Dataset",
|
414 |
+
info="Choose the dataset or input custom script"
|
415 |
+
)
|
416 |
+
data_type_selector = gr.Radio(
|
417 |
+
choices=["train", "validation", "test"],
|
418 |
+
label="Split Type",
|
419 |
+
info="Select data split",
|
420 |
+
visible=True
|
421 |
+
)
|
422 |
+
name_list = gr.Dropdown(
|
423 |
+
choices=[],
|
424 |
+
label="Select Script",
|
425 |
+
info="Choose a script to analyze",
|
426 |
+
visible=True
|
427 |
+
)
|
428 |
+
custom_input = gr.Textbox(
|
429 |
+
label="Custom Script Input",
|
430 |
+
info="Enter scenes as a JSON list: ['scene1', 'scene2', ...]",
|
431 |
+
lines=10,
|
432 |
+
visible=False
|
433 |
+
)
|
434 |
+
narrative_output = gr.Textbox(
|
435 |
+
label="Script Content",
|
436 |
+
interactive=False,
|
437 |
+
lines=10
|
438 |
+
)
|
439 |
+
|
440 |
+
with gr.TabItem("Knowledge Graph"):
|
441 |
+
with gr.Row():
|
442 |
+
generate_kg_button = gr.Button(
|
443 |
+
"Generate Knowledge Graph",
|
444 |
+
variant="primary"
|
445 |
+
)
|
446 |
+
open_kg_button = gr.Button("View Graph")
|
447 |
+
kg_status = gr.Textbox(
|
448 |
+
label="Status",
|
449 |
+
interactive=False
|
450 |
+
)
|
451 |
+
kg_viewer = gr.HTML(label="Knowledge Graph Visualization")
|
452 |
+
|
453 |
+
with gr.TabItem("Summary Generation"):
|
454 |
+
generate_summary_button = gr.Button(
|
455 |
+
"Generate Initial Summary",
|
456 |
+
variant="primary"
|
457 |
+
)
|
458 |
+
summary_output = gr.Textbox(
|
459 |
+
label="Generated Summary",
|
460 |
+
interactive=False,
|
461 |
+
lines=5
|
462 |
+
)
|
463 |
+
calculate_score_button = gr.Button("Calculate Factuality Score")
|
464 |
+
fact_score_display = gr.Textbox(
|
465 |
+
label="Factuality Analysis",
|
466 |
+
interactive=False,
|
467 |
+
lines=10
|
468 |
+
)
|
469 |
+
|
470 |
+
with gr.TabItem("Summary Refinement"):
|
471 |
+
refine_button = gr.Button(
|
472 |
+
"Refine Summary",
|
473 |
+
variant="primary"
|
474 |
+
)
|
475 |
+
refined_output = gr.HTML(label="Refined Summary with Changes")
|
476 |
+
|
477 |
+
# Hidden states
|
478 |
+
work_index = gr.State()
|
479 |
+
data_state = gr.State()
|
480 |
+
kg_output = gr.State()
|
481 |
+
summary_state = gr.State()
|
482 |
+
fact_score_state = gr.State()
|
483 |
+
|
484 |
+
# Event handlers
|
485 |
+
dataset_selector.change(
|
486 |
+
fn=lambda x: gr.update(visible=x in ["MovieSum", "MENSA"]),
|
487 |
+
inputs=[dataset_selector],
|
488 |
+
outputs=data_type_selector
|
489 |
+
)
|
490 |
+
|
491 |
+
dataset_selector.change(
|
492 |
+
fn=update_name_list_interface,
|
493 |
+
inputs=[dataset_selector, data_type_selector],
|
494 |
+
outputs=[name_list, custom_input, narrative_output]
|
495 |
+
)
|
496 |
+
|
497 |
+
name_list.change(
|
498 |
+
fn=get_narrative_content_with_index,
|
499 |
+
inputs=[dataset_selector, data_type_selector, name_list],
|
500 |
+
outputs=[narrative_output, work_index, data_state]
|
501 |
+
)
|
502 |
+
|
503 |
+
generate_kg_button.click(
|
504 |
+
fn=build_kg_with_data,
|
505 |
+
inputs=[
|
506 |
+
data_state, # data
|
507 |
+
work_index, # work_index
|
508 |
+
custom_input, # custom_scenes
|
509 |
+
api_key_input, # api_key
|
510 |
+
model_selector # model_id
|
511 |
+
],
|
512 |
+
outputs=[kg_output, kg_status]
|
513 |
+
)
|
514 |
+
|
515 |
+
open_kg_button.click(
|
516 |
+
fn=open_kg,
|
517 |
+
inputs=[kg_output],
|
518 |
+
outputs=kg_viewer
|
519 |
+
)
|
520 |
+
|
521 |
+
generate_summary_button.click(
|
522 |
+
fn=generate_summary_with_data,
|
523 |
+
inputs=[data_state, work_index, custom_input, api_key_input, model_selector],
|
524 |
+
outputs=[summary_state, summary_output]
|
525 |
+
)
|
526 |
+
|
527 |
+
calculate_score_button.click(
|
528 |
+
fn=lambda summary, kg, api_key, model: (
|
529 |
+
*calculate_narrative_fact_score(summary, kg, api_key, model),
|
530 |
+
format_fact_score_output(calculate_narrative_fact_score(summary, kg, api_key, model)[0])
|
531 |
+
),
|
532 |
+
inputs=[summary_state, kg_output, api_key_input, model_selector],
|
533 |
+
outputs=[fact_score_state, fact_score_display]
|
534 |
+
)
|
535 |
+
|
536 |
+
refine_button.click(
|
537 |
+
fn=refine_summary_and_return_diff,
|
538 |
+
inputs=[summary_state, fact_score_state, api_key_input, model_selector],
|
539 |
+
outputs=refined_output
|
540 |
+
)
|
541 |
+
|
542 |
+
if __name__ == "__main__":
|
543 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.1.0
|
2 |
+
accelerate==1.1.1
|
3 |
+
aiofiles==23.2.1
|
4 |
+
aiohappyeyeballs==2.4.4
|
5 |
+
aiohttp==3.11.9
|
6 |
+
aiosignal==1.3.1
|
7 |
+
airportsdata==20241001
|
8 |
+
alembic==1.14.0
|
9 |
+
annotated-types==0.7.0
|
10 |
+
anyio==4.6.2.post1
|
11 |
+
asttokens
|
12 |
+
async-timeout==5.0.1
|
13 |
+
attrs==24.2.0
|
14 |
+
banal==1.0.6
|
15 |
+
bert-score==0.3.13
|
16 |
+
bitsandbytes==0.44.1
|
17 |
+
blis==1.0.1
|
18 |
+
breadability==0.1.20
|
19 |
+
catalogue==2.0.10
|
20 |
+
certifi==2024.8.30
|
21 |
+
chardet==5.2.0
|
22 |
+
charset-normalizer==3.4.0
|
23 |
+
click==8.1.7
|
24 |
+
cloudpathlib==0.20.0
|
25 |
+
cloudpickle==3.1.0
|
26 |
+
cmake==3.31.1
|
27 |
+
confection==0.1.5
|
28 |
+
contourpy==1.3.1
|
29 |
+
cycler==0.12.1
|
30 |
+
cymem==2.0.10
|
31 |
+
dataset==1.6.2
|
32 |
+
datasets==2.16.0
|
33 |
+
debugpy
|
34 |
+
decorator
|
35 |
+
dill==0.3.6
|
36 |
+
diskcache==5.6.3
|
37 |
+
distro==1.9.0
|
38 |
+
docopt==0.6.2
|
39 |
+
entrypoints
|
40 |
+
evaluate==0.4.3
|
41 |
+
exceptiongroup
|
42 |
+
executing
|
43 |
+
fastapi==0.115.6
|
44 |
+
ffmpy==0.4.0
|
45 |
+
filelock==3.16.1
|
46 |
+
FlagEmbedding==1.2.11
|
47 |
+
fonttools==4.55.1
|
48 |
+
frozenlist==1.5.0
|
49 |
+
fsspec==2023.10.0
|
50 |
+
gradio==5.8.0
|
51 |
+
gradio_client==1.5.1
|
52 |
+
greenlet==3.1.1
|
53 |
+
h11==0.14.0
|
54 |
+
httpcore==1.0.7
|
55 |
+
httptools==0.6.4
|
56 |
+
httpx==0.28.0
|
57 |
+
huggingface-hub==0.26.3
|
58 |
+
idna==3.10
|
59 |
+
interegular==0.3.3
|
60 |
+
ipdb==0.13.13
|
61 |
+
jedi
|
62 |
+
Jinja2==3.1.4
|
63 |
+
jiter==0.8.0
|
64 |
+
joblib==1.4.2
|
65 |
+
jsonpickle==4.0.0
|
66 |
+
jsonschema==4.23.0
|
67 |
+
jsonschema-specifications==2024.10.1
|
68 |
+
kiwisolver==1.4.7
|
69 |
+
langcodes==3.5.0
|
70 |
+
language_data==1.3.0
|
71 |
+
lark==1.2.2
|
72 |
+
lm-format-enforcer==0.10.1
|
73 |
+
longdocfactscore==1.0.0
|
74 |
+
lxml==5.3.0
|
75 |
+
Mako==1.3.8
|
76 |
+
marisa-trie==1.2.1
|
77 |
+
markdown-it-py==3.0.0
|
78 |
+
MarkupSafe==2.1.5
|
79 |
+
matching==1.4
|
80 |
+
matplotlib==3.9.3
|
81 |
+
mdurl==0.1.2
|
82 |
+
mpmath==1.3.0
|
83 |
+
msgpack==1.1.0
|
84 |
+
multidict==6.1.0
|
85 |
+
multiprocess==0.70.14
|
86 |
+
murmurhash==1.0.11
|
87 |
+
nest_asyncio
|
88 |
+
networkx==3.4.2
|
89 |
+
ninja==1.11.1.2
|
90 |
+
nltk==3.6.2
|
91 |
+
numpy==2.0.2
|
92 |
+
nvidia-cublas-cu11==11.11.3.6
|
93 |
+
nvidia-cuda-cupti-cu11==11.8.87
|
94 |
+
nvidia-cuda-nvrtc-cu11==11.8.89
|
95 |
+
nvidia-cuda-runtime-cu11==11.8.89
|
96 |
+
nvidia-cudnn-cu11==8.7.0.84
|
97 |
+
nvidia-cufft-cu11==10.9.0.58
|
98 |
+
nvidia-curand-cu11==10.3.0.86
|
99 |
+
nvidia-cusolver-cu11==11.4.1.48
|
100 |
+
nvidia-cusparse-cu11==11.7.5.86
|
101 |
+
nvidia-ml-py==12.560.30
|
102 |
+
nvidia-nccl-cu11==2.20.5
|
103 |
+
nvidia-nvtx-cu11==11.8.86
|
104 |
+
openai==0.28.1
|
105 |
+
orjson==3.10.12
|
106 |
+
outlines==0.1.7
|
107 |
+
outlines_core==0.1.17
|
108 |
+
packaging
|
109 |
+
pandas==2.2.3
|
110 |
+
parso
|
111 |
+
peft==0.13.2
|
112 |
+
pexpect
|
113 |
+
pickleshare
|
114 |
+
pillow==11.0.0
|
115 |
+
platformdirs
|
116 |
+
preshed==3.0.9
|
117 |
+
prometheus-fastapi-instrumentator==7.0.0
|
118 |
+
prometheus_client==0.21.1
|
119 |
+
prompt_toolkit
|
120 |
+
propcache==0.2.1
|
121 |
+
protobuf==5.29.0
|
122 |
+
psutil
|
123 |
+
ptyprocess
|
124 |
+
pure_eval
|
125 |
+
py-cpuinfo==9.0.0
|
126 |
+
pyarrow==18.1.0
|
127 |
+
pyarrow-hotfix==0.6
|
128 |
+
pycountry==24.6.1
|
129 |
+
pydantic==2.10.3
|
130 |
+
pydantic_core==2.27.1
|
131 |
+
pydub==0.25.1
|
132 |
+
Pygments
|
133 |
+
pyparsing==3.2.0
|
134 |
+
pysbd==0.3.4
|
135 |
+
python-dateutil
|
136 |
+
python-dotenv==1.0.1
|
137 |
+
python-multipart==0.0.19
|
138 |
+
pytz==2024.2
|
139 |
+
pyvis==0.3.2
|
140 |
+
PyYAML==6.0.2
|
141 |
+
pyzmq
|
142 |
+
ray==2.40.0
|
143 |
+
referencing==0.35.1
|
144 |
+
regex==2024.11.6
|
145 |
+
requests==2.32.3
|
146 |
+
responses==0.18.0
|
147 |
+
rich==13.9.4
|
148 |
+
rouge_score==0.1.2
|
149 |
+
rpds-py==0.22.1
|
150 |
+
ruff==0.8.3
|
151 |
+
safehttpx==0.1.6
|
152 |
+
safetensors==0.4.5
|
153 |
+
scikit-learn==1.5.2
|
154 |
+
scipy==1.14.1
|
155 |
+
seaborn==0.13.2
|
156 |
+
semantic-version==2.10.0
|
157 |
+
sentence-transformers==3.3.1
|
158 |
+
sentencepiece==0.2.0
|
159 |
+
shellingham==1.5.4
|
160 |
+
six==1.16.0
|
161 |
+
smart-open==7.0.5
|
162 |
+
sniffio==1.3.1
|
163 |
+
spacy==3.8.2
|
164 |
+
spacy-legacy==3.0.12
|
165 |
+
spacy-loggers==1.0.5
|
166 |
+
SQLAlchemy==1.4.54
|
167 |
+
srsly==2.4.8
|
168 |
+
stack_data
|
169 |
+
starlette==0.41.3
|
170 |
+
sumy==0.11.0
|
171 |
+
sympy==1.13.3
|
172 |
+
tenacity==9.0.0
|
173 |
+
thinc==8.3.2
|
174 |
+
threadpoolctl==3.5.0
|
175 |
+
tiktoken==0.8.0
|
176 |
+
tokenizers==0.20.3
|
177 |
+
tomli==2.2.1
|
178 |
+
tomlkit==0.13.2
|
179 |
+
torch==2.3.0
|
180 |
+
tornado
|
181 |
+
tqdm==4.67.1
|
182 |
+
traitlets
|
183 |
+
transformers==4.46.3
|
184 |
+
triton==2.3.0
|
185 |
+
typer==0.15.0
|
186 |
+
typing_extensions
|
187 |
+
tzdata==2024.2
|
188 |
+
Unidecode==1.2.0
|
189 |
+
urllib3==2.2.3
|
190 |
+
uvicorn==0.32.1
|
191 |
+
uvloop==0.21.0
|
192 |
+
wasabi==1.1.3
|
193 |
+
watchfiles==1.0.0
|
194 |
+
wcwidth
|
195 |
+
weasel==0.4.1
|
196 |
+
websockets==14.1
|
197 |
+
wrapt==1.17.0
|
198 |
+
xformers==0.0.26.post1
|
199 |
+
xxhash==3.5.0
|
200 |
+
yarl==1.18.3
|
src/fact/narrativefactscore.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Suppress annoying warnings from this issue which cannot be solved: https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md and transformers packages
|
2 |
+
import warnings
|
3 |
+
warnings.filterwarnings("ignore")
|
4 |
+
|
5 |
+
import re
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import traceback
|
9 |
+
from transformers import BartTokenizer, BartForConditionalGeneration
|
10 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
11 |
+
import numpy as np
|
12 |
+
from nltk import sent_tokenize
|
13 |
+
import logging
|
14 |
+
import openai
|
15 |
+
from tqdm import tqdm
|
16 |
+
from sentence_transformers import SentenceTransformer, util
|
17 |
+
from openai.error import (APIError, RateLimitError, ServiceUnavailableError,
|
18 |
+
Timeout, APIConnectionError, InvalidRequestError)
|
19 |
+
from tenacity import (before_sleep_log, retry, retry_if_exception_type,
|
20 |
+
stop_after_delay, wait_random_exponential, stop_after_attempt)
|
21 |
+
from .utils import break_down2scenes
|
22 |
+
from .prompt import build_fact_prompt
|
23 |
+
from .openai_api import openai_api_response
|
24 |
+
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
class OpenAIEmbedding:
|
29 |
+
def __init__(self, api_key, model="text-embedding-3-large"):
|
30 |
+
self.api_key = api_key
|
31 |
+
self.model = model
|
32 |
+
openai.api_key = api_key
|
33 |
+
|
34 |
+
@retry(retry=retry_if_exception_type((APIError, Timeout, RateLimitError,
|
35 |
+
ServiceUnavailableError, APIConnectionError)),
|
36 |
+
wait=wait_random_exponential(max=60), stop=stop_after_attempt(10),
|
37 |
+
before_sleep=before_sleep_log(logger, logging.WARNING))
|
38 |
+
def encode(self, texts, **kwargs):
|
39 |
+
if isinstance(texts, str):
|
40 |
+
texts = [texts]
|
41 |
+
|
42 |
+
try:
|
43 |
+
response = openai.Embedding.create(
|
44 |
+
model=self.model,
|
45 |
+
input=texts,
|
46 |
+
)
|
47 |
+
|
48 |
+
# Extract embeddings from response
|
49 |
+
embeddings = [item["embedding"] for item in response["data"]]
|
50 |
+
return np.array(embeddings)
|
51 |
+
|
52 |
+
except Exception as e:
|
53 |
+
logger.error(f"Embedding API failed: {str(e)}")
|
54 |
+
return None
|
55 |
+
|
56 |
+
class NarrativeFactScore():
|
57 |
+
def __init__(self, model="gpt-4o-mini", split_type="fast", checkpoint=None, api_key=None, model_id="gpt-4"):
|
58 |
+
self.sent_model = OpenAIEmbedding(api_key=api_key)
|
59 |
+
self.split_type = split_type
|
60 |
+
self.checkpoint = checkpoint
|
61 |
+
self.api_key = api_key
|
62 |
+
self.model_id = model_id
|
63 |
+
openai.api_key = api_key
|
64 |
+
|
65 |
+
if model == "gptscore":
|
66 |
+
self.metric = GPTScore(model=self.model_id, api_key=self.api_key)
|
67 |
+
self.metric_function = self.metric.gpt_score
|
68 |
+
else:
|
69 |
+
raise ValueError("NarrativeFactScore currently only supports GPTScore")
|
70 |
+
|
71 |
+
def get_surrounding_sentences(self, sentence_array, ii):
|
72 |
+
if ii > 0 and ii < len(sentence_array) - 1:
|
73 |
+
sents = " ".join(np.array(sentence_array)[ii - 1 : ii + 1])
|
74 |
+
elif ii == 0:
|
75 |
+
sents = " ".join(np.array(sentence_array)[:2])
|
76 |
+
elif ii == len(sentence_array) - 1:
|
77 |
+
sents = " ".join(np.array(sentence_array)[ii - 1 :])
|
78 |
+
return sents
|
79 |
+
|
80 |
+
def group_into_sections(self, sentence_array, num_sent):
|
81 |
+
sectioned_sents = []
|
82 |
+
for ii in range(0, len(sentence_array), num_sent):
|
83 |
+
sectioned_sents.append(" ".join(sentence_array)[ii : ii + num_sent])
|
84 |
+
return sectioned_sents
|
85 |
+
|
86 |
+
def split_sent(self, text):
|
87 |
+
text_list = []
|
88 |
+
if self.split_type == "fast":
|
89 |
+
for t in text.split('.'):
|
90 |
+
if len(t) == 0:
|
91 |
+
continue
|
92 |
+
text_list.append(t)
|
93 |
+
return text_list
|
94 |
+
elif self.split_type == "fast_comma":
|
95 |
+
for t in re.split(r'[.,]', text):
|
96 |
+
if len(t) == 0:
|
97 |
+
continue
|
98 |
+
text_list.append(t)
|
99 |
+
return text_list
|
100 |
+
elif self.split_type == "gpt":
|
101 |
+
prompt = build_fact_prompt(
|
102 |
+
prompt_template = './templates/atomic_fact.txt',
|
103 |
+
input_text_list=[text],
|
104 |
+
)
|
105 |
+
response = openai_api_response(prompt, model=self.model_id, api_key=self.api_key)
|
106 |
+
text_list = []
|
107 |
+
for res in response.split('\n'):
|
108 |
+
text_list.append(res.strip())
|
109 |
+
return text_list
|
110 |
+
else:
|
111 |
+
return None
|
112 |
+
|
113 |
+
def score_src_hyp_long(self, srcs, hyps, kgs):
|
114 |
+
all_scores = []
|
115 |
+
all_scores_per_sent = []
|
116 |
+
all_relevant_scenes = []
|
117 |
+
all_summary_chunks = []
|
118 |
+
all_feedback_list = []
|
119 |
+
# src is a list containing source documents.
|
120 |
+
# hyps is a list containing predicted documents
|
121 |
+
total_score = 0
|
122 |
+
for global_idx, (src, hyp) in enumerate(zip(tqdm(srcs), hyps)):
|
123 |
+
src_sents = break_down2scenes(src)
|
124 |
+
# Get embeddings using OpenAI API
|
125 |
+
sentence_embeddings_src = self.sent_model.encode(src_sents)
|
126 |
+
sentence_embeddings_kg = self.sent_model.encode(kgs)
|
127 |
+
|
128 |
+
doc_scores = []
|
129 |
+
relevant_scenes = []
|
130 |
+
feedbacks = []
|
131 |
+
hyp_array = self.split_sent(hyp)
|
132 |
+
for idx, hyp_sentence in enumerate(hyp_array):
|
133 |
+
# Get embedding for hypothesis sentence
|
134 |
+
sentence_embeddings_hyp = self.sent_model.encode(hyp_sentence)
|
135 |
+
|
136 |
+
# Calculate cosine similarity
|
137 |
+
scores = util.cos_sim(sentence_embeddings_hyp, sentence_embeddings_src)[0]
|
138 |
+
scores_kg = util.cos_sim(sentence_embeddings_hyp, sentence_embeddings_kg)[0]
|
139 |
+
|
140 |
+
sorted_idxs = np.argsort(-1 * scores) # descending order
|
141 |
+
sorted_idxs_kg = np.argsort(-1 * scores_kg) # descending order
|
142 |
+
similar_src_sentences = []
|
143 |
+
similar_src_sentences_kg = []
|
144 |
+
triple = ''
|
145 |
+
|
146 |
+
for sorted_idx, ii in enumerate(sorted_idxs_kg[0:1]):
|
147 |
+
if sorted_idx == 0:
|
148 |
+
triple += f'{kgs[ii]}'
|
149 |
+
else:
|
150 |
+
triple += f', {kgs[ii]}'
|
151 |
+
for ii in sorted_idxs[0:1]:
|
152 |
+
similar_sents = src_sents[ii]
|
153 |
+
similar_src_sentences.append(similar_sents)
|
154 |
+
|
155 |
+
scores, feedback_list = self.metric_function(similar_src_sentences, [hyp_sentence for i in range(0, len(similar_src_sentences))], triple)
|
156 |
+
score = np.max(scores)
|
157 |
+
max_scene_idx = np.argmax(scores)
|
158 |
+
max_scene = similar_src_sentences[max_scene_idx]
|
159 |
+
feedback = feedback_list[max_scene_idx]
|
160 |
+
|
161 |
+
doc_scores.append(int(score))
|
162 |
+
relevant_scenes.append(max_scene)
|
163 |
+
feedbacks.append(feedback)
|
164 |
+
|
165 |
+
doc_score = np.mean(doc_scores)
|
166 |
+
all_scores_per_sent.append(doc_scores)
|
167 |
+
all_scores.append(doc_score)
|
168 |
+
all_relevant_scenes.append(relevant_scenes)
|
169 |
+
all_summary_chunks.append(hyp_array)
|
170 |
+
all_feedback_list.append(feedbacks)
|
171 |
+
total_score += doc_score
|
172 |
+
if global_idx % 100 == 99:
|
173 |
+
print(f"Document mean {global_idx+1} Score: {total_score/(global_idx+1)} Score")
|
174 |
+
return all_scores, all_scores_per_sent, all_relevant_scenes, all_summary_chunks, all_feedback_list
|
175 |
+
|
176 |
+
class GPTScore():
|
177 |
+
def __init__(self, model="gpt-4o", api_key=None, prompt='./templates/fact_score_kg.txt'):
|
178 |
+
self.max_length = 1024
|
179 |
+
self.model = model
|
180 |
+
self.api_key = api_key
|
181 |
+
self.prompt = prompt
|
182 |
+
openai.api_key = api_key
|
183 |
+
|
184 |
+
@retry(retry=retry_if_exception_type((APIError, Timeout, RateLimitError,
|
185 |
+
ServiceUnavailableError, APIConnectionError, InvalidRequestError)),
|
186 |
+
wait=wait_random_exponential(max=60), stop=stop_after_attempt(10),
|
187 |
+
before_sleep=before_sleep_log(logger, logging.WARNING))
|
188 |
+
def gpt_inference(self, prompt):
|
189 |
+
prompt_messages = [{"role": "user", "content": prompt}]
|
190 |
+
try:
|
191 |
+
response = openai.ChatCompletion.create(
|
192 |
+
model=self.model,
|
193 |
+
messages=prompt_messages,
|
194 |
+
temperature=0,
|
195 |
+
api_key=self.api_key
|
196 |
+
)
|
197 |
+
response = response.choices[0].message.content
|
198 |
+
except InvalidRequestError:
|
199 |
+
response = 1
|
200 |
+
return response
|
201 |
+
|
202 |
+
def gpt_score(self, srcs, tgts, kgs, batch_size=4):
|
203 |
+
score_list = []
|
204 |
+
feedback_list = []
|
205 |
+
|
206 |
+
for i in range(len(srcs)):
|
207 |
+
src = srcs[i]
|
208 |
+
tgt = tgts[i]
|
209 |
+
|
210 |
+
prompt = build_fact_prompt(
|
211 |
+
prompt_template=self.prompt,
|
212 |
+
input_text_list=[src, kgs, tgt],
|
213 |
+
)
|
214 |
+
|
215 |
+
try:
|
216 |
+
score = self.gpt_inference(prompt)
|
217 |
+
if '1' in score:
|
218 |
+
score_list.append(float(1))
|
219 |
+
feedback_list.append('')
|
220 |
+
else:
|
221 |
+
score_list.append(float(0))
|
222 |
+
feedback_list.append(score)
|
223 |
+
|
224 |
+
except RuntimeError:
|
225 |
+
traceback.print_exc()
|
226 |
+
print(f"source: {src_list}")
|
227 |
+
print(f"target: {tgt_list}")
|
228 |
+
exit(0)
|
229 |
+
|
230 |
+
return score_list, feedback_list
|
src/fact/openai_api.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import openai
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
from openai.error import (APIError, RateLimitError, ServiceUnavailableError,
|
9 |
+
Timeout, APIConnectionError, InvalidRequestError)
|
10 |
+
from tenacity import (before_sleep_log, retry, retry_if_exception_type,
|
11 |
+
stop_after_delay, wait_random_exponential, stop_after_attempt)
|
12 |
+
from tiktoken import Encoding, encoding_for_model
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
load_dotenv()
|
17 |
+
|
18 |
+
# This value is set by OpenAI for the selected model and cannot be changed.
|
19 |
+
MAX_MODEL_TOKEN_COUNT = 4096
|
20 |
+
# This value can be changed.
|
21 |
+
MAX_RESPONSE_TOKEN_COUNT = 512
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
@retry(retry=retry_if_exception_type((APIError, Timeout, RateLimitError,
|
26 |
+
ServiceUnavailableError, APIConnectionError, InvalidRequestError)),
|
27 |
+
wait=wait_random_exponential(max=60), stop=stop_after_attempt(10),
|
28 |
+
before_sleep=before_sleep_log(logger, logging.WARNING))
|
29 |
+
def openai_api_response(prompt, model, api_key, save_path=None):
|
30 |
+
"""
|
31 |
+
Use a prompt to make a request to the OpenAI API and save the response to a
|
32 |
+
JSON file.
|
33 |
+
"""
|
34 |
+
openai.api_key = api_key
|
35 |
+
try:
|
36 |
+
prompt_messages = [{"role": "user", "content": prompt}]
|
37 |
+
response = openai.ChatCompletion.create(
|
38 |
+
model=model, messages=prompt_messages, temperature=0)
|
39 |
+
finish_reason = response.choices[0].finish_reason
|
40 |
+
if finish_reason != 'stop':
|
41 |
+
logger.error(f'`finish_reason` is `{finish_reason}` for {save_path}.')
|
42 |
+
save_data = {'model': response.model, 'usage': response.usage,
|
43 |
+
'finish_reason': finish_reason,
|
44 |
+
'prompt_messages': prompt_messages,
|
45 |
+
'response': response.choices[0].message.content}
|
46 |
+
except InvalidRequestError:
|
47 |
+
logger.error(f'InvalidRequestError encountered 10 times. Returning empty string for {save_path}.')
|
48 |
+
save_data = {'model': None, 'usage': None,
|
49 |
+
'finish_reason': 'invalid_request',
|
50 |
+
'prompt_messages': prompt_messages,
|
51 |
+
'response': ' '}
|
52 |
+
return save_data['response']
|
src/fact/prompt.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Optional
|
3 |
+
|
4 |
+
|
5 |
+
def build_fact_prompt(
|
6 |
+
prompt_template:str,
|
7 |
+
input_text_list:List[str],
|
8 |
+
chat_mode:Optional[str] = None) -> str:
|
9 |
+
|
10 |
+
"""_summary_
|
11 |
+
chat_mode(str) : 'hf-chat', 'kullm', 'None'
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
_type_: _description_
|
15 |
+
"""
|
16 |
+
|
17 |
+
if os.path.isfile(prompt_template):
|
18 |
+
with open(prompt_template,'r') as f:
|
19 |
+
prompt_template = f.read()
|
20 |
+
else:
|
21 |
+
pass
|
22 |
+
|
23 |
+
assert isinstance(prompt_template, str)
|
24 |
+
|
25 |
+
prompt = prompt_template.format(*input_text_list)
|
26 |
+
|
27 |
+
|
28 |
+
return prompt
|
src/fact/utils.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
def delete_special(pre_text, character_list):
|
4 |
+
for c in character_list:
|
5 |
+
pre_text = pre_text.replace(c, "")
|
6 |
+
return pre_text
|
7 |
+
|
8 |
+
def break_down2scenes(text: str):
|
9 |
+
# Split the text based on the 's#' pattern
|
10 |
+
scenes = re.split(r'(s#\d+)', text)
|
11 |
+
|
12 |
+
# Remove empty elements from the split results
|
13 |
+
scenes = [scene for scene in scenes if scene.strip()]
|
14 |
+
|
15 |
+
scenes_list = []
|
16 |
+
current_scene_number = None
|
17 |
+
|
18 |
+
for i in range(0, len(scenes), 2): # Process the 's#' marker and its corresponding text as pairs
|
19 |
+
scene_marker = scenes[i].strip()
|
20 |
+
try:
|
21 |
+
scene_number = int(scene_marker.split('#')[1]) # Extract the number part
|
22 |
+
except:
|
23 |
+
if len(scenes) % 2 == 1:
|
24 |
+
return [scenes[0]]
|
25 |
+
import ipdb;ipdb.set_trace(context=10)
|
26 |
+
scene_text = scenes[i+1].strip() if i+1 < len(scenes) else ""
|
27 |
+
|
28 |
+
# Check if the scene numbers are in the correct sequence
|
29 |
+
if current_scene_number is not None:
|
30 |
+
expected_scene_number = current_scene_number + 1
|
31 |
+
if scene_number != expected_scene_number:
|
32 |
+
raise ValueError(f"Unexpected scene number: {scene_number}, expected {expected_scene_number}")
|
33 |
+
|
34 |
+
# Store the scene number and its corresponding text together
|
35 |
+
scenes_list.append({
|
36 |
+
'detected_scene_number': scene_number,
|
37 |
+
'text': f"{scene_marker}\n{scene_text}".strip()
|
38 |
+
})
|
39 |
+
|
40 |
+
filtered_scene_list = []
|
41 |
+
scene_number = 0
|
42 |
+
for scene_dict in scenes_list:
|
43 |
+
detected_scene_number = int(scene_dict['detected_scene_number'])
|
44 |
+
filtered_scene_list.append(scene_dict['text'])
|
45 |
+
scene_number = detected_scene_number
|
46 |
+
|
47 |
+
return filtered_scene_list
|
src/kg/__init__.py
ADDED
File without changes
|
src/kg/generate_kg.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src.kg.generate_kg.py
|
2 |
+
import pickle
|
3 |
+
from collections import defaultdict, Counter
|
4 |
+
from contextlib import redirect_stdout
|
5 |
+
from pathlib import Path
|
6 |
+
import json
|
7 |
+
import argparse
|
8 |
+
import os
|
9 |
+
import openai
|
10 |
+
import time
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
import networkx as nx
|
14 |
+
from pyvis.network import Network
|
15 |
+
from tqdm import tqdm
|
16 |
+
from contextlib import redirect_stdout
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
from .knowledge_graph import generate_knowledge_graph
|
21 |
+
from .openai_api import load_response_text
|
22 |
+
from .save_triples import get_response_save_path
|
23 |
+
from .utils import set_up_logging
|
24 |
+
|
25 |
+
logger = set_up_logging('generate-knowledge-graphs-books.log')
|
26 |
+
KNOWLEDGE_GRAPHS_DIRECTORY_PATH = Path('../knowledge-graphs_new')
|
27 |
+
|
28 |
+
|
29 |
+
"""def gpt_inference(system_instruction, prompt, retries=10, delay=5):
|
30 |
+
# api
|
31 |
+
messages = [{"role": "system", "content": system_instruction},
|
32 |
+
{"role": "user", "content": prompt}]
|
33 |
+
|
34 |
+
for attempt in range(retries):
|
35 |
+
try:
|
36 |
+
response = openai.ChatCompletion.create(
|
37 |
+
model='gpt-4o-mini-2024-07-18',
|
38 |
+
messages=messages,
|
39 |
+
temperature=0.0,
|
40 |
+
max_tokens=128,
|
41 |
+
top_p=0.5,
|
42 |
+
frequency_penalty=0,
|
43 |
+
presence_penalty=0
|
44 |
+
)
|
45 |
+
result = response['choices'][0]['message']['content']
|
46 |
+
return result
|
47 |
+
except openai.error.APIError as e:
|
48 |
+
|
49 |
+
time.sleep(delay)
|
50 |
+
continue"""
|
51 |
+
|
52 |
+
|
53 |
+
def generate_knowledge_graph_for_scripts(book, idx, save_path):
|
54 |
+
"""
|
55 |
+
Use the responses from the OpenAI API to generate a knowledge graph for a
|
56 |
+
book.
|
57 |
+
"""
|
58 |
+
response_texts = defaultdict(list)
|
59 |
+
project_gutenberg_id = book['id']
|
60 |
+
for chapter in book['chapters']:
|
61 |
+
chapter_index = chapter['index']
|
62 |
+
chapter_responses_directory = get_response_save_path(
|
63 |
+
idx, save_path, project_gutenberg_id, chapter_index)
|
64 |
+
for response_path in chapter_responses_directory.glob('*.json'):
|
65 |
+
response_text = load_response_text(response_path)
|
66 |
+
response_texts[chapter_index].append(response_text)
|
67 |
+
knowledge_graph = generate_knowledge_graph(response_texts, project_gutenberg_id)
|
68 |
+
return knowledge_graph
|
69 |
+
|
70 |
+
def generate_knowledge_graph_for_scripts(book, idx, response_list):
|
71 |
+
"""
|
72 |
+
Use the responses from the OpenAI API to generate a knowledge graph for a
|
73 |
+
book.
|
74 |
+
"""
|
75 |
+
|
76 |
+
response_texts = defaultdict(list)
|
77 |
+
project_gutenberg_id = book['id']
|
78 |
+
for chapter in book['chapters']:
|
79 |
+
chapter_index = chapter['index']
|
80 |
+
for response in response_list:
|
81 |
+
response_texts[chapter_index].append(response['response'])
|
82 |
+
knowledge_graph = generate_knowledge_graph(response_texts, project_gutenberg_id)
|
83 |
+
return knowledge_graph
|
84 |
+
|
85 |
+
|
86 |
+
def save_knowledge_graph(knowledge_graph,
|
87 |
+
project_gutenberg_id, save_path):
|
88 |
+
"""Save a knowledge graph to a `pickle` file."""
|
89 |
+
save_path = save_path / 'kg.pkl'
|
90 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
91 |
+
with open(save_path, 'wb') as knowledge_graph_file:
|
92 |
+
pickle.dump(knowledge_graph, knowledge_graph_file)
|
93 |
+
|
94 |
+
|
95 |
+
def load_knowledge_graph(project_gutenberg_id, save_path):
|
96 |
+
"""Load a knowledge graph from a `pickle` file."""
|
97 |
+
save_path = save_path / 'kg.pkl'
|
98 |
+
with open(save_path, 'rb') as knowledge_graph_file:
|
99 |
+
knowledge_graph = pickle.load(knowledge_graph_file)
|
100 |
+
return knowledge_graph
|
101 |
+
|
102 |
+
|
103 |
+
def display_knowledge_graph(knowledge_graph, save_path):
|
104 |
+
"""Display a knowledge graph using pyvis."""
|
105 |
+
# Convert the knowledge graph into a format that can be displayed by pyvis.
|
106 |
+
# Merge all edges with the same subject and object into a single edge.
|
107 |
+
pyvis_graph = nx.MultiDiGraph()
|
108 |
+
for node in knowledge_graph.nodes:
|
109 |
+
pyvis_graph.add_node(str(node), label='\n'.join(node.names),
|
110 |
+
shape='box')
|
111 |
+
for edge in knowledge_graph.edges(data=True):
|
112 |
+
subject = str(edge[0])
|
113 |
+
object_ = str(edge[1])
|
114 |
+
predicate = edge[2]['predicate']
|
115 |
+
chapter_index = edge[2]['chapter_index']
|
116 |
+
if pyvis_graph.has_edge(subject, object_):
|
117 |
+
pyvis_graph[subject][object_][0].update(
|
118 |
+
title=(f'{pyvis_graph[subject][object_][0]["title"]}\n'
|
119 |
+
f'{predicate}')) # f'{predicate} ({chapter_index})'))
|
120 |
+
else:
|
121 |
+
pyvis_graph.add_edge(subject, object_,
|
122 |
+
title=f'{predicate}') # title=f'{predicate} ({chapter_index})')
|
123 |
+
network = Network(height='99vh', directed=True, bgcolor='#262626',
|
124 |
+
cdn_resources='remote')
|
125 |
+
network.set_options('''
|
126 |
+
const options = {
|
127 |
+
"interaction": {
|
128 |
+
"tooltipDelay": 0
|
129 |
+
},
|
130 |
+
"physics": {
|
131 |
+
"forceAtlas2Based": {
|
132 |
+
"gravitationalConstant": -50,
|
133 |
+
"centralGravity": 0.01,
|
134 |
+
"springLength": 100,
|
135 |
+
"springConstant": 0.08,
|
136 |
+
"damping": 0.4,
|
137 |
+
"avoidOverlap": 0
|
138 |
+
},
|
139 |
+
"solver": "forceAtlas2Based"
|
140 |
+
}
|
141 |
+
}''')
|
142 |
+
network.from_nx(pyvis_graph)
|
143 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
144 |
+
# `show()` tries to print the name of the HTML file to the console, so
|
145 |
+
# suppress it.
|
146 |
+
with redirect_stdout(None):
|
147 |
+
network.show(str(save_path), notebook=False)
|
148 |
+
logger.info(f'Saved pyvis knowledge graph to {save_path}.')
|
149 |
+
|
150 |
+
def fuse_subject(subjects):
|
151 |
+
subject_list = subjects.split('/')
|
152 |
+
if len(subject_list) == 1:
|
153 |
+
return subject_list[0]
|
154 |
+
flag = 0
|
155 |
+
striped_subject_list = []
|
156 |
+
len_list = []
|
157 |
+
for subject in subject_list:
|
158 |
+
striped_subject_list.append(subject.strip())
|
159 |
+
len_list.append(len(subject))
|
160 |
+
idx = np.argmin(len_list)
|
161 |
+
for subject in striped_subject_list:
|
162 |
+
if striped_subject_list[idx] in subject:
|
163 |
+
flag += 1
|
164 |
+
|
165 |
+
if flag == len(striped_subject_list):
|
166 |
+
return striped_subject_list[idx]
|
167 |
+
else:
|
168 |
+
return subjects
|
169 |
+
|
170 |
+
def init_kg(script, idx, response_list):
|
171 |
+
"""
|
172 |
+
Generate knowledge graphs for book in the books dataset using saved
|
173 |
+
responses from the OpenAI API.
|
174 |
+
"""
|
175 |
+
knowledge_graph = generate_knowledge_graph_for_scripts(script, idx, response_list)
|
176 |
+
return knowledge_graph
|
177 |
+
|
178 |
+
def refine_kg(knowledge_graph, idx, topk):
|
179 |
+
result = []
|
180 |
+
edge_count = Counter()
|
181 |
+
for edge in knowledge_graph.edges(data=True):
|
182 |
+
subject = str(edge[0])
|
183 |
+
object_ = str(edge[1])
|
184 |
+
edge_count[subject] += 1
|
185 |
+
edge_count[object_] += 1
|
186 |
+
|
187 |
+
# 엣지가 많은 상위 k개의 노드 선택
|
188 |
+
top_k_nodes = [node for node, count in edge_count.most_common(topk)]
|
189 |
+
|
190 |
+
# 상위 k개 노드 간의 모든 관계를 수집
|
191 |
+
rel_dict = {}
|
192 |
+
for edge in knowledge_graph.edges(data=True):
|
193 |
+
subject = str(edge[0])
|
194 |
+
object_ = str(edge[1])
|
195 |
+
if subject in top_k_nodes and object_ in top_k_nodes:
|
196 |
+
predicate = edge[2]['predicate']
|
197 |
+
chapter_index = edge[2]['chapter_index']
|
198 |
+
count = edge[2]['count']
|
199 |
+
key = f"{subject}\t{object_}"
|
200 |
+
if key not in rel_dict:
|
201 |
+
rel_dict[key] = []
|
202 |
+
rel_dict[key].append((predicate, chapter_index, count))
|
203 |
+
|
204 |
+
# 시각화 코드
|
205 |
+
pyvis_graph = nx.MultiDiGraph()
|
206 |
+
for node in top_k_nodes:
|
207 |
+
pyvis_graph.add_node(node, label=node, shape='box')
|
208 |
+
|
209 |
+
for key, relations in rel_dict.items():
|
210 |
+
subject, object_ = key.split('\t')
|
211 |
+
for relation in relations:
|
212 |
+
predicate, chapter_index, count = relation
|
213 |
+
if 'output' in predicate:
|
214 |
+
continue
|
215 |
+
if count >= 2:
|
216 |
+
if pyvis_graph.has_edge(subject, object_):
|
217 |
+
pyvis_graph[subject][object_][0]['title'] += f', {predicate}'
|
218 |
+
else:
|
219 |
+
pyvis_graph.add_edge(subject, object_, title=f'{predicate}')
|
220 |
+
|
221 |
+
network = Network(height='99vh', directed=True, bgcolor='#262626', cdn_resources='remote')
|
222 |
+
network.from_nx(pyvis_graph)
|
223 |
+
|
224 |
+
with redirect_stdout(None):
|
225 |
+
network.show('refined_kg.html', notebook=False)
|
226 |
+
|
227 |
+
for key, relations in rel_dict.items():
|
228 |
+
subject, object_ = key.split('\t')
|
229 |
+
|
230 |
+
for relation in relations:
|
231 |
+
predicate, chapter_index, count = relation
|
232 |
+
|
233 |
+
if 'output' in predicate:
|
234 |
+
continue
|
235 |
+
|
236 |
+
subject = fuse_subject(subject)
|
237 |
+
object_ = fuse_subject(object_)
|
238 |
+
|
239 |
+
relationship = {
|
240 |
+
'subject': subject,
|
241 |
+
'predicate': predicate,
|
242 |
+
'object': object_,
|
243 |
+
'chapter_index': chapter_index,
|
244 |
+
'count': count,
|
245 |
+
'subject_node_count': edge_count[subject],
|
246 |
+
'object_node_count': edge_count[object_]
|
247 |
+
}
|
248 |
+
|
249 |
+
if count >= 2:
|
250 |
+
result.append(relationship)
|
251 |
+
|
252 |
+
return result
|
253 |
+
|
src/kg/knowledge_graph.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src.kg.knowledge_graph.py
|
2 |
+
import itertools
|
3 |
+
import logging
|
4 |
+
import re
|
5 |
+
from collections import defaultdict
|
6 |
+
from itertools import combinations, product
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import networkx as nx
|
10 |
+
|
11 |
+
from .utils import strip_and_remove_empty_strings
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
PROMPT_FILE_PATH = Path('templates/story-prompt.txt')
|
16 |
+
MAX_RESPONSE_EDGE_COUNT = 15
|
17 |
+
MAX_PREDICATE_WORD_COUNT = 5
|
18 |
+
MAX_POSSESSION_WORD_COUNT = 2
|
19 |
+
MAX_MERGEABLE_NODE_EDGE_COUNT = 2
|
20 |
+
MIN_NODE_EDGE_COUNT = 1
|
21 |
+
|
22 |
+
|
23 |
+
class NamedEntity:
|
24 |
+
"""A knowledge graph node representing a named entity."""
|
25 |
+
|
26 |
+
def __init__(self, names):
|
27 |
+
self.names = names
|
28 |
+
|
29 |
+
def __repr__(self):
|
30 |
+
return ' / '.join(self.names)
|
31 |
+
|
32 |
+
def remove_number_prefix(text):
|
33 |
+
|
34 |
+
clean_text = re.sub(r'^\d+\.\s*', '', text)
|
35 |
+
return clean_text
|
36 |
+
|
37 |
+
def parse_response_text(response_text, identifier, are_edges_numbered=True):
|
38 |
+
"""
|
39 |
+
Parse a response text from the OpenAI model into names (a list of names for
|
40 |
+
each entity) and edges (relations between entities). `identifier` is a
|
41 |
+
string used to identify the response text in error messages.
|
42 |
+
"""
|
43 |
+
|
44 |
+
lines = strip_and_remove_empty_strings(response_text.split('\n'))
|
45 |
+
|
46 |
+
if 'Named entities' not in lines[0]:
|
47 |
+
logger.error(f'{identifier}: First line of response text does not '
|
48 |
+
f'start with "Named entities:". ("{lines[0]}")')
|
49 |
+
return [], []
|
50 |
+
mode = 'names'
|
51 |
+
names = []
|
52 |
+
edges = []
|
53 |
+
for line in lines[1:]:
|
54 |
+
if 'Knowledge graph edges' in line:
|
55 |
+
mode = 'edges'
|
56 |
+
continue
|
57 |
+
if mode == 'names':
|
58 |
+
if line.startswith('-'):
|
59 |
+
line = line[1:]
|
60 |
+
name_group = strip_and_remove_empty_strings(line.split(' / '))
|
61 |
+
name_group = [remove_number_prefix(name) for name in name_group]
|
62 |
+
names.append(name_group)
|
63 |
+
elif mode == 'edges':
|
64 |
+
if are_edges_numbered:
|
65 |
+
if not re.match(r'^\d{1,2}\. ', line):
|
66 |
+
break
|
67 |
+
if int(line.split('.')[0]) > MAX_RESPONSE_EDGE_COUNT:
|
68 |
+
break;
|
69 |
+
line = line[3:]
|
70 |
+
edge_components = strip_and_remove_empty_strings(line.split(';'))
|
71 |
+
if len(edge_components) not in (2, 3):
|
72 |
+
continue
|
73 |
+
subjects = strip_and_remove_empty_strings(
|
74 |
+
edge_components[0].split(','))
|
75 |
+
predicate = edge_components[1]
|
76 |
+
if len(edge_components) == 3:
|
77 |
+
objects = strip_and_remove_empty_strings(
|
78 |
+
edge_components[2].split(','))
|
79 |
+
else:
|
80 |
+
objects = [None]
|
81 |
+
for subject, object_ in product(subjects, objects):
|
82 |
+
edge = (subject, predicate, object_)
|
83 |
+
edges.append(edge)
|
84 |
+
if not names:
|
85 |
+
logger.error(f'{identifier}: No names were parsed from the response '
|
86 |
+
f'text.')
|
87 |
+
if not edges:
|
88 |
+
logger.error(f'{identifier}: No edges were parsed from the response '
|
89 |
+
f'text.')
|
90 |
+
|
91 |
+
return names, edges
|
92 |
+
|
93 |
+
|
94 |
+
def generate_names_graph(names):
|
95 |
+
"""
|
96 |
+
Generate a graph of names where the nodes are names and the edges indicate
|
97 |
+
that two names refer to the same entity.
|
98 |
+
"""
|
99 |
+
names_graph = nx.Graph()
|
100 |
+
for name_group in names:
|
101 |
+
for name in name_group:
|
102 |
+
names_graph.add_node(name)
|
103 |
+
for name_pair in combinations(name_group, 2):
|
104 |
+
names_graph.add_edge(*name_pair)
|
105 |
+
return names_graph
|
106 |
+
|
107 |
+
|
108 |
+
def expand_contracted_possessive(predicate, names):
|
109 |
+
"""
|
110 |
+
Check if a predicate is of the form "<owner>'s <possession>", where the
|
111 |
+
owner is a named entity. If so, return a predicate of the form
|
112 |
+
"<possession> of" and an object of the form "<owner>".
|
113 |
+
"""
|
114 |
+
match = re.search(
|
115 |
+
fr'\'s\s\w+(?:\s\w+)'
|
116 |
+
fr'{{0,{MAX_POSSESSION_WORD_COUNT - 1}}}$', predicate)
|
117 |
+
if not match:
|
118 |
+
return predicate, None
|
119 |
+
apostrophe_index = match.start()
|
120 |
+
owner = next(
|
121 |
+
(name for name in names
|
122 |
+
if predicate[:apostrophe_index].endswith(name)), None)
|
123 |
+
if owner is None:
|
124 |
+
return predicate, None
|
125 |
+
possession = predicate[apostrophe_index + 2:].strip()
|
126 |
+
predicate = (f'{predicate[:apostrophe_index - len(owner)].strip()} '
|
127 |
+
f'{possession} of')
|
128 |
+
object_ = owner
|
129 |
+
return predicate, object_
|
130 |
+
|
131 |
+
|
132 |
+
def does_duplicate_edge_exist(knowledge_graph, subject, predicate, object_):
|
133 |
+
"""
|
134 |
+
Check if an edge with a given subject, predicate, and object already exists
|
135 |
+
in a knowledge graph. If it exists, return the edge data; otherwise, return None.
|
136 |
+
"""
|
137 |
+
for edge in knowledge_graph.edges(subject, data=True):
|
138 |
+
if edge[1] == object_ and edge[2]['predicate'] == predicate:
|
139 |
+
return edge
|
140 |
+
return None
|
141 |
+
|
142 |
+
|
143 |
+
def add_edge_to_knowledge_graph(knowledge_graph, names, edge, max_predicate_word_count, **edge_attributes):
|
144 |
+
"""Add an edge to a knowledge graph, updating count if the edge already exists."""
|
145 |
+
subject, predicate, object_ = edge
|
146 |
+
if subject not in names:
|
147 |
+
return
|
148 |
+
if object_ is not None and object_ not in names:
|
149 |
+
predicate += f' {object_}'
|
150 |
+
object_ = None
|
151 |
+
if object_ is None:
|
152 |
+
object_at_end_of_predicate = next(
|
153 |
+
(name for name in names if predicate.endswith(' ' + name)), None)
|
154 |
+
if object_at_end_of_predicate is not None:
|
155 |
+
object_ = object_at_end_of_predicate
|
156 |
+
predicate = predicate[:-len(object_)].strip()
|
157 |
+
else:
|
158 |
+
predicate, object_ = expand_contracted_possessive(predicate, names)
|
159 |
+
while predicate.endswith(('.', ',', '!', '?')):
|
160 |
+
predicate = predicate[:-1]
|
161 |
+
if (max_predicate_word_count and len(predicate.split()) > max_predicate_word_count):
|
162 |
+
return
|
163 |
+
if subject == object_:
|
164 |
+
return
|
165 |
+
if object_ is None:
|
166 |
+
object_ = subject
|
167 |
+
subject_node = next((node for node in knowledge_graph.nodes if subject in node.names), None)
|
168 |
+
object_node = next((node for node in knowledge_graph.nodes if object_ in node.names), None)
|
169 |
+
|
170 |
+
if subject_node is None or object_node is None:
|
171 |
+
return
|
172 |
+
|
173 |
+
existing_edge = does_duplicate_edge_exist(knowledge_graph, subject_node, predicate, object_node)
|
174 |
+
if existing_edge:
|
175 |
+
existing_edge[2]['count'] += 1
|
176 |
+
else:
|
177 |
+
knowledge_graph.add_edge(subject_node, object_node, predicate=predicate, count=1, **edge_attributes)
|
178 |
+
|
179 |
+
|
180 |
+
def initialize_knowledge_graph(names_graph, edges):
|
181 |
+
"""
|
182 |
+
Initialize a knowledge graph from a graph of names and a dictionary of
|
183 |
+
edges grouped by chapter index.
|
184 |
+
"""
|
185 |
+
names = set(names_graph.nodes)
|
186 |
+
knowledge_graph = nx.MultiDiGraph()
|
187 |
+
for name in names:
|
188 |
+
knowledge_graph.add_node(NamedEntity({name}))
|
189 |
+
for chapter_index, chapter_edges in edges.items():
|
190 |
+
for edge in chapter_edges:
|
191 |
+
add_edge_to_knowledge_graph(
|
192 |
+
knowledge_graph, names, edge,
|
193 |
+
max_predicate_word_count=MAX_PREDICATE_WORD_COUNT,
|
194 |
+
chapter_index=chapter_index)
|
195 |
+
return knowledge_graph
|
196 |
+
|
197 |
+
|
198 |
+
def get_node_edge_count(knowledge_graph, node):
|
199 |
+
"""
|
200 |
+
Get the number of edges for a node in a knowledge graph, excluding
|
201 |
+
self-loops.
|
202 |
+
"""
|
203 |
+
edges = (set(knowledge_graph.in_edges(node))
|
204 |
+
| set(knowledge_graph.out_edges(node)))
|
205 |
+
edge_count = sum(1 for edge in edges if edge[0] is not edge[1])
|
206 |
+
return edge_count
|
207 |
+
|
208 |
+
|
209 |
+
def merge_nodes(knowledge_graph, nodes_to_merge):
|
210 |
+
"""
|
211 |
+
Merge a list of nodes in a knowledge graph into one node, combining their
|
212 |
+
sets of names and preserving their edges.
|
213 |
+
"""
|
214 |
+
merged_node = NamedEntity(set())
|
215 |
+
for node in nodes_to_merge:
|
216 |
+
merged_node.names.update(node.names)
|
217 |
+
knowledge_graph.add_node(merged_node)
|
218 |
+
for node in nodes_to_merge:
|
219 |
+
for edge in itertools.chain(knowledge_graph.out_edges(node, data=True),
|
220 |
+
knowledge_graph.in_edges(node, data=True)):
|
221 |
+
subject, object_, attributes = edge
|
222 |
+
if (does_duplicate_edge_exist(knowledge_graph, merged_node,
|
223 |
+
attributes['predicate'], object_)
|
224 |
+
or does_duplicate_edge_exist(knowledge_graph, subject,
|
225 |
+
attributes['predicate'],
|
226 |
+
merged_node)):
|
227 |
+
continue
|
228 |
+
if subject is object_:
|
229 |
+
knowledge_graph.add_edge(merged_node, merged_node,
|
230 |
+
**attributes)
|
231 |
+
if subject is node:
|
232 |
+
knowledge_graph.add_edge(merged_node, object_, **attributes)
|
233 |
+
else:
|
234 |
+
knowledge_graph.add_edge(subject, merged_node, **attributes)
|
235 |
+
knowledge_graph.remove_node(node)
|
236 |
+
|
237 |
+
def merge_same_entity_nodes(knowledge_graph, names_graph):
|
238 |
+
"""
|
239 |
+
Using a graph of names, merge nodes in a knowledge graph corresponding to
|
240 |
+
the same entity.
|
241 |
+
"""
|
242 |
+
for name_pair in names_graph.edges:
|
243 |
+
first_node = next((node for node in knowledge_graph.nodes
|
244 |
+
if name_pair[0] in node.names), None)
|
245 |
+
if first_node is None:
|
246 |
+
continue
|
247 |
+
if name_pair[1] in first_node.names:
|
248 |
+
continue
|
249 |
+
second_node = next((node for node in knowledge_graph.nodes
|
250 |
+
if name_pair[1] in node.names), None)
|
251 |
+
if second_node is None:
|
252 |
+
continue
|
253 |
+
if knowledge_graph.has_edge(first_node, second_node):
|
254 |
+
continue
|
255 |
+
first_node_edge_count = get_node_edge_count(knowledge_graph,
|
256 |
+
first_node)
|
257 |
+
second_node_edge_count = get_node_edge_count(knowledge_graph,
|
258 |
+
second_node)
|
259 |
+
if (first_node_edge_count > MAX_MERGEABLE_NODE_EDGE_COUNT
|
260 |
+
and second_node_edge_count > MAX_MERGEABLE_NODE_EDGE_COUNT):
|
261 |
+
continue
|
262 |
+
merge_nodes(knowledge_graph, [first_node, second_node])
|
263 |
+
|
264 |
+
|
265 |
+
|
266 |
+
def remove_nodes_with_few_edges(knowledge_graph):
|
267 |
+
"""
|
268 |
+
Remove nodes that have fewer than `MIN_NODE_EDGE_COUNT` edges (excluding
|
269 |
+
self-loops) from a knowledge graph. Repeat until no more nodes are removed.
|
270 |
+
"""
|
271 |
+
while True:
|
272 |
+
nodes_to_remove = []
|
273 |
+
for node in knowledge_graph.nodes:
|
274 |
+
edge_count = get_node_edge_count(knowledge_graph, node)
|
275 |
+
if edge_count < MIN_NODE_EDGE_COUNT:
|
276 |
+
nodes_to_remove.append(node)
|
277 |
+
if not nodes_to_remove:
|
278 |
+
break
|
279 |
+
knowledge_graph.remove_nodes_from(nodes_to_remove)
|
280 |
+
|
281 |
+
|
282 |
+
def generate_knowledge_graph(response_texts, project_gutenberg_index):
|
283 |
+
"""
|
284 |
+
Use OpenAI API response texts grouped by chapter index to generate a
|
285 |
+
knowledge graph for a book.
|
286 |
+
"""
|
287 |
+
names = []
|
288 |
+
edges = defaultdict(list)
|
289 |
+
for chapter_index, chapter_response_texts in response_texts.items():
|
290 |
+
for response_text in chapter_response_texts:
|
291 |
+
identifier = (f'Book {project_gutenberg_index}, chapter '
|
292 |
+
f'{chapter_index}')
|
293 |
+
chapter_segment_names, chapter_segment_edges = parse_response_text(
|
294 |
+
response_text, identifier)
|
295 |
+
names.extend(chapter_segment_names)
|
296 |
+
edges[chapter_index].extend(chapter_segment_edges)
|
297 |
+
names_graph = generate_names_graph(names)
|
298 |
+
knowledge_graph = initialize_knowledge_graph(names_graph, edges)
|
299 |
+
merge_same_entity_nodes(knowledge_graph, names_graph)
|
300 |
+
remove_nodes_with_few_edges(knowledge_graph)
|
301 |
+
return knowledge_graph
|
src/kg/main.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src.kg.main.py
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from .preprocess import preprocess
|
7 |
+
from .save_triples import save_triples_for_scripts
|
8 |
+
from .generate_kg import init_kg, refine_kg
|
9 |
+
|
10 |
+
def script2kg(scene_list, idx, name, api_key, model_id):
|
11 |
+
# 1) preprocess script
|
12 |
+
preprocessed_script = preprocess(scene_list, idx)
|
13 |
+
|
14 |
+
# 2) extract triples
|
15 |
+
triple_list = save_triples_for_scripts(preprocessed_script, idx, api_key, model_id)
|
16 |
+
|
17 |
+
# 3) build kg
|
18 |
+
kg = init_kg(preprocessed_script, idx, triple_list)
|
19 |
+
|
20 |
+
# 4) refine kg
|
21 |
+
kg = refine_kg(kg, idx, topk=10)
|
22 |
+
|
23 |
+
return kg
|
src/kg/openai_api.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src.kg.openai_api.py
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import openai
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
from openai.error import (APIError, RateLimitError, ServiceUnavailableError,
|
10 |
+
Timeout, APIConnectionError, InvalidRequestError)
|
11 |
+
from tenacity import (before_sleep_log, retry, retry_if_exception_type,
|
12 |
+
stop_after_delay, wait_random_exponential, stop_after_attempt)
|
13 |
+
from tiktoken import Encoding, encoding_for_model
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
load_dotenv()
|
18 |
+
|
19 |
+
# This value is set by OpenAI for the selected model and cannot be changed.
|
20 |
+
MAX_MODEL_TOKEN_COUNT = 4096
|
21 |
+
# This value can be changed.
|
22 |
+
MAX_RESPONSE_TOKEN_COUNT = 512
|
23 |
+
RESPONSES_DIRECTORY_PATH = Path('../openai-api-responses-new')
|
24 |
+
|
25 |
+
|
26 |
+
def get_openai_model_encoding(model_id):
|
27 |
+
"""Get the encoding (tokenizer) for the OpenAI model."""
|
28 |
+
return encoding_for_model(model_id)
|
29 |
+
|
30 |
+
|
31 |
+
def get_max_chapter_segment_token_count(prompt: str, model_id: str) -> int:
|
32 |
+
"""
|
33 |
+
Calculate the maximum number of tokens that a chapter segment may contain
|
34 |
+
given the prompt.
|
35 |
+
"""
|
36 |
+
encoding = get_openai_model_encoding(model_id)
|
37 |
+
# `encode_ordinary()` ignores special tokens and is slightly faster than
|
38 |
+
# `encode()`.
|
39 |
+
prompt_token_count = len(encoding.encode_ordinary(prompt))
|
40 |
+
# Subtract 8 for tokens added by OpenAI in the prompt and response (refer
|
41 |
+
# to https://platform.openai.com/docs/guides/chat/managing-tokens for
|
42 |
+
# details).
|
43 |
+
# Subtract 1 for the newline added below to the end of the prompt.
|
44 |
+
# This calculation does not have to be exact.
|
45 |
+
max_chapter_segment_token_count = (MAX_MODEL_TOKEN_COUNT
|
46 |
+
- MAX_RESPONSE_TOKEN_COUNT
|
47 |
+
- prompt_token_count - 8 - 1)
|
48 |
+
return max_chapter_segment_token_count
|
49 |
+
|
50 |
+
|
51 |
+
@retry(retry=retry_if_exception_type((APIError, Timeout, RateLimitError,
|
52 |
+
ServiceUnavailableError, APIConnectionError, InvalidRequestError)),
|
53 |
+
wait=wait_random_exponential(max=60), stop=stop_after_attempt(10),
|
54 |
+
before_sleep=before_sleep_log(logger, logging.WARNING))
|
55 |
+
def save_openai_api_response(prompt_messages):
|
56 |
+
"""
|
57 |
+
Use a prompt to make a request to the OpenAI API and return the response data.
|
58 |
+
"""
|
59 |
+
|
60 |
+
openai.api_key = prompt_messages[0]['api_key'] # Set the API key for OpenAI
|
61 |
+
model_id = prompt_messages[0]['model_id'] # Get the model ID from the prompt messages
|
62 |
+
prompt_messages[0].pop('api_key') # Remove the API key from the prompt messages
|
63 |
+
prompt_messages[0].pop('model_id') # Remove the model ID from the prompt messages
|
64 |
+
|
65 |
+
try:
|
66 |
+
logger.info('Calling OpenAI API...')
|
67 |
+
response = openai.ChatCompletion.create(
|
68 |
+
model=model_id, messages=prompt_messages, temperature=0
|
69 |
+
)
|
70 |
+
finish_reason = response.choices[0].finish_reason
|
71 |
+
if finish_reason != 'stop':
|
72 |
+
logger.error(f'`finish_reason` is `{finish_reason}`.')
|
73 |
+
|
74 |
+
save_data = {
|
75 |
+
'model': response.model,
|
76 |
+
'usage': response.usage,
|
77 |
+
'finish_reason': finish_reason,
|
78 |
+
'prompt_messages': prompt_messages,
|
79 |
+
'response': response.choices[0].message.content
|
80 |
+
}
|
81 |
+
except InvalidRequestError:
|
82 |
+
logger.error('InvalidRequestError encountered 10 times. Returning empty response.')
|
83 |
+
save_data = {
|
84 |
+
'model': None,
|
85 |
+
'usage': None,
|
86 |
+
'finish_reason': 'invalid_request',
|
87 |
+
'prompt_messages': prompt_messages,
|
88 |
+
'response': ' '
|
89 |
+
}
|
90 |
+
|
91 |
+
return save_data
|
92 |
+
|
93 |
+
|
94 |
+
def load_response_text(save_path):
|
95 |
+
"""
|
96 |
+
Load the response text from a JSON file containing response data from the
|
97 |
+
OpenAI API.
|
98 |
+
"""
|
99 |
+
with open(save_path, 'r') as save_file:
|
100 |
+
save_data = json.load(save_file)
|
101 |
+
return save_data['response']
|
src/kg/preprocess.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
|
4 |
+
def split_scenes(script):
|
5 |
+
# TBD
|
6 |
+
return script
|
7 |
+
|
8 |
+
def preprocess(scene_list, idx):
|
9 |
+
script_dict = {}
|
10 |
+
script_dict['id'] = idx
|
11 |
+
script_dict['chapters'] = []
|
12 |
+
|
13 |
+
elem_dict = {}
|
14 |
+
elem_dict['index'] = 1
|
15 |
+
elem_dict['text'] = scene_list
|
16 |
+
elem_dict['summaries'] = ""
|
17 |
+
|
18 |
+
script_dict['chapters'].append(elem_dict)
|
19 |
+
|
20 |
+
return script_dict
|
src/kg/save_triples.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src.kg.save_triples.py
|
2 |
+
from pathlib import Path
|
3 |
+
import json
|
4 |
+
import argparse
|
5 |
+
import os
|
6 |
+
|
7 |
+
from pysbd import Segmenter
|
8 |
+
from tiktoken import Encoding
|
9 |
+
|
10 |
+
from .knowledge_graph import PROMPT_FILE_PATH
|
11 |
+
from .openai_api import (RESPONSES_DIRECTORY_PATH,
|
12 |
+
get_max_chapter_segment_token_count,
|
13 |
+
get_openai_model_encoding, save_openai_api_response)
|
14 |
+
from .utils import (execute_function_in_parallel, set_up_logging,
|
15 |
+
strip_and_remove_empty_strings)
|
16 |
+
|
17 |
+
logger = set_up_logging('openai-api-scripts.log')
|
18 |
+
|
19 |
+
|
20 |
+
def get_paragraphs(text):
|
21 |
+
"""Split a text into paragraphs."""
|
22 |
+
paragraphs = strip_and_remove_empty_strings(text.split('\n\n'))
|
23 |
+
# Convert all whitespace into single spaces.
|
24 |
+
paragraphs = [' '.join(paragraph.split()) for paragraph in paragraphs]
|
25 |
+
return paragraphs
|
26 |
+
|
27 |
+
|
28 |
+
def combine_text_subunits_into_segments(subunits, join_string,
|
29 |
+
encoding: Encoding,
|
30 |
+
max_token_count):
|
31 |
+
"""
|
32 |
+
Combine subunits of text into segments that do not exceed a maximum number
|
33 |
+
of tokens.
|
34 |
+
"""
|
35 |
+
# `encode_ordinary_batch()` ignores special tokens and is slightly faster
|
36 |
+
# than `encode_batch()`.
|
37 |
+
subunit_token_counts = [len(tokens) for tokens
|
38 |
+
in encoding.encode_ordinary_batch(subunits)]
|
39 |
+
join_string_token_count = len(encoding.encode_ordinary(join_string))
|
40 |
+
total_token_count = (sum(subunit_token_counts) + join_string_token_count
|
41 |
+
* (len(subunits) - 1))
|
42 |
+
if total_token_count <= max_token_count:
|
43 |
+
return [join_string.join(subunits)]
|
44 |
+
# Calculate the approximate number of segments and the approximate number
|
45 |
+
# of tokens per segment, in order to keep the segment lengths roughly
|
46 |
+
# equal.
|
47 |
+
approximate_segment_count = total_token_count // max_token_count + 1
|
48 |
+
approximate_segment_token_count = round(total_token_count
|
49 |
+
/ approximate_segment_count)
|
50 |
+
segments = []
|
51 |
+
current_segment_subunits = []
|
52 |
+
current_segment_token_count = 0
|
53 |
+
for i, (subunit, subunit_token_count) in enumerate(
|
54 |
+
zip(subunits, subunit_token_counts)):
|
55 |
+
# The token count if the current subunit is added to the current
|
56 |
+
# segment.
|
57 |
+
extended_segment_token_count = (current_segment_token_count
|
58 |
+
+ join_string_token_count
|
59 |
+
+ subunit_token_count)
|
60 |
+
# Add the current subunit to the current segment if it results in a
|
61 |
+
# token count that is closer to the approximate segment token count
|
62 |
+
# than the current segment token count.
|
63 |
+
if (extended_segment_token_count <= max_token_count
|
64 |
+
and abs(extended_segment_token_count
|
65 |
+
- approximate_segment_token_count)
|
66 |
+
<= abs(current_segment_token_count
|
67 |
+
- approximate_segment_token_count)):
|
68 |
+
current_segment_subunits.append(subunit)
|
69 |
+
current_segment_token_count = extended_segment_token_count
|
70 |
+
else:
|
71 |
+
segment = join_string.join(current_segment_subunits)
|
72 |
+
segments.append(segment)
|
73 |
+
# If it is possible to join the remaining subunits into a single
|
74 |
+
# segment, do so. Additionally, add the current subunit as a
|
75 |
+
# segment if it is the last subunit.
|
76 |
+
if (sum(subunit_token_counts[i:]) + join_string_token_count
|
77 |
+
* (len(subunits) - i - 1) <= max_token_count
|
78 |
+
or i == len(subunits) - 1):
|
79 |
+
segment = join_string.join(subunits[i:])
|
80 |
+
segments.append(segment)
|
81 |
+
break
|
82 |
+
current_segment_subunits = [subunit]
|
83 |
+
current_segment_token_count = subunit_token_count
|
84 |
+
return segments
|
85 |
+
|
86 |
+
|
87 |
+
def split_long_sentences(sentences, encoding: Encoding,
|
88 |
+
max_token_count):
|
89 |
+
"""
|
90 |
+
Given a list of sentences, split sentences that exceed a maximum number of
|
91 |
+
tokens into multiple segments.
|
92 |
+
"""
|
93 |
+
token_counts = [len(tokens) for tokens
|
94 |
+
in encoding.encode_ordinary_batch(sentences)]
|
95 |
+
split_sentences = []
|
96 |
+
for sentence, token_count in zip(sentences, token_counts):
|
97 |
+
if token_count > max_token_count:
|
98 |
+
words = sentence.split()
|
99 |
+
segments = combine_text_subunits_into_segments(
|
100 |
+
words, ' ', encoding, max_token_count)
|
101 |
+
split_sentences.extend(segments)
|
102 |
+
else:
|
103 |
+
split_sentences.append(sentence)
|
104 |
+
return split_sentences
|
105 |
+
|
106 |
+
|
107 |
+
def split_long_paragraphs(paragraphs, encoding: Encoding,
|
108 |
+
max_token_count):
|
109 |
+
"""
|
110 |
+
Given a list of paragraphs, split paragraphs that exceed a maximum number
|
111 |
+
of tokens into multiple segments.
|
112 |
+
"""
|
113 |
+
token_counts = [len(tokens) for tokens
|
114 |
+
in encoding.encode_ordinary_batch(paragraphs)]
|
115 |
+
split_paragraphs = []
|
116 |
+
for paragraph, token_count in zip(paragraphs, token_counts):
|
117 |
+
if token_count > max_token_count:
|
118 |
+
sentences = Segmenter().segment(paragraph)
|
119 |
+
sentences = split_long_sentences(sentences, encoding,
|
120 |
+
max_token_count)
|
121 |
+
segments = combine_text_subunits_into_segments(
|
122 |
+
sentences, ' ', encoding, max_token_count)
|
123 |
+
split_paragraphs.extend(segments)
|
124 |
+
else:
|
125 |
+
split_paragraphs.append(paragraph)
|
126 |
+
return split_paragraphs
|
127 |
+
|
128 |
+
|
129 |
+
def get_chapter_segments(chapter_text, encoding: Encoding,
|
130 |
+
max_token_count):
|
131 |
+
"""
|
132 |
+
Split a chapter text into segments that do not exceed a maximum number of
|
133 |
+
tokens.
|
134 |
+
"""
|
135 |
+
paragraphs = get_paragraphs(chapter_text)
|
136 |
+
paragraphs = split_long_paragraphs(paragraphs, encoding, max_token_count)
|
137 |
+
chapter_segments = combine_text_subunits_into_segments(
|
138 |
+
paragraphs, '\n', encoding, max_token_count)
|
139 |
+
return chapter_segments
|
140 |
+
|
141 |
+
|
142 |
+
def get_response_save_path(idx, save_path, project_gutenberg_id,
|
143 |
+
chapter_index = None,
|
144 |
+
chapter_segment_index = None,
|
145 |
+
chapter_segment_count = None):
|
146 |
+
"""
|
147 |
+
Get the path to the JSON file(s) containing response data from the OpenAI
|
148 |
+
API.
|
149 |
+
"""
|
150 |
+
save_path = Path(save_path)
|
151 |
+
os.makedirs(save_path, exist_ok=True)
|
152 |
+
|
153 |
+
if chapter_index is not None:
|
154 |
+
save_path /= str(chapter_index)
|
155 |
+
if chapter_segment_index is not None:
|
156 |
+
save_path /= (f'{chapter_segment_index + 1}-of-'
|
157 |
+
f'{chapter_segment_count}.json')
|
158 |
+
return save_path
|
159 |
+
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
+
def save_openai_api_responses_for_script(script, prompt, encoding, max_chapter_segment_token_count, idx, api_key, model_id):
|
164 |
+
"""
|
165 |
+
Call the OpenAI API for each chapter segment in a script and save the
|
166 |
+
responses to a list.
|
167 |
+
"""
|
168 |
+
project_gutenberg_id = script['id']
|
169 |
+
chapter_count = len(script['chapters'])
|
170 |
+
logger.info(f'Starting to call OpenAI API and process responses for script '
|
171 |
+
f'{project_gutenberg_id} ({chapter_count} chapters).')
|
172 |
+
|
173 |
+
prompt_message_lists = []
|
174 |
+
response_list = []
|
175 |
+
|
176 |
+
for chapter in script['chapters']:
|
177 |
+
chapter_index = chapter['index']
|
178 |
+
chapter_segments = chapter['text']
|
179 |
+
chapter_segment_count = len(chapter_segments)
|
180 |
+
|
181 |
+
for chapter_segment_index, chapter_segment in enumerate(chapter_segments):
|
182 |
+
prompt_with_story = prompt.replace('{STORY}', chapter_segment)
|
183 |
+
prompt_message_lists.append([{
|
184 |
+
'role': 'user',
|
185 |
+
'content': prompt_with_story,
|
186 |
+
'api_key': api_key,
|
187 |
+
'model_id': model_id
|
188 |
+
}])
|
189 |
+
|
190 |
+
responses = execute_function_in_parallel(save_openai_api_response, prompt_message_lists)
|
191 |
+
|
192 |
+
for response in responses:
|
193 |
+
response_list.append(response)
|
194 |
+
|
195 |
+
logger.info(f'Finished processing responses for script {project_gutenberg_id}.')
|
196 |
+
return response_list
|
197 |
+
|
198 |
+
|
199 |
+
def save_triples_for_scripts(input_data, idx, api_key, model_id):
|
200 |
+
"""
|
201 |
+
Call the OpenAI API to generate knowledge graph nodes and edges, and store
|
202 |
+
the responses in a list.
|
203 |
+
"""
|
204 |
+
# 1) load data
|
205 |
+
script = input_data
|
206 |
+
|
207 |
+
# 2) call OpenAI API
|
208 |
+
prompt = PROMPT_FILE_PATH.read_text() # load prompt
|
209 |
+
max_chapter_segment_token_count = get_max_chapter_segment_token_count(prompt, model_id)
|
210 |
+
encoding = get_openai_model_encoding(model_id)
|
211 |
+
responses = save_openai_api_responses_for_script(
|
212 |
+
script, prompt, encoding, max_chapter_segment_token_count, idx, api_key, model_id
|
213 |
+
)
|
214 |
+
|
215 |
+
return responses
|
src/kg/utils.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Callable
|
5 |
+
|
6 |
+
LOGS_DIRECTORY_PATH = Path('../logs')
|
7 |
+
# Number of processes to use when executing functions in parallel.
|
8 |
+
MAX_PROCESS_COUNT = 20
|
9 |
+
|
10 |
+
|
11 |
+
def set_up_logging(log_name):
|
12 |
+
"""Set up a logger that logs to both the console and a file."""
|
13 |
+
log_path = LOGS_DIRECTORY_PATH / log_name
|
14 |
+
log_path.parent.mkdir(parents=True, exist_ok=True)
|
15 |
+
logging.basicConfig(level=logging.INFO)
|
16 |
+
logger = logging.getLogger()
|
17 |
+
# Remove the default handler.
|
18 |
+
logger.handlers.clear()
|
19 |
+
stream_handler = logging.StreamHandler()
|
20 |
+
stream_handler.setLevel(logging.INFO)
|
21 |
+
stream_formatter = logging.Formatter(
|
22 |
+
'%(asctime)s - %(levelname)s - %(message)s', datefmt='%H:%M:%S')
|
23 |
+
stream_handler.setFormatter(stream_formatter)
|
24 |
+
logger.addHandler(stream_handler)
|
25 |
+
file_handler = logging.FileHandler(log_path)
|
26 |
+
file_handler.setLevel(logging.INFO)
|
27 |
+
file_formatter = logging.Formatter(
|
28 |
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
29 |
+
datefmt='%Y-%m-%d %H:%M:%S')
|
30 |
+
file_handler.setFormatter(file_formatter)
|
31 |
+
logger.addHandler(file_handler)
|
32 |
+
return logger
|
33 |
+
|
34 |
+
|
35 |
+
def strip_and_remove_empty_strings(strings):
|
36 |
+
"""Strip a list of strings and remove empty strings."""
|
37 |
+
strings = [string.strip() for string in strings]
|
38 |
+
strings = [string for string in strings if string]
|
39 |
+
return strings
|
40 |
+
|
41 |
+
|
42 |
+
def execute_function_in_parallel(function: Callable, *argument_lists,
|
43 |
+
logger=None):
|
44 |
+
"""Execute a function in parallel using multiple processes."""
|
45 |
+
with ProcessPoolExecutor(max_workers=MAX_PROCESS_COUNT) as executor:
|
46 |
+
futures = [executor.submit(function, *arguments)
|
47 |
+
for arguments in zip(*argument_lists)]
|
48 |
+
results = []
|
49 |
+
for future in as_completed(futures):
|
50 |
+
try:
|
51 |
+
result = future.result()
|
52 |
+
results.append(result)
|
53 |
+
except Exception as exception:
|
54 |
+
if logger:
|
55 |
+
logger.exception('Exception')
|
56 |
+
raise exception
|
57 |
+
return results
|
src/summary/prompt.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Optional
|
3 |
+
|
4 |
+
|
5 |
+
def build_summarizer_prompt(
|
6 |
+
prompt_template:str,
|
7 |
+
input_text_list:List[str],
|
8 |
+
chat_mode:Optional[str] = None) -> str:
|
9 |
+
|
10 |
+
"""_summary_
|
11 |
+
chat_mode(str) : 'hf-chat', 'kullm', 'None'
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
_type_: _description_
|
15 |
+
"""
|
16 |
+
|
17 |
+
if os.path.isfile(prompt_template):
|
18 |
+
with open(prompt_template,'r') as f:
|
19 |
+
prompt_template = f.read()
|
20 |
+
else:
|
21 |
+
pass
|
22 |
+
|
23 |
+
# 예외처리 필요
|
24 |
+
assert isinstance(prompt_template, str)
|
25 |
+
|
26 |
+
prompt = prompt_template.format(*input_text_list)
|
27 |
+
|
28 |
+
if chat_mode == "hf-chat":
|
29 |
+
prompt = _get_hf_chat_template().format(prompt)
|
30 |
+
elif chat_mode == "kullm":
|
31 |
+
prompt = _get_kullm_template().format(prompt)
|
32 |
+
|
33 |
+
return prompt
|
src/summary/summarizer.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import math
|
4 |
+
from typing import Union, Optional
|
5 |
+
import torch
|
6 |
+
import logging
|
7 |
+
|
8 |
+
#from vllm import LLM, SamplingParams
|
9 |
+
#from vllm.lora.request import LoRARequest
|
10 |
+
|
11 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, set_seed, BitsAndBytesConfig
|
12 |
+
import openai
|
13 |
+
from openai.error import (APIError, RateLimitError, ServiceUnavailableError,
|
14 |
+
Timeout, APIConnectionError, InvalidRequestError)
|
15 |
+
from tenacity import (before_sleep_log, retry, retry_if_exception_type,
|
16 |
+
stop_after_delay, wait_random_exponential, stop_after_attempt)
|
17 |
+
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
class Summarizer:
|
22 |
+
def __init__(self,
|
23 |
+
inference_mode:str,
|
24 |
+
model_id:str,
|
25 |
+
api_key:str,
|
26 |
+
dtype:str="bfloat16",
|
27 |
+
seed=42,
|
28 |
+
context_size:int=int(1024*26),
|
29 |
+
gpu_memory_utilization:int=0.7,
|
30 |
+
tensor_parallel_size=1
|
31 |
+
) -> None:
|
32 |
+
|
33 |
+
self.inference_mode=inference_mode
|
34 |
+
self.model = None
|
35 |
+
self.tokenizer = None
|
36 |
+
self.seed = seed
|
37 |
+
openai.api_key = api_key
|
38 |
+
self.model = model_id
|
39 |
+
|
40 |
+
def get_generation_config(
|
41 |
+
self,
|
42 |
+
repetition_penalty:float = 1.2,
|
43 |
+
do_sample:bool=True,
|
44 |
+
temperature:float = 0.1,
|
45 |
+
top_p:float = 0.9,
|
46 |
+
max_tokens:int = 1024
|
47 |
+
):
|
48 |
+
|
49 |
+
return generation_config
|
50 |
+
|
51 |
+
@retry(retry=retry_if_exception_type((APIError, Timeout, RateLimitError,
|
52 |
+
ServiceUnavailableError, APIConnectionError, InvalidRequestError)),
|
53 |
+
wait=wait_random_exponential(max=60), stop=stop_after_attempt(10),
|
54 |
+
before_sleep=before_sleep_log(logger, logging.WARNING))
|
55 |
+
def inference_with_gpt(self, prompt):
|
56 |
+
prompt_messages = [{"role": "user", "content": prompt}]
|
57 |
+
try:
|
58 |
+
response = openai.ChatCompletion.create(model = self.model, messages = prompt_messages, temperature = 0.1)
|
59 |
+
#finish_reason = response.choices[0].finish_reason
|
60 |
+
response = response.choices[0].message.content
|
61 |
+
except InvalidRequestError:
|
62 |
+
response = ''
|
63 |
+
|
64 |
+
return response
|
65 |
+
|
src/summary/utils.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#src.summary.utils.py
|
2 |
+
import re
|
3 |
+
from typing import List
|
4 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
5 |
+
from typing import Callable
|
6 |
+
|
7 |
+
from transformers import AutoTokenizer
|
8 |
+
from tiktoken import Encoding, encoding_for_model
|
9 |
+
|
10 |
+
SCENE_INDICATORS = ['씬/','씬','SS##','S#','s#','S','s','#\d+.','\d+.']
|
11 |
+
|
12 |
+
def delete_special(pre_text, character_list):
|
13 |
+
for c in character_list:
|
14 |
+
pre_text = pre_text.replace(c, "")
|
15 |
+
return pre_text
|
16 |
+
|
17 |
+
def preprocess_script(script:str) -> str:
|
18 |
+
|
19 |
+
lines = script.split("\n")
|
20 |
+
|
21 |
+
new_text = ""
|
22 |
+
for line in lines:
|
23 |
+
line = delete_special(line, ["\n", "\t", "\xa0",'၀','ᝰ','ศ','ನ','tุ','\x00Ā\x00\x00\x00'])
|
24 |
+
cleaned = re.sub('[^가-힣a-zA-Z0-9\s,.!?/#]',' ', line).strip()
|
25 |
+
cleaned = delete_special(cleaned, [" "]).strip()
|
26 |
+
cleaned = cleaned.replace("<|start|>", "").replace("<|end|>","")
|
27 |
+
if len(cleaned)>0:
|
28 |
+
new_text += f"{line}\n"
|
29 |
+
new_text = new_text.strip()
|
30 |
+
|
31 |
+
return new_text
|
32 |
+
|
33 |
+
|
34 |
+
def preprocess_scripts(scripts:List[str]) -> List[str]:
|
35 |
+
scripts = [preprocess_script(s) for s in scripts]
|
36 |
+
|
37 |
+
return scripts
|
38 |
+
|
39 |
+
def break_down2scenes(text: str):
|
40 |
+
# Split the text using "s#" as the delimiter
|
41 |
+
scenes = re.split(r'(s#\d+)', text)
|
42 |
+
|
43 |
+
# Remove empty elements from the split results
|
44 |
+
scenes = [scene for scene in scenes if scene.strip()]
|
45 |
+
|
46 |
+
scenes_list = []
|
47 |
+
current_scene_number = None
|
48 |
+
|
49 |
+
for i in range(0, len(scenes), 2): # Process the "s#" marker and corresponding text as pairs
|
50 |
+
scene_marker = scenes[i].strip()
|
51 |
+
scene_number = int(scene_marker.split('#')[1]) # Extract only the number
|
52 |
+
scene_text = scenes[i+1].strip() if i+1 < len(scenes) else ""
|
53 |
+
|
54 |
+
# Verify that the scene numbers are in the correct order
|
55 |
+
if current_scene_number is not None:
|
56 |
+
expected_scene_number = current_scene_number + 1
|
57 |
+
if scene_number != expected_scene_number:
|
58 |
+
raise ValueError(f"Unexpected scene number: {scene_number}, expected {expected_scene_number}")
|
59 |
+
|
60 |
+
# Save the scene number and text together
|
61 |
+
scenes_list.append({
|
62 |
+
'detected_scene_number': scene_number,
|
63 |
+
'text': f"{scene_marker}\n{scene_text}".strip()
|
64 |
+
})
|
65 |
+
return scenes_list
|
66 |
+
|
67 |
+
|
68 |
+
def chunk_script_gpt(script:str,
|
69 |
+
model_id:str,
|
70 |
+
chunk_size:int=-1) -> List[str]:
|
71 |
+
if chunk_size == -1:
|
72 |
+
chunks = [script]
|
73 |
+
print("Single Inference Mode")
|
74 |
+
return chunks
|
75 |
+
|
76 |
+
encoding = encoding_for_model(model_id)
|
77 |
+
|
78 |
+
scenes = break_down2scenes(script)
|
79 |
+
|
80 |
+
len_scenes = len(scenes)
|
81 |
+
|
82 |
+
chunks = []
|
83 |
+
if len_scenes > 10:
|
84 |
+
print(f"Num of detected scenes : {len_scenes}")
|
85 |
+
|
86 |
+
chunk = ""
|
87 |
+
token_len_chunk = 0
|
88 |
+
for i, scene_data in enumerate(scenes):
|
89 |
+
scene = scene_data["text"].strip()
|
90 |
+
token_len_scene = len(encoding.encode_ordinary(scene))
|
91 |
+
if token_len_chunk + token_len_scene > chunk_size:
|
92 |
+
if token_len_chunk == 0:
|
93 |
+
chunk += scene
|
94 |
+
token_len_chunk += token_len_scene
|
95 |
+
else:
|
96 |
+
chunks.append(chunk)
|
97 |
+
chunk = scene
|
98 |
+
token_len_chunk = token_len_scene
|
99 |
+
else:
|
100 |
+
chunk += scene
|
101 |
+
token_len_chunk += token_len_scene
|
102 |
+
|
103 |
+
if i == len_scenes-1:
|
104 |
+
chunks.append(chunk)
|
105 |
+
else:
|
106 |
+
print(f"No Detected Scenes ({len_scenes})")
|
107 |
+
tokenized_script = encoding.encode_ordinary(script)
|
108 |
+
token_len_script = len(tokenized_script)
|
109 |
+
for start in range(0,token_len_script,chunk_size):
|
110 |
+
if start + chunk_size >= token_len_script:
|
111 |
+
end = token_len_script+1
|
112 |
+
else:
|
113 |
+
end = start+chunk_size
|
114 |
+
|
115 |
+
chunk = encoding.decode(tokenized_script[start:end])
|
116 |
+
chunks.append(chunk)
|
117 |
+
print(f"Num of chunks : {len(chunks)}")
|
118 |
+
return chunks
|
templates/atomic_fact.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
I will give you a summary from a chunk of movie script.
|
2 |
+
Your task is to provide me with a list of atomic facts expressed in the given summary.
|
3 |
+
Each atomic fact should be described in a name-only third-person format.
|
4 |
+
Please separate each atomic fact with a `\n`.
|
5 |
+
Summary: {}
|
templates/external_summary.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
This is a part of a script from a Movie. Read the following content carefully, then answer my question:
|
2 |
+
{}
|
3 |
+
The script has ended now.
|
4 |
+
|
5 |
+
Please summarize the content:
|
6 |
+
- Provide a detailed summary of the key characters' actions, emotions, and situations as reflected in the dialogue or context.
|
7 |
+
- Clearly state the outcome of the events.
|
8 |
+
- The summary should be between 2 to 5 sentences long.
|
templates/fact_score.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Consider the given statement and the related scene.
|
2 |
+
Indicate whether the statement is supported by the scene.
|
3 |
+
Negation of a false statement should be considered supported.
|
4 |
+
If the statement is true, output 1.
|
5 |
+
If the statement is false, output the reason why it is false.
|
6 |
+
Scene: {}
|
7 |
+
Statement: {}
|
8 |
+
Output:
|
templates/fact_score_kg.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Consider the given statement, the related scene, and the relationship subgraph.
|
2 |
+
Indicate whether the statement is supported by the scene and the relationship subgraph.
|
3 |
+
Negation of a false statement should be considered supported.
|
4 |
+
If the statement is true, output 1.
|
5 |
+
If the statement is false, output the reason why it is false.
|
6 |
+
Scene: {}
|
7 |
+
Relationship Subgraph: {}
|
8 |
+
Statement: {}
|
9 |
+
Output:
|
templates/self_correction.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Below is a part of the script from the titled movie.
|
2 |
+
- Script: {}
|
3 |
+
Based on the 'Statement to Revise' and 'Reason for Revision,' create a 'Revised Summary' of the 'Summary of the Script.'
|
4 |
+
Keep the revised summary concise and similar in length to the original summary.
|
5 |
+
Do not directly copy any part of the 'Script.'
|
6 |
+
If the 'Summary of the Script' is accurate, generate the original summary as is.
|
7 |
+
- Summary of the Script: {}
|
8 |
+
- Revised Summary:
|
templates/story-prompt.txt
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Read part of a story, then identify named entities and generate knowledge graph edges.
|
2 |
+
|
3 |
+
<Important Instructions>
|
4 |
+
- Do not use or refer to the content, names, or entities from the example provided below in your response.
|
5 |
+
- Only apply the process of identifying named entities and generating knowledge graph edges to the new story excerpt provided.
|
6 |
+
|
7 |
+
<Example> (for reference only, do not use in your response):
|
8 |
+
[Begin story excerpt]
|
9 |
+
"Christmas won't be Christmas without any presents," grumbled Jo. "It's so dreadful to be poor!" sighed Meg, looking out the window at the snow-covered streets of Concord. "I don't think it's fair for some girls to have plenty of pretty things, and other girls nothing at all," added little Amy, with an injured sniff. "We've got Father and Mother, and each other," said Beth contentedly from her corner. The four young faces brightened at the cheerful words, but darkened again as Jo said sadly, "We haven't got Father, and shall not have him for a long time." She didn't say "perhaps never," but each silently added it, thinking of Father far away, where the fighting was.
|
10 |
+
As young readers like to know 'how people look', we will take this moment to give them a little sketch of the four sisters. Margaret March, the eldest of the four, was sixteen, and very pretty, with large eyes, plenty of soft brown hair, a sweet mouth, and white hands. Fifteen-year-old Jo March was very tall, thin, and brown, and never seemed to know what to do with her long limbs. Elizabeth, or Beth, as everyone called her, was a rosy, smooth-haired, bright-eyed girl of thirteen, with a shy manner, a timid voice, and a peaceful expression which was seldom disturbed. Amy, the youngest, was a regular snow maiden, with blue eyes, and yellow hair curling on her shoulders.
|
11 |
+
The clock struck six and, having swept up the hearth, Beth put a pair of slippers down to warm. Somehow the sight of the old shoes had a good effect upon the girls, for Mother was coming, and everyone brightened to welcome her. Jo sat up to hold the slippers nearer to the blaze. "They are quite worn out. Marmee must have a new pair." "I thought I'd get her some with my dollar," said Beth. "No, I shall!" cried Amy. "I'll tell you what we'll do," said Beth, "let's each get her something for Christmas, and not get anything for ourselves." "Let Marmee think we are getting things for ourselves, and then surprise her. We must go shopping tomorrow afternoon," said Jo, marching up and down.
|
12 |
+
"Glad to find you so merry, my girls," said a cheery voice at the door, and the girls turned to welcome a tall, motherly lady. She was not elegantly dressed, but the girls thought the gray cloak and unfashionable bonnet covered the most splendid mother in the world. As they gathered about the table, Mrs. March said, with a particularly happy face, "I've got a treat for you after supper." A quick, bright smile went round like a streak of sunshine. Beth clapped her hands, and Jo tossed up her napkin, crying, "A letter! A letter! Three cheers for Father!" "Yes, a nice long letter. He is well, and he sends all sorts of loving wishes for Christmas, and an especial message to you girls," said Mrs. March, patting her pocket as if she had got a treasure there. "I think it was so splendid in Father to go as chaplain when he was too old to be drafted, and not strong enough for a soldier," said Meg warmly, proud of her father's work with the Union Army.
|
13 |
+
[End story excerpt]
|
14 |
+
|
15 |
+
Named entities (include all aliases and name variations):
|
16 |
+
Jo / Jo March
|
17 |
+
Meg / Margaret / Margaret March
|
18 |
+
Amy
|
19 |
+
Beth / Elizabeth
|
20 |
+
March sisters
|
21 |
+
Mrs. March / Marmee / Mother
|
22 |
+
Father
|
23 |
+
Concord
|
24 |
+
Union Army
|
25 |
+
|
26 |
+
Knowledge graph edges (select up to 20 most important, `subject(s); predicate; [object(s)]` format, named entities only, predicate(relation): two words or fewer):
|
27 |
+
1. Jo, Meg, Amy, Beth; in; March sisters
|
28 |
+
2. March sisters; daughters of; Mrs. March, Father
|
29 |
+
3. Mrs. March; mother of; March sisters
|
30 |
+
4. Father; father of; March sisters
|
31 |
+
5. March sisters, Mrs. March; living in; Concord
|
32 |
+
6. Father; fighting in war
|
33 |
+
7. Father; chaplain in; Union Army
|
34 |
+
8. Meg; sixteen years
|
35 |
+
9. Jo; fifteen years
|
36 |
+
10. Beth; thirteen years
|
37 |
+
11. Beth; shy
|
38 |
+
12. Amy; youngest among; March sisters
|
39 |
+
|
40 |
+
<Request>
|
41 |
+
[Begin story excerpt]
|
42 |
+
{STORY}
|
43 |
+
[End story excerpt]
|