File size: 25,200 Bytes
86fd3c3
 
 
 
a4ada39
3ebe7c1
86fd3c3
 
b0ca692
cd5f66f
86fd3c3
 
a64355d
86fd3c3
9965eef
86fd3c3
 
5536134
86fd3c3
 
 
 
 
 
 
 
 
 
094e088
 
86fd3c3
2d98661
86fd3c3
b0ca692
86fd3c3
b0ca692
86fd3c3
 
 
 
7e02716
86fd3c3
b0ca692
86fd3c3
 
dc64b01
b0ca692
cd5f66f
b0ca692
cd5f66f
86fd3c3
 
 
 
 
73ce5b8
 
 
86fd3c3
 
 
73ce5b8
b0ca692
86fd3c3
b0ca692
86fd3c3
b0ca692
2876c46
b0ca692
86fd3c3
b0ca692
a4ada39
86fd3c3
 
b0ca692
 
 
 
 
86fd3c3
b0ca692
86fd3c3
 
b0ca692
86fd3c3
 
 
 
 
b0ca692
 
 
86fd3c3
b0ca692
 
 
86fd3c3
b0ca692
86fd3c3
 
 
2d98661
b0ca692
86fd3c3
 
 
 
 
a4ada39
b0ca692
 
 
3ebe7c1
86fd3c3
a4ada39
b0ca692
b789119
 
 
 
 
 
 
86fd3c3
 
094e088
 
86fd3c3
dff1399
 
d470fba
 
 
 
 
 
 
 
2876c46
 
 
 
 
 
 
 
 
b0ca692
 
86fd3c3
a4ada39
b0ca692
 
 
 
 
86fd3c3
b0ca692
86fd3c3
 
 
b0ca692
 
 
 
 
 
 
 
 
86fd3c3
 
98587f6
86fd3c3
98587f6
86fd3c3
 
 
a4ada39
86fd3c3
a4ada39
b0ca692
 
 
 
 
 
 
 
86fd3c3
b0ca692
86fd3c3
b0ca692
 
86fd3c3
 
b0ca692
 
5e4dd72
86fd3c3
9f543c6
86fd3c3
 
 
 
9f543c6
b4df8a2
86fd3c3
b0ca692
86fd3c3
 
 
6424c78
1a53308
86fd3c3
 
b0ca692
 
 
2d98661
5e4dd72
 
b0ca692
5e4dd72
 
 
3d701d8
b0ca692
86fd3c3
b0ca692
 
86fd3c3
 
b0ca692
 
 
 
5536134
86fd3c3
 
 
b0ca692
86fd3c3
 
 
 
b0ca692
 
86fd3c3
b0ca692
 
 
42c9703
b0ca692
 
 
 
86fd3c3
 
 
 
 
 
 
 
 
ba6916c
86fd3c3
a4ada39
86fd3c3
 
 
ba6916c
86fd3c3
a4ada39
86fd3c3
 
 
 
 
 
ba6916c
 
86fd3c3
 
ba6916c
b0ca692
ba6916c
86fd3c3
ba6916c
86fd3c3
ba6916c
 
86fd3c3
b0ca692
 
 
3d701d8
144fbb5
 
 
 
97681f5
 
 
f4673b6
 
 
73ce5b8
ddfd6c7
 
144fbb5
86fd3c3
 
 
 
 
 
 
 
 
73ce5b8
86fd3c3
 
 
 
 
144fbb5
3d701d8
3ebe7c1
3d701d8
86fd3c3
144fbb5
97681f5
144fbb5
b0ca692
 
 
 
97681f5
 
 
 
 
 
f4673b6
6424c78
f4673b6
5754acb
73ce5b8
 
3ebe7c1
86fd3c3
b0ca692
 
 
 
144fbb5
86fd3c3
b0ca692
 
 
144fbb5
86fd3c3
 
44c5824
b0ca692
86fd3c3
144fbb5
b0ca692
 
 
3ebe7c1
 
3d701d8
144fbb5
 
86fd3c3
144fbb5
b0ca692
fa3a182
144fbb5
 
 
 
26d5a6c
86fd3c3
b0ca692
 
 
 
26d5a6c
 
674ebda
 
144fbb5
 
 
 
3d701d8
717d87a
86fd3c3
 
 
 
 
2f64ddc
717d87a
 
86fd3c3
717d87a
86fd3c3
 
717d87a
86fd3c3
717d87a
 
 
 
c2803a0
73ce5b8
c2803a0
 
717d87a
 
 
b0ca692
86fd3c3
b0ca692
33ac7b1
3d701d8
b0ca692
 
73ce5b8
b0ca692
094e088
73ce5b8
3d701d8
 
 
 
86fd3c3
de3c062
 
86fd3c3
73ce5b8
 
b0ca692
3ebe7c1
73ce5b8
3ebe7c1
86fd3c3
73ce5b8
3ebe7c1
86fd3c3
73ce5b8
3ebe7c1
 
73ce5b8
3ebe7c1
 
73ce5b8
5bc74c0
 
 
 
 
 
3ebe7c1
86fd3c3
73ce5b8
3ebe7c1
86fd3c3
73ce5b8
3ebe7c1
86fd3c3
73ce5b8
3ebe7c1
86fd3c3
73ce5b8
3f06d5a
 
 
 
 
 
73ce5b8
5d85f5b
 
 
 
 
 
 
 
73ce5b8
 
5d85f5b
 
 
 
 
73ce5b8
de3c062
 
 
 
 
73ce5b8
5d85f5b
73ce5b8
5d85f5b
 
 
 
 
 
 
 
73ce5b8
 
 
 
5d85f5b
de3c062
5d85f5b
73ce5b8
 
 
 
 
 
7270531
 
 
 
 
 
73ce5b8
7270531
 
 
 
73ce5b8
 
7270531
 
73ce5b8
5d85f5b
 
 
 
 
 
73ce5b8
 
 
 
 
5d85f5b
 
 
 
 
 
 
 
 
73ce5b8
 
 
de3c062
 
 
 
 
 
73ce5b8
de3c062
 
73ce5b8
 
5d85f5b
 
 
 
 
 
3d701d8
73ce5b8
 
3d701d8
5d85f5b
de3c062
 
 
 
 
 
5d85f5b
b0ca692
5d85f5b
 
 
 
 
 
b0ca692
3d701d8
de3c062
3d701d8
 
86fd3c3
 
 
 
 
3ebe7c1
86fd3c3
3ebe7c1
 
 
86fd3c3
3ebe7c1
b0ca692
69e8901
26d5a6c
 
 
3dfab37
26d5a6c
 
69e8901
 
5e4dd72
 
86fd3c3
69e8901
5e4dd72
 
86fd3c3
69e8901
 
5e4dd72
 
b0ca692
 
 
 
3ebe7c1
86fd3c3
3ebe7c1
86fd3c3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
import logging
import os
import shutil
import tempfile
import traceback
from contextlib import asynccontextmanager
from functools import lru_cache
from typing import Any, AsyncGenerator, Dict, List, Optional

import aiofiles
import faiss
import gcsfs
import polars as pl
import torch
import zipfile
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, PrivateAttr
from pydantic_settings import BaseSettings as SettingsBase
from sentence_transformers import CrossEncoder
from starlette.concurrency import run_in_threadpool
from tqdm import tqdm
from transformers import (  # Transformers for LLM pipeline
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    pipeline,
    T5ForConditionalGeneration,
    T5Tokenizer,
    MT5ForConditionalGeneration, MT5TokenizerFast
)

# LangChain imports for RAG
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts import PromptTemplate
from langchain.schema import BaseRetriever, Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain_community.vectorstores import FAISS
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline

# Project-specific imports
from app.rag import build_indexes, HybridRetriever

# ---------------------------------------------------------------------------- #
#                                   Settings                                   #
# ---------------------------------------------------------------------------- #

class Settings(SettingsBase):
    """
    Configuration settings loaded from environment or .env file.
    """
    # Data sources
    parquet_path: str = "gs://mda_kul_project/data/consolidated_clean_pred.parquet"
    whoosh_dir: str = "gs://mda_kul_project/whoosh_index"
    vectorstore_path: str = "gs://mda_kul_project/vectorstore_index"

    # Model names
    embedding_model: str = "sentence-transformers/LaBSE"
    llm_model: str = "google/mt5-base"
    cross_encoder_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"

    # RAG parameters
    chunk_size: int = 750
    chunk_overlap: int = 100
    hybrid_k: int = 4
    assistant_role: str = (
        "You are a knowledgeable project analyst. You have access to the following retrieved document snippets."
    )
    skip_warmup: bool = True

    # CORS
    allowed_origins: List[str] = ["*"]

    class Config:
        env_file = ".env"

# Instantiate settings and logger
settings = Settings()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Pre-instantiate the embedding model for reuse
EMBEDDING = HuggingFaceEmbeddings(
    model_name=settings.embedding_model,
    model_kwargs={"trust_remote_code": True},
)

@lru_cache(maxsize=256)
def embed_query_cached(query: str) -> List[float]:
    """Cache embedding vectors for repeated queries."""
    return EMBEDDING.embed_query(query.strip().lower())

# ---------------------------------------------------------------------------- #
#                              Application Lifespan                            #
# ---------------------------------------------------------------------------- #

app = FastAPI(lifespan=lambda app: lifespan(app))

@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
    """
    Startup: initialize RAG chain, embeddings, memory, indexes, and load data.
    Shutdown: clean up resources if needed.
    """
    # 1) Initialize document compressor
    logger.info("Initializing Document Compressor")
    compressor = DocumentCompressorPipeline(
        transformers=[EmbeddingsRedundantFilter(embeddings=EMBEDDING)]
    )

    # 2) Initialize and quantize Cross-Encoder
    logger.info("Initializing Cross-Encoder")
    cross_encoder = CrossEncoder(settings.cross_encoder_model)
    cross_encoder.model = torch.quantization.quantize_dynamic(
        cross_encoder.model,
        {torch.nn.Linear},
        dtype=torch.qint8,
    )
    logger.info("Cross-Encoder quantized")

    # 3) Build Seq2Seq pipeline and wrap in LangChain
    logger.info("Initializing LLM pipeline")
    tokenizer = MT5TokenizerFast.from_pretrained(settings.llm_model)
    model = MT5ForConditionalGeneration.from_pretrained(settings.llm_model)
    model = torch.quantization.quantize_dynamic(
        model, {torch.nn.Linear}, dtype=torch.qint8
    )

    # assemble the pipeline
    gen_pipe = pipeline(
        "text2text-generation",
        model=model,
        tokenizer=tokenizer,
        device=-1,
        max_new_tokens=256,
        do_sample=True,
        temperature=0.8,
        top_k=30,
        top_p=0.95,
        repetition_penalty=1.2,
        no_repeat_ngram_size=3,
        use_cache=True,
    )

    llm = HuggingFacePipeline(pipeline=gen_pipe)

    # 4) Initialize conversation memory
    logger.info("Initializing Conversation Memory")
    memory = ConversationBufferWindowMemory(
        memory_key="chat_history",
        input_key="question",
        output_key="answer",
        return_messages=True,
        k=5,
    )

    # 5) Build or load indexes for vectorstore and Whoosh
    logger.info("Building or loading indexes")
    vs, ix = await build_indexes(
        settings.parquet_path,
        settings.vectorstore_path,
        settings.whoosh_dir,
        settings.chunk_size,
        settings.chunk_overlap,
        None,
    )
    retriever = HybridRetriever(vs=vs, ix=ix, compressor=compressor, cross_encoder=cross_encoder)

    # 6) Define prompt template for RAG chain
    prompt = PromptTemplate.from_template(
        f"{settings.assistant_role}\n"  
        "{context}\n"
        "User Question:\n{question}\n"
        "Answer:"  # Rules are embedded in assistant_role
    )

    # 7) Instantiate the conversational retrieval chain
    logger.info("Initializing Retrieval Chain")
    app.state.rag_chain = ConversationalRetrievalChain.from_llm(
        llm=llm,
        retriever=retriever,
        memory=memory,
        combine_docs_chain_kwargs={"prompt": prompt},
        return_source_documents=True,
    )

    # Optional warmup
    if not settings.skip_warmup:
        logger.info("Warming up RAG chain")
        await app.state.rag_chain.ainvoke({"question": "warmup"})

    # 8) Load project data into Polars DataFrame
    logger.info("Loading Parquet data from GCS")
    fs = gcsfs.GCSFileSystem()
    with fs.open(settings.parquet_path, "rb") as f:
        df = pl.read_parquet(f)
    # Cast id to integer and lowercase key columns for filtering
    df = df.with_columns(
        pl.col("id").cast(pl.Int64),
        *(pl.col(col).str.to_lowercase().alias(f"_{col}_lc") for col in [
            "title", "status", "legalBasis", "fundingScheme"
        ])
    )

    # Cache DataFrame and filter values in app state
    app.state.df = df
    app.state.statuses = df["_status_lc"].unique().to_list()
    app.state.legal_bases = df["_legalBasis_lc"].unique().to_list()
    app.state.orgs_list = df.explode("list_name")["list_name"].unique().to_list()
    app.state.countries_list = df.explode("list_country")["list_country"].unique().to_list()

    yield  # Application is ready
    
# ---------------------------------------------------------------------------- #
#                                   App Setup                                   #
# ---------------------------------------------------------------------------- #
app = FastAPI(lifespan=lifespan)
app.add_middleware(
    CORSMiddleware,
    allow_origins=settings.allowed_origins,
    allow_methods=["*"],
    allow_headers=["*"],
)

# ---------------------------------------------------------------------------- #
#                                 Pydantic Models                             #
# ---------------------------------------------------------------------------- #
class RAGRequest(BaseModel):
    session_id: Optional[str] = None  # Optional conversation ID
    query: str  # User's query text

class RAGResponse(BaseModel):
    answer: str
    source_ids: List[str]

# ---------------------------------------------------------------------------- #
#                                 RAG Endpoint                                  #
# ---------------------------------------------------------------------------- #
def rag_chain_depender(app: FastAPI = Depends(lambda: app)) -> Any:
    """
    Dependency injector to retrieve the initialized RAG chain from the application state.
    Raises HTTPException if chain is not yet initialized.
    """
    chain = app.state.rag_chain
    if chain is None:
        # If the chain isn't set up, respond with a 500 server error
        raise HTTPException(status_code=500, detail="RAG chain not initialized")
    return chain

@app.post("/api/rag", response_model=RAGResponse)
async def ask_rag(
    req: RAGRequest,
    rag_chain = Depends(rag_chain_depender)
):
    """
    Endpoint to process a RAG-based query.

    1. Logs start of processing.
    2. Invokes the RAG chain asynchronously with the user question.
    3. Validates returned result structure and extracts answer + source IDs.
    4. Handles any exceptions by logging traceback and returning a JSON error.
    """
    logger.info("Starting to answer RAG query")
    try:
        # Asynchronously invoke the chain to get answer + docs
        result = await rag_chain.ainvoke({"question": req.query})
        logger.info("RAG results retrieved")

        # Validate that the chain returned expected dict
        if not isinstance(result, dict):
            # Try sync call for debugging
            result2 = await rag_chain.acall({"question": req.query})
            raise ValueError(
                f"Expected dict from chain, got {type(result)}; "
                f"acall() returned {type(result2)}"
            )

        # Extract answer text and source document IDs
        answer = result.get("answer")
        docs   = result.get("source_documents", [])
        sources = [d.metadata.get("id", "") for d in docs]

        return RAGResponse(answer=answer, source_ids=sources)

    except Exception as e:
        # Log full stacktrace to container logs
        traceback.print_exc()
        # Return HTTP 500 with error detail
        raise HTTPException(status_code=500, detail=str(e))
    

# ---------------------------------------------------------------------------- #
#                                  Data Endpoints                               #
# ---------------------------------------------------------------------------- #
@app.get("/api/projects")
def get_projects(
    page: int = 0,
    limit: int = 10,
    search: str = "",
    status: str = "",
    legalBasis: str = "",
    organization: str = "",
    country: str = "",
    fundingScheme: str = "",
    proj_id: str = "",
    topic: str = "",
    sortOrder: str = "desc",
    sortField: str = "startDate",
):
    """
    Paginated project listing with optional filtering and sorting.

    Query Parameters:
    - page: zero-based page index
    - limit: number of items per page
    - search: substring search in project title
    - status, legalBasis, organization, country, fundingScheme: filters
    - proj_id: exact project ID filter
    - topic: filter by EuroSciVoc topic
    - sortOrder: 'asc' or 'desc'
    - sortField: field name to sort by (fallback to startDate)

    Returns a list of project dicts including explanations and publication counts.
    """
    df: pl.DataFrame = app.state.df
    start = page * limit
    sel = df

    # Apply text and field filters as needed
    if search:
        sel = sel.filter(pl.col("_title_lc").str.contains(search.lower()))
    if status:
        sel = sel.filter(
            pl.col("status").is_null() if status == "UNKNOWN"
            else pl.col("_status_lc") == status.lower()
        )
    if legalBasis:
        sel = sel.filter(pl.col("_legalBasis_lc") == legalBasis.lower())
    if organization:
        sel = sel.filter(pl.col("list_name").list.contains(organization))
    if country:
        sel = sel.filter(pl.col("list_country").list.contains(country))
    if fundingScheme:
        sel = sel.filter(pl.col("_fundingScheme_lc").str.contains(fundingScheme.lower()))
    if proj_id:
        sel = sel.filter(pl.col("id") == int(proj_id))
    if topic:
        sel = sel.filter(pl.col("list_euroSciVocTitle").list.contains(topic))

    # Base columns to return
    base_cols = [
        "id","title","status","startDate","endDate","ecMaxContribution","acronym",
        "legalBasis","objective","frameworkProgramme","list_euroSciVocTitle",
        "list_euroSciVocPath","totalCost","list_isPublishedAs","fundingScheme"
    ]
    # Append top feature & SHAP value columns
    for i in range(1,7):
        base_cols += [f"top{i}_feature", f"top{i}_shap"]
    base_cols += ["predicted_label","predicted_prob"]

    # Determine sort direction and safe field
    sort_desc = sortOrder.lower() == "desc"
    sortField = sortField if sortField in df.columns else "startDate"

    # Query, sort, slice, and collect to Python dicts
    rows = (
        sel.sort(sortField, descending=sort_desc)
           .slice(start, limit)
           .select(base_cols)
           .to_dicts()
    )

    projects = []
    for row in rows:
        # Reformat SHAP explanations into list of dicts
        explanations = []
        for i in range(1,7):
            feat = row.pop(f"top{i}_feature", None)
            shap = row.pop(f"top{i}_shap", None)
            if feat is not None and shap is not None:
                explanations.append({"feature": feat, "shap": shap})
        row["explanations"] = explanations

        # Aggregate publications counts
        raw_pubs = row.pop("list_publications", []) or []
        pub_counts: Dict[str,int] = {}
        for p in raw_pubs:
            pub_counts[p] = pub_counts.get(p, 0) + 1
        row["publications"] = pub_counts

        row["startDate"] = row["startDate"].date().isoformat() if row["startDate"] else None
        row["endDate"]   = row["endDate"].date().isoformat() if row["endDate"] else None
        projects.append(row)

    return projects

@app.get("/api/filters")
def get_filters(request: Request):
    """
    Retrieve available filter options based on current dataset and optional query filters.

    Returns JSON with lists for statuses, legalBases, organizations, countries, and fundingSchemes.
    """
    df = app.state.df
    params = request.query_params

    # Dynamically filter df based on provided params
    if s := params.get("status"):
        df = df.filter(pl.col("status").is_null() if s == "UNKNOWN"
                       else pl.col("_status_lc") == s.lower())
    if lb := params.get("legalBasis"):
        df = df.filter(pl.col("_legalBasis_lc") == lb.lower())
    if org := params.get("organization"):
        df = df.filter(pl.col("list_name").list.contains(org))
    if c := params.get("country"):
        df = df.filter(pl.col("list_country").list.contains(c))
    if t := params.get("topic"):
        df = df.filter(pl.col("list_euroSciVocTitle").list.contains(t))
    if fs := params.get("fundingScheme"):
        df = df.filter(pl.col("_fundingScheme_lc").str.contains(fs.lower()))
    if search := params.get("search"):
        df = df.filter(pl.col("_title_lc").str.contains(search.lower()))

    def normalize(vals):
        # Map None to "UNKNOWN" and return sorted unique list
        return sorted({("UNKNOWN" if v is None else v) for v in vals})

    return {
        "statuses":     normalize(df["status"].to_list()),
        "legalBases":   normalize(df["legalBasis"].to_list()),
        "organizations": normalize(df["list_name"].explode().to_list()),
        "countries":    normalize(df["list_country"].explode().to_list()),
        "fundingSchemes": normalize(df["fundingScheme"].to_list()),
        "topics":       normalize(df["list_euroSciVocTitle"].explode().to_list()),
    }

@app.get("/api/stats")
def get_stats(request: Request):
    """
    Compute various statistics on projects with optional filters for status,
    legal basis, funding, start/end years, etc. Returns a dict of chart data.
    """
    df = app.state.df
    lf = df.lazy()
    params = request.query_params

    # Apply filters
    if s := params.get("status"):
        lf = lf.filter(pl.col("_status_lc") == s.lower())
        df = df.filter(pl.col("_status_lc") == s.lower())
    if lb := params.get("legalBasis"):
        lf = lf.filter(pl.col("_legalBasis_lc") == lb.lower())
        df = df.filter(pl.col("_legalBasis_lc") == lb.lower())
    if org := params.get("organization"):
        lf = lf.filter(pl.col("list_name").list.contains(org))
        df = df.filter(pl.col("list_name").list.contains(org))
    if c := params.get("country"):
        lf = lf.filter(pl.col("list_country").list.contains(c))
        df = df.filter(pl.col("list_country").list.contains(c))
    if eu := params.get("topic"):
        lf = lf.filter(pl.col("list_euroSciVocTitle").list.contains(eu))
        df = df.filter(pl.col("list_euroSciVocTitle").list.contains(eu))
    if fs := params.get("fundingScheme"):
        lf = lf.filter(pl.col("_fundingScheme_lc").str.contains(fs.lower()))
        df = df.filter(pl.col("_fundingScheme_lc").str.contains(fs.lower()))
    if mn := params.get("minFunding"):
        lf = lf.filter(pl.col("ecMaxContribution") >= int(mn))
        df = df.filter(pl.col("ecMaxContribution") >= int(mn))
    if mx := params.get("maxFunding"):
        lf = lf.filter(pl.col("ecMaxContribution") <= int(mx))
        df = df.filter(pl.col("ecMaxContribution") <= int(mx))
    if y1 := params.get("minYear"):
        lf = lf.filter(pl.col("startDate").dt.year() >= int(y1))
        df = df.filter(pl.col("startDate").dt.year() >= int(y1))
    if y2 := params.get("maxYear"):
        lf = lf.filter(pl.col("startDate").dt.year() <= int(y2))
        df = df.filter(pl.col("startDate").dt.year() <= int(y2))
    if ye1 := params.get("minEndYear"):
        lf = lf.filter(pl.col("endDate").dt.year() >= int(ye1))
        df = df.filter(pl.col("endDate").dt.year() >= int(ye1))
    if ye2 := params.get("maxEndYear"):
        lf = lf.filter(pl.col("endDate").dt.year() <= int(ye2))
        df = df.filter(pl.col("endDate").dt.year() <= int(ye2))

    # Helper to drop any None/null entries
    def clean_data(labels: list, values: list) -> tuple[list, list]:
        pairs = [(l, v) for l, v in zip(labels, values) if l is not None and v is not None]
        if not pairs:
            return [], []
        lbls, vals = zip(*pairs)
        return list(lbls), list(vals)

    # 1) Projects per Year (Line)
    yearly = (
        lf.select(pl.col("startDate").dt.year().alias("year"))
          .group_by("year")
          .agg(pl.count().alias("count"))
          .sort("year")
          .collect()
    )
    years       = yearly["year"].to_list()
    year_counts = yearly["count"].to_list()

    # fixed bucket order
    size_order = ["<100 K","100 K–500 K","500 K–1 M","1 M–5 M","5 M–10 M","≥10 M"]

    # 2) Project-Size Distribution (Bar)
    size_buckets = (
        df.with_columns(
             pl.when(pl.col("totalCost") < 100_000).then(pl.lit("<100 K"))
               .when(pl.col("totalCost") < 500_000).then(pl.lit("100 K–500 K"))
               .when(pl.col("totalCost") < 1_000_000).then(pl.lit("500 K–1 M"))
               .when(pl.col("totalCost") < 5_000_000).then(pl.lit("1 M–5 M"))
               .when(pl.col("totalCost") < 10_000_000).then(pl.lit("5 M–10 M"))
               .otherwise(pl.lit("≥10 M"))
               .alias("size_range")
        )
        .group_by("size_range")
        .agg(pl.count().alias("count"))
        .with_columns(
            pl.col("size_range")
              .replace_strict(size_order, list(range(len(size_order))))
              .alias("order")
        )
        .sort("order")
    )
    size_labels = size_buckets["size_range"].to_list()
    size_counts = size_buckets["count"].to_list()

    # 3) Scheme Frequency (Bar)
    scheme_counts_df = (
        df.with_columns(
            pl.col("fundingScheme")
                .cast(pl.List(pl.Utf8))
                .alias("fundingScheme")
        )
        .group_by("fundingScheme")
        .agg(pl.count().alias("count"))
        .sort("count", descending=True)
        .head(10)
    )

    scheme_labels = scheme_counts_df["fundingScheme"].to_list()
    scheme_values = scheme_counts_df["count"].to_list()
    # 4) Top 10 Macro Topics by EC Contribution (Bar)
    top_topics   = (
        df.explode("list_euroSciVocTitle")
          .group_by("list_euroSciVocTitle")
          .agg(pl.col("ecMaxContribution").sum().alias("total_ec"))
          .sort("total_ec", descending=True)
          .head(10)
    )
    topic_labels = top_topics["list_euroSciVocTitle"].to_list()
    topic_values = (top_topics["total_ec"] / 1e6).round(1).to_list()

    # 5) Projects by Funding Range (Pie)
    fund_range   = (
        df.with_columns(
             pl.when(pl.col("ecMaxContribution") < 100_000).then(pl.lit("<100 K"))
               .when(pl.col("ecMaxContribution") < 500_000).then(pl.lit("100 K–500 K"))
               .when(pl.col("ecMaxContribution") < 1_000_000).then(pl.lit("500 K–1 M"))
               .when(pl.col("ecMaxContribution") < 5_000_000).then(pl.lit("1 M–5 M"))
               .when(pl.col("ecMaxContribution") < 10_000_000).then(pl.lit("5 M–10 M"))
               .otherwise(pl.lit("≥10 M"))
               .alias("funding_range")
        )
        .group_by("funding_range")
        .agg(pl.count().alias("count"))
        .with_columns(
            pl.col("funding_range")
              .replace_strict(size_order, list(range(len(size_order))))
              .alias("order")
        )
        .sort("order")
    )
    fr_labels = fund_range["funding_range"].to_list()
    fr_counts = fund_range["count"].to_list()

    # 6) Projects per Country (Doughnut)
    country      = (
        df.explode("list_country")
          .group_by("list_country")
          .agg(pl.count().alias("count"))
          .sort("count", descending=True)
          .head(10)
    )
    country_labels = country["list_country"].to_list()
    country_counts = country["count"].to_list()

    # Clean out any nulls before returning
    years,           year_counts   = clean_data(years,           year_counts)
    size_labels,     size_counts   = clean_data(size_labels,     size_counts)
    scheme_labels,   scheme_values = clean_data(scheme_labels,   scheme_values)
    topic_labels,    topic_values  = clean_data(topic_labels,    topic_values)
    fr_labels,       fr_counts     = clean_data(fr_labels,       fr_counts)
    country_labels,  country_counts= clean_data(country_labels,  country_counts)

    return {
        "ppy":   {"labels": years,           "values": year_counts},
        "psd":   {"labels": size_labels,     "values": size_counts},
        "frs":   {"labels": scheme_labels,   "values": scheme_values},
        "top10": {"labels": topic_labels,    "values": topic_values},
        "frb":   {"labels": fr_labels,       "values": fr_counts},
        "ppc":   {"labels": country_labels,  "values": country_counts},
    }


@app.get("/api/project/{project_id}/organizations")
def get_project_organizations(project_id: str):
    """
    Retrieve organization details for a given project ID, including geolocation.

    Raises 404 if the project ID does not exist.
    """
    df = app.state.df
    sel = df.filter(pl.col("id") == int(project_id))
    if sel.is_empty():
        raise HTTPException(status_code=404, detail="Project not found")

    # Explode list columns and parse latitude/longitude
    orgs_df = (
        sel.select([
            pl.col("list_name").explode().alias("name"),
            pl.col("list_city").explode().alias("city"),
            pl.col("list_SME").explode().alias("sme"),
            pl.col("list_role").explode().alias("role"),
            pl.col("list_organizationURL").explode().alias("orgURL"),
            pl.col("list_ecContribution").explode().alias("contribution"),
            pl.col("list_activityType").explode().alias("activityType"),
            pl.col("list_country").explode().alias("country"),
            pl.col("list_geolocation").explode().alias("geoloc"),
        ])
        .with_columns([
            # Split "lat,lon" string into list
            pl.col("geoloc").str.split(",").alias("latlon"),
        ])
        .with_columns([
            # Cast to floats for numeric use
            pl.col("latlon").list.get(0).cast(pl.Float64).alias("latitude"),
            pl.col("latlon").list.get(1).cast(pl.Float64).alias("longitude"),
        ])
        .filter(pl.col("name").is_not_null())
        .select([
            "name","city","sme","role","contribution",
            "activityType","orgURL","country","latitude","longitude"
        ])
    )
    logger.info(f"Organization data for project {project_id}: {orgs_df.to_dicts()}")
    return orgs_df.to_dicts()