from fastapi import FastAPI, Request, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi import Request # Access the global FastAPI app state #from fastapi import current_app from pydantic import BaseModel #try: # from rag import get_rag_chain, RAGRequest, RAGResponse #except: # from .rag import get_rag_chain, RAGRequest, RAGResponse from contextlib import asynccontextmanager import os import polars as pl import gcsfs os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = r"C:\Users\Romain\OneDrive - KU Leuven\focal-pager-460414-e9-45369b738be0.json" @asynccontextmanager async def lifespan(app: FastAPI): bucket = "mda_eu_project" path = "data/consolidated_clean.parquet" uri = f"gs://{bucket}/{path}" fs = gcsfs.GCSFileSystem() with fs.open(uri, "rb") as f: df = pl.read_parquet(f) for col in ("title", "status", "legalBasis"): df = df.with_columns(pl.col(col).str.to_lowercase().alias(f"_{col}_lc")) statuses = df["_status_lc"].unique().to_list() legal_bases = df["_legalBasis_lc"].unique().to_list() organizations = df.explode("list_name")["list_name"].unique().to_list() countries = df.explode("list_country")["list_country"].unique().to_list() app.state.df = df app.state.statuses = statuses app.state.legal_bases = legal_bases app.state.orgs_list = organizations app.state.countries_list = countries yield app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) @app.get("/api/projects") def get_projects(page: int = 0, limit: int = 10, search: str = "", status: str = ""): df = app.state.df start = page * limit sel = df if search != "": sel = sel.filter(pl.col("_title_lc").str.contains(search.lower())) if status != "": sel = sel.filter(pl.col("_status_lc") == status.lower()) return ( sel.slice(start, limit) .select([ "id","title","status","startDate","ecMaxContribution", "acronym","endDate","legalBasis","objective", "frameworkProgramme","list_euroSciVocTitle", "list_euroSciVocPath" ]) .to_dicts() ) @app.get("/api/filters") def get_filters(): return { "statuses": app.state.statuses, "legalBases": app.state.legal_bases, "organizations": app.state.orgs_list, "countries": app.state.countries_list } @app.get("/api/stats") def get_stats(request: Request): params = request.query_params lf = app.state.df.lazy() if s := params.get("status"): lf = lf.filter(pl.col("_status_lc") == s.lower()) if lb := params.get("legalBasis"): lf = lf.filter(pl.col("_legalBasis_lc") == lb.lower()) if org := params.get("organization"): lf = lf.filter(pl.col("list_name").list.contains(org)) if c := params.get("country"): lf = lf.filter(pl.col("list_country").list.contains(c)) if mn := params.get("minFunding"): lf = lf.filter(pl.col("ecMaxContribution") >= int(mn)) if mx := params.get("maxFunding"): lf = lf.filter(pl.col("ecMaxContribution") <= int(mx)) if y1 := params.get("minYear"): lf = lf.filter(pl.col("startDate").dt.year() >= int(y1)) if y2 := params.get("maxYear"): lf = lf.filter(pl.col("startDate").dt.year() <= int(y2)) grouped = ( lf.select(pl.col("startDate").dt.year().alias("year")) .group_by("year") .agg(pl.count().alias("count")) .sort("year") .collect() ) years, counts = grouped["year"].to_list(), grouped["count"].to_list() return {"Projects per Year": {"labels": years, "values": counts}, "Projects per Year 2": {"labels": years, "values": counts}, "Projects per Year 3": {"labels": years, "values": counts}, "Projects per Year 4": {"labels": years, "values": counts}, "Projects per Year 5": {"labels": years, "values": counts}, "Projects per Year 6": {"labels": years, "values": counts}} @app.get("/api/project/{project_id}/organizations") def get_project_organizations(project_id: str): df = app.state.df sel = df.filter(pl.col("id") == project_id) if sel.is_empty(): raise HTTPException(status_code=404, detail="Project not found") orgs_df = ( sel .select([ pl.col("list_name").explode().alias("name"), pl.col("list_country").explode().alias("country"), pl.col("list_geolocation").explode().alias("geoloc"), ]) .with_columns([ # now this is a List(Utf8) pl.col("geoloc").str.split(",").alias("latlon"), ]) .with_columns([ pl.col("latlon").list.get(0).cast(pl.Float64).alias("latitude"), pl.col("latlon").list.get(1).cast(pl.Float64).alias("longitude"), ]) .filter(pl.col("name").is_not_null()) .select(["name", "country", "latitude", "longitude"]) ) return orgs_df.to_dicts() """def rag_chain_depender(): """ #Dependency injector for the RAG chain stored in app.state. #Raises HTTPException if not initialized. """ chain = current_app.state.rag_chain if chain is None: raise HTTPException(status_code=500, detail="RAG chain not initialized") return chain @app.post("/rag", response_model=RAGResponse) async def ask_rag( req: RAGRequest, rag_chain = Depends(rag_chain_depender) ): """ #Handle a RAG query. Uses session memory and the provided RAG chain. """ # Invoke the chain with the named input result = await rag_chain.ainvoke({"question": req.query}) return RAGResponse(answer=result["answer"])"""