Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| import shutil | |
| import tempfile | |
| import traceback | |
| from contextlib import asynccontextmanager | |
| from functools import lru_cache | |
| from typing import Any, AsyncGenerator, Dict, List, Optional | |
| import aiofiles | |
| import faiss | |
| import gcsfs | |
| import polars as pl | |
| import torch | |
| import zipfile | |
| from fastapi import Depends, FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, PrivateAttr | |
| from pydantic_settings import BaseSettings as SettingsBase | |
| from sentence_transformers import CrossEncoder | |
| from starlette.concurrency import run_in_threadpool | |
| from tqdm import tqdm | |
| from transformers import ( # Transformers for LLM pipeline | |
| AutoModelForCausalLM, | |
| AutoModelForSeq2SeqLM, | |
| AutoTokenizer, | |
| pipeline, | |
| T5ForConditionalGeneration, | |
| T5Tokenizer, | |
| MT5ForConditionalGeneration, MT5TokenizerFast | |
| ) | |
| # LangChain imports for RAG | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.memory import ConversationBufferWindowMemory | |
| from langchain.prompts import PromptTemplate | |
| from langchain.schema import BaseRetriever, Document | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.document_transformers import EmbeddingsRedundantFilter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.retrievers.document_compressors import DocumentCompressorPipeline | |
| from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline | |
| # Project-specific imports | |
| from app.rag import build_indexes, HybridRetriever | |
| # ---------------------------------------------------------------------------- # | |
| # Settings # | |
| # ---------------------------------------------------------------------------- # | |
| class Settings(SettingsBase): | |
| """ | |
| Configuration settings loaded from environment or .env file. | |
| """ | |
| # Data sources | |
| parquet_path: str = "gs://mda_kul_project/data/consolidated_clean_pred.parquet" | |
| whoosh_dir: str = "gs://mda_kul_project/whoosh_index" | |
| vectorstore_path: str = "gs://mda_kul_project/vectorstore_index" | |
| # Model names | |
| embedding_model: str = "sentence-transformers/LaBSE" | |
| llm_model: str = "google/mt5-base" | |
| cross_encoder_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1" | |
| # RAG parameters | |
| chunk_size: int = 750 | |
| chunk_overlap: int = 100 | |
| hybrid_k: int = 4 | |
| assistant_role: str = ( | |
| "You are a knowledgeable project analyst. You have access to the following retrieved document snippets." | |
| ) | |
| skip_warmup: bool = True | |
| # CORS | |
| allowed_origins: List[str] = ["*"] | |
| class Config: | |
| env_file = ".env" | |
| # Instantiate settings and logger | |
| settings = Settings() | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Pre-instantiate the embedding model for reuse | |
| EMBEDDING = HuggingFaceEmbeddings( | |
| model_name=settings.embedding_model, | |
| model_kwargs={"trust_remote_code": True}, | |
| ) | |
| def embed_query_cached(query: str) -> List[float]: | |
| """Cache embedding vectors for repeated queries.""" | |
| return EMBEDDING.embed_query(query.strip().lower()) | |
| # ---------------------------------------------------------------------------- # | |
| # Application Lifespan # | |
| # ---------------------------------------------------------------------------- # | |
| app = FastAPI(lifespan=lambda app: lifespan(app)) | |
| async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: | |
| """ | |
| Startup: initialize RAG chain, embeddings, memory, indexes, and load data. | |
| Shutdown: clean up resources if needed. | |
| """ | |
| # 1) Initialize document compressor | |
| logger.info("Initializing Document Compressor") | |
| compressor = DocumentCompressorPipeline( | |
| transformers=[EmbeddingsRedundantFilter(embeddings=EMBEDDING)] | |
| ) | |
| # 2) Initialize and quantize Cross-Encoder | |
| logger.info("Initializing Cross-Encoder") | |
| cross_encoder = CrossEncoder(settings.cross_encoder_model) | |
| cross_encoder.model = torch.quantization.quantize_dynamic( | |
| cross_encoder.model, | |
| {torch.nn.Linear}, | |
| dtype=torch.qint8, | |
| ) | |
| logger.info("Cross-Encoder quantized") | |
| # 3) Build Seq2Seq pipeline and wrap in LangChain | |
| logger.info("Initializing LLM pipeline") | |
| tokenizer = MT5TokenizerFast.from_pretrained(settings.llm_model) | |
| model = MT5ForConditionalGeneration.from_pretrained(settings.llm_model) | |
| model = torch.quantization.quantize_dynamic( | |
| model, {torch.nn.Linear}, dtype=torch.qint8 | |
| ) | |
| # assemble the pipeline | |
| gen_pipe = pipeline( | |
| "text2text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=-1, | |
| max_new_tokens=256, | |
| do_sample=True, | |
| temperature=0.8, | |
| top_k=30, | |
| top_p=0.95, | |
| repetition_penalty=1.2, | |
| no_repeat_ngram_size=3, | |
| use_cache=True, | |
| ) | |
| llm = HuggingFacePipeline(pipeline=gen_pipe) | |
| # 4) Initialize conversation memory | |
| logger.info("Initializing Conversation Memory") | |
| memory = ConversationBufferWindowMemory( | |
| memory_key="chat_history", | |
| input_key="question", | |
| output_key="answer", | |
| return_messages=True, | |
| k=5, | |
| ) | |
| # 5) Build or load indexes for vectorstore and Whoosh | |
| logger.info("Building or loading indexes") | |
| vs, ix = await build_indexes( | |
| settings.parquet_path, | |
| settings.vectorstore_path, | |
| settings.whoosh_dir, | |
| settings.chunk_size, | |
| settings.chunk_overlap, | |
| None, | |
| ) | |
| retriever = HybridRetriever(vs=vs, ix=ix, compressor=compressor, cross_encoder=cross_encoder) | |
| # 6) Define prompt template for RAG chain | |
| prompt = PromptTemplate.from_template( | |
| f"{settings.assistant_role}\n" | |
| "{context}\n" | |
| "User Question:\n{question}\n" | |
| "Answer:" # Rules are embedded in assistant_role | |
| ) | |
| # 7) Instantiate the conversational retrieval chain | |
| logger.info("Initializing Retrieval Chain") | |
| app.state.rag_chain = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| retriever=retriever, | |
| memory=memory, | |
| combine_docs_chain_kwargs={"prompt": prompt}, | |
| return_source_documents=True, | |
| ) | |
| # Optional warmup | |
| if not settings.skip_warmup: | |
| logger.info("Warming up RAG chain") | |
| await app.state.rag_chain.ainvoke({"question": "warmup"}) | |
| # 8) Load project data into Polars DataFrame | |
| logger.info("Loading Parquet data from GCS") | |
| fs = gcsfs.GCSFileSystem() | |
| with fs.open(settings.parquet_path, "rb") as f: | |
| df = pl.read_parquet(f) | |
| # Cast id to integer and lowercase key columns for filtering | |
| df = df.with_columns( | |
| pl.col("id").cast(pl.Int64), | |
| *(pl.col(col).str.to_lowercase().alias(f"_{col}_lc") for col in [ | |
| "title", "status", "legalBasis", "fundingScheme" | |
| ]) | |
| ) | |
| # Cache DataFrame and filter values in app state | |
| app.state.df = df | |
| app.state.statuses = df["_status_lc"].unique().to_list() | |
| app.state.legal_bases = df["_legalBasis_lc"].unique().to_list() | |
| app.state.orgs_list = df.explode("list_name")["list_name"].unique().to_list() | |
| app.state.countries_list = df.explode("list_country")["list_country"].unique().to_list() | |
| yield # Application is ready | |
| # ---------------------------------------------------------------------------- # | |
| # App Setup # | |
| # ---------------------------------------------------------------------------- # | |
| app = FastAPI(lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=settings.allowed_origins, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ---------------------------------------------------------------------------- # | |
| # Pydantic Models # | |
| # ---------------------------------------------------------------------------- # | |
| class RAGRequest(BaseModel): | |
| session_id: Optional[str] = None # Optional conversation ID | |
| query: str # User's query text | |
| class RAGResponse(BaseModel): | |
| answer: str | |
| source_ids: List[str] | |
| # ---------------------------------------------------------------------------- # | |
| # RAG Endpoint # | |
| # ---------------------------------------------------------------------------- # | |
| def rag_chain_depender(app: FastAPI = Depends(lambda: app)) -> Any: | |
| """ | |
| Dependency injector to retrieve the initialized RAG chain from the application state. | |
| Raises HTTPException if chain is not yet initialized. | |
| """ | |
| chain = app.state.rag_chain | |
| if chain is None: | |
| # If the chain isn't set up, respond with a 500 server error | |
| raise HTTPException(status_code=500, detail="RAG chain not initialized") | |
| return chain | |
| async def ask_rag( | |
| req: RAGRequest, | |
| rag_chain = Depends(rag_chain_depender) | |
| ): | |
| """ | |
| Endpoint to process a RAG-based query. | |
| 1. Logs start of processing. | |
| 2. Invokes the RAG chain asynchronously with the user question. | |
| 3. Validates returned result structure and extracts answer + source IDs. | |
| 4. Handles any exceptions by logging traceback and returning a JSON error. | |
| """ | |
| logger.info("Starting to answer RAG query") | |
| try: | |
| # Asynchronously invoke the chain to get answer + docs | |
| result = await rag_chain.ainvoke({"question": req.query}) | |
| logger.info("RAG results retrieved") | |
| # Validate that the chain returned expected dict | |
| if not isinstance(result, dict): | |
| # Try sync call for debugging | |
| result2 = await rag_chain.acall({"question": req.query}) | |
| raise ValueError( | |
| f"Expected dict from chain, got {type(result)}; " | |
| f"acall() returned {type(result2)}" | |
| ) | |
| # Extract answer text and source document IDs | |
| answer = result.get("answer") | |
| docs = result.get("source_documents", []) | |
| sources = [d.metadata.get("id", "") for d in docs] | |
| return RAGResponse(answer=answer, source_ids=sources) | |
| except Exception as e: | |
| # Log full stacktrace to container logs | |
| traceback.print_exc() | |
| # Return HTTP 500 with error detail | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ---------------------------------------------------------------------------- # | |
| # Data Endpoints # | |
| # ---------------------------------------------------------------------------- # | |
| def get_projects( | |
| page: int = 0, | |
| limit: int = 10, | |
| search: str = "", | |
| status: str = "", | |
| legalBasis: str = "", | |
| organization: str = "", | |
| country: str = "", | |
| fundingScheme: str = "", | |
| proj_id: str = "", | |
| topic: str = "", | |
| sortOrder: str = "desc", | |
| sortField: str = "startDate", | |
| ): | |
| """ | |
| Paginated project listing with optional filtering and sorting. | |
| Query Parameters: | |
| - page: zero-based page index | |
| - limit: number of items per page | |
| - search: substring search in project title | |
| - status, legalBasis, organization, country, fundingScheme: filters | |
| - proj_id: exact project ID filter | |
| - topic: filter by EuroSciVoc topic | |
| - sortOrder: 'asc' or 'desc' | |
| - sortField: field name to sort by (fallback to startDate) | |
| Returns a list of project dicts including explanations and publication counts. | |
| """ | |
| df: pl.DataFrame = app.state.df | |
| start = page * limit | |
| sel = df | |
| # Apply text and field filters as needed | |
| if search: | |
| sel = sel.filter(pl.col("_title_lc").str.contains(search.lower())) | |
| if status: | |
| sel = sel.filter( | |
| pl.col("status").is_null() if status == "UNKNOWN" | |
| else pl.col("_status_lc") == status.lower() | |
| ) | |
| if legalBasis: | |
| sel = sel.filter(pl.col("_legalBasis_lc") == legalBasis.lower()) | |
| if organization: | |
| sel = sel.filter(pl.col("list_name").list.contains(organization)) | |
| if country: | |
| sel = sel.filter(pl.col("list_country").list.contains(country)) | |
| if fundingScheme: | |
| sel = sel.filter(pl.col("_fundingScheme_lc").str.contains(fundingScheme.lower())) | |
| if proj_id: | |
| sel = sel.filter(pl.col("id") == int(proj_id)) | |
| if topic: | |
| sel = sel.filter(pl.col("list_euroSciVocTitle").list.contains(topic)) | |
| # Base columns to return | |
| base_cols = [ | |
| "id","title","status","startDate","endDate","ecMaxContribution","acronym", | |
| "legalBasis","objective","frameworkProgramme","list_euroSciVocTitle", | |
| "list_euroSciVocPath","totalCost","list_isPublishedAs","fundingScheme" | |
| ] | |
| # Append top feature & SHAP value columns | |
| for i in range(1,7): | |
| base_cols += [f"top{i}_feature", f"top{i}_shap"] | |
| base_cols += ["predicted_label","predicted_prob"] | |
| # Determine sort direction and safe field | |
| sort_desc = sortOrder.lower() == "desc" | |
| sortField = sortField if sortField in df.columns else "startDate" | |
| # Query, sort, slice, and collect to Python dicts | |
| rows = ( | |
| sel.sort(sortField, descending=sort_desc) | |
| .slice(start, limit) | |
| .select(base_cols) | |
| .to_dicts() | |
| ) | |
| projects = [] | |
| for row in rows: | |
| # Reformat SHAP explanations into list of dicts | |
| explanations = [] | |
| for i in range(1,7): | |
| feat = row.pop(f"top{i}_feature", None) | |
| shap = row.pop(f"top{i}_shap", None) | |
| if feat is not None and shap is not None: | |
| explanations.append({"feature": feat, "shap": shap}) | |
| row["explanations"] = explanations | |
| # Aggregate publications counts | |
| raw_pubs = row.pop("list_publications", []) or [] | |
| pub_counts: Dict[str,int] = {} | |
| for p in raw_pubs: | |
| pub_counts[p] = pub_counts.get(p, 0) + 1 | |
| row["publications"] = pub_counts | |
| row["startDate"] = row["startDate"].date().isoformat() if row["startDate"] else None | |
| row["endDate"] = row["endDate"].date().isoformat() if row["endDate"] else None | |
| projects.append(row) | |
| return projects | |
| def get_filters(request: Request): | |
| """ | |
| Retrieve available filter options based on current dataset and optional query filters. | |
| Returns JSON with lists for statuses, legalBases, organizations, countries, and fundingSchemes. | |
| """ | |
| df = app.state.df | |
| params = request.query_params | |
| # Dynamically filter df based on provided params | |
| if s := params.get("status"): | |
| df = df.filter(pl.col("status").is_null() if s == "UNKNOWN" | |
| else pl.col("_status_lc") == s.lower()) | |
| if lb := params.get("legalBasis"): | |
| df = df.filter(pl.col("_legalBasis_lc") == lb.lower()) | |
| if org := params.get("organization"): | |
| df = df.filter(pl.col("list_name").list.contains(org)) | |
| if c := params.get("country"): | |
| df = df.filter(pl.col("list_country").list.contains(c)) | |
| if t := params.get("topic"): | |
| df = df.filter(pl.col("list_euroSciVocTitle").list.contains(t)) | |
| if fs := params.get("fundingScheme"): | |
| df = df.filter(pl.col("_fundingScheme_lc").str.contains(fs.lower())) | |
| if search := params.get("search"): | |
| df = df.filter(pl.col("_title_lc").str.contains(search.lower())) | |
| def normalize(vals): | |
| # Map None to "UNKNOWN" and return sorted unique list | |
| return sorted({("UNKNOWN" if v is None else v) for v in vals}) | |
| return { | |
| "statuses": normalize(df["status"].to_list()), | |
| "legalBases": normalize(df["legalBasis"].to_list()), | |
| "organizations": normalize(df["list_name"].explode().to_list()), | |
| "countries": normalize(df["list_country"].explode().to_list()), | |
| "fundingSchemes": normalize(df["fundingScheme"].to_list()), | |
| "topics": normalize(df["list_euroSciVocTitle"].explode().to_list()), | |
| } | |
| def get_stats(request: Request): | |
| """ | |
| Compute various statistics on projects with optional filters for status, | |
| legal basis, funding, start/end years, etc. Returns a dict of chart data. | |
| """ | |
| df = app.state.df | |
| lf = df.lazy() | |
| params = request.query_params | |
| # Apply filters | |
| if s := params.get("status"): | |
| lf = lf.filter(pl.col("_status_lc") == s.lower()) | |
| df = df.filter(pl.col("_status_lc") == s.lower()) | |
| if lb := params.get("legalBasis"): | |
| lf = lf.filter(pl.col("_legalBasis_lc") == lb.lower()) | |
| df = df.filter(pl.col("_legalBasis_lc") == lb.lower()) | |
| if org := params.get("organization"): | |
| lf = lf.filter(pl.col("list_name").list.contains(org)) | |
| df = df.filter(pl.col("list_name").list.contains(org)) | |
| if c := params.get("country"): | |
| lf = lf.filter(pl.col("list_country").list.contains(c)) | |
| df = df.filter(pl.col("list_country").list.contains(c)) | |
| if eu := params.get("topic"): | |
| lf = lf.filter(pl.col("list_euroSciVocTitle").list.contains(eu)) | |
| df = df.filter(pl.col("list_euroSciVocTitle").list.contains(eu)) | |
| if fs := params.get("fundingScheme"): | |
| lf = lf.filter(pl.col("_fundingScheme_lc").str.contains(fs.lower())) | |
| df = df.filter(pl.col("_fundingScheme_lc").str.contains(fs.lower())) | |
| if mn := params.get("minFunding"): | |
| lf = lf.filter(pl.col("ecMaxContribution") >= int(mn)) | |
| df = df.filter(pl.col("ecMaxContribution") >= int(mn)) | |
| if mx := params.get("maxFunding"): | |
| lf = lf.filter(pl.col("ecMaxContribution") <= int(mx)) | |
| df = df.filter(pl.col("ecMaxContribution") <= int(mx)) | |
| if y1 := params.get("minYear"): | |
| lf = lf.filter(pl.col("startDate").dt.year() >= int(y1)) | |
| df = df.filter(pl.col("startDate").dt.year() >= int(y1)) | |
| if y2 := params.get("maxYear"): | |
| lf = lf.filter(pl.col("startDate").dt.year() <= int(y2)) | |
| df = df.filter(pl.col("startDate").dt.year() <= int(y2)) | |
| if ye1 := params.get("minEndYear"): | |
| lf = lf.filter(pl.col("endDate").dt.year() >= int(ye1)) | |
| df = df.filter(pl.col("endDate").dt.year() >= int(ye1)) | |
| if ye2 := params.get("maxEndYear"): | |
| lf = lf.filter(pl.col("endDate").dt.year() <= int(ye2)) | |
| df = df.filter(pl.col("endDate").dt.year() <= int(ye2)) | |
| # Helper to drop any None/null entries | |
| def clean_data(labels: list, values: list) -> tuple[list, list]: | |
| pairs = [(l, v) for l, v in zip(labels, values) if l is not None and v is not None] | |
| if not pairs: | |
| return [], [] | |
| lbls, vals = zip(*pairs) | |
| return list(lbls), list(vals) | |
| # 1) Projects per Year (Line) | |
| yearly = ( | |
| lf.select(pl.col("startDate").dt.year().alias("year")) | |
| .group_by("year") | |
| .agg(pl.count().alias("count")) | |
| .sort("year") | |
| .collect() | |
| ) | |
| years = yearly["year"].to_list() | |
| year_counts = yearly["count"].to_list() | |
| # fixed bucket order | |
| size_order = ["<100 K","100 K–500 K","500 K–1 M","1 M–5 M","5 M–10 M","≥10 M"] | |
| # 2) Project-Size Distribution (Bar) | |
| size_buckets = ( | |
| df.with_columns( | |
| pl.when(pl.col("totalCost") < 100_000).then(pl.lit("<100 K")) | |
| .when(pl.col("totalCost") < 500_000).then(pl.lit("100 K–500 K")) | |
| .when(pl.col("totalCost") < 1_000_000).then(pl.lit("500 K–1 M")) | |
| .when(pl.col("totalCost") < 5_000_000).then(pl.lit("1 M–5 M")) | |
| .when(pl.col("totalCost") < 10_000_000).then(pl.lit("5 M–10 M")) | |
| .otherwise(pl.lit("≥10 M")) | |
| .alias("size_range") | |
| ) | |
| .group_by("size_range") | |
| .agg(pl.count().alias("count")) | |
| .with_columns( | |
| pl.col("size_range") | |
| .replace_strict(size_order, list(range(len(size_order)))) | |
| .alias("order") | |
| ) | |
| .sort("order") | |
| ) | |
| size_labels = size_buckets["size_range"].to_list() | |
| size_counts = size_buckets["count"].to_list() | |
| # 3) Scheme Frequency (Bar) | |
| scheme_counts_df = ( | |
| df.with_columns( | |
| pl.col("fundingScheme") | |
| .cast(pl.List(pl.Utf8)) | |
| .alias("fundingScheme") | |
| ) | |
| .group_by("fundingScheme") | |
| .agg(pl.count().alias("count")) | |
| .sort("count", descending=True) | |
| .head(10) | |
| ) | |
| scheme_labels = scheme_counts_df["fundingScheme"].to_list() | |
| scheme_values = scheme_counts_df["count"].to_list() | |
| # 4) Top 10 Macro Topics by EC Contribution (Bar) | |
| top_topics = ( | |
| df.explode("list_euroSciVocTitle") | |
| .group_by("list_euroSciVocTitle") | |
| .agg(pl.col("ecMaxContribution").sum().alias("total_ec")) | |
| .sort("total_ec", descending=True) | |
| .head(10) | |
| ) | |
| topic_labels = top_topics["list_euroSciVocTitle"].to_list() | |
| topic_values = (top_topics["total_ec"] / 1e6).round(1).to_list() | |
| # 5) Projects by Funding Range (Pie) | |
| fund_range = ( | |
| df.with_columns( | |
| pl.when(pl.col("ecMaxContribution") < 100_000).then(pl.lit("<100 K")) | |
| .when(pl.col("ecMaxContribution") < 500_000).then(pl.lit("100 K–500 K")) | |
| .when(pl.col("ecMaxContribution") < 1_000_000).then(pl.lit("500 K–1 M")) | |
| .when(pl.col("ecMaxContribution") < 5_000_000).then(pl.lit("1 M–5 M")) | |
| .when(pl.col("ecMaxContribution") < 10_000_000).then(pl.lit("5 M–10 M")) | |
| .otherwise(pl.lit("≥10 M")) | |
| .alias("funding_range") | |
| ) | |
| .group_by("funding_range") | |
| .agg(pl.count().alias("count")) | |
| .with_columns( | |
| pl.col("funding_range") | |
| .replace_strict(size_order, list(range(len(size_order)))) | |
| .alias("order") | |
| ) | |
| .sort("order") | |
| ) | |
| fr_labels = fund_range["funding_range"].to_list() | |
| fr_counts = fund_range["count"].to_list() | |
| # 6) Projects per Country (Doughnut) | |
| country = ( | |
| df.explode("list_country") | |
| .group_by("list_country") | |
| .agg(pl.count().alias("count")) | |
| .sort("count", descending=True) | |
| .head(10) | |
| ) | |
| country_labels = country["list_country"].to_list() | |
| country_counts = country["count"].to_list() | |
| # Clean out any nulls before returning | |
| years, year_counts = clean_data(years, year_counts) | |
| size_labels, size_counts = clean_data(size_labels, size_counts) | |
| scheme_labels, scheme_values = clean_data(scheme_labels, scheme_values) | |
| topic_labels, topic_values = clean_data(topic_labels, topic_values) | |
| fr_labels, fr_counts = clean_data(fr_labels, fr_counts) | |
| country_labels, country_counts= clean_data(country_labels, country_counts) | |
| return { | |
| "ppy": {"labels": years, "values": year_counts}, | |
| "psd": {"labels": size_labels, "values": size_counts}, | |
| "frs": {"labels": scheme_labels, "values": scheme_values}, | |
| "top10": {"labels": topic_labels, "values": topic_values}, | |
| "frb": {"labels": fr_labels, "values": fr_counts}, | |
| "ppc": {"labels": country_labels, "values": country_counts}, | |
| } | |
| def get_project_organizations(project_id: str): | |
| """ | |
| Retrieve organization details for a given project ID, including geolocation. | |
| Raises 404 if the project ID does not exist. | |
| """ | |
| df = app.state.df | |
| sel = df.filter(pl.col("id") == int(project_id)) | |
| if sel.is_empty(): | |
| raise HTTPException(status_code=404, detail="Project not found") | |
| # Explode list columns and parse latitude/longitude | |
| orgs_df = ( | |
| sel.select([ | |
| pl.col("list_name").explode().alias("name"), | |
| pl.col("list_city").explode().alias("city"), | |
| pl.col("list_SME").explode().alias("sme"), | |
| pl.col("list_role").explode().alias("role"), | |
| pl.col("list_organizationURL").explode().alias("orgURL"), | |
| pl.col("list_ecContribution").explode().alias("contribution"), | |
| pl.col("list_activityType").explode().alias("activityType"), | |
| pl.col("list_country").explode().alias("country"), | |
| pl.col("list_geolocation").explode().alias("geoloc"), | |
| ]) | |
| .with_columns([ | |
| # Split "lat,lon" string into list | |
| pl.col("geoloc").str.split(",").alias("latlon"), | |
| ]) | |
| .with_columns([ | |
| # Cast to floats for numeric use | |
| pl.col("latlon").list.get(0).cast(pl.Float64).alias("latitude"), | |
| pl.col("latlon").list.get(1).cast(pl.Float64).alias("longitude"), | |
| ]) | |
| .filter(pl.col("name").is_not_null()) | |
| .select([ | |
| "name","city","sme","role","contribution", | |
| "activityType","orgURL","country","latitude","longitude" | |
| ]) | |
| ) | |
| logger.info(f"Organization data for project {project_id}: {orgs_df.to_dicts()}") | |
| return orgs_df.to_dicts() | |