File size: 8,893 Bytes
eaa3d8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 |
# src.kg.generate_kg.py
import pickle
from collections import defaultdict, Counter
from contextlib import redirect_stdout
from pathlib import Path
import json
import argparse
import os
import openai
import time
import numpy as np
import networkx as nx
from pyvis.network import Network
from tqdm import tqdm
from contextlib import redirect_stdout
from .knowledge_graph import generate_knowledge_graph
from .openai_api import load_response_text
from .save_triples import get_response_save_path
from .utils import set_up_logging
logger = set_up_logging('generate-knowledge-graphs-books.log')
KNOWLEDGE_GRAPHS_DIRECTORY_PATH = Path('../knowledge-graphs_new')
"""def gpt_inference(system_instruction, prompt, retries=10, delay=5):
# api
messages = [{"role": "system", "content": system_instruction},
{"role": "user", "content": prompt}]
for attempt in range(retries):
try:
response = openai.ChatCompletion.create(
model='gpt-4o-mini-2024-07-18',
messages=messages,
temperature=0.0,
max_tokens=128,
top_p=0.5,
frequency_penalty=0,
presence_penalty=0
)
result = response['choices'][0]['message']['content']
return result
except openai.error.APIError as e:
time.sleep(delay)
continue"""
def generate_knowledge_graph_for_scripts(book, idx, save_path):
"""
Use the responses from the OpenAI API to generate a knowledge graph for a
book.
"""
response_texts = defaultdict(list)
project_gutenberg_id = book['id']
for chapter in book['chapters']:
chapter_index = chapter['index']
chapter_responses_directory = get_response_save_path(
idx, save_path, project_gutenberg_id, chapter_index)
for response_path in chapter_responses_directory.glob('*.json'):
response_text = load_response_text(response_path)
response_texts[chapter_index].append(response_text)
knowledge_graph = generate_knowledge_graph(response_texts, project_gutenberg_id)
return knowledge_graph
def generate_knowledge_graph_for_scripts(book, idx, response_list):
"""
Use the responses from the OpenAI API to generate a knowledge graph for a
book.
"""
response_texts = defaultdict(list)
project_gutenberg_id = book['id']
for chapter in book['chapters']:
chapter_index = chapter['index']
for response in response_list:
response_texts[chapter_index].append(response['response'])
knowledge_graph = generate_knowledge_graph(response_texts, project_gutenberg_id)
return knowledge_graph
def save_knowledge_graph(knowledge_graph,
project_gutenberg_id, save_path):
"""Save a knowledge graph to a `pickle` file."""
save_path = save_path / 'kg.pkl'
save_path.parent.mkdir(parents=True, exist_ok=True)
with open(save_path, 'wb') as knowledge_graph_file:
pickle.dump(knowledge_graph, knowledge_graph_file)
def load_knowledge_graph(project_gutenberg_id, save_path):
"""Load a knowledge graph from a `pickle` file."""
save_path = save_path / 'kg.pkl'
with open(save_path, 'rb') as knowledge_graph_file:
knowledge_graph = pickle.load(knowledge_graph_file)
return knowledge_graph
def display_knowledge_graph(knowledge_graph, save_path):
"""Display a knowledge graph using pyvis."""
# Convert the knowledge graph into a format that can be displayed by pyvis.
# Merge all edges with the same subject and object into a single edge.
pyvis_graph = nx.MultiDiGraph()
for node in knowledge_graph.nodes:
pyvis_graph.add_node(str(node), label='\n'.join(node.names),
shape='box')
for edge in knowledge_graph.edges(data=True):
subject = str(edge[0])
object_ = str(edge[1])
predicate = edge[2]['predicate']
chapter_index = edge[2]['chapter_index']
if pyvis_graph.has_edge(subject, object_):
pyvis_graph[subject][object_][0].update(
title=(f'{pyvis_graph[subject][object_][0]["title"]}\n'
f'{predicate}')) # f'{predicate} ({chapter_index})'))
else:
pyvis_graph.add_edge(subject, object_,
title=f'{predicate}') # title=f'{predicate} ({chapter_index})')
network = Network(height='99vh', directed=True, bgcolor='#262626',
cdn_resources='remote')
network.set_options('''
const options = {
"interaction": {
"tooltipDelay": 0
},
"physics": {
"forceAtlas2Based": {
"gravitationalConstant": -50,
"centralGravity": 0.01,
"springLength": 100,
"springConstant": 0.08,
"damping": 0.4,
"avoidOverlap": 0
},
"solver": "forceAtlas2Based"
}
}''')
network.from_nx(pyvis_graph)
save_path.parent.mkdir(parents=True, exist_ok=True)
# `show()` tries to print the name of the HTML file to the console, so
# suppress it.
with redirect_stdout(None):
network.show(str(save_path), notebook=False)
logger.info(f'Saved pyvis knowledge graph to {save_path}.')
def fuse_subject(subjects):
subject_list = subjects.split('/')
if len(subject_list) == 1:
return subject_list[0]
flag = 0
striped_subject_list = []
len_list = []
for subject in subject_list:
striped_subject_list.append(subject.strip())
len_list.append(len(subject))
idx = np.argmin(len_list)
for subject in striped_subject_list:
if striped_subject_list[idx] in subject:
flag += 1
if flag == len(striped_subject_list):
return striped_subject_list[idx]
else:
return subjects
def init_kg(script, idx, response_list):
"""
Generate knowledge graphs for book in the books dataset using saved
responses from the OpenAI API.
"""
knowledge_graph = generate_knowledge_graph_for_scripts(script, idx, response_list)
return knowledge_graph
def refine_kg(knowledge_graph, idx, topk):
result = []
edge_count = Counter()
for edge in knowledge_graph.edges(data=True):
subject = str(edge[0])
object_ = str(edge[1])
edge_count[subject] += 1
edge_count[object_] += 1
# μ£μ§κ° λ§μ μμ kκ°μ λ
Έλ μ ν
top_k_nodes = [node for node, count in edge_count.most_common(topk)]
# μμ kκ° λ
Έλ κ°μ λͺ¨λ κ΄κ³λ₯Ό μμ§
rel_dict = {}
for edge in knowledge_graph.edges(data=True):
subject = str(edge[0])
object_ = str(edge[1])
if subject in top_k_nodes and object_ in top_k_nodes:
predicate = edge[2]['predicate']
chapter_index = edge[2]['chapter_index']
count = edge[2]['count']
key = f"{subject}\t{object_}"
if key not in rel_dict:
rel_dict[key] = []
rel_dict[key].append((predicate, chapter_index, count))
# μκ°ν μ½λ
pyvis_graph = nx.MultiDiGraph()
for node in top_k_nodes:
pyvis_graph.add_node(node, label=node, shape='box')
for key, relations in rel_dict.items():
subject, object_ = key.split('\t')
for relation in relations:
predicate, chapter_index, count = relation
if 'output' in predicate:
continue
if count >= 2:
if pyvis_graph.has_edge(subject, object_):
pyvis_graph[subject][object_][0]['title'] += f', {predicate}'
else:
pyvis_graph.add_edge(subject, object_, title=f'{predicate}')
network = Network(height='99vh', directed=True, bgcolor='#262626', cdn_resources='remote')
network.from_nx(pyvis_graph)
with redirect_stdout(None):
network.show('refined_kg.html', notebook=False)
for key, relations in rel_dict.items():
subject, object_ = key.split('\t')
for relation in relations:
predicate, chapter_index, count = relation
if 'output' in predicate:
continue
subject = fuse_subject(subject)
object_ = fuse_subject(object_)
relationship = {
'subject': subject,
'predicate': predicate,
'object': object_,
'chapter_index': chapter_index,
'count': count,
'subject_node_count': edge_count[subject],
'object_node_count': edge_count[object_]
}
if count >= 2:
result.append(relationship)
return result
|