File size: 5,890 Bytes
514abd8 d4a7f4f 8f09805 514abd8 d4a7f4f 70a117b d4a7f4f 70a117b d4a7f4f 70a117b d4a7f4f 70a117b d4a7f4f 70a117b d4a7f4f 3acc8ac 70a117b 3acc8ac 332ba53 3acc8ac 332ba53 3acc8ac 332ba53 3acc8ac 12bee48 3acc8ac 12bee48 3acc8ac 12bee48 3acc8ac 12bee48 3acc8ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import gradio as gr
from py2neo import Graph
from langchain_community.graphs.neo4j_graph import Neo4jGraph
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WikipediaLoader
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import json
# Set up the connection to the Neo4j database
url = "neo4j+s://ddb8863b.databases.neo4j.io"
username = "neo4j"
password = "vz6OLij_IrY-cSIgSMhUWxblTUzH8m4bZaBeJGgmtU0"
graph = Graph(url, auth=(username, password))
neo4j_graph = Neo4jGraph(url=url, username=username, password=password)
# Initialize the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
# Function to extract relations from model output
def extract_relations_from_model_output(text):
triplets = []
relation, subject, relation, object_ = '', '', '', ''
text = text.strip()
current = 'x'
for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
if token == "<triplet>":
current = 't'
if relation != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
relation = ''
subject = ''
elif token == "<subj>":
current = 's'
if relation != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
object_ = ''
elif token == "<obj>":
current = 'o'
relation = ''
else:
if current == 't':
subject += ' ' + token
elif current == 's':
object_ += ' ' + token
elif current == 'o':
relation += ' ' + token
if subject != '' and relation != '' and object_ != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
return triplets
extracted_triplets = extract_triplets(extracted_text[0])
print(extracted_triplets)
return extracted_triplets
class KB():
def __init__(self):
self.relations = []
def are_relations_equal(self, r1, r2):
return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])
def exists_relation(self, r1):
return any(self.are_relations_equal(r1, r2) for r2 in self.relations)
def add_relation(self, r):
if not self.exists_relation(r):
self.relations.append(r)
def print(self):
print("Relations:")
for r in self.relations:
print(f" {r}")
def from_small_text_to_kb(text, verbose=False):
kb = KB()
model_inputs = tokenizer(text, max_length=512, padding=True, truncation=True, return_tensors='pt')
if verbose:
print(f"Num tokens: {len(model_inputs['input_ids'][0])}")
print("Tokens are done")
gen_kwargs = {
"max_length": 216,
"length_penalty": 0,
"num_beams": 3,
"num_return_sequences": 3
}
generated_tokens = model.generate(
**model_inputs,
**gen_kwargs,
)
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
print("Before for loop")
for sentence_pred in decoded_preds:
relations = extract_relations_from_model_output(sentence_pred)
print(len(relations))
for r in relations:
kb.add_relation(r)
return kb
# Function to insert data into Neo4j from Wikipedia query
def insert_data_from_wikipedia(query):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, length_function=len, is_separator_regex=False)
raw_documents = WikipediaLoader(query=query).load_and_split(text_splitter=text_splitter)
if not raw_documents:
print("No documents found for query:", query)
return False
for doc in raw_documents:
kb = from_small_text_to_kb(doc.page_content, verbose=True)
for relation in kb.relations:
head = relation['head']
relationship = relation['type']
tail = relation['tail']
if head and relationship and tail:
cypher = f"MERGE (h:`{head}`) MERGE (t:`{tail}`) MERGE (h)-[:`{relationship}`]->(t)"
print(f"Executing Cypher query: {cypher}") # Debug print for Cypher query
graph.run(cypher)
else:
print(f"Skipping invalid relation: head='{head}', relationship='{relationship}', tail='{tail}'") # Skip invalid relations
return True
# Function to query the database
def query_neo4j(query):
if not query.strip():
return json.dumps({"error": "Empty Cypher query"}, indent=2) # Handle empty query case
try:
result = graph.run(query).data()
return json.dumps(result, indent=2) # Convert to JSON string
except Exception as e:
return json.dumps({"error": str(e)}, indent=2) # Return error as JSON
# Gradio interface function
def gradio_interface(wiki_query, cypher_query):
if not wiki_query.strip():
return json.dumps({"error": "Wikipedia query cannot be empty"}, indent=2)
success = insert_data_from_wikipedia(wiki_query)
if not success:
return json.dumps({"error": f"No data found for Wikipedia query: {wiki_query}"}, indent=2)
if not cypher_query.strip():
return json.dumps({"error": "Cypher query cannot be empty"}, indent=2)
result = query_neo4j(cypher_query)
return result
# Create the Gradio interface
iface = gr.Interface(
fn=gradio_interface,
inputs=["text", "text"],
outputs="json",
title="Neo4j and Wikipedia Interface",
description="Insert data from Wikipedia and query the Neo4j database."
)
# Launch the Gradio app
if __name__ == "__main__":
iface.launch()
|