Rom89823974978 commited on
Commit
b0ca692
·
1 Parent(s): 418b952
backend/main.py CHANGED
@@ -1,55 +1,252 @@
1
  from fastapi import FastAPI, Request, HTTPException, Depends
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from fastapi import Request
4
- # Access the global FastAPI app state
5
- #from fastapi import current_app
6
  from pydantic import BaseModel
7
- #try:
8
- # from rag import get_rag_chain, RAGRequest, RAGResponse
9
- #except:
10
- # from .rag import get_rag_chain, RAGRequest, RAGResponse
11
  from contextlib import asynccontextmanager
 
 
12
  import os
 
13
  import polars as pl
14
  import gcsfs
15
 
16
- #os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = r"C:\Users\Romain\OneDrive - KU Leuven\focal-pager-460414-e9-45369b738be0.json"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
 
 
 
 
 
 
 
 
 
 
18
  @asynccontextmanager
19
- async def lifespan(app: FastAPI):
20
- bucket = "mda_eu_project"
21
- path = "data/consolidated_clean_pred.parquet" #"data/consolidated_clean.parquet"
22
- uri = f"gs://{bucket}/{path}"
 
 
 
 
 
23
 
24
- fs = gcsfs.GCSFileSystem()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- with fs.open(uri, "rb") as f:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  df = pl.read_parquet(f)
28
 
 
29
  for col in ("title", "status", "legalBasis"):
30
  df = df.with_columns(pl.col(col).str.to_lowercase().alias(f"_{col}_lc"))
31
 
32
- statuses = df["_status_lc"].unique().to_list()
33
- legal_bases = df["_legalBasis_lc"].unique().to_list()
34
- organizations = df.explode("list_name")["list_name"].unique().to_list()
35
- countries = df.explode("list_country")["list_country"].unique().to_list()
36
-
37
- app.state.df = df
38
- app.state.statuses = statuses
39
- app.state.legal_bases = legal_bases
40
- app.state.orgs_list = organizations
41
- app.state.countries_list = countries
42
 
43
  yield
44
-
 
 
 
 
45
  app = FastAPI(lifespan=lifespan)
46
  app.add_middleware(
47
  CORSMiddleware,
48
- allow_origins=["*"],
49
  allow_methods=["*"],
50
  allow_headers=["*"],
51
  )
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  @app.get("/api/projects")
54
  def get_projects(
55
  page: int = 0,
@@ -71,10 +268,10 @@ def get_projects(
71
  if search:
72
  sel = sel.filter(pl.col("_title_lc").str.contains(search.lower()))
73
  if status:
74
- if status =="UNKNOWN":
75
- sel = sel.filter(pl.col("status").is_null())
76
- else:
77
- sel = sel.filter(pl.col("_status_lc") == status.lower())
78
  if legalBasis:
79
  sel = sel.filter(pl.col("_legalBasis_lc") == legalBasis.lower())
80
  if organization:
@@ -86,59 +283,57 @@ def get_projects(
86
  if proj_id:
87
  sel = sel.filter(pl.col("id") == proj_id)
88
 
89
- cols = [
90
- "id", "title", "status", "startDate", "endDate",
91
- "ecMaxContribution", "acronym", "legalBasis", "objective",
92
- "frameworkProgramme", "list_euroSciVocTitle", "list_euroSciVocPath","totalCost","list_isPublishedAs","fundingScheme"
93
  ]
94
- for i in range(1, 7):
95
- cols += [f"top{i}_feature", f"top{i}_shap"]
 
 
96
 
97
- cols += ["predicted_label", "predicted_prob"]
98
- sortOrder = True if sortOrder == "desc" else False
99
  sortField = sortField if sortField in df.columns else "startDate"
 
100
  rows = (
101
- sel.sort(sortField,descending=sortOrder).slice(start, limit)
102
- .select(cols)
 
103
  .to_dicts()
104
  )
105
 
106
  projects = []
107
  for row in rows:
108
  explanations = []
109
- for i in range(1, 7):
110
  feat = row.pop(f"top{i}_feature", None)
111
  shap = row.pop(f"top{i}_shap", None)
112
  if feat is not None and shap is not None:
113
  explanations.append({"feature": feat, "shap": shap})
114
  row["explanations"] = explanations
115
- # 2) transform list_publications into a { type: count } map
116
- raw_pubs = row.pop("list_publications", None) or []
117
- pub_counts: dict[str, int] = {}
118
- for entry in raw_pubs:
119
- # assuming entry is a string like "paper" or "peer reviewed paper"
120
- pub_counts[entry] = pub_counts.get(entry, 0) + 1
121
 
 
 
 
 
 
122
  row["publications"] = pub_counts
123
 
124
  projects.append(row)
125
 
126
  return projects
127
 
128
-
129
  @app.get("/api/filters")
130
  def get_filters(request: Request):
131
  df = app.state.df
132
  params = request.query_params
133
 
134
- # apply the same filters you use elsewhere
135
  if s := params.get("status"):
136
- if s == "UNKNOWN":
137
- df = df.filter(pl.col("status").is_null())
138
- else:
139
- df = df.filter(pl.col("_status_lc") == s.lower())
140
  if lb := params.get("legalBasis"):
141
- df = df.filter(pl.col("_legalBasis_lc") == lb.lower())
142
  if org := params.get("organization"):
143
  df = df.filter(pl.col("list_name").list.contains(org))
144
  if c := params.get("country"):
@@ -146,40 +341,39 @@ def get_filters(request: Request):
146
  if search := params.get("search"):
147
  df = df.filter(pl.col("_title_lc").str.contains(search.lower()))
148
 
149
- def normalize(values):
150
- return sorted(set("UNKNOWN" if v is None else v for v in values))
151
 
152
  return {
153
- "statuses": normalize(df["status"].to_list()),
154
- "legalBases": normalize(df["legalBasis"].to_list()),
155
  "organizations": normalize(df["list_name"].explode().to_list()),
156
- "countries": normalize(df["list_country"].explode().to_list()),
157
  "fundingSchemes": normalize(df["fundingScheme"].explode().to_list()),
158
- "ids": normalize(df["id"].to_list()),
159
  }
160
 
161
-
162
  @app.get("/api/stats")
163
  def get_stats(request: Request):
164
- params = request.query_params
165
  lf = app.state.df.lazy()
 
166
 
167
  if s := params.get("status"):
168
- lf = lf.filter(pl.col("_status_lc") == s.lower())
169
  if lb := params.get("legalBasis"):
170
- lf = lf.filter(pl.col("_legalBasis_lc") == lb.lower())
171
  if org := params.get("organization"):
172
  lf = lf.filter(pl.col("list_name").list.contains(org))
173
  if c := params.get("country"):
174
  lf = lf.filter(pl.col("list_country").list.contains(c))
175
  if mn := params.get("minFunding"):
176
- lf = lf.filter(pl.col("ecMaxContribution") >= int(mn))
177
  if mx := params.get("maxFunding"):
178
- lf = lf.filter(pl.col("ecMaxContribution") <= int(mx))
179
  if y1 := params.get("minYear"):
180
- lf = lf.filter(pl.col("startDate").dt.year() >= int(y1))
181
  if y2 := params.get("maxYear"):
182
- lf = lf.filter(pl.col("startDate").dt.year() <= int(y2))
183
 
184
  grouped = (
185
  lf.select(pl.col("startDate").dt.year().alias("year"))
@@ -189,25 +383,25 @@ def get_stats(request: Request):
189
  .collect()
190
  )
191
  years, counts = grouped["year"].to_list(), grouped["count"].to_list()
192
- return {"Projects per Year": {"labels": years, "values": counts},
193
- "Projects per Year 2": {"labels": years, "values": counts},
194
- "Projects per Year 3": {"labels": years, "values": counts},
195
- "Projects per Year 4": {"labels": years, "values": counts},
196
- "Projects per Year 5": {"labels": years, "values": counts},
197
- "Projects per Year 6": {"labels": years, "values": counts}}
198
 
 
 
 
 
 
 
 
 
199
 
200
  @app.get("/api/project/{project_id}/organizations")
201
  def get_project_organizations(project_id: str):
202
  df = app.state.df
203
-
204
- sel = df.filter(pl.col("id") == project_id)
205
  if sel.is_empty():
206
  raise HTTPException(status_code=404, detail="Project not found")
207
 
208
  orgs_df = (
209
- sel
210
- .select([
211
  pl.col("list_name").explode().alias("name"),
212
  pl.col("list_city").explode().alias("city"),
213
  pl.col("list_SME").explode().alias("sme"),
@@ -219,7 +413,6 @@ def get_project_organizations(project_id: str):
219
  pl.col("list_geolocation").explode().alias("geoloc"),
220
  ])
221
  .with_columns([
222
- # now this is a List(Utf8)
223
  pl.col("geoloc").str.split(",").alias("latlon"),
224
  ])
225
  .with_columns([
@@ -227,29 +420,10 @@ def get_project_organizations(project_id: str):
227
  pl.col("latlon").list.get(1).cast(pl.Float64).alias("longitude"),
228
  ])
229
  .filter(pl.col("name").is_not_null())
230
- .select(["name", "city", "sme","role","contribution","activityType","orgURL","country", "latitude", "longitude"])
 
 
 
231
  )
232
 
233
  return orgs_df.to_dicts()
234
-
235
- """def rag_chain_depender():
236
- """
237
- #Dependency injector for the RAG chain stored in app.state.
238
- #Raises HTTPException if not initialized.
239
- """
240
- chain = current_app.state.rag_chain
241
- if chain is None:
242
- raise HTTPException(status_code=500, detail="RAG chain not initialized")
243
- return chain
244
-
245
- @app.post("/rag", response_model=RAGResponse)
246
- async def ask_rag(
247
- req: RAGRequest,
248
- rag_chain = Depends(rag_chain_depender)
249
- ):
250
- """
251
- #Handle a RAG query. Uses session memory and the provided RAG chain.
252
- """
253
- # Invoke the chain with the named input
254
- result = await rag_chain.ainvoke({"question": req.query})
255
- return RAGResponse(answer=result["answer"])"""
 
1
  from fastapi import FastAPI, Request, HTTPException, Depends
2
  from fastapi.middleware.cors import CORSMiddleware
 
 
 
3
  from pydantic import BaseModel
4
+ from pydantic_settings import BaseSettings
 
 
 
5
  from contextlib import asynccontextmanager
6
+ from typing import Any, Dict, List, Optional, AsyncGenerator
7
+
8
  import os
9
+ import logging
10
  import polars as pl
11
  import gcsfs
12
 
13
+ from langchain.schema import Document,BaseRetriever
14
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
15
+ from langchain_community.vectorstores import FAISS
16
+ from langchain.retrievers.document_compressors import DocumentCompressorPipeline
17
+ from langchain_community.document_transformers import EmbeddingsRedundantFilter
18
+ from langchain.memory import ConversationBufferWindowMemory
19
+ from langchain.chains import ConversationalRetrievalChain
20
+ from langchain.prompts import PromptTemplate
21
+ from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings
22
+
23
+ from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM
24
+ from sentence_transformers import CrossEncoder
25
+
26
+ try:
27
+ # Preferred: direct import (works if rag.py is on sys.path)
28
+ from rag import build_indexes, bm25_search, load_documents
29
+ except ImportError:
30
+ try:
31
+ # Next: relative import (works if you're inside a package)
32
+ from .rag import build_indexes, bm25_search, load_documents
33
+ except ImportError:
34
+ # Last: explicit absolute package import
35
+ from app.rag import build_indexes, bm25_search, load_documents
36
+
37
+ from functools import lru_cache
38
+ # ---------------------------------------------------------------------------- #
39
+ # Settings #
40
+ # ---------------------------------------------------------------------------- #
41
+ class Settings(BaseSettings):
42
+ # Parquet + Whoosh/FAISS
43
+ parquet_path: str = "gs://mda_eu_project/data/consolidated_clean_pred.parquet"
44
+ whoosh_dir: str = "gs://mda_eu_project/whoosh_index"
45
+ vectorstore_path: str = "gs://mda_eu_project/vectorstore_index"
46
+ # Models
47
+ embedding_model: str = "sentence-transformers/LaBSE"
48
+ llm_model: str = "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w4a16"
49
+ cross_encoder_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
50
+ # RAG parameters
51
+ chunk_size: int = 750
52
+ chunk_overlap: int = 100
53
+ hybrid_k: int = 50
54
+ assistant_role: str = (
55
+ "You are a concise, factual assistant. Cite Document [ID] for each claim."
56
+ )
57
+ skip_warmup: bool = False
58
+ allowed_origins: List[str] = ["*"]
59
+
60
+ class Config:
61
+ env_file = ".env"
62
+
63
+ settings = Settings()
64
+
65
+ # Pre‐instantiate embedding model (used by filter/compressor)
66
+ EMBEDDING = HuggingFaceEmbeddings(model_name=settings.embedding_model)
67
+
68
+ @lru_cache(maxsize=256)
69
+ def embed_query_cached(query: str) -> List[float]:
70
+ """Cache embedding vectors for queries."""
71
+ return EMBEDDING.embed_query(query.strip().lower())
72
+
73
+ # === Hybrid Retriever ===
74
+ class HybridRetriever(BaseRetriever):
75
+ """Hybrid retriever combining BM25 and FAISS with cross-encoder re-ranking."""
76
+ # store FAISS and Whoosh under private attributes to avoid Pydantic field errors
77
+ from pydantic import PrivateAttr
78
+ _vs: FAISS = PrivateAttr()
79
+ _ix: index.Index = PrivateAttr()
80
+ _compressor: DocumentCompressorPipeline = PrivateAttr()
81
+ _cross_encoder: CrossEncoder = PrivateAttr()
82
+
83
+ def __init__(
84
+ self,
85
+ vs: FAISS,
86
+ ix: index.Index,
87
+ compressor: DocumentCompressorPipeline,
88
+ cross_encoder: CrossEncoder
89
+ ) -> None:
90
+ super().__init__()
91
+ object.__setattr__(self, '_vs', vs)
92
+ object.__setattr__(self, '_ix', ix)
93
+ object.__setattr__(self, '_compressor', compressor)
94
+ object.__setattr__(self, '_cross_encoder', cross_encoder)
95
+
96
+ async def _aget_relevant_documents(self, query: str) -> List[Document]:
97
+ # BM25 retrieval using Whoosh index
98
+ bm_docs = await bm25_search(self._ix, query, settings.hybrid_k)
99
+ # Dense retrieval using FAISS
100
+ dense_docs = self._vs.similarity_search_by_vector(
101
+ embed_query_cached(query), k=settings.hybrid_k
102
+ )
103
+ # Cross-encoder re-ranking
104
+ candidates = bm_docs + dense_docs
105
+ scores = self._cross_encoder.predict([
106
+ (query, doc.page_content) for doc in candidates
107
+ ])
108
+ ranked = sorted(zip(scores, candidates), key=lambda x: x[0], reverse=True)
109
+ top = [doc for _, doc in ranked[: settings.hybrid_k]]
110
+ # Compress and return
111
+ return self._compressor.compress_documents(top, query=query)
112
 
113
+ def _get_relevant_documents(self, query: str) -> List[Document]:
114
+ import asyncio
115
+ return asyncio.get_event_loop().run_until_complete(
116
+ self._aget_relevant_documents(query)
117
+ )
118
+
119
+ #os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = r"C:\Users\Romain\OneDrive - KU Leuven\focal-pager-460414-e9-45369b738be0.json"
120
+ # ---------------------------------------------------------------------------- #
121
+ # Single Lifespan #
122
+ # ---------------------------------------------------------------------------- #
123
  @asynccontextmanager
124
+ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
125
+ # --- 1) RAG Initialization --- #
126
+ logger = logging.getLogger("uvicorn")
127
+ logger.info("Initializing RAG components…")
128
+
129
+ # Compressor pipeline to de‐duplicate via embeddings
130
+ compressor = DocumentCompressorPipeline(
131
+ transformers=[EmbeddingsRedundantFilter(embeddings=EMBEDDING)]
132
+ )
133
 
134
+ # Cross‐encoder ranker
135
+ cross_encoder = CrossEncoder(settings.cross_encoder_model)
136
+
137
+ # Causal LLM pipeline
138
+ llm_model = AutoModelForCausalLM.from_pretrained(settings.llm_model)
139
+ gen_pipe = pipeline(
140
+ "text-generation",
141
+ model=llm_model,
142
+ tokenizer=AutoTokenizer.from_pretrained(settings.llm_model),
143
+ max_new_tokens=256,
144
+ do_sample=True,
145
+ temperature=0.7,
146
+ )
147
+ llm = HuggingFacePipeline(pipeline=gen_pipe)
148
+
149
+ # Conversational memory
150
+ memory = ConversationBufferWindowMemory(
151
+ memory_key="chat_history",
152
+ k=5,
153
+ input_key="question",
154
+ output_key="answer",
155
+ return_messages=True,
156
+ )
157
 
158
+ # Build or load FAISS & Whoosh once
159
+ vs, ix = await build_indexes(
160
+ settings.parquet_path,
161
+ settings.vectorstore_path,
162
+ settings.whoosh_dir,
163
+ settings.chunk_size,
164
+ settings.chunk_overlap,
165
+ None,
166
+ )
167
+ retriever = HybridRetriever(vs=vs, ix=ix, compressor=compressor, cross_encoder=cross_encoder)
168
+
169
+ prompt = PromptTemplate.from_template(
170
+ f"{settings.assistant_role}\n\n"
171
+ "Context (up to 2,000 tokens, with document IDs):\n"
172
+ "{context}\n"
173
+ "Q: {question}\n"
174
+ "A: Provide your answer."
175
+ )
176
+
177
+ app.state.rag_chain = ConversationalRetrievalChain.from_llm(
178
+ llm=llm,
179
+ retriever=retriever,
180
+ memory=memory,
181
+ combine_docs_chain_kwargs={"prompt": prompt},
182
+ return_source_documents=True,
183
+ )
184
+
185
+ if not settings.skip_warmup:
186
+ logger.info("Warming up RAG chain…")
187
+ await app.state.rag_chain.ainvoke({"question": "warmup"})
188
+ logger.info("RAG ready.")
189
+
190
+ # --- 2) Dataframe Initialization --- #
191
+ logger.info("Loading Parquet data from GCS…")
192
+ fs = gcsfs.GCSFileSystem()
193
+ with fs.open(settings.parquet_path, "rb") as f:
194
  df = pl.read_parquet(f)
195
 
196
+ # lowercase for filtering
197
  for col in ("title", "status", "legalBasis"):
198
  df = df.with_columns(pl.col(col).str.to_lowercase().alias(f"_{col}_lc"))
199
 
200
+ # materialize unique filter values
201
+ app.state.df = df
202
+ app.state.statuses = df["_status_lc"].unique().to_list()
203
+ app.state.legal_bases = df["_legalBasis_lc"].unique().to_list()
204
+ app.state.orgs_list = df.explode("list_name")["list_name"].unique().to_list()
205
+ app.state.countries_list = df.explode("list_country")["list_country"].unique().to_list()
 
 
 
 
206
 
207
  yield
208
+ # teardown (if any) goes here
209
+
210
+ # ---------------------------------------------------------------------------- #
211
+ # App Setup #
212
+ # ---------------------------------------------------------------------------- #
213
  app = FastAPI(lifespan=lifespan)
214
  app.add_middleware(
215
  CORSMiddleware,
216
+ allow_origins=settings.allowed_origins,
217
  allow_methods=["*"],
218
  allow_headers=["*"],
219
  )
220
 
221
+ # ---------------------------------------------------------------------------- #
222
+ # RAG Endpoint #
223
+ # ---------------------------------------------------------------------------- #
224
+ class RAGRequest(BaseModel):
225
+ session_id: Optional[str]
226
+ query: str
227
+
228
+ class RAGResponse(BaseModel):
229
+ answer: str
230
+ source_ids: List[str]
231
+
232
+ def rag_chain_depender(app: FastAPI = Depends(lambda: app)) -> Any:
233
+ chain = app.state.rag_chain
234
+ if chain is None:
235
+ raise HTTPException(status_code=500, detail="RAG chain not initialized")
236
+ return chain
237
+
238
+ @app.post("/rag", response_model=RAGResponse)
239
+ async def ask_rag(
240
+ req: RAGRequest,
241
+ rag_chain = Depends(rag_chain_depender)
242
+ ):
243
+ result = await rag_chain.ainvoke({"question": req.query})
244
+ sources = [doc.metadata.get("id", "") for doc in result.get("source_documents", [])]
245
+ return RAGResponse(answer=result["answer"], source_ids=sources)
246
+
247
+ # ---------------------------------------------------------------------------- #
248
+ # Data Endpoints #
249
+ # ---------------------------------------------------------------------------- #
250
  @app.get("/api/projects")
251
  def get_projects(
252
  page: int = 0,
 
268
  if search:
269
  sel = sel.filter(pl.col("_title_lc").str.contains(search.lower()))
270
  if status:
271
+ sel = sel.filter(
272
+ pl.col("status").is_null() if status == "UNKNOWN"
273
+ else pl.col("_status_lc") == status.lower()
274
+ )
275
  if legalBasis:
276
  sel = sel.filter(pl.col("_legalBasis_lc") == legalBasis.lower())
277
  if organization:
 
283
  if proj_id:
284
  sel = sel.filter(pl.col("id") == proj_id)
285
 
286
+ base_cols = [
287
+ "id","title","status","startDate","endDate","ecMaxContribution","acronym",
288
+ "legalBasis","objective","frameworkProgramme","list_euroSciVocTitle",
289
+ "list_euroSciVocPath","totalCost","list_isPublishedAs","fundingScheme"
290
  ]
291
+ # add shap/explanation columns
292
+ for i in range(1,7):
293
+ base_cols += [f"top{i}_feature", f"top{i}_shap"]
294
+ base_cols += ["predicted_label","predicted_prob"]
295
 
296
+ sort_desc = True if sortOrder=="desc" else False
 
297
  sortField = sortField if sortField in df.columns else "startDate"
298
+
299
  rows = (
300
+ sel.sort(sortField, descending=sort_desc)
301
+ .slice(start, limit)
302
+ .select(base_cols)
303
  .to_dicts()
304
  )
305
 
306
  projects = []
307
  for row in rows:
308
  explanations = []
309
+ for i in range(1,7):
310
  feat = row.pop(f"top{i}_feature", None)
311
  shap = row.pop(f"top{i}_shap", None)
312
  if feat is not None and shap is not None:
313
  explanations.append({"feature": feat, "shap": shap})
314
  row["explanations"] = explanations
 
 
 
 
 
 
315
 
316
+ # publications aggregation
317
+ raw_pubs = row.pop("list_publications", []) or []
318
+ pub_counts: Dict[str,int] = {}
319
+ for p in raw_pubs:
320
+ pub_counts[p] = pub_counts.get(p, 0) + 1
321
  row["publications"] = pub_counts
322
 
323
  projects.append(row)
324
 
325
  return projects
326
 
 
327
  @app.get("/api/filters")
328
  def get_filters(request: Request):
329
  df = app.state.df
330
  params = request.query_params
331
 
 
332
  if s := params.get("status"):
333
+ df = df.filter(pl.col("status").is_null() if s=="UNKNOWN"
334
+ else pl.col("_status_lc")==s.lower())
 
 
335
  if lb := params.get("legalBasis"):
336
+ df = df.filter(pl.col("_legalBasis_lc")==lb.lower())
337
  if org := params.get("organization"):
338
  df = df.filter(pl.col("list_name").list.contains(org))
339
  if c := params.get("country"):
 
341
  if search := params.get("search"):
342
  df = df.filter(pl.col("_title_lc").str.contains(search.lower()))
343
 
344
+ def normalize(vals):
345
+ return sorted({("UNKNOWN" if v is None else v) for v in vals})
346
 
347
  return {
348
+ "statuses": normalize(df["status"].to_list()),
349
+ "legalBases": normalize(df["legalBasis"].to_list()),
350
  "organizations": normalize(df["list_name"].explode().to_list()),
351
+ "countries": normalize(df["list_country"].explode().to_list()),
352
  "fundingSchemes": normalize(df["fundingScheme"].explode().to_list()),
353
+ "ids": normalize(df["id"].to_list()),
354
  }
355
 
 
356
  @app.get("/api/stats")
357
  def get_stats(request: Request):
 
358
  lf = app.state.df.lazy()
359
+ params = request.query_params
360
 
361
  if s := params.get("status"):
362
+ lf = lf.filter(pl.col("_status_lc")==s.lower())
363
  if lb := params.get("legalBasis"):
364
+ lf = lf.filter(pl.col("_legalBasis_lc")==lb.lower())
365
  if org := params.get("organization"):
366
  lf = lf.filter(pl.col("list_name").list.contains(org))
367
  if c := params.get("country"):
368
  lf = lf.filter(pl.col("list_country").list.contains(c))
369
  if mn := params.get("minFunding"):
370
+ lf = lf.filter(pl.col("ecMaxContribution")>=int(mn))
371
  if mx := params.get("maxFunding"):
372
+ lf = lf.filter(pl.col("ecMaxContribution")<=int(mx))
373
  if y1 := params.get("minYear"):
374
+ lf = lf.filter(pl.col("startDate").dt.year()>=int(y1))
375
  if y2 := params.get("maxYear"):
376
+ lf = lf.filter(pl.col("startDate").dt.year()<=int(y2))
377
 
378
  grouped = (
379
  lf.select(pl.col("startDate").dt.year().alias("year"))
 
383
  .collect()
384
  )
385
  years, counts = grouped["year"].to_list(), grouped["count"].to_list()
 
 
 
 
 
 
386
 
387
+ return {
388
+ "Projects per Year": {"labels": years, "values": counts},
389
+ "Projects per Year 2": {"labels": years, "values": counts},
390
+ "Projects per Year 3": {"labels": years, "values": counts},
391
+ "Projects per Year 4": {"labels": years, "values": counts},
392
+ "Projects per Year 5": {"labels": years, "values": counts},
393
+ "Projects per Year 6": {"labels": years, "values": counts},
394
+ }
395
 
396
  @app.get("/api/project/{project_id}/organizations")
397
  def get_project_organizations(project_id: str):
398
  df = app.state.df
399
+ sel = df.filter(pl.col("id")==project_id)
 
400
  if sel.is_empty():
401
  raise HTTPException(status_code=404, detail="Project not found")
402
 
403
  orgs_df = (
404
+ sel.select([
 
405
  pl.col("list_name").explode().alias("name"),
406
  pl.col("list_city").explode().alias("city"),
407
  pl.col("list_SME").explode().alias("sme"),
 
413
  pl.col("list_geolocation").explode().alias("geoloc"),
414
  ])
415
  .with_columns([
 
416
  pl.col("geoloc").str.split(",").alias("latlon"),
417
  ])
418
  .with_columns([
 
420
  pl.col("latlon").list.get(1).cast(pl.Float64).alias("longitude"),
421
  ])
422
  .filter(pl.col("name").is_not_null())
423
+ .select([
424
+ "name","city","sme","role","contribution",
425
+ "activityType","orgURL","country","latitude","longitude"
426
+ ])
427
  )
428
 
429
  return orgs_df.to_dicts()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
backend/rag.py CHANGED
@@ -1,15 +1,15 @@
1
  import os
2
- import io
3
- import asyncio
4
- from functools import lru_cache
5
- from typing import Any, Dict, List
6
 
7
  import gcsfs
 
8
  import polars as pl
9
  from pydantic_settings import BaseSettings
10
- from fastapi import FastAPI, HTTPException, Depends, Request
11
  from fastapi.middleware.cors import CORSMiddleware
12
- from contextlib import asynccontextmanager
13
  from pydantic import BaseModel
14
 
15
  from langchain.schema import BaseRetriever, Document
@@ -17,241 +17,286 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from langchain_community.vectorstores import FAISS
18
  from langchain.retrievers.document_compressors import DocumentCompressorPipeline
19
  from langchain_community.document_transformers import EmbeddingsRedundantFilter
20
- from langchain.memory import ConversationBufferMemory
21
  from langchain.chains import ConversationalRetrievalChain
22
  from langchain.prompts import PromptTemplate
23
  from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings
24
 
25
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
 
26
  from sentence_transformers import CrossEncoder
27
-
28
  from whoosh import index
29
  from whoosh.fields import Schema, TEXT, ID
30
  from whoosh.analysis import StemmingAnalyzer
31
  from whoosh.qparser import MultifieldParser
 
 
32
 
33
- # === Settings ===
34
- class Settings(BaseSettings):
35
- # allow GCS paths (e.g. "gs://bucket/...") via gcsfs
36
- parquet_path: str = "/content/drive/MyDrive/consolidated_clean.parquet"
37
- vectorstore_path: str = "/content/drive/MyDrive/vectorstore_index"
38
- whoosh_dir: str = "whoosh_index"
39
- embedding_model: str = "sentence-transformers/all-mpnet-base-v2"
40
- llm_model: str = "bigscience/bloomz-560m"
41
- cross_encoder_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
42
- chunk_size: int = 300
43
- chunk_overlap: int = 50
44
- hybrid_k: int = 50
45
- assistant_role: str = "You are a concise, factual assistant. Cite Document [ID] for each claim."
46
 
47
- class Config:
48
- env_file = ".env"
 
49
 
50
- settings = Settings()
 
51
 
52
- # === Globals ===
53
- embedding: HuggingFaceEmbeddings # will be set in lifespan
54
- cross_encoder: CrossEncoder
55
- compressor: DocumentCompressorPipeline
56
- session_meta: Dict[str, Dict] = {} # per‐session metadata
57
- # rag_chain will be stored on app.state
58
 
59
- # === Whoosh Index Builder ===
60
  _WHOOSH_CACHE: Dict[str, index.Index] = {}
61
- def build_whoosh_index(docs: List[Document]) -> index.Index:
62
- key = settings.whoosh_dir
63
- if key in _WHOOSH_CACHE:
64
- return _WHOOSH_CACHE[key]
65
- os.makedirs(key, exist_ok=True)
66
- schema = Schema(id=ID(stored=True, unique=True),
67
- content=TEXT(analyzer=StemmingAnalyzer()))
68
- ix = index.create_in(key, schema)
69
- writer = ix.writer()
70
- for doc in docs:
71
- writer.add_document(id=doc.metadata["id"], content=doc.page_content)
72
- writer.commit()
73
- _WHOOSH_CACHE[key] = ix
74
- return ix
75
 
76
- # === Load Documents ===
77
- def load_documents(path: str) -> List[Document]:
78
- # support GCS via gcsfs
79
- if path.startswith("gs://"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  fs = gcsfs.GCSFileSystem()
81
- with fs.open(path, "rb") as f:
82
- df = pl.read_parquet(f).to_pandas()
83
- else:
84
- df = pl.read_parquet(path).to_pandas()
85
- df = df.head(20)
86
- docs = []
87
- for _, row in df.iterrows():
88
- text = f"{row.get('title','')} {row.get('objective','')} {row.get('list_description','')}"
89
- meta = {c: str(row[c]) for c in row.index
90
- if c in ("id","startDate","topics","list_euroSciVocTitle","list_euroSciVocPath")}
91
- docs.append(Document(page_content=text, metadata=meta))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  return docs
93
 
94
- # === Query Expansion & BM25 ===
95
- def expand_query_with_metadata(query: str, md: Dict[str,str]) -> str:
96
- parts = [query]
97
- for key,label in [
98
- ("startDate","from"), ("topics","about"),
99
- ("list_euroSciVocTitle","with topics"), ("list_euroSciVocPath","with topic path")
100
- ]:
101
- if md.get(key):
102
- parts.append(f"{label} {md[key]}")
103
- return "; ".join(parts)
104
-
105
- async def iterative_retrieval(
106
- query: str, vs: FAISS, ix: index.Index,
107
- rewriter: HuggingFacePipeline, md: Dict[str,str]
108
- ) -> List[Document]:
109
- first_q = await rewriter.ainvoke({"query": query})
110
  parser = MultifieldParser(["content"], schema=ix.schema)
111
- def bm25(q: str) -> List[Document]:
112
- with ix.searcher() as s:
113
- hits = s.search(parser.parse(q), limit=settings.hybrid_k)
114
- return [
115
- Document(page_content=h["content"], metadata={"id": h["id"]})
116
- for h in hits
117
- ]
118
- docs1 = bm25(first_q)
119
- docs2 = bm25(expand_query_with_metadata(first_q, md))
120
- uniq = {d.metadata["id"]: d for d in docs1 + docs2}
121
- return list(uniq.values())
122
-
123
- # === Cached Embeddings ===
124
- @lru_cache(maxsize=1024)
125
- def embed_query_cached(q: str):
126
- return embedding.embed_query(q)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- # === Retriever & Index Builder ===
129
- def build_retriever(docs: List[Document], embedder: HuggingFaceEmbeddings):
130
- ix = build_whoosh_index(docs)
131
  splitter = RecursiveCharacterTextSplitter(
132
- chunk_size=settings.chunk_size,
133
- chunk_overlap=settings.chunk_overlap
134
  )
135
  chunks = splitter.split_documents(docs)
136
- if os.path.exists(settings.vectorstore_path):
137
- vs = FAISS.load_local(
138
- settings.vectorstore_path,
139
- embeddings=embedder,
140
- allow_dangerous_deserialization=True
141
- )
142
- else:
143
- vs = FAISS.from_documents(chunks, embedder)
144
- vs.save_local(settings.vectorstore_path)
145
- return vs, ix
146
-
147
- # === Rewriter & LLM Pipelines ===
148
- rewriter_pipe = pipeline(
149
- "text-generation",
150
- model=AutoModelForCausalLM.from_pretrained(settings.llm_model),
151
- tokenizer=AutoTokenizer.from_pretrained(settings.llm_model),
152
- max_new_tokens=64, truncation=True, do_sample=False, return_full_text=False
153
- )
154
- rewriter = HuggingFacePipeline(pipeline=rewriter_pipe)
155
- query_rewriter = PromptTemplate.from_template("Rewrite question: {query}") | rewriter
156
-
157
- def get_llm_pipeline():
158
- gen_pipe = pipeline(
159
- "text-generation",
160
- model=AutoModelForCausalLM.from_pretrained(settings.llm_model),
161
- tokenizer=AutoTokenizer.from_pretrained(settings.llm_model),
162
- max_new_tokens=256, truncation=True,
163
- temperature=0.7, do_sample=True, return_full_text=False
164
- )
165
- return HuggingFacePipeline(pipeline=gen_pipe)
166
-
167
- combine_prompt = PromptTemplate.from_template(
168
- f"""{settings.assistant_role}
169
 
170
- Context (up to 2,000 tokens):
171
- {{context}}
172
 
173
- Q: {{question}}
174
- A:
175
- """
176
- )
177
-
178
- async def retrieve_and_rank(
179
- query: str, vs: FAISS, ix: index.Index, md: Dict[str,str]
180
- ) -> List[Document]:
181
- bm25_docs = await iterative_retrieval(query, vs, ix, query_rewriter, md)
182
- dense = vs.similarity_search_by_vector(embed_query_cached(query), k=settings.hybrid_k)
183
- candidates = bm25_docs + dense
184
- scores = cross_encoder.predict([(query, d.page_content) for d in candidates])
185
- top_docs = [doc for _, doc in sorted(zip(scores, candidates), reverse=True)[:settings.hybrid_k]]
186
- return compressor.compress_documents(top_docs, query=query)
187
-
188
- class HybridRetriever(BaseRetriever):
189
- vs: Any
190
- ix: Any
191
- meta_store: Dict[str,Dict]
192
-
193
- def __init__(self, vs, ix, meta_store):
194
- super().__init__(vs=vs, ix=ix, meta_store=meta_store)
195
- def _get_relevant_documents(self, query: str):
196
- return asyncio.get_event_loop().run_until_complete(
197
- retrieve_and_rank(query, self.vs, self.ix, self.meta_store.get(query, {}))
198
- )
199
- async def _aget_relevant_documents(self, query: str):
200
- return await retrieve_and_rank(query, self.vs, self.ix, self.meta_store.get(query, {}))
201
-
202
- # === FastAPI Lifespan & Initialization ===
203
- @asynccontextmanager
204
- async def lifespan(app: FastAPI):
205
- global embedding, cross_encoder, compressor
206
-
207
- # 1) Embeddings & docs
208
- embedding = HuggingFaceEmbeddings(model_name=settings.embedding_model)
209
- docs = load_documents(settings.parquet_path)
210
-
211
- # 2) Retriever + Whoosh index
212
- vs, ix = build_retriever(docs, embedding)
213
-
214
- # 3) Compressor and reranker
215
- compressor = DocumentCompressorPipeline(
216
- transformers=[EmbeddingsRedundantFilter(embeddings=embedding)]
217
- )
218
- cross_encoder = CrossEncoder(settings.cross_encoder_model)
219
-
220
- # 4) LLM & memory
221
- llm = get_llm_pipeline()
222
- memory = ConversationBufferMemory(
223
- memory_key="chat_history",
224
- input_key="question",
225
- output_key="answer",
226
- return_messages=True
227
- )
228
-
229
- # 5) Build RAG chain
230
- retriever = HybridRetriever(vs, ix, session_meta)
231
- rag_chain = ConversationalRetrievalChain.from_llm(
232
- llm=llm,
233
- retriever=retriever,
234
- memory=memory,
235
- combine_docs_chain_kwargs={"prompt": combine_prompt},
236
- return_source_documents=True
237
- )
238
-
239
- # 6) Warm-up
240
- vs.similarity_search("warmup", k=1)
241
- embed_query_cached("warmup")
242
- await rag_chain.ainvoke({"question": "warmup"})
243
-
244
- # 7) Store on app.state so get_rag_chain() can find it
245
- app.state.rag_chain = rag_chain
246
-
247
- yield
248
-
249
- # === Pydantic Models & FastAPI Setup ===
250
- class RAGRequest(BaseModel):
251
- session_id: str
252
- topic: str
253
- startDate: str
254
- query: str
255
-
256
- class RAGResponse(BaseModel):
257
- answer: str
 
1
  import os
2
+ import logging
3
+ from typing import Any, Dict, List, Optional, Tuple, AsyncGenerator
4
+ from contextlib import asynccontextmanager
 
5
 
6
  import gcsfs
7
+ import aiofiles
8
  import polars as pl
9
  from pydantic_settings import BaseSettings
10
+ from fastapi import FastAPI, HTTPException
11
  from fastapi.middleware.cors import CORSMiddleware
12
+ from starlette.concurrency import run_in_threadpool
13
  from pydantic import BaseModel
14
 
15
  from langchain.schema import BaseRetriever, Document
 
17
  from langchain_community.vectorstores import FAISS
18
  from langchain.retrievers.document_compressors import DocumentCompressorPipeline
19
  from langchain_community.document_transformers import EmbeddingsRedundantFilter
20
+ from langchain.memory import ConversationBufferWindowMemory
21
  from langchain.chains import ConversationalRetrievalChain
22
  from langchain.prompts import PromptTemplate
23
  from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings
24
 
25
+ from transformers import AutoTokenizer, pipeline
26
+ from transformers import AutoTokenizer, AutoModelForCausalLM
27
+ #from optimum.onnxruntime import ORTModelForCausalLM, ORTOptimizer
28
  from sentence_transformers import CrossEncoder
 
29
  from whoosh import index
30
  from whoosh.fields import Schema, TEXT, ID
31
  from whoosh.analysis import StemmingAnalyzer
32
  from whoosh.qparser import MultifieldParser
33
+ from tqdm import tqdm
34
+ import faiss
35
 
36
+ from functools import lru_cache
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ # === Logging ===
39
+ logging.basicConfig(level=logging.INFO)
40
+ logger = logging.getLogger(__name__)
41
 
42
+ # === Global Embeddings & Cache ===
43
+ EMBEDDING = HuggingFaceEmbeddings(model_name=settings.embedding_model)
44
 
45
+ @lru_cache(maxsize=256)
46
+ def embed_query_cached(query: str) -> List[float]:
47
+ """Cache embedding vectors for queries."""
48
+ return EMBEDDING.embed_query(query.strip().lower())
 
 
49
 
50
+ # === Whoosh Cache & Builder ===
51
  _WHOOSH_CACHE: Dict[str, index.Index] = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ async def build_whoosh_index(docs: List[Document], whoosh_dir: str) -> index.Index:
54
+ key = whoosh_dir
55
+ fs = gcsfs.GCSFileSystem()
56
+ local_dir = key
57
+ is_gcs = key.startswith("gs://")
58
+ try:
59
+ # stage local copy for GCS
60
+ if is_gcs:
61
+ local_dir = "/tmp/whoosh_index"
62
+ if not os.path.exists(local_dir):
63
+ if await run_in_threadpool(fs.exists, key):
64
+ await run_in_threadpool(fs.get, key, local_dir, recursive=True)
65
+ else:
66
+ os.makedirs(local_dir, exist_ok=True)
67
+ # build once
68
+ if key not in _WHOOSH_CACHE:
69
+ os.makedirs(local_dir, exist_ok=True)
70
+ schema = Schema(
71
+ id=ID(stored=True, unique=True),
72
+ content=TEXT(analyzer=StemmingAnalyzer()),
73
+ )
74
+ ix = index.create_in(local_dir, schema)
75
+ with ix.writer() as writer:
76
+ for doc in docs:
77
+ writer.add_document(
78
+ id=doc.metadata.get("id", ""),
79
+ content=doc.page_content,
80
+ )
81
+ # push back to GCS atomically
82
+ if is_gcs:
83
+ await run_in_threadpool(fs.put, local_dir, key, recursive=True)
84
+ _WHOOSH_CACHE[key] = ix
85
+ return _WHOOSH_CACHE[key]
86
+ except Exception as e:
87
+ logger.error(f"Failed to build Whoosh index: {e}")
88
+ raise
89
+
90
+ # === Document Loader ===
91
+ async def load_documents(
92
+ path: str,
93
+ sample_size: Optional[int] = None
94
+ ) -> List[Document]:
95
+ """
96
+ Load a Parquet file from local or GCS, convert to a list of Documents.
97
+ """
98
+ def _read_local(p: str, n: Optional[int]):
99
+ # streaming scan keeps memory low
100
+ lf = pl.scan_parquet(p)
101
+ if n:
102
+ lf = lf.limit(n)
103
+ return lf.collect(streaming=True)
104
+
105
+ def _read_gcs(p: str, n: Optional[int]):
106
+ # download to a temp file synchronously, then read with Polars
107
  fs = gcsfs.GCSFileSystem()
108
+ with tempfile.TemporaryDirectory() as td:
109
+ local_path = os.path.join(td, "data.parquet")
110
+ fs.get(p, local_path, recursive=False)
111
+ df = pl.read_parquet(local_path)
112
+ if n:
113
+ df = df.head(n)
114
+ return df
115
+
116
+ try:
117
+ if path.startswith("gs://"):
118
+ df = await run_in_threadpool(_read_gcs, path, sample_size)
119
+ else:
120
+ df = await run_in_threadpool(_read_local, path, sample_size)
121
+ except Exception as e:
122
+ logger.error(f"Error loading documents: {e}")
123
+ raise HTTPException(status_code=500, detail="Document loading failed.")
124
+
125
+ docs: List[Document] = []
126
+ for row in df.rows(named=True):
127
+ context_parts: List[str] = []
128
+ # build metadata context
129
+ max_contrib = row.get("ecMaxContribution", "")
130
+ end_date = row.get("endDate", "")
131
+ duration = row.get("durationDays", "")
132
+ status = row.get("status", "")
133
+ legal = row.get("legalBasis", "")
134
+ framework = row.get("frameworkProgramme", "")
135
+ scheme = row.get("fundingScheme", "")
136
+ names = row.get("list_name", []) or []
137
+ cities = row.get("list_city", []) or []
138
+ countries = row.get("list_country", []) or []
139
+ activity = row.get("list_activityType", []) or []
140
+ contributions = row.get("list_ecContribution", []) or []
141
+ smes = row.get("list_sme", []) or []
142
+ project_id =row.get("id", "")
143
+ pred=row.get("predicted_label", "")
144
+ proba=row.get("predicted_prob", "")
145
+ top1_feats=row.get("top1_features", "")
146
+ top2_feats=row.get("top2_features", "")
147
+ top3_feats=row.get("top3_features", "")
148
+ top1_shap=row.get("top1_shap", "")
149
+ top2_shap=row.get("top2_shap", "")
150
+ top3_shap=row.get("top3_shap", "")
151
+
152
+
153
+ context_parts.append(
154
+ f"This project under framework {framework} with funding scheme {scheme}, status {status}, legal basis {legal}."
155
+ )
156
+ context_parts.append(
157
+ f"It ends on {end_date} after {duration} days and has a max EC contribution of {max_contrib}."
158
+ )
159
+ context_parts.append("Participating organizations:")
160
+ for i, name in enumerate(names):
161
+ city = cities[i] if i < len(cities) else ""
162
+ country = countries[i] if i < len(countries) else ""
163
+ act = activity[i] if i < len(activity) else ""
164
+ contrib = contributions[i] if i < len(contributions) else ""
165
+ sme_flag = "SME" if (smes and i < len(smes) and smes[i]) else "non-SME"
166
+ context_parts.append(
167
+ f"- {name} in {city}, {country}, activity: {act}, contributed: {contrib}, {sme_flag}."
168
+ )
169
+ if status in (None,"signed","SIGNED","Signed"):
170
+ if int(pred) == 1:
171
+ label = "TERMINATED"
172
+ score = float(proba)
173
+ else:
174
+ label = "CLOSED"
175
+ score = 1 - float(proba)
176
+
177
+ score_str = f"{score:.2f}"
178
+
179
+ context_parts.append(
180
+ f"- Project {project_id} is predicted to be {label} (score={score_str}). "
181
+ f"The 3 most predictive features were: "
182
+ f"{top1_feats} ({top1_shap:.3f}), "
183
+ f"{top2_feats} ({top2_shap:.3f}), "
184
+ f"{top3_feats} ({top3_shap:.3f})."
185
+ )
186
+
187
+ title_report = row.get("list_title_report", "")
188
+ objective = row.get("objective", "")
189
+ full_body = f"{title_report} {objective}"
190
+ full_text = " ".join(context_parts + [full_body])
191
+ meta: Dict[str, Any] = {"id": str(row.get("id", "")),"startDate": str(row.get("startDate", "")),"endDate": str(row.get("endDate", "")),"status":str(row.get("status", "")),"legalBasis":str(row.get("legalBasis",""))}
192
+ meta.update({"id": str(row.get("id", "")),"startDate": str(row.get("startDate", "")),"endDate": str(row.get("endDate", "")),"status":str(row.get("status", "")),"legalBasis":str(row.get("legalBasis",""))})
193
+ docs.append(Document(page_content=full_text, metadata=meta))
194
  return docs
195
 
196
+ # === BM25 Search ===
197
+ async def bm25_search(ix: index.Index, query: str, k: int) -> List[Document]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  parser = MultifieldParser(["content"], schema=ix.schema)
199
+ def _search() -> List[Document]:
200
+ with ix.searcher() as searcher:
201
+ hits = searcher.search(parser.parse(query), limit=k)
202
+ return [Document(page_content=h["content"], metadata={"id": h["id"]}) for h in hits]
203
+ return await run_in_threadpool(_search)
204
+
205
+ # === Helper: build or load FAISS with mmap ===
206
+ async def build_or_load_faiss(
207
+ chunks: List[Document],
208
+ vectorstore_path: str,
209
+ batch_size: int = 15000
210
+ ) -> FAISS:
211
+ faiss_index_file = os.path.join(vectorstore_path, "index.faiss")
212
+ # If on-disk exists: memory-map the FAISS index and load metadata separately
213
+ if os.path.exists(faiss_index_file):
214
+ logger.info("Memory-mapping existing FAISS index...")
215
+ mmap_idx = faiss.read_index(faiss_index_file, faiss.IO_FLAG_MMAP)
216
+ # Manually load metadata (docstore and index_to_docstore) without loading the index
217
+ import pickle
218
+ for meta_file in ["faiss.pkl", "index.pkl"]:
219
+ meta_path = os.path.join(vectorstore_path, meta_file)
220
+ if os.path.exists(meta_path):
221
+ with open(meta_path, "rb") as f:
222
+ saved = pickle.load(f)
223
+ break
224
+ else:
225
+ raise FileNotFoundError(
226
+ f"Could not find FAISS metadata pickle in {vectorstore_path}"
227
+ )
228
+ # extract metadata
229
+ if isinstance(saved, tuple):
230
+ # Handle metadata tuple of length 2 or 3
231
+ if len(saved) == 3:
232
+ _, docstore, index_to_docstore = saved
233
+ elif len(saved) == 2:
234
+ docstore, index_to_docstore = saved
235
+ else:
236
+ raise ValueError(f"Unexpected metadata tuple length: {len(saved)}")
237
+ else:
238
+ if hasattr(saved, 'docstore'):
239
+ docstore = saved.docstore
240
+ elif hasattr(saved, '_docstore'):
241
+ docstore = saved._docstore
242
+ else:
243
+ raise AttributeError("Could not find docstore in FAISS metadata")
244
+ if hasattr(saved, 'index_to_docstore'):
245
+ index_to_docstore = saved.index_to_docstore
246
+ elif hasattr(saved, '_index_to_docstore'):
247
+ index_to_docstore = saved._index_to_docstore
248
+ elif hasattr(saved, '_faiss_index_to_docstore'):
249
+ index_to_docstore = saved._faiss_index_to_docstore
250
+ else:
251
+ raise AttributeError("Could not find index_to_docstore in FAISS metadata")
252
+ # reconstruct FAISS wrapper
253
+ vs = FAISS(
254
+ embedding_function=EMBEDDING,
255
+ index=mmap_idx,
256
+ docstore=docstore,
257
+ index_to_docstore_id=index_to_docstore,
258
+ )
259
+ return vs
260
+
261
+ # 2) Else: build from scratch in batches
262
+ logger.info(f"Building FAISS index in batches of {batch_size}…")
263
+ vs: Optional[FAISS] = None
264
+ for i in tqdm(range(0, len(chunks), batch_size),
265
+ desc="Building FAISS index",
266
+ unit="batch"):
267
+ batch = chunks[i : i + batch_size]
268
+
269
+ if vs is None:
270
+ vs = FAISS.from_documents(batch, EMBEDDING)
271
+ else:
272
+ vs.add_documents(batch)
273
+
274
+ # periodic save every 5 batches
275
+ if (i // batch_size) % 5 == 0:
276
+ vs.save_local(vectorstore_path)
277
+
278
+ logger.info(f" • Saved batch up to document {i + len(batch)} / {len(chunks)}")
279
+ assert vs is not None, "No documents to index!"
280
+ return vs
281
+
282
+ # === Index Builder ===
283
+ async def build_indexes(
284
+ parquet_path: str,
285
+ vectorstore_path: str,
286
+ whoosh_dir: str,
287
+ chunk_size: int,
288
+ chunk_overlap: int,
289
+ debug_size: Optional[int]
290
+ ) -> Tuple[FAISS, index.Index]:
291
+ docs = await load_documents(parquet_path, debug_size)
292
+ ix = await build_whoosh_index(docs, whoosh_dir)
293
 
 
 
 
294
  splitter = RecursiveCharacterTextSplitter(
295
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap
 
296
  )
297
  chunks = splitter.split_documents(docs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
+ # build or load (with mmap) FAISS
300
+ vs = await build_or_load_faiss(chunks, vectorstore_path)
301
 
302
+ return vs, ix
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/src/components/ProjectDetails.tsx CHANGED
@@ -129,6 +129,7 @@ export default function ProjectDetails({
129
  <Box><Text fontWeight="bold">End Date</Text><Text>{fmtDate(project.endDate)}</Text></Box>
130
  <Box><Text fontWeight="bold">Funding (EC max)</Text><Text>€{fmtNum(project.ecMaxContribution)}</Text></Box>
131
  <Box><Text fontWeight="bold">Total Cost</Text><Text>€{fmtNum(project.totalCost)}</Text></Box>
 
132
  <Box>
133
  <Text fontWeight="bold">Legal Basis</Text>
134
  <Text>{project.legalBasis}</Text>
@@ -272,7 +273,7 @@ export default function ProjectDetails({
272
  <ResponsiveContainer width="100%" height={300}>
273
  <BarChart data={shapData} margin={{ top: 10, right: 30, left: 0, bottom: 5 }}>
274
  <CartesianGrid strokeDasharray="3 3" />
275
- <XAxis dataKey="feature" />
276
  <YAxis />
277
  <Tooltip />
278
  <Bar dataKey="shap" name="SHAP Value">
@@ -285,6 +286,10 @@ export default function ProjectDetails({
285
  </Bar>
286
  </BarChart>
287
  </ResponsiveContainer>
 
 
 
 
288
  </>
289
  ) : (
290
  <Spinner />
 
129
  <Box><Text fontWeight="bold">End Date</Text><Text>{fmtDate(project.endDate)}</Text></Box>
130
  <Box><Text fontWeight="bold">Funding (EC max)</Text><Text>€{fmtNum(project.ecMaxContribution)}</Text></Box>
131
  <Box><Text fontWeight="bold">Total Cost</Text><Text>€{fmtNum(project.totalCost)}</Text></Box>
132
+ <Box><Text fontWeight="bold">Funding Scheme</Text><Text>project.fundingScheme</Text></Box>
133
  <Box>
134
  <Text fontWeight="bold">Legal Basis</Text>
135
  <Text>{project.legalBasis}</Text>
 
273
  <ResponsiveContainer width="100%" height={300}>
274
  <BarChart data={shapData} margin={{ top: 10, right: 30, left: 0, bottom: 5 }}>
275
  <CartesianGrid strokeDasharray="3 3" />
276
+ <XAxis dataKey="feature" axisLine={false} tick={false} />
277
  <YAxis />
278
  <Tooltip />
279
  <Bar dataKey="shap" name="SHAP Value">
 
286
  </Bar>
287
  </BarChart>
288
  </ResponsiveContainer>
289
+ <Text fontSize="xs" color="gray.500" mt={2}>
290
+ Each bar shows how much that feature pushed the model's prediction.
291
+ Positive bars increase the chance of termination; Negative bars decrease it.
292
+ </Text>
293
  </>
294
  ) : (
295
  <Spinner />
predictive_modelling.py CHANGED
@@ -7,6 +7,7 @@ import shap
7
  import matplotlib.pyplot as plt
8
  import scipy.sparse
9
  import polars as pl
 
10
  import gcsfs
11
 
12
  from sklearn.base import BaseEstimator, TransformerMixin
@@ -510,22 +511,126 @@ def score(new_df, model_dir="model_artifacts"):
510
  explainer = shap.Explainer(clf, X_sel, feature_names=feature_names)
511
  shap_vals = explainer(X_sel) # returns a ShapleyValues object
512
 
513
- # 6) For each row, pick top-3 absolute contributors
514
  shap_df = pd.DataFrame(shap_vals.values, columns=feature_names, index=df.index)
 
 
515
  abs_shap = shap_df.abs()
516
 
 
517
  top_feats = abs_shap.apply(lambda row: row.nlargest(6).index.tolist(), axis=1)
518
- top_vals = abs_shap.apply(lambda row: row.nlargest(6).values.tolist(), axis=1)
519
 
520
- df[["top1_feature","top2_feature","top3_feature","top4_feature","top5_feature","top6_feature"]] = pd.DataFrame(
521
- top_feats.tolist(), index=df.index
522
- )
523
- df[["top1_shap","top2_shap","top3_shap","top4_shap","top5_shap","top6_shap"]] = pd.DataFrame(
524
- top_vals.tolist(), index=df.index
525
- )
 
 
 
 
 
 
 
 
526
 
527
  return df
528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
  if __name__ == "__main__":
530
  # Entry point for training and scoring: loads data from Google Cloud Storage,
531
  # builds model and artifacts, then scores the same data as a test.
@@ -539,5 +644,5 @@ if __name__ == "__main__":
539
  df = pl.read_parquet(f).to_pandas()
540
 
541
  status_prediction_model(df)
542
- scored_df = score(df)
543
- print(scored_df.head())
 
7
  import matplotlib.pyplot as plt
8
  import scipy.sparse
9
  import polars as pl
10
+ import re
11
  import gcsfs
12
 
13
  from sklearn.base import BaseEstimator, TransformerMixin
 
511
  explainer = shap.Explainer(clf, X_sel, feature_names=feature_names)
512
  shap_vals = explainer(X_sel) # returns a ShapleyValues object
513
 
514
+ # 6) For each row, pick top-6 absolute contributors
515
  shap_df = pd.DataFrame(shap_vals.values, columns=feature_names, index=df.index)
516
+
517
+ # 7) get absolute values
518
  abs_shap = shap_df.abs()
519
 
520
+ # 8) for each row, record the top‐6 feature names by absolute magnitude
521
  top_feats = abs_shap.apply(lambda row: row.nlargest(6).index.tolist(), axis=1)
 
522
 
523
+ # 9) convert that to six separate columns
524
+ feat_cols = [f"top{i}_feature" for i in range(1,7)]
525
+ df[feat_cols] = pd.DataFrame(top_feats.tolist(), index=df.index)
526
+
527
+ # 10) now build the *true* SHAP values by looking up each name in shap_df
528
+ # for each row, shap_df.loc[idx, feat] is the signed value
529
+ top_vals = [
530
+ [ shap_df.loc[idx, feat] for feat in feats ]
531
+ for idx, feats in top_feats.items()
532
+ ]
533
+
534
+ # 11) store them in your six shap‐value columns
535
+ val_cols = [f"top{i}_shap" for i in range(1,7)]
536
+ df[val_cols] = pd.DataFrame(top_vals, index=df.index)
537
 
538
  return df
539
 
540
+ def clean_feature_name(raw: str) -> str:
541
+ """
542
+ - cat__: "cat__feature_value" → "Feature: Value"
543
+ - num__: "num__some_count" → "Some Count"
544
+ - mlb_: "mlb_list_activityType__list_activityType_Research"
545
+ → "List Activity Type: Research"
546
+ """
547
+ if not raw:
548
+ return ""
549
+
550
+ # 1) cat__
551
+ if raw.startswith("cat__"):
552
+ s = raw[len("cat__"):]
553
+ col, val = (s.split("__", 1) + [None])[:2]
554
+ col_c = col.replace("_", " ").title()
555
+ if val:
556
+ val_c = val.replace("_", " ").title()
557
+ return f"{col_c}: {val_c}"
558
+ return col_c
559
+
560
+ # 2) num__
561
+ if raw.startswith("num__"):
562
+ s = raw[len("num__"):]
563
+ return s.replace("_", " ").replace('n ','Number of ')
564
+
565
+ # 3) mlb_
566
+ if raw.startswith("mlb_"):
567
+ s = raw[len("mlb_"):]
568
+ col_part, val_part = (s.split("__", 1) + [None])[:2]
569
+ # drop leading "list_" on the column
570
+ if col_part.startswith("list_"):
571
+ col_inner = col_part[len("list_"):]
572
+ else:
573
+ col_inner = col_part
574
+ col_c = col_inner.replace("_", " ").title()
575
+ col_c = "List " + col_c
576
+
577
+ if val_part:
578
+ # drop "list_{col_inner}_" or leading "list_"
579
+ prefix = f"list_{col_inner}_"
580
+ if val_part.startswith(prefix):
581
+ val_inner = val_part[len(prefix):]
582
+ elif val_part.startswith("list_"):
583
+ val_inner = val_part[len("list_"):]
584
+ else:
585
+ val_inner = val_part
586
+ val_c = val_inner.replace("_", " ").title()
587
+ return f"{col_c}: {val_c}"
588
+ return col_c
589
+
590
+ # fallback: replace __ → ": ", _ → " "
591
+ return raw.replace("__", ": ").replace("_", " ").title()
592
+
593
+
594
+ def preprocess_feature_names(df: pl.DataFrame) -> pl.DataFrame:
595
+ transforms = []
596
+
597
+ # clean and round top-6 features & shap values
598
+ for i in range(1, 7):
599
+ fcol = f"top{i}_feature"
600
+ scol = f"top{i}_shap"
601
+
602
+ if fcol in df.columns:
603
+ transforms.append(
604
+ pl.col(fcol)
605
+ .map_elements(clean_feature_name, return_dtype=pl.Utf8)
606
+ .alias(fcol)
607
+ )
608
+ if scol in df.columns:
609
+ transforms.append(
610
+ pl.col(scol)
611
+ .round(4)
612
+ .alias(scol)
613
+ )
614
+
615
+ # round overall predicted probability
616
+ if "predicted_prob" in df.columns:
617
+ transforms.append(
618
+ pl.col("predicted_prob")
619
+ .round(4)
620
+ .alias("predicted_prob")
621
+ )
622
+
623
+ # 1) build the full list of embed-columns
624
+ embed_cols = [f"title_embed_{i}" for i in range(50)] + \
625
+ [f"objective_embed_{i}" for i in range(50)]
626
+
627
+ # 2) keep only the ones that actually exist in df.columns
628
+ to_drop = [c for c in embed_cols if c in df.columns]
629
+
630
+ # 3) drop them
631
+ df = df.drop(to_drop)
632
+ return df.with_columns(transforms)
633
+
634
  if __name__ == "__main__":
635
  # Entry point for training and scoring: loads data from Google Cloud Storage,
636
  # builds model and artifacts, then scores the same data as a test.
 
644
  df = pl.read_parquet(f).to_pandas()
645
 
646
  status_prediction_model(df)
647
+ df_clean = preprocess_feature_names(pl.from_pandas(score(df)))
648
+ df_clean.head(10)