import streamlit as st from annotated_text import annotated_text from refined.inference.processor import Refined import requests import json import spacy import spacy.cli import warnings import logging from transformers import AutoTokenizer import os # Suppress torch warnings warnings.filterwarnings("ignore", message=".*torch.classes.*") warnings.filterwarnings("ignore", message=".*__path__._path.*") # Set logging level to reduce noise logging.getLogger("torch").setLevel(logging.ERROR) logging.getLogger("transformers").setLevel(logging.ERROR) # Page config st.set_page_config( page_title="Entity Linking by WordLift", page_icon="fav-ico.png", layout="wide", initial_sidebar_state="collapsed", menu_items={ 'Get Help': 'https://wordlift.io/book-a-demo/', 'About': "# This is a demo app for NEL/NED/NER and SEO" } ) # Sidebar st.sidebar.image("logo-wordlift.png") language_options = {"English", "English - spaCy", "German"} selected_language = st.sidebar.selectbox("Select the Language", list(language_options), index=0) # Based on selected language, configure model, entity set, and citation options if selected_language == "German" or selected_language == "English - spaCy": selected_model_name = None selected_entity_set = None entity_fishing_citation = """ @misc{entity-fishing, title = {entity-fishing}, publisher = {GitHub}, year = {2016--2023}, archivePrefix = {swh}, eprint = {1:dir:cb0ba3379413db12b0018b7c3af8d0d2d864139c} } """ with st.sidebar.expander('Citations'): st.markdown(entity_fishing_citation) else: model_options = ["aida_model", "wikipedia_model_with_numbers"] entity_set_options = ["wikidata", "wikipedia"] selected_model_name = st.sidebar.selectbox("Select the Model", model_options) selected_entity_set = st.sidebar.selectbox("Select the Entity Set", entity_set_options) refined_citation = """ @inproceedings{ayoola-etal-2022-refined, title = "{R}e{F}in{ED}: An Efficient Zero-shot-capable Approach to End-to-End Entity Linking", author = "Tom Ayoola, Shubhi Tyagi, Joseph Fisher, Christos Christodoulopoulos, Andrea Pierleoni", booktitle = "NAACL", year = "2022" } """ with st.sidebar.expander('Citations'): st.markdown(refined_citation) @st.cache_resource # 👈 Add the caching decorator def load_model(selected_language, model_name=None, entity_set=None): # This dictionary maps the easy names to their full Hugging Face Hub IDs model_mapping = { "aida_model": "amazon-science/ReFinED-aida-model", "wikipedia_model_with_numbers": "amazon-science/ReFinED-wikipedia-model" } with warnings.catch_warnings(): warnings.simplefilter("ignore") try: # This block handles the spaCy models for German and English if selected_language == "German": try: nlp_model_de = spacy.load("de_core_news_lg") except OSError: st.info("Downloading German language model... This may take a moment.") spacy.cli.download("de_core_news_lg") nlp_model_de = spacy.load("de_core_news_lg") if "entityfishing" not in nlp_model_de.pipe_names: try: nlp_model_de.add_pipe("entityfishing") except Exception as e: st.warning(f"Entity-fishing not available: {e}") return nlp_model_de elif selected_language == "English - spaCy": try: nlp_model_en = spacy.load("en_core_web_sm") except OSError: st.info("Downloading English language model... This may take a moment.") spacy.cli.download("en_core_web_sm") nlp_model_en = spacy.load("en_core_web_sm") if "entityfishing" not in nlp_model_en.pipe_names: try: nlp_model_en.add_pipe("entityfishing") except Exception as e: st.warning(f"Entity-fishing not available: {e}") return nlp_model_en # This block handles the ReFinED model and the "add_special_tokens" error else: try: return Refined.from_pretrained(model_name=model_name, entity_set=entity_set) except Exception as e: if "add_special_tokens" in str(e): st.warning("Conflict detected. Applying fix by downloading and patching model...") # 1. Get the REAL model name from our mapping real_model_name = model_mapping.get(model_name) if not real_model_name: st.error(f"Unknown model alias: {model_name}") return None # 2. Define a local path to save the fixed model local_model_path = f"./{model_name}-{entity_set}-fixed" # 3. Download the tokenizer and the model using the REAL name st.info(f"Downloading model files for {real_model_name}...") tokenizer = AutoTokenizer.from_pretrained(real_model_name) model_files = AutoModelForSeq2SeqLM.from_pretrained(real_model_name) # 4. Save them to the local directory tokenizer.save_pretrained(local_model_path) model_files.save_pretrained(local_model_path) st.info("Model files downloaded.") # 5. Patch the tokenizer config file config_path = os.path.join(local_model_path, "tokenizer_config.json") with open(config_path, "r") as f: config_data = json.load(f) config_data.pop("add_special_tokens", None) # Remove the conflicting key with open(config_path, "w") as f: json.dump(config_data, f, indent=2) # 6. Load the model from the local, fixed path st.success("Patch applied. Loading model from local cache...") return Refined.from_pretrained(model_name=local_model_path, entity_set=entity_set) else: raise e # If it's a different error, we still want to see it except Exception as e: st.error(f"Failed to load model. Error: {e}") return None # Use the cached model model = load_model(selected_language, selected_model_name, selected_entity_set) # Helper functions def get_wikidata_id(entity_string): entity_list = entity_string.split("=") entity_id = str(entity_list[1]) entity_link = "http://www.wikidata.org/entity/" + entity_id return {"id": entity_id, "link": entity_link} def get_entity_data(entity_link): try: # Format the entity_link formatted_link = entity_link.replace("http://", "http/") response = requests.get(f'https://api.wordlift.io/id/{formatted_link}') return response.json() except Exception as e: print(f"Exception when fetching data for entity: {entity_link}. Exception: {e}") return None # Create the form with st.form(key='my_form'): text_input = st.text_area(label='Enter a sentence') submit_button = st.form_submit_button(label='Analyze') # Initialization entities_map = {} entities_data = {} if text_input and model is not None: try: if selected_language in ["German", "English - spaCy"]: # Process the text with error handling for spaCy doc = model(text_input) entities = [] for ent in doc.ents: try: kb_qid = getattr(ent._, 'kb_qid', None) url_wikidata = getattr(ent._, 'url_wikidata', None) entities.append((ent.text, ent.label_, kb_qid, url_wikidata)) except AttributeError: entities.append((ent.text, ent.label_, None, None)) for entity_string, entity_type, wikidata_id, wikidata_url in entities: if wikidata_url: entities_map[entity_string] = {"id": wikidata_id, "link": wikidata_url} entity_data = get_entity_data(wikidata_url) if entity_data: entities_data[entity_string] = entity_data else: # === CORRECTED ReFinED PROCESSING LOGIC === entities = model.process_text(text_input) # Iterate through the entity objects directly and safely for entity in entities: # Check if the entity has a wikidata_id before processing if entity.wikidata_id: entity_text = entity.text entity_id = entity.wikidata_id entity_link = f"http://www.wikidata.org/entity/{entity_id}" # Populate your dictionaries entities_map[entity_text] = {"id": entity_id, "link": entity_link} entity_data = get_entity_data(entity_link) if entity_data is not None: entities_data[entity_text] = entity_data except Exception as e: st.error(f"Error processing text: {e}") if "entityfishing" in str(e).lower(): st.error("This appears to be an entity-fishing related error. Please ensure:") st.error("1. Entity-fishing service is running") st.error("2. spacyfishing package is properly installed") st.error("3. Network connectivity to entity-fishing service") # Combine entity information combined_entity_info_dictionary = dict([(k, [entities_map[k], entities_data[k] if k in entities_data else None]) for k in entities_map]) if submit_button and entities_map: # Prepare a list to hold the final output final_text = [] # JSON-LD data json_ld_data = { "@context": "https://schema.org", "@type": "WebPage", "mentions": [] } # Replace each entity in the text with its annotated version for entity_string, entity_info in entities_map.items(): # Check if the entity has a valid Wikidata link if entity_info["link"] is None or entity_info["link"] == "None": continue # skip this entity entity_data = entities_data.get(entity_string, None) entity_type = None if entity_data is not None: entity_type = entity_data.get("@type", None) # Use different colors based on the entity's type color = "#8ef" # Default color if entity_type == "Place": color = "#8AC7DB" elif entity_type == "Organization": color = "#ADD8E6" elif entity_type == "Person": color = "#67B7D1" elif entity_type == "Product": color = "#2ea3f2" elif entity_type == "CreativeWork": color = "#00BFFF" elif entity_type == "Event": color = "#1E90FF" entity_annotation = (entity_string, entity_info["id"], color) text_input = text_input.replace(entity_string, f'{{{str(entity_annotation)}}}', 1) # Add the entity to JSON-LD data entity_json_ld = combined_entity_info_dictionary[entity_string][1] if entity_json_ld and entity_json_ld.get("link") != "None": json_ld_data["mentions"].append(entity_json_ld) # Split the modified text_input into a list text_list = text_input.split("{") for item in text_list: if "}" in item: item_list = item.split("}") try: final_text.append(eval(item_list[0])) except: final_text.append(item_list[0]) if len(item_list) > 1 and len(item_list[1]) > 0: final_text.append(item_list[1]) else: final_text.append(item) # Pass the final_text to the annotated_text function annotated_text(*final_text) with st.expander("See annotations"): st.write(combined_entity_info_dictionary) with st.expander("Here is the final JSON-LD"): st.json(json_ld_data) # Output JSON-LD elif submit_button and not entities_map: st.warning("No entities found in the text. Please try with different text or check if the model is working correctly.")