Spaces:
Sleeping
Sleeping
Commit
·
144fbb5
1
Parent(s):
bdc141f
- backend/main.py +47 -16
- frontend/package.json +3 -1
- frontend/src/components/ProjectDetails.tsx +43 -64
- frontend/src/hooks/types.ts +8 -5
backend/main.py
CHANGED
@@ -18,7 +18,7 @@ import gcsfs
|
|
18 |
@asynccontextmanager
|
19 |
async def lifespan(app: FastAPI):
|
20 |
bucket = "mda_eu_project"
|
21 |
-
path = "data/consolidated_clean.parquet"
|
22 |
uri = f"gs://{bucket}/{path}"
|
23 |
|
24 |
fs = gcsfs.GCSFileSystem()
|
@@ -51,28 +51,59 @@ app.add_middleware(
|
|
51 |
)
|
52 |
|
53 |
@app.get("/api/projects")
|
54 |
-
def get_projects(
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
56 |
start = page * limit
|
57 |
sel = df
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
sel.slice(start, limit)
|
67 |
-
.select(
|
68 |
-
"id","title","status","startDate","ecMaxContribution",
|
69 |
-
"acronym","endDate","legalBasis","objective",
|
70 |
-
"frameworkProgramme","list_euroSciVocTitle",
|
71 |
-
"list_euroSciVocPath"
|
72 |
-
])
|
73 |
.to_dicts()
|
74 |
)
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
@app.get("/api/filters")
|
77 |
def get_filters():
|
78 |
return {
|
|
|
18 |
@asynccontextmanager
|
19 |
async def lifespan(app: FastAPI):
|
20 |
bucket = "mda_eu_project"
|
21 |
+
path = "data/consolidated_clean_pred.parquet" #"data/consolidated_clean.parquet"
|
22 |
uri = f"gs://{bucket}/{path}"
|
23 |
|
24 |
fs = gcsfs.GCSFileSystem()
|
|
|
51 |
)
|
52 |
|
53 |
@app.get("/api/projects")
|
54 |
+
def get_projects(
|
55 |
+
page: int = 0,
|
56 |
+
limit: int = 10,
|
57 |
+
search: str = "",
|
58 |
+
status: str = ""
|
59 |
+
):
|
60 |
+
df: pl.DataFrame = app.state.df
|
61 |
start = page * limit
|
62 |
sel = df
|
63 |
|
64 |
+
# case-insensitive title search
|
65 |
+
if search:
|
66 |
+
sel = sel.filter(
|
67 |
+
pl.col("_title_lc").str.contains(search.lower())
|
68 |
+
)
|
69 |
+
|
70 |
+
# filter by status
|
71 |
+
if status:
|
72 |
+
sel = sel.filter(
|
73 |
+
pl.col("_status_lc") == status.lower()
|
74 |
+
)
|
75 |
+
|
76 |
+
# slice for pagination and select all needed fields, including top-N features/shap and prediction
|
77 |
+
cols = [
|
78 |
+
"id", "title", "status", "startDate", "endDate",
|
79 |
+
"ecMaxContribution", "acronym", "legalBasis", "objective",
|
80 |
+
"frameworkProgramme", "list_euroSciVocTitle", "list_euroSciVocPath",
|
81 |
+
]
|
82 |
+
for i in range(1, 7):
|
83 |
+
cols += [f"top{i}_features", f"top{i}_shap"]
|
84 |
+
|
85 |
+
cols += ["predicted_label", "predicted_prob"]
|
86 |
+
|
87 |
+
rows = (
|
88 |
sel.slice(start, limit)
|
89 |
+
.select(cols)
|
|
|
|
|
|
|
|
|
|
|
90 |
.to_dicts()
|
91 |
)
|
92 |
|
93 |
+
# build explanations array for each row
|
94 |
+
projects = []
|
95 |
+
for row in rows:
|
96 |
+
explanations = []
|
97 |
+
for i in range(1, 7):
|
98 |
+
feat = row.pop(f"top{i}_features", None)
|
99 |
+
shap = row.pop(f"top{i}_shap", None)
|
100 |
+
if feat is not None and shap is not None:
|
101 |
+
explanations.append({"feature": feat, "shap": shap})
|
102 |
+
row["explanations"] = explanations
|
103 |
+
projects.append(row)
|
104 |
+
|
105 |
+
return projects
|
106 |
+
|
107 |
@app.get("/api/filters")
|
108 |
def get_filters():
|
109 |
return {
|
frontend/package.json
CHANGED
@@ -23,7 +23,8 @@
|
|
23 |
"react-chartjs-2": "^5.3.0",
|
24 |
"react-dom": "^18.3.1",
|
25 |
"react-leaflet": "^4.2.1",
|
26 |
-
"react-select": "^5.10.1"
|
|
|
27 |
},
|
28 |
"devDependencies": {
|
29 |
"@eslint/js": "^9.17.0",
|
@@ -32,6 +33,7 @@
|
|
32 |
"@types/node": "^22.15.17",
|
33 |
"@types/react": "^18.3.21",
|
34 |
"@types/react-dom": "^18.3.5",
|
|
|
35 |
"@vitejs/plugin-react": "^4.3.4",
|
36 |
"eslint": "^9.17.0",
|
37 |
"eslint-plugin-react-hooks": "^5.0.0",
|
|
|
23 |
"react-chartjs-2": "^5.3.0",
|
24 |
"react-dom": "^18.3.1",
|
25 |
"react-leaflet": "^4.2.1",
|
26 |
+
"react-select": "^5.10.1",
|
27 |
+
"rechart": "^0.0.1"
|
28 |
},
|
29 |
"devDependencies": {
|
30 |
"@eslint/js": "^9.17.0",
|
|
|
33 |
"@types/node": "^22.15.17",
|
34 |
"@types/react": "^18.3.21",
|
35 |
"@types/react-dom": "^18.3.5",
|
36 |
+
"@types/recharts": "^2.0.1",
|
37 |
"@vitejs/plugin-react": "^4.3.4",
|
38 |
"eslint": "^9.17.0",
|
39 |
"eslint-plugin-react-hooks": "^5.0.0",
|
frontend/src/components/ProjectDetails.tsx
CHANGED
@@ -16,6 +16,16 @@ import {
|
|
16 |
Button,
|
17 |
Avatar,
|
18 |
} from "@chakra-ui/react";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
import {
|
20 |
MapContainer,
|
21 |
TileLayer,
|
@@ -49,16 +59,11 @@ function ResizeMap({ count }: { count: number }) {
|
|
49 |
}
|
50 |
|
51 |
export default function ProjectDetails({
|
52 |
-
project,
|
53 |
-
question,
|
54 |
-
setQuestion,
|
55 |
-
askChatbot,
|
56 |
-
chatHistory = [],
|
57 |
-
messagesEndRef,
|
58 |
-
}: ProjectDetailsProps) {
|
59 |
// fetch organization locations
|
60 |
const [orgLocations, setOrgLocations] = useState<OrganizationLocation[]>([]);
|
61 |
const [loadingOrgs, setLoadingOrgs] = useState(true);
|
|
|
62 |
|
63 |
useEffect(() => {
|
64 |
if (!project) return;
|
@@ -77,7 +82,9 @@ export default function ProjectDetails({
|
|
77 |
</Box>
|
78 |
);
|
79 |
}
|
80 |
-
|
|
|
|
|
81 |
// Map center fallback
|
82 |
const center: [number, number] = orgLocations.length
|
83 |
? [orgLocations[0].latitude, orgLocations[0].longitude]
|
@@ -178,62 +185,34 @@ export default function ProjectDetails({
|
|
178 |
)}
|
179 |
</Box>
|
180 |
|
181 |
-
{/* Right:
|
182 |
-
<Box
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
borderRadius="md"
|
209 |
-
>
|
210 |
-
{msg.content}
|
211 |
-
</Text>
|
212 |
-
</Box>
|
213 |
-
{msg.role === "user" && <Avatar size="sm" name="You" bg="blue.300" />}
|
214 |
-
</HStack>
|
215 |
-
))}
|
216 |
-
<div ref={messagesEndRef} />
|
217 |
-
</VStack>
|
218 |
-
</Box>
|
219 |
-
|
220 |
-
<HStack>
|
221 |
-
<Input
|
222 |
-
placeholder="Type your question..."
|
223 |
-
value={question}
|
224 |
-
onChange={(e) => setQuestion(e.target.value)}
|
225 |
-
onKeyDown={(e) => {
|
226 |
-
if (e.key === "Enter" && !e.shiftKey) {
|
227 |
-
e.preventDefault();
|
228 |
-
askChatbot();
|
229 |
-
}
|
230 |
-
}}
|
231 |
-
/>
|
232 |
-
<Button onClick={askChatbot} aria-label="Send question">
|
233 |
-
Send
|
234 |
-
</Button>
|
235 |
-
</HStack>
|
236 |
</Box>
|
237 |
</Flex>
|
238 |
);
|
239 |
-
}
|
|
|
16 |
Button,
|
17 |
Avatar,
|
18 |
} from "@chakra-ui/react";
|
19 |
+
import {
|
20 |
+
ResponsiveContainer,
|
21 |
+
BarChart,
|
22 |
+
Bar,
|
23 |
+
Cell,
|
24 |
+
XAxis,
|
25 |
+
YAxis,
|
26 |
+
CartesianGrid,
|
27 |
+
Tooltip
|
28 |
+
} from "recharts";
|
29 |
import {
|
30 |
MapContainer,
|
31 |
TileLayer,
|
|
|
59 |
}
|
60 |
|
61 |
export default function ProjectDetails({
|
62 |
+
project,}: ProjectDetailsProps) {
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
// fetch organization locations
|
64 |
const [orgLocations, setOrgLocations] = useState<OrganizationLocation[]>([]);
|
65 |
const [loadingOrgs, setLoadingOrgs] = useState(true);
|
66 |
+
const [loadingPlot, setLoadingPlot] = useState(true);
|
67 |
|
68 |
useEffect(() => {
|
69 |
if (!project) return;
|
|
|
82 |
</Box>
|
83 |
);
|
84 |
}
|
85 |
+
const shapData = project.explanations;
|
86 |
+
const predicted = project.predicted_label;
|
87 |
+
const probability = project.predicted_prob;
|
88 |
// Map center fallback
|
89 |
const center: [number, number] = orgLocations.length
|
90 |
? [orgLocations[0].latitude, orgLocations[0].longitude]
|
|
|
185 |
)}
|
186 |
</Box>
|
187 |
|
188 |
+
{/* Right: Model Explanation */}
|
189 |
+
<Box flex={{ base: '1', md: '0.6' }} bg="white" p={4} borderRadius="md" boxShadow="sm">
|
190 |
+
<Heading size="sm" mb={4}>Model Prediction & Explanation</Heading>
|
191 |
+
{shapData?.length ? (
|
192 |
+
<>
|
193 |
+
<Text mb={2}><strong>Predicted Label:</strong> {predicted}</Text>
|
194 |
+
<Text mb={4}><strong>Probability:</strong> {(probability * 100).toFixed(2)}%</Text>
|
195 |
+
<ResponsiveContainer width="100%" height={300}>
|
196 |
+
<BarChart data={shapData} margin={{ top: 10, right: 30, left: 0, bottom: 5 }}>
|
197 |
+
<CartesianGrid strokeDasharray="3 3" />
|
198 |
+
<XAxis dataKey="feature" />
|
199 |
+
<YAxis />
|
200 |
+
<Tooltip />
|
201 |
+
<Bar dataKey="shap" name="SHAP Value">
|
202 |
+
{shapData.map((entry, index) => (
|
203 |
+
<Cell
|
204 |
+
key={`cell-${index}`}
|
205 |
+
fill={entry.shap >= 0 ? "#4caf50" : "#f44336"}
|
206 |
+
/>
|
207 |
+
))}
|
208 |
+
</Bar>
|
209 |
+
</BarChart>
|
210 |
+
</ResponsiveContainer>
|
211 |
+
</>
|
212 |
+
) : (
|
213 |
+
<Spinner />
|
214 |
+
)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
</Box>
|
216 |
</Flex>
|
217 |
);
|
218 |
+
}
|
frontend/src/hooks/types.ts
CHANGED
@@ -12,15 +12,18 @@ export interface Project {
|
|
12 |
frameworkProgramme: string;
|
13 |
list_euroSciVocTitle: string[];
|
14 |
list_euroSciVocPath: string[];
|
|
|
|
|
|
|
15 |
}
|
16 |
|
17 |
export interface ProjectDetailsProps {
|
18 |
project: Project;
|
19 |
-
question: string;
|
20 |
-
setQuestion: React.Dispatch<React.SetStateAction<string>>;
|
21 |
-
askChatbot: () => void;
|
22 |
-
chatHistory: ChatMessage[];
|
23 |
-
messagesEndRef: React.RefObject<HTMLDivElement>;
|
24 |
}
|
25 |
|
26 |
|
|
|
12 |
frameworkProgramme: string;
|
13 |
list_euroSciVocTitle: string[];
|
14 |
list_euroSciVocPath: string[];
|
15 |
+
explanations: Array<{ feature: string; shap: number }>;
|
16 |
+
predicted_label: number;
|
17 |
+
predicted_prob: number;
|
18 |
}
|
19 |
|
20 |
export interface ProjectDetailsProps {
|
21 |
project: Project;
|
22 |
+
// question: string;
|
23 |
+
// setQuestion: React.Dispatch<React.SetStateAction<string>>;
|
24 |
+
// askChatbot: () => void;
|
25 |
+
// chatHistory: ChatMessage[];
|
26 |
+
// messagesEndRef: React.RefObject<HTMLDivElement>;
|
27 |
}
|
28 |
|
29 |
|