import os import time import spacy import shutil import pickle import random import logging import asyncio import warnings import rapidjson import gradio as gr import networkx as nx from llm_graph import LLMGraph, MODEL_LIST from pyvis.network import Network from spacy import displacy from spacy.tokens import Span logging.basicConfig(level=logging.INFO) warnings.filterwarnings("ignore", category=UserWarning) # Constants TITLE = "🌐 Text2Graph: Extract Knowledge Graphs from Natural Language" SUBTITLE = "✨ Extract and visualize knowledge graphs from texts in any language!" # Basic CSS for styling CUSTOM_CSS = """ .gradio-container { font-family: 'Segoe UI', Roboto, sans-serif; } """ # Cache directory and file paths CACHE_DIR = "./cache" WORKING_DIR = "./sample" EXAMPLE_CACHE_FILE = os.path.join(CACHE_DIR, "first_example_cache.pkl") GRAPHML_FILE = WORKING_DIR + "/graph_chunk_entity_relation.graphml" # Load the sample texts text_en_file1 = "./data/sample1_en.txt" with open(text_en_file1, 'r', encoding='utf-8') as file: text1_en = file.read() text_en_file2 = "./data/sample2_en.txt" with open(text_en_file2, 'r', encoding='utf-8') as file: text2_en = file.read() text_en_file3 = "./data/sample3_en.txt" with open(text_en_file3, 'r', encoding='utf-8') as file: text3_en = file.read() text_fr_file = "./data/sample_fr.txt" with open(text_fr_file, 'r', encoding='utf-8') as file: text_fr = file.read() text_es_file = "./data/sample_es.txt" with open(text_es_file, 'r', encoding='utf-8') as file: text_es = file.read() # Create cache directory if it doesn't exist os.makedirs(CACHE_DIR, exist_ok=True) os.makedirs(WORKING_DIR, exist_ok=True) def get_random_light_color(): """ Color utilities """ r = random.randint(140, 255) g = random.randint(140, 255) b = random.randint(140, 255) return f"#{r:02x}{g:02x}{b:02x}" def handle_text(text=""): """ Text preprocessing """ # Catch empty text if not text: return "" return " ".join(text.split()) def extract_kg(text="", model_name=MODEL_LIST[0], model=None): """ Extract knowledge graph from text """ # Catch empty text if not text or not model_name: raise gr.Error("⚠️ Both text and model must be provided!") if not model: raise gr.Error("⚠️ Model must be provided!") try: start_time = time.time() result = model.extract(text, model_name) end_time = time.time() duration = end_time - start_time logging.info(f"Response time: {duration:.4f} seconds") if isinstance(result, dict): return result else: # convert string to dict return rapidjson.loads(result) except Exception as e: raise gr.Error(f"❌ Extraction error: {str(e)}") def find_token_indices(doc, substring, text): """ Find token indices for a given substring in the text based on the provided spaCy doc. """ result = [] start_idx = text.find(substring) while start_idx != -1: end_idx = start_idx + len(substring) start_token = None end_token = None for token in doc: if token.idx == start_idx: start_token = token.i if token.idx + len(token) == end_idx: end_token = token.i + 1 if start_token is not None and end_token is not None: result.append({ "start": start_token, "end": end_token }) # Search for next occurrence start_idx = text.find(substring, end_idx) return result def create_custom_entity_viz(data, full_text, type_col="type"): """ Create custom entity visualization using spaCy's displacy """ nlp = spacy.blank("xx") doc = nlp(full_text) spans = [] colors = {} for node in data["nodes"]: entity_spans = find_token_indices(doc, node["id"], full_text) for entity in entity_spans: start = entity["start"] end = entity["end"] if start < len(doc) and end <= len(doc): # Check for overlapping spans overlapping = any(s.start < end and start < s.end for s in spans) if not overlapping: node_type = node.get(type_col, "Entity") span = Span(doc, start, end, label=node_type) spans.append(span) if node_type not in colors: colors[node_type] = get_random_light_color() doc.set_ents(spans, default="unmodified") doc.spans["sc"] = spans options = { "colors": colors, "ents": list(colors.keys()), "style": "ent", "manual": True } html = displacy.render(doc, style="span", options=options) # Add custom styling to the entity visualization styled_html = f"""
{html}
""" return styled_html def create_graph(json_data, model_name=MODEL_LIST[0]): """ Create interactive knowledge graph using pyvis """ if model_name == MODEL_LIST[0]: G = nx.Graph() # Add nodes with tooltips and error handling for missing keys for node in json_data['nodes']: # Get node type with fallback type = node.get("type", "Entity") # Get detailed type with fallback detailed_type = node.get("detailed_type", type) # Use node ID and type info for the tooltip G.add_node(node['id'], title=f"{type}: {detailed_type}") # Add edges with labels for edge in json_data['edges']: # Check if the required keys exist if 'from' in edge and 'to' in edge: label = edge.get('label', 'related') G.add_edge(edge['from'], edge['to'], title=label, label=label) else: G = nx.read_graphml(GRAPHML_FILE) # Create network visualization network = Network( width="100%", # height="700px", height="100vh", notebook=False, bgcolor="#f8fafc", font_color="#1e293b" ) # Configure network display network.from_nx(G) if model_name == MODEL_LIST[0]: network.barnes_hut( gravity=-3000, central_gravity=0.3, spring_length=50, spring_strength=0.001, damping=0.09, overlap=0, ) # Customize node appearance for node in network.nodes: if "description" in node: node["title"] = node["description"] node['color'] = {'background': '#e0e7ff', 'border': '#6366f1', 'highlight': {'background': '#c7d2fe', 'border': '#4f46e5'}} node['font'] = {'size': 14, 'color': '#1e293b'} node['shape'] = 'dot' node['size'] = 20 # Customize edge appearance for edge in network.edges: if "description" in edge: edge["title"] = edge["description"] edge['width'] = 4 # edge['arrows'] = {'to': {'enabled': False, 'type': 'arrow'}} edge['color'] = {'color': '#6366f1', 'highlight': '#4f46e5'} edge['font'] = {'size': 12, 'color': '#4b5563', 'face': 'Arial'} # Generate HTML with iframe to isolate styles html = network.generate_html() html = html.replace("'", '"') return f"""""" def process_and_visualize(text, model_name, progress=gr.Progress()): """ Process text and visualize knowledge graph and entities """ if not text or not model_name: raise gr.Error("⚠️ Both text and model must be provided!") # Check if we're processing the first example for caching is_first_example = text == EXAMPLES[0][0] # Clear the working directory if it exists if os.path.exists(WORKING_DIR): shutil.rmtree(WORKING_DIR) os.makedirs(WORKING_DIR, exist_ok=True) # Initialize the LLMGraph model model = LLMGraph() asyncio.run(model.initialize_rag()) # Try to load from cache if it's the first example if is_first_example and model_name == MODEL_LIST[0] and os.path.exists(EXAMPLE_CACHE_FILE): try: progress(0.3, desc="Loading from cache...") with open(EXAMPLE_CACHE_FILE, 'rb') as f: cached_data = pickle.load(f) progress(1.0, desc="Loaded from cache!") return cached_data["graph_html"], cached_data["entities_viz"], cached_data["json_data"], cached_data["stats"] except Exception as e: logging.error(f"Cache loading error: {str(e)}") # Continue with normal processing if cache fails progress(0, desc="Starting extraction...") json_data = extract_kg(text, model_name, model) progress(0.5, desc="Creating entity visualization...") if model_name == MODEL_LIST[0]: entities_viz = create_custom_entity_viz(json_data, text, type_col="type") else: entities_viz = create_custom_entity_viz(json_data, text, type_col="entity_type") progress(0.8, desc="Building knowledge graph...") graph_html = create_graph(json_data, model_name) node_count = len(json_data["nodes"]) edge_count = len(json_data["edges"]) stats = f"📊 Extracted {node_count} entities and {edge_count} relationships" # Save to cache if it's the first example if is_first_example and model_name == MODEL_LIST[0]: try: cached_data = { "graph_html": graph_html, "entities_viz": entities_viz, "json_data": json_data, "stats": stats } with open(EXAMPLE_CACHE_FILE, 'wb') as f: pickle.dump(cached_data, f) except Exception as e: logging.error(f"Cache saving error: {str(e)}") progress(1.0, desc="Complete!") return graph_html, entities_viz, json_data, stats # Example texts EXAMPLES = [ [handle_text(text1_en)], [handle_text(text_fr)], [handle_text(text2_en)], [handle_text(text_es)], [handle_text(text3_en)] ] def generate_first_example(): """ Generate cache for the first example if it doesn't exist when the app starts. """ if not os.path.exists(EXAMPLE_CACHE_FILE): logging.info("Generating cache for first example...") try: text = EXAMPLES[0][0] model_name = MODEL_LIST[0] if MODEL_LIST else None # Initialize the LLMGraph model model = LLMGraph() asyncio.run(model.initialize_rag()) # Extract data json_data = extract_kg(text, model_name, model) entities_viz = create_custom_entity_viz(json_data, text) graph_html = create_graph(json_data) node_count = len(json_data["nodes"]) edge_count = len(json_data["edges"]) stats = f"📊 Extracted {node_count} entities and {edge_count} relationships" # Save to cache cached_data = { "graph_html": graph_html, "entities_viz": entities_viz, "json_data": json_data, "stats": stats } with open(EXAMPLE_CACHE_FILE, 'wb') as f: pickle.dump(cached_data, f) logging.info("First example cache generated successfully") return cached_data except Exception as e: logging.error(f"Error generating first example cache: {str(e)}") else: logging.info("First example cache already exists") # Load existing cache try: with open(EXAMPLE_CACHE_FILE, 'rb') as f: return pickle.load(f) except Exception as e: logging.error(f"Error loading existing cache: {str(e)}") return None def create_ui(): """ Create the Gradio UI """ # Clear the working directory if it exists if os.path.exists(WORKING_DIR): shutil.rmtree(WORKING_DIR) os.makedirs(WORKING_DIR, exist_ok=True) # Try to generate/load the first example cache first_example = generate_first_example() with gr.Blocks(css=CUSTOM_CSS, title=TITLE) as demo: # Header gr.Markdown(f"# {TITLE}") gr.Markdown(f"{SUBTITLE}") # Main content area with gr.Row(): # Left panel - Input controls with gr.Column(scale=1): input_model = gr.Radio( MODEL_LIST, label="🤖 Select Model", info="Choose a model to process your text", value=MODEL_LIST[0] if MODEL_LIST else None, ) input_text = gr.TextArea( label="📝 Input Text", info="Enter text in any language to extract a knowledge graph", placeholder="Enter text here...", lines=8, value=EXAMPLES[0][0] # Pre-fill with first example ) with gr.Row(): submit_button = gr.Button("🚀 Extract & Visualize", variant="primary", scale=2) clear_button = gr.Button("🔄 Clear", variant="secondary", scale=1) # Statistics will appear here stats_output = gr.Markdown("", label="🔍 Analysis Results") # Right panel - Examples moved to right side with gr.Column(scale=1): gr.Markdown("## 📚 Example Texts") gr.Examples( examples=EXAMPLES, inputs=input_text, label="" ) # JSON output moved to right side as well with gr.Accordion("📊 JSON Data", open=False): output_json = gr.JSON(label="") # Full width visualization area at the bottom with gr.Row(): # Full width visualization area with gr.Tabs(): with gr.TabItem("🧩 Knowledge Graph"): output_graph = gr.HTML(label="") with gr.TabItem("🏷️ Entity Recognition"): output_entity_viz = gr.HTML(label="") # Functionality submit_button.click( fn=process_and_visualize, inputs=[input_text, input_model], outputs=[output_graph, output_entity_viz, output_json, stats_output] ) clear_button.click( fn=lambda: [None, None, None, ""], inputs=[], outputs=[output_graph, output_entity_viz, output_json, stats_output] ) # Set initial values from cache if available if first_example: # Use this to set initial values when the app loads demo.load( lambda: [ first_example["graph_html"], first_example["entities_viz"], first_example["json_data"], first_example["stats"] ], inputs=None, outputs=[output_graph, output_entity_viz, output_json, stats_output] ) # Footer gr.Markdown("---") gr.Markdown("📋 **Instructions:** Enter text in any language, select a model and click `Extract & Visualize` to generate a knowledge graph.") gr.Markdown("🛠️ Powered by [GPT-4.1-mini](https://platform.openai.com/docs/models/gpt-4.1-mini) and [Phi-3-mini-128k-instruct-graph](https://huggingface.co/EmergentMethods/Phi-3-mini-128k-instruct-graph)") return demo def main(): """ Main function to run the Gradio app """ demo = create_ui() demo.launch(share=False) if __name__ == "__main__": main()