cyberandy commited on
Commit
014cdc2
·
verified ·
1 Parent(s): 49fe6db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -59
app.py CHANGED
@@ -74,87 +74,53 @@ else:
74
 
75
  @st.cache_resource # 👈 Add the caching decorator
76
  def load_model(selected_language, model_name=None, entity_set=None):
77
- # This dictionary maps the easy names to their full Hugging Face Hub IDs
78
- model_mapping = {
79
- "aida_model": "amazon-science/ReFinED-aida-model",
80
- "wikipedia_model_with_numbers": "amazon-science/ReFinED-wikipedia-model"
81
- }
82
-
83
  with warnings.catch_warnings():
84
  warnings.simplefilter("ignore")
85
  try:
86
- # This block handles the spaCy models for German and English
87
  if selected_language == "German":
88
  try:
89
  nlp_model_de = spacy.load("de_core_news_lg")
90
  except OSError:
91
- st.info("Downloading German language model... This may take a moment.")
92
  spacy.cli.download("de_core_news_lg")
93
  nlp_model_de = spacy.load("de_core_news_lg")
 
94
  if "entityfishing" not in nlp_model_de.pipe_names:
95
- try: nlp_model_de.add_pipe("entityfishing")
96
- except Exception as e: st.warning(f"Entity-fishing not available: {e}")
 
 
97
  return nlp_model_de
98
-
99
  elif selected_language == "English - spaCy":
100
  try:
101
  nlp_model_en = spacy.load("en_core_web_sm")
102
  except OSError:
103
- st.info("Downloading English language model... This may take a moment.")
104
  spacy.cli.download("en_core_web_sm")
105
  nlp_model_en = spacy.load("en_core_web_sm")
 
106
  if "entityfishing" not in nlp_model_en.pipe_names:
107
- try: nlp_model_en.add_pipe("entityfishing")
108
- except Exception as e: st.warning(f"Entity-fishing not available: {e}")
109
- return nlp_model_en
 
 
110
 
111
- # This block handles the ReFinED model and the "add_special_tokens" error
112
  else:
113
- try:
114
- return Refined.from_pretrained(model_name=model_name, entity_set=entity_set)
115
-
116
- except Exception as e:
117
- if "add_special_tokens" in str(e):
118
- st.warning("Conflict detected. Applying fix by downloading and patching model...")
119
-
120
- # 1. Get the REAL model name from our mapping
121
- real_model_name = model_mapping.get(model_name)
122
- if not real_model_name:
123
- st.error(f"Unknown model alias: {model_name}")
124
- return None
125
-
126
- # 2. Define a local path to save the fixed model
127
- local_model_path = f"./{model_name}-{entity_set}-fixed"
128
-
129
- # 3. Download the tokenizer and the model using the REAL name
130
- st.info(f"Downloading model files for {real_model_name}...")
131
- tokenizer = AutoTokenizer.from_pretrained(real_model_name)
132
- model_files = AutoModelForSeq2SeqLM.from_pretrained(real_model_name)
133
-
134
- # 4. Save them to the local directory
135
- tokenizer.save_pretrained(local_model_path)
136
- model_files.save_pretrained(local_model_path)
137
- st.info("Model files downloaded.")
138
-
139
- # 5. Patch the tokenizer config file
140
- config_path = os.path.join(local_model_path, "tokenizer_config.json")
141
- with open(config_path, "r") as f:
142
- config_data = json.load(f)
143
-
144
- config_data.pop("add_special_tokens", None) # Remove the conflicting key
145
-
146
- with open(config_path, "w") as f:
147
- json.dump(config_data, f, indent=2)
148
-
149
- # 6. Load the model from the local, fixed path
150
- st.success("Patch applied. Loading model from local cache...")
151
- return Refined.from_pretrained(model_name=local_model_path, entity_set=entity_set)
152
-
153
- else:
154
- raise e # If it's a different error, we still want to see it
155
 
156
  except Exception as e:
157
- st.error(f"Failed to load model. Error: {e}")
158
  return None
159
 
160
  # Use the cached model
 
74
 
75
  @st.cache_resource # 👈 Add the caching decorator
76
  def load_model(selected_language, model_name=None, entity_set=None):
77
+ """
78
+ Loads the appropriate model based on user selection.
79
+ This simplified version works by using older, compatible library versions
80
+ specified in requirements.txt, avoiding the 'add_special_tokens' conflict.
81
+ """
 
82
  with warnings.catch_warnings():
83
  warnings.simplefilter("ignore")
84
  try:
 
85
  if selected_language == "German":
86
  try:
87
  nlp_model_de = spacy.load("de_core_news_lg")
88
  except OSError:
89
+ st.info("Downloading German language model for the first time...")
90
  spacy.cli.download("de_core_news_lg")
91
  nlp_model_de = spacy.load("de_core_news_lg")
92
+
93
  if "entityfishing" not in nlp_model_de.pipe_names:
94
+ try:
95
+ nlp_model_de.add_pipe("entityfishing")
96
+ except Exception as e:
97
+ st.warning(f"Could not add entity-fishing pipe: {e}")
98
  return nlp_model_de
99
+
100
  elif selected_language == "English - spaCy":
101
  try:
102
  nlp_model_en = spacy.load("en_core_web_sm")
103
  except OSError:
104
+ st.info("Downloading English language model for the first time...")
105
  spacy.cli.download("en_core_web_sm")
106
  nlp_model_en = spacy.load("en_core_web_sm")
107
+
108
  if "entityfishing" not in nlp_model_en.pipe_names:
109
+ try:
110
+ nlp_model_en.add_pipe("entityfishing")
111
+ except Exception as e:
112
+ st.warning(f"Could not add entity-fishing pipe: {e}")
113
+ return nlp_model_en
114
 
 
115
  else:
116
+ # With the correct libraries, this will now work directly.
117
+ st.info(f"Loading ReFinED model: {model_name}...")
118
+ refined_model = Refined.from_pretrained(model_name=model_name, entity_set=entity_set)
119
+ st.success("ReFinED model loaded successfully!")
120
+ return refined_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  except Exception as e:
123
+ st.error(f"Error loading model: {e}")
124
  return None
125
 
126
  # Use the cached model