Spaces:
Sleeping
Sleeping
Commit
·
b0ca692
1
Parent(s):
418b952
- backend/main.py +276 -102
- backend/rag.py +263 -218
- frontend/src/components/ProjectDetails.tsx +6 -1
- predictive_modelling.py +115 -10
backend/main.py
CHANGED
@@ -1,55 +1,252 @@
|
|
1 |
from fastapi import FastAPI, Request, HTTPException, Depends
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
-
from fastapi import Request
|
4 |
-
# Access the global FastAPI app state
|
5 |
-
#from fastapi import current_app
|
6 |
from pydantic import BaseModel
|
7 |
-
|
8 |
-
# from rag import get_rag_chain, RAGRequest, RAGResponse
|
9 |
-
#except:
|
10 |
-
# from .rag import get_rag_chain, RAGRequest, RAGResponse
|
11 |
from contextlib import asynccontextmanager
|
|
|
|
|
12 |
import os
|
|
|
13 |
import polars as pl
|
14 |
import gcsfs
|
15 |
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
@asynccontextmanager
|
19 |
-
async def lifespan(app: FastAPI):
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
df = pl.read_parquet(f)
|
28 |
|
|
|
29 |
for col in ("title", "status", "legalBasis"):
|
30 |
df = df.with_columns(pl.col(col).str.to_lowercase().alias(f"_{col}_lc"))
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
app.state.
|
38 |
-
app.state.statuses = statuses
|
39 |
-
app.state.legal_bases = legal_bases
|
40 |
-
app.state.orgs_list = organizations
|
41 |
-
app.state.countries_list = countries
|
42 |
|
43 |
yield
|
44 |
-
|
|
|
|
|
|
|
|
|
45 |
app = FastAPI(lifespan=lifespan)
|
46 |
app.add_middleware(
|
47 |
CORSMiddleware,
|
48 |
-
allow_origins=
|
49 |
allow_methods=["*"],
|
50 |
allow_headers=["*"],
|
51 |
)
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
@app.get("/api/projects")
|
54 |
def get_projects(
|
55 |
page: int = 0,
|
@@ -71,10 +268,10 @@ def get_projects(
|
|
71 |
if search:
|
72 |
sel = sel.filter(pl.col("_title_lc").str.contains(search.lower()))
|
73 |
if status:
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
if legalBasis:
|
79 |
sel = sel.filter(pl.col("_legalBasis_lc") == legalBasis.lower())
|
80 |
if organization:
|
@@ -86,59 +283,57 @@ def get_projects(
|
|
86 |
if proj_id:
|
87 |
sel = sel.filter(pl.col("id") == proj_id)
|
88 |
|
89 |
-
|
90 |
-
"id",
|
91 |
-
"
|
92 |
-
"
|
93 |
]
|
94 |
-
|
95 |
-
|
|
|
|
|
96 |
|
97 |
-
|
98 |
-
sortOrder = True if sortOrder == "desc" else False
|
99 |
sortField = sortField if sortField in df.columns else "startDate"
|
|
|
100 |
rows = (
|
101 |
-
sel.sort(sortField,descending=
|
102 |
-
.
|
|
|
103 |
.to_dicts()
|
104 |
)
|
105 |
|
106 |
projects = []
|
107 |
for row in rows:
|
108 |
explanations = []
|
109 |
-
for i in range(1,
|
110 |
feat = row.pop(f"top{i}_feature", None)
|
111 |
shap = row.pop(f"top{i}_shap", None)
|
112 |
if feat is not None and shap is not None:
|
113 |
explanations.append({"feature": feat, "shap": shap})
|
114 |
row["explanations"] = explanations
|
115 |
-
# 2) transform list_publications into a { type: count } map
|
116 |
-
raw_pubs = row.pop("list_publications", None) or []
|
117 |
-
pub_counts: dict[str, int] = {}
|
118 |
-
for entry in raw_pubs:
|
119 |
-
# assuming entry is a string like "paper" or "peer reviewed paper"
|
120 |
-
pub_counts[entry] = pub_counts.get(entry, 0) + 1
|
121 |
|
|
|
|
|
|
|
|
|
|
|
122 |
row["publications"] = pub_counts
|
123 |
|
124 |
projects.append(row)
|
125 |
|
126 |
return projects
|
127 |
|
128 |
-
|
129 |
@app.get("/api/filters")
|
130 |
def get_filters(request: Request):
|
131 |
df = app.state.df
|
132 |
params = request.query_params
|
133 |
|
134 |
-
# apply the same filters you use elsewhere
|
135 |
if s := params.get("status"):
|
136 |
-
if s
|
137 |
-
|
138 |
-
else:
|
139 |
-
df = df.filter(pl.col("_status_lc") == s.lower())
|
140 |
if lb := params.get("legalBasis"):
|
141 |
-
df = df.filter(pl.col("_legalBasis_lc")
|
142 |
if org := params.get("organization"):
|
143 |
df = df.filter(pl.col("list_name").list.contains(org))
|
144 |
if c := params.get("country"):
|
@@ -146,40 +341,39 @@ def get_filters(request: Request):
|
|
146 |
if search := params.get("search"):
|
147 |
df = df.filter(pl.col("_title_lc").str.contains(search.lower()))
|
148 |
|
149 |
-
def normalize(
|
150 |
-
return sorted(
|
151 |
|
152 |
return {
|
153 |
-
"statuses":
|
154 |
-
"legalBases":
|
155 |
"organizations": normalize(df["list_name"].explode().to_list()),
|
156 |
-
"countries":
|
157 |
"fundingSchemes": normalize(df["fundingScheme"].explode().to_list()),
|
158 |
-
"ids":
|
159 |
}
|
160 |
|
161 |
-
|
162 |
@app.get("/api/stats")
|
163 |
def get_stats(request: Request):
|
164 |
-
params = request.query_params
|
165 |
lf = app.state.df.lazy()
|
|
|
166 |
|
167 |
if s := params.get("status"):
|
168 |
-
lf = lf.filter(pl.col("_status_lc")
|
169 |
if lb := params.get("legalBasis"):
|
170 |
-
lf = lf.filter(pl.col("_legalBasis_lc")
|
171 |
if org := params.get("organization"):
|
172 |
lf = lf.filter(pl.col("list_name").list.contains(org))
|
173 |
if c := params.get("country"):
|
174 |
lf = lf.filter(pl.col("list_country").list.contains(c))
|
175 |
if mn := params.get("minFunding"):
|
176 |
-
lf = lf.filter(pl.col("ecMaxContribution")
|
177 |
if mx := params.get("maxFunding"):
|
178 |
-
lf = lf.filter(pl.col("ecMaxContribution")
|
179 |
if y1 := params.get("minYear"):
|
180 |
-
lf = lf.filter(pl.col("startDate").dt.year()
|
181 |
if y2 := params.get("maxYear"):
|
182 |
-
lf = lf.filter(pl.col("startDate").dt.year()
|
183 |
|
184 |
grouped = (
|
185 |
lf.select(pl.col("startDate").dt.year().alias("year"))
|
@@ -189,25 +383,25 @@ def get_stats(request: Request):
|
|
189 |
.collect()
|
190 |
)
|
191 |
years, counts = grouped["year"].to_list(), grouped["count"].to_list()
|
192 |
-
return {"Projects per Year": {"labels": years, "values": counts},
|
193 |
-
"Projects per Year 2": {"labels": years, "values": counts},
|
194 |
-
"Projects per Year 3": {"labels": years, "values": counts},
|
195 |
-
"Projects per Year 4": {"labels": years, "values": counts},
|
196 |
-
"Projects per Year 5": {"labels": years, "values": counts},
|
197 |
-
"Projects per Year 6": {"labels": years, "values": counts}}
|
198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
|
200 |
@app.get("/api/project/{project_id}/organizations")
|
201 |
def get_project_organizations(project_id: str):
|
202 |
df = app.state.df
|
203 |
-
|
204 |
-
sel = df.filter(pl.col("id") == project_id)
|
205 |
if sel.is_empty():
|
206 |
raise HTTPException(status_code=404, detail="Project not found")
|
207 |
|
208 |
orgs_df = (
|
209 |
-
sel
|
210 |
-
.select([
|
211 |
pl.col("list_name").explode().alias("name"),
|
212 |
pl.col("list_city").explode().alias("city"),
|
213 |
pl.col("list_SME").explode().alias("sme"),
|
@@ -219,7 +413,6 @@ def get_project_organizations(project_id: str):
|
|
219 |
pl.col("list_geolocation").explode().alias("geoloc"),
|
220 |
])
|
221 |
.with_columns([
|
222 |
-
# now this is a List(Utf8)
|
223 |
pl.col("geoloc").str.split(",").alias("latlon"),
|
224 |
])
|
225 |
.with_columns([
|
@@ -227,29 +420,10 @@ def get_project_organizations(project_id: str):
|
|
227 |
pl.col("latlon").list.get(1).cast(pl.Float64).alias("longitude"),
|
228 |
])
|
229 |
.filter(pl.col("name").is_not_null())
|
230 |
-
.select([
|
|
|
|
|
|
|
231 |
)
|
232 |
|
233 |
return orgs_df.to_dicts()
|
234 |
-
|
235 |
-
"""def rag_chain_depender():
|
236 |
-
"""
|
237 |
-
#Dependency injector for the RAG chain stored in app.state.
|
238 |
-
#Raises HTTPException if not initialized.
|
239 |
-
"""
|
240 |
-
chain = current_app.state.rag_chain
|
241 |
-
if chain is None:
|
242 |
-
raise HTTPException(status_code=500, detail="RAG chain not initialized")
|
243 |
-
return chain
|
244 |
-
|
245 |
-
@app.post("/rag", response_model=RAGResponse)
|
246 |
-
async def ask_rag(
|
247 |
-
req: RAGRequest,
|
248 |
-
rag_chain = Depends(rag_chain_depender)
|
249 |
-
):
|
250 |
-
"""
|
251 |
-
#Handle a RAG query. Uses session memory and the provided RAG chain.
|
252 |
-
"""
|
253 |
-
# Invoke the chain with the named input
|
254 |
-
result = await rag_chain.ainvoke({"question": req.query})
|
255 |
-
return RAGResponse(answer=result["answer"])"""
|
|
|
1 |
from fastapi import FastAPI, Request, HTTPException, Depends
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
|
|
3 |
from pydantic import BaseModel
|
4 |
+
from pydantic_settings import BaseSettings
|
|
|
|
|
|
|
5 |
from contextlib import asynccontextmanager
|
6 |
+
from typing import Any, Dict, List, Optional, AsyncGenerator
|
7 |
+
|
8 |
import os
|
9 |
+
import logging
|
10 |
import polars as pl
|
11 |
import gcsfs
|
12 |
|
13 |
+
from langchain.schema import Document,BaseRetriever
|
14 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
15 |
+
from langchain_community.vectorstores import FAISS
|
16 |
+
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
|
17 |
+
from langchain_community.document_transformers import EmbeddingsRedundantFilter
|
18 |
+
from langchain.memory import ConversationBufferWindowMemory
|
19 |
+
from langchain.chains import ConversationalRetrievalChain
|
20 |
+
from langchain.prompts import PromptTemplate
|
21 |
+
from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings
|
22 |
+
|
23 |
+
from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM
|
24 |
+
from sentence_transformers import CrossEncoder
|
25 |
+
|
26 |
+
try:
|
27 |
+
# Preferred: direct import (works if rag.py is on sys.path)
|
28 |
+
from rag import build_indexes, bm25_search, load_documents
|
29 |
+
except ImportError:
|
30 |
+
try:
|
31 |
+
# Next: relative import (works if you're inside a package)
|
32 |
+
from .rag import build_indexes, bm25_search, load_documents
|
33 |
+
except ImportError:
|
34 |
+
# Last: explicit absolute package import
|
35 |
+
from app.rag import build_indexes, bm25_search, load_documents
|
36 |
+
|
37 |
+
from functools import lru_cache
|
38 |
+
# ---------------------------------------------------------------------------- #
|
39 |
+
# Settings #
|
40 |
+
# ---------------------------------------------------------------------------- #
|
41 |
+
class Settings(BaseSettings):
|
42 |
+
# Parquet + Whoosh/FAISS
|
43 |
+
parquet_path: str = "gs://mda_eu_project/data/consolidated_clean_pred.parquet"
|
44 |
+
whoosh_dir: str = "gs://mda_eu_project/whoosh_index"
|
45 |
+
vectorstore_path: str = "gs://mda_eu_project/vectorstore_index"
|
46 |
+
# Models
|
47 |
+
embedding_model: str = "sentence-transformers/LaBSE"
|
48 |
+
llm_model: str = "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w4a16"
|
49 |
+
cross_encoder_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
|
50 |
+
# RAG parameters
|
51 |
+
chunk_size: int = 750
|
52 |
+
chunk_overlap: int = 100
|
53 |
+
hybrid_k: int = 50
|
54 |
+
assistant_role: str = (
|
55 |
+
"You are a concise, factual assistant. Cite Document [ID] for each claim."
|
56 |
+
)
|
57 |
+
skip_warmup: bool = False
|
58 |
+
allowed_origins: List[str] = ["*"]
|
59 |
+
|
60 |
+
class Config:
|
61 |
+
env_file = ".env"
|
62 |
+
|
63 |
+
settings = Settings()
|
64 |
+
|
65 |
+
# Pre‐instantiate embedding model (used by filter/compressor)
|
66 |
+
EMBEDDING = HuggingFaceEmbeddings(model_name=settings.embedding_model)
|
67 |
+
|
68 |
+
@lru_cache(maxsize=256)
|
69 |
+
def embed_query_cached(query: str) -> List[float]:
|
70 |
+
"""Cache embedding vectors for queries."""
|
71 |
+
return EMBEDDING.embed_query(query.strip().lower())
|
72 |
+
|
73 |
+
# === Hybrid Retriever ===
|
74 |
+
class HybridRetriever(BaseRetriever):
|
75 |
+
"""Hybrid retriever combining BM25 and FAISS with cross-encoder re-ranking."""
|
76 |
+
# store FAISS and Whoosh under private attributes to avoid Pydantic field errors
|
77 |
+
from pydantic import PrivateAttr
|
78 |
+
_vs: FAISS = PrivateAttr()
|
79 |
+
_ix: index.Index = PrivateAttr()
|
80 |
+
_compressor: DocumentCompressorPipeline = PrivateAttr()
|
81 |
+
_cross_encoder: CrossEncoder = PrivateAttr()
|
82 |
+
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
vs: FAISS,
|
86 |
+
ix: index.Index,
|
87 |
+
compressor: DocumentCompressorPipeline,
|
88 |
+
cross_encoder: CrossEncoder
|
89 |
+
) -> None:
|
90 |
+
super().__init__()
|
91 |
+
object.__setattr__(self, '_vs', vs)
|
92 |
+
object.__setattr__(self, '_ix', ix)
|
93 |
+
object.__setattr__(self, '_compressor', compressor)
|
94 |
+
object.__setattr__(self, '_cross_encoder', cross_encoder)
|
95 |
+
|
96 |
+
async def _aget_relevant_documents(self, query: str) -> List[Document]:
|
97 |
+
# BM25 retrieval using Whoosh index
|
98 |
+
bm_docs = await bm25_search(self._ix, query, settings.hybrid_k)
|
99 |
+
# Dense retrieval using FAISS
|
100 |
+
dense_docs = self._vs.similarity_search_by_vector(
|
101 |
+
embed_query_cached(query), k=settings.hybrid_k
|
102 |
+
)
|
103 |
+
# Cross-encoder re-ranking
|
104 |
+
candidates = bm_docs + dense_docs
|
105 |
+
scores = self._cross_encoder.predict([
|
106 |
+
(query, doc.page_content) for doc in candidates
|
107 |
+
])
|
108 |
+
ranked = sorted(zip(scores, candidates), key=lambda x: x[0], reverse=True)
|
109 |
+
top = [doc for _, doc in ranked[: settings.hybrid_k]]
|
110 |
+
# Compress and return
|
111 |
+
return self._compressor.compress_documents(top, query=query)
|
112 |
|
113 |
+
def _get_relevant_documents(self, query: str) -> List[Document]:
|
114 |
+
import asyncio
|
115 |
+
return asyncio.get_event_loop().run_until_complete(
|
116 |
+
self._aget_relevant_documents(query)
|
117 |
+
)
|
118 |
+
|
119 |
+
#os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = r"C:\Users\Romain\OneDrive - KU Leuven\focal-pager-460414-e9-45369b738be0.json"
|
120 |
+
# ---------------------------------------------------------------------------- #
|
121 |
+
# Single Lifespan #
|
122 |
+
# ---------------------------------------------------------------------------- #
|
123 |
@asynccontextmanager
|
124 |
+
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
125 |
+
# --- 1) RAG Initialization --- #
|
126 |
+
logger = logging.getLogger("uvicorn")
|
127 |
+
logger.info("Initializing RAG components…")
|
128 |
+
|
129 |
+
# Compressor pipeline to de‐duplicate via embeddings
|
130 |
+
compressor = DocumentCompressorPipeline(
|
131 |
+
transformers=[EmbeddingsRedundantFilter(embeddings=EMBEDDING)]
|
132 |
+
)
|
133 |
|
134 |
+
# Cross‐encoder ranker
|
135 |
+
cross_encoder = CrossEncoder(settings.cross_encoder_model)
|
136 |
+
|
137 |
+
# Causal LLM pipeline
|
138 |
+
llm_model = AutoModelForCausalLM.from_pretrained(settings.llm_model)
|
139 |
+
gen_pipe = pipeline(
|
140 |
+
"text-generation",
|
141 |
+
model=llm_model,
|
142 |
+
tokenizer=AutoTokenizer.from_pretrained(settings.llm_model),
|
143 |
+
max_new_tokens=256,
|
144 |
+
do_sample=True,
|
145 |
+
temperature=0.7,
|
146 |
+
)
|
147 |
+
llm = HuggingFacePipeline(pipeline=gen_pipe)
|
148 |
+
|
149 |
+
# Conversational memory
|
150 |
+
memory = ConversationBufferWindowMemory(
|
151 |
+
memory_key="chat_history",
|
152 |
+
k=5,
|
153 |
+
input_key="question",
|
154 |
+
output_key="answer",
|
155 |
+
return_messages=True,
|
156 |
+
)
|
157 |
|
158 |
+
# Build or load FAISS & Whoosh once
|
159 |
+
vs, ix = await build_indexes(
|
160 |
+
settings.parquet_path,
|
161 |
+
settings.vectorstore_path,
|
162 |
+
settings.whoosh_dir,
|
163 |
+
settings.chunk_size,
|
164 |
+
settings.chunk_overlap,
|
165 |
+
None,
|
166 |
+
)
|
167 |
+
retriever = HybridRetriever(vs=vs, ix=ix, compressor=compressor, cross_encoder=cross_encoder)
|
168 |
+
|
169 |
+
prompt = PromptTemplate.from_template(
|
170 |
+
f"{settings.assistant_role}\n\n"
|
171 |
+
"Context (up to 2,000 tokens, with document IDs):\n"
|
172 |
+
"{context}\n"
|
173 |
+
"Q: {question}\n"
|
174 |
+
"A: Provide your answer."
|
175 |
+
)
|
176 |
+
|
177 |
+
app.state.rag_chain = ConversationalRetrievalChain.from_llm(
|
178 |
+
llm=llm,
|
179 |
+
retriever=retriever,
|
180 |
+
memory=memory,
|
181 |
+
combine_docs_chain_kwargs={"prompt": prompt},
|
182 |
+
return_source_documents=True,
|
183 |
+
)
|
184 |
+
|
185 |
+
if not settings.skip_warmup:
|
186 |
+
logger.info("Warming up RAG chain…")
|
187 |
+
await app.state.rag_chain.ainvoke({"question": "warmup"})
|
188 |
+
logger.info("RAG ready.")
|
189 |
+
|
190 |
+
# --- 2) Dataframe Initialization --- #
|
191 |
+
logger.info("Loading Parquet data from GCS…")
|
192 |
+
fs = gcsfs.GCSFileSystem()
|
193 |
+
with fs.open(settings.parquet_path, "rb") as f:
|
194 |
df = pl.read_parquet(f)
|
195 |
|
196 |
+
# lowercase for filtering
|
197 |
for col in ("title", "status", "legalBasis"):
|
198 |
df = df.with_columns(pl.col(col).str.to_lowercase().alias(f"_{col}_lc"))
|
199 |
|
200 |
+
# materialize unique filter values
|
201 |
+
app.state.df = df
|
202 |
+
app.state.statuses = df["_status_lc"].unique().to_list()
|
203 |
+
app.state.legal_bases = df["_legalBasis_lc"].unique().to_list()
|
204 |
+
app.state.orgs_list = df.explode("list_name")["list_name"].unique().to_list()
|
205 |
+
app.state.countries_list = df.explode("list_country")["list_country"].unique().to_list()
|
|
|
|
|
|
|
|
|
206 |
|
207 |
yield
|
208 |
+
# teardown (if any) goes here
|
209 |
+
|
210 |
+
# ---------------------------------------------------------------------------- #
|
211 |
+
# App Setup #
|
212 |
+
# ---------------------------------------------------------------------------- #
|
213 |
app = FastAPI(lifespan=lifespan)
|
214 |
app.add_middleware(
|
215 |
CORSMiddleware,
|
216 |
+
allow_origins=settings.allowed_origins,
|
217 |
allow_methods=["*"],
|
218 |
allow_headers=["*"],
|
219 |
)
|
220 |
|
221 |
+
# ---------------------------------------------------------------------------- #
|
222 |
+
# RAG Endpoint #
|
223 |
+
# ---------------------------------------------------------------------------- #
|
224 |
+
class RAGRequest(BaseModel):
|
225 |
+
session_id: Optional[str]
|
226 |
+
query: str
|
227 |
+
|
228 |
+
class RAGResponse(BaseModel):
|
229 |
+
answer: str
|
230 |
+
source_ids: List[str]
|
231 |
+
|
232 |
+
def rag_chain_depender(app: FastAPI = Depends(lambda: app)) -> Any:
|
233 |
+
chain = app.state.rag_chain
|
234 |
+
if chain is None:
|
235 |
+
raise HTTPException(status_code=500, detail="RAG chain not initialized")
|
236 |
+
return chain
|
237 |
+
|
238 |
+
@app.post("/rag", response_model=RAGResponse)
|
239 |
+
async def ask_rag(
|
240 |
+
req: RAGRequest,
|
241 |
+
rag_chain = Depends(rag_chain_depender)
|
242 |
+
):
|
243 |
+
result = await rag_chain.ainvoke({"question": req.query})
|
244 |
+
sources = [doc.metadata.get("id", "") for doc in result.get("source_documents", [])]
|
245 |
+
return RAGResponse(answer=result["answer"], source_ids=sources)
|
246 |
+
|
247 |
+
# ---------------------------------------------------------------------------- #
|
248 |
+
# Data Endpoints #
|
249 |
+
# ---------------------------------------------------------------------------- #
|
250 |
@app.get("/api/projects")
|
251 |
def get_projects(
|
252 |
page: int = 0,
|
|
|
268 |
if search:
|
269 |
sel = sel.filter(pl.col("_title_lc").str.contains(search.lower()))
|
270 |
if status:
|
271 |
+
sel = sel.filter(
|
272 |
+
pl.col("status").is_null() if status == "UNKNOWN"
|
273 |
+
else pl.col("_status_lc") == status.lower()
|
274 |
+
)
|
275 |
if legalBasis:
|
276 |
sel = sel.filter(pl.col("_legalBasis_lc") == legalBasis.lower())
|
277 |
if organization:
|
|
|
283 |
if proj_id:
|
284 |
sel = sel.filter(pl.col("id") == proj_id)
|
285 |
|
286 |
+
base_cols = [
|
287 |
+
"id","title","status","startDate","endDate","ecMaxContribution","acronym",
|
288 |
+
"legalBasis","objective","frameworkProgramme","list_euroSciVocTitle",
|
289 |
+
"list_euroSciVocPath","totalCost","list_isPublishedAs","fundingScheme"
|
290 |
]
|
291 |
+
# add shap/explanation columns
|
292 |
+
for i in range(1,7):
|
293 |
+
base_cols += [f"top{i}_feature", f"top{i}_shap"]
|
294 |
+
base_cols += ["predicted_label","predicted_prob"]
|
295 |
|
296 |
+
sort_desc = True if sortOrder=="desc" else False
|
|
|
297 |
sortField = sortField if sortField in df.columns else "startDate"
|
298 |
+
|
299 |
rows = (
|
300 |
+
sel.sort(sortField, descending=sort_desc)
|
301 |
+
.slice(start, limit)
|
302 |
+
.select(base_cols)
|
303 |
.to_dicts()
|
304 |
)
|
305 |
|
306 |
projects = []
|
307 |
for row in rows:
|
308 |
explanations = []
|
309 |
+
for i in range(1,7):
|
310 |
feat = row.pop(f"top{i}_feature", None)
|
311 |
shap = row.pop(f"top{i}_shap", None)
|
312 |
if feat is not None and shap is not None:
|
313 |
explanations.append({"feature": feat, "shap": shap})
|
314 |
row["explanations"] = explanations
|
|
|
|
|
|
|
|
|
|
|
|
|
315 |
|
316 |
+
# publications aggregation
|
317 |
+
raw_pubs = row.pop("list_publications", []) or []
|
318 |
+
pub_counts: Dict[str,int] = {}
|
319 |
+
for p in raw_pubs:
|
320 |
+
pub_counts[p] = pub_counts.get(p, 0) + 1
|
321 |
row["publications"] = pub_counts
|
322 |
|
323 |
projects.append(row)
|
324 |
|
325 |
return projects
|
326 |
|
|
|
327 |
@app.get("/api/filters")
|
328 |
def get_filters(request: Request):
|
329 |
df = app.state.df
|
330 |
params = request.query_params
|
331 |
|
|
|
332 |
if s := params.get("status"):
|
333 |
+
df = df.filter(pl.col("status").is_null() if s=="UNKNOWN"
|
334 |
+
else pl.col("_status_lc")==s.lower())
|
|
|
|
|
335 |
if lb := params.get("legalBasis"):
|
336 |
+
df = df.filter(pl.col("_legalBasis_lc")==lb.lower())
|
337 |
if org := params.get("organization"):
|
338 |
df = df.filter(pl.col("list_name").list.contains(org))
|
339 |
if c := params.get("country"):
|
|
|
341 |
if search := params.get("search"):
|
342 |
df = df.filter(pl.col("_title_lc").str.contains(search.lower()))
|
343 |
|
344 |
+
def normalize(vals):
|
345 |
+
return sorted({("UNKNOWN" if v is None else v) for v in vals})
|
346 |
|
347 |
return {
|
348 |
+
"statuses": normalize(df["status"].to_list()),
|
349 |
+
"legalBases": normalize(df["legalBasis"].to_list()),
|
350 |
"organizations": normalize(df["list_name"].explode().to_list()),
|
351 |
+
"countries": normalize(df["list_country"].explode().to_list()),
|
352 |
"fundingSchemes": normalize(df["fundingScheme"].explode().to_list()),
|
353 |
+
"ids": normalize(df["id"].to_list()),
|
354 |
}
|
355 |
|
|
|
356 |
@app.get("/api/stats")
|
357 |
def get_stats(request: Request):
|
|
|
358 |
lf = app.state.df.lazy()
|
359 |
+
params = request.query_params
|
360 |
|
361 |
if s := params.get("status"):
|
362 |
+
lf = lf.filter(pl.col("_status_lc")==s.lower())
|
363 |
if lb := params.get("legalBasis"):
|
364 |
+
lf = lf.filter(pl.col("_legalBasis_lc")==lb.lower())
|
365 |
if org := params.get("organization"):
|
366 |
lf = lf.filter(pl.col("list_name").list.contains(org))
|
367 |
if c := params.get("country"):
|
368 |
lf = lf.filter(pl.col("list_country").list.contains(c))
|
369 |
if mn := params.get("minFunding"):
|
370 |
+
lf = lf.filter(pl.col("ecMaxContribution")>=int(mn))
|
371 |
if mx := params.get("maxFunding"):
|
372 |
+
lf = lf.filter(pl.col("ecMaxContribution")<=int(mx))
|
373 |
if y1 := params.get("minYear"):
|
374 |
+
lf = lf.filter(pl.col("startDate").dt.year()>=int(y1))
|
375 |
if y2 := params.get("maxYear"):
|
376 |
+
lf = lf.filter(pl.col("startDate").dt.year()<=int(y2))
|
377 |
|
378 |
grouped = (
|
379 |
lf.select(pl.col("startDate").dt.year().alias("year"))
|
|
|
383 |
.collect()
|
384 |
)
|
385 |
years, counts = grouped["year"].to_list(), grouped["count"].to_list()
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
|
387 |
+
return {
|
388 |
+
"Projects per Year": {"labels": years, "values": counts},
|
389 |
+
"Projects per Year 2": {"labels": years, "values": counts},
|
390 |
+
"Projects per Year 3": {"labels": years, "values": counts},
|
391 |
+
"Projects per Year 4": {"labels": years, "values": counts},
|
392 |
+
"Projects per Year 5": {"labels": years, "values": counts},
|
393 |
+
"Projects per Year 6": {"labels": years, "values": counts},
|
394 |
+
}
|
395 |
|
396 |
@app.get("/api/project/{project_id}/organizations")
|
397 |
def get_project_organizations(project_id: str):
|
398 |
df = app.state.df
|
399 |
+
sel = df.filter(pl.col("id")==project_id)
|
|
|
400 |
if sel.is_empty():
|
401 |
raise HTTPException(status_code=404, detail="Project not found")
|
402 |
|
403 |
orgs_df = (
|
404 |
+
sel.select([
|
|
|
405 |
pl.col("list_name").explode().alias("name"),
|
406 |
pl.col("list_city").explode().alias("city"),
|
407 |
pl.col("list_SME").explode().alias("sme"),
|
|
|
413 |
pl.col("list_geolocation").explode().alias("geoloc"),
|
414 |
])
|
415 |
.with_columns([
|
|
|
416 |
pl.col("geoloc").str.split(",").alias("latlon"),
|
417 |
])
|
418 |
.with_columns([
|
|
|
420 |
pl.col("latlon").list.get(1).cast(pl.Float64).alias("longitude"),
|
421 |
])
|
422 |
.filter(pl.col("name").is_not_null())
|
423 |
+
.select([
|
424 |
+
"name","city","sme","role","contribution",
|
425 |
+
"activityType","orgURL","country","latitude","longitude"
|
426 |
+
])
|
427 |
)
|
428 |
|
429 |
return orgs_df.to_dicts()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/rag.py
CHANGED
@@ -1,15 +1,15 @@
|
|
1 |
import os
|
2 |
-
import
|
3 |
-
import
|
4 |
-
from
|
5 |
-
from typing import Any, Dict, List
|
6 |
|
7 |
import gcsfs
|
|
|
8 |
import polars as pl
|
9 |
from pydantic_settings import BaseSettings
|
10 |
-
from fastapi import FastAPI, HTTPException
|
11 |
from fastapi.middleware.cors import CORSMiddleware
|
12 |
-
from
|
13 |
from pydantic import BaseModel
|
14 |
|
15 |
from langchain.schema import BaseRetriever, Document
|
@@ -17,241 +17,286 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
17 |
from langchain_community.vectorstores import FAISS
|
18 |
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
|
19 |
from langchain_community.document_transformers import EmbeddingsRedundantFilter
|
20 |
-
from langchain.memory import
|
21 |
from langchain.chains import ConversationalRetrievalChain
|
22 |
from langchain.prompts import PromptTemplate
|
23 |
from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings
|
24 |
|
25 |
-
from transformers import AutoTokenizer,
|
|
|
|
|
26 |
from sentence_transformers import CrossEncoder
|
27 |
-
|
28 |
from whoosh import index
|
29 |
from whoosh.fields import Schema, TEXT, ID
|
30 |
from whoosh.analysis import StemmingAnalyzer
|
31 |
from whoosh.qparser import MultifieldParser
|
|
|
|
|
32 |
|
33 |
-
|
34 |
-
class Settings(BaseSettings):
|
35 |
-
# allow GCS paths (e.g. "gs://bucket/...") via gcsfs
|
36 |
-
parquet_path: str = "/content/drive/MyDrive/consolidated_clean.parquet"
|
37 |
-
vectorstore_path: str = "/content/drive/MyDrive/vectorstore_index"
|
38 |
-
whoosh_dir: str = "whoosh_index"
|
39 |
-
embedding_model: str = "sentence-transformers/all-mpnet-base-v2"
|
40 |
-
llm_model: str = "bigscience/bloomz-560m"
|
41 |
-
cross_encoder_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
|
42 |
-
chunk_size: int = 300
|
43 |
-
chunk_overlap: int = 50
|
44 |
-
hybrid_k: int = 50
|
45 |
-
assistant_role: str = "You are a concise, factual assistant. Cite Document [ID] for each claim."
|
46 |
|
47 |
-
|
48 |
-
|
|
|
49 |
|
50 |
-
|
|
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
session_meta: Dict[str, Dict] = {} # per‐session metadata
|
57 |
-
# rag_chain will be stored on app.state
|
58 |
|
59 |
-
# === Whoosh
|
60 |
_WHOOSH_CACHE: Dict[str, index.Index] = {}
|
61 |
-
def build_whoosh_index(docs: List[Document]) -> index.Index:
|
62 |
-
key = settings.whoosh_dir
|
63 |
-
if key in _WHOOSH_CACHE:
|
64 |
-
return _WHOOSH_CACHE[key]
|
65 |
-
os.makedirs(key, exist_ok=True)
|
66 |
-
schema = Schema(id=ID(stored=True, unique=True),
|
67 |
-
content=TEXT(analyzer=StemmingAnalyzer()))
|
68 |
-
ix = index.create_in(key, schema)
|
69 |
-
writer = ix.writer()
|
70 |
-
for doc in docs:
|
71 |
-
writer.add_document(id=doc.metadata["id"], content=doc.page_content)
|
72 |
-
writer.commit()
|
73 |
-
_WHOOSH_CACHE[key] = ix
|
74 |
-
return ix
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
fs = gcsfs.GCSFileSystem()
|
81 |
-
with
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
return docs
|
93 |
|
94 |
-
# ===
|
95 |
-
def
|
96 |
-
parts = [query]
|
97 |
-
for key,label in [
|
98 |
-
("startDate","from"), ("topics","about"),
|
99 |
-
("list_euroSciVocTitle","with topics"), ("list_euroSciVocPath","with topic path")
|
100 |
-
]:
|
101 |
-
if md.get(key):
|
102 |
-
parts.append(f"{label} {md[key]}")
|
103 |
-
return "; ".join(parts)
|
104 |
-
|
105 |
-
async def iterative_retrieval(
|
106 |
-
query: str, vs: FAISS, ix: index.Index,
|
107 |
-
rewriter: HuggingFacePipeline, md: Dict[str,str]
|
108 |
-
) -> List[Document]:
|
109 |
-
first_q = await rewriter.ainvoke({"query": query})
|
110 |
parser = MultifieldParser(["content"], schema=ix.schema)
|
111 |
-
def
|
112 |
-
with ix.searcher() as
|
113 |
-
hits =
|
114 |
-
return [
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
-
# === Retriever & Index Builder ===
|
129 |
-
def build_retriever(docs: List[Document], embedder: HuggingFaceEmbeddings):
|
130 |
-
ix = build_whoosh_index(docs)
|
131 |
splitter = RecursiveCharacterTextSplitter(
|
132 |
-
chunk_size=
|
133 |
-
chunk_overlap=settings.chunk_overlap
|
134 |
)
|
135 |
chunks = splitter.split_documents(docs)
|
136 |
-
if os.path.exists(settings.vectorstore_path):
|
137 |
-
vs = FAISS.load_local(
|
138 |
-
settings.vectorstore_path,
|
139 |
-
embeddings=embedder,
|
140 |
-
allow_dangerous_deserialization=True
|
141 |
-
)
|
142 |
-
else:
|
143 |
-
vs = FAISS.from_documents(chunks, embedder)
|
144 |
-
vs.save_local(settings.vectorstore_path)
|
145 |
-
return vs, ix
|
146 |
-
|
147 |
-
# === Rewriter & LLM Pipelines ===
|
148 |
-
rewriter_pipe = pipeline(
|
149 |
-
"text-generation",
|
150 |
-
model=AutoModelForCausalLM.from_pretrained(settings.llm_model),
|
151 |
-
tokenizer=AutoTokenizer.from_pretrained(settings.llm_model),
|
152 |
-
max_new_tokens=64, truncation=True, do_sample=False, return_full_text=False
|
153 |
-
)
|
154 |
-
rewriter = HuggingFacePipeline(pipeline=rewriter_pipe)
|
155 |
-
query_rewriter = PromptTemplate.from_template("Rewrite question: {query}") | rewriter
|
156 |
-
|
157 |
-
def get_llm_pipeline():
|
158 |
-
gen_pipe = pipeline(
|
159 |
-
"text-generation",
|
160 |
-
model=AutoModelForCausalLM.from_pretrained(settings.llm_model),
|
161 |
-
tokenizer=AutoTokenizer.from_pretrained(settings.llm_model),
|
162 |
-
max_new_tokens=256, truncation=True,
|
163 |
-
temperature=0.7, do_sample=True, return_full_text=False
|
164 |
-
)
|
165 |
-
return HuggingFacePipeline(pipeline=gen_pipe)
|
166 |
-
|
167 |
-
combine_prompt = PromptTemplate.from_template(
|
168 |
-
f"""{settings.assistant_role}
|
169 |
|
170 |
-
|
171 |
-
|
172 |
|
173 |
-
|
174 |
-
A:
|
175 |
-
"""
|
176 |
-
)
|
177 |
-
|
178 |
-
async def retrieve_and_rank(
|
179 |
-
query: str, vs: FAISS, ix: index.Index, md: Dict[str,str]
|
180 |
-
) -> List[Document]:
|
181 |
-
bm25_docs = await iterative_retrieval(query, vs, ix, query_rewriter, md)
|
182 |
-
dense = vs.similarity_search_by_vector(embed_query_cached(query), k=settings.hybrid_k)
|
183 |
-
candidates = bm25_docs + dense
|
184 |
-
scores = cross_encoder.predict([(query, d.page_content) for d in candidates])
|
185 |
-
top_docs = [doc for _, doc in sorted(zip(scores, candidates), reverse=True)[:settings.hybrid_k]]
|
186 |
-
return compressor.compress_documents(top_docs, query=query)
|
187 |
-
|
188 |
-
class HybridRetriever(BaseRetriever):
|
189 |
-
vs: Any
|
190 |
-
ix: Any
|
191 |
-
meta_store: Dict[str,Dict]
|
192 |
-
|
193 |
-
def __init__(self, vs, ix, meta_store):
|
194 |
-
super().__init__(vs=vs, ix=ix, meta_store=meta_store)
|
195 |
-
def _get_relevant_documents(self, query: str):
|
196 |
-
return asyncio.get_event_loop().run_until_complete(
|
197 |
-
retrieve_and_rank(query, self.vs, self.ix, self.meta_store.get(query, {}))
|
198 |
-
)
|
199 |
-
async def _aget_relevant_documents(self, query: str):
|
200 |
-
return await retrieve_and_rank(query, self.vs, self.ix, self.meta_store.get(query, {}))
|
201 |
-
|
202 |
-
# === FastAPI Lifespan & Initialization ===
|
203 |
-
@asynccontextmanager
|
204 |
-
async def lifespan(app: FastAPI):
|
205 |
-
global embedding, cross_encoder, compressor
|
206 |
-
|
207 |
-
# 1) Embeddings & docs
|
208 |
-
embedding = HuggingFaceEmbeddings(model_name=settings.embedding_model)
|
209 |
-
docs = load_documents(settings.parquet_path)
|
210 |
-
|
211 |
-
# 2) Retriever + Whoosh index
|
212 |
-
vs, ix = build_retriever(docs, embedding)
|
213 |
-
|
214 |
-
# 3) Compressor and reranker
|
215 |
-
compressor = DocumentCompressorPipeline(
|
216 |
-
transformers=[EmbeddingsRedundantFilter(embeddings=embedding)]
|
217 |
-
)
|
218 |
-
cross_encoder = CrossEncoder(settings.cross_encoder_model)
|
219 |
-
|
220 |
-
# 4) LLM & memory
|
221 |
-
llm = get_llm_pipeline()
|
222 |
-
memory = ConversationBufferMemory(
|
223 |
-
memory_key="chat_history",
|
224 |
-
input_key="question",
|
225 |
-
output_key="answer",
|
226 |
-
return_messages=True
|
227 |
-
)
|
228 |
-
|
229 |
-
# 5) Build RAG chain
|
230 |
-
retriever = HybridRetriever(vs, ix, session_meta)
|
231 |
-
rag_chain = ConversationalRetrievalChain.from_llm(
|
232 |
-
llm=llm,
|
233 |
-
retriever=retriever,
|
234 |
-
memory=memory,
|
235 |
-
combine_docs_chain_kwargs={"prompt": combine_prompt},
|
236 |
-
return_source_documents=True
|
237 |
-
)
|
238 |
-
|
239 |
-
# 6) Warm-up
|
240 |
-
vs.similarity_search("warmup", k=1)
|
241 |
-
embed_query_cached("warmup")
|
242 |
-
await rag_chain.ainvoke({"question": "warmup"})
|
243 |
-
|
244 |
-
# 7) Store on app.state so get_rag_chain() can find it
|
245 |
-
app.state.rag_chain = rag_chain
|
246 |
-
|
247 |
-
yield
|
248 |
-
|
249 |
-
# === Pydantic Models & FastAPI Setup ===
|
250 |
-
class RAGRequest(BaseModel):
|
251 |
-
session_id: str
|
252 |
-
topic: str
|
253 |
-
startDate: str
|
254 |
-
query: str
|
255 |
-
|
256 |
-
class RAGResponse(BaseModel):
|
257 |
-
answer: str
|
|
|
1 |
import os
|
2 |
+
import logging
|
3 |
+
from typing import Any, Dict, List, Optional, Tuple, AsyncGenerator
|
4 |
+
from contextlib import asynccontextmanager
|
|
|
5 |
|
6 |
import gcsfs
|
7 |
+
import aiofiles
|
8 |
import polars as pl
|
9 |
from pydantic_settings import BaseSettings
|
10 |
+
from fastapi import FastAPI, HTTPException
|
11 |
from fastapi.middleware.cors import CORSMiddleware
|
12 |
+
from starlette.concurrency import run_in_threadpool
|
13 |
from pydantic import BaseModel
|
14 |
|
15 |
from langchain.schema import BaseRetriever, Document
|
|
|
17 |
from langchain_community.vectorstores import FAISS
|
18 |
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
|
19 |
from langchain_community.document_transformers import EmbeddingsRedundantFilter
|
20 |
+
from langchain.memory import ConversationBufferWindowMemory
|
21 |
from langchain.chains import ConversationalRetrievalChain
|
22 |
from langchain.prompts import PromptTemplate
|
23 |
from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings
|
24 |
|
25 |
+
from transformers import AutoTokenizer, pipeline
|
26 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
27 |
+
#from optimum.onnxruntime import ORTModelForCausalLM, ORTOptimizer
|
28 |
from sentence_transformers import CrossEncoder
|
|
|
29 |
from whoosh import index
|
30 |
from whoosh.fields import Schema, TEXT, ID
|
31 |
from whoosh.analysis import StemmingAnalyzer
|
32 |
from whoosh.qparser import MultifieldParser
|
33 |
+
from tqdm import tqdm
|
34 |
+
import faiss
|
35 |
|
36 |
+
from functools import lru_cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
+
# === Logging ===
|
39 |
+
logging.basicConfig(level=logging.INFO)
|
40 |
+
logger = logging.getLogger(__name__)
|
41 |
|
42 |
+
# === Global Embeddings & Cache ===
|
43 |
+
EMBEDDING = HuggingFaceEmbeddings(model_name=settings.embedding_model)
|
44 |
|
45 |
+
@lru_cache(maxsize=256)
|
46 |
+
def embed_query_cached(query: str) -> List[float]:
|
47 |
+
"""Cache embedding vectors for queries."""
|
48 |
+
return EMBEDDING.embed_query(query.strip().lower())
|
|
|
|
|
49 |
|
50 |
+
# === Whoosh Cache & Builder ===
|
51 |
_WHOOSH_CACHE: Dict[str, index.Index] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
+
async def build_whoosh_index(docs: List[Document], whoosh_dir: str) -> index.Index:
|
54 |
+
key = whoosh_dir
|
55 |
+
fs = gcsfs.GCSFileSystem()
|
56 |
+
local_dir = key
|
57 |
+
is_gcs = key.startswith("gs://")
|
58 |
+
try:
|
59 |
+
# stage local copy for GCS
|
60 |
+
if is_gcs:
|
61 |
+
local_dir = "/tmp/whoosh_index"
|
62 |
+
if not os.path.exists(local_dir):
|
63 |
+
if await run_in_threadpool(fs.exists, key):
|
64 |
+
await run_in_threadpool(fs.get, key, local_dir, recursive=True)
|
65 |
+
else:
|
66 |
+
os.makedirs(local_dir, exist_ok=True)
|
67 |
+
# build once
|
68 |
+
if key not in _WHOOSH_CACHE:
|
69 |
+
os.makedirs(local_dir, exist_ok=True)
|
70 |
+
schema = Schema(
|
71 |
+
id=ID(stored=True, unique=True),
|
72 |
+
content=TEXT(analyzer=StemmingAnalyzer()),
|
73 |
+
)
|
74 |
+
ix = index.create_in(local_dir, schema)
|
75 |
+
with ix.writer() as writer:
|
76 |
+
for doc in docs:
|
77 |
+
writer.add_document(
|
78 |
+
id=doc.metadata.get("id", ""),
|
79 |
+
content=doc.page_content,
|
80 |
+
)
|
81 |
+
# push back to GCS atomically
|
82 |
+
if is_gcs:
|
83 |
+
await run_in_threadpool(fs.put, local_dir, key, recursive=True)
|
84 |
+
_WHOOSH_CACHE[key] = ix
|
85 |
+
return _WHOOSH_CACHE[key]
|
86 |
+
except Exception as e:
|
87 |
+
logger.error(f"Failed to build Whoosh index: {e}")
|
88 |
+
raise
|
89 |
+
|
90 |
+
# === Document Loader ===
|
91 |
+
async def load_documents(
|
92 |
+
path: str,
|
93 |
+
sample_size: Optional[int] = None
|
94 |
+
) -> List[Document]:
|
95 |
+
"""
|
96 |
+
Load a Parquet file from local or GCS, convert to a list of Documents.
|
97 |
+
"""
|
98 |
+
def _read_local(p: str, n: Optional[int]):
|
99 |
+
# streaming scan keeps memory low
|
100 |
+
lf = pl.scan_parquet(p)
|
101 |
+
if n:
|
102 |
+
lf = lf.limit(n)
|
103 |
+
return lf.collect(streaming=True)
|
104 |
+
|
105 |
+
def _read_gcs(p: str, n: Optional[int]):
|
106 |
+
# download to a temp file synchronously, then read with Polars
|
107 |
fs = gcsfs.GCSFileSystem()
|
108 |
+
with tempfile.TemporaryDirectory() as td:
|
109 |
+
local_path = os.path.join(td, "data.parquet")
|
110 |
+
fs.get(p, local_path, recursive=False)
|
111 |
+
df = pl.read_parquet(local_path)
|
112 |
+
if n:
|
113 |
+
df = df.head(n)
|
114 |
+
return df
|
115 |
+
|
116 |
+
try:
|
117 |
+
if path.startswith("gs://"):
|
118 |
+
df = await run_in_threadpool(_read_gcs, path, sample_size)
|
119 |
+
else:
|
120 |
+
df = await run_in_threadpool(_read_local, path, sample_size)
|
121 |
+
except Exception as e:
|
122 |
+
logger.error(f"Error loading documents: {e}")
|
123 |
+
raise HTTPException(status_code=500, detail="Document loading failed.")
|
124 |
+
|
125 |
+
docs: List[Document] = []
|
126 |
+
for row in df.rows(named=True):
|
127 |
+
context_parts: List[str] = []
|
128 |
+
# build metadata context
|
129 |
+
max_contrib = row.get("ecMaxContribution", "")
|
130 |
+
end_date = row.get("endDate", "")
|
131 |
+
duration = row.get("durationDays", "")
|
132 |
+
status = row.get("status", "")
|
133 |
+
legal = row.get("legalBasis", "")
|
134 |
+
framework = row.get("frameworkProgramme", "")
|
135 |
+
scheme = row.get("fundingScheme", "")
|
136 |
+
names = row.get("list_name", []) or []
|
137 |
+
cities = row.get("list_city", []) or []
|
138 |
+
countries = row.get("list_country", []) or []
|
139 |
+
activity = row.get("list_activityType", []) or []
|
140 |
+
contributions = row.get("list_ecContribution", []) or []
|
141 |
+
smes = row.get("list_sme", []) or []
|
142 |
+
project_id =row.get("id", "")
|
143 |
+
pred=row.get("predicted_label", "")
|
144 |
+
proba=row.get("predicted_prob", "")
|
145 |
+
top1_feats=row.get("top1_features", "")
|
146 |
+
top2_feats=row.get("top2_features", "")
|
147 |
+
top3_feats=row.get("top3_features", "")
|
148 |
+
top1_shap=row.get("top1_shap", "")
|
149 |
+
top2_shap=row.get("top2_shap", "")
|
150 |
+
top3_shap=row.get("top3_shap", "")
|
151 |
+
|
152 |
+
|
153 |
+
context_parts.append(
|
154 |
+
f"This project under framework {framework} with funding scheme {scheme}, status {status}, legal basis {legal}."
|
155 |
+
)
|
156 |
+
context_parts.append(
|
157 |
+
f"It ends on {end_date} after {duration} days and has a max EC contribution of {max_contrib}."
|
158 |
+
)
|
159 |
+
context_parts.append("Participating organizations:")
|
160 |
+
for i, name in enumerate(names):
|
161 |
+
city = cities[i] if i < len(cities) else ""
|
162 |
+
country = countries[i] if i < len(countries) else ""
|
163 |
+
act = activity[i] if i < len(activity) else ""
|
164 |
+
contrib = contributions[i] if i < len(contributions) else ""
|
165 |
+
sme_flag = "SME" if (smes and i < len(smes) and smes[i]) else "non-SME"
|
166 |
+
context_parts.append(
|
167 |
+
f"- {name} in {city}, {country}, activity: {act}, contributed: {contrib}, {sme_flag}."
|
168 |
+
)
|
169 |
+
if status in (None,"signed","SIGNED","Signed"):
|
170 |
+
if int(pred) == 1:
|
171 |
+
label = "TERMINATED"
|
172 |
+
score = float(proba)
|
173 |
+
else:
|
174 |
+
label = "CLOSED"
|
175 |
+
score = 1 - float(proba)
|
176 |
+
|
177 |
+
score_str = f"{score:.2f}"
|
178 |
+
|
179 |
+
context_parts.append(
|
180 |
+
f"- Project {project_id} is predicted to be {label} (score={score_str}). "
|
181 |
+
f"The 3 most predictive features were: "
|
182 |
+
f"{top1_feats} ({top1_shap:.3f}), "
|
183 |
+
f"{top2_feats} ({top2_shap:.3f}), "
|
184 |
+
f"{top3_feats} ({top3_shap:.3f})."
|
185 |
+
)
|
186 |
+
|
187 |
+
title_report = row.get("list_title_report", "")
|
188 |
+
objective = row.get("objective", "")
|
189 |
+
full_body = f"{title_report} {objective}"
|
190 |
+
full_text = " ".join(context_parts + [full_body])
|
191 |
+
meta: Dict[str, Any] = {"id": str(row.get("id", "")),"startDate": str(row.get("startDate", "")),"endDate": str(row.get("endDate", "")),"status":str(row.get("status", "")),"legalBasis":str(row.get("legalBasis",""))}
|
192 |
+
meta.update({"id": str(row.get("id", "")),"startDate": str(row.get("startDate", "")),"endDate": str(row.get("endDate", "")),"status":str(row.get("status", "")),"legalBasis":str(row.get("legalBasis",""))})
|
193 |
+
docs.append(Document(page_content=full_text, metadata=meta))
|
194 |
return docs
|
195 |
|
196 |
+
# === BM25 Search ===
|
197 |
+
async def bm25_search(ix: index.Index, query: str, k: int) -> List[Document]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
parser = MultifieldParser(["content"], schema=ix.schema)
|
199 |
+
def _search() -> List[Document]:
|
200 |
+
with ix.searcher() as searcher:
|
201 |
+
hits = searcher.search(parser.parse(query), limit=k)
|
202 |
+
return [Document(page_content=h["content"], metadata={"id": h["id"]}) for h in hits]
|
203 |
+
return await run_in_threadpool(_search)
|
204 |
+
|
205 |
+
# === Helper: build or load FAISS with mmap ===
|
206 |
+
async def build_or_load_faiss(
|
207 |
+
chunks: List[Document],
|
208 |
+
vectorstore_path: str,
|
209 |
+
batch_size: int = 15000
|
210 |
+
) -> FAISS:
|
211 |
+
faiss_index_file = os.path.join(vectorstore_path, "index.faiss")
|
212 |
+
# If on-disk exists: memory-map the FAISS index and load metadata separately
|
213 |
+
if os.path.exists(faiss_index_file):
|
214 |
+
logger.info("Memory-mapping existing FAISS index...")
|
215 |
+
mmap_idx = faiss.read_index(faiss_index_file, faiss.IO_FLAG_MMAP)
|
216 |
+
# Manually load metadata (docstore and index_to_docstore) without loading the index
|
217 |
+
import pickle
|
218 |
+
for meta_file in ["faiss.pkl", "index.pkl"]:
|
219 |
+
meta_path = os.path.join(vectorstore_path, meta_file)
|
220 |
+
if os.path.exists(meta_path):
|
221 |
+
with open(meta_path, "rb") as f:
|
222 |
+
saved = pickle.load(f)
|
223 |
+
break
|
224 |
+
else:
|
225 |
+
raise FileNotFoundError(
|
226 |
+
f"Could not find FAISS metadata pickle in {vectorstore_path}"
|
227 |
+
)
|
228 |
+
# extract metadata
|
229 |
+
if isinstance(saved, tuple):
|
230 |
+
# Handle metadata tuple of length 2 or 3
|
231 |
+
if len(saved) == 3:
|
232 |
+
_, docstore, index_to_docstore = saved
|
233 |
+
elif len(saved) == 2:
|
234 |
+
docstore, index_to_docstore = saved
|
235 |
+
else:
|
236 |
+
raise ValueError(f"Unexpected metadata tuple length: {len(saved)}")
|
237 |
+
else:
|
238 |
+
if hasattr(saved, 'docstore'):
|
239 |
+
docstore = saved.docstore
|
240 |
+
elif hasattr(saved, '_docstore'):
|
241 |
+
docstore = saved._docstore
|
242 |
+
else:
|
243 |
+
raise AttributeError("Could not find docstore in FAISS metadata")
|
244 |
+
if hasattr(saved, 'index_to_docstore'):
|
245 |
+
index_to_docstore = saved.index_to_docstore
|
246 |
+
elif hasattr(saved, '_index_to_docstore'):
|
247 |
+
index_to_docstore = saved._index_to_docstore
|
248 |
+
elif hasattr(saved, '_faiss_index_to_docstore'):
|
249 |
+
index_to_docstore = saved._faiss_index_to_docstore
|
250 |
+
else:
|
251 |
+
raise AttributeError("Could not find index_to_docstore in FAISS metadata")
|
252 |
+
# reconstruct FAISS wrapper
|
253 |
+
vs = FAISS(
|
254 |
+
embedding_function=EMBEDDING,
|
255 |
+
index=mmap_idx,
|
256 |
+
docstore=docstore,
|
257 |
+
index_to_docstore_id=index_to_docstore,
|
258 |
+
)
|
259 |
+
return vs
|
260 |
+
|
261 |
+
# 2) Else: build from scratch in batches
|
262 |
+
logger.info(f"Building FAISS index in batches of {batch_size}…")
|
263 |
+
vs: Optional[FAISS] = None
|
264 |
+
for i in tqdm(range(0, len(chunks), batch_size),
|
265 |
+
desc="Building FAISS index",
|
266 |
+
unit="batch"):
|
267 |
+
batch = chunks[i : i + batch_size]
|
268 |
+
|
269 |
+
if vs is None:
|
270 |
+
vs = FAISS.from_documents(batch, EMBEDDING)
|
271 |
+
else:
|
272 |
+
vs.add_documents(batch)
|
273 |
+
|
274 |
+
# periodic save every 5 batches
|
275 |
+
if (i // batch_size) % 5 == 0:
|
276 |
+
vs.save_local(vectorstore_path)
|
277 |
+
|
278 |
+
logger.info(f" • Saved batch up to document {i + len(batch)} / {len(chunks)}")
|
279 |
+
assert vs is not None, "No documents to index!"
|
280 |
+
return vs
|
281 |
+
|
282 |
+
# === Index Builder ===
|
283 |
+
async def build_indexes(
|
284 |
+
parquet_path: str,
|
285 |
+
vectorstore_path: str,
|
286 |
+
whoosh_dir: str,
|
287 |
+
chunk_size: int,
|
288 |
+
chunk_overlap: int,
|
289 |
+
debug_size: Optional[int]
|
290 |
+
) -> Tuple[FAISS, index.Index]:
|
291 |
+
docs = await load_documents(parquet_path, debug_size)
|
292 |
+
ix = await build_whoosh_index(docs, whoosh_dir)
|
293 |
|
|
|
|
|
|
|
294 |
splitter = RecursiveCharacterTextSplitter(
|
295 |
+
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
|
|
296 |
)
|
297 |
chunks = splitter.split_documents(docs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
|
299 |
+
# build or load (with mmap) FAISS
|
300 |
+
vs = await build_or_load_faiss(chunks, vectorstore_path)
|
301 |
|
302 |
+
return vs, ix
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frontend/src/components/ProjectDetails.tsx
CHANGED
@@ -129,6 +129,7 @@ export default function ProjectDetails({
|
|
129 |
<Box><Text fontWeight="bold">End Date</Text><Text>{fmtDate(project.endDate)}</Text></Box>
|
130 |
<Box><Text fontWeight="bold">Funding (EC max)</Text><Text>€{fmtNum(project.ecMaxContribution)}</Text></Box>
|
131 |
<Box><Text fontWeight="bold">Total Cost</Text><Text>€{fmtNum(project.totalCost)}</Text></Box>
|
|
|
132 |
<Box>
|
133 |
<Text fontWeight="bold">Legal Basis</Text>
|
134 |
<Text>{project.legalBasis}</Text>
|
@@ -272,7 +273,7 @@ export default function ProjectDetails({
|
|
272 |
<ResponsiveContainer width="100%" height={300}>
|
273 |
<BarChart data={shapData} margin={{ top: 10, right: 30, left: 0, bottom: 5 }}>
|
274 |
<CartesianGrid strokeDasharray="3 3" />
|
275 |
-
<XAxis dataKey="feature" />
|
276 |
<YAxis />
|
277 |
<Tooltip />
|
278 |
<Bar dataKey="shap" name="SHAP Value">
|
@@ -285,6 +286,10 @@ export default function ProjectDetails({
|
|
285 |
</Bar>
|
286 |
</BarChart>
|
287 |
</ResponsiveContainer>
|
|
|
|
|
|
|
|
|
288 |
</>
|
289 |
) : (
|
290 |
<Spinner />
|
|
|
129 |
<Box><Text fontWeight="bold">End Date</Text><Text>{fmtDate(project.endDate)}</Text></Box>
|
130 |
<Box><Text fontWeight="bold">Funding (EC max)</Text><Text>€{fmtNum(project.ecMaxContribution)}</Text></Box>
|
131 |
<Box><Text fontWeight="bold">Total Cost</Text><Text>€{fmtNum(project.totalCost)}</Text></Box>
|
132 |
+
<Box><Text fontWeight="bold">Funding Scheme</Text><Text>project.fundingScheme</Text></Box>
|
133 |
<Box>
|
134 |
<Text fontWeight="bold">Legal Basis</Text>
|
135 |
<Text>{project.legalBasis}</Text>
|
|
|
273 |
<ResponsiveContainer width="100%" height={300}>
|
274 |
<BarChart data={shapData} margin={{ top: 10, right: 30, left: 0, bottom: 5 }}>
|
275 |
<CartesianGrid strokeDasharray="3 3" />
|
276 |
+
<XAxis dataKey="feature" axisLine={false} tick={false} />
|
277 |
<YAxis />
|
278 |
<Tooltip />
|
279 |
<Bar dataKey="shap" name="SHAP Value">
|
|
|
286 |
</Bar>
|
287 |
</BarChart>
|
288 |
</ResponsiveContainer>
|
289 |
+
<Text fontSize="xs" color="gray.500" mt={2}>
|
290 |
+
Each bar shows how much that feature pushed the model's prediction.
|
291 |
+
Positive bars increase the chance of termination; Negative bars decrease it.
|
292 |
+
</Text>
|
293 |
</>
|
294 |
) : (
|
295 |
<Spinner />
|
predictive_modelling.py
CHANGED
@@ -7,6 +7,7 @@ import shap
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
import scipy.sparse
|
9 |
import polars as pl
|
|
|
10 |
import gcsfs
|
11 |
|
12 |
from sklearn.base import BaseEstimator, TransformerMixin
|
@@ -510,22 +511,126 @@ def score(new_df, model_dir="model_artifacts"):
|
|
510 |
explainer = shap.Explainer(clf, X_sel, feature_names=feature_names)
|
511 |
shap_vals = explainer(X_sel) # returns a ShapleyValues object
|
512 |
|
513 |
-
# 6) For each row, pick top-
|
514 |
shap_df = pd.DataFrame(shap_vals.values, columns=feature_names, index=df.index)
|
|
|
|
|
515 |
abs_shap = shap_df.abs()
|
516 |
|
|
|
517 |
top_feats = abs_shap.apply(lambda row: row.nlargest(6).index.tolist(), axis=1)
|
518 |
-
top_vals = abs_shap.apply(lambda row: row.nlargest(6).values.tolist(), axis=1)
|
519 |
|
520 |
-
|
521 |
-
|
522 |
-
)
|
523 |
-
|
524 |
-
|
525 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
526 |
|
527 |
return df
|
528 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
529 |
if __name__ == "__main__":
|
530 |
# Entry point for training and scoring: loads data from Google Cloud Storage,
|
531 |
# builds model and artifacts, then scores the same data as a test.
|
@@ -539,5 +644,5 @@ if __name__ == "__main__":
|
|
539 |
df = pl.read_parquet(f).to_pandas()
|
540 |
|
541 |
status_prediction_model(df)
|
542 |
-
|
543 |
-
|
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
import scipy.sparse
|
9 |
import polars as pl
|
10 |
+
import re
|
11 |
import gcsfs
|
12 |
|
13 |
from sklearn.base import BaseEstimator, TransformerMixin
|
|
|
511 |
explainer = shap.Explainer(clf, X_sel, feature_names=feature_names)
|
512 |
shap_vals = explainer(X_sel) # returns a ShapleyValues object
|
513 |
|
514 |
+
# 6) For each row, pick top-6 absolute contributors
|
515 |
shap_df = pd.DataFrame(shap_vals.values, columns=feature_names, index=df.index)
|
516 |
+
|
517 |
+
# 7) get absolute values
|
518 |
abs_shap = shap_df.abs()
|
519 |
|
520 |
+
# 8) for each row, record the top‐6 feature names by absolute magnitude
|
521 |
top_feats = abs_shap.apply(lambda row: row.nlargest(6).index.tolist(), axis=1)
|
|
|
522 |
|
523 |
+
# 9) convert that to six separate columns
|
524 |
+
feat_cols = [f"top{i}_feature" for i in range(1,7)]
|
525 |
+
df[feat_cols] = pd.DataFrame(top_feats.tolist(), index=df.index)
|
526 |
+
|
527 |
+
# 10) now build the *true* SHAP values by looking up each name in shap_df
|
528 |
+
# for each row, shap_df.loc[idx, feat] is the signed value
|
529 |
+
top_vals = [
|
530 |
+
[ shap_df.loc[idx, feat] for feat in feats ]
|
531 |
+
for idx, feats in top_feats.items()
|
532 |
+
]
|
533 |
+
|
534 |
+
# 11) store them in your six shap‐value columns
|
535 |
+
val_cols = [f"top{i}_shap" for i in range(1,7)]
|
536 |
+
df[val_cols] = pd.DataFrame(top_vals, index=df.index)
|
537 |
|
538 |
return df
|
539 |
|
540 |
+
def clean_feature_name(raw: str) -> str:
|
541 |
+
"""
|
542 |
+
- cat__: "cat__feature_value" → "Feature: Value"
|
543 |
+
- num__: "num__some_count" → "Some Count"
|
544 |
+
- mlb_: "mlb_list_activityType__list_activityType_Research"
|
545 |
+
→ "List Activity Type: Research"
|
546 |
+
"""
|
547 |
+
if not raw:
|
548 |
+
return ""
|
549 |
+
|
550 |
+
# 1) cat__
|
551 |
+
if raw.startswith("cat__"):
|
552 |
+
s = raw[len("cat__"):]
|
553 |
+
col, val = (s.split("__", 1) + [None])[:2]
|
554 |
+
col_c = col.replace("_", " ").title()
|
555 |
+
if val:
|
556 |
+
val_c = val.replace("_", " ").title()
|
557 |
+
return f"{col_c}: {val_c}"
|
558 |
+
return col_c
|
559 |
+
|
560 |
+
# 2) num__
|
561 |
+
if raw.startswith("num__"):
|
562 |
+
s = raw[len("num__"):]
|
563 |
+
return s.replace("_", " ").replace('n ','Number of ')
|
564 |
+
|
565 |
+
# 3) mlb_
|
566 |
+
if raw.startswith("mlb_"):
|
567 |
+
s = raw[len("mlb_"):]
|
568 |
+
col_part, val_part = (s.split("__", 1) + [None])[:2]
|
569 |
+
# drop leading "list_" on the column
|
570 |
+
if col_part.startswith("list_"):
|
571 |
+
col_inner = col_part[len("list_"):]
|
572 |
+
else:
|
573 |
+
col_inner = col_part
|
574 |
+
col_c = col_inner.replace("_", " ").title()
|
575 |
+
col_c = "List " + col_c
|
576 |
+
|
577 |
+
if val_part:
|
578 |
+
# drop "list_{col_inner}_" or leading "list_"
|
579 |
+
prefix = f"list_{col_inner}_"
|
580 |
+
if val_part.startswith(prefix):
|
581 |
+
val_inner = val_part[len(prefix):]
|
582 |
+
elif val_part.startswith("list_"):
|
583 |
+
val_inner = val_part[len("list_"):]
|
584 |
+
else:
|
585 |
+
val_inner = val_part
|
586 |
+
val_c = val_inner.replace("_", " ").title()
|
587 |
+
return f"{col_c}: {val_c}"
|
588 |
+
return col_c
|
589 |
+
|
590 |
+
# fallback: replace __ → ": ", _ → " "
|
591 |
+
return raw.replace("__", ": ").replace("_", " ").title()
|
592 |
+
|
593 |
+
|
594 |
+
def preprocess_feature_names(df: pl.DataFrame) -> pl.DataFrame:
|
595 |
+
transforms = []
|
596 |
+
|
597 |
+
# clean and round top-6 features & shap values
|
598 |
+
for i in range(1, 7):
|
599 |
+
fcol = f"top{i}_feature"
|
600 |
+
scol = f"top{i}_shap"
|
601 |
+
|
602 |
+
if fcol in df.columns:
|
603 |
+
transforms.append(
|
604 |
+
pl.col(fcol)
|
605 |
+
.map_elements(clean_feature_name, return_dtype=pl.Utf8)
|
606 |
+
.alias(fcol)
|
607 |
+
)
|
608 |
+
if scol in df.columns:
|
609 |
+
transforms.append(
|
610 |
+
pl.col(scol)
|
611 |
+
.round(4)
|
612 |
+
.alias(scol)
|
613 |
+
)
|
614 |
+
|
615 |
+
# round overall predicted probability
|
616 |
+
if "predicted_prob" in df.columns:
|
617 |
+
transforms.append(
|
618 |
+
pl.col("predicted_prob")
|
619 |
+
.round(4)
|
620 |
+
.alias("predicted_prob")
|
621 |
+
)
|
622 |
+
|
623 |
+
# 1) build the full list of embed-columns
|
624 |
+
embed_cols = [f"title_embed_{i}" for i in range(50)] + \
|
625 |
+
[f"objective_embed_{i}" for i in range(50)]
|
626 |
+
|
627 |
+
# 2) keep only the ones that actually exist in df.columns
|
628 |
+
to_drop = [c for c in embed_cols if c in df.columns]
|
629 |
+
|
630 |
+
# 3) drop them
|
631 |
+
df = df.drop(to_drop)
|
632 |
+
return df.with_columns(transforms)
|
633 |
+
|
634 |
if __name__ == "__main__":
|
635 |
# Entry point for training and scoring: loads data from Google Cloud Storage,
|
636 |
# builds model and artifacts, then scores the same data as a test.
|
|
|
644 |
df = pl.read_parquet(f).to_pandas()
|
645 |
|
646 |
status_prediction_model(df)
|
647 |
+
df_clean = preprocess_feature_names(pl.from_pandas(score(df)))
|
648 |
+
df_clean.head(10)
|