import os
import time
import spacy
import shutil
import pickle
import random
import logging
import asyncio
import warnings
import rapidjson
import gradio as gr
import networkx as nx
from llm_graph import LLMGraph, MODEL_LIST
from pyvis.network import Network
from spacy import displacy
from spacy.tokens import Span
logging.basicConfig(level=logging.INFO)
warnings.filterwarnings("ignore", category=UserWarning)
# Constants
TITLE = "🌐 Text2Graph: Extract Knowledge Graphs from Natural Language"
SUBTITLE = "✨ Extract and visualize knowledge graphs from texts in any language!"
# Basic CSS for styling
CUSTOM_CSS = """
.gradio-container {
font-family: 'Segoe UI', Roboto, sans-serif;
}
"""
# Cache directory and file paths
CACHE_DIR = "./cache"
WORKING_DIR = "./sample"
EXAMPLE_CACHE_FILE = os.path.join(CACHE_DIR, "first_example_cache.pkl")
GRAPHML_FILE = WORKING_DIR + "/graph_chunk_entity_relation.graphml"
# Load the sample texts
text_en_file1 = "./data/sample1_en.txt"
with open(text_en_file1, 'r', encoding='utf-8') as file:
text1_en = file.read()
text_en_file2 = "./data/sample2_en.txt"
with open(text_en_file2, 'r', encoding='utf-8') as file:
text2_en = file.read()
text_en_file3 = "./data/sample3_en.txt"
with open(text_en_file3, 'r', encoding='utf-8') as file:
text3_en = file.read()
text_fr_file = "./data/sample_fr.txt"
with open(text_fr_file, 'r', encoding='utf-8') as file:
text_fr = file.read()
text_es_file = "./data/sample_es.txt"
with open(text_es_file, 'r', encoding='utf-8') as file:
text_es = file.read()
# Create cache directory if it doesn't exist
os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs(WORKING_DIR, exist_ok=True)
def get_random_light_color():
"""
Color utilities
"""
r = random.randint(140, 255)
g = random.randint(140, 255)
b = random.randint(140, 255)
return f"#{r:02x}{g:02x}{b:02x}"
def handle_text(text=""):
"""
Text preprocessing
"""
# Catch empty text
if not text:
return ""
return " ".join(text.split())
def extract_kg(text="", model_name=MODEL_LIST[0], model=None):
"""
Extract knowledge graph from text
"""
# Catch empty text
if not text or not model_name:
raise gr.Error("⚠️ Both text and model must be provided!")
if not model:
raise gr.Error("⚠️ Model must be provided!")
try:
start_time = time.time()
result = model.extract(text, model_name)
end_time = time.time()
duration = end_time - start_time
logging.info(f"Response time: {duration:.4f} seconds")
if isinstance(result, dict):
return result
else: # convert string to dict
return rapidjson.loads(result)
except Exception as e:
raise gr.Error(f"❌ Extraction error: {str(e)}")
def find_token_indices(doc, substring, text):
"""
Find token indices for a given substring in the text
based on the provided spaCy doc.
"""
result = []
start_idx = text.find(substring)
while start_idx != -1:
end_idx = start_idx + len(substring)
start_token = None
end_token = None
for token in doc:
if token.idx == start_idx:
start_token = token.i
if token.idx + len(token) == end_idx:
end_token = token.i + 1
if start_token is not None and end_token is not None:
result.append({
"start": start_token,
"end": end_token
})
# Search for next occurrence
start_idx = text.find(substring, end_idx)
return result
def create_custom_entity_viz(data, full_text, type_col="type"):
"""
Create custom entity visualization using spaCy's displacy
"""
nlp = spacy.blank("xx")
doc = nlp(full_text)
spans = []
colors = {}
for node in data["nodes"]:
entity_spans = find_token_indices(doc, node["id"], full_text)
for entity in entity_spans:
start = entity["start"]
end = entity["end"]
if start < len(doc) and end <= len(doc):
# Check for overlapping spans
overlapping = any(s.start < end and start < s.end for s in spans)
if not overlapping:
node_type = node.get(type_col, "Entity")
span = Span(doc, start, end, label=node_type)
spans.append(span)
if node_type not in colors:
colors[node_type] = get_random_light_color()
doc.set_ents(spans, default="unmodified")
doc.spans["sc"] = spans
options = {
"colors": colors,
"ents": list(colors.keys()),
"style": "ent",
"manual": True
}
html = displacy.render(doc, style="span", options=options)
# Add custom styling to the entity visualization
styled_html = f"""
{html}
"""
return styled_html
def create_graph(json_data, model_name=MODEL_LIST[0]):
"""
Create interactive knowledge graph using pyvis
"""
if model_name == MODEL_LIST[0]:
G = nx.Graph()
# Add nodes with tooltips and error handling for missing keys
for node in json_data['nodes']:
# Get node type with fallback
type = node.get("type", "Entity")
# Get detailed type with fallback
detailed_type = node.get("detailed_type", type)
# Use node ID and type info for the tooltip
G.add_node(node['id'], title=f"{type}: {detailed_type}")
# Add edges with labels
for edge in json_data['edges']:
# Check if the required keys exist
if 'from' in edge and 'to' in edge:
label = edge.get('label', 'related')
G.add_edge(edge['from'], edge['to'], title=label, label=label)
else:
G = nx.read_graphml(GRAPHML_FILE)
# Create network visualization
network = Network(
width="100%",
# height="700px",
height="100vh",
notebook=False,
bgcolor="#f8fafc",
font_color="#1e293b"
)
# Configure network display
network.from_nx(G)
if model_name == MODEL_LIST[0]:
network.barnes_hut(
gravity=-3000,
central_gravity=0.3,
spring_length=50,
spring_strength=0.001,
damping=0.09,
overlap=0,
)
# Customize node appearance
for node in network.nodes:
if "description" in node:
node["title"] = node["description"]
node['color'] = {'background': '#e0e7ff', 'border': '#6366f1', 'highlight': {'background': '#c7d2fe', 'border': '#4f46e5'}}
node['font'] = {'size': 14, 'color': '#1e293b'}
node['shape'] = 'dot'
node['size'] = 20
# Customize edge appearance
for edge in network.edges:
if "description" in edge:
edge["title"] = edge["description"]
edge['width'] = 4
# edge['arrows'] = {'to': {'enabled': False, 'type': 'arrow'}}
edge['color'] = {'color': '#6366f1', 'highlight': '#4f46e5'}
edge['font'] = {'size': 12, 'color': '#4b5563', 'face': 'Arial'}
# Generate HTML with iframe to isolate styles
html = network.generate_html()
html = html.replace("'", '"')
return f""""""
def process_and_visualize(text, model_name, progress=gr.Progress()):
"""
Process text and visualize knowledge graph and entities
"""
if not text or not model_name:
raise gr.Error("⚠️ Both text and model must be provided!")
# Check if we're processing the first example for caching
is_first_example = text == EXAMPLES[0][0]
# Clear the working directory if it exists
if os.path.exists(WORKING_DIR):
shutil.rmtree(WORKING_DIR)
os.makedirs(WORKING_DIR, exist_ok=True)
# Initialize the LLMGraph model
model = LLMGraph()
asyncio.run(model.initialize_rag())
# Try to load from cache if it's the first example
if is_first_example and model_name == MODEL_LIST[0] and os.path.exists(EXAMPLE_CACHE_FILE):
try:
progress(0.3, desc="Loading from cache...")
with open(EXAMPLE_CACHE_FILE, 'rb') as f:
cached_data = pickle.load(f)
progress(1.0, desc="Loaded from cache!")
return cached_data["graph_html"], cached_data["entities_viz"], cached_data["json_data"], cached_data["stats"]
except Exception as e:
logging.error(f"Cache loading error: {str(e)}")
# Continue with normal processing if cache fails
progress(0, desc="Starting extraction...")
json_data = extract_kg(text, model_name, model)
progress(0.5, desc="Creating entity visualization...")
if model_name == MODEL_LIST[0]:
entities_viz = create_custom_entity_viz(json_data, text, type_col="type")
else:
entities_viz = create_custom_entity_viz(json_data, text, type_col="entity_type")
progress(0.8, desc="Building knowledge graph...")
graph_html = create_graph(json_data, model_name)
node_count = len(json_data["nodes"])
edge_count = len(json_data["edges"])
stats = f"📊 Extracted {node_count} entities and {edge_count} relationships"
# Save to cache if it's the first example
if is_first_example and model_name == MODEL_LIST[0]:
try:
cached_data = {
"graph_html": graph_html,
"entities_viz": entities_viz,
"json_data": json_data,
"stats": stats
}
with open(EXAMPLE_CACHE_FILE, 'wb') as f:
pickle.dump(cached_data, f)
except Exception as e:
logging.error(f"Cache saving error: {str(e)}")
progress(1.0, desc="Complete!")
return graph_html, entities_viz, json_data, stats
# Example texts
EXAMPLES = [
[handle_text(text1_en)],
[handle_text(text_fr)],
[handle_text(text2_en)],
[handle_text(text_es)],
[handle_text(text3_en)]
]
def generate_first_example():
"""
Generate cache for the first example if it doesn't exist when the app starts.
"""
if not os.path.exists(EXAMPLE_CACHE_FILE):
logging.info("Generating cache for first example...")
try:
text = EXAMPLES[0][0]
model_name = MODEL_LIST[0] if MODEL_LIST else None
# Initialize the LLMGraph model
model = LLMGraph()
asyncio.run(model.initialize_rag())
# Extract data
json_data = extract_kg(text, model_name, model)
entities_viz = create_custom_entity_viz(json_data, text)
graph_html = create_graph(json_data)
node_count = len(json_data["nodes"])
edge_count = len(json_data["edges"])
stats = f"📊 Extracted {node_count} entities and {edge_count} relationships"
# Save to cache
cached_data = {
"graph_html": graph_html,
"entities_viz": entities_viz,
"json_data": json_data,
"stats": stats
}
with open(EXAMPLE_CACHE_FILE, 'wb') as f:
pickle.dump(cached_data, f)
logging.info("First example cache generated successfully")
return cached_data
except Exception as e:
logging.error(f"Error generating first example cache: {str(e)}")
else:
logging.info("First example cache already exists")
# Load existing cache
try:
with open(EXAMPLE_CACHE_FILE, 'rb') as f:
return pickle.load(f)
except Exception as e:
logging.error(f"Error loading existing cache: {str(e)}")
return None
def create_ui():
"""
Create the Gradio UI
"""
# Clear the working directory if it exists
if os.path.exists(WORKING_DIR):
shutil.rmtree(WORKING_DIR)
os.makedirs(WORKING_DIR, exist_ok=True)
# Try to generate/load the first example cache
first_example = generate_first_example()
with gr.Blocks(css=CUSTOM_CSS, title=TITLE) as demo:
# Header
gr.Markdown(f"# {TITLE}")
gr.Markdown(f"{SUBTITLE}")
# Main content area
with gr.Row():
# Left panel - Input controls
with gr.Column(scale=1):
input_model = gr.Radio(
MODEL_LIST,
label="🤖 Select Model",
info="Choose a model to process your text",
value=MODEL_LIST[0] if MODEL_LIST else None,
)
input_text = gr.TextArea(
label="📝 Input Text",
info="Enter text in any language to extract a knowledge graph",
placeholder="Enter text here...",
lines=8,
value=EXAMPLES[0][0] # Pre-fill with first example
)
with gr.Row():
submit_button = gr.Button("🚀 Extract & Visualize", variant="primary", scale=2)
clear_button = gr.Button("🔄 Clear", variant="secondary", scale=1)
# Statistics will appear here
stats_output = gr.Markdown("", label="🔍 Analysis Results")
# Right panel - Examples moved to right side
with gr.Column(scale=1):
gr.Markdown("## 📚 Example Texts")
gr.Examples(
examples=EXAMPLES,
inputs=input_text,
label=""
)
# JSON output moved to right side as well
with gr.Accordion("📊 JSON Data", open=False):
output_json = gr.JSON(label="")
# Full width visualization area at the bottom
with gr.Row():
# Full width visualization area
with gr.Tabs():
with gr.TabItem("🧩 Knowledge Graph"):
output_graph = gr.HTML(label="")
with gr.TabItem("🏷️ Entity Recognition"):
output_entity_viz = gr.HTML(label="")
# Functionality
submit_button.click(
fn=process_and_visualize,
inputs=[input_text, input_model],
outputs=[output_graph, output_entity_viz, output_json, stats_output]
)
clear_button.click(
fn=lambda: [None, None, None, ""],
inputs=[],
outputs=[output_graph, output_entity_viz, output_json, stats_output]
)
# Set initial values from cache if available
if first_example:
# Use this to set initial values when the app loads
demo.load(
lambda: [
first_example["graph_html"],
first_example["entities_viz"],
first_example["json_data"],
first_example["stats"]
],
inputs=None,
outputs=[output_graph, output_entity_viz, output_json, stats_output]
)
# Footer
gr.Markdown("---")
gr.Markdown("📋 **Instructions:** Enter text in any language, select a model and click `Extract & Visualize` to generate a knowledge graph.")
gr.Markdown("🛠️ Powered by [GPT-4.1-mini](https://platform.openai.com/docs/models/gpt-4.1-mini) and [Phi-3-mini-128k-instruct-graph](https://huggingface.co/EmergentMethods/Phi-3-mini-128k-instruct-graph)")
return demo
def main():
"""
Main function to run the Gradio app
"""
demo = create_ui()
demo.launch(share=False)
if __name__ == "__main__":
main()