Rom89823974978 commited on
Commit
144fbb5
·
1 Parent(s): bdc141f
backend/main.py CHANGED
@@ -18,7 +18,7 @@ import gcsfs
18
  @asynccontextmanager
19
  async def lifespan(app: FastAPI):
20
  bucket = "mda_eu_project"
21
- path = "data/consolidated_clean.parquet"
22
  uri = f"gs://{bucket}/{path}"
23
 
24
  fs = gcsfs.GCSFileSystem()
@@ -51,28 +51,59 @@ app.add_middleware(
51
  )
52
 
53
  @app.get("/api/projects")
54
- def get_projects(page: int = 0, limit: int = 10, search: str = "", status: str = ""):
55
- df = app.state.df
 
 
 
 
 
56
  start = page * limit
57
  sel = df
58
 
59
- if search != "":
60
- sel = sel.filter(pl.col("_title_lc").str.contains(search.lower()))
61
-
62
- if status != "":
63
- sel = sel.filter(pl.col("_status_lc") == status.lower())
64
-
65
- return (
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  sel.slice(start, limit)
67
- .select([
68
- "id","title","status","startDate","ecMaxContribution",
69
- "acronym","endDate","legalBasis","objective",
70
- "frameworkProgramme","list_euroSciVocTitle",
71
- "list_euroSciVocPath"
72
- ])
73
  .to_dicts()
74
  )
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  @app.get("/api/filters")
77
  def get_filters():
78
  return {
 
18
  @asynccontextmanager
19
  async def lifespan(app: FastAPI):
20
  bucket = "mda_eu_project"
21
+ path = "data/consolidated_clean_pred.parquet" #"data/consolidated_clean.parquet"
22
  uri = f"gs://{bucket}/{path}"
23
 
24
  fs = gcsfs.GCSFileSystem()
 
51
  )
52
 
53
  @app.get("/api/projects")
54
+ def get_projects(
55
+ page: int = 0,
56
+ limit: int = 10,
57
+ search: str = "",
58
+ status: str = ""
59
+ ):
60
+ df: pl.DataFrame = app.state.df
61
  start = page * limit
62
  sel = df
63
 
64
+ # case-insensitive title search
65
+ if search:
66
+ sel = sel.filter(
67
+ pl.col("_title_lc").str.contains(search.lower())
68
+ )
69
+
70
+ # filter by status
71
+ if status:
72
+ sel = sel.filter(
73
+ pl.col("_status_lc") == status.lower()
74
+ )
75
+
76
+ # slice for pagination and select all needed fields, including top-N features/shap and prediction
77
+ cols = [
78
+ "id", "title", "status", "startDate", "endDate",
79
+ "ecMaxContribution", "acronym", "legalBasis", "objective",
80
+ "frameworkProgramme", "list_euroSciVocTitle", "list_euroSciVocPath",
81
+ ]
82
+ for i in range(1, 7):
83
+ cols += [f"top{i}_features", f"top{i}_shap"]
84
+
85
+ cols += ["predicted_label", "predicted_prob"]
86
+
87
+ rows = (
88
  sel.slice(start, limit)
89
+ .select(cols)
 
 
 
 
 
90
  .to_dicts()
91
  )
92
 
93
+ # build explanations array for each row
94
+ projects = []
95
+ for row in rows:
96
+ explanations = []
97
+ for i in range(1, 7):
98
+ feat = row.pop(f"top{i}_features", None)
99
+ shap = row.pop(f"top{i}_shap", None)
100
+ if feat is not None and shap is not None:
101
+ explanations.append({"feature": feat, "shap": shap})
102
+ row["explanations"] = explanations
103
+ projects.append(row)
104
+
105
+ return projects
106
+
107
  @app.get("/api/filters")
108
  def get_filters():
109
  return {
frontend/package.json CHANGED
@@ -23,7 +23,8 @@
23
  "react-chartjs-2": "^5.3.0",
24
  "react-dom": "^18.3.1",
25
  "react-leaflet": "^4.2.1",
26
- "react-select": "^5.10.1"
 
27
  },
28
  "devDependencies": {
29
  "@eslint/js": "^9.17.0",
@@ -32,6 +33,7 @@
32
  "@types/node": "^22.15.17",
33
  "@types/react": "^18.3.21",
34
  "@types/react-dom": "^18.3.5",
 
35
  "@vitejs/plugin-react": "^4.3.4",
36
  "eslint": "^9.17.0",
37
  "eslint-plugin-react-hooks": "^5.0.0",
 
23
  "react-chartjs-2": "^5.3.0",
24
  "react-dom": "^18.3.1",
25
  "react-leaflet": "^4.2.1",
26
+ "react-select": "^5.10.1",
27
+ "rechart": "^0.0.1"
28
  },
29
  "devDependencies": {
30
  "@eslint/js": "^9.17.0",
 
33
  "@types/node": "^22.15.17",
34
  "@types/react": "^18.3.21",
35
  "@types/react-dom": "^18.3.5",
36
+ "@types/recharts": "^2.0.1",
37
  "@vitejs/plugin-react": "^4.3.4",
38
  "eslint": "^9.17.0",
39
  "eslint-plugin-react-hooks": "^5.0.0",
frontend/src/components/ProjectDetails.tsx CHANGED
@@ -16,6 +16,16 @@ import {
16
  Button,
17
  Avatar,
18
  } from "@chakra-ui/react";
 
 
 
 
 
 
 
 
 
 
19
  import {
20
  MapContainer,
21
  TileLayer,
@@ -49,16 +59,11 @@ function ResizeMap({ count }: { count: number }) {
49
  }
50
 
51
  export default function ProjectDetails({
52
- project,
53
- question,
54
- setQuestion,
55
- askChatbot,
56
- chatHistory = [],
57
- messagesEndRef,
58
- }: ProjectDetailsProps) {
59
  // fetch organization locations
60
  const [orgLocations, setOrgLocations] = useState<OrganizationLocation[]>([]);
61
  const [loadingOrgs, setLoadingOrgs] = useState(true);
 
62
 
63
  useEffect(() => {
64
  if (!project) return;
@@ -77,7 +82,9 @@ export default function ProjectDetails({
77
  </Box>
78
  );
79
  }
80
-
 
 
81
  // Map center fallback
82
  const center: [number, number] = orgLocations.length
83
  ? [orgLocations[0].latitude, orgLocations[0].longitude]
@@ -178,62 +185,34 @@ export default function ProjectDetails({
178
  )}
179
  </Box>
180
 
181
- {/* Right: Chatbot */}
182
- <Box
183
- flex={{ base: '1', md: '0.6' }}
184
- bg="gray.50"
185
- p={4}
186
- borderRadius="md"
187
- display="flex"
188
- flexDirection="column"
189
- maxH="600px"
190
- >
191
- <Heading size="sm" mb={2}>Ask about this project</Heading>
192
-
193
- <Box flex={1} overflowY="auto" mb={4}>
194
- <VStack spacing={3} align="stretch">
195
- {(chatHistory ?? []).map((msg, i) => (
196
- <HStack
197
- key={i}
198
- alignSelf={msg.role === "user" ? "flex-end" : "flex-start"}
199
- maxW="90%"
200
- >
201
- {msg.role === "assistant" && <Avatar size="sm" name="Bot" />}
202
- <Box>
203
- <Text
204
- fontSize="sm"
205
- bg={msg.role === "user" ? "blue.100" : "gray.200"}
206
- px={3}
207
- py={2}
208
- borderRadius="md"
209
- >
210
- {msg.content}
211
- </Text>
212
- </Box>
213
- {msg.role === "user" && <Avatar size="sm" name="You" bg="blue.300" />}
214
- </HStack>
215
- ))}
216
- <div ref={messagesEndRef} />
217
- </VStack>
218
- </Box>
219
-
220
- <HStack>
221
- <Input
222
- placeholder="Type your question..."
223
- value={question}
224
- onChange={(e) => setQuestion(e.target.value)}
225
- onKeyDown={(e) => {
226
- if (e.key === "Enter" && !e.shiftKey) {
227
- e.preventDefault();
228
- askChatbot();
229
- }
230
- }}
231
- />
232
- <Button onClick={askChatbot} aria-label="Send question">
233
- Send
234
- </Button>
235
- </HStack>
236
  </Box>
237
  </Flex>
238
  );
239
- }
 
16
  Button,
17
  Avatar,
18
  } from "@chakra-ui/react";
19
+ import {
20
+ ResponsiveContainer,
21
+ BarChart,
22
+ Bar,
23
+ Cell,
24
+ XAxis,
25
+ YAxis,
26
+ CartesianGrid,
27
+ Tooltip
28
+ } from "recharts";
29
  import {
30
  MapContainer,
31
  TileLayer,
 
59
  }
60
 
61
  export default function ProjectDetails({
62
+ project,}: ProjectDetailsProps) {
 
 
 
 
 
 
63
  // fetch organization locations
64
  const [orgLocations, setOrgLocations] = useState<OrganizationLocation[]>([]);
65
  const [loadingOrgs, setLoadingOrgs] = useState(true);
66
+ const [loadingPlot, setLoadingPlot] = useState(true);
67
 
68
  useEffect(() => {
69
  if (!project) return;
 
82
  </Box>
83
  );
84
  }
85
+ const shapData = project.explanations;
86
+ const predicted = project.predicted_label;
87
+ const probability = project.predicted_prob;
88
  // Map center fallback
89
  const center: [number, number] = orgLocations.length
90
  ? [orgLocations[0].latitude, orgLocations[0].longitude]
 
185
  )}
186
  </Box>
187
 
188
+ {/* Right: Model Explanation */}
189
+ <Box flex={{ base: '1', md: '0.6' }} bg="white" p={4} borderRadius="md" boxShadow="sm">
190
+ <Heading size="sm" mb={4}>Model Prediction & Explanation</Heading>
191
+ {shapData?.length ? (
192
+ <>
193
+ <Text mb={2}><strong>Predicted Label:</strong> {predicted}</Text>
194
+ <Text mb={4}><strong>Probability:</strong> {(probability * 100).toFixed(2)}%</Text>
195
+ <ResponsiveContainer width="100%" height={300}>
196
+ <BarChart data={shapData} margin={{ top: 10, right: 30, left: 0, bottom: 5 }}>
197
+ <CartesianGrid strokeDasharray="3 3" />
198
+ <XAxis dataKey="feature" />
199
+ <YAxis />
200
+ <Tooltip />
201
+ <Bar dataKey="shap" name="SHAP Value">
202
+ {shapData.map((entry, index) => (
203
+ <Cell
204
+ key={`cell-${index}`}
205
+ fill={entry.shap >= 0 ? "#4caf50" : "#f44336"}
206
+ />
207
+ ))}
208
+ </Bar>
209
+ </BarChart>
210
+ </ResponsiveContainer>
211
+ </>
212
+ ) : (
213
+ <Spinner />
214
+ )}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  </Box>
216
  </Flex>
217
  );
218
+ }
frontend/src/hooks/types.ts CHANGED
@@ -12,15 +12,18 @@ export interface Project {
12
  frameworkProgramme: string;
13
  list_euroSciVocTitle: string[];
14
  list_euroSciVocPath: string[];
 
 
 
15
  }
16
 
17
  export interface ProjectDetailsProps {
18
  project: Project;
19
- question: string;
20
- setQuestion: React.Dispatch<React.SetStateAction<string>>;
21
- askChatbot: () => void;
22
- chatHistory: ChatMessage[];
23
- messagesEndRef: React.RefObject<HTMLDivElement>;
24
  }
25
 
26
 
 
12
  frameworkProgramme: string;
13
  list_euroSciVocTitle: string[];
14
  list_euroSciVocPath: string[];
15
+ explanations: Array<{ feature: string; shap: number }>;
16
+ predicted_label: number;
17
+ predicted_prob: number;
18
  }
19
 
20
  export interface ProjectDetailsProps {
21
  project: Project;
22
+ // question: string;
23
+ // setQuestion: React.Dispatch<React.SetStateAction<string>>;
24
+ // askChatbot: () => void;
25
+ // chatHistory: ChatMessage[];
26
+ // messagesEndRef: React.RefObject<HTMLDivElement>;
27
  }
28
 
29