|
|
|
import pickle |
|
from collections import defaultdict, Counter |
|
from contextlib import redirect_stdout |
|
from pathlib import Path |
|
import json |
|
import argparse |
|
import os |
|
import openai |
|
import time |
|
import numpy as np |
|
|
|
import networkx as nx |
|
from pyvis.network import Network |
|
from tqdm import tqdm |
|
from contextlib import redirect_stdout |
|
|
|
|
|
|
|
from .knowledge_graph import generate_knowledge_graph |
|
from .openai_api import load_response_text |
|
from .save_triples import get_response_save_path |
|
from .utils import set_up_logging |
|
|
|
logger = set_up_logging('generate-knowledge-graphs-books.log') |
|
KNOWLEDGE_GRAPHS_DIRECTORY_PATH = Path('../knowledge-graphs_new') |
|
|
|
|
|
"""def gpt_inference(system_instruction, prompt, retries=10, delay=5): |
|
# api |
|
messages = [{"role": "system", "content": system_instruction}, |
|
{"role": "user", "content": prompt}] |
|
|
|
for attempt in range(retries): |
|
try: |
|
response = openai.ChatCompletion.create( |
|
model='gpt-4o-mini-2024-07-18', |
|
messages=messages, |
|
temperature=0.0, |
|
max_tokens=128, |
|
top_p=0.5, |
|
frequency_penalty=0, |
|
presence_penalty=0 |
|
) |
|
result = response['choices'][0]['message']['content'] |
|
return result |
|
except openai.error.APIError as e: |
|
|
|
time.sleep(delay) |
|
continue""" |
|
|
|
|
|
def generate_knowledge_graph_for_scripts(book, idx, save_path): |
|
""" |
|
Use the responses from the OpenAI API to generate a knowledge graph for a |
|
book. |
|
""" |
|
response_texts = defaultdict(list) |
|
project_gutenberg_id = book['id'] |
|
for chapter in book['chapters']: |
|
chapter_index = chapter['index'] |
|
chapter_responses_directory = get_response_save_path( |
|
idx, save_path, project_gutenberg_id, chapter_index) |
|
for response_path in chapter_responses_directory.glob('*.json'): |
|
response_text = load_response_text(response_path) |
|
response_texts[chapter_index].append(response_text) |
|
knowledge_graph = generate_knowledge_graph(response_texts, project_gutenberg_id) |
|
return knowledge_graph |
|
|
|
def generate_knowledge_graph_for_scripts(book, idx, response_list): |
|
""" |
|
Use the responses from the OpenAI API to generate a knowledge graph for a |
|
book. |
|
""" |
|
|
|
response_texts = defaultdict(list) |
|
project_gutenberg_id = book['id'] |
|
for chapter in book['chapters']: |
|
chapter_index = chapter['index'] |
|
for response in response_list: |
|
response_texts[chapter_index].append(response['response']) |
|
knowledge_graph = generate_knowledge_graph(response_texts, project_gutenberg_id) |
|
return knowledge_graph |
|
|
|
|
|
def save_knowledge_graph(knowledge_graph, |
|
project_gutenberg_id, save_path): |
|
"""Save a knowledge graph to a `pickle` file.""" |
|
save_path = save_path / 'kg.pkl' |
|
save_path.parent.mkdir(parents=True, exist_ok=True) |
|
with open(save_path, 'wb') as knowledge_graph_file: |
|
pickle.dump(knowledge_graph, knowledge_graph_file) |
|
|
|
|
|
def load_knowledge_graph(project_gutenberg_id, save_path): |
|
"""Load a knowledge graph from a `pickle` file.""" |
|
save_path = save_path / 'kg.pkl' |
|
with open(save_path, 'rb') as knowledge_graph_file: |
|
knowledge_graph = pickle.load(knowledge_graph_file) |
|
return knowledge_graph |
|
|
|
|
|
def display_knowledge_graph(knowledge_graph, save_path): |
|
"""Display a knowledge graph using pyvis.""" |
|
|
|
|
|
pyvis_graph = nx.MultiDiGraph() |
|
for node in knowledge_graph.nodes: |
|
pyvis_graph.add_node(str(node), label='\n'.join(node.names), |
|
shape='box') |
|
for edge in knowledge_graph.edges(data=True): |
|
subject = str(edge[0]) |
|
object_ = str(edge[1]) |
|
predicate = edge[2]['predicate'] |
|
chapter_index = edge[2]['chapter_index'] |
|
if pyvis_graph.has_edge(subject, object_): |
|
pyvis_graph[subject][object_][0].update( |
|
title=(f'{pyvis_graph[subject][object_][0]["title"]}\n' |
|
f'{predicate}')) |
|
else: |
|
pyvis_graph.add_edge(subject, object_, |
|
title=f'{predicate}') |
|
network = Network(height='99vh', directed=True, bgcolor='#262626', |
|
cdn_resources='remote') |
|
network.set_options(''' |
|
const options = { |
|
"interaction": { |
|
"tooltipDelay": 0 |
|
}, |
|
"physics": { |
|
"forceAtlas2Based": { |
|
"gravitationalConstant": -50, |
|
"centralGravity": 0.01, |
|
"springLength": 100, |
|
"springConstant": 0.08, |
|
"damping": 0.4, |
|
"avoidOverlap": 0 |
|
}, |
|
"solver": "forceAtlas2Based" |
|
} |
|
}''') |
|
network.from_nx(pyvis_graph) |
|
save_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
with redirect_stdout(None): |
|
network.show(str(save_path), notebook=False) |
|
logger.info(f'Saved pyvis knowledge graph to {save_path}.') |
|
|
|
def fuse_subject(subjects): |
|
subject_list = subjects.split('/') |
|
if len(subject_list) == 1: |
|
return subject_list[0] |
|
flag = 0 |
|
striped_subject_list = [] |
|
len_list = [] |
|
for subject in subject_list: |
|
striped_subject_list.append(subject.strip()) |
|
len_list.append(len(subject)) |
|
idx = np.argmin(len_list) |
|
for subject in striped_subject_list: |
|
if striped_subject_list[idx] in subject: |
|
flag += 1 |
|
|
|
if flag == len(striped_subject_list): |
|
return striped_subject_list[idx] |
|
else: |
|
return subjects |
|
|
|
def init_kg(script, idx, response_list): |
|
""" |
|
Generate knowledge graphs for book in the books dataset using saved |
|
responses from the OpenAI API. |
|
""" |
|
knowledge_graph = generate_knowledge_graph_for_scripts(script, idx, response_list) |
|
return knowledge_graph |
|
|
|
def refine_kg(knowledge_graph, idx, topk): |
|
result = [] |
|
edge_count = Counter() |
|
for edge in knowledge_graph.edges(data=True): |
|
subject = str(edge[0]) |
|
object_ = str(edge[1]) |
|
edge_count[subject] += 1 |
|
edge_count[object_] += 1 |
|
|
|
|
|
top_k_nodes = [node for node, count in edge_count.most_common(topk)] |
|
|
|
|
|
rel_dict = {} |
|
for edge in knowledge_graph.edges(data=True): |
|
subject = str(edge[0]) |
|
object_ = str(edge[1]) |
|
if subject in top_k_nodes and object_ in top_k_nodes: |
|
predicate = edge[2]['predicate'] |
|
chapter_index = edge[2]['chapter_index'] |
|
count = edge[2]['count'] |
|
key = f"{subject}\t{object_}" |
|
if key not in rel_dict: |
|
rel_dict[key] = [] |
|
rel_dict[key].append((predicate, chapter_index, count)) |
|
|
|
|
|
pyvis_graph = nx.MultiDiGraph() |
|
for node in top_k_nodes: |
|
pyvis_graph.add_node(node, label=node, shape='box') |
|
|
|
for key, relations in rel_dict.items(): |
|
subject, object_ = key.split('\t') |
|
for relation in relations: |
|
predicate, chapter_index, count = relation |
|
if 'output' in predicate: |
|
continue |
|
if count >= 2: |
|
if pyvis_graph.has_edge(subject, object_): |
|
pyvis_graph[subject][object_][0]['title'] += f', {predicate}' |
|
else: |
|
pyvis_graph.add_edge(subject, object_, title=f'{predicate}') |
|
|
|
network = Network(height='99vh', directed=True, bgcolor='#262626', cdn_resources='remote') |
|
network.from_nx(pyvis_graph) |
|
|
|
with redirect_stdout(None): |
|
network.show('refined_kg.html', notebook=False) |
|
|
|
for key, relations in rel_dict.items(): |
|
subject, object_ = key.split('\t') |
|
|
|
for relation in relations: |
|
predicate, chapter_index, count = relation |
|
|
|
if 'output' in predicate: |
|
continue |
|
|
|
subject = fuse_subject(subject) |
|
object_ = fuse_subject(object_) |
|
|
|
relationship = { |
|
'subject': subject, |
|
'predicate': predicate, |
|
'object': object_, |
|
'chapter_index': chapter_index, |
|
'count': count, |
|
'subject_node_count': edge_count[subject], |
|
'object_node_count': edge_count[object_] |
|
} |
|
|
|
if count >= 2: |
|
result.append(relationship) |
|
|
|
return result |
|
|
|
|