File size: 7,080 Bytes
68fa4ef
 
06f147b
68fa4ef
06f147b
 
 
 
 
 
 
 
 
fbb6998
 
06f147b
fbb6998
06f147b
 
 
 
 
 
8cd08cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06f147b
 
 
 
 
 
 
8cd08cb
a40eef3
06f147b
8cd08cb
06f147b
 
a40eef3
06f147b
 
 
 
 
 
a40eef3
06f147b
 
 
 
 
 
 
 
 
 
 
 
a40eef3
06f147b
 
 
 
 
 
 
 
 
 
21214a8
 
 
06f147b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cd08cb
06f147b
21214a8
 
 
 
 
 
 
 
06f147b
21214a8
 
 
 
06f147b
21214a8
 
 
 
06f147b
21214a8
 
 
 
06f147b
 
21214a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06f147b
 
 
 
68fa4ef
21214a8
8cd08cb
21214a8
8cd08cb
21214a8
 
 
 
 
 
 
 
 
 
 
 
 
 
8cd08cb
 
21214a8
4f0be60
 
 
8cd08cb
 
21214a8
8cd08cb
21214a8
 
 
68fa4ef
 
 
36c08d4
a40eef3
8cd08cb
fbb6998
36c08d4
21214a8
06f147b
21214a8
36c08d4
fbb6998
 
36c08d4
21214a8
 
68fa4ef
06f147b
68fa4ef
21214a8
 
 
 
 
 
 
 
36c08d4
21214a8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import json
import os

from datetime import datetime, timezone
from typing import List
import tempfile

from src.display.formatting import styled_message
from src.envs import API, EVAL_RESULTS_PATH, RESULTS_REPO

import threading
import queue

import gradio as gr

import numpy as np
from sklearn.metrics import confusion_matrix


GROUND_TRUTH_DATA = None
TASKS_QUEUE = queue.Queue()


# Build remap LUT
remap_dict = {
    0: 0, 1: 0, 2: 0, 3: 0, 4: 0,
    5: 1,
    6: 2, 7: 2,
    8: 3, 9: 3, 10: 3, 11: 3,
    14: 4,
    15: 5,
    16: 6,
    17: 7, 18: 7,
    19: 8,
    20: 9,
    21: 10,
    12: 255, 13: 255, 255: 255
}


# Create LUT (assuming input values are between 0 and 255)
lut = np.full(256, 255, dtype=np.uint8)  # default to 255
for k, v in remap_dict.items():
    lut[k] = v

# Worker function (runs in background thread)
def queue_worker():
    global TASKS_QUEUE
    print("queue_worker")
    while True:
        # Get data from queue
        print("Wait data from queue")
        las_files, user_name, result_name, current_time, remap = TASKS_QUEUE.get()
        print(f"Compute stats {user_name} {result_name} {current_time}")
        # Compute metrics:
        metrics = eval_las_files(las_files, remap)

        eval_entry = {
                         "result_name": result_name,
                         "submitted_time": current_time,
                     } | metrics

        print("Creating eval file")
        out_dir = f"{EVAL_RESULTS_PATH}/{user_name}"
        os.makedirs(out_dir, exist_ok=True)
        out_path = f"{out_dir}/{result_name}_result.json"

        print(out_path)
        with open(out_path, "w") as f:
            f.write(json.dumps(eval_entry))

        print("Uploading eval file")
        print(out_path.split("eval-results/")[1])
        API.upload_file(
            path_or_fileobj=out_path,
            path_in_repo=out_path.split("eval-results/")[1],
            repo_id=RESULTS_REPO,
            repo_type="dataset",
            commit_message=f"Add {user_name}/{result_name} to result",
        )
        # Update the UI (via the output_box.set() call)
        # output_box.set(result)
        TASKS_QUEUE.task_done()


WORKER_THREAD = threading.Thread(target=queue_worker, daemon=True)
WORKER_THREAD.start()


def handle_file_path(path: str) -> str:
    basename = os.path.basename(path)
    name_wo_ext = os.path.splitext(basename)[0]
    return name_wo_ext


def read_ground_truth():
    print("read_ground_truth")
    global GROUND_TRUTH_DATA
    GROUND_TRUTH_DATA = {}
    directory_path = os.path.join(EVAL_RESULTS_PATH, "ground_truth")
    # Iterate over all files in the directory
    for filename in os.listdir(directory_path):
        if filename.endswith(".npz"):
            file_path = os.path.join(directory_path, filename)
            # Load the .npz file
            with np.load(file_path) as data:
                # Store the data in the dictionary with the filename (without extension) as the key
                GROUND_TRUTH_DATA[os.path.splitext(filename)[0]] = lut[data["data"]]

def compute_metrics_from_cm(cm: np.ndarray):
    tp = np.diag(cm)
    fp = np.sum(cm, axis=0) - tp
    fn = np.sum(cm, axis=1) - tp

    # Accuracy
    total = np.sum(cm)
    accuracy = np.sum(tp) / total if total != 0 else 0.0

    # Precision, Recall, F1 per class
    precision = np.divide(tp, tp + fp, out=np.zeros_like(tp, dtype=np.float64), where=(tp + fp) != 0)
    recall = np.divide(tp, tp + fn, out=np.zeros_like(tp, dtype=np.float64), where=(tp + fn) != 0)
    f1 = np.divide(2 * precision * recall, precision + recall, out=np.zeros_like(tp, dtype=np.float64), where=(precision + recall) != 0)

    # Macro average
    precision_macro = np.mean(precision)
    recall_macro = np.mean(recall)
    f1_macro = np.mean(f1)

    # iou
    denom = tp + fp + fn
    iou = np.divide(tp, denom, out=np.zeros_like(tp, dtype=np.float64), where=denom != 0)
    miou = np.mean(iou)

    # Compute UA and PA
    row_sums = cm.sum(axis=1)
    col_sums = cm.sum(axis=0)
    # Producer's Accuracy (Recall-like)
    pa = np.divide(tp, row_sums, out=np.zeros_like(tp, dtype=np.float64), where=row_sums != 0)

    # User's Accuracy (Precision-like)
    ua = np.divide(tp, col_sums, out=np.zeros_like(tp, dtype=np.float64), where=col_sums != 0)


    return {
        "accuracy": accuracy,
        "precision_macro": precision_macro,
        "recall_macro": recall_macro,
        "f1_macro": f1_macro,
        "precision_per_class": precision.tolist(),
        "recall_per_class": recall.tolist(),
        "f1_per_class": f1.tolist(),
        "iou": iou.tolist(),
        "miou": miou,
        "producer_accuracy": pa.tolist(),
        "user_accuracy": ua.tolist(),
        "confusion_matrix": cm.tolist(),
    }


def eval_las_files(npz_file_paths: List[str], remap=False):
    global GROUND_TRUTH_DATA
    NUM_CLASSES = 11  # adjust to your case
    LABELS = list(range(NUM_CLASSES))  # [0, 1, ..., 16]
    global_cm = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=np.int64)

    if GROUND_TRUTH_DATA is None:
        read_ground_truth()
    for file_path in npz_file_paths:
        print("Reading file:", file_path)
        area = handle_file_path(file_path)
        if area not in GROUND_TRUTH_DATA:
            print(f"Error {area} is not a known area !")
            continue
        # Read the NPZ file
        with np.load(file_path) as data:
            y_pred = data["data"]
            if remap:
                y_pred = lut[y_pred]
            y_true = GROUND_TRUTH_DATA[area]
            if y_true.shape != y_pred.shape:
                print(f"Error {area} pred and gt with different shape {y_true.shape=} {y_pred.shape} !")
                continue
            #
            valid = y_true != 255
            # Confusion matrix
            cm = confusion_matrix(y_true[valid], y_pred[valid], labels=LABELS)
            global_cm += cm

    return compute_metrics_from_cm(global_cm)


def add_new_eval(
    # user_name: str,
    result_name: str,
    npz_files: List[tempfile._TemporaryFileWrapper],
    remap: bool,
    profile: gr.OAuthProfile | None
) -> str:
    global TASKS_QUEUE

    if profile is None:
        return styled_message("⚠️ Please log in to submit your evaluation.")

    if not result_name or not npz_files:
        return styled_message("❌ Please fill in all fields and upload at least one NPZ file.")

    current_time = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
    print("Adding new eval in tasks queue")

    # Save uploaded LAS files to disk (because TemporaryFile may be closed)
    saved_files = []
    for file in npz_files:
        new_path = os.path.join(tempfile.gettempdir(), os.path.basename(file.name))
        with open(new_path, 'wb') as out_file, open(file.name, 'rb') as in_file:
            out_file.write(in_file.read())
        saved_files.append(new_path)

    TASKS_QUEUE.put((saved_files, profile.username, result_name, current_time, remap))

    return styled_message("✅ Your request has been added! The leaderboard will update after processing.")