File size: 5,890 Bytes
514abd8
 
d4a7f4f
 
 
 
8f09805
514abd8
 
 
 
 
 
d4a7f4f
 
 
 
 
 
 
 
70a117b
d4a7f4f
 
 
70a117b
 
d4a7f4f
70a117b
 
 
d4a7f4f
70a117b
d4a7f4f
 
70a117b
 
 
d4a7f4f
 
 
 
3acc8ac
 
 
 
 
 
70a117b
 
 
 
 
3acc8ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332ba53
3acc8ac
 
 
 
 
 
 
 
 
 
 
332ba53
3acc8ac
 
332ba53
3acc8ac
 
 
 
 
 
 
 
 
12bee48
 
 
 
3acc8ac
 
 
 
 
 
 
 
 
 
 
 
12bee48
3acc8ac
 
 
 
 
 
 
 
 
 
 
 
 
12bee48
3acc8ac
12bee48
 
 
3acc8ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import gradio as gr
from py2neo import Graph
from langchain_community.graphs.neo4j_graph import Neo4jGraph
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WikipediaLoader
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import json

# Set up the connection to the Neo4j database
url = "neo4j+s://ddb8863b.databases.neo4j.io"
username = "neo4j"
password = "vz6OLij_IrY-cSIgSMhUWxblTUzH8m4bZaBeJGgmtU0"
graph = Graph(url, auth=(username, password))
neo4j_graph = Neo4jGraph(url=url, username=username, password=password)

# Initialize the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")

# Function to extract relations from model output
def extract_relations_from_model_output(text):
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
    return triplets
    extracted_triplets = extract_triplets(extracted_text[0])
    print(extracted_triplets)
    return extracted_triplets

class KB():
    def __init__(self):
        self.relations = []
    def are_relations_equal(self, r1, r2):
        return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])
    def exists_relation(self, r1):
        return any(self.are_relations_equal(r1, r2) for r2 in self.relations)
    def add_relation(self, r):
        if not self.exists_relation(r):
            self.relations.append(r)
    def print(self):
        print("Relations:")
        for r in self.relations:
            print(f"  {r}")

def from_small_text_to_kb(text, verbose=False):
    kb = KB()
    model_inputs = tokenizer(text, max_length=512, padding=True, truncation=True, return_tensors='pt')
    if verbose:
        print(f"Num tokens: {len(model_inputs['input_ids'][0])}")
    print("Tokens are done")
    gen_kwargs = {
        "max_length": 216,
        "length_penalty": 0,
        "num_beams": 3,
        "num_return_sequences": 3
    }
    generated_tokens = model.generate(
        **model_inputs,
        **gen_kwargs,
    )
    decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
    print("Before for loop")
    for sentence_pred in decoded_preds:
        relations = extract_relations_from_model_output(sentence_pred)
        print(len(relations))
        for r in relations:
            kb.add_relation(r)
    return kb

# Function to insert data into Neo4j from Wikipedia query
def insert_data_from_wikipedia(query):
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, length_function=len, is_separator_regex=False)
    raw_documents = WikipediaLoader(query=query).load_and_split(text_splitter=text_splitter)

    if not raw_documents:
        print("No documents found for query:", query)
        return False

    for doc in raw_documents:
        kb = from_small_text_to_kb(doc.page_content, verbose=True)
        for relation in kb.relations:
            head = relation['head']
            relationship = relation['type']
            tail = relation['tail']
            if head and relationship and tail:
                cypher = f"MERGE (h:`{head}`) MERGE (t:`{tail}`) MERGE (h)-[:`{relationship}`]->(t)"
                print(f"Executing Cypher query: {cypher}")  # Debug print for Cypher query
                graph.run(cypher)
            else:
                print(f"Skipping invalid relation: head='{head}', relationship='{relationship}', tail='{tail}'")  # Skip invalid relations
    return True

# Function to query the database
def query_neo4j(query):
    if not query.strip():
        return json.dumps({"error": "Empty Cypher query"}, indent=2)  # Handle empty query case
    try:
        result = graph.run(query).data()
        return json.dumps(result, indent=2)  # Convert to JSON string
    except Exception as e:
        return json.dumps({"error": str(e)}, indent=2)  # Return error as JSON

# Gradio interface function
def gradio_interface(wiki_query, cypher_query):
    if not wiki_query.strip():
        return json.dumps({"error": "Wikipedia query cannot be empty"}, indent=2)
    success = insert_data_from_wikipedia(wiki_query)
    if not success:
        return json.dumps({"error": f"No data found for Wikipedia query: {wiki_query}"}, indent=2)
    if not cypher_query.strip():
        return json.dumps({"error": "Cypher query cannot be empty"}, indent=2)
    result = query_neo4j(cypher_query)
    return result

# Create the Gradio interface
iface = gr.Interface(
    fn=gradio_interface,
    inputs=["text", "text"],
    outputs="json",
    title="Neo4j and Wikipedia Interface",
    description="Insert data from Wikipedia and query the Neo4j database."
)

# Launch the Gradio app
if __name__ == "__main__":
    iface.launch()