cyberandy commited on
Commit
813c7ba
·
verified ·
1 Parent(s): ad44c67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -27
app.py CHANGED
@@ -74,10 +74,14 @@ else:
74
 
75
  @st.cache_resource # 👈 Add the caching decorator
76
  def load_model(selected_language, model_name=None, entity_set=None):
77
- # Suppress warnings during model loading
 
 
 
 
 
78
  with warnings.catch_warnings():
79
  warnings.simplefilter("ignore")
80
-
81
  try:
82
  # This block handles the spaCy models for German and English
83
  if selected_language == "German":
@@ -87,12 +91,9 @@ def load_model(selected_language, model_name=None, entity_set=None):
87
  st.info("Downloading German language model... This may take a moment.")
88
  spacy.cli.download("de_core_news_lg")
89
  nlp_model_de = spacy.load("de_core_news_lg")
90
-
91
  if "entityfishing" not in nlp_model_de.pipe_names:
92
- try:
93
- nlp_model_de.add_pipe("entityfishing")
94
- except Exception as e:
95
- st.warning(f"Entity-fishing not available, using basic NER only: {e}")
96
  return nlp_model_de
97
 
98
  elif selected_language == "English - spaCy":
@@ -102,52 +103,58 @@ def load_model(selected_language, model_name=None, entity_set=None):
102
  st.info("Downloading English language model... This may take a moment.")
103
  spacy.cli.download("en_core_web_sm")
104
  nlp_model_en = spacy.load("en_core_web_sm")
105
-
106
  if "entityfishing" not in nlp_model_en.pipe_names:
107
- try:
108
- nlp_model_en.add_pipe("entityfishing")
109
- except Exception as e:
110
- st.warning(f"Entity-fishing not available, using basic NER only: {e}")
111
  return nlp_model_en
112
 
113
  # This block handles the ReFinED model and the "add_special_tokens" error
114
  else:
115
  try:
116
- # First, attempt to load the model as usual
117
  return Refined.from_pretrained(model_name=model_name, entity_set=entity_set)
118
 
119
  except Exception as e:
120
- # If the specific "add_special_tokens" error occurs, apply the fix
121
  if "add_special_tokens" in str(e):
122
- st.warning("Conflict detected. Applying fix by modifying tokenizer config...")
123
-
124
- # Define a local path to save/load the fixed model
 
 
 
 
 
 
125
  local_model_path = f"./{model_name}-{entity_set}-fixed"
126
 
127
- # Download tokenizer, modify config, and save locally
128
- tokenizer = AutoTokenizer.from_pretrained(model_name)
129
- tokenizer.save_pretrained(local_model_path)
 
130
 
 
 
 
 
 
 
131
  config_path = os.path.join(local_model_path, "tokenizer_config.json")
132
  with open(config_path, "r") as f:
133
  config_data = json.load(f)
134
 
135
- # Remove the conflicting parameter
136
- config_data.pop("add_special_tokens", None)
137
 
138
  with open(config_path, "w") as f:
139
  json.dump(config_data, f, indent=2)
140
 
141
- # Now, load the model from the local, fixed path
142
- st.success("Fix applied. Loading model from local cache.")
143
  return Refined.from_pretrained(model_name=local_model_path, entity_set=entity_set)
144
 
145
  else:
146
- # If it's a different error, raise it
147
- raise e
148
 
149
  except Exception as e:
150
- st.error(f"Error loading model: {e}")
151
  return None
152
 
153
  # Use the cached model
 
74
 
75
  @st.cache_resource # 👈 Add the caching decorator
76
  def load_model(selected_language, model_name=None, entity_set=None):
77
+ # This dictionary maps the easy names to their full Hugging Face Hub IDs
78
+ model_mapping = {
79
+ "aida_model": "amazon-science/ReFinED-aida-model",
80
+ "wikipedia_model_with_numbers": "amazon-science/ReFinED-wikipedia-model"
81
+ }
82
+
83
  with warnings.catch_warnings():
84
  warnings.simplefilter("ignore")
 
85
  try:
86
  # This block handles the spaCy models for German and English
87
  if selected_language == "German":
 
91
  st.info("Downloading German language model... This may take a moment.")
92
  spacy.cli.download("de_core_news_lg")
93
  nlp_model_de = spacy.load("de_core_news_lg")
 
94
  if "entityfishing" not in nlp_model_de.pipe_names:
95
+ try: nlp_model_de.add_pipe("entityfishing")
96
+ except Exception as e: st.warning(f"Entity-fishing not available: {e}")
 
 
97
  return nlp_model_de
98
 
99
  elif selected_language == "English - spaCy":
 
103
  st.info("Downloading English language model... This may take a moment.")
104
  spacy.cli.download("en_core_web_sm")
105
  nlp_model_en = spacy.load("en_core_web_sm")
 
106
  if "entityfishing" not in nlp_model_en.pipe_names:
107
+ try: nlp_model_en.add_pipe("entityfishing")
108
+ except Exception as e: st.warning(f"Entity-fishing not available: {e}")
 
 
109
  return nlp_model_en
110
 
111
  # This block handles the ReFinED model and the "add_special_tokens" error
112
  else:
113
  try:
 
114
  return Refined.from_pretrained(model_name=model_name, entity_set=entity_set)
115
 
116
  except Exception as e:
 
117
  if "add_special_tokens" in str(e):
118
+ st.warning("Conflict detected. Applying fix by downloading and patching model...")
119
+
120
+ # 1. Get the REAL model name from our mapping
121
+ real_model_name = model_mapping.get(model_name)
122
+ if not real_model_name:
123
+ st.error(f"Unknown model alias: {model_name}")
124
+ return None
125
+
126
+ # 2. Define a local path to save the fixed model
127
  local_model_path = f"./{model_name}-{entity_set}-fixed"
128
 
129
+ # 3. Download the tokenizer and the model using the REAL name
130
+ st.info(f"Downloading model files for {real_model_name}...")
131
+ tokenizer = AutoTokenizer.from_pretrained(real_model_name)
132
+ model_files = AutoModelForSeq2SeqLM.from_pretrained(real_model_name)
133
 
134
+ # 4. Save them to the local directory
135
+ tokenizer.save_pretrained(local_model_path)
136
+ model_files.save_pretrained(local_model_path)
137
+ st.info("Model files downloaded.")
138
+
139
+ # 5. Patch the tokenizer config file
140
  config_path = os.path.join(local_model_path, "tokenizer_config.json")
141
  with open(config_path, "r") as f:
142
  config_data = json.load(f)
143
 
144
+ config_data.pop("add_special_tokens", None) # Remove the conflicting key
 
145
 
146
  with open(config_path, "w") as f:
147
  json.dump(config_data, f, indent=2)
148
 
149
+ # 6. Load the model from the local, fixed path
150
+ st.success("Patch applied. Loading model from local cache...")
151
  return Refined.from_pretrained(model_name=local_model_path, entity_set=entity_set)
152
 
153
  else:
154
+ raise e # If it's a different error, we still want to see it
 
155
 
156
  except Exception as e:
157
+ st.error(f"Failed to load model. Error: {e}")
158
  return None
159
 
160
  # Use the cached model