MDA / backend /main.py
Rom89823974978's picture
Update date type
674ebda
import logging
import os
import shutil
import tempfile
import traceback
from contextlib import asynccontextmanager
from functools import lru_cache
from typing import Any, AsyncGenerator, Dict, List, Optional
import aiofiles
import faiss
import gcsfs
import polars as pl
import torch
import zipfile
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, PrivateAttr
from pydantic_settings import BaseSettings as SettingsBase
from sentence_transformers import CrossEncoder
from starlette.concurrency import run_in_threadpool
from tqdm import tqdm
from transformers import ( # Transformers for LLM pipeline
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
pipeline,
T5ForConditionalGeneration,
T5Tokenizer,
MT5ForConditionalGeneration, MT5TokenizerFast
)
# LangChain imports for RAG
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts import PromptTemplate
from langchain.schema import BaseRetriever, Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain_community.vectorstores import FAISS
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
# Project-specific imports
from app.rag import build_indexes, HybridRetriever
# ---------------------------------------------------------------------------- #
# Settings #
# ---------------------------------------------------------------------------- #
class Settings(SettingsBase):
"""
Configuration settings loaded from environment or .env file.
"""
# Data sources
parquet_path: str = "gs://mda_kul_project/data/consolidated_clean_pred.parquet"
whoosh_dir: str = "gs://mda_kul_project/whoosh_index"
vectorstore_path: str = "gs://mda_kul_project/vectorstore_index"
# Model names
embedding_model: str = "sentence-transformers/LaBSE"
llm_model: str = "google/mt5-base"
cross_encoder_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
# RAG parameters
chunk_size: int = 750
chunk_overlap: int = 100
hybrid_k: int = 4
assistant_role: str = (
"You are a knowledgeable project analyst. You have access to the following retrieved document snippets."
)
skip_warmup: bool = True
# CORS
allowed_origins: List[str] = ["*"]
class Config:
env_file = ".env"
# Instantiate settings and logger
settings = Settings()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Pre-instantiate the embedding model for reuse
EMBEDDING = HuggingFaceEmbeddings(
model_name=settings.embedding_model,
model_kwargs={"trust_remote_code": True},
)
@lru_cache(maxsize=256)
def embed_query_cached(query: str) -> List[float]:
"""Cache embedding vectors for repeated queries."""
return EMBEDDING.embed_query(query.strip().lower())
# ---------------------------------------------------------------------------- #
# Application Lifespan #
# ---------------------------------------------------------------------------- #
app = FastAPI(lifespan=lambda app: lifespan(app))
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""
Startup: initialize RAG chain, embeddings, memory, indexes, and load data.
Shutdown: clean up resources if needed.
"""
# 1) Initialize document compressor
logger.info("Initializing Document Compressor")
compressor = DocumentCompressorPipeline(
transformers=[EmbeddingsRedundantFilter(embeddings=EMBEDDING)]
)
# 2) Initialize and quantize Cross-Encoder
logger.info("Initializing Cross-Encoder")
cross_encoder = CrossEncoder(settings.cross_encoder_model)
cross_encoder.model = torch.quantization.quantize_dynamic(
cross_encoder.model,
{torch.nn.Linear},
dtype=torch.qint8,
)
logger.info("Cross-Encoder quantized")
# 3) Build Seq2Seq pipeline and wrap in LangChain
logger.info("Initializing LLM pipeline")
tokenizer = MT5TokenizerFast.from_pretrained(settings.llm_model)
model = MT5ForConditionalGeneration.from_pretrained(settings.llm_model)
model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
# assemble the pipeline
gen_pipe = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
device=-1,
max_new_tokens=256,
do_sample=True,
temperature=0.8,
top_k=30,
top_p=0.95,
repetition_penalty=1.2,
no_repeat_ngram_size=3,
use_cache=True,
)
llm = HuggingFacePipeline(pipeline=gen_pipe)
# 4) Initialize conversation memory
logger.info("Initializing Conversation Memory")
memory = ConversationBufferWindowMemory(
memory_key="chat_history",
input_key="question",
output_key="answer",
return_messages=True,
k=5,
)
# 5) Build or load indexes for vectorstore and Whoosh
logger.info("Building or loading indexes")
vs, ix = await build_indexes(
settings.parquet_path,
settings.vectorstore_path,
settings.whoosh_dir,
settings.chunk_size,
settings.chunk_overlap,
None,
)
retriever = HybridRetriever(vs=vs, ix=ix, compressor=compressor, cross_encoder=cross_encoder)
# 6) Define prompt template for RAG chain
prompt = PromptTemplate.from_template(
f"{settings.assistant_role}\n"
"{context}\n"
"User Question:\n{question}\n"
"Answer:" # Rules are embedded in assistant_role
)
# 7) Instantiate the conversational retrieval chain
logger.info("Initializing Retrieval Chain")
app.state.rag_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
memory=memory,
combine_docs_chain_kwargs={"prompt": prompt},
return_source_documents=True,
)
# Optional warmup
if not settings.skip_warmup:
logger.info("Warming up RAG chain")
await app.state.rag_chain.ainvoke({"question": "warmup"})
# 8) Load project data into Polars DataFrame
logger.info("Loading Parquet data from GCS")
fs = gcsfs.GCSFileSystem()
with fs.open(settings.parquet_path, "rb") as f:
df = pl.read_parquet(f)
# Cast id to integer and lowercase key columns for filtering
df = df.with_columns(
pl.col("id").cast(pl.Int64),
*(pl.col(col).str.to_lowercase().alias(f"_{col}_lc") for col in [
"title", "status", "legalBasis", "fundingScheme"
])
)
# Cache DataFrame and filter values in app state
app.state.df = df
app.state.statuses = df["_status_lc"].unique().to_list()
app.state.legal_bases = df["_legalBasis_lc"].unique().to_list()
app.state.orgs_list = df.explode("list_name")["list_name"].unique().to_list()
app.state.countries_list = df.explode("list_country")["list_country"].unique().to_list()
yield # Application is ready
# ---------------------------------------------------------------------------- #
# App Setup #
# ---------------------------------------------------------------------------- #
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.allowed_origins,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------------------------------- #
# Pydantic Models #
# ---------------------------------------------------------------------------- #
class RAGRequest(BaseModel):
session_id: Optional[str] = None # Optional conversation ID
query: str # User's query text
class RAGResponse(BaseModel):
answer: str
source_ids: List[str]
# ---------------------------------------------------------------------------- #
# RAG Endpoint #
# ---------------------------------------------------------------------------- #
def rag_chain_depender(app: FastAPI = Depends(lambda: app)) -> Any:
"""
Dependency injector to retrieve the initialized RAG chain from the application state.
Raises HTTPException if chain is not yet initialized.
"""
chain = app.state.rag_chain
if chain is None:
# If the chain isn't set up, respond with a 500 server error
raise HTTPException(status_code=500, detail="RAG chain not initialized")
return chain
@app.post("/api/rag", response_model=RAGResponse)
async def ask_rag(
req: RAGRequest,
rag_chain = Depends(rag_chain_depender)
):
"""
Endpoint to process a RAG-based query.
1. Logs start of processing.
2. Invokes the RAG chain asynchronously with the user question.
3. Validates returned result structure and extracts answer + source IDs.
4. Handles any exceptions by logging traceback and returning a JSON error.
"""
logger.info("Starting to answer RAG query")
try:
# Asynchronously invoke the chain to get answer + docs
result = await rag_chain.ainvoke({"question": req.query})
logger.info("RAG results retrieved")
# Validate that the chain returned expected dict
if not isinstance(result, dict):
# Try sync call for debugging
result2 = await rag_chain.acall({"question": req.query})
raise ValueError(
f"Expected dict from chain, got {type(result)}; "
f"acall() returned {type(result2)}"
)
# Extract answer text and source document IDs
answer = result.get("answer")
docs = result.get("source_documents", [])
sources = [d.metadata.get("id", "") for d in docs]
return RAGResponse(answer=answer, source_ids=sources)
except Exception as e:
# Log full stacktrace to container logs
traceback.print_exc()
# Return HTTP 500 with error detail
raise HTTPException(status_code=500, detail=str(e))
# ---------------------------------------------------------------------------- #
# Data Endpoints #
# ---------------------------------------------------------------------------- #
@app.get("/api/projects")
def get_projects(
page: int = 0,
limit: int = 10,
search: str = "",
status: str = "",
legalBasis: str = "",
organization: str = "",
country: str = "",
fundingScheme: str = "",
proj_id: str = "",
topic: str = "",
sortOrder: str = "desc",
sortField: str = "startDate",
):
"""
Paginated project listing with optional filtering and sorting.
Query Parameters:
- page: zero-based page index
- limit: number of items per page
- search: substring search in project title
- status, legalBasis, organization, country, fundingScheme: filters
- proj_id: exact project ID filter
- topic: filter by EuroSciVoc topic
- sortOrder: 'asc' or 'desc'
- sortField: field name to sort by (fallback to startDate)
Returns a list of project dicts including explanations and publication counts.
"""
df: pl.DataFrame = app.state.df
start = page * limit
sel = df
# Apply text and field filters as needed
if search:
sel = sel.filter(pl.col("_title_lc").str.contains(search.lower()))
if status:
sel = sel.filter(
pl.col("status").is_null() if status == "UNKNOWN"
else pl.col("_status_lc") == status.lower()
)
if legalBasis:
sel = sel.filter(pl.col("_legalBasis_lc") == legalBasis.lower())
if organization:
sel = sel.filter(pl.col("list_name").list.contains(organization))
if country:
sel = sel.filter(pl.col("list_country").list.contains(country))
if fundingScheme:
sel = sel.filter(pl.col("_fundingScheme_lc").str.contains(fundingScheme.lower()))
if proj_id:
sel = sel.filter(pl.col("id") == int(proj_id))
if topic:
sel = sel.filter(pl.col("list_euroSciVocTitle").list.contains(topic))
# Base columns to return
base_cols = [
"id","title","status","startDate","endDate","ecMaxContribution","acronym",
"legalBasis","objective","frameworkProgramme","list_euroSciVocTitle",
"list_euroSciVocPath","totalCost","list_isPublishedAs","fundingScheme"
]
# Append top feature & SHAP value columns
for i in range(1,7):
base_cols += [f"top{i}_feature", f"top{i}_shap"]
base_cols += ["predicted_label","predicted_prob"]
# Determine sort direction and safe field
sort_desc = sortOrder.lower() == "desc"
sortField = sortField if sortField in df.columns else "startDate"
# Query, sort, slice, and collect to Python dicts
rows = (
sel.sort(sortField, descending=sort_desc)
.slice(start, limit)
.select(base_cols)
.to_dicts()
)
projects = []
for row in rows:
# Reformat SHAP explanations into list of dicts
explanations = []
for i in range(1,7):
feat = row.pop(f"top{i}_feature", None)
shap = row.pop(f"top{i}_shap", None)
if feat is not None and shap is not None:
explanations.append({"feature": feat, "shap": shap})
row["explanations"] = explanations
# Aggregate publications counts
raw_pubs = row.pop("list_publications", []) or []
pub_counts: Dict[str,int] = {}
for p in raw_pubs:
pub_counts[p] = pub_counts.get(p, 0) + 1
row["publications"] = pub_counts
row["startDate"] = row["startDate"].date().isoformat() if row["startDate"] else None
row["endDate"] = row["endDate"].date().isoformat() if row["endDate"] else None
projects.append(row)
return projects
@app.get("/api/filters")
def get_filters(request: Request):
"""
Retrieve available filter options based on current dataset and optional query filters.
Returns JSON with lists for statuses, legalBases, organizations, countries, and fundingSchemes.
"""
df = app.state.df
params = request.query_params
# Dynamically filter df based on provided params
if s := params.get("status"):
df = df.filter(pl.col("status").is_null() if s == "UNKNOWN"
else pl.col("_status_lc") == s.lower())
if lb := params.get("legalBasis"):
df = df.filter(pl.col("_legalBasis_lc") == lb.lower())
if org := params.get("organization"):
df = df.filter(pl.col("list_name").list.contains(org))
if c := params.get("country"):
df = df.filter(pl.col("list_country").list.contains(c))
if t := params.get("topic"):
df = df.filter(pl.col("list_euroSciVocTitle").list.contains(t))
if fs := params.get("fundingScheme"):
df = df.filter(pl.col("_fundingScheme_lc").str.contains(fs.lower()))
if search := params.get("search"):
df = df.filter(pl.col("_title_lc").str.contains(search.lower()))
def normalize(vals):
# Map None to "UNKNOWN" and return sorted unique list
return sorted({("UNKNOWN" if v is None else v) for v in vals})
return {
"statuses": normalize(df["status"].to_list()),
"legalBases": normalize(df["legalBasis"].to_list()),
"organizations": normalize(df["list_name"].explode().to_list()),
"countries": normalize(df["list_country"].explode().to_list()),
"fundingSchemes": normalize(df["fundingScheme"].to_list()),
"topics": normalize(df["list_euroSciVocTitle"].explode().to_list()),
}
@app.get("/api/stats")
def get_stats(request: Request):
"""
Compute various statistics on projects with optional filters for status,
legal basis, funding, start/end years, etc. Returns a dict of chart data.
"""
df = app.state.df
lf = df.lazy()
params = request.query_params
# Apply filters
if s := params.get("status"):
lf = lf.filter(pl.col("_status_lc") == s.lower())
df = df.filter(pl.col("_status_lc") == s.lower())
if lb := params.get("legalBasis"):
lf = lf.filter(pl.col("_legalBasis_lc") == lb.lower())
df = df.filter(pl.col("_legalBasis_lc") == lb.lower())
if org := params.get("organization"):
lf = lf.filter(pl.col("list_name").list.contains(org))
df = df.filter(pl.col("list_name").list.contains(org))
if c := params.get("country"):
lf = lf.filter(pl.col("list_country").list.contains(c))
df = df.filter(pl.col("list_country").list.contains(c))
if eu := params.get("topic"):
lf = lf.filter(pl.col("list_euroSciVocTitle").list.contains(eu))
df = df.filter(pl.col("list_euroSciVocTitle").list.contains(eu))
if fs := params.get("fundingScheme"):
lf = lf.filter(pl.col("_fundingScheme_lc").str.contains(fs.lower()))
df = df.filter(pl.col("_fundingScheme_lc").str.contains(fs.lower()))
if mn := params.get("minFunding"):
lf = lf.filter(pl.col("ecMaxContribution") >= int(mn))
df = df.filter(pl.col("ecMaxContribution") >= int(mn))
if mx := params.get("maxFunding"):
lf = lf.filter(pl.col("ecMaxContribution") <= int(mx))
df = df.filter(pl.col("ecMaxContribution") <= int(mx))
if y1 := params.get("minYear"):
lf = lf.filter(pl.col("startDate").dt.year() >= int(y1))
df = df.filter(pl.col("startDate").dt.year() >= int(y1))
if y2 := params.get("maxYear"):
lf = lf.filter(pl.col("startDate").dt.year() <= int(y2))
df = df.filter(pl.col("startDate").dt.year() <= int(y2))
if ye1 := params.get("minEndYear"):
lf = lf.filter(pl.col("endDate").dt.year() >= int(ye1))
df = df.filter(pl.col("endDate").dt.year() >= int(ye1))
if ye2 := params.get("maxEndYear"):
lf = lf.filter(pl.col("endDate").dt.year() <= int(ye2))
df = df.filter(pl.col("endDate").dt.year() <= int(ye2))
# Helper to drop any None/null entries
def clean_data(labels: list, values: list) -> tuple[list, list]:
pairs = [(l, v) for l, v in zip(labels, values) if l is not None and v is not None]
if not pairs:
return [], []
lbls, vals = zip(*pairs)
return list(lbls), list(vals)
# 1) Projects per Year (Line)
yearly = (
lf.select(pl.col("startDate").dt.year().alias("year"))
.group_by("year")
.agg(pl.count().alias("count"))
.sort("year")
.collect()
)
years = yearly["year"].to_list()
year_counts = yearly["count"].to_list()
# fixed bucket order
size_order = ["<100 K","100 K–500 K","500 K–1 M","1 M–5 M","5 M–10 M","≥10 M"]
# 2) Project-Size Distribution (Bar)
size_buckets = (
df.with_columns(
pl.when(pl.col("totalCost") < 100_000).then(pl.lit("<100 K"))
.when(pl.col("totalCost") < 500_000).then(pl.lit("100 K–500 K"))
.when(pl.col("totalCost") < 1_000_000).then(pl.lit("500 K–1 M"))
.when(pl.col("totalCost") < 5_000_000).then(pl.lit("1 M–5 M"))
.when(pl.col("totalCost") < 10_000_000).then(pl.lit("5 M–10 M"))
.otherwise(pl.lit("≥10 M"))
.alias("size_range")
)
.group_by("size_range")
.agg(pl.count().alias("count"))
.with_columns(
pl.col("size_range")
.replace_strict(size_order, list(range(len(size_order))))
.alias("order")
)
.sort("order")
)
size_labels = size_buckets["size_range"].to_list()
size_counts = size_buckets["count"].to_list()
# 3) Scheme Frequency (Bar)
scheme_counts_df = (
df.with_columns(
pl.col("fundingScheme")
.cast(pl.List(pl.Utf8))
.alias("fundingScheme")
)
.group_by("fundingScheme")
.agg(pl.count().alias("count"))
.sort("count", descending=True)
.head(10)
)
scheme_labels = scheme_counts_df["fundingScheme"].to_list()
scheme_values = scheme_counts_df["count"].to_list()
# 4) Top 10 Macro Topics by EC Contribution (Bar)
top_topics = (
df.explode("list_euroSciVocTitle")
.group_by("list_euroSciVocTitle")
.agg(pl.col("ecMaxContribution").sum().alias("total_ec"))
.sort("total_ec", descending=True)
.head(10)
)
topic_labels = top_topics["list_euroSciVocTitle"].to_list()
topic_values = (top_topics["total_ec"] / 1e6).round(1).to_list()
# 5) Projects by Funding Range (Pie)
fund_range = (
df.with_columns(
pl.when(pl.col("ecMaxContribution") < 100_000).then(pl.lit("<100 K"))
.when(pl.col("ecMaxContribution") < 500_000).then(pl.lit("100 K–500 K"))
.when(pl.col("ecMaxContribution") < 1_000_000).then(pl.lit("500 K–1 M"))
.when(pl.col("ecMaxContribution") < 5_000_000).then(pl.lit("1 M–5 M"))
.when(pl.col("ecMaxContribution") < 10_000_000).then(pl.lit("5 M–10 M"))
.otherwise(pl.lit("≥10 M"))
.alias("funding_range")
)
.group_by("funding_range")
.agg(pl.count().alias("count"))
.with_columns(
pl.col("funding_range")
.replace_strict(size_order, list(range(len(size_order))))
.alias("order")
)
.sort("order")
)
fr_labels = fund_range["funding_range"].to_list()
fr_counts = fund_range["count"].to_list()
# 6) Projects per Country (Doughnut)
country = (
df.explode("list_country")
.group_by("list_country")
.agg(pl.count().alias("count"))
.sort("count", descending=True)
.head(10)
)
country_labels = country["list_country"].to_list()
country_counts = country["count"].to_list()
# Clean out any nulls before returning
years, year_counts = clean_data(years, year_counts)
size_labels, size_counts = clean_data(size_labels, size_counts)
scheme_labels, scheme_values = clean_data(scheme_labels, scheme_values)
topic_labels, topic_values = clean_data(topic_labels, topic_values)
fr_labels, fr_counts = clean_data(fr_labels, fr_counts)
country_labels, country_counts= clean_data(country_labels, country_counts)
return {
"ppy": {"labels": years, "values": year_counts},
"psd": {"labels": size_labels, "values": size_counts},
"frs": {"labels": scheme_labels, "values": scheme_values},
"top10": {"labels": topic_labels, "values": topic_values},
"frb": {"labels": fr_labels, "values": fr_counts},
"ppc": {"labels": country_labels, "values": country_counts},
}
@app.get("/api/project/{project_id}/organizations")
def get_project_organizations(project_id: str):
"""
Retrieve organization details for a given project ID, including geolocation.
Raises 404 if the project ID does not exist.
"""
df = app.state.df
sel = df.filter(pl.col("id") == int(project_id))
if sel.is_empty():
raise HTTPException(status_code=404, detail="Project not found")
# Explode list columns and parse latitude/longitude
orgs_df = (
sel.select([
pl.col("list_name").explode().alias("name"),
pl.col("list_city").explode().alias("city"),
pl.col("list_SME").explode().alias("sme"),
pl.col("list_role").explode().alias("role"),
pl.col("list_organizationURL").explode().alias("orgURL"),
pl.col("list_ecContribution").explode().alias("contribution"),
pl.col("list_activityType").explode().alias("activityType"),
pl.col("list_country").explode().alias("country"),
pl.col("list_geolocation").explode().alias("geoloc"),
])
.with_columns([
# Split "lat,lon" string into list
pl.col("geoloc").str.split(",").alias("latlon"),
])
.with_columns([
# Cast to floats for numeric use
pl.col("latlon").list.get(0).cast(pl.Float64).alias("latitude"),
pl.col("latlon").list.get(1).cast(pl.Float64).alias("longitude"),
])
.filter(pl.col("name").is_not_null())
.select([
"name","city","sme","role","contribution",
"activityType","orgURL","country","latitude","longitude"
])
)
logger.info(f"Organization data for project {project_id}: {orgs_df.to_dicts()}")
return orgs_df.to_dicts()