General-Level commited on
Commit
0eb3766
·
1 Parent(s): 1395cd6

Resolve conflict

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. General-Bench-Closeset/.gitkeep +0 -0
  2. General-Bench-Openset/.gitkeep +0 -0
  3. README.md +80 -0
  4. README_Evaluate.md +77 -0
  5. outcome/Qwen2.5-7B-Instruct_result.xlsx +0 -0
  6. outcome/emu2-32b_result.xlsx +0 -0
  7. outcome/test_result.xlsx +0 -0
  8. predictors/audio_predict_comprehension.py +1252 -0
  9. predictors/audio_predict_generation.py +1245 -0
  10. predictors/nlp_predictor.py +1024 -0
  11. predictors/video_comprehension_flow_matching_tracking.py +562 -0
  12. predictors/video_comprehension_qa_caption.py +443 -0
  13. predictors/video_comprehension_tasks.py +550 -0
  14. predictors/video_generation_evaluate_kit.py +327 -0
  15. predictors/video_translation_restoration_superresolution_objectdetection.py +340 -0
  16. processors/._audio_processor.py +0 -0
  17. processors/._image_processor.py +0 -0
  18. processors/__init__.py +1 -0
  19. processors/__pycache__/.___init__.cpython-38.pyc +0 -0
  20. processors/__pycache__/.___init__.cpython-39.pyc +0 -0
  21. processors/__pycache__/._video_processor.cpython-39.pyc +0 -0
  22. processors/__pycache__/__init__.cpython-311.pyc +0 -0
  23. processors/__pycache__/__init__.cpython-312.pyc +0 -0
  24. processors/__pycache__/__init__.cpython-38.pyc +0 -0
  25. processors/__pycache__/__init__.cpython-39.pyc +0 -0
  26. processors/__pycache__/audio_processor.cpython-311.pyc +0 -0
  27. processors/__pycache__/audio_processor.cpython-312.pyc +0 -0
  28. processors/__pycache__/audio_processor.cpython-38.pyc +0 -0
  29. processors/__pycache__/audio_processor.cpython-39.pyc +0 -0
  30. processors/__pycache__/image_processor.cpython-311.pyc +0 -0
  31. processors/__pycache__/image_processor.cpython-312.pyc +0 -0
  32. processors/__pycache__/image_processor.cpython-38.pyc +0 -0
  33. processors/__pycache__/image_processor.cpython-39.pyc +0 -0
  34. processors/__pycache__/nlp_processor.cpython-311.pyc +0 -0
  35. processors/__pycache__/nlp_processor.cpython-312.pyc +0 -0
  36. processors/__pycache__/nlp_processor.cpython-38.pyc +0 -0
  37. processors/__pycache__/nlp_processor.cpython-39.pyc +0 -0
  38. processors/__pycache__/pseudo_audio_processor.cpython-39.pyc +0 -0
  39. processors/__pycache__/three_d_processor.cpython-311.pyc +0 -0
  40. processors/__pycache__/three_d_processor.cpython-312.pyc +0 -0
  41. processors/__pycache__/three_d_processor.cpython-38.pyc +0 -0
  42. processors/__pycache__/three_d_processor.cpython-39.pyc +0 -0
  43. processors/__pycache__/video_processor.cpython-311.pyc +0 -0
  44. processors/__pycache__/video_processor.cpython-312.pyc +0 -0
  45. processors/__pycache__/video_processor.cpython-38.pyc +0 -0
  46. processors/__pycache__/video_processor.cpython-39.pyc +0 -0
  47. processors/audio_processor.py +80 -0
  48. processors/image_processor.py +83 -0
  49. processors/nlp_processor.py +381 -0
  50. processors/three_d_processor.py +79 -0
General-Bench-Closeset/.gitkeep ADDED
File without changes
General-Bench-Openset/.gitkeep ADDED
File without changes
README.md CHANGED
@@ -1,3 +1,4 @@
 
1
  ---
2
  title: README
3
  emoji: 🌍
@@ -131,3 +132,82 @@ If you find our benchmark useful in your research, please kindly consider citing
131
  ```
132
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <<<<<<< HEAD
2
  ---
3
  title: README
4
  emoji: 🌍
 
132
  ```
133
 
134
 
135
+ =======
136
+ # GenBench 评分系统 - 用户使用说明
137
+
138
+ 本系统用于评估大模型在 General-Bench 多模态任务集上的表现,可完成预测、评分和最终得分计算。
139
+
140
+ ## 环境准备
141
+
142
+ - Python 3.9 及以上
143
+ - 推荐提前安装依赖(如 pandas, numpy, openpyxl 等)
144
+ - Video Generation评测,需要按照video_generation_evaluation/README.md中的步骤安装依赖
145
+ - Video Comprehension评测,需要按照[sa2va](https://github.com/magic-research/Sa2VA)中的README.md中的步骤安装依赖。
146
+
147
+ ## 数据集下载
148
+
149
+ - **Open Set(公开数据集)**:请从 [HuggingFace General-Bench-Openset](https://huggingface.co/datasets/General-Level/General-Bench-Openset) 下载全部数据,解压后放入 `General-Bench-Openset/` 目录。
150
+ - **Close Set(私有数据集)**:请从 [HuggingFace General-Bench-Closeset](https://huggingface.co/datasets/General-Level/General-Bench-Closeset) 下载全部数据,解压后放入 `General-Bench-Closeset/` 目录。
151
+
152
+ ## 一键运行
153
+
154
+ 请直接运行主脚本 `run.sh`,即可完成全部流程:
155
+
156
+ ```bash
157
+ bash run.sh
158
+ ```
159
+
160
+ 该命令将依次完成:
161
+ 1. 生成各模态预测结果
162
+ 2. 计算各任务得分
163
+ 3. 计算最终 Level 得分
164
+
165
+ ## 分步运行(可选)
166
+
167
+ 如只需运行部分步骤,可使用 `--step` 参数:
168
+
169
+ - 只运行第1步(生成预测):
170
+ ```bash
171
+ bash run.sh --step 1
172
+ ```
173
+ - 只运行第1、2步:
174
+ ```bash
175
+ bash run.sh --step 12
176
+ ```
177
+ - 只运行第2、3步:
178
+ ```bash
179
+ bash run.sh --step 23
180
+ ```
181
+ - 不加参数默认全部执行(等价于 `--step 123`)
182
+
183
+ - 步骤1:生成预测结果prediction.json,存在每一个数据集的annotation.json同级目录下
184
+ - 步骤2:计算每个任务的得分,存在outcome/{model_name}_result.xlsx中
185
+ - 步骤3:计算相关模型的Level得分
186
+
187
+ > **注意:**
188
+ > - 使用 **Close Set(私有数据集)** 时,只需运行 step1(即 `bash run.sh --step 1`),并将生成的 prediction.json 提交到系统。
189
+ > - 使用 **Open Set(公开数据集)** 时,需依次运行 step1、step2、step3(即 `bash run.sh --step 123`),完成全部评测流程。
190
+
191
+ ## 结果查看
192
+
193
+ - 预测结果(prediction.json)会输出到每个任务对应的数据集文件夹下,与 annotation.json 同级。
194
+ - 评分结果(如 Qwen2.5-7B-Instruct_result.xlsx)会输出到 outcome/ 目录。
195
+ - 最终 Level 得分会直接在终端打印输出。
196
+
197
+ ## 目录说明
198
+
199
+ - `General-Bench-Openset/`:公开数据集目录
200
+ - `General-Bench-Closeset/`:私有数据集目录
201
+ - `outcome/`:输出结果目录
202
+ - `references/`:参考模板目录
203
+ - `run.sh`:主运行脚本(推荐用户只用此脚本)
204
+
205
+ ## 常见问题
206
+
207
+ - 如遇依赖缺失,请根据报错信息安装相应 Python 包。
208
+ - 如需自定义模型或数据路径,可编辑 `run.sh` 脚本中的相关变量。
209
+
210
+ ---
211
+
212
+ 如需进一步帮助,请联系系统维护者或查阅详细开发文档。
213
+ >>>>>>> 6f59817 (submit NLP Video Audio)
README_Evaluate.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GenBench 评分系统 - 用户使用说明
2
+
3
+ 本系统用于评估大模型在 General-Bench 多模态任务集上的表现,可完成预测、评分和最终得分计算。
4
+
5
+ ## 环境准备
6
+
7
+ - Python 3.9 及以上
8
+ - 推荐提前安装依赖(如 pandas, numpy, openpyxl 等)
9
+ - Video Generation评测,需要按照video_generation_evaluation/README.md中的步骤安装依赖
10
+ - Video Comprehension评测,需要按照[sa2va](https://github.com/magic-research/Sa2VA)中的README.md中的步骤安装依赖。
11
+
12
+ ## 数据集下载
13
+
14
+ - **Open Set(公开数据集)**:请从 [HuggingFace General-Bench-Openset](https://huggingface.co/datasets/General-Level/General-Bench-Openset) 下载全部数据,解压后放入 `General-Bench-Openset/` 目录。
15
+ - **Close Set(私有数据集)**:请从 [HuggingFace General-Bench-Closeset](https://huggingface.co/datasets/General-Level/General-Bench-Closeset) 下载全部数据,解压后放入 `General-Bench-Closeset/` 目录。
16
+
17
+ ## 一键运行
18
+
19
+ 请直接运行主脚本 `run.sh`,即可完成全部流程:
20
+
21
+ ```bash
22
+ bash run.sh
23
+ ```
24
+
25
+ 该命令将依次完成:
26
+ 1. 生成各模态预测结果
27
+ 2. 计算各任务得分
28
+ 3. 计算最终 Level 得分
29
+
30
+ ## 分步运行(可选)
31
+
32
+ 如只需运行部分步骤,可使用 `--step` 参数:
33
+
34
+ - 只运行第1步(生成预测):
35
+ ```bash
36
+ bash run.sh --step 1
37
+ ```
38
+ - 只运行第1、2步:
39
+ ```bash
40
+ bash run.sh --step 12
41
+ ```
42
+ - 只运行第2、3步:
43
+ ```bash
44
+ bash run.sh --step 23
45
+ ```
46
+ - 不加参数默认全部执行(等价于 `--step 123`)
47
+
48
+ - 步骤1:生成预测结果prediction.json,存在每一个数据集的annotation.json同级目录下
49
+ - 步骤2:计算每个任务的得分,存在outcome/{model_name}_result.xlsx中
50
+ - 步骤3:计算相关模型的Level得分
51
+
52
+ > **注意:**
53
+ > - 使用 **Close Set(私有数据集)** 时,只需运行 step1(即 `bash run.sh --step 1`),并将生成的 prediction.json 提交到系统。
54
+ > - 使用 **Open Set(公开数据集)** 时,需依次运行 step1、step2、step3(即 `bash run.sh --step 123`),完成全部评测流程。
55
+
56
+ ## 结果查看
57
+
58
+ - 预测结果(prediction.json)会输出到每个任务对应的数据集文件夹下,与 annotation.json 同级。
59
+ - 评分结果(如 Qwen2.5-7B-Instruct_result.xlsx)会输出到 outcome/ 目录。
60
+ - 最终 Level 得分会直接在终端打印输出。
61
+
62
+ ## 目录说明
63
+
64
+ - `General-Bench-Openset/`:公开数据集目录
65
+ - `General-Bench-Closeset/`:私有数据集目录
66
+ - `outcome/`:输出结果目录
67
+ - `references/`:参考模板目录
68
+ - `run.sh`:主运行脚本(推荐用户只用此脚本)
69
+
70
+ ## 常见问题
71
+
72
+ - 如遇依赖缺失,请根据报错信息安装相应 Python 包。
73
+ - 如需自定义模型或数据路径,可编辑 `run.sh` 脚本中的相关变量。
74
+
75
+ ---
76
+
77
+ 如需进一步帮助,请联系系统维护者或查阅详细开发文档。
outcome/Qwen2.5-7B-Instruct_result.xlsx ADDED
Binary file (34.1 kB). View file
 
outcome/emu2-32b_result.xlsx ADDED
Binary file (43.8 kB). View file
 
outcome/test_result.xlsx ADDED
Binary file (34 kB). View file
 
predictors/audio_predict_comprehension.py ADDED
@@ -0,0 +1,1252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from email.mime import audio
2
+ import json
3
+ import os
4
+ from pandas import read_json
5
+ from regex import B, D
6
+ import tqdm
7
+ from typing import List, Dict, Any
8
+ import nltk
9
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
10
+ from dataclasses import dataclass
11
+ from abc import ABC, abstractmethod
12
+ from rouge_score import rouge_scorer
13
+ import math
14
+ import time
15
+ from urllib.request import urlopen
16
+ import librosa
17
+ from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor
18
+ import torch
19
+ from transformers import AutoTokenizer, AutoModelForCausalLM
20
+
21
+
22
+ def read_json(file_path: str) -> Dict[str, Any]:
23
+ with open(file_path, "r") as f:
24
+ data = json.load(f)
25
+ return data
26
+
27
+
28
+ def exact_match_accuracy(predictions: List[str], references: List[str]) -> float:
29
+ correct = 0
30
+ for pred, ref in zip(predictions, references):
31
+ if isinstance(ref, str):
32
+ ref = [ref]
33
+ if isinstance(ref, int):
34
+ ref = [ref]
35
+ is_match_this_turn = False
36
+ for r in ref:
37
+ if pred.strip() == r.strip():
38
+ is_match_this_turn = True
39
+ if is_match_this_turn:
40
+ correct += 1
41
+ return correct / len(predictions) if predictions else 0.0
42
+
43
+
44
+ def blur_match_accuracy(predictions: List[str], references: List[str]) -> float:
45
+ correct = 0
46
+ for pred, ref in zip(predictions, references):
47
+ # if isinstance(ref, int):
48
+ # if == ref:
49
+ if str(ref) in str(pred).strip().lower():
50
+ correct += 1
51
+ return correct / len(predictions) if predictions else 0.0
52
+
53
+
54
+ def calculate_f1(predictions: List[str], references: List[str]) -> float:
55
+ def compute_f1(pred: str, ref: str) -> float:
56
+ pred_tokens = pred.strip().split()
57
+ ref_tokens = ref.strip().split()
58
+
59
+ common_tokens = set(pred_tokens) & set(ref_tokens)
60
+ num_common = len(common_tokens)
61
+
62
+ if num_common == 0:
63
+ return 0.0
64
+
65
+ precision = num_common / len(pred_tokens)
66
+ recall = num_common / len(ref_tokens)
67
+
68
+ return 2 * precision * recall / (precision + recall)
69
+
70
+ total_f1 = 0.0
71
+ for pred, ref in zip(predictions, references):
72
+ if isinstance(ref, str):
73
+ ref = [ref]
74
+ max_f1 = 0.0
75
+ for r in ref:
76
+ max_f1 = max(compute_f1(pred, r), max_f1)
77
+ total_f1 += max_f1
78
+
79
+ return total_f1 / len(predictions) if predictions else 0.0
80
+
81
+
82
+ def rouge_evaluation(predictions: List[str], references: List[str]) -> Dict[str, float]:
83
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
84
+ rouge1_scores, rouge2_scores, rougel_scores = [], [], []
85
+ for pred, ref in zip(predictions, references):
86
+ if isinstance(ref, str):
87
+ ref = [ref]
88
+ rouge1, rouge2, rougeL = 0, 0, 0
89
+ for r in ref:
90
+ scores = scorer.score(r, pred)
91
+ rouge1 = max(scores['rouge1'].fmeasure, rouge1)
92
+ rouge2 = max(scores['rouge2'].fmeasure, rouge2)
93
+ rougeL = max(scores['rougeL'].fmeasure, rougeL)
94
+ rouge1_scores.append(rouge1)
95
+ rouge2_scores.append(rouge2)
96
+ rougel_scores.append(rougeL)
97
+ return {
98
+ 'rouge1': sum(rouge1_scores) / len(rouge1_scores),
99
+ 'rouge2': sum(rouge2_scores) / len(rouge2_scores),
100
+ 'rougeL': sum(rougel_scores) / len(rougel_scores),
101
+ }
102
+
103
+
104
+ def bleu_evaluation(predictions: List[str], references: List[str]) -> Dict[str, float]:
105
+ smoothie = SmoothingFunction().method4
106
+ bleu1_scores, bleu2_scores, bleu3_scores, bleu4_scores = [], [], [], []
107
+
108
+ for pred, ref in zip(predictions, references):
109
+ hypothesis = nltk.word_tokenize(pred)
110
+ if isinstance(ref, str):
111
+ ref = [ref]
112
+ bleu1, bleu2, bleu3, bleu4 = 0, 0, 0, 0
113
+ for r in ref:
114
+ reference = [nltk.word_tokenize(r)]
115
+ bleu1 = max(sentence_bleu(reference, hypothesis, weights=(1, 0, 0, 0), smoothing_function=smoothie), bleu1)
116
+ bleu2 = max(sentence_bleu(reference, hypothesis, weights=(0.5, 0.5, 0, 0), smoothing_function=smoothie), bleu2)
117
+ bleu3 = max(sentence_bleu(reference, hypothesis, weights=(1/3, 1/3, 1/3, 0), smoothing_function=smoothie), bleu3)
118
+ bleu4 = max(sentence_bleu(reference, hypothesis, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothie), bleu4)
119
+
120
+ bleu1_scores.append(bleu1)
121
+ bleu2_scores.append(bleu2)
122
+ bleu3_scores.append(bleu3)
123
+ bleu4_scores.append(bleu4)
124
+
125
+ return {
126
+ 'bleu1': sum(bleu1_scores) / len(bleu1_scores) if bleu1_scores else 0.0,
127
+ 'bleu2': sum(bleu2_scores) / len(bleu2_scores) if bleu2_scores else 0.0,
128
+ 'bleu3': sum(bleu3_scores) / len(bleu3_scores) if bleu3_scores else 0.0,
129
+ 'bleu4': sum(bleu4_scores) / len(bleu4_scores) if bleu4_scores else 0.0,
130
+ }
131
+
132
+
133
+ def mean_absolute_error(predictions: List[float], references: List[float]) -> float:
134
+ if not predictions:
135
+ return 0.0
136
+ error_sum = 0.0
137
+ for p, r in zip(predictions, references):
138
+ error_sum += abs(p - r)
139
+ return error_sum / len(predictions)
140
+
141
+
142
+ def mean_squared_error(predictions: List[float], references: List[float]) -> float:
143
+ if not predictions:
144
+ return 0.0
145
+ error_sum = 0.0
146
+ for p, r in zip(predictions, references):
147
+ error_sum += (p - r) ** 2
148
+ return error_sum / len(predictions)
149
+
150
+
151
+ def root_mean_squared_error(predictions: List[float], references: List[float]) -> float:
152
+ return math.sqrt(mean_squared_error(predictions, references))
153
+
154
+
155
+ def post_process_output(output: str) -> str:
156
+ cnt = 0
157
+ for d in output:
158
+ if d['gt'] in d['response'].strip().lower():
159
+ cnt += 1
160
+ acc = round(cnt / len(output), 4)
161
+ print(f"Accuracy: {acc}")
162
+ return acc
163
+
164
+
165
+ def evaluation_accuracy(predictions: List[str]) -> Dict[str, float]:
166
+ correct = 0
167
+ for pred in predictions:
168
+ if pred == '1':
169
+ correct += 1
170
+ return correct / len(predictions) if predictions else 0.0
171
+
172
+
173
+ class AudioComprehensionModel:
174
+ def __init__(self, model_name: str):
175
+ self.model_name = model_name
176
+ self.load_model()
177
+
178
+ def load_model(self):
179
+ if 'qwen-audio-chat' in self.model_name.lower():
180
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_name, device_map='cuda', trust_remote_code=True).eval()
181
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
182
+ self.tokenizer.padding_side = 'left'
183
+ self.tokenizer.pad_token_id = self.tokenizer.eod_id
184
+ elif 'qwen2' in self.model_name.lower():
185
+ self.processor = AutoProcessor.from_pretrained(self.model_name)
186
+ print(self.processor.chat_template)
187
+ self.model = Qwen2AudioForConditionalGeneration.from_pretrained(self.model_name, device_map="auto").eval()
188
+
189
+ elif 'new_model_name' in self.model_name.lower():
190
+ # support to load self-build models here
191
+ pass
192
+
193
+ else:
194
+ raise ValueError(f"Unsupported model name: {self.model_name}")
195
+
196
+ def generate(self, prompt: str, max_new_tokens=256, audio_path: str=None) -> str:
197
+
198
+ if "qwen-audio-chat" in self.model_name.lower():
199
+ query = self.tokenizer.from_list_format([
200
+ {'audio': audio_path}, # Either a local path or an url
201
+ {'text': prompt} # The query,
202
+ ])
203
+ response, history = self.model.chat(self.tokenizer, query=query, history=None)
204
+ return response
205
+
206
+ elif "qwen2" in self.model_name.lower():
207
+ conversation = [
208
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
209
+ {"role": "user", "content": [
210
+ {"type": "audio", "audio": audio_path},
211
+ {"type": "text", "text": prompt},
212
+ ]},
213
+ ]
214
+ text = self.processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
215
+ audios = []
216
+ for message in conversation:
217
+ if isinstance(message["content"], list):
218
+ for ele in message["content"]:
219
+ if ele["type"] == "audio":
220
+ audios.append(
221
+ librosa.load(
222
+ ele['audio'],
223
+ sr=self.processor.feature_extractor.sampling_rate)[0]
224
+ )
225
+ # print(text)
226
+ inputs = self.processor(text=text, audios=audios, return_tensors="pt", padding=True)
227
+ inputs.input_ids = inputs.input_ids.to("cuda")
228
+ inputs = inputs.to("cuda")
229
+ # print(inputs)
230
+ # exit(0)
231
+ generate_ids = self.model.generate(**inputs, max_length=300)
232
+ generate_ids = generate_ids[:, inputs.input_ids.size(1):]
233
+
234
+ response = self.processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
235
+ return response
236
+
237
+ elif "new" in self.model_name.lower():
238
+ # support to generate response based on self-build models here
239
+ pass
240
+
241
+ else:
242
+ raise ValueError(f"Unsupported model name: {self.model_name}")
243
+
244
+
245
+
246
+ @dataclass
247
+ class Instance:
248
+ input: Dict[str, Any]
249
+ output: Dict[str, Any]
250
+ id: str
251
+
252
+
253
+ class BaseTask(ABC):
254
+ def __init__(self, task_data: Dict[str, Any], model: AudioComprehensionModel, audio_dir: str = None, output_dir: str = None, task_name: str = None):
255
+ self.task_data = read_json(task_data)
256
+ self.model = model
257
+ self.audio_dir = audio_dir # should include the audios files
258
+ self.data = self._parse_data(self.task_data)
259
+ self.choice_candidate = self._get_choice_candidate(self.task_data)
260
+ self.task_name = os.path.dirname(task_data).split("/")[-1] if task_name is None else task_name
261
+ self.output_dir = output_dir
262
+ os.makedirs(self.output_dir, exist_ok=True) if self.output_dir else None
263
+
264
+ self.references = []
265
+ self.predictions = []
266
+
267
+ def save_predictions(self, audio_paths):
268
+ results = []
269
+ for gt, response, audio_path in zip(self.references, self.predictions, audio_paths):
270
+ results.append({
271
+ 'gt': gt,
272
+ 'response': response,
273
+ 'audio_path': audio_path,
274
+ })
275
+ time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
276
+ results_file = os.path.join(self.output_dir, f'{self.task_name }_{time_prefix}.json') if self.output_dir else f'{self.task_name }_{time_prefix}.json'
277
+ json.dump(results, open(results_file, 'w'))
278
+
279
+ @abstractmethod
280
+ def _get_choice_candidate(self):
281
+ pass
282
+
283
+ @abstractmethod
284
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
285
+ pass
286
+
287
+ @abstractmethod
288
+ def evaluate(self) -> Dict[str, float]:
289
+ pass
290
+
291
+ @abstractmethod
292
+ def run_inference(self):
293
+ pass
294
+
295
+
296
+ class EvaluationTask(BaseTask):
297
+ """
298
+ Used to determine whether the results generated by the model are correct
299
+ """
300
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
301
+ return task_data
302
+
303
+ def _get_choice_candidate(self, data: List[Instance]) -> List[str]:
304
+ return ["None"]
305
+
306
+ def save_predictions(self, audio_paths):
307
+ results = []
308
+ for gt, response, audio_path in zip(self.references, self.predictions, audio_paths):
309
+ results.append({
310
+ 'gt': gt[0],
311
+ 'response': gt[1],
312
+ 'audio_path': audio_path,
313
+ 'llm_prediction': response,
314
+ })
315
+ time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
316
+ results_file = os.path.join(self.output_dir, f'{self.task_name }_{time_prefix}.json') if self.output_dir else f'{self.task_name }_{time_prefix}.json'
317
+ json.dump(results, open(results_file, 'w'))
318
+
319
+ def run_inference(self):
320
+ audio_paths = []
321
+ for inst in tqdm.tqdm(self.data):
322
+ prompt = " will provide you with a Ground-truth label and a Prediction label. The label can either be a single string or a list of multiple labels. I need you to compare these two labels on a semantic level.\nSpecifically, I want you to evaluate whether the Prediction label semantically matches, is partially aligned, includes, or describes the Ground-truth label (or the semantic meaning represented by the list of labels). If any of these conditions are satisfied, consider it a match.\n\nHere are some examples of successful matches:\n\nGround-truth label: \"rain\"\nPrediction label: \"The sound in the audio is rain falling\"\n(This is considered a match.)\nGround-truth label: [\"decrease\", \"volume\", \"none\"]\nPrediction label: \"The intent in the audio is to adjust the volume\"(This is also considered a match.)\nIf the labels successfully match, assign a score of 1. If they do not match, assign a score of 0.**Imporant!!!, only output the score (0 or 1), no explanation.** \n\nGround-truth label:{}\nPrediction label:{}"
323
+ gt = inst["gt"]
324
+ response = inst["response"]
325
+ prompt = prompt.format(gt, response)
326
+ try:
327
+ response = self.model.generate(prompt)
328
+ # print(response)
329
+ except Exception as e:
330
+ response = "None"
331
+ continue
332
+
333
+ self.predictions.append(response)
334
+ self.references.append([inst["gt"], inst["response"]])
335
+ audio_paths.append(inst["audio_path"])
336
+ self.save_predictions(audio_paths)
337
+
338
+ def evaluate(self) -> Dict[str, float]:
339
+ acc = evaluation_accuracy(self.predictions)
340
+ return {"accuracy": acc}
341
+
342
+
343
+ class AccentSexClassification(BaseTask):
344
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
345
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
346
+ for d in task_data["data"]]
347
+
348
+ def _get_choice_candidate(self, data: List[Instance]) -> List[str]:
349
+ return ['female', 'male']
350
+
351
+ def save_predictions(self, audio_paths):
352
+ results = []
353
+ for gt, response, audio_path in zip(self.references, self.predictions, audio_paths):
354
+ results.append({
355
+ 'gt': gt,
356
+ 'response': response,
357
+ 'audio_path': audio_path,
358
+ })
359
+ time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
360
+ results_file = os.path.join(self.output_dir, f'{self.task_name }_{time_prefix}.json') if self.output_dir else f'{self.task_name }_{time_prefix}.json'
361
+ json.dump(results, open(results_file, 'w'))
362
+
363
+ def run_inference(self):
364
+ self.predictions = []
365
+ self.references = []
366
+ audio_paths = []
367
+ for inst in tqdm.tqdm(self.data):
368
+ audio_path = os.path.join(self.audio_dir, inst.input["audio_file"])
369
+ question = inst.input["prompt"]
370
+ prompt = f"Please listen to the audio and then answer the question by directly choose a choice from choice candidates. Questions: {question}, Candidate choices: {self.choice_candidate}\nAnswer:"
371
+ try:
372
+ response = self.model.generate(prompt, audio_path=audio_path)
373
+ except:
374
+ print("error audio {}".format(inst.input["audio_file"]))
375
+ continue
376
+ self.predictions.append(response)
377
+ self.references.append(inst.output["text"])
378
+ audio_paths.append(inst.input["audio_file"])
379
+
380
+ self.save_predictions(audio_paths)
381
+
382
+
383
+ def evaluate(self) -> Dict[str, float]:
384
+ acc = exact_match_accuracy(self.predictions, self.references)
385
+ return {"accuracy": acc}
386
+
387
+
388
+ class AcousticSceneClassification(BaseTask):
389
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
390
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
391
+ for d in task_data["data"]]
392
+
393
+ def _get_choice_candidate(self, data: List[Instance]) -> List[str]:
394
+ choices = []
395
+ for item in data['data']:
396
+ choices.append(item['output']["text"].strip().lower())
397
+ choices = list(set(choices))
398
+ return choices
399
+
400
+ def run_inference(self):
401
+ print(f"Choice candidates: {self.choice_candidate}")
402
+ audio_paths = []
403
+ for inst in tqdm.tqdm(self.data):
404
+ audio_path = os.path.join(self.audio_dir, inst.input["audio_file"])
405
+ question = inst.input["prompt"]
406
+ prompt = f"Please listen to the input music and then determine the category of the acoustic scene. The candidate scene category are {self.choice_candidate}. Please output **only one category** from the provided candidate categories, and **DO NOT** output any other words.\nQuestions: {question}\nAnswer:"
407
+ try:
408
+ response = self.model.generate(prompt, audio_path=audio_path)
409
+ except Exception as e:
410
+ print("Error audio: {}".format(inst.input["audio_file"]))
411
+ response = "None"
412
+ continue
413
+ self.predictions.append(response)
414
+ self.references.append(inst.output["text"].strip().lower())
415
+ audio_paths.append(inst.input["audio_file"])
416
+ self.save_predictions(audio_paths)
417
+
418
+ def evaluate(self) -> Dict[str, float]:
419
+ acc = exact_match_accuracy(self.predictions, self.references)
420
+ return {"accuracy": acc}
421
+
422
+
423
+ class AnimalSoundDetection(BaseTask):
424
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
425
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
426
+ for d in task_data["data"]]
427
+
428
+ def _get_choice_candidate(self, data) -> List[str]:
429
+ choices = []
430
+ for item in data['data']:
431
+ choices.append(item['output']["text"].strip().lower())
432
+ choices = list(set(choices))
433
+ return choices
434
+
435
+ def run_inference(self):
436
+ print(f"Choice candidates: {self.choice_candidate}")
437
+ audio_paths = []
438
+ for inst in tqdm.tqdm(self.data):
439
+ audio_path = os.path.join(self.audio_dir, inst.input["audio_file"])
440
+ question = inst.input["prompt"]
441
+ prompt = f"Please listen to the audio and then answer the question by directly choose a choice from choice candidates, without other words. Questions: {question}, Candidate choices: {self.choice_candidate}\nAnswer:"
442
+ try:
443
+ response = self.model.generate(prompt, audio_path=audio_path)
444
+ except Exception as e:
445
+ print("Error audio: {}".format(inst.input["audio_file"]))
446
+ response = "None"
447
+ continue
448
+ self.predictions.append(response)
449
+ self.references.append(inst.output["text"].strip().lower())
450
+ audio_paths.append(inst.input["audio_file"])
451
+ self.save_predictions(audio_paths)
452
+
453
+ def evaluate(self) -> Dict[str, float]:
454
+ acc = exact_match_accuracy(self.predictions, self.references)
455
+ return {"accuracy": acc}
456
+
457
+
458
+ class AudioCaptions(BaseTask):
459
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
460
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
461
+ for d in task_data["data"]]
462
+
463
+ def _get_choice_candidate(self, data: List[Instance]) -> List[str]:
464
+ return ["None"]
465
+
466
+ def run_inference(self):
467
+ audio_paths = []
468
+ for inst in tqdm.tqdm(self.data):
469
+ audio_path = os.path.join(self.audio_dir, inst.input["audio_file"])
470
+ question = inst.input["prompt"]
471
+ prompt = f"Please listen to the audio and then answer the question. Questions: {question}\nAnswer:"
472
+ try:
473
+ response = self.model.generate(prompt, audio_path=audio_path)
474
+ except Exception as e:
475
+ print("Error audio: {}".format(inst.input["audio_file"]))
476
+ response = "None"
477
+ continue
478
+ self.predictions.append(response)
479
+ self.references.append(inst.output["text"])
480
+ audio_paths.append(inst.input["audio_file"])
481
+ self.save_predictions(audio_paths)
482
+
483
+ def evaluate(self) -> Dict[str, float]:
484
+ bleu = bleu_evaluation(self.predictions, self.references)
485
+ return {"bleu1": bleu['bleu1']}
486
+
487
+
488
+ class AudioCaptionsClotho(BaseTask):
489
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
490
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
491
+ for d in task_data["data"]]
492
+
493
+ def _get_choice_candidate(self, data: List[Instance]) -> List[str]:
494
+ return ["None"]
495
+
496
+ def run_inference(self):
497
+ audio_paths = []
498
+ for inst in tqdm.tqdm(self.data):
499
+ audio_path = os.path.join(self.audio_dir, inst.input["audio_file"])
500
+ question = inst.input["prompt"]
501
+ prompt = f"Please listen to the audio and then answer the question. Questions: {question}\nAnswer:"
502
+ try:
503
+ response = self.model.generate(prompt, audio_path=audio_path)
504
+ except Exception as e:
505
+ print("Error audio: {}".format(inst.input["audio_file"]))
506
+ response = "None"
507
+ continue
508
+ self.predictions.append(response)
509
+ self.references.append(inst.output["text"])
510
+ audio_paths.append(inst.input["audio_file"])
511
+ self.save_predictions(audio_paths)
512
+
513
+ def evaluate(self) -> Dict[str, float]:
514
+ acc = bleu_evaluation(self.predictions, self.references)
515
+ return {"accuracy": acc}
516
+
517
+
518
+ class AudioQA(BaseTask):
519
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
520
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
521
+ for d in task_data["data"]]
522
+
523
+ def _get_choice_candidate(self, data) -> List[str]:
524
+ choices = []
525
+ for item in data['data']:
526
+ choices.append(item['output']["text"].strip().lower())
527
+ choices = list(set(choices))
528
+ return choices
529
+
530
+ def run_inference(self):
531
+ audio_paths = []
532
+ for inst in tqdm.tqdm(self.data):
533
+ audio_path = os.path.join(self.audio_dir, inst.input["audio_file"])
534
+ question = inst.input["prompt"]
535
+ prompt = f"Please listen to the audio and then answer the question by directly choose a choice from choice candidates. Questions: {question}, Candidate choices: {self.choice_candidate}\nAnswer:"
536
+ try:
537
+ response = self.model.generate(prompt, audio_path=audio_path)
538
+ except Exception as e:
539
+ print("Error audio: {}".format(inst.input["audio_file"]))
540
+ response = "None"
541
+ continue
542
+ self.predictions.append(response)
543
+ self.references.append(inst.output["text"])
544
+ audio_paths.append(inst.input["audio_file"])
545
+ self.save_predictions(audio_paths)
546
+
547
+ def evaluate(self) -> Dict[str, float]:
548
+ acc = exact_match_accuracy(self.predictions, self.references)
549
+ return {"accuracy": acc}
550
+
551
+
552
+ class BirdSoundDetection(BaseTask):
553
+
554
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
555
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
556
+ for d in task_data["data"]]
557
+
558
+ def _get_choice_candidate(self, data: List[Instance]) -> List[str]:
559
+ return ["Yes", "No"]
560
+
561
+ def save_predictions(self, audio_paths):
562
+ results = []
563
+ for gt, response, audio_path in zip(self.references, self.predictions, audio_paths):
564
+ results.append({
565
+ 'gt': gt,
566
+ 'response': response,
567
+ 'audio_path': audio_path,
568
+ })
569
+ time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
570
+ results_file = os.path.join(self.output_dir, f'{self.task_name }_{time_prefix}.json') if self.output_dir else f'{self.task_name }_{time_prefix}.json'
571
+ json.dump(results, open(results_file, 'w'))
572
+
573
+ def run_inference(self):
574
+ self.predictions = []
575
+ self.references = []
576
+ audio_paths = []
577
+ for inst in tqdm.tqdm(self.data):
578
+ audio_path = os.path.join(self.audio_dir, inst.input["audio_file"])
579
+ question = inst.input["prompt"]
580
+ prompt = f"Please listen to the audio and then answer the question by directly choose a choice from choice candidates. Questions: {question}, Candidate choices: {self.choice_candidate}\nAnswer:"
581
+ try:
582
+ response = self.model.generate(prompt, audio_path=audio_path)
583
+ except Exception as e:
584
+ print("Error audio: {}".format(inst.input["audio_file"]))
585
+ response = "None"
586
+ continue
587
+ self.predictions.append(response)
588
+ self.references.append("Yes" if inst.output["text"] == 1 else "No")
589
+ audio_paths.append(inst.input["audio_file"])
590
+ self.save_predictions(audio_paths)
591
+
592
+ def evaluate(self) -> Dict[str, float]:
593
+ acc = exact_match_accuracy(self.predictions, self.references)
594
+ return {"accuracy": acc}
595
+
596
+
597
+ class EnvironmentSoundRecognition(BaseTask):
598
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
599
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
600
+ for d in task_data["data"]]
601
+
602
+ def _get_choice_candidate(self, data) -> List[str]:
603
+ choices = []
604
+ for item in data['data']:
605
+ choices.append(item['output']["text"].strip().lower())
606
+ choices = list(set(choices))
607
+ return choices
608
+
609
+ def run_inference(self):
610
+ audio_paths = []
611
+ for inst in tqdm.tqdm(self.data):
612
+ audio_path = os.path.join(self.audio_dir, inst.input["audio_file"])
613
+ question = inst.input["prompt"]
614
+ prompt = f"Please listen to the audio and then answer the question by directly choose a choice from choice candidates. Questions: {question}, Candidate choices: {self.choice_candidate}\nAnswer:"
615
+ try:
616
+ response = self.model.generate(prompt, audio_path=audio_path)
617
+ except Exception as e:
618
+ print(f"error {e}")
619
+ print("Error audio: {}".format(inst.input["audio_file"]))
620
+ response = "None"
621
+ continue
622
+ self.predictions.append(response)
623
+ self.references.append(inst.output["text"])
624
+ audio_paths.append(inst.input["audio_file"])
625
+ self.save_predictions(audio_paths)
626
+
627
+ def evaluate(self) -> Dict[str, float]:
628
+ acc = blur_match_accuracy(self.predictions, self.references)
629
+ return {"accuracy": acc}
630
+
631
+
632
+ class IntentClassification(BaseTask):
633
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
634
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
635
+ for d in task_data["data"]]
636
+
637
+ def _get_choice_candidate(self, data: Dict) -> Dict:
638
+ intent_label = data['intent_label']
639
+ return intent_label
640
+
641
+ def run_inference(self):
642
+ audio_paths = []
643
+ candidate_actions = ','.join([k for k in self.choice_candidate['action'].keys() if not k[0].isdigit()])
644
+ candidate_objects = ','.join([k for k in self.choice_candidate['object'].keys() if not k[0].isdigit()])
645
+ candidate_locations = ','.join([k for k in self.choice_candidate['location'].keys() if not k[0].isdigit()])
646
+ for inst in tqdm.tqdm(self.data):
647
+ audio_path = os.path.join(self.audio_dir, inst.input["audio_file"])
648
+ question = inst.input["prompt"]
649
+ prompt = f"Please listen to the audio and then detect the intention. The intention triplet includes three parts: action, object, and location. The candicate actions are {candidate_actions}, candidate objects are {candidate_objects}, and candidate locations are {candidate_locations}. Please answer the questions only use the provided candidate actions, objects, and locations. Questions: {question}\nAnswer:"
650
+ try:
651
+ response = self.model.generate(prompt, audio_path=audio_path)
652
+ except Exception as e:
653
+ print("Error audio: {}".format(inst.input["audio_file"]))
654
+ response = "None"
655
+ continue
656
+ self.predictions.append(response)
657
+ self.references.append(' '.join([self.choice_candidate['action'][inst.output["text"].split()[0]], self.choice_candidate['action'][inst.output["text"].split()[1]], self.choice_candidate['action'][inst.output["text"].split()[2]]]))
658
+ audio_paths.append(inst.input["audio_file"])
659
+ self.save_predictions(audio_paths)
660
+
661
+ def evaluate(self) -> Dict[str, float]:
662
+ acc = exact_match_accuracy(self.predictions, self.references)
663
+ return {"accuracy": acc}
664
+
665
+
666
+ def post_process_intent_output():
667
+ data_path = '/m2v_intern/wushengqiong/model/audio-test/predictions/understanding/IntentClassification_250102204424.json'
668
+ intent_label = read_json('/m2v_intern/wushengqiong/model/audio-test/understanding/IntentClassification/annotation.json')['intent_label']
669
+ action = intent_label['action']
670
+ object = intent_label['object']
671
+ location = intent_label['location']
672
+
673
+ data = read_json(data_path)
674
+
675
+ results = []
676
+ for d in data:
677
+ results.append({
678
+ 'gt': [action[d['gt'].split()[0]], object[d['gt'].split()[1]], location[d['gt'].split()[2]]],
679
+ 'response': d['response'],
680
+ 'audio_path': d['audio_path'],
681
+ })
682
+ json.dump(results, open('/m2v_intern/wushengqiong/model/audio-test/predictions/understanding/IntentClassification_250102204424_1.json', 'w'))
683
+
684
+
685
+ class MusicGenreClassification(BaseTask):
686
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
687
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
688
+ for d in task_data["data"]]
689
+
690
+ def _get_choice_candidate(self, data: Dict) -> Dict:
691
+ choices = []
692
+ for item in data['data']:
693
+ choices.append(item['output']["text"].strip().lower())
694
+ choices = list(set(choices))
695
+ return choices
696
+
697
+
698
+ def run_inference(self):
699
+ audio_paths = []
700
+ for inst in tqdm.tqdm(self.data):
701
+ audio_path = os.path.join(self.audio_dir, inst.input["audio_file"].replace('\\', '/'))
702
+ question = inst.input["prompt"]
703
+ prompt = f"Please listen to the input music and then determine the genre of the music. The candidate genres are {self.choice_candidate}. Please output **only one genre** from the provided candidate genres, and **DO NOT** output any other words.\nQuestions: {question}\nAnswer:"
704
+ try:
705
+ response = self.model.generate(prompt, audio_path=audio_path)
706
+ except Exception as e:
707
+ print("Error audio: {}".format(inst.input["audio_file"]))
708
+ response = "None"
709
+ continue
710
+ self.predictions.append(response)
711
+ self.references.append(inst.output["text"])
712
+ audio_paths.append(inst.input["audio_file"])
713
+ self.save_predictions(audio_paths)
714
+
715
+ def evaluate(self) -> Dict[str, float]:
716
+ acc = exact_match_accuracy(self.predictions, self.references)
717
+ return {"accuracy": acc}
718
+
719
+
720
+ class MusicInstrumentClassification(BaseTask):
721
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
722
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
723
+ for d in task_data["data"]]
724
+
725
+ def _get_choice_candidate(self, data: Dict) -> Dict:
726
+ choices = []
727
+ for item in data['data']:
728
+ choices.append(item['output']["text"].strip().lower())
729
+ choices = list(set(choices))
730
+ return choices
731
+
732
+ def run_inference(self):
733
+ audio_paths = []
734
+ # candidate_instruments = ','.join([k for k in self.choice_candidate.keys() if not k[0].isdigit()])
735
+ for inst in tqdm.tqdm(self.data):
736
+ audio_path = os.path.join(self.audio_dir, inst.input["audio_file"])
737
+ question = inst.input["prompt"]
738
+ prompt = f"Please listen to the music and then detect the instrument of the music. The candidate instruments are {self.choice_candidate}. Please output **only the most appropriate music instrument** from the provided candidate music instruments, and **DO NOT** output any other words. Questions: {question}\nAnswer:"
739
+ try:
740
+ response = self.model.generate(prompt, audio_path=audio_path)
741
+ except Exception as e:
742
+ print("Error audio: {}".format(inst.input["audio_file"]))
743
+ response = "None"
744
+ continue
745
+ self.predictions.append(response)
746
+ self.references.append(inst.output["text"])
747
+ audio_paths.append(inst.input["audio_file"])
748
+ self.save_predictions(audio_paths)
749
+
750
+ def evaluate(self) -> Dict[str, float]:
751
+ acc = exact_match_accuracy(self.predictions, self.references)
752
+ return {"accuracy": acc}
753
+
754
+
755
+ class MusicInstrumentSourceAnalysis(BaseTask):
756
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
757
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
758
+ for d in task_data["data"]]
759
+
760
+ def _get_choice_candidate(self, data: Dict) -> Dict:
761
+ choices = []
762
+ for item in data['data']:
763
+ choices.append(item['output']["text"].strip().lower())
764
+ choices = list(set(choices))
765
+ return choices
766
+
767
+ def run_inference(self):
768
+ audio_paths = []
769
+ for inst in tqdm.tqdm(self.data):
770
+ audio_path = os.path.join(self.audio_dir, inst.input["audio_file"])
771
+ question = inst.input["prompt"]
772
+ prompt = f"Please listen to the music and then detect the instrucment source of the music. The candidate sources are {self.choice_candidate}. Please output **only the most appropriate music source** from the provided candidate music sources, and **DO NOT** output any other words. Questions: {question}\nAnswer:"
773
+ try:
774
+ response = self.model.generate(prompt, audio_path=audio_path)
775
+ except Exception as e:
776
+ print("Error audio: {}".format(inst.input["audio_file"]))
777
+ response = "None"
778
+ continue
779
+ self.predictions.append(response)
780
+ self.references.append(inst.output["text"])
781
+ audio_paths.append(inst.input["audio_file"].strip().lower())
782
+ self.save_predictions(audio_paths)
783
+
784
+ def evaluate(self) -> Dict[str, float]:
785
+ acc = exact_match_accuracy(self.predictions, self.references)
786
+ return {"accuracy": acc}
787
+
788
+
789
+ class MusicPitchAnalysis(BaseTask):
790
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
791
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
792
+ for d in task_data["data"]]
793
+
794
+ def _get_choice_candidate(self, data: Dict) -> Dict:
795
+ choices = []
796
+ for item in data['data']:
797
+ choices.append(item['output']["text"])
798
+ choices = list(set(choices))
799
+ return choices
800
+
801
+ def run_inference(self):
802
+ audio_paths = []
803
+ for inst in tqdm.tqdm(self.data):
804
+ audio_path = os.path.join(self.audio_dir, inst.input["audio_file"])
805
+ question = inst.input["prompt"]
806
+ prompt = f"Please listen to the music and then detect the pitch score of the music. The 0-based MIDI pitch is in the range [0, 127]. Please output **only the most appropriate pitch score in a number** from the provided range, and **DO NOT** output any other words. Questions: {question}\nAnswer:"
807
+ try:
808
+ response = self.model.generate(prompt, audio_path=audio_path)
809
+ except Exception as e:
810
+ print("Error audio: {}".format(inst.input["audio_file"]))
811
+ response = "None"
812
+ continue
813
+ self.predictions.append(response)
814
+ self.references.append(inst.output["text"])
815
+ audio_paths.append(inst.input["audio_file"].strip().lower())
816
+ self.save_predictions(audio_paths)
817
+
818
+ def evaluate(self) -> Dict[str, float]:
819
+ acc = exact_match_accuracy(self.predictions, self.references)
820
+ return {"accuracy": acc}
821
+
822
+
823
+ class NoteQualitiesAnalysis(BaseTask):
824
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
825
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
826
+ for d in task_data["data"]]
827
+
828
+ def _get_choice_candidate(self, data: Dict) -> Dict:
829
+ choices = []
830
+ for item in data['data']:
831
+ choices.append(','.join(item['output']["text"]).strip().lower())
832
+ choices = list(set(choices))
833
+ return choices
834
+
835
+ def run_inference(self):
836
+ audio_paths = []
837
+ for inst in tqdm.tqdm(self.data):
838
+ audio_path = os.path.join(self.audio_dir, inst.input["audio_file"])
839
+ question = inst.input["prompt"]
840
+ prompt = f"Please listen to the music and then detect the note quality of the given music. The candidate annotation is {self.choice_candidate}. Please output **the qualities which are present in this note** from the provided candidate music note quality candidate categories, and **DO NOT** output any other words. Questions: {question}\nAnswer:"
841
+ try:
842
+ response = self.model.generate(prompt, audio_path=audio_path)
843
+ except Exception as e:
844
+ print("Error audio: {}".format(inst.input["audio_file"]))
845
+ response = "None"
846
+ continue
847
+ self.predictions.append(response)
848
+ self.references.append(','.join(inst.output["text"]))
849
+ audio_paths.append(inst.input["audio_file"].strip().lower())
850
+ self.save_predictions(audio_paths)
851
+
852
+ def evaluate(self) -> Dict[str, float]:
853
+ acc = exact_match_accuracy(self.predictions, self.references)
854
+ return {"accuracy": acc}
855
+
856
+
857
+ class OpenAQA(BaseTask):
858
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
859
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
860
+ for d in task_data["data"]]
861
+
862
+ def _get_choice_candidate(self, data: Dict) -> Dict:
863
+ choices = []
864
+ for item in data['data']:
865
+ choices.append(item['output']["text"].strip().lower())
866
+ choices = list(set(choices))
867
+ return choices
868
+
869
+ def run_inference(self):
870
+ audio_paths = []
871
+ for inst in tqdm.tqdm(self.data):
872
+ audio_path = os.path.join(self.audio_dir, inst.input["audio_file"])
873
+ question = inst.input["prompt"]
874
+ prompt = f"Please listen to the audio and then answer the question. Questions: {question}\nAnswer:"
875
+ try:
876
+ response = self.model.generate(prompt, audio_path=audio_path)
877
+ except Exception as e:
878
+ print("Error audio: {}".format(inst.input["audio_file"]))
879
+ response = "None"
880
+ continue
881
+ self.predictions.append(response)
882
+ self.references.append(inst.output["text"])
883
+ audio_paths.append(inst.input["audio_file"])
884
+ self.save_predictions(audio_paths)
885
+
886
+ def evaluate(self) -> Dict[str, float]:
887
+ acc = bleu_evaluation(self.predictions, self.references)
888
+ return {"accuracy": acc}
889
+
890
+
891
+ class SoundEventClassification(BaseTask):
892
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
893
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
894
+ for d in task_data["data"]]
895
+
896
+ def _get_choice_candidate(self, data: Dict) -> Dict:
897
+ choices = []
898
+ for item in data['data']:
899
+ choices.append(item['output']["text"].strip().lower())
900
+ choices = list(set(choices))
901
+ return choices
902
+
903
+ def run_inference(self):
904
+ audio_paths = []
905
+ for inst in tqdm.tqdm(self.data):
906
+ audio_path = os.path.join(self.audio_dir, inst.input["audio_file"])
907
+ question = inst.input["prompt"]
908
+ prompt = f"Please listen to the music and then detect the happening event of the given audio. The candidate annotation is {self.choice_candidate}. Please output **only one event** from the provided candidate events,, and **DO NOT** output any other words. Questions: {question}\nAnswer:"
909
+ try:
910
+ response = self.model.generate(prompt, audio_path=audio_path)
911
+ except Exception as e:
912
+ print("Error audio: {}".format(inst.input["audio_file"]))
913
+ response = "None"
914
+ continue
915
+ self.predictions.append(response)
916
+ self.references.append(inst.output["text"])
917
+ audio_paths.append(inst.input["audio_file"])
918
+ self.save_predictions(audio_paths)
919
+
920
+ def evaluate(self) -> Dict[str, float]:
921
+ acc = exact_match_accuracy(self.predictions, self.references)
922
+ return {"accuracy": acc}
923
+
924
+
925
+ class SpeechCommand(BaseTask):
926
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
927
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
928
+ for d in task_data["data"]]
929
+
930
+ def _get_choice_candidate(self, data: Dict) -> Dict:
931
+ choices = []
932
+ for item in data['data']:
933
+ choices.append(item['output']["text"].strip().lower())
934
+ choices = list(set(choices))
935
+ return choices
936
+
937
+ def run_inference(self):
938
+ audio_paths = []
939
+ for inst in tqdm.tqdm(self.data):
940
+ audio_path = os.path.join(self.audio_dir, inst.input["audio_file"].replace('\\', '/'))
941
+ question = inst.input["prompt"]
942
+ prompt = f"Please listen to the audio and then detect the speech command of the given audio. The candidate annotation is {self.choice_candidate}. Please output **only one command** from the provided candidate commands, and **DO NOT** output any other words. Questions: {question}\nAnswer:"
943
+ try:
944
+ response = self.model.generate(prompt, audio_path=audio_path)
945
+ except Exception as e:
946
+ print("Error audio: {}".format(inst.input["audio_file"]))
947
+ response = "None"
948
+ continue
949
+ self.predictions.append(response)
950
+ self.references.append(inst.output["text"].strip().lower())
951
+ audio_paths.append(inst.input["audio_file"])
952
+ self.save_predictions(audio_paths)
953
+
954
+ def evaluate(self) -> Dict[str, float]:
955
+ acc = exact_match_accuracy(self.predictions, self.references)
956
+ return {"accuracy": acc}
957
+
958
+
959
+ class SpeechEmotionRecognition(BaseTask):
960
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
961
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
962
+ for d in task_data["data"]]
963
+
964
+ def _get_choice_candidate(self, data: Dict) -> Dict:
965
+ choices = []
966
+ for item in data['data']:
967
+ choices.append(item['output']["text"].strip().lower())
968
+ choices = list(set(choices))
969
+ return choices
970
+
971
+ def run_inference(self):
972
+ audio_paths = []
973
+ for inst in tqdm.tqdm(self.data):
974
+ audio_path = os.path.join(self.audio_dir, inst.input["audio_file"])
975
+ question = inst.input["prompt"]
976
+ prompt = f"Please listen to the audio and then detect the emotion of the given audio. The candidate annotation is {self.choice_candidate}. Please output **only one emotion** from the provided candidate emotions, and **DO NOT** output any other words. Questions: {question}\nAnswer:"
977
+ try:
978
+ response = self.model.generate(prompt, audio_path=audio_path)
979
+ except Exception as e:
980
+ print("Error audio: {}".format(inst.input["audio_file"]))
981
+ response = "None"
982
+ continue
983
+ self.predictions.append(response)
984
+ self.references.append(inst.output["text"].strip().lower())
985
+ audio_paths.append(inst.input["audio_file"])
986
+ self.save_predictions(audio_paths)
987
+
988
+ def evaluate(self) -> Dict[str, float]:
989
+ acc = exact_match_accuracy(self.predictions, self.references)
990
+ return {"accuracy": acc}
991
+
992
+
993
+ class VocalSoundClassification(BaseTask):
994
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
995
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
996
+ for d in task_data["data"]]
997
+
998
+ def _get_choice_candidate(self, data: Dict) -> Dict:
999
+ choices = []
1000
+ for item in data['data']:
1001
+ choices.append(item['output']["text"].strip().lower())
1002
+ choices = list(set(choices))
1003
+ return choices
1004
+
1005
+ def run_inference(self):
1006
+ audio_paths = []
1007
+ for inst in tqdm.tqdm(self.data):
1008
+ audio_path = os.path.join(self.audio_dir, inst.input["audio_file"])
1009
+ question = inst.input["prompt"]
1010
+ prompt = f"Please listen to the audio and then detect the vocal sound category of the given audio. The candidate annotation is {self.choice_candidate}. Please output **only one vocal sound category** from the provided candidate vocal sounds, and **DO NOT** output any other words. Questions: {question}\nAnswer:"
1011
+ try:
1012
+ response = self.model.generate(prompt, audio_path=audio_path)
1013
+ except Exception as e:
1014
+ print("Error audio: {}".format(inst.input["audio_file"]))
1015
+ response = "None"
1016
+ continue
1017
+ self.predictions.append(response)
1018
+ self.references.append(inst.output["text"].strip().lower())
1019
+ audio_paths.append(inst.input["audio_file"])
1020
+ self.save_predictions(audio_paths)
1021
+
1022
+ def evaluate(self) -> Dict[str, float]:
1023
+ acc = exact_match_accuracy(self.predictions, self.references)
1024
+ return {"accuracy": acc}
1025
+
1026
+
1027
+ class VocalTechniqueDetection(BaseTask):
1028
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
1029
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
1030
+ for d in task_data["data"]]
1031
+
1032
+ def _get_choice_candidate(self, data: Dict) -> Dict:
1033
+ choices = []
1034
+ for item in data['data']:
1035
+ choices.append(item['output']["text"].strip().lower())
1036
+ choices = list(set(choices))
1037
+ return choices
1038
+
1039
+ def run_inference(self):
1040
+ audio_paths = []
1041
+ for inst in tqdm.tqdm(self.data):
1042
+ audio_path = os.path.join(self.audio_dir, inst.input["audio_file"].replace('\\', '/'))
1043
+ question = inst.input["prompt"]
1044
+ prompt = f"Please listen to the audio and then detect the vocal technique of the given audio. The candidate annotations are scales, arpeggios, long tones, and excerpts. Please output **only one vocal technique** from the provided candidate vocal techniques, and **DO NOT** output any other words. Questions: {question}\nAnswer:"
1045
+ try:
1046
+ response = self.model.generate(prompt, audio_path=audio_path)
1047
+ except Exception as e:
1048
+ print("Error audio: {}".format(inst.input["audio_file"]))
1049
+ response = "None"
1050
+ continue
1051
+ self.predictions.append(response)
1052
+ self.references.append(inst.output["text"].strip().lower())
1053
+ audio_paths.append(inst.input["audio_file"])
1054
+ self.save_predictions(audio_paths)
1055
+
1056
+ def evaluate(self) -> Dict[str, float]:
1057
+ acc = exact_match_accuracy(self.predictions, self.references)
1058
+ return {"accuracy": acc}
1059
+
1060
+
1061
+ def log_performance_csv(model_name, task_name, metric, score, root_path, output_file='prediction.json'):
1062
+ import csv
1063
+ file_exists = os.path.isfile(os.path.join(root_path, output_file))
1064
+
1065
+ row_data = {
1066
+ 'model': model_name,
1067
+ 'task': task_name,
1068
+ 'metric': metric,
1069
+ 'score': str(score),
1070
+ }
1071
+
1072
+ with open(os.path.join(root_path, output_file), mode='a', newline='', encoding='utf-8') as f:
1073
+ writer = csv.DictWriter(f, fieldnames=row_data.keys())
1074
+ if not file_exists:
1075
+ writer.writeheader()
1076
+
1077
+ writer.writerow(row_data)
1078
+
1079
+
1080
+ def log_performance_json(model_name, task_name, metric, score, root_path, output_file='prediction.json'):
1081
+ import json
1082
+ log_data = {
1083
+ 'model': model_name,
1084
+ 'task': task_name,
1085
+ 'metric': metric,
1086
+ 'score': str(score),
1087
+ }
1088
+
1089
+ log_file_path = os.path.join(root_path, output_file)
1090
+
1091
+ if os.path.exists(log_file_path):
1092
+ with open(log_file_path, 'r') as f:
1093
+ existing_data = json.load(f)
1094
+ else:
1095
+ existing_data = []
1096
+
1097
+ existing_data.append(log_data)
1098
+
1099
+ with open(log_file_path, 'w', encoding='utf-8') as f:
1100
+ json.dump(existing_data, f, indent=4)
1101
+
1102
+
1103
+ def log_performance_detail(model_name, task_name, metrics, root_path, output_file='performance_log.csv'):
1104
+ import csv
1105
+ file_path = os.path.join(root_path, output_file)
1106
+ file_exists = os.path.isfile(file_path)
1107
+
1108
+ # Retrieve the main indicator values from the metrics dictionary
1109
+ metric_value = None
1110
+ if isinstance(metrics, dict):
1111
+ # Select metrics based on priority
1112
+ for key in ['accuracy', 'f1', 'micro_f1', 'bleu4', 'rougeL', 'code_bleu', 'MAE']:
1113
+ if key in metrics:
1114
+ metric_value = metrics[key]
1115
+ break
1116
+ if metric_value is None and len(metrics) > 0:
1117
+ # If no priority metric is found, use the first metric
1118
+ metric_value = list(metrics.values())[0]
1119
+ else:
1120
+ metric_value = metrics
1121
+
1122
+ # Simplify the file name, keeping only the last part
1123
+ model_name = model_name.split('/')[-1]
1124
+
1125
+ if file_exists:
1126
+ # Read existing data
1127
+ rows = []
1128
+ tasks = set()
1129
+ with open(file_path, 'r', newline='', encoding='utf-8') as f:
1130
+ reader = csv.reader(f)
1131
+ header = next(reader, ['task', model_name]) # If the file is empty, use the default header
1132
+ if len(header) == 1: # If there is only the task column, add the model column
1133
+ header.append(model_name)
1134
+ rows.append(header)
1135
+
1136
+ # Read existing data and update
1137
+ for row in reader:
1138
+ if row[0] == task_name: # If the same task is found, update the value
1139
+ row = [task_name, str(metric_value)]
1140
+ tasks.add(row[0])
1141
+ rows.append(row)
1142
+
1143
+ # If it is a new task, add a new row
1144
+ if task_name not in tasks:
1145
+ rows.append([task_name, str(metric_value)])
1146
+ else:
1147
+ # Create a new file
1148
+ rows = [
1149
+ ['task', model_name],
1150
+ [task_name, str(metric_value)]
1151
+ ]
1152
+
1153
+ # Write all data
1154
+ with open(file_path, 'w', newline='', encoding='utf-8') as f:
1155
+ writer = csv.writer(f)
1156
+ writer.writerows(rows)
1157
+
1158
+
1159
+ if __name__ == "__main__":
1160
+
1161
+ import argparse
1162
+ # Parse command line arguments
1163
+ parser = argparse.ArgumentParser(description="Run audio understanding tasks")
1164
+ parser.add_argument('-m', '--model_name', type=str, required=True, help='Name of the audio understanding model to use')
1165
+ parser.add_argument('-d', '--data_dir', type=str, default='./audio/understanding/', help='Directory containing task data')
1166
+ parser.add_argument('-o', '--output_dir', type=str, default='./audio/predictions/understanding/', help='Directory to save predictions')
1167
+ parser.add_argument('-r', '--root_path', type=str, default='./', help='Root path for logging performance')
1168
+ parser.add_argument('-t', '--task_names', type=str, nargs='+',
1169
+ help='List of task names to run (default: AccentClassification AccentSexClassification AcousticSceneClassification)')
1170
+ args = parser.parse_args()
1171
+
1172
+ # model_name = 'Qwen2-Audio-7B-Instruct'
1173
+ # data_dir = './understanding/'
1174
+ # output_dir = f'./predictions/understanding/{model_name}'
1175
+ # root_path = './'
1176
+
1177
+ model = AudioComprehensionModel(model_name=args.model_name)
1178
+
1179
+
1180
+ task_name_list = [
1181
+ 'AccentClassification', 'AccentSexClassification', 'AcousticSceneClassification',
1182
+ 'AnimalSoundClassification', 'AudioCaptioning', 'AudioCaptioningClotho',
1183
+ 'AudioQA', 'BirdSoundDetection', 'EnvironmentSoundRecognition',
1184
+ 'IntentClassification', 'MusicGenreClassification',
1185
+ 'MusicInstrumentClassification', 'MusicInstrumentSourceAnalysis',
1186
+ 'MusicPitchAnalysis', 'NoteQualitiesAnalysis', 'OpenAQA',
1187
+ 'SingerIdentification', 'SoundEventClassification',
1188
+ 'SpeakerIdentification', 'SpeechCommand',
1189
+ 'SpeechEmotionRecognition', 'VocalSoundClassification',
1190
+ 'VocalTechniqueDetection'
1191
+ ]
1192
+ if args.task_names is None or len(args.task_names) == 0:
1193
+ args.task_names = task_name_list
1194
+
1195
+ for task_name in args.task_names: # os.listdir(data_dir):
1196
+
1197
+ # Dynamically get the class by its name
1198
+ if task_name in globals(): # Ensure the class is defined in the current scope
1199
+ task_class = globals()[task_name]
1200
+ else:
1201
+ # Optionally, handle cases where the class is not found
1202
+ print(f"Task {task_name} is not defined in the current scope.")
1203
+ continue
1204
+
1205
+ # Initialize the task class
1206
+ import glob
1207
+ json_file_list = glob.glob(os.path.join(args.data_dir, task_name, "*.json"))
1208
+ if len(json_file_list) == 0:
1209
+ print(f"No JSON files found for task: {task_name}")
1210
+ continue
1211
+ elif len(json_file_list) > 1:
1212
+ print(f"Multiple JSON files found for task: {task_name}, using the first one: {json_file_list[0]}")
1213
+ task_annotation_data = json_file_list[0]
1214
+ else:
1215
+ task_annotation_data = json_file_list[0]
1216
+ task = task_class(
1217
+ task_data=task_annotation_data,
1218
+ model=model,
1219
+ audio_dir=os.path.join(args.data_dir, task_name, 'audios'),
1220
+ output_dir=args.output_dir
1221
+ )
1222
+
1223
+ # Run inference for the task
1224
+ # This should generate audio files based on the task's data
1225
+ print(f"Running inference for task: {task_name}")
1226
+ task.run_inference()
1227
+ # if you want to save the predictions, you need to rewrite the save_predictions() in each Task class depending on your need, and call task.save_predictions() after task.run_inference() or inside the run_inference method.
1228
+
1229
+
1230
+ # Evaluate the task, return a dictionary of metrics
1231
+ # For example, {'FAD_score': 0.123}
1232
+ eval_results = task.evaluate()
1233
+ print("Task name: ", task_name, "Evaluation results:", eval_results)
1234
+ log_performance_json(
1235
+ model_name=args.model_name,
1236
+ task_name=task_name,
1237
+ metric=list(eval_results.keys())[0].split('_')[0], # CLAP_score
1238
+ score=eval_results[list(eval_results.keys())[0]], # e.g., 0.123
1239
+ root_path=args.data_dir)
1240
+
1241
+ # or you can run the tasks one by one like below:
1242
+ # task_name = 'AcousticSceneClassification'
1243
+ # task = AcousticSceneClassification(
1244
+ # task_data=os.path.join(data_dir, f"{task_name}/annotation.json"),
1245
+ # model=model,
1246
+ # audio_dir=os.path.join(data_dir, f"{task_name}/audios"),
1247
+ # output_dir=output_dir)
1248
+ # task.run_inference()
1249
+ # print(task.evaluate())
1250
+
1251
+
1252
+
predictors/audio_predict_generation.py ADDED
@@ -0,0 +1,1245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from email.mime import audio
2
+ import json
3
+ import os
4
+ from pyexpat import model
5
+ from regex import B, D
6
+ import tqdm
7
+ from typing import List, Dict, Any
8
+ import nltk
9
+ from dataclasses import dataclass
10
+ from abc import ABC, abstractmethod
11
+ import math
12
+ import time
13
+ from urllib.request import urlopen
14
+ import librosa
15
+ import torch
16
+ from torch import nn
17
+ import numpy as np
18
+ from encodec import EncodecModel
19
+ import laion_clap
20
+ import resampy
21
+ import soundfile as sf
22
+ from scipy import linalg
23
+ from multiprocessing.dummy import Pool as ThreadPool
24
+ import copy
25
+ import pickle
26
+ from collections import defaultdict
27
+
28
+
29
+
30
+ def read_json(file_path: str) -> Dict[str, Any]:
31
+ with open(file_path, "r") as f:
32
+ data = json.load(f)
33
+ return data
34
+
35
+
36
+ # ================================================ FAD related functions ================================================
37
+ # These functions are used to calculate the FAD score
38
+
39
+
40
+ def load_audio_task(fname, sample_rate, channels, dtype="float32"):
41
+ if dtype not in ['float64', 'float32', 'int32', 'int16']:
42
+ raise ValueError(f"dtype not supported: {dtype}")
43
+
44
+ wav_data, sr = sf.read(fname, dtype=dtype)
45
+ # For integer type PCM input, convert to [-1.0, +1.0]
46
+ if dtype == 'int16':
47
+ wav_data = wav_data / 32768.0
48
+ elif dtype == 'int32':
49
+ wav_data = wav_data / float(2**31)
50
+
51
+ # Convert to mono
52
+ assert channels in [1, 2], "channels must be 1 or 2"
53
+ if len(wav_data.shape) > channels:
54
+ wav_data = np.mean(wav_data, axis=1)
55
+
56
+ if sr != sample_rate:
57
+ wav_data = resampy.resample(wav_data, sr, sample_rate)
58
+
59
+ return wav_data
60
+
61
+
62
+ class FrechetAudioDistance:
63
+ def __init__(
64
+ self,
65
+ ckpt_dir=None,
66
+ model_name="clap",
67
+ submodel_name="630k-audioset", # only for CLAP
68
+ sample_rate=16000,
69
+ channels=1,
70
+ use_pca=False, # only for VGGish
71
+ use_activation=False, # only for VGGish
72
+ verbose=False,
73
+ audio_load_worker=8,
74
+ enable_fusion=False, # only for CLAP
75
+ ):
76
+ """
77
+ Initialize FAD
78
+
79
+ -- ckpt_dir: folder where the downloaded checkpoints are stored
80
+ -- model_name: one between vggish, pann, clap or encodec
81
+ -- submodel_name: only for clap models - determines which checkpoint to use.
82
+ options: ["630k-audioset", "630k", "music_audioset", "music_speech", "music_speech_audioset"]
83
+ -- sample_rate: one between [8000, 16000, 32000, 48000]. depending on the model set the sample rate to use
84
+ -- channels: number of channels in an audio track
85
+ -- use_pca: whether to apply PCA to the vggish embeddings
86
+ -- use_activation: whether to use the output activation in vggish
87
+ -- enable_fusion: whether to use fusion for clap models (valid depending on the specific submodel used)
88
+ """
89
+ assert model_name in ["vggish", "clap", "encodec"], "model_name must be either 'vggish', 'pann', 'clap' or 'encodec'"
90
+ if model_name == "vggish":
91
+ assert sample_rate == 16000, "sample_rate must be 16000"
92
+ elif model_name == "clap":
93
+ assert sample_rate == 48000, "sample_rate must be 48000"
94
+ assert submodel_name in ["630k-audioset", "630k", "music_audioset", "music_speech", "music_speech_audioset"]
95
+ elif model_name == "encodec":
96
+ assert sample_rate in [24000, 48000], "sample_rate must be 24000 or 48000"
97
+ if sample_rate == 48000:
98
+ assert channels == 2, "channels must be 2 for 48khz encodec model"
99
+ self.model_name = model_name
100
+ self.submodel_name = submodel_name
101
+ self.sample_rate = sample_rate
102
+ self.channels = channels
103
+ self.verbose = verbose
104
+ self.device = torch.device(
105
+ 'cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
106
+ if self.device == torch.device('mps') and self.model_name == "clap":
107
+ if self.verbose:
108
+ print("[Frechet Audio Distance] CLAP does not support MPS device yet, because:")
109
+ print("[Frechet Audio Distance] The operator 'aten::upsample_bicubic2d.out' is not currently implemented for the MPS device.")
110
+ print("[Frechet Audio Distance] Using CPU device instead.")
111
+ self.device = torch.device('cpu')
112
+ if self.verbose:
113
+ print("[Frechet Audio Distance] Using device: {}".format(self.device))
114
+ self.audio_load_worker = audio_load_worker
115
+ self.enable_fusion = enable_fusion
116
+ if ckpt_dir is not None:
117
+ os.makedirs(ckpt_dir, exist_ok=True)
118
+ torch.hub.set_dir(ckpt_dir)
119
+ self.ckpt_dir = ckpt_dir
120
+ else:
121
+ # by default `ckpt_dir` is `torch.hub.get_dir()`
122
+ self.ckpt_dir = torch.hub.get_dir()
123
+ self.__get_model(model_name=model_name, use_pca=use_pca, use_activation=use_activation)
124
+
125
+ def __get_model(self, model_name="vggish", use_pca=False, use_activation=False):
126
+ """
127
+ Get ckpt and set model for the specified model_name
128
+
129
+ Params:
130
+ -- model_name: one between vggish, pann or clap
131
+ -- use_pca: whether to apply PCA to the vggish embeddings
132
+ -- use_activation: whether to use the output activation in vggish
133
+ """
134
+ # vggish
135
+ if model_name == "vggish":
136
+ # S. Hershey et al., "CNN Architectures for Large-Scale Audio Classification", ICASSP 2017
137
+ self.model = torch.hub.load(repo_or_dir='harritaylor/torchvggish', model='vggish')
138
+ if not use_pca:
139
+ self.model.postprocess = False
140
+ if not use_activation:
141
+ self.model.embeddings = nn.Sequential(*list(self.model.embeddings.children())[:-1])
142
+ self.model.device = self.device
143
+ # clap
144
+ elif model_name == "clap":
145
+ # choose the right checkpoint and model
146
+ if self.submodel_name == "630k-audioset":
147
+ if self.enable_fusion:
148
+ download_name = "630k-audioset-fusion-best.pt"
149
+ else:
150
+ download_name = "630k-audioset-best.pt"
151
+ elif self.submodel_name == "630k":
152
+ if self.enable_fusion:
153
+ download_name = "630k-fusion-best.pt"
154
+ else:
155
+ download_name = "630k-best.pt"
156
+ elif self.submodel_name == "music_audioset":
157
+ download_name = "music_audioset_epoch_15_esc_90.14.pt"
158
+ elif self.submodel_name == "music_speech":
159
+ download_name = "music_speech_epoch_15_esc_89.25.pt"
160
+ elif self.submodel_name == "music_speech_audioset":
161
+ download_name = "music_speech_audioset_epoch_15_esc_89.98.pt"
162
+
163
+ model_path = os.path.join(self.ckpt_dir, download_name)
164
+
165
+ # download checkpoint
166
+ if not (os.path.exists(model_path)):
167
+ if self.verbose:
168
+ print("[Frechet Audio Distance] Downloading {}...".format(model_path))
169
+ torch.hub.download_url_to_file(
170
+ url=f"https://huggingface.co/lukewys/laion_clap/resolve/main/{download_name}",
171
+ dst=model_path
172
+ )
173
+ # init model and load checkpoint
174
+ if self.submodel_name in ["630k-audioset", "630k"]:
175
+ self.model = laion_clap.CLAP_Module(enable_fusion=self.enable_fusion,
176
+ device=self.device)
177
+ elif self.submodel_name in ["music_audioset", "music_speech", "music_speech_audioset"]:
178
+ self.model = laion_clap.CLAP_Module(enable_fusion=self.enable_fusion,
179
+ amodel='HTSAT-base',
180
+ device=self.device)
181
+ self.model.load_ckpt(model_path)
182
+
183
+ # init model and load checkpoint
184
+ if self.submodel_name in ["630k-audioset", "630k"]:
185
+ self.model = laion_clap.CLAP_Module(enable_fusion=self.enable_fusion,
186
+ device=self.device)
187
+ elif self.submodel_name in ["music_audioset", "music_speech", "music_speech_audioset"]:
188
+ self.model = laion_clap.CLAP_Module(enable_fusion=self.enable_fusion,
189
+ amodel='HTSAT-base',
190
+ device=self.device)
191
+ self.model.load_ckpt(model_path)
192
+
193
+ # encodec
194
+ elif model_name == "encodec":
195
+ # choose the right model based on sample_rate
196
+ # weights are loaded from the encodec repo: https://github.com/facebookresearch/encodec/
197
+ if self.sample_rate == 24000:
198
+ self.model = EncodecModel.encodec_model_24khz()
199
+ elif self.sample_rate == 48000:
200
+ self.model = EncodecModel.encodec_model_48khz()
201
+ # 24kbps is the max bandwidth supported by both versions
202
+ # these models use 32 residual quantizers
203
+ self.model.set_target_bandwidth(24.0)
204
+
205
+ self.model.to(self.device)
206
+ self.model.eval()
207
+
208
+ def get_embeddings(self, x, sr):
209
+ """
210
+ Get embeddings using VGGish, PANN, CLAP or EnCodec models.
211
+ Params:
212
+ -- x : a list of np.ndarray audio samples
213
+ -- sr : sampling rate.
214
+ """
215
+ embd_lst = []
216
+ try:
217
+ for audio in tqdm(x, disable=(not self.verbose)):
218
+ if self.model_name == "vggish":
219
+ embd = self.model.forward(audio, sr)
220
+ elif self.model_name == "clap":
221
+ audio = torch.tensor(audio).float().unsqueeze(0)
222
+ embd = self.model.get_audio_embedding_from_data(audio, use_tensor=True)
223
+ elif self.model_name == "encodec":
224
+ # add two dimensions
225
+ audio = torch.tensor(
226
+ audio).float().unsqueeze(0).unsqueeze(0).to(self.device)
227
+ # if SAMPLE_RATE is 48000, we need to make audio stereo
228
+ if self.model.sample_rate == 48000:
229
+ if audio.shape[-1] != 2:
230
+ if self.verbose:
231
+ print(
232
+ "[Frechet Audio Distance] Audio is mono, converting to stereo for 48khz model..."
233
+ )
234
+ audio = torch.cat((audio, audio), dim=1)
235
+ else:
236
+ # transpose to (batch, channels, samples)
237
+ audio = audio[:, 0].transpose(1, 2)
238
+
239
+ if self.verbose:
240
+ print(
241
+ "[Frechet Audio Distance] Audio shape: {}".format(
242
+ audio.shape
243
+ )
244
+ )
245
+
246
+ with torch.no_grad():
247
+ # encodec embedding (before quantization)
248
+ embd = self.model.encoder(audio)
249
+ embd = embd.squeeze(0)
250
+
251
+ if self.verbose:
252
+ print(
253
+ "[Frechet Audio Distance] Embedding shape: {}".format(
254
+ embd.shape
255
+ )
256
+ )
257
+
258
+ if embd.device != torch.device("cpu"):
259
+ embd = embd.cpu()
260
+
261
+ if torch.is_tensor(embd):
262
+ embd = embd.detach().numpy()
263
+
264
+ embd_lst.append(embd)
265
+ except Exception as e:
266
+ print("[Frechet Audio Distance] get_embeddings throw an exception: {}".format(str(e)))
267
+
268
+ return np.concatenate(embd_lst, axis=0)
269
+
270
+ def calculate_embd_statistics(self, embd_lst):
271
+ if isinstance(embd_lst, list):
272
+ embd_lst = np.array(embd_lst)
273
+ mu = np.mean(embd_lst, axis=0)
274
+ sigma = np.cov(embd_lst, rowvar=False)
275
+ return mu, sigma
276
+
277
+ def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
278
+ """
279
+ Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
280
+
281
+ Numpy implementation of the Frechet Distance.
282
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
283
+ and X_2 ~ N(mu_2, C_2) is
284
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
285
+ Stable version by Dougal J. Sutherland.
286
+ Params:
287
+ -- mu1 : Numpy array containing the activations of a layer of the
288
+ inception net (like returned by the function 'get_predictions')
289
+ for generated samples.
290
+ -- mu2 : The sample mean over activations, precalculated on an
291
+ representative data set.
292
+ -- sigma1: The covariance matrix over activations for generated samples.
293
+ -- sigma2: The covariance matrix over activations, precalculated on an
294
+ representative data set.
295
+ Returns:
296
+ -- : The Frechet Distance.
297
+ """
298
+
299
+ mu1 = np.atleast_1d(mu1)
300
+ mu2 = np.atleast_1d(mu2)
301
+
302
+ sigma1 = np.atleast_2d(sigma1)
303
+ sigma2 = np.atleast_2d(sigma2)
304
+
305
+ assert mu1.shape == mu2.shape, \
306
+ 'Training and test mean vectors have different lengths'
307
+ assert sigma1.shape == sigma2.shape, \
308
+ 'Training and test covariances have different dimensions'
309
+
310
+ diff = mu1 - mu2
311
+
312
+ # Product might be almost singular
313
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2).astype(complex), disp=False)
314
+ if not np.isfinite(covmean).all():
315
+ msg = ('fid calculation produces singular product; '
316
+ 'adding %s to diagonal of cov estimates') % eps
317
+ print(msg)
318
+ offset = np.eye(sigma1.shape[0]) * eps
319
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset).astype(complex))
320
+
321
+ # Numerical error might give slight imaginary component
322
+ if np.iscomplexobj(covmean):
323
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
324
+ m = np.max(np.abs(covmean.imag))
325
+ raise ValueError('Imaginary component {}'.format(m))
326
+ covmean = covmean.real
327
+
328
+ tr_covmean = np.trace(covmean)
329
+
330
+ return (diff.dot(diff) + np.trace(sigma1)
331
+ + np.trace(sigma2) - 2 * tr_covmean)
332
+
333
+ def __load_audio_files(self, dir, dtype="float32"):
334
+ task_results = []
335
+
336
+ pool = ThreadPool(self.audio_load_worker)
337
+ pbar = tqdm(total=len(os.listdir(dir)), disable=(not self.verbose))
338
+
339
+ def update(*a):
340
+ pbar.update()
341
+
342
+ if self.verbose:
343
+ print("[Frechet Audio Distance] Loading audio from {}...".format(dir))
344
+ for fname in os.listdir(dir):
345
+ res = pool.apply_async(
346
+ load_audio_task,
347
+ args=(os.path.join(dir, fname), self.sample_rate, self.channels, dtype),
348
+ callback=update,
349
+ )
350
+ task_results.append(res)
351
+ pool.close()
352
+ pool.join()
353
+
354
+ return [k.get() for k in task_results]
355
+
356
+ def score(self,
357
+ background_dir,
358
+ eval_dir,
359
+ background_embds_path=None,
360
+ eval_embds_path=None,
361
+ dtype="float32"
362
+ ):
363
+ """
364
+ Computes the Frechet Audio Distance (FAD) between two directories of audio files.
365
+
366
+ Parameters:
367
+ - background_dir (str): Path to the directory containing background audio files.
368
+ - eval_dir (str): Path to the directory containing evaluation audio files.
369
+ - background_embds_path (str, optional): Path to save/load background audio embeddings (e.g., /folder/bkg_embs.npy). If None, embeddings won't be saved.
370
+ - eval_embds_path (str, optional): Path to save/load evaluation audio embeddings (e.g., /folder/test_embs.npy). If None, embeddings won't be saved.
371
+ - dtype (str, optional): Data type for loading audio. Default is "float32".
372
+
373
+ Returns:
374
+ - float: The Frechet Audio Distance (FAD) score between the two directories of audio files.
375
+ """
376
+ try:
377
+ # Load or compute background embeddings
378
+ if background_embds_path is not None and os.path.exists(background_embds_path):
379
+ if self.verbose:
380
+ print(f"[Frechet Audio Distance] Loading embeddings from {background_embds_path}...")
381
+ embds_background = np.load(background_embds_path)
382
+ else:
383
+ audio_background = self.__load_audio_files(background_dir, dtype=dtype)
384
+ embds_background = self.get_embeddings(audio_background, sr=self.sample_rate)
385
+ if background_embds_path:
386
+ os.makedirs(os.path.dirname(background_embds_path), exist_ok=True)
387
+ np.save(background_embds_path, embds_background)
388
+
389
+ # Load or compute eval embeddings
390
+ if eval_embds_path is not None and os.path.exists(eval_embds_path):
391
+ if self.verbose:
392
+ print(f"[Frechet Audio Distance] Loading embeddings from {eval_embds_path}...")
393
+ embds_eval = np.load(eval_embds_path)
394
+ else:
395
+ audio_eval = self.__load_audio_files(eval_dir, dtype=dtype)
396
+ embds_eval = self.get_embeddings(audio_eval, sr=self.sample_rate)
397
+ if eval_embds_path:
398
+ os.makedirs(os.path.dirname(eval_embds_path), exist_ok=True)
399
+ np.save(eval_embds_path, embds_eval)
400
+
401
+ # Check if embeddings are empty
402
+ if len(embds_background) == 0:
403
+ print("[Frechet Audio Distance] background set dir is empty, exiting...")
404
+ return -1
405
+ if len(embds_eval) == 0:
406
+ print("[Frechet Audio Distance] eval set dir is empty, exiting...")
407
+ return -1
408
+
409
+ # Compute statistics and FAD score
410
+ mu_background, sigma_background = self.calculate_embd_statistics(embds_background)
411
+ mu_eval, sigma_eval = self.calculate_embd_statistics(embds_eval)
412
+
413
+ fad_score = self.calculate_frechet_distance(
414
+ mu_background,
415
+ sigma_background,
416
+ mu_eval,
417
+ sigma_eval
418
+ )
419
+
420
+ return fad_score
421
+ except Exception as e:
422
+ print(f"[Frechet Audio Distance] An error occurred: {e}")
423
+ return -1
424
+
425
+
426
+ def calculate_fad_score(background_dir, eval_dir, background_embds_path=None, eval_embds_path=None, dtype="float32", ckpt_dir=None, model_name="clap", submodel_name="630k-audioset", sample_rate=16000, channels=1, use_pca=False, use_activation=False, verbose=False, audio_load_worker=8, enable_fusion=False):
427
+ """
428
+ Calculate the Frechet Audio Distance (FAD) score between two directories of audio files.
429
+
430
+ Parameters:
431
+ - background_dir: Directory containing background audio files.
432
+ - eval_dir: Directory containing evaluation audio files.
433
+ - background_embds_path: Path to save/load background audio embeddings.
434
+ - eval_embds_path: Path to save/load evaluation audio embeddings.
435
+ - dtype: Data type for loading audio files (default is "float32").
436
+ - ckpt_dir: Directory where the model checkpoints are stored.
437
+ - model_name: Name of the model to use (default is "clap").
438
+ - submodel_name: Submodel name for CLAP (default is "630k-audioset").
439
+ - sample_rate: Sample rate for audio files (default is 16000).
440
+ - channels: Number of channels in the audio files (default is 1).
441
+ - use_pca: Whether to apply PCA to VGGish embeddings (default is False).
442
+ - use_activation: Whether to use output activation in VGGish (default is False).
443
+ - verbose: Whether to print verbose output (default is False).
444
+ - audio_load_worker: Number of workers for loading audio files (default is 8).
445
+ - enable_fusion: Whether to enable fusion for CLAP models (default is False).
446
+
447
+ Returns:
448
+ - FAD score as a float.
449
+ """
450
+
451
+ fad = FrechetAudioDistance(
452
+ ckpt_dir=ckpt_dir,
453
+ model_name=model_name,
454
+ submodel_name=submodel_name,
455
+ sample_rate=sample_rate,
456
+ channels=channels,
457
+ use_pca=use_pca,
458
+ use_activation=use_activation,
459
+ verbose=verbose,
460
+ audio_load_worker=audio_load_worker,
461
+ enable_fusion=enable_fusion
462
+ )
463
+
464
+ return {
465
+ "FAD_score": fad.score(background_dir, eval_dir, background_embds_path, eval_embds_path, dtype)
466
+ }
467
+
468
+
469
+
470
+
471
+
472
+ # ================================================ CLAP related functions ================================================
473
+ # These functions are used to calculate the CLAP score
474
+
475
+
476
+ # quantization
477
+ def int16_to_float32(x):
478
+ return (x / 32767.0).astype('float32')
479
+
480
+
481
+ def float32_to_int16(x):
482
+ x = np.clip(x, a_min=-1., a_max=1.)
483
+ return (x * 32767.).astype('int16')
484
+
485
+
486
+ def calculate_cosine_similarity(embeddings1, embeddings2):
487
+ dot_product = np.dot(embeddings1, embeddings2)
488
+ norm1 = np.linalg.norm(embeddings1)
489
+ norm2 = np.linalg.norm(embeddings2)
490
+ return dot_product / (norm1 * norm2) if norm1 and norm2 else 0.0
491
+
492
+
493
+ def calculate_clap_score(clap_checkpoint=None, model_id=-1, verbose=True, audio_file_list=None, text_file_list=None):
494
+ """Load the pretrained checkpoint of CLAP model
495
+
496
+ Parameters
497
+ ----------
498
+ ckpt: str
499
+ if ckpt is specified, the model will load this ckpt, otherwise the model will download the ckpt from zenodo. \n
500
+ For fusion model, it will download the 630k+audioset fusion model (id=3). For non-fusion model, it will download the 630k+audioset model (id=1).
501
+ model_id:
502
+ if model_id is specified, you can download our best ckpt, as:
503
+ id = 0 --> 630k non-fusion ckpt \n
504
+ id = 1 --> 630k+audioset non-fusion ckpt \n
505
+ id = 2 --> 630k fusion ckpt \n
506
+ id = 3 --> 630k+audioset fusion ckpt \n
507
+ Note that if your model is specied as non-fusion model but you download a fusion model ckpt, you will face an error.
508
+ """
509
+ model = laion_clap.CLAP_Module(enable_fusion=False)
510
+ model.load_ckpt(ckpt = clap_checkpoint, model_id = model_id, verbose=verbose) # download the default pretrained checkpoint.
511
+ audio_embeddings = []
512
+ for file in audio_file_list:
513
+ audio, sr = librosa.load(file, sr=16000)
514
+ audio = int16_to_float32(audio)
515
+ embeddings = laion_clap.get_audio_embedding(audio)
516
+ audio_embeddings.append(embeddings)
517
+
518
+ text_embeddings = []
519
+ for file in text_file_list:
520
+ if os.path.exists(file):
521
+ with open(file, 'r') as f:
522
+ text = f.read()
523
+ else:
524
+ text = file
525
+ embeddings = laion_clap.get_text_embedding(text)
526
+ text_embeddings.append(embeddings)
527
+
528
+ # Compute similarity scores
529
+ scores = []
530
+ for audio_emb, text_emb in zip(audio_embeddings, text_embeddings):
531
+ score = calculate_cosine_similarity(audio_emb, text_emb)
532
+ scores.append(score)
533
+
534
+ # compute the average score
535
+ if len(scores) > 0:
536
+ average_score = sum(scores) / len(scores)
537
+ else:
538
+ average_score = 0.0
539
+
540
+ return {"CLAP_score": average_score, "scores": scores}
541
+
542
+
543
+ # ================================================ CIDEr (Consensus-based Image Description Evaluation) related functions ================================================
544
+ # These functions are used to calculate the CIDEr score
545
+
546
+
547
+ import whisper # a tool from OpenAI for speech recognition
548
+
549
+
550
+ def speech_to_text(model_name="turbo", audio_file="audio.mp3"):
551
+ """
552
+ Convert speech to text using a speech recognition model.
553
+ """
554
+ model = whisper.load_model(model_name)
555
+
556
+ # load audio and pad/trim it to fit 30 seconds
557
+ audio = whisper.load_audio(audio_file)
558
+ audio = whisper.pad_or_trim(audio)
559
+
560
+ # make log-Mel spectrogram and move to the same device as the model
561
+ mel = whisper.log_mel_spectrogram(audio, n_mels=model.dims.n_mels).to(model.device)
562
+
563
+ # detect the spoken language
564
+ _, probs = model.detect_language(mel)
565
+ print(f"Detected language: {max(probs, key=probs.get)}")
566
+
567
+ # decode the audio
568
+ options = whisper.DecodingOptions()
569
+ result = whisper.decode(model, mel, options)
570
+
571
+ # print the recognized text
572
+ print(result.text)
573
+ return result.text
574
+
575
+
576
+ def precook(s, n=4, out=False):
577
+ """
578
+ Takes a string as input and returns an object that can be given to
579
+ either cook_refs or cook_test. This is optional: cook_refs and cook_test
580
+ can take string arguments as well.
581
+ :param s: string : sentence to be converted into ngrams
582
+ :param n: int : number of ngrams for which representation is calculated
583
+ :return: term frequency vector for occuring ngrams
584
+ """
585
+ words = s.split()
586
+ counts = defaultdict(int)
587
+ for k in range(1,n+1):
588
+ for i in range(len(words)-k+1):
589
+ ngram = tuple(words[i:i+k])
590
+ counts[ngram] += 1
591
+ return counts
592
+
593
+ def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
594
+ '''Takes a list of reference sentences for a single segment
595
+ and returns an object that encapsulates everything that BLEU
596
+ needs to know about them.
597
+ :param refs: list of string : reference sentences for some image
598
+ :param n: int : number of ngrams for which (ngram) representation is calculated
599
+ :return: result (list of dict)
600
+ '''
601
+ return [precook(ref, n) for ref in refs]
602
+
603
+ def cook_test(test, n=4):
604
+ '''Takes a test sentence and returns an object that
605
+ encapsulates everything that BLEU needs to know about it.
606
+ :param test: list of string : hypothesis sentence for some image
607
+ :param n: int : number of ngrams for which (ngram) representation is calculated
608
+ :return: result (dict)
609
+ '''
610
+ return precook(test, n, True)
611
+
612
+
613
+ # https://github.com/ramavedantam/cider/blob/master/pyciderevalcap/cider/cider_scorer.py
614
+ class CiderScorer(object):
615
+ """CIDEr scorer.
616
+ """
617
+
618
+ def copy(self):
619
+ ''' copy the refs.'''
620
+ new = CiderScorer(n=self.n)
621
+ new.ctest = copy.copy(self.ctest)
622
+ new.crefs = copy.copy(self.crefs)
623
+ return new
624
+
625
+ def __init__(self, test=None, refs=None, n=4, sigma=6.0):
626
+ ''' singular instance '''
627
+ self.n = n
628
+ self.sigma = sigma
629
+ self.crefs = []
630
+ self.ctest = []
631
+ self.document_frequency = defaultdict(float)
632
+ self.cook_append(test, refs)
633
+ self.ref_len = None
634
+
635
+ def cook_append(self, test, refs):
636
+ '''called by constructor and __iadd__ to avoid creating new instances.'''
637
+
638
+ if refs is not None:
639
+ self.crefs.append(cook_refs(refs))
640
+ if test is not None:
641
+ self.ctest.append(cook_test(test)) ## N.B.: -1
642
+ else:
643
+ self.ctest.append(None) # lens of crefs and ctest have to match
644
+
645
+ def size(self):
646
+ assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
647
+ return len(self.crefs)
648
+
649
+ def __iadd__(self, other):
650
+ '''add an instance (e.g., from another sentence).'''
651
+
652
+ if type(other) is tuple:
653
+ ## avoid creating new CiderScorer instances
654
+ self.cook_append(other[0], other[1])
655
+ else:
656
+ self.ctest.extend(other.ctest)
657
+ self.crefs.extend(other.crefs)
658
+
659
+ return self
660
+
661
+ def compute_doc_freq(self):
662
+ '''
663
+ Compute term frequency for reference data.
664
+ This will be used to compute idf (inverse document frequency later)
665
+ The term frequency is stored in the object
666
+ :return: None
667
+ '''
668
+ for refs in self.crefs:
669
+ # refs, k ref captions of one image
670
+ for ngram in set([ngram for ref in refs for (ngram,count) in ref.iteritems()]):
671
+ self.document_frequency[ngram] += 1
672
+ # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
673
+
674
+ def compute_cider(self, df_mode="corpus"):
675
+ def counts2vec(cnts):
676
+ """
677
+ Function maps counts of ngram to vector of tfidf weights.
678
+ The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
679
+ The n-th entry of array denotes length of n-grams.
680
+ :param cnts:
681
+ :return: vec (array of dict), norm (array of float), length (int)
682
+ """
683
+ vec = [defaultdict(float) for _ in range(self.n)]
684
+ length = 0
685
+ norm = [0.0 for _ in range(self.n)]
686
+ for (ngram,term_freq) in cnts.iteritems():
687
+ # give word count 1 if it doesn't appear in reference corpus
688
+ df = np.log(max(1.0, self.document_frequency[ngram]))
689
+ # ngram index
690
+ n = len(ngram)-1
691
+ # tf (term_freq) * idf (precomputed idf) for n-grams
692
+ vec[n][ngram] = float(term_freq)*(self.ref_len - df)
693
+ # compute norm for the vector. the norm will be used for
694
+ # computing similarity
695
+ norm[n] += pow(vec[n][ngram], 2)
696
+
697
+ if n == 1:
698
+ length += term_freq
699
+ norm = [np.sqrt(n) for n in norm]
700
+ return vec, norm, length
701
+
702
+ def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
703
+ '''
704
+ Compute the cosine similarity of two vectors.
705
+ :param vec_hyp: array of dictionary for vector corresponding to hypothesis
706
+ :param vec_ref: array of dictionary for vector corresponding to reference
707
+ :param norm_hyp: array of float for vector corresponding to hypothesis
708
+ :param norm_ref: array of float for vector corresponding to reference
709
+ :param length_hyp: int containing length of hypothesis
710
+ :param length_ref: int containing length of reference
711
+ :return: array of score for each n-grams cosine similarity
712
+ '''
713
+ delta = float(length_hyp - length_ref)
714
+ # measure consine similarity
715
+ val = np.array([0.0 for _ in range(self.n)])
716
+ for n in range(self.n):
717
+ # ngram
718
+ for (ngram,count) in vec_hyp[n].iteritems():
719
+ val[n] += vec_hyp[n][ngram] * vec_ref[n][ngram]
720
+
721
+ if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
722
+ val[n] /= (norm_hyp[n]*norm_ref[n])
723
+
724
+ assert(not math.isnan(val[n]))
725
+ return val
726
+
727
+ # compute log reference length
728
+ if df_mode == "corpus":
729
+ self.ref_len = np.log(float(len(self.crefs)))
730
+ elif df_mode == "coco-val-df":
731
+ # if coco option selected, use length of coco-val set
732
+ self.ref_len = np.log(float(40504))
733
+
734
+ scores = []
735
+ for test, refs in zip(self.ctest, self.crefs):
736
+ # compute vector for test captions
737
+ vec, norm, length = counts2vec(test)
738
+ # compute vector for ref captions
739
+ score = np.array([0.0 for _ in range(self.n)])
740
+ for ref in refs:
741
+ vec_ref, norm_ref, length_ref = counts2vec(ref)
742
+ score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
743
+ # change by vrama91 - mean of ngram scores, instead of sum
744
+ score_avg = np.mean(score)
745
+ # divide by number of references
746
+ score_avg /= len(refs)
747
+ # multiply score by 10
748
+ score_avg *= 10.0
749
+ # append score of an image to the score list
750
+ scores.append(score_avg)
751
+ return scores
752
+
753
+ def compute_score(self, df_mode, option=None, verbose=0):
754
+ # compute idf
755
+ if df_mode == "corpus":
756
+ self.compute_doc_freq()
757
+ # assert to check document frequency
758
+ assert(len(self.ctest) >= max(self.document_frequency.values()))
759
+ # import json for now and write the corresponding files
760
+ else:
761
+ self.document_frequency = pickle.load(open(os.path.join('data', df_mode + '.p'),'r'))
762
+ # compute cider score
763
+ score = self.compute_cider(df_mode)
764
+ # debug
765
+ # print score
766
+ return np.mean(np.array(score)), np.array(score)
767
+
768
+
769
+ # https://github.com/ramavedantam/cider/blob/master/pyciderevalcap/cider/cider.py
770
+ class Cider:
771
+ """
772
+ Main Class to compute the CIDEr metric
773
+
774
+ """
775
+ def __init__(self, n=4, df="corpus"):
776
+ """
777
+ Initialize the CIDEr scoring function
778
+ : param n (int): n-gram size
779
+ : param df (string): specifies where to get the IDF values from
780
+ takes values 'corpus', 'coco-train'
781
+ : return: None
782
+ """
783
+ # set cider to sum over 1 to 4-grams
784
+ self._n = n
785
+ self._df = df
786
+
787
+ def compute_score(self, gts, res):
788
+ """
789
+ Main function to compute CIDEr score
790
+ : param gts (dict) : {image:tokenized reference sentence}
791
+ : param res (dict) : {image:tokenized candidate sentence}
792
+ : return: cider (float) : computed CIDEr score for the corpus
793
+ """
794
+
795
+ cider_scorer = CiderScorer(n=self._n)
796
+
797
+ for res_id in res:
798
+
799
+ hypo = res_id['caption']
800
+ ref = gts[res_id['image_id']]
801
+
802
+ # Sanity check.
803
+ assert(type(hypo) is list)
804
+ assert(len(hypo) == 1)
805
+ assert(type(ref) is list)
806
+ assert(len(ref) > 0)
807
+ cider_scorer += (hypo[0], ref)
808
+
809
+ (score, scores) = cider_scorer.compute_score(self._df)
810
+
811
+ return score, scores
812
+
813
+ def method(self):
814
+ return "CIDEr"
815
+
816
+
817
+ def calculate_CIDEr_score(audio_file_list=None, text_file_list=None):
818
+ # convert audio files to text using speech-to-text
819
+ if audio_file_list is None or text_file_list is None:
820
+ raise ValueError("Both audio_file_list and text_file_list must be provided.")
821
+ if len(audio_file_list) != len(text_file_list):
822
+ raise ValueError("audio_file_list and text_file_list must have the same length.")
823
+ # Load the CIDEr scorer
824
+ cider_scorer = Cider(n=4, df="corpus")
825
+ # Prepare the ground truth and results
826
+ gts = {}
827
+ res = []
828
+ from spacy.tokenizer import Tokenizer
829
+ from spacy.lang.en import English
830
+ nlp = English()
831
+ # Create a blank Tokenizer with just the English vocab
832
+ tokenizer = Tokenizer(nlp.vocab)
833
+
834
+ for audio_file, text_file in zip(audio_file_list, text_file_list):
835
+ # Convert audio to text
836
+ text = speech_to_text(audio_file=audio_file)
837
+
838
+ gts[audio_file] = [tokenizer(text).words] # Tokenize the text
839
+
840
+ with open(text_file, 'r') as f:
841
+ reference_text = f.read().strip()
842
+ # Tokenize the reference text
843
+ text = tokenizer(reference_text).words
844
+ res.append({
845
+ 'image_id': audio_file,
846
+ 'caption': [text]
847
+ })
848
+ # Compute the CIDEr score
849
+ score, scores = cider_scorer.compute_score(gts, res)
850
+ return {
851
+ "CIDEr_score": score,
852
+ "scores": scores
853
+ }
854
+
855
+
856
+
857
+
858
+
859
+
860
+
861
+ # ================================================ WER (Word Error Rate) related functions ================================================
862
+ # These functions are used to calculate the WER
863
+
864
+ # pip install werpy
865
+
866
+ import werpy
867
+ def calculate_wer(audio_file_list: list, text_file_list: list) -> float:
868
+ """Calculate the Word Error Rate (WER) between a reference and a hypothesis.
869
+ Args:
870
+ audio_file_list (list): List of audio files to be transcribed.
871
+ text_file_list (list): List of text files containing the reference transcriptions.
872
+ """
873
+ if len(audio_file_list) != len(text_file_list):
874
+ raise ValueError("audio_file_list and text_file_list must have the same length.")
875
+
876
+ total_wer = 0.0
877
+ for audio_file, text_file in zip(audio_file_list, text_file_list):
878
+ # Convert audio to text using speech-to-text
879
+ transcribed_text = speech_to_text(audio_file=audio_file)
880
+
881
+ # Read the reference text from the file
882
+ with open(text_file, 'r') as f:
883
+ reference_text = f.read().strip()
884
+
885
+ # Calculate WER
886
+ wer_score = werpy.wer(reference_text, transcribed_text)
887
+ total_wer += wer_score
888
+
889
+ average_wer = total_wer / len(audio_file_list)
890
+ return {"WER_score": average_wer}
891
+
892
+
893
+
894
+
895
+ # ================================================ MCD (Mel Cepstral Distortion ) related functions ================================================
896
+ # These functions are used to calculate the MCD
897
+
898
+ # pip install -U pymcd
899
+ from pymcd.mcd import Calculate_MCD
900
+
901
+ def calculate_mcd(reference_audio_list: str, generated_audio_list: str) -> float:
902
+ """Calculate the Mel Cepstral Distortion (MCD) between two audio files.
903
+
904
+ Args:
905
+ reference_audio (str): Path to the reference audio file.
906
+ generated_audio (str): Path to the generated audio file.
907
+
908
+ Returns:
909
+ float: The MCD score.
910
+ """
911
+ # instance of MCD class
912
+ # three different modes "plain", "dtw" and "dtw_sl" for the above three MCD metrics
913
+ mcd_toolbox = Calculate_MCD(MCD_mode="plain")
914
+
915
+ # two inputs w.r.t. reference (ground-truth) and synthesized speeches, respectively
916
+ mcd_scores = []
917
+ for ref_audio, gen_audio in zip(reference_audio_list, generated_audio_list):
918
+ # calculate MCD score
919
+ mcd_score = mcd_toolbox.calculate_mcd(ref_audio, gen_audio)
920
+ mcd_scores.append(mcd_score)
921
+ # calculate average MCD score
922
+ mcd_score = sum(mcd_scores) / len(mcd_scores)
923
+ if mcd_score is None:
924
+ raise ValueError("MCD score could not be calculated. Please check the audio files.")
925
+
926
+ return {"MCD_score": mcd_score, "mcd_scores": mcd_scores}
927
+
928
+
929
+
930
+ class AudioGenerationModel:
931
+ def __init__(self, model_name: str):
932
+ self.model_name = model_name
933
+
934
+ def __init__(self, model_name: str):
935
+ self.model_name = model_name
936
+ self.load_model()
937
+
938
+ def load_model(self):
939
+ # Placeholder for loading the model
940
+ pass
941
+
942
+ def generate(self, input_text: str) -> np.ndarray:
943
+ # Placeholder for audio generation logic
944
+ # This should return the generated audio as a numpy array or a file path
945
+ pass
946
+
947
+
948
+
949
+ @dataclass
950
+ class Instance:
951
+ input: Dict[str, Any]
952
+ output: Dict[str, Any]
953
+ id: str
954
+
955
+
956
+ class BaseTask(ABC):
957
+ def __init__(self, task_data: Dict[str, Any], model: AudioGenerationModel, audio_dir: str = None, output_dir: str = None, task_name: str = None):
958
+ self.task_data = read_json(task_data)
959
+ self.model = model
960
+ self.audio_dir = audio_dir # should include the audios files
961
+ self.data = self._parse_data(self.task_data)
962
+ self.task_name = os.path.dirname(task_data).split("/")[-1] if task_name is None else task_name
963
+ self.output_dir = output_dir
964
+ os.makedirs(self.output_dir, exist_ok=True) if self.output_dir else None
965
+
966
+ self.references = []
967
+ self.predictions = []
968
+
969
+ def save_predictions(self, audio_paths):
970
+ results = []
971
+ for gt, response, audio_path in zip(self.references, self.predictions, audio_paths):
972
+ results.append({
973
+ 'gt': gt,
974
+ 'response': response,
975
+ 'audio_path': audio_path,
976
+ })
977
+ time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
978
+ results_file = os.path.join(self.output_dir, f'{self.task_name }_{time_prefix}.json') if self.output_dir else f'{self.task_name }_{time_prefix}.json'
979
+ json.dump(results, open(results_file, 'w'))
980
+
981
+ @abstractmethod
982
+ def _get_choice_candidate(self):
983
+ pass
984
+
985
+ @abstractmethod
986
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
987
+ pass
988
+
989
+ @abstractmethod
990
+ def evaluate(self) -> Dict[str, float]:
991
+ pass
992
+
993
+ @abstractmethod
994
+ def run_inference(self):
995
+ pass
996
+
997
+
998
+ class SingleCaptionToAudio(BaseTask):
999
+ def __init__(self, task_data: Dict[str, Any], model: AudioGenerationModel, audio_dir: str = None, output_dir: str = None, task_name: str = None):
1000
+ super().__init__(task_data, model, audio_dir, output_dir, task_name)
1001
+ self._get_choice_candidate()
1002
+
1003
+ def _get_choice_candidate(self):
1004
+ # Placeholder for getting choice candidates
1005
+ pass
1006
+
1007
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
1008
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
1009
+ for d in task_data["data"]]
1010
+
1011
+ def save_predictions(self, audio_paths):
1012
+ results = []
1013
+ for gt, response, audio_path in zip(self.references, self.predictions, audio_paths):
1014
+ results.append({
1015
+ 'gt': gt,
1016
+ 'response': response,
1017
+ 'audio_path': audio_path,
1018
+ })
1019
+ time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
1020
+ results_file = os.path.join(self.output_dir, f'{self.task_name }_{time_prefix}.json') if self.output_dir else f'{self.task_name }_{time_prefix}.json'
1021
+ json.dump(results, open(results_file, 'w'))
1022
+
1023
+
1024
+ def evaluate(self) -> Dict[str, float]:
1025
+ self.predictions = []
1026
+ self.references = []
1027
+ for inst in tqdm.tqdm(self.data):
1028
+ audio_path = os.path.join(self.audio_dir, inst.input["audio_file"])
1029
+ prompt = inst.input["prompt"]
1030
+ try:
1031
+ response = self.model.generate(prompt, audio_path=audio_path)
1032
+ except:
1033
+ print("error audio {}".format(inst.input["audio_file"]))
1034
+ continue
1035
+ # response is the generated audio file path
1036
+ self.predictions.append(response)
1037
+ self.references.append(prompt)
1038
+ # self.save_predictions(audio_paths)
1039
+
1040
+ def run_inference(self):
1041
+ clap_score = calculate_clap_score(self.predictions, self.references)
1042
+ return clap_score
1043
+
1044
+
1045
+ class VideoToAudio(BaseTask):
1046
+ def __init__(self, task_data: Dict[str, Any], model: AudioGenerationModel, audio_dir: str = None, output_dir: str = None, task_name: str = None):
1047
+ super().__init__(task_data, model, audio_dir, output_dir, task_name)
1048
+ self._get_choice_candidate()
1049
+
1050
+ def _get_choice_candidate(self):
1051
+ # Placeholder for getting choice candidates
1052
+ pass
1053
+
1054
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
1055
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
1056
+ for d in task_data["data"]]
1057
+
1058
+ def evaluate(self) -> Dict[str, float]:
1059
+ self.predictions = []
1060
+ self.references = []
1061
+ for inst in tqdm.tqdm(self.data):
1062
+ video_path = os.path.join(self.audio_dir, inst.input["video_file"])
1063
+ prompt = inst.input["prompt"]
1064
+ try:
1065
+ response = self.model.generate(prompt, video_path=video_path)
1066
+ except:
1067
+ print("error video {}".format(inst.input["video_file"]))
1068
+ continue
1069
+ # response is the generated audio file path
1070
+ self.predictions.append(response)
1071
+ self.references.append(prompt)
1072
+
1073
+ def run_inference(self):
1074
+ fad_score = calculate_fad_score(
1075
+ background_dir=self.audio_dir,
1076
+ eval_dir=self.output_dir
1077
+ )
1078
+ return fad_score
1079
+
1080
+
1081
+ class ImageToSpeech(BaseTask):
1082
+ def __init__(self, task_data: Dict[str, Any], model: AudioGenerationModel, audio_dir: str = None, output_dir: str = None, task_name: str = None):
1083
+ super().__init__(task_data, model, audio_dir, output_dir, task_name)
1084
+ self._get_choice_candidate()
1085
+
1086
+ def _get_choice_candidate(self):
1087
+ # Placeholder for getting choice candidates
1088
+ pass
1089
+
1090
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
1091
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
1092
+ for d in task_data["data"]]
1093
+
1094
+ def evaluate(self) -> Dict[str, float]:
1095
+ # Placeholder for evaluation logic
1096
+ self.predictions = []
1097
+ self.references = []
1098
+ for inst in tqdm.tqdm(self.data):
1099
+ image_path = os.path.join(self.audio_dir, inst.input["image_file"])
1100
+ prompt = inst.input["prompt"]
1101
+ try:
1102
+ response = self.model.generate(prompt, image_path=image_path)
1103
+ except:
1104
+ print("error image {}".format(inst.input["image_file"]))
1105
+ continue
1106
+ # response is the generated audio file path
1107
+ self.predictions.append(response)
1108
+ self.references.append(prompt)
1109
+
1110
+ def run_inference(self):
1111
+ CIDEr_score = calculate_CIDEr_score(
1112
+ audio_file_list=self.predictions,
1113
+ text_file_list=self.references
1114
+ )
1115
+ return CIDEr_score
1116
+
1117
+
1118
+ def log_performance_csv(model_name, task_name, metric, score, root_path, output_file='prediction.json'):
1119
+ import csv
1120
+ file_exists = os.path.isfile(os.path.join(root_path, output_file))
1121
+
1122
+ row_data = {
1123
+ 'model': model_name,
1124
+ 'task': task_name,
1125
+ 'metric': metric,
1126
+ 'score': str(score),
1127
+ }
1128
+
1129
+ with open(os.path.join(root_path, output_file), mode='a', newline='', encoding='utf-8') as f:
1130
+ writer = csv.DictWriter(f, fieldnames=row_data.keys())
1131
+ if not file_exists:
1132
+ writer.writeheader()
1133
+
1134
+ writer.writerow(row_data)
1135
+
1136
+
1137
+ def log_performance_json(model_name, task_name, metric, score, root_path, output_file='prediction.json'):
1138
+ import json
1139
+ log_data = {
1140
+ 'model': model_name,
1141
+ 'task': task_name,
1142
+ 'metric': metric,
1143
+ 'score': str(score),
1144
+ }
1145
+
1146
+ log_file_path = os.path.join(root_path, output_file)
1147
+
1148
+ if os.path.exists(log_file_path):
1149
+ with open(log_file_path, 'r') as f:
1150
+ existing_data = json.load(f)
1151
+ else:
1152
+ existing_data = []
1153
+
1154
+ existing_data.append(log_data)
1155
+
1156
+ with open(log_file_path, 'w', encoding='utf-8') as f:
1157
+ json.dump(existing_data, f, indent=4)
1158
+
1159
+
1160
+
1161
+
1162
+ if __name__ == "__main__":
1163
+ import argparse
1164
+ # Parse command line arguments
1165
+ parser = argparse.ArgumentParser(description="Run audio generation tasks")
1166
+ parser.add_argument('-m', '--model_name', type=str, required=True, help='Name of the audio generation model to use')
1167
+ parser.add_argument('-d', '--data_dir', type=str, default='./audio/generation/', help='Directory containing task data')
1168
+ parser.add_argument('-o', '--output_dir', type=str, default='./audio/predictions/generation/', help='Directory to save predictions for each task')
1169
+ parser.add_argument('-r', '--root_path', type=str, default='./', help='Root path for logging performance')
1170
+ parser.add_argument('-t', '--task_names', type=str, nargs='+',
1171
+ help='List of task names to run (for example: SingleCaptionToAudio VideoToAudio ImageToSpeech)')
1172
+ args = parser.parse_args()
1173
+
1174
+ # Initialize the model
1175
+ model = AudioGenerationModel(model_name=args.model_name)
1176
+ # data_dir = './generation/'
1177
+ # output_dir = f'./predictions/generation/{args.model_name}'
1178
+ # root_path = './'
1179
+
1180
+ task_name_list = [
1181
+ 'SingleCaptionToAudio', 'VideoToAudio', 'ImageToSpeech',
1182
+ # Add more task names as needed
1183
+ ]
1184
+
1185
+ if args.task_names is None or len(args.task_names) == 0:
1186
+ args.task_names = task_name_list
1187
+
1188
+ for task_name in args.task_names: # os.listdir(data_dir):
1189
+
1190
+ # Dynamically get the class by its name
1191
+ if task_name in globals(): # Ensure the class is defined in the current scope
1192
+ task_class = globals()[task_name]
1193
+ else:
1194
+ # Optionally, handle cases where the class is not found
1195
+ print(f"Task {task_name} is not defined in the current scope.")
1196
+ continue
1197
+
1198
+ # Initialize the task class
1199
+ import glob
1200
+ json_file_list = glob.glob(os.path.join(args.data_dir, task_name, "*.json"))
1201
+ if len(json_file_list) == 0:
1202
+ print(f"No JSON files found for task: {task_name}")
1203
+ continue
1204
+ elif len(json_file_list) > 1:
1205
+ print(f"Multiple JSON files found for task: {task_name}, using the first one: {json_file_list[0]}")
1206
+ task_annotation_data = json_file_list[0]
1207
+ else:
1208
+ task_annotation_data = json_file_list[0]
1209
+ print(f"Using task annotation data: {task_annotation_data}")
1210
+ task = task_class(
1211
+ task_data=task_annotation_data,
1212
+ model=model,
1213
+ audio_dir=os.path.join(args.data_dir, task_name, 'audios'),
1214
+ output_dir=args.output_dir
1215
+ )
1216
+
1217
+ # Run inference for the task
1218
+ # This should generate audio files based on the task's data
1219
+ print(f"Running inference for task: {task_name}")
1220
+ task.run_inference()
1221
+ # if you want to save the predictions, you need to rewrite the save_predictions() in each Task class depending on your need, and call task.save_predictions() after task.run_inference() or inside the run_inference method.
1222
+
1223
+
1224
+ # Evaluate the task, return a dictionary of metrics
1225
+ # For example, {'FAD_score': 0.123}
1226
+ eval_results = task.evaluate()
1227
+ print("Task name: ", task_name, "Evaluation results:", eval_results)
1228
+ log_performance_json(
1229
+ model_name=args.model_name,
1230
+ task_name=task_name,
1231
+ metric=list(eval_results.keys())[0].split('_')[0], # FAD_score
1232
+ score=eval_results[list(eval_results.keys())[0]], # e.g., 0.123
1233
+ root_path=args.data_dir)
1234
+
1235
+ # or you can run the tasks one by one like below:
1236
+ # task_name = 'SingleCaptionToAudio'
1237
+ # task = SingleCaptionToAudio(
1238
+ # task_data=os.path.join(data_dir, f"{task_name}/annotation.json"),
1239
+ # model=model,
1240
+ # audio_dir=os.path.join(data_dir, f"{task_name}/audios"),
1241
+ # output_dir=output_dir)
1242
+ # task.run_inference()
1243
+ # print(task.evaluate())
1244
+
1245
+
predictors/nlp_predictor.py ADDED
@@ -0,0 +1,1024 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import tqdm
4
+ from typing import List, Dict, Any
5
+ import nltk
6
+ import re
7
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
8
+ from dataclasses import dataclass
9
+ from abc import ABC, abstractmethod
10
+ from transformers import pipeline
11
+ from rouge_score import rouge_scorer
12
+ from codebleu import calc_codebleu
13
+ import math
14
+ import numpy as np
15
+ import jieba
16
+
17
+ import torch
18
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
19
+
20
+
21
+ class LLMModel:
22
+ def __init__(self, model_name: str):
23
+ self.model_name = model_name
24
+ self.is_time_series = False
25
+ self.timesfm_model = None # timesfm时序模型
26
+
27
+ if "timesfm" in model_name.lower():
28
+ import timesfm
29
+ self.is_time_series = True
30
+ self.tfm = timesfm.TimesFm(
31
+ hparams=timesfm.TimesFmHparams(
32
+ backend="gpu",
33
+ per_core_batch_size=32,
34
+ ),
35
+ checkpoint=timesfm.TimesFmCheckpoint(
36
+ huggingface_repo_id=model_name),
37
+ )
38
+
39
+ elif "qwen" in model_name.lower() or "gemma" in model_name.lower() or "internlm" in model_name.lower() or "vicuna" in model_name.lower() or "gpt" in model_name.lower():
40
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
41
+ self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto")
42
+ self.copied_model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto")
43
+ self.model = self.model.eval()
44
+
45
+ elif "chatglm" in model_name.lower():
46
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
47
+ self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto")
48
+ self.model = self.model.eval()
49
+
50
+ else:
51
+ self.pipeline = pipeline("text-generation", model=model_name, device_map="auto", trust_remote_code=True)
52
+
53
+
54
+ def generate(self, prompt: str, max_new_tokens=256) -> str:
55
+ if self.is_time_series:
56
+ raise NotImplementedError("This model is a time-series model. Please call generate_for_timeseries() instead of generate().")
57
+
58
+ if "vicuna" in self.model_name.lower() or "gpt" in self.model_name.lower():
59
+ inputs = self.tokenizer(prompt, return_tensors="pt")
60
+ generate_ids = self.model.generate(inputs.input_ids.cuda(), max_new_tokens=max_new_tokens, pad_token_id=self.tokenizer.eos_token_id)
61
+ output = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
62
+ return output
63
+
64
+ elif "llama" in self.model_name.lower():
65
+ self.messages = [
66
+ {"role": "system", "content": "You are a helpful and useful AI assistant."},
67
+ {"role": "user", "content":prompt }
68
+ ]
69
+ prompt = self.pipeline.tokenizer.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
70
+ terminators = [
71
+ self.pipeline.tokenizer.eos_token_id,
72
+ self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
73
+ ]
74
+ output = self.pipeline(prompt, max_new_tokens=max_new_tokens, num_return_sequences=1,
75
+ pad_token_id = self.pipeline.tokenizer.eos_token_id,
76
+ return_full_text=False, eos_token_id=terminators)
77
+ return output[0]["generated_text"]
78
+
79
+ elif "qwen" in self.model_name.lower():
80
+ self.messages = [
81
+ {"role": "system", "content": "You are a helpful and useful AI assistant."},
82
+ {"role": "user", "content": prompt}
83
+ ]
84
+ prompt = self.tokenizer.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
85
+ model_inputs = self.tokenizer([prompt], return_tensors="pt").to("cuda")
86
+ generated_ids = self.model.generate(model_inputs.input_ids, max_new_tokens=max_new_tokens, pad_token_id=self.tokenizer.eos_token_id)
87
+ generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
88
+ response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
89
+ return response
90
+
91
+ elif "gemma" in self.model_name.lower():
92
+ self.messages = [
93
+ {"role": "user", "content": prompt}
94
+ ]
95
+ prompt = self.tokenizer.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
96
+ model_inputs = self.tokenizer([prompt], return_tensors="pt").to("cuda")
97
+ generated_ids = self.model.generate(model_inputs.input_ids, max_new_tokens=max_new_tokens, pad_token_id=self.tokenizer.eos_token_id)
98
+ generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
99
+ response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
100
+ return response
101
+
102
+ elif "chatglm" in self.model_name.lower() or "internlm" in self.model_name.lower():
103
+ response, _ = self.model.chat(self.tokenizer, prompt, history=[])
104
+ return response
105
+
106
+ def generate_for_timeseries(
107
+ self,
108
+ series_data: List[float],
109
+ horizon: int = 1,
110
+ freq: int = 0
111
+ ) -> List[float]:
112
+ if self.is_time_series and self.tfm is not None:
113
+ forecast_input = [series_data]
114
+ frequency_input = [freq]
115
+
116
+ point_forecast, _ = self.tfm.forecast(
117
+ forecast_input,
118
+ freq=frequency_input
119
+ )
120
+
121
+ forecast_result = point_forecast[0]
122
+ if horizon < len(forecast_result):
123
+ forecast_result = forecast_result[:horizon]
124
+ return forecast_result.tolist()
125
+
126
+ else:
127
+ prompt = (
128
+ "You are a time-series forecasting assistant.\n"
129
+ f"The historical data points are: {series_data}.\n"
130
+ f"Please predict the next {horizon} future data point(s) directly without other words based on the historical trend.\n\n"
131
+ "Format your answer as a list of floats, e.g. `[3.1415, 2.7182]`.\n"
132
+ "Answer:"
133
+ )
134
+
135
+ raw_response = self.generate(prompt, max_new_tokens=64)
136
+ import re
137
+ pattern = r"\[([\d\.\,\s\-eE]+)\]"
138
+ match = re.search(pattern, raw_response)
139
+ if not match:
140
+ print("Warning: LLM output not in expected format, fallback to 0.0")
141
+ return [0.0] * horizon
142
+
143
+ numbers_str = match.group(1)
144
+ raw_nums = re.split(r"[\s,]+", numbers_str.strip())
145
+ parsed_vals = []
146
+ for val in raw_nums:
147
+ try:
148
+ parsed_vals.append(float(val))
149
+ except ValueError:
150
+ continue
151
+
152
+ # 如果预测数量不够 horizon,就做填充或截断
153
+ if len(parsed_vals) < horizon:
154
+ # 填充
155
+ while len(parsed_vals) < horizon:
156
+ parsed_vals.append(parsed_vals[-1] if parsed_vals else 0.0)
157
+ elif len(parsed_vals) > horizon:
158
+ parsed_vals = parsed_vals[:horizon]
159
+
160
+ return parsed_vals
161
+
162
+
163
+ @dataclass
164
+ class Instance:
165
+ input: Dict[str, Any]
166
+ output: Dict[str, Any]
167
+ id: str
168
+
169
+ class BaseTask(ABC):
170
+ def __init__(self, task_data: Dict[str, Any], model: LLMModel):
171
+ self.task_data = task_data
172
+ self.model = model
173
+ self.data = self._parse_data(task_data)
174
+
175
+ @abstractmethod
176
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
177
+ pass
178
+
179
+ @abstractmethod
180
+ def run_inference(self):
181
+ pass
182
+
183
+
184
+ class MultipleChoiceQA(BaseTask):
185
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
186
+ return [Instance(input=d["input"], output={}, id=d["id"])
187
+ for d in task_data["data"]]
188
+
189
+ def run_inference(self):
190
+ self.predictions = []
191
+ for inst in tqdm.tqdm(self.data):
192
+ question = inst.input["question"]
193
+ options = inst.input["options"]
194
+ options_chars = [chr(65 + i) for i in range(len(options))]
195
+ prompt = f"Question: {question}\nOptions:\n"
196
+ for i, opt in enumerate(options):
197
+ prompt += options_chars[i] + ". " + opt + "\n"
198
+
199
+ if self.task_data["task"] == "Causal Reasoning":
200
+ prompt += f"{question}\nPlease substitute yourself into the above scenario and select the most likely cause and effect outcome. "
201
+ prompt += r'Please answer the question and output it strictly in the following format: "The final answer is $\boxed{your choice}$" at the end of the sentence.'
202
+ response = self.model.generate(prompt, max_new_tokens=256)
203
+ pred = None
204
+ if "answer" not in response:
205
+ pred = "A"
206
+ else:
207
+ pattern = "answer"
208
+ response = re.split(pattern, response, flags=re.IGNORECASE)[-1]
209
+ for opt in options_chars:
210
+ if opt in response:
211
+ pred = opt
212
+ break
213
+ if pred is None:
214
+ pred = "A"
215
+
216
+ self.predictions.append(pred)
217
+
218
+
219
+ class OpenQA(BaseTask):
220
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
221
+ return [Instance(input=d["input"], output={}, id=d["id"])
222
+ for d in task_data["data"]]
223
+
224
+ def run_inference(self):
225
+ self.predictions = []
226
+ for inst in tqdm.tqdm(self.data):
227
+ prompt = ""
228
+ question = inst.input["question"]
229
+
230
+ if "context" in inst.input.keys():
231
+ context = inst.input["context"]
232
+ prompt += f"Given the context: {context}\n"
233
+
234
+ if self.task_data["task"] == "Temporal Reasoning":
235
+ prompt += f"{question}\nAccroding to the provided context, how long does it take for the event? Please give a direct answer without other words"
236
+ elif self.task_data["task"] == "Medical Question Answering":
237
+ prompt += f"Please answer the question in a short pargraph: {question}"
238
+ elif self.task_data["task"] == "Multilingual Question Answering":
239
+ prompt += f"Please directly answer the question using the language in the question: {question}"
240
+ elif self.task_data["task"] == "Table Question Answering":
241
+ table = inst.input["table"]
242
+ prompt += f"Please read the content of the table below carefully and then directly answer the question without other words:\n{table}\n\nQuestion: {question}\nAnswer:"
243
+ else:
244
+ prompt += f"Please directly answer the question in a short sentence: {question}"
245
+ if self.task_data["task"] == "Document-Level Causal":
246
+ prompt += f"\nIf the context does not contain an answer to the question, simply output \"None of the above\"."
247
+
248
+ response = self.model.generate(prompt, max_new_tokens=256)
249
+ pred = response.strip()
250
+ self.predictions.append(pred)
251
+
252
+
253
+ class SummarizationTask(BaseTask):
254
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
255
+ instances = []
256
+ for d in task_data["data"]:
257
+ if "document_list" in d:
258
+ instance = Instance(
259
+ input={"document_list": d["document_list"]},
260
+ output={},
261
+ id=d["id"]
262
+ )
263
+ elif d.get("input") and "highlights" in d.get("output", {}):
264
+ instance = Instance(
265
+ input={"document": d["document"]},
266
+ output={},
267
+ id=d["id"]
268
+ )
269
+ else:
270
+ instance = Instance(
271
+ input={"document": d["document"]},
272
+ output={},
273
+ id=d["id"]
274
+ )
275
+ instances.append(instance)
276
+ return instances
277
+
278
+ def run_inference(self):
279
+ self.predictions = []
280
+ for inst in tqdm.tqdm(self.data):
281
+ if "document_list" in inst.input:
282
+ doc_list = inst.input["document_list"]
283
+ combined_docs = "\n".join(doc_list)
284
+
285
+ prompt = (
286
+ "You are a multi-document summarization assistant.\n"
287
+ "Please read the following documents, and then summarize them in a concise paragraph:\n\n"
288
+ f"{combined_docs}\n\n"
289
+ "Summary:"
290
+ )
291
+ else:
292
+ doc = inst.input["document"]
293
+ prompt = (
294
+ "Please summarize the following document in a short sentence\n"
295
+ f"{doc}\n"
296
+ "Summary:"
297
+ )
298
+
299
+ pred = self.model.generate(prompt, max_new_tokens=256)
300
+
301
+ if "Summary:" in pred:
302
+ pred = pred.split("Summary:")[-1].strip()
303
+ else:
304
+ pred = pred.strip()
305
+
306
+ self.predictions.append(pred)
307
+
308
+
309
+ class TranslationTask(BaseTask):
310
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
311
+ return [Instance(input={
312
+ "source_lang": d["in"],
313
+ "target_lang": d["out"],
314
+ "text": d["input"]
315
+ },
316
+ output={},
317
+ id=d["id"])
318
+ for d in task_data["data"]]
319
+
320
+ def run_inference(self):
321
+ self.predictions = []
322
+ for inst in tqdm.tqdm(self.data):
323
+ source_lang = inst.input["source_lang"]
324
+ target_lang = inst.input["target_lang"]
325
+ text = inst.input["text"]
326
+
327
+ prompt = (f"Please directly Translate the following text from {source_lang} to {target_lang}.\n"
328
+ f"Text: {text}\n"
329
+ f"Translation:")
330
+ pred = self.model.generate(prompt, max_new_tokens=256)
331
+ if "Translation:" in pred:
332
+ pred = pred.split("Translation:")[-1].strip()
333
+ else:
334
+ pred = pred.strip()
335
+
336
+ self.predictions.append(pred)
337
+
338
+
339
+ class StoryGenerationTask(BaseTask):
340
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
341
+ instances = []
342
+ for d in task_data["data"]:
343
+ instances.append(
344
+ Instance(
345
+ input=d["input"],
346
+ output={},
347
+ id=d["id"]
348
+ )
349
+ )
350
+ return instances
351
+
352
+ def run_inference(self):
353
+ self.predictions = []
354
+ for inst in tqdm.tqdm(self.data):
355
+ prompt_text = inst.input["prompt"]
356
+ prompt = f"Please write a story based on the following prompt:\n{prompt_text}\nStory:"
357
+ pred = self.model.generate(prompt, max_new_tokens=512)
358
+ if "Story:" in pred:
359
+ pred = pred.split("Story:")[-1].strip()
360
+
361
+ self.predictions.append(pred)
362
+
363
+
364
+ class DialogueGenerationTask(BaseTask):
365
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
366
+ instances = []
367
+ for d in task_data["data"]:
368
+ dialog_list = d.get("dialog", [])
369
+ if not dialog_list:
370
+ continue
371
+
372
+ instances.append(
373
+ Instance(
374
+ input={"dialog": dialog_list},
375
+ output={},
376
+ id=d["id"]
377
+ )
378
+ )
379
+ return instances
380
+
381
+ def run_inference(self):
382
+ self.predictions = []
383
+
384
+ for inst in tqdm.tqdm(self.data):
385
+ dialog_context = inst.input["dialog"]
386
+ prompt = "Below is a multi-turn conversation. Please continue the dialogue for the last turn.\n\n"
387
+ for turn_idx, turn in enumerate(dialog_context):
388
+ prompt += f"Turn {turn_idx + 1}: {turn}\n"
389
+ prompt += "\nNow please respond in one short answer:\n"
390
+
391
+ pred = self.model.generate(prompt, max_new_tokens=128).strip()
392
+ self.predictions.append(pred)
393
+
394
+
395
+ class CodeGenerationTask(BaseTask):
396
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
397
+ instances = []
398
+ for d in task_data["data"]:
399
+ instance_id = d["id"]
400
+ language = d["language"]
401
+ goal = d["goal"]
402
+ context = d.get("context", [])
403
+
404
+ instances.append(
405
+ Instance(
406
+ input={
407
+ "language": language,
408
+ "goal": goal,
409
+ "context": context
410
+ },
411
+ output={},
412
+ id=instance_id
413
+ )
414
+ )
415
+ return instances
416
+
417
+ def run_inference(self):
418
+ self.predictions = []
419
+ self.languages = []
420
+
421
+ for inst in tqdm.tqdm(self.data):
422
+ language = inst.input["language"]
423
+ goal = inst.input["goal"]
424
+ context = inst.input["context"]
425
+
426
+ prompt = f"You are an AI developer. Your goal is: {goal}\n"
427
+ prompt += f"Please write {language} code that solves the described task.\n\n"
428
+
429
+ for c_item in context:
430
+ c_type = c_item["type"]
431
+ c_content = c_item["content"]
432
+ if c_type == "description":
433
+ prompt += f"Description:\n{c_content}\n\n"
434
+ elif c_type == "example":
435
+ prompt += "Examples:\n"
436
+ for ex in c_content:
437
+ prompt += f"- Input: {ex['input']}, Expected Output: {ex['output']}\n"
438
+ prompt += "\n"
439
+ else:
440
+ prompt += f"{c_type.capitalize()}:\n{c_content}\n\n"
441
+
442
+ prompt += (
443
+ "Now, please output ONLY the final code solution (without additional explanations, comments or text)."
444
+ "\nCode:\n"
445
+ )
446
+
447
+ pred_code = self.model.generate(prompt, max_new_tokens=256).strip()
448
+ if "Code:" in pred_code:
449
+ pred_code = pred_code.split("Code:", 1)[-1].strip()
450
+
451
+ self.predictions.append(pred_code)
452
+ self.languages.append(language)
453
+
454
+
455
+ class CodeRepairTask(BaseTask):
456
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
457
+ instances = []
458
+ for d in task_data["data"]:
459
+ instance_id = d["id"]
460
+ input_part = d["input"]
461
+
462
+ prompt = input_part["prompt"]
463
+ source_code = input_part["sourceCode"]
464
+ instances.append(
465
+ Instance(
466
+ input={
467
+ "prompt": prompt,
468
+ "sourceCode": source_code
469
+ },
470
+ output={},
471
+ id=instance_id
472
+ )
473
+ )
474
+ return instances
475
+
476
+ def run_inference(self):
477
+ self.predictions = []
478
+
479
+ for inst in tqdm.tqdm(self.data):
480
+ prompt = inst.input["prompt"]
481
+ source_code = inst.input["sourceCode"]
482
+ final_prompt = (
483
+ f"{prompt}\n"
484
+ f"{source_code}\n\n"
485
+ "Now, please output ONLY the final code solution (without additional explanations, comments or text)."
486
+ "Refined Code:"
487
+ )
488
+
489
+ pred_code = self.model.generate(final_prompt, max_new_tokens=256).strip()
490
+ if "Refined Code:" in pred_code:
491
+ pred_code = pred_code.split("Refined Code:", 1)[-1].strip()
492
+
493
+ self.predictions.append(pred_code)
494
+
495
+
496
+ class CodeDefectDetectionTask(BaseTask):
497
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
498
+ instances = []
499
+ for d in task_data["data"]:
500
+ instances.append(
501
+ Instance(
502
+ input={"func": d["func"]},
503
+ output={},
504
+ id=d["id"]
505
+ )
506
+ )
507
+ return instances
508
+
509
+ def run_inference(self):
510
+ self.predictions = []
511
+
512
+ for inst in tqdm.tqdm(self.data):
513
+ code_snippet = inst.input["func"]
514
+ prompt = (
515
+ "You are a code reviewer. Below is a piece of code or function:\n"
516
+ f"{code_snippet}\n\n"
517
+ "Please review carefully and determine if it contains a grammatical or logical defect. "
518
+ "For example, the code below has defect:\n"
519
+ "static void show_packets(AVFormatContext *format_ctx)\n\n{\n\n AVPacket packet;\n\n\n\n av_init_packet(&packet);\n\n probe_array_header(\"packets\", 0);\n\n while (!av_read_frame(format_ctx, &packet))\n\n show_packet(format_ctx, &packet);\n\n probe_array_footer(\"packets\", 0);\n\n}\n"
520
+ "For another example, the code below has no defect:\n"
521
+ "static void visitor_output_setup_internal(TestOutputVisitorData *output_data,\n\n bool is_human)\n\n{\n\n output_data->human = is_human;\n\n output_data->sov = string_output_visitor_new(is_human);\n\n g_assert(output_data->sov);\n\n output_data->ov = string_output_get_visitor(output_data->sov);\n\n g_assert(output_data->ov);\n\n}\n"
522
+ "Output only 'No defect' if it does NOT contain a grammatical or logical defect, "
523
+ "or ouput only 'Defect' if it DOES contain a defect.\n"
524
+ "Answer:"
525
+ )
526
+
527
+ response = self.model.generate(prompt, max_new_tokens=16).strip()
528
+
529
+ if "no defect" in response.lower():
530
+ pred = "0"
531
+ else:
532
+ pred = "1"
533
+
534
+ self.predictions.append(pred)
535
+
536
+
537
+ class TextToSQLTask(BaseTask):
538
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
539
+ instances = []
540
+ for d in task_data["data"]:
541
+ instances.append(
542
+ Instance(
543
+ input={
544
+ "context": d["input"]["context"],
545
+ "question": d["input"]["question"],
546
+ },
547
+ output={},
548
+ id=d["id"]
549
+ )
550
+ )
551
+ return instances
552
+
553
+ def run_inference(self):
554
+ self.predictions = []
555
+
556
+ for inst in tqdm.tqdm(self.data):
557
+ schema_context = inst.input["context"]
558
+ question = inst.input["question"]
559
+
560
+ prompt = (
561
+ "Below is a database schema:\n"
562
+ f"{schema_context}\n"
563
+ "Given the schema, please write a valid SQL query that answers the following question without other words.\n"
564
+ f"Question: {question}\n"
565
+ "SQL:"
566
+ )
567
+
568
+ response = self.model.generate(prompt, max_new_tokens=256)
569
+ if "SQL:" in response:
570
+ pred_sql = response.split("SQL:", 1)[-1].strip()
571
+ else:
572
+ pred_sql = response.strip()
573
+
574
+ self.predictions.append(pred_sql)
575
+
576
+
577
+ class CodeExplanationTask(BaseTask):
578
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
579
+ instances = []
580
+ for d in task_data["data"]:
581
+ code_snippet = d["code"]
582
+ instance_id = d["id"]
583
+
584
+ instances.append(
585
+ Instance(
586
+ input={"code": code_snippet},
587
+ output={},
588
+ id=instance_id
589
+ )
590
+ )
591
+ return instances
592
+
593
+ def run_inference(self):
594
+ self.predictions = []
595
+
596
+ for inst in tqdm.tqdm(self.data):
597
+ code_snippet = inst.input["code"]
598
+ prompt = (
599
+ "You are a code explainer. "
600
+ "Please read the following code snippet and provide a concise, clear explanation in natural language:. For example:\n"
601
+ "Code:\nboolean equalsResidueRing ( Object obj ) { if ( !( obj instanceof ResidueRing ) ) { return false ; } ResidueRing < C > otherRing = null ; try { otherRing = ( ResidueRing < C > ) obj ; } catch ( ClassCastException e ) { return false ; } if ( otherRing == null ) { return false ; } if ( ! ring . equals ( otherRing . ring ) ) { return false ; } return modul . equals ( otherRing . modul ) ; }"
602
+ "Explanation: compares this ResidueRing with another object.\n\n"
603
+ "Now please explain the code below without other words:\n"
604
+ f"{code_snippet}\n"
605
+ "Explanation:"
606
+ )
607
+
608
+ pred_explanation = self.model.generate(prompt, max_new_tokens=256).strip()
609
+ self.predictions.append(pred_explanation)
610
+
611
+
612
+ class MathematicalProofGenerationTask(BaseTask):
613
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
614
+ instances = []
615
+ for d in task_data["data"]:
616
+ statement = d["statement"]
617
+
618
+ instances.append(
619
+ Instance(
620
+ input={
621
+ "statement": statement
622
+ },
623
+ output={},
624
+ id=d["id"]
625
+ )
626
+ )
627
+ return instances
628
+
629
+ def run_inference(self):
630
+ self.predictions = []
631
+
632
+ for inst in tqdm.tqdm(self.data):
633
+ statement = inst.input["statement"]
634
+
635
+ prompt = (
636
+ "You are a mathematical assistant. "
637
+ "Please provide a clear, step-by-step proof for the following statement:\n"
638
+ f"Statement: {statement}\n\n"
639
+ "Ensure you include the final conclusion as well. Proof:"
640
+ )
641
+
642
+ pred_proof = self.model.generate(prompt, max_new_tokens=512).strip()
643
+ self.predictions.append(pred_proof)
644
+
645
+
646
+ class MathematicalWordProblemSolvingTask(BaseTask):
647
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
648
+ instances = []
649
+ for d in task_data["data"]:
650
+ problem_text = d["problem"]["text"]
651
+ constraints = d["problem"].get("constraints", [])
652
+
653
+ instances.append(
654
+ Instance(
655
+ input={
656
+ "problem_text": problem_text,
657
+ "constraints": constraints
658
+ },
659
+ output={},
660
+ id=d["id"]
661
+ )
662
+ )
663
+ return instances
664
+
665
+ def run_inference(self):
666
+ self.predictions_steps = []
667
+ self.predictions_final = []
668
+
669
+ for inst in tqdm.tqdm(self.data):
670
+ problem_text = inst.input["problem_text"]
671
+ constraints = inst.input["constraints"]
672
+ constraints_str = ""
673
+ if constraints:
674
+ constraints_str = "\nConstraints:\n" + "\n".join(constraints)
675
+
676
+ prompt = (
677
+ "You are a math problem solver. Please solve the following word problem step by step. "
678
+ "Finally, provide the final numeric or short answer in a separate line labeled as 'Final Answer:'.\n\n"
679
+ f"Problem:\n{problem_text}{constraints_str}\n\n"
680
+ "Solution (step-by-step) + Final Answer:\n"
681
+ )
682
+
683
+ response = self.model.generate(prompt, max_new_tokens=512).strip()
684
+
685
+ steps_part, final_part = response, ""
686
+ if "Final Answer:" in response:
687
+ parts = response.split("Final Answer:", 1)
688
+ steps_part = parts[0].strip()
689
+ final_part = parts[1].strip()
690
+
691
+ self.predictions_steps.append(steps_part)
692
+ self.predictions_final.append(final_part)
693
+
694
+
695
+ class ParaphraseGenerationTask(BaseTask):
696
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
697
+ instances = []
698
+ for d in task_data["data"]:
699
+ instances.append(
700
+ Instance(
701
+ input={"originalSentence": d["input"]["originalSentence"]},
702
+ output={},
703
+ id=d["id"]
704
+ )
705
+ )
706
+ return instances
707
+
708
+ def run_inference(self):
709
+ self.predictions = []
710
+ for inst in tqdm.tqdm(self.data):
711
+ original_sentence = inst.input["originalSentence"]
712
+
713
+ prompt = (
714
+ "Please rewrite the following sentence in a different way but keep the same meaning:\n"
715
+ f"{original_sentence}\n"
716
+ "Paraphrase:"
717
+ )
718
+
719
+ pred = self.model.generate(prompt, max_new_tokens=128)
720
+
721
+ if "Paraphrase:" in pred:
722
+ pred = pred.split("Paraphrase:")[-1].strip()
723
+
724
+ self.predictions.append(pred.strip())
725
+
726
+
727
+ class GrammarCorrectionTask(BaseTask):
728
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
729
+ return [
730
+ Instance(
731
+ input=d["input"],
732
+ output={},
733
+ id=d["id"]
734
+ )
735
+ for d in task_data["data"]
736
+ ]
737
+
738
+ def run_inference(self):
739
+ self.predictions = []
740
+
741
+ for inst in tqdm.tqdm(self.data):
742
+ error_type = inst.input["Error Type"]
743
+ ungrammatical_sentence = inst.input["Ungrammatical Statement"]
744
+
745
+ prompt = (
746
+ f"You are a grammar correction assistant.\n"
747
+ f"There is a sentence with the following error type: {error_type}.\n"
748
+ f"Please rewrite the sentence in correct standard English without any other word.\n\n"
749
+ f"Ungrammatical Sentence: {ungrammatical_sentence}\n\n"
750
+ f"Rewritten Sentence:"
751
+ )
752
+
753
+ corrected = self.model.generate(prompt, max_new_tokens=128).strip()
754
+ if "Rewritten Sentence:" in corrected:
755
+ corrected = corrected.split("Rewritten Sentence:")[-1].strip()
756
+
757
+ self.predictions.append(corrected)
758
+
759
+
760
+ class TextStyleTransferTask(BaseTask):
761
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
762
+ instances = []
763
+ for d in task_data["data"]:
764
+ instances.append(
765
+ Instance(
766
+ input={
767
+ "text": d["input"]["text"],
768
+ "style": d["input"]["style"]
769
+ },
770
+ output={},
771
+ id=d["id"]
772
+ )
773
+ )
774
+ return instances
775
+
776
+ def run_inference(self):
777
+ self.predictions = []
778
+
779
+ for inst in tqdm.tqdm(self.data):
780
+ text = inst.input["text"]
781
+ style = inst.input["style"]
782
+
783
+ prompt = (
784
+ "You are a style transfer assistant.\n"
785
+ "Below is a piece of text and a target style.\n"
786
+ f"Text: {text}\n"
787
+ f"Style: {style}\n\n"
788
+ "Please rewrite the above text to match the target style more accurately, "
789
+ "while keeping the original meaning intact.\n"
790
+ "Answer:"
791
+ )
792
+
793
+ pred = self.model.generate(prompt, max_new_tokens=256).strip()
794
+ if "Answer:" in pred:
795
+ pred = pred.split("Answer:")[-1].strip()
796
+
797
+ self.predictions.append(pred)
798
+
799
+
800
+ class TableToTextGenerationTask(BaseTask):
801
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
802
+ instances = []
803
+ for d in task_data["data"]:
804
+ instance_id = d["id"]
805
+ table_data = d["input"]["table"]
806
+ instances.append(
807
+ Instance(
808
+ input={"table": table_data},
809
+ output={},
810
+ id=instance_id
811
+ )
812
+ )
813
+ return instances
814
+
815
+ def run_inference(self):
816
+ self.predictions = []
817
+
818
+ for inst in tqdm.tqdm(self.data):
819
+ table_data = inst.input["table"]
820
+
821
+ prompt = "Below is a table. Please generate a coherent description that summarizes the table's content.\n\n"
822
+ for table_idx, table_item in enumerate(table_data):
823
+ header = table_item["header"]
824
+ rows = table_item["rows"]
825
+ prompt += f"Table {table_idx+1}:\nHeader: {header}\nRows:\n"
826
+ for r_idx, row in enumerate(rows):
827
+ prompt += f"{r_idx+1}. {row}\n"
828
+ prompt += "\n"
829
+
830
+ prompt += "Now write a concise text describing the above table:\n"
831
+
832
+ pred_text = self.model.generate(prompt, max_new_tokens=512)
833
+ pred_text = pred_text.strip()
834
+
835
+ self.predictions.append(pred_text)
836
+
837
+
838
+ class TimeSeriesForecastingTask(BaseTask):
839
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
840
+ instances = []
841
+ for d in task_data["data"]:
842
+ time_series = d["input"]["data"]
843
+ instances.append(
844
+ Instance(
845
+ input={"time_series": time_series},
846
+ output={},
847
+ id=d["id"]
848
+ )
849
+ )
850
+ return instances
851
+
852
+ def run_inference(self):
853
+ self.predictions = []
854
+ for inst in tqdm.tqdm(self.data):
855
+ series_data = inst.input["time_series"]
856
+ pred_values = self.model.generate_for_timeseries(series_data, horizon=1, freq=0)
857
+ predicted = pred_values[0] if pred_values else 0.0
858
+ self.predictions.append(predicted)
859
+
860
+
861
+ class ClassificationTask(BaseTask):
862
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
863
+ return [Instance(input=d["input"], output={}, id=d["id"])
864
+ for d in task_data["data"]]
865
+
866
+ def run_inference(self):
867
+ self.predictions = []
868
+ for inst in tqdm.tqdm(self.data):
869
+ if 'stance_detection' in self.task_data['task']:
870
+ tweets = inst.input["tweets"]
871
+ target = inst.input["target"]
872
+ prompt = inst.input["prompt"].replace("<<<target>>>", target).replace("<<<tweets>>>", tweets)
873
+ elif 'aspect_sentiment_classification' in self.task_data['task']:
874
+ raw_text = inst.input["raw_text"]
875
+ target = inst.input["target"]
876
+ prompt = inst.input["prompt"].replace("<<<raw_text>>>", raw_text).replace("<<<target>>>", target) + 'Please direct return the category name without any other words.'
877
+ elif 'target_oriented_opinion_words_extraction' in self.task_data['task']:
878
+ raw_text = inst.input["raw_text"]
879
+ aspect = inst.input["aspect"]
880
+ prompt = inst.input["prompt"].replace("<<<raw_text>>>", raw_text).replace("<<<aspect>>>", aspect) + 'Please direct return the opinion word without any other words.'
881
+ else:
882
+ raw_text = inst.input["raw_text"]
883
+ prompt = inst.input["prompt"].replace("<<<raw_text>>>", raw_text) + 'Please return the desired result directly, without any other explanation.'
884
+ response = self.model.generate(prompt, max_new_tokens=64)
885
+ self.predictions.append(response.lower())
886
+
887
+
888
+ class MultiLabelClassificationTask(BaseTask):
889
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
890
+ return [Instance(input=d["input"], output={}, id=d["id"])
891
+ for d in task_data["data"]]
892
+
893
+ def run_inference(self):
894
+ self.predictions = []
895
+ for inst in tqdm.tqdm(self.data):
896
+ raw_text = inst.input["raw_text"]
897
+ prompt = inst.input["prompt"].replace("<<<raw_text>>>", raw_text)
898
+ prompt = prompt + " Please return the desired result directly, without any other explanation." + " Split the result by commas instead of \\n."
899
+ response = self.model.generate(prompt, max_new_tokens=64)
900
+ self.predictions.append('<p>'.join(response.lower().split(', ')))
901
+
902
+
903
+ class ChoiceTask(BaseTask):
904
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
905
+ return [Instance(input=d["input"], output={}, id=d["id"])
906
+ for d in task_data["data"]]
907
+
908
+ def run_inference(self):
909
+ self.predictions = []
910
+ for inst in tqdm.tqdm(self.data):
911
+ raw_text = inst.input["raw_text"]
912
+ prompt = inst.input["prompt"].replace("<<<raw_text>>>", raw_text) + 'Please return the desired result directly, without any other explanation.'
913
+ response = self.model.generate(prompt, max_new_tokens=64)
914
+ if len(response.strip()) > 1:
915
+ if "A" in response.strip():
916
+ response = "A"
917
+ elif "B" in response.strip():
918
+ response = "B"
919
+ elif "C" in response.strip():
920
+ response = "C"
921
+ elif "D" in response.strip():
922
+ response = "D"
923
+ self.predictions.append(response.lower())
924
+
925
+
926
+ class NERTask(BaseTask):
927
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
928
+ return [Instance(input=d["input"], output={}, id=d["id"])
929
+ for d in task_data["data"]]
930
+
931
+ def run_inference(self):
932
+ self.predictions = []
933
+ for inst in tqdm.tqdm(self.data):
934
+ text = inst.input["raw_text"]
935
+ prompt = inst.input["prompt"].replace("<<<raw_text>>>", text)
936
+ response = self.model.generate(prompt, max_new_tokens=128)
937
+ self.predictions.append('<p>'.join(response.lower().split(', ')))
938
+
939
+
940
+ def save_predictions(task_obj: BaseTask, task_directory: str):
941
+ save_path = os.path.join(task_directory, "prediction.json")
942
+ records = []
943
+ if isinstance(task_obj, MathematicalWordProblemSolvingTask):
944
+ for idx, inst in enumerate(task_obj.data):
945
+ records.append({
946
+ "id": inst.id,
947
+ "prediction_steps": task_obj.predictions_steps[idx],
948
+ "prediction_final": task_obj.predictions_final[idx]
949
+ })
950
+ elif isinstance(task_obj, TimeSeriesForecastingTask):
951
+ for idx, inst in enumerate(task_obj.data):
952
+ records.append({
953
+ "id": inst.id,
954
+ "prediction": float(task_obj.predictions[idx])
955
+ })
956
+ else:
957
+ for idx, inst in enumerate(task_obj.data):
958
+ pred_val = task_obj.predictions[idx]
959
+ if isinstance(pred_val, (np.floating, np.integer)):
960
+ pred_val = float(pred_val)
961
+ records.append({"id": inst.id, "prediction": pred_val})
962
+ with open(save_path, "w", encoding="utf-8") as fp:
963
+ json.dump(records, fp, ensure_ascii=False, indent=2)
964
+
965
+
966
+ TASK_MAPPING = {
967
+ "MultipleChoiceQA": MultipleChoiceQA,
968
+ "OpenQA": OpenQA,
969
+ "Summarization": SummarizationTask,
970
+ "Story Generation": StoryGenerationTask,
971
+ "Translation": TranslationTask,
972
+ "Dialogue": DialogueGenerationTask,
973
+ "Code Generation": CodeGenerationTask,
974
+ "Code Defect Detection": CodeDefectDetectionTask,
975
+ "Code Repair": CodeRepairTask,
976
+ "Code Explanation": CodeExplanationTask,
977
+ "Proof": MathematicalProofGenerationTask,
978
+ "Mathematical Word Problem Solving": MathematicalWordProblemSolvingTask,
979
+ "Text to SQL": TextToSQLTask,
980
+ "Paraphrase Generation": ParaphraseGenerationTask,
981
+ "Grammar Correction": GrammarCorrectionTask,
982
+ "Table-to-Text Generation": TableToTextGenerationTask,
983
+ "Time Series": TimeSeriesForecastingTask,
984
+ "Text Style Transfer": TextStyleTransferTask,
985
+ "classification": ClassificationTask,
986
+ "multi label classification": MultiLabelClassificationTask,
987
+ "ner": NERTask,
988
+ "extraction": MultiLabelClassificationTask,
989
+ "relation extraction": MultiLabelClassificationTask,
990
+ "event detection": MultiLabelClassificationTask,
991
+ "parsing": MultiLabelClassificationTask,
992
+ "multiple choice": ChoiceTask,
993
+ }
994
+
995
+
996
+ if __name__ == "__main__":
997
+ import argparse
998
+
999
+ parser = argparse.ArgumentParser(description="NLP Predictor")
1000
+ parser.add_argument("--dataset_dir", required=True)
1001
+ parser.add_argument("--model_name", required=True)
1002
+ args = parser.parse_args()
1003
+
1004
+ data_root = os.path.abspath(args.dataset_dir)
1005
+ model = LLMModel(args.model_name)
1006
+
1007
+ task_dirs = sorted([d for d in os.listdir(data_root) if os.path.isdir(os.path.join(data_root, d))])
1008
+
1009
+ for idx, task_folder in enumerate(task_dirs, start=1):
1010
+ folder_path = os.path.join(data_root, task_folder)
1011
+ annotation_path = os.path.join(folder_path, "annotation.json")
1012
+
1013
+ with open(annotation_path, "r", encoding="utf-8") as f:
1014
+ task_data = json.load(f)
1015
+
1016
+ task_type = task_data.get("type")
1017
+ task_name = task_data.get("task", task_folder)
1018
+ print(f"\nTask {idx}/{len(task_dirs)}: {task_name} (Type = {task_type})")
1019
+
1020
+ task_class = TASK_MAPPING.get(task_type, OpenQA)
1021
+ task = task_class(task_data, model)
1022
+
1023
+ task.run_inference()
1024
+ save_predictions(task, folder_path)
predictors/video_comprehension_flow_matching_tracking.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+ from typing import List, Dict, Any
3
+ from dataclasses import dataclass
4
+ from abc import ABC, abstractmethod
5
+ from PIL import Image
6
+ import numpy as np
7
+ import cv2
8
+ from typing import Tuple
9
+ import os
10
+ import json
11
+ import argparse
12
+
13
+ import torch
14
+ from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
15
+ BitsAndBytesConfig, CLIPImageProcessor,
16
+ CLIPVisionModel, GenerationConfig)
17
+
18
+ def exact_match_accuracy(predictions: List[str], references: List[str]) -> float:
19
+ correct = 0
20
+ for pred, ref in zip(predictions, references):
21
+ if isinstance(ref, str):
22
+ ref = [ref]
23
+ is_match_this_turn = False
24
+ for r in ref:
25
+ if pred.strip() == r.strip():
26
+ is_match_this_turn = True
27
+ if is_match_this_turn:
28
+ correct += 1
29
+ return correct / len(predictions) if predictions else 0.0
30
+
31
+
32
+ def bbox_to_corners(bbox):
33
+ """将(x_min, y_min, w, h)格式转换为(x_min, y_min, x_max, y_max)格式"""
34
+ x_min, y_min, w, h = bbox
35
+ return (x_min, y_min, x_min + w, y_min + h)
36
+
37
+
38
+ def calculate_iou(bbox1, bbox2):
39
+ """计算两个边界框的交并比(IoU/Jaccard Index)"""
40
+ # 转换为对角坐标格式
41
+ bbox1 = bbox_to_corners(bbox1)
42
+ bbox2 = bbox_to_corners(bbox2)
43
+
44
+ # 计算交集区域的坐标
45
+ x1 = max(bbox1[0], bbox2[0])
46
+ y1 = max(bbox1[1], bbox2[1])
47
+ x2 = min(bbox1[2], bbox2[2])
48
+ y2 = min(bbox1[3], bbox2[3])
49
+
50
+ # 计算交集面积
51
+ intersection_area = max(0, x2 - x1) * max(0, y2 - y1)
52
+
53
+ # 计算两个边界框的面积
54
+ bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
55
+ bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
56
+
57
+ # 计算并集面积
58
+ union_area = bbox1_area + bbox2_area - intersection_area
59
+
60
+ # 计算IoU
61
+ if union_area == 0:
62
+ return 0.0
63
+ return intersection_area / union_area
64
+
65
+
66
+ def calculate_j_metric(pred_bboxes, gt_bboxes):
67
+ """计算J指标(Jaccard Index)"""
68
+ if len(pred_bboxes) != len(gt_bboxes):
69
+ raise ValueError("预测边界框和真实边界框数量不一致")
70
+
71
+ iou_values = []
72
+ for pred, gt in zip(pred_bboxes, gt_bboxes):
73
+ iou = calculate_iou(pred, gt)
74
+ iou_values.append(iou)
75
+
76
+ # 返回平均Jaccard Index
77
+ return sum(iou_values) / len(iou_values) if iou_values else 0.0
78
+
79
+
80
+ def calculate_f1_score(pred_bboxes, gt_bboxes, threshold=0.5):
81
+ """计算F1 Score(F指标)"""
82
+ if len(pred_bboxes) == 0 and len(gt_bboxes) == 0:
83
+ return 1.0 # 特殊情况:没有检测也没有真实目标,视为完全正确
84
+
85
+ true_positives = 0
86
+ false_positives = 0
87
+ false_negatives = 0
88
+
89
+ # 标记已匹配的真实边界框
90
+ gt_matched = [False] * len(gt_bboxes)
91
+
92
+ # 计算每对边界框的IoU
93
+ iou_matrix = []
94
+ for i, pred in enumerate(pred_bboxes):
95
+ row = []
96
+ for j, gt in enumerate(gt_bboxes):
97
+ row.append(calculate_iou(pred, gt))
98
+ iou_matrix.append(row)
99
+
100
+ # 贪心匹配:将每个预测边界框匹配到IoU最高的真实边界框
101
+ for i in range(len(pred_bboxes)):
102
+ if not iou_matrix:
103
+ break
104
+
105
+ # 找到当前行的最大值及其索引
106
+ max_iou = max(iou_matrix[i]) if iou_matrix[i] else 0
107
+ j = iou_matrix[i].index(max_iou) if iou_matrix[i] else -1
108
+
109
+ if max_iou >= threshold:
110
+ true_positives += 1
111
+ gt_matched[j] = True
112
+ else:
113
+ false_positives += 1
114
+
115
+ # 计算假阴性
116
+ false_negatives = sum(1 for matched in gt_matched if not matched)
117
+
118
+ # 计算精确率和召回率
119
+ precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
120
+ recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
121
+
122
+ # 计算F1 Score
123
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
124
+ return f1
125
+
126
+
127
+ def calculate_j_and_f_metrics(pred_bboxes, gt_bboxes, iou_threshold=0.5):
128
+ """计算J指标和F指标"""
129
+ # 计算J指标
130
+ j_metric = calculate_j_metric(pred_bboxes, gt_bboxes)
131
+
132
+ # 计算F指标
133
+ f_metric = calculate_f1_score(pred_bboxes, gt_bboxes, threshold=iou_threshold)
134
+
135
+ return {
136
+ "J_metric": j_metric,
137
+ "F_metric": f_metric
138
+ }
139
+
140
+ def read_flow(file_path: str) -> np.ndarray:
141
+ if file_path.endswith('.flo'):
142
+ return read_flow_flo(file_path)
143
+ elif file_path.endswith(('.png', '.jpg', '.jpeg')):
144
+ return read_flow_png(file_path)
145
+ else:
146
+ raise NotImplementedError
147
+
148
+
149
+ def read_flow_flo(file_path: str) -> np.ndarray:
150
+ with open(file_path, 'rb') as f:
151
+
152
+ magic = np.fromfile(f, np.float32, count=1)
153
+ if 202021.25 != magic:
154
+ raise NotImplementedError
155
+
156
+ w = np.fromfile(f, np.int32, count=1)[0]
157
+ h = np.fromfile(f, np.int32, count=1)[0]
158
+
159
+ flow = np.fromfile(f, np.float32, count=2 * w * h)
160
+ flow = flow.reshape(h, w, 2)
161
+
162
+ return flow
163
+
164
+
165
+ def read_flow_png(file_path: str) -> np.ndarray:
166
+ img = cv2.imread(file_path, cv2.IMREAD_UNCHANGED).astype(np.float32)
167
+
168
+ # 确保图像有足够的通道
169
+ if len(img.shape) != 3 or img.shape[2] < 2:
170
+ raise NotImplementedError
171
+
172
+ u = (img[:, :, 2] - 32768.0) / 64.0 # R
173
+ v = (img[:, :, 1] - 32768.0) / 64.0 # G
174
+
175
+ flow = np.stack([u, v], axis=2)
176
+
177
+ return flow
178
+
179
+
180
+ def calculate_epe(flow_gt: np.ndarray, flow_pred: np.ndarray) -> Tuple[float, np.ndarray]:
181
+ if flow_gt.shape != flow_pred.shape:
182
+ raise NotImplementedError
183
+
184
+ diff = flow_gt - flow_pred
185
+ epe_map = np.sqrt(np.sum(diff ** 2, axis=2))
186
+
187
+ mean_epe = np.mean(epe_map)
188
+
189
+ return mean_epe, epe_map
190
+
191
+ class Sa2VAModel:
192
+ def __init__(self, model_name="ByteDance/Sa2VA-4B"):
193
+ self.model_name = model_name
194
+
195
+ model = AutoModel.from_pretrained(
196
+ model_name,
197
+ torch_dtype=torch.bfloat16,
198
+ low_cpu_mem_usage=True,
199
+ use_flash_attn=True,
200
+ trust_remote_code=True,
201
+ ).eval().cuda()
202
+
203
+ tokenizer = AutoTokenizer.from_pretrained(
204
+ model_name,
205
+ trust_remote_code=True,
206
+ )
207
+
208
+ self.model = model
209
+ self.tokenizer = tokenizer
210
+
211
+ def generate(self, input_dict):
212
+ pred_dict = self.model.predict_forward(**input_dict, tokenizer=self.tokenizer)
213
+ if 'prediction_masks' in pred_dict.keys() and pred_dict['prediction_masks'] and len(
214
+ pred_dict['prediction_masks']) != 0:
215
+ masks = pred_dict['prediction_masks'][0] # (f, h, w)
216
+ else:
217
+ masks = None
218
+ text_response = pred_dict["prediction"]
219
+ return text_response, masks
220
+
221
+ @dataclass
222
+ class Instance:
223
+ input: Dict[str, Any]
224
+ output: Dict[str, Any]
225
+ id: str
226
+
227
+
228
+ class BaseTask(ABC):
229
+ def __init__(self, task_data: Dict[str, Any], model):
230
+ self.task_data = task_data
231
+ self.model = model
232
+ self.data = self._parse_data(task_data)
233
+
234
+ @abstractmethod
235
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
236
+ pass
237
+
238
+ @abstractmethod
239
+ def evaluate(self) -> Dict[str, float]:
240
+ pass
241
+
242
+ @abstractmethod
243
+ def run_inference(self):
244
+ pass
245
+
246
+ def get_bbox_from_mask(mask):
247
+ if len(mask.shape) != 2:
248
+ raise NotImplementedError
249
+
250
+ y_indices, x_indices = np.nonzero(mask)
251
+
252
+ if len(x_indices) == 0 or len(y_indices) == 0:
253
+ return None
254
+
255
+ x_min = np.min(x_indices)
256
+ x_max = np.max(x_indices)
257
+ y_min = np.min(y_indices)
258
+ y_max = np.max(y_indices)
259
+
260
+ return (x_min, y_min, x_max-x_min, y_max-y_min)
261
+
262
+ def mask2bbox(masks, video_length):
263
+ if masks is None:
264
+ bboxes = [[0, 0, 0, 0]] * video_length
265
+ else:
266
+ bboxes = []
267
+ for mask in masks:
268
+ bbox = get_bbox_from_mask(mask)
269
+ if bbox is None:
270
+ bbox = [0, 0, 0, 0]
271
+ bboxes.append(bbox)
272
+ return bboxes
273
+
274
+ class MatchTask(BaseTask):
275
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
276
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
277
+ for d in task_data["data"]]
278
+
279
+ def run_inference(self):
280
+ self.predictions = []
281
+ self.references = []
282
+ for inst in tqdm.tqdm(self.data):
283
+ prompt = "<image>\n" + inst.input["prompt"]
284
+ video_folder = inst.input["video_folder"]
285
+ frame_files = [os.path.join(video_folder, _name) for _name in os.listdir(video_folder)]
286
+ video = []
287
+ for image_path in frame_files:
288
+ video.append(Image.open(image_path).convert('RGB'))
289
+
290
+ input_dict = {
291
+ "video": video,
292
+ "text": prompt,
293
+ }
294
+
295
+ response, _ = self.model.generate(input_dict, max_new_tokens=256)
296
+ response = response.split("<")[0].strip()
297
+
298
+ self.predictions.append(response)
299
+ self.references.append(inst.output["answer"])
300
+
301
+ def evaluate(self) -> Dict[str, float]:
302
+ acc = exact_match_accuracy(self.predictions, self.references)
303
+ return {"accuracy": acc}
304
+
305
+ class TrackingTask(BaseTask):
306
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
307
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
308
+ for d in task_data["data"]]
309
+
310
+ def run_inference(self):
311
+ self.predictions = []
312
+ self.references = []
313
+ for inst in tqdm.tqdm(self.data):
314
+ prompt = "<image>\n" + inst.input["prompt"]
315
+ video_folder = inst.input["video_folder"]
316
+ frame_files = [os.path.join(video_folder, _name) for _name in os.listdir(video_folder)]
317
+ video = []
318
+ for image_path in frame_files:
319
+ video.append(Image.open(image_path).convert('RGB'))
320
+
321
+ input_dict = {
322
+ "video": video,
323
+ "text": prompt,
324
+ }
325
+
326
+ response, masks = self.model.generate(input_dict, max_new_tokens=256)
327
+
328
+ bboxes = mask2bbox(masks, len(video))
329
+
330
+ self.predictions.append(bboxes)
331
+ self.references.append(inst.output["answer"])
332
+
333
+ def evaluate(self) -> Dict[str, float]:
334
+ j_f, n = 0, 1e-4
335
+ for pred_bboxes, gt_bboxes in zip(self.predictions, self.references):
336
+ metrics = calculate_j_and_f_metrics(pred_bboxes, gt_bboxes)
337
+ j_f += (metrics['J_metric'] + metrics['F_metric']) / 2.0
338
+ n += 1
339
+ j_f = j_f / n
340
+ return {"J&F": j_f}
341
+
342
+ class FlowTask(BaseTask):
343
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
344
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
345
+ for d in task_data["data"]]
346
+
347
+ def run_inference(self):
348
+ self.predictions = []
349
+ self.references = []
350
+ for inst in tqdm.tqdm(self.data):
351
+ prompt = "<image>\n" + inst.input["prompt"]
352
+ video_folder = inst.input["video_folder"]
353
+ frame_files = [os.path.join(video_folder, _name) for _name in os.listdir(video_folder)]
354
+ video = []
355
+ for image_path in frame_files:
356
+ video.append(Image.open(image_path).convert('RGB'))
357
+
358
+ input_dict = {
359
+ "video": video,
360
+ "text": prompt,
361
+ }
362
+
363
+ response, masks = self.model.generate(input_dict, max_new_tokens=256)
364
+
365
+ pred_flows = np.zeros(masks.shape[1], masks.shape[2], 2)
366
+
367
+ self.predictions.append(pred_flows)
368
+ self.references.append(read_flow(inst.output["flow"]))
369
+
370
+ def evaluate(self) -> Dict[str, float]:
371
+ EPE, n = 0, 1e-4
372
+ for pred_flow, gt_flow in zip(self.predictions, self.references):
373
+ mean_epe, _ = calculate_epe(pred_flow, gt_flow)
374
+ EPE += mean_epe
375
+ n += 1
376
+ EPE = EPE / n
377
+ return {"EPE": EPE}
378
+
379
+
380
+ def log_performance(model_name, task_name, metrics, root_path, output_file='performance_log.csv'):
381
+ import csv
382
+ file_exists = os.path.isfile(os.path.join(root_path, output_file))
383
+
384
+ row_data = {
385
+ 'model': model_name,
386
+ 'task': task_name,
387
+ 'metrics': str(metrics)
388
+ }
389
+
390
+ with open(os.path.join(root_path, output_file), mode='a', newline='', encoding='utf-8') as f:
391
+ writer = csv.DictWriter(f, fieldnames=row_data.keys())
392
+ if not file_exists:
393
+ writer.writeheader()
394
+
395
+ writer.writerow(row_data)
396
+
397
+
398
+ def log_performance_detail(model_name, task_name, metrics, root_path, output_file='performance_log.csv'):
399
+ import csv
400
+ file_path = os.path.join(root_path, output_file)
401
+ file_exists = os.path.isfile(file_path)
402
+
403
+ # 从metrics字典中获取主要指标值
404
+ metric_value = None
405
+ if isinstance(metrics, dict):
406
+ # 按照优先级选择指标
407
+ for key in ['accuracy', 'f1', 'micro_f1', 'bleu4', 'rougeL', 'code_bleu', 'MAE']:
408
+ if key in metrics:
409
+ metric_value = metrics[key]
410
+ break
411
+ if metric_value is None and len(metrics) > 0:
412
+ # 如果没有找到优先指标,使用第一个指标
413
+ metric_value = list(metrics.values())[0]
414
+ else:
415
+ metric_value = metrics
416
+
417
+ # 简化文件名,只保留最后一部分
418
+ model_name = model_name.split('/')[-1]
419
+
420
+ if file_exists:
421
+ # 读取现有数据
422
+ rows = []
423
+ tasks = set()
424
+ with open(file_path, 'r', newline='', encoding='utf-8') as f:
425
+ reader = csv.reader(f)
426
+ header = next(reader, ['task', model_name]) # 如果文件为空,使用默认表头
427
+ if len(header) == 1: # 如果只有task列,添加model列
428
+ header.append(model_name)
429
+ rows.append(header)
430
+
431
+ # 读取现有数据并更新
432
+ for row in reader:
433
+ if row[0] == task_name: # 如果找到相同任务,更新值
434
+ row = [task_name, str(metric_value)]
435
+ tasks.add(row[0])
436
+ rows.append(row)
437
+
438
+ # 如果是新任务,添加新行
439
+ if task_name not in tasks:
440
+ rows.append([task_name, str(metric_value)])
441
+ else:
442
+ # 创建新文件
443
+ rows = [
444
+ ['task', model_name],
445
+ [task_name, str(metric_value)]
446
+ ]
447
+
448
+ # 写入所有数据
449
+ with open(file_path, 'w', newline='', encoding='utf-8') as f:
450
+ writer = csv.writer(f)
451
+ writer.writerows(rows)
452
+
453
+ if __name__ == "__main__":
454
+ parser = argparse.ArgumentParser()
455
+ parser.add_argument("--root_path", type=str, default="General-Bench-Openset/video/comprehension")
456
+ parser.add_argument("--model_name", type=str, default="ByteDance/Sa2VA-4B")
457
+ args = parser.parse_args()
458
+ root_path = args.root_path
459
+ model_name = args.model_name
460
+
461
+ model = Sa2VAModel(model_name=model_name)
462
+
463
+ task_files = [
464
+ "AnimalTrack",
465
+ "GreenWaterTrack",
466
+ "LongVideoHumanTrack",
467
+ "RelationMatch",
468
+ "UAVUAVTrack",
469
+ "BallTrack",
470
+ "HumanPartTrack",
471
+ "LongVideoVehicleTrack",
472
+ "ShapeMatch",
473
+ "UAVVehicleTrack",
474
+ "BlueWaterTrack",
475
+ "HumanTrack",
476
+ "MotionMatch",
477
+ "SizeMatch",
478
+ "VehicleTrack",
479
+ "ColorMatch",
480
+ "LOGOMarkerMatch",
481
+ "ObjectMarkerMatch",
482
+ "SyntheticSceneFlowEstimate",
483
+ "WhiteWaterTrack",
484
+ "ComplexSceneFlowEstimate",
485
+ "LongVideoAnimalTrack",
486
+ "OtherPartTrack",
487
+ "UAVBuildingTrack",
488
+ "YellowWaterTrack",
489
+ "CrowdTrack",
490
+ "LongVideoCrowdTrack",
491
+ "PanoramicFlowEstimate",
492
+ "UAVGeneralObjectTrack",
493
+ "GeneralObjectTrack",
494
+ "LongVideoGeneralObjectTrack",
495
+ "PositionMatch",
496
+ "UAVHumanTrack"]
497
+
498
+ task_files = [w + '.json' if not w.endswith('json') else w for w in task_files]
499
+
500
+ if isinstance(task_files, str):
501
+ task_files = [task_files]
502
+
503
+ for idx, filename in enumerate(task_files):
504
+ file_path = os.path.join(root_path, f"{filename.replace('.json', '')}/", filename)
505
+ if not os.path.exists(file_path):
506
+ continue
507
+
508
+ with open(file_path, 'r', encoding='utf-8') as f:
509
+ task_data = json.load(f)
510
+
511
+ task_type = task_data["type"]
512
+ task_name = task_data["task"]
513
+ print(f"Running evaluation for task {idx + 1}: {task_name}")
514
+
515
+ # 定义任务类型与任务类的映射字典
516
+ TASK_MAPPING = {
517
+ "AnimalTrack": TrackingTask,
518
+ "GreenWaterTrack": TrackingTask,
519
+ "LongVideoHumanTrack": TrackingTask,
520
+ "RelationMatch": MatchTask,
521
+ "UAVUAVTrack": TrackingTask,
522
+ "BallTrack": TrackingTask,
523
+ "HumanPartTrack": TrackingTask,
524
+ "LongVideoVehicleTrack": TrackingTask,
525
+ "ShapeMatch": MatchTask,
526
+ "UAVVehicleTrack": TrackingTask,
527
+ "BlueWaterTrack": TrackingTask,
528
+ "HumanTrack": TrackingTask,
529
+ "MotionMatch": MatchTask,
530
+ "SizeMatch": MatchTask,
531
+ "VehicleTrack": TrackingTask,
532
+ "ColorMatch": MatchTask,
533
+ "LOGOMarkerMatch": MatchTask,
534
+ "ObjectMarkerMatch": MatchTask,
535
+ "SyntheticSceneFlowEstimate": FlowTask,
536
+ "WhiteWaterTrack": TrackingTask,
537
+ "ComplexSceneFlowEstimate": FlowTask,
538
+ "LongVideoAnimalTrack": TrackingTask,
539
+ "OtherPartTrack": TrackingTask,
540
+ "UAVBuildingTrack": TrackingTask,
541
+ "YellowWaterTrack": TrackingTask,
542
+ "CrowdTrack": TrackingTask,
543
+ "LongVideoCrowdTrack": TrackingTask,
544
+ "PanoramicFlowEstimate": FlowTask,
545
+ "UAVGeneralObjectTrack": TrackingTask,
546
+ "GeneralObjectTrack": TrackingTask,
547
+ "LongVideoGeneralObjectTrack": TrackingTask,
548
+ "PositionMatch": MatchTask,
549
+ "UAVHumanTrack": TrackingTask,
550
+ }
551
+
552
+ # 根据 task_type 获取对应的任务类
553
+ task_class = TASK_MAPPING.get(task_type) # 使用精确匹配
554
+ if task_class is None:
555
+ raise NotImplementedError
556
+ else:
557
+ task = task_class(task_data, model)
558
+
559
+ task.run_inference()
560
+ metrics = task.evaluate()
561
+ print("Task name: ", task_name, "Task type: ", task_type, "Evaluation results:", metrics)
562
+ log_performance(model_name, task_name, metrics, root_path)
predictors/video_comprehension_qa_caption.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+ from typing import List, Dict, Any
3
+ from dataclasses import dataclass
4
+ from abc import ABC, abstractmethod
5
+ from PIL import Image
6
+ import numpy as np
7
+ import os
8
+ import json
9
+ import argparse
10
+
11
+ import torch
12
+ from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
13
+ LlavaOnevisionForConditionalGeneration, AutoProcessor)
14
+
15
+ # An example of the model
16
+ class LLavaOneVisionModel:
17
+ def __init__(self, model_name="llava-hf/llava-onevision-qwen2-7b-ov-hf"):
18
+ self.model_name = model_name
19
+
20
+ model = LlavaOnevisionForConditionalGeneration.from_pretrained(
21
+ model_name,
22
+ torch_dtype=torch.float16,
23
+ low_cpu_mem_usage=True,
24
+ ).eval().cuda()
25
+
26
+ tokenizer = AutoTokenizer.from_pretrained(
27
+ model_name,
28
+ trust_remote_code=True
29
+ )
30
+
31
+ self.processor = AutoProcessor.from_pretrained(model_name)
32
+
33
+ self.model = model
34
+ self.tokenizer = tokenizer
35
+
36
+ def generate(self, conversation, video):
37
+ prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
38
+ inputs = self.processor(images=video, text=prompt, return_tensors="pt").to(self.model.device, torch.float16)
39
+ outputs = self.model.generate(**inputs, max_new_tokens=256, do_sample=False)
40
+ text_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
41
+ text_response = text_response.split('assistant\n')[1]
42
+
43
+ return text_response
44
+
45
+ @dataclass
46
+ class Instance:
47
+ input: Dict[str, Any]
48
+ output: Dict[str, Any]
49
+ id: str
50
+
51
+
52
+ class BaseTask(ABC):
53
+ def __init__(self, task_data: Dict[str, Any], model):
54
+ self.task_data = task_data
55
+ self.model = model
56
+ self.data = self._parse_data(task_data)
57
+
58
+ @abstractmethod
59
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
60
+ pass
61
+
62
+ @abstractmethod
63
+ def evaluate(self) -> Dict[str, float]:
64
+ pass
65
+
66
+ @abstractmethod
67
+ def run_inference(self):
68
+ pass
69
+
70
+
71
+ def cal_accuracy(predictions: List[str], references: List[str]) -> float:
72
+ correct = 0
73
+ for pred, ref in zip(predictions, references):
74
+ if isinstance(ref, str):
75
+ ref = [ref]
76
+ is_match_this_turn = False
77
+ for r in ref:
78
+ if "yes" in r.lower() or "no" in r.lower():
79
+ # for yes or no question
80
+ r = r.lower()
81
+ pred = pred.lower()
82
+
83
+ if r.strip() in pred.strip():
84
+ is_match_this_turn = True
85
+
86
+ if is_match_this_turn:
87
+ correct += 1
88
+ return correct / len(predictions) if predictions else 0.0
89
+
90
+
91
+ class Bleu1_Scorer():
92
+ def __init__(self, predictions, references):
93
+ from pycocoevalcap.bleu.bleu import Bleu
94
+ self.pred = predictions
95
+ self.gt = references
96
+ self.scorers = [
97
+ (Bleu(4), ['Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4']),
98
+ ]
99
+
100
+ def compute_scores(self):
101
+ total_scores = {}
102
+ for scorer, method in self.scorers:
103
+ print('Computing %s score...' % (scorer.method()))
104
+ score, scores = scorer.compute_score(self.gt, self.pred)
105
+ if isinstance(method, list):
106
+ for sc, scs, m in zip(score, scores, method):
107
+ print('%s: %0.3f' % (m, sc * 100))
108
+ total_scores['Bleu'] = [x * 100 for x in score]
109
+ else:
110
+ total_scores[method] = score * 100
111
+
112
+ return {"Bleu_1": total_scores['Bleu'][0]}
113
+
114
+
115
+ class AccTask(BaseTask):
116
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
117
+ self.task_name = task_data["task"]
118
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
119
+ for d in task_data["data"]]
120
+
121
+ def read_video_frames(self, data_path_list, root_path, max_frames_num=64):
122
+ frames = []
123
+ if len(data_path_list) > max_frames_num:
124
+ frame_idx = np.linspace(0, len(data_path_list) - 1, max_frames_num, dtype=int)
125
+ data_path_list = [data_path_list[i] for i in frame_idx]
126
+
127
+ for frame_path in data_path_list:
128
+ path = os.path.join(root_path, frame_path)
129
+ if os.path.exists(path):
130
+ try:
131
+ frame = Image.open(path)
132
+ frames.append(frame)
133
+ except Exception as e:
134
+ print(f"Warning: Failed to read frame {path}. Error: {e}")
135
+ else:
136
+ print(f"Warning: Frame path {path} does not exist.")
137
+ return frames
138
+
139
+
140
+ def run_inference(self, root_path):
141
+
142
+ if os.path.exists(f'./predictions_{self.task_name}.json'):
143
+ self.predictions = json.load(open(f'./predictions_{self.task_name}.json', 'r'))
144
+ self.references = json.load(open(f'./references_{self.task_name}.json', 'r'))
145
+ return
146
+
147
+ self.predictions = []
148
+ self.references = []
149
+ for inst in tqdm.tqdm(self.data):
150
+ video_path = inst.input['video_file_list']
151
+ video = self.read_video_frames(video_path, os.path.join(root_path, self.task_name, 'videos'), max_frames_num=64)
152
+
153
+ question = 'Please answer the following question related to the video. ' + inst.input['prompt']
154
+
155
+ other_requirements = ''
156
+ if 'VideoActionCounting' in self.task_name:
157
+ other_requirements = 'The output must consist only of Arabic numerals.'
158
+ if 'VideoActionOrdering' in self.task_name:
159
+ other_requirements = 'The output format must be: [num]->[num]->[num]->[num]. The number represents the index marked in the question. For example: 2->1->3->4, 1->2->3->4, 3->2->1->4...'
160
+ if 'SignLanguageVideoRecognition' in self.task_name:
161
+ other_requirements = 'The output format must be a word.'
162
+ question += other_requirements
163
+
164
+ conversation = [
165
+ {
166
+ "role": "user",
167
+ "content": [
168
+ {"type": "text", "text": question},
169
+ {"type": "video"},
170
+ ],
171
+ },
172
+ ]
173
+
174
+ text_response = self.model.generate(conversation, video)
175
+
176
+ self.predictions.append(text_response)
177
+ self.references.append(inst.output["text"])
178
+
179
+ json.dump(self.predictions, open(f'./predictions_{self.task_name}.json', 'w'))
180
+ json.dump(self.references, open(f'./references_{self.task_name}.json', 'w'))
181
+
182
+ def evaluate(self) -> Dict[str, float]:
183
+
184
+ acc = cal_accuracy(self.predictions, self.references)
185
+ return {"accuracy": acc*100}
186
+
187
+
188
+ class BLEUTASK(BaseTask):
189
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
190
+ self.task_name = task_data["task"]
191
+ return [Instance(input=d["input"], output=d["output"], id=d["id"])
192
+ for d in task_data["data"]]
193
+
194
+ def read_video_frames(self, data_path_list, root_path, max_frames_num=64):
195
+ frames = []
196
+ if len(data_path_list) > max_frames_num:
197
+ frame_idx = np.linspace(0, len(data_path_list) - 1, max_frames_num, dtype=int)
198
+ data_path_list = [data_path_list[i] for i in frame_idx]
199
+
200
+ for frame_path in data_path_list:
201
+ path = os.path.join(root_path, frame_path)
202
+ if os.path.exists(path):
203
+ try:
204
+ frame = Image.open(path)
205
+ frames.append(frame)
206
+ except Exception as e:
207
+ print(f"Warning: Failed to read frame {path}. Error: {e}")
208
+ else:
209
+ print(f"Warning: Frame path {path} does not exist.")
210
+ return frames
211
+
212
+
213
+ def run_inference(self, root_path):
214
+
215
+ if os.path.exists(f'./predictions_{self.task_name}.json'):
216
+ self.predictions = json.load(open(f'./predictions_{self.task_name}.json', 'r'))
217
+ self.references = json.load(open(f'./references_{self.task_name}.json', 'r'))
218
+ return
219
+
220
+ self.predictions = []
221
+ self.references = []
222
+ for inst in tqdm.tqdm(self.data):
223
+ video_path = inst.input['video_file_list']
224
+ video = self.read_video_frames(video_path, os.path.join(root_path, self.task_name, 'videos'), max_frames_num=64)
225
+
226
+ question = 'Please answer the following question related to the video. ' + inst.input['prompt']
227
+ other_requirements = ' The output should be concise. '
228
+ question += other_requirements
229
+
230
+ conversation = [
231
+ {
232
+ "role": "user",
233
+ "content": [
234
+ {"type": "text", "text": question},
235
+ {"type": "video"},
236
+ ],
237
+ },
238
+ ]
239
+
240
+ text_response = self.model.generate(conversation, video)
241
+
242
+ self.predictions.append(text_response)
243
+ self.references.append(inst.output["text"])
244
+
245
+ json.dump(self.predictions, open(f'./predictions_{self.task_name}.json', 'w'))
246
+ json.dump(self.references, open(f'./references_{self.task_name}.json', 'w'))
247
+
248
+ def evaluate(self) -> Dict[str, float]:
249
+
250
+ predictions = {}
251
+ references = {}
252
+
253
+ num = 1
254
+ for pred, ref in zip(self.predictions, self.references):
255
+ predictions[str(num)] = [pred.lower()]
256
+ references[str(num)] = [ref.lower()]
257
+ num += 1
258
+
259
+ bleu1_scorer = Bleu1_Scorer(predictions, references)
260
+ bleu1_scores = bleu1_scorer.compute_scores()
261
+ return bleu1_scores
262
+
263
+
264
+
265
+ def log_performance(model_name, task_name, metrics, root_path, output_file='performance_log.csv'):
266
+ import csv
267
+ file_exists = os.path.isfile(os.path.join(root_path, output_file))
268
+
269
+ row_data = {
270
+ 'model': model_name,
271
+ 'task': task_name,
272
+ 'metrics': str(metrics)
273
+ }
274
+
275
+ with open(os.path.join(root_path, output_file), mode='a', newline='', encoding='utf-8') as f:
276
+ writer = csv.DictWriter(f, fieldnames=row_data.keys())
277
+ if not file_exists:
278
+ writer.writeheader()
279
+
280
+ writer.writerow(row_data)
281
+
282
+
283
+ if __name__ == "__main__":
284
+ parser = argparse.ArgumentParser()
285
+ parser.add_argument("--root_path", type=str, default="General-Bench-Openset/video/comprehension")
286
+ parser.add_argument("--model_name", type=str, default="llava-hf/llava-onevision-qwen2-7b-ov-hf")
287
+ args = parser.parse_args()
288
+ root_path = args.root_path
289
+ model_name = args.model_name
290
+
291
+ model = LLavaOneVisionModel(model_name=model_name) # An example of the model
292
+
293
+ # 56 tasks
294
+ task_files = [
295
+ "AgricultureVideoQuestionAnswering",
296
+ "ArtRecognition",
297
+ "ArtsAndCraftsVideoCaptioning",
298
+ "AutosAndVehiclesVideoCaptioning",
299
+ "BallGameVideoQuestionAnswering",
300
+ "BallSportsVideoCaptioning",
301
+ "BodyMotionVideoCaptioning",
302
+ "BusinessVideoCaptioning",
303
+ "ComedyVideoQuestionAnswering",
304
+ "DailyLifeAndSkillsVideoCaptioning",
305
+ "EducationVideoQuestionAnswering",
306
+ "EntertainmentRelatedVideoCaptioning",
307
+ "FacialActionVideoCaptioning",
308
+ "FacialObjectOperationsVideoCaptioning",
309
+ "FinanceVideoCaptioning",
310
+ "FoodVideoCaptioning",
311
+ "GameVideoQuestionAnswering",
312
+ "GeographyVideoQuestionAnswering",
313
+ "GymnasticsVideoQuestionAnswering",
314
+ "HistoryAndLiteratureVideoCaptioning",
315
+ "HumanHumanInteractionVideoCaptioning",
316
+ "HumanObjectInteractionVideoCaptioning",
317
+ "HumanObjectInteractionVideoQuestionAnswering",
318
+ "HumanSurvivalVideoQuestionAnswering",
319
+ "HumorVideoCaptioning",
320
+ "MilitaryVideoQuestionAnswering",
321
+ "MovieAndShowVideoCaptioning",
322
+ "MovieVideoQuestionAnswering",
323
+ "MusicalInstrumentsVideoCaptioning",
324
+ "MusicVideoQuestionAnswering",
325
+ "NaturalDisasterVideoRecognition",
326
+ "NewsAndDocumentaryVideoCaptioning",
327
+ "ObjectColorVideoQuestionAnswering",
328
+ "ObjectDirectionVideoQuestionAnswering",
329
+ "ObjectLocationVideoQuestionAnswering",
330
+ "ObjectMotionVideoQuestionAnswering",
331
+ "PersonalCareVideoCaptioning",
332
+ "PetsVideoQuestionAnswering",
333
+ "PetsVideoRecognition",
334
+ "ScienceAndTechnologyVideoCaptioning",
335
+ "ScienceVideoQuestionAnswering",
336
+ "ScienceVideoRecognition",
337
+ "SignLanguageVideoRecognition",
338
+ "SportsAndExcerciseVideoCaptioning",
339
+ "SportsVideoQuestionAnswering",
340
+ "TVShowRecognition",
341
+ "VideoActionCounting",
342
+ "VideoActionOrdering",
343
+ "VideoActionSequencePrediction",
344
+ "VideoActionSequenceUnderstanding",
345
+ "VideoAnimalRecognition",
346
+ "VideoFoodRecognition",
347
+ "VideoObjectCounting",
348
+ "VideoObjectExistenceRecognition",
349
+ "VideoObjectInteractionRecognition",
350
+ "VideoSportsRecognition",
351
+ ]
352
+
353
+ task_files = [w + '.json' if not w.endswith('json') else w for w in task_files]
354
+
355
+ if isinstance(task_files, str):
356
+ task_files = [task_files]
357
+
358
+ for idx, filename in enumerate(task_files):
359
+ file_path = os.path.join(root_path, f"{filename.replace('.json', '')}/", "annotation.json")
360
+
361
+ if not os.path.exists(file_path):
362
+ continue
363
+
364
+ with open(file_path, 'r', encoding='utf-8') as f:
365
+ task_data = json.load(f)
366
+
367
+ task_type = task_data["type"]
368
+ task_name = task_data["task"]
369
+ print(f"Running evaluation for task {idx + 1}: {task_name}")
370
+
371
+ TASK_MAPPING = {
372
+ "AgricultureVideoQuestionAnswering": BLEUTASK,
373
+ "ArtRecognition": AccTask,
374
+ "ArtsAndCraftsVideoCaptioning": BLEUTASK,
375
+ "AutosAndVehiclesVideoCaptioning": BLEUTASK,
376
+ "BallGameVideoQuestionAnswering": AccTask,
377
+ "BallSportsVideoCaptioning": BLEUTASK,
378
+ "BodyMotionVideoCaptioning": BLEUTASK,
379
+ "BusinessVideoCaptioning": BLEUTASK,
380
+ "ComedyVideoQuestionAnswering": BLEUTASK,
381
+ "DailyLifeAndSkillsVideoCaptioning": BLEUTASK,
382
+ "EducationVideoQuestionAnswering": AccTask,
383
+ "EntertainmentRelatedVideoCaptioning": BLEUTASK,
384
+ "FacialActionVideoCaptioning": BLEUTASK,
385
+ "FacialObjectOperationsVideoCaptioning": BLEUTASK,
386
+ "FinanceVideoCaptioning": BLEUTASK,
387
+ "FoodVideoCaptioning": BLEUTASK,
388
+ "GameVideoQuestionAnswering": BLEUTASK,
389
+ "GeographyVideoQuestionAnswering": BLEUTASK,
390
+ "GymnasticsVideoQuestionAnswering": AccTask,
391
+ "HistoryAndLiteratureVideoCaptioning": BLEUTASK,
392
+ "HumanHumanInteractionVideoCaptioning": BLEUTASK,
393
+ "HumanObjectInteractionVideoCaptioning": BLEUTASK,
394
+ "HumanObjectInteractionVideoQuestionAnswering": BLEUTASK,
395
+ "HumanSurvivalVideoQuestionAnswering": BLEUTASK,
396
+ "HumorVideoCaptioning": BLEUTASK,
397
+ "MilitaryVideoQuestionAnswering": BLEUTASK,
398
+ "MovieAndShowVideoCaptioning": BLEUTASK,
399
+ "MovieVideoQuestionAnswering": BLEUTASK,
400
+ "MusicalInstrumentsVideoCaptioning": BLEUTASK,
401
+ "MusicVideoQuestionAnswering": BLEUTASK,
402
+ "NaturalDisasterVideoRecognition": BLEUTASK,
403
+ "NewsAndDocumentaryVideoCaptioning": BLEUTASK,
404
+ "ObjectColorVideoQuestionAnswering": AccTask,
405
+ "ObjectDirectionVideoQuestionAnswering": BLEUTASK,
406
+ "ObjectLocationVideoQuestionAnswering": AccTask,
407
+ "ObjectMotionVideoQuestionAnswering": AccTask,
408
+ "PersonalCareVideoCaptioning": BLEUTASK,
409
+ "PetsVideoQuestionAnswering": BLEUTASK,
410
+ "PetsVideoRecognition": BLEUTASK,
411
+ "ScienceAndTechnologyVideoCaptioning": BLEUTASK,
412
+ "ScienceVideoQuestionAnswering": BLEUTASK,
413
+ "ScienceVideoRecognition": BLEUTASK,
414
+ "SignLanguageVideoRecognition": AccTask,
415
+ "SportsAndExcerciseVideoCaptioning": BLEUTASK,
416
+ "SportsVideoQuestionAnswering": BLEUTASK,
417
+ "TVShowRecognition": AccTask,
418
+ "VideoActionCounting": AccTask,
419
+ "VideoActionOrdering": AccTask,
420
+ "VideoActionSequencePrediction": BLEUTASK,
421
+ "VideoActionSequenceUnderstanding": BLEUTASK,
422
+ "VideoAnimalRecognition": AccTask,
423
+ "VideoFoodRecognition": AccTask,
424
+ "VideoObjectCounting": BLEUTASK,
425
+ "VideoObjectExistenceRecognition": BLEUTASK,
426
+ "VideoObjectInteractionRecognition": BLEUTASK,
427
+ "VideoSportsRecognition": AccTask,
428
+ }
429
+
430
+ task_class = TASK_MAPPING.get(task_name)
431
+ if task_class is None:
432
+ raise NotImplementedError
433
+ else:
434
+ task = task_class(task_data, model)
435
+
436
+ task.run_inference(root_path=root_path)
437
+ metrics = task.evaluate()
438
+
439
+ print("Task name: ", task_name, "Task type: ", task_type, "Evaluation results:", metrics)
440
+ log_performance(model_name, task_name, metrics, '../outcome/', output_file='video_comprehension_qa_caption_performance_log.csv')
441
+
442
+
443
+
predictors/video_comprehension_tasks.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass
5
+ import os
6
+ from typing import Dict, Any, List
7
+ import json
8
+ import torch
9
+ import tqdm
10
+ import argparse
11
+
12
+
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+ from PIL import Image
15
+ import pycocotools.mask as mask_util
16
+ import numpy as np
17
+
18
+
19
+ PREFIX = 'data'
20
+
21
+ PROMPT = {
22
+ 'VOS': '<image>\nPlease segment the major object in the video.',
23
+ 'RVOS': '<image>\nPlease segment {}.',
24
+ 'ActionDet': '<image>\nPlease detect {}.',
25
+ 'VDE': '<image>\nPlease generate the depth map of the video.',
26
+ }
27
+
28
+
29
+ @dataclass
30
+ class Instance:
31
+ input: Dict[str, Any]
32
+ output: Dict[str, Any]
33
+ id: str
34
+
35
+
36
+ class BaseTask(ABC):
37
+ def __init__(self, task_data: str, model):
38
+ self.task_data = task_data
39
+ self.model = model
40
+ self.task_name = os.path.basename(task_data)
41
+
42
+
43
+ self.data = self._parse_data(task_data)
44
+
45
+ @abstractmethod
46
+ def _parse_data(self, task_data: str) -> List[Instance]:
47
+ pass
48
+
49
+ @abstractmethod
50
+ def evaluate(self, results:List[Instance]) -> Dict[str, float]:
51
+ pass
52
+
53
+ @abstractmethod
54
+ def run_inference(self) -> List[Instance]:
55
+ pass
56
+
57
+
58
+ class TaskVOS(BaseTask):
59
+
60
+ def _load_video(self, video_path: str) -> List[Image.Image]:
61
+ video_frames = []
62
+ for frame_file in sorted(os.listdir(video_path)):
63
+ if frame_file.endswith('.jpg') or frame_file.endswith('.png'):
64
+ frame_path = os.path.join(video_path, frame_file)
65
+ video_frames.append(Image.open(frame_path).convert('RGB'))
66
+ return video_frames
67
+
68
+
69
+ def _parse_data(self, task_data: str) -> List[Instance]:
70
+ json_path = os.path.join(task_data, 'annotation.json')
71
+ json_data = json.load(open(json_path, 'r'))
72
+
73
+ results = []
74
+ json_data_data = json_data['data']
75
+ for json_item in json_data_data:
76
+ input_dict = {}
77
+ input_dict['video_folder'] = json_item['input']['video_folder']
78
+ input_dict['video'] = self._load_video(os.path.join(task_data, input_dict['video_folder']))
79
+
80
+ output_dict = {}
81
+ output_dict['serilized_masks'] = json_item['output']
82
+ output_dict['masks'] = []
83
+ for mask_id, mask_data in output_dict['serilized_masks'].items():
84
+ mask = mask_util.decode(mask_data['mask'])
85
+ output_dict['masks'].append(mask)
86
+ instance_id = json_item['id']
87
+ results.append(Instance(input=input_dict, output=output_dict, id=instance_id))
88
+ return results
89
+
90
+
91
+
92
+ def evaluate(self, results:List[Instance]) -> Dict[str, float]:
93
+ iou_list = []
94
+ for instance in results:
95
+ masks = instance.output['masks']
96
+ prediction_masks = instance.output['prediction_masks']
97
+
98
+ assert len(masks) == len(prediction_masks), "Number of masks and prediction masks do not match."
99
+
100
+ intersection = 0.
101
+ union = 0.
102
+ for gt_mask, pred_mask in zip(masks, prediction_masks):
103
+ intersection += (gt_mask.astype(bool) & pred_mask.astype(bool)).sum()
104
+ union += (gt_mask | pred_mask).sum()
105
+ iou = intersection / union if union > 0 else 0.0
106
+ iou_list.append(iou)
107
+ iou_mean = np.mean(iou_list).item() * 100
108
+ return {"IoU": iou_mean}
109
+
110
+ def run_inference(self) -> List[Instance]:
111
+ results = []
112
+ for instance in tqdm.tqdm(self.data, desc=f"Running inference on {self.task_name}"):
113
+ input_data = instance.input
114
+
115
+ result = self.model.predict_forward(
116
+ video=input_data['video'],
117
+ text=PROMPT['VOS'],
118
+ )
119
+
120
+ # output postprocessing
121
+ output_masks = result['prediction_masks']
122
+
123
+ instance.output['prediction_masks'] = output_masks[0]
124
+ results.append(instance)
125
+ return results
126
+
127
+
128
+ class TaskRVOS(BaseTask):
129
+ def _load_video(self, video_path: str) -> List[Image.Image]:
130
+ video_frames = []
131
+ for frame_file in sorted(os.listdir(video_path)):
132
+ if frame_file.endswith('.jpg') or frame_file.endswith('.png'):
133
+ frame_path = os.path.join(video_path, frame_file)
134
+ video_frames.append(Image.open(frame_path).convert('RGB'))
135
+ return video_frames
136
+
137
+
138
+ def _parse_data(self, task_data: str) -> List[Instance]:
139
+ json_path = os.path.join(task_data, 'annotation.json')
140
+ json_data = json.load(open(json_path, 'r'))
141
+
142
+ results = []
143
+ json_data_data = json_data['data']
144
+ for json_item in json_data_data:
145
+ input_dict = {}
146
+ input_dict['video_folder'] = json_item['input']['video_folder']
147
+ input_dict['video'] = self._load_video(os.path.join(task_data, input_dict['video_folder']))
148
+ input_dict['prompt'] = json_item['input']['prompt']
149
+
150
+ output_dict = {}
151
+ output_dict['serilized_masks'] = json_item['output']
152
+ output_dict['masks'] = []
153
+ for mask_id, mask_data in output_dict['serilized_masks'].items():
154
+ mask = mask_util.decode(mask_data['mask'])
155
+ output_dict['masks'].append(mask)
156
+ instance_id = json_item['id']
157
+ results.append(Instance(input=input_dict, output=output_dict, id=instance_id))
158
+ return results
159
+
160
+
161
+
162
+ def evaluate(self, results:List[Instance]) -> Dict[str, float]:
163
+ iou_list = []
164
+ for instance in results:
165
+ masks = instance.output['masks']
166
+ prediction_masks = instance.output['prediction_masks']
167
+
168
+ assert len(masks) == len(prediction_masks), "Number of masks and prediction masks do not match."
169
+
170
+ intersection = 0.
171
+ union = 0.
172
+ for gt_mask, pred_mask in zip(masks, prediction_masks):
173
+ intersection += (gt_mask.astype(bool) & pred_mask.astype(bool)).sum()
174
+ union += (gt_mask | pred_mask).sum()
175
+ iou = intersection / union if union > 0 else 0.0
176
+ iou_list.append(iou)
177
+ iou_mean = np.mean(iou_list).item() * 100
178
+ return {"IoU": iou_mean}
179
+
180
+ def run_inference(self) -> List[Instance]:
181
+ results = []
182
+ for instance in tqdm.tqdm(self.data, desc=f"Running inference on {self.task_name}"):
183
+ input_data = instance.input
184
+
185
+ result = self.model.predict_forward(
186
+ video=input_data['video'],
187
+ text=PROMPT['RVOS'].format(input_data['prompt']),
188
+ )
189
+
190
+ # output postprocessing
191
+ output_masks = result['prediction_masks']
192
+
193
+ instance.output['prediction_masks'] = output_masks[0]
194
+ results.append(instance)
195
+ return results
196
+
197
+
198
+
199
+ class TaskActionDet(BaseTask):
200
+ def _load_video(self, video_path: str) -> List[Image.Image]:
201
+ import cv2
202
+ cap = cv2.VideoCapture(video_path)
203
+ img_list = []
204
+ while cap.isOpened():
205
+ ret, frame = cap.read()
206
+ if not ret:
207
+ break
208
+
209
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
210
+ img_list.append(Image.fromarray(frame).convert('RGB'))
211
+
212
+ return img_list
213
+
214
+
215
+ def _parse_data(self, task_data: str) -> List[Instance]:
216
+ if self.task_name in ['AnimalVG', 'AutoVG', 'HumanVG']:
217
+ self.is_vg = True
218
+ else:
219
+ self.is_vg = False
220
+
221
+ json_path = os.path.join(task_data, 'annotation.json')
222
+ json_data = json.load(open(json_path, 'r'))
223
+
224
+ results = []
225
+ json_data_data = json_data['data']
226
+ for json_item in json_data_data:
227
+ video_path = os.path.join(self.task_data, 'videos', json_item['video_path'])
228
+ image_list = self._load_video(video_path)
229
+ assert len(image_list) > 0, f"Video {video_path} has no frames."
230
+ if len(image_list) != json_item['frame_count']:
231
+ print(f"Warning: Frame count mismatch for video {video_path}. Expected {json_item['frame_count']}, got {len(image_list)}.")
232
+ while len(image_list) < json_item['frame_count']:
233
+ image_list.append(image_list[-1])
234
+ input_dict = {}
235
+ input_dict['video'] = image_list
236
+ input_dict['prompt'] = json_item['caption']
237
+
238
+ output_dict = {}
239
+ if self.is_vg:
240
+ output_dict['tube_start_frame'] = json_item['tube_start_frame']
241
+ output_dict['tube_end_frame'] = json_item['tube_end_frame']
242
+ else:
243
+ output_dict['tube_start_frame'] = json_item['tube_start_frame'] - 1
244
+ output_dict['tube_end_frame'] = json_item['tube_end_frame'] - 1
245
+
246
+ trajectory = json_item['trajectory']
247
+
248
+ if self.is_vg:
249
+ trajectory = [trajectory[frame_id_str]['bbox'] for frame_id_str in trajectory if output_dict['tube_start_frame'] <= int(frame_id_str) < output_dict['tube_end_frame']]
250
+
251
+ assert len(trajectory) == output_dict['tube_end_frame'] - output_dict['tube_start_frame']
252
+ bboxes = []
253
+ for _ in range(output_dict['tube_start_frame']):
254
+ bboxes.append([0, 0, 0, 0])
255
+
256
+ # trajectory is a list of [x, y, w, h] for each frame
257
+ for item in trajectory:
258
+ x, y, w, h = item
259
+ bbox = [x, y, x + w, y + h]
260
+ bboxes.append(bbox)
261
+
262
+ for _ in range(output_dict['tube_end_frame'], len(image_list)):
263
+ bboxes.append([0, 0, 0, 0])
264
+ output_dict['bboxes'] = bboxes
265
+
266
+ instance_id = json_item['original_video_id']
267
+ results.append(Instance(input=input_dict, output=output_dict, id=instance_id))
268
+ return results
269
+
270
+ def evaluate(self, results:List[Instance]) -> Dict[str, float]:
271
+ iou_list = []
272
+ for instance in results:
273
+ boxes = instance.output['bboxes']
274
+ prediction_boxes = instance.output['prediction_boxes']
275
+ assert len(boxes) == len(prediction_boxes), "Number of boxes and prediction boxes do not match."
276
+ iou = 0.
277
+ frame_union = 0
278
+ for gt_box, pred_box in zip(boxes, prediction_boxes):
279
+ gt_box = np.array(gt_box)
280
+ pred_box = np.array(pred_box)
281
+
282
+ if np.all(gt_box == 0) and np.all(pred_box == 0):
283
+ continue
284
+ frame_union += 1
285
+ if np.all(gt_box == 0) or np.all(pred_box == 0):
286
+ continue
287
+
288
+ intersection = np.maximum(0, np.minimum(gt_box[2:], pred_box[2:]) - np.maximum(gt_box[:2], pred_box[:2]))
289
+ intersection_area = intersection[0] * intersection[1]
290
+ gt_area = (gt_box[2] - gt_box[0]) * (gt_box[3] - gt_box[1])
291
+ pred_area = (pred_box[2] - pred_box[0]) * (pred_box[3] - pred_box[1])
292
+ union_area = gt_area + pred_area - intersection_area
293
+ iou += intersection_area / union_area
294
+ if frame_union > 0:
295
+ iou /= frame_union
296
+ iou_list.append(iou)
297
+ iou_mean = np.mean(iou_list).item() * 100
298
+ return {"vIoU": iou_mean}
299
+
300
+ def run_inference(self) -> List[Instance]:
301
+ results = []
302
+ for instance in tqdm.tqdm(self.data, desc=f"Running inference on {self.task_name}"):
303
+ input_data = instance.input
304
+
305
+ result = self.model.predict_boxes(
306
+ video=input_data['video'],
307
+ text=PROMPT['ActionDet'].format(input_data['prompt']),
308
+ )
309
+
310
+ # output postprocessing
311
+ output_masks = result['prediction_boxes']
312
+ instance.output['prediction_boxes'] = output_masks[0]
313
+ results.append(instance)
314
+ return results
315
+
316
+
317
+
318
+ class TaskVDE(BaseTask):
319
+ def _load_video(self, video_path: str) -> List[Image.Image]:
320
+ import cv2
321
+ cap = cv2.VideoCapture(video_path)
322
+ img_list = []
323
+ while cap.isOpened():
324
+ ret, frame = cap.read()
325
+ if not ret:
326
+ break
327
+
328
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
329
+ img_list.append(Image.fromarray(frame).convert('RGB'))
330
+
331
+ return img_list
332
+
333
+ def _parse_data(self, task_data: str) -> List[Instance]:
334
+ json_path = os.path.join(task_data, 'annotation.json')
335
+ json_data = json.load(open(json_path, 'r'))
336
+
337
+ results = []
338
+ json_data_data = json_data['data']
339
+ for json_item in json_data_data:
340
+ video_path = os.path.join(self.task_data, 'video', json_item['input'])
341
+ annotation_path = os.path.join(self.task_data, 'depth', json_item['output'])
342
+ instance_id = json_item['id']
343
+
344
+ assert os.path.exists(video_path), f"Video path {video_path} does not exist."
345
+ assert os.path.exists(annotation_path), f"Annotation path {annotation_path} does not exist"
346
+
347
+
348
+ input_dict = {}
349
+ input_dict['video'] = self._load_video(video_path)
350
+
351
+ output_dict = {}
352
+ output_dict['depth_map'] = np.load(annotation_path)['disparity'] # nf, 1, h, w
353
+ assert len(input_dict['video']) == output_dict['depth_map'].shape[0], "Number of video frames and depth map frames do not match."
354
+ assert output_dict['depth_map'].ndim == 4, "Depth map should be 4-dimensional (nf, 1, h, w)."
355
+ assert input_dict['video'][0].size == (output_dict['depth_map'].shape[3], output_dict['depth_map'].shape[2]), "Video frame size does not match depth map size."
356
+ results.append(Instance(input=input_dict, output=output_dict, id=instance_id))
357
+ return results
358
+
359
+
360
+ def _abs_relative_difference(self, output, target, valid_mask=None):
361
+ actual_output = output[valid_mask]
362
+ actual_target = target[valid_mask]
363
+ abs_relative_diff = np.abs(actual_output - actual_target) / actual_target
364
+ return abs_relative_diff.mean()
365
+
366
+ def evaluate(self, results:List[Instance]) -> Dict[str, float]:
367
+ abs_rel_list = []
368
+ dataset_max_depth = 80
369
+ for instance in results:
370
+ depth_map = instance.output['depth_map']
371
+ prediction_depth = instance.output['prediction_depth']
372
+
373
+ assert depth_map.shape == prediction_depth.shape, "Depth map and prediction depth shape do not match."
374
+
375
+ # Calculate absolute relative error
376
+ gt_disp = depth_map[:, 0]
377
+ pred_disp = prediction_depth[:, 0]
378
+ # valid mask
379
+ valid_mask = np.logical_and(
380
+ (gt_disp > 1e-3),
381
+ (gt_disp < dataset_max_depth)
382
+ )
383
+ pred_disp = np.clip(pred_disp, a_min=1e-3, a_max=None)
384
+ pred_disp_masked = pred_disp[valid_mask].reshape((-1, 1))
385
+
386
+
387
+ gt_disp_maksed = gt_disp[valid_mask].reshape((-1, 1)).astype(np.float64)
388
+ # calc scale and shift
389
+ _ones = np.ones_like(pred_disp_masked)
390
+ A = np.concatenate([pred_disp_masked, _ones], axis=-1)
391
+ X = np.linalg.lstsq(A, gt_disp_maksed, rcond=None)[0]
392
+ scale, shift = X # gt = scale * pred + shift
393
+
394
+ # align
395
+ aligned_pred = scale * pred_disp + shift
396
+ aligned_pred = np.clip(aligned_pred, a_min=1e-3, a_max=None)
397
+
398
+
399
+ pred_depth = aligned_pred
400
+ gt_depth = gt_disp
401
+
402
+ # metric evaluation, clip to dataset min max
403
+ pred_depth = np.clip(
404
+ pred_depth, a_min=1e-3, a_max=dataset_max_depth
405
+ )
406
+ abs_rel = self._abs_relative_difference(
407
+ pred_depth,
408
+ gt_depth,
409
+ valid_mask=valid_mask
410
+ )
411
+ abs_rel_list.append(abs_rel)
412
+
413
+ abs_rel_mean = np.mean(abs_rel_list).item()
414
+
415
+
416
+ def sigmoid(x):
417
+ return 1 / (1 + np.exp(-x))
418
+ score = (sigmoid(0.1 / (abs_rel_mean + 1e-6)) * 2 - 1) * 100
419
+ return {"absRel": abs_rel_mean, "score": score}
420
+
421
+
422
+ def run_inference(self) -> List[Instance]:
423
+ results = []
424
+ for instance in tqdm.tqdm(self.data, desc=f"Running inference on {self.task_name}"):
425
+ input_data = instance.input
426
+
427
+ result = self.model.predict_depth(
428
+ video=input_data['video'],
429
+ text=PROMPT['VDE'],
430
+ )
431
+
432
+ # output postprocessing
433
+ depth_map = result['prediction_depth']
434
+ instance.output['prediction_depth'] = depth_map
435
+ results.append(instance)
436
+ return results
437
+
438
+
439
+ tasks = {
440
+ 'AnimalVOS': TaskVOS,
441
+ 'AutoVOS':TaskVOS,
442
+ 'HumanVOS':TaskVOS,
443
+ 'SportsVOS':TaskVOS,
444
+
445
+ ## IW
446
+ 'IWAnimalVOS':TaskVOS,
447
+ 'IWAutoVOS':TaskVOS,
448
+ 'IWFurnitureVOS':TaskVOS,
449
+ 'IWHumanVOS':TaskVOS,
450
+
451
+ ## Street
452
+ 'AutoStreetVOS':TaskVOS,
453
+ 'BicycleStreetVOS':TaskVOS,
454
+ 'HumanStreetVOS':TaskVOS,
455
+
456
+ # RVOS
457
+ 'AnimalRVOS':TaskRVOS,
458
+ 'HumanRVOS':TaskRVOS,
459
+
460
+ ## ReVOS,
461
+ 'AnimalReVOS':TaskRVOS,
462
+ 'AutoReVOS': TaskRVOS,
463
+ 'HumanReVOS': TaskRVOS,
464
+
465
+ ## CReVOS
466
+ 'AnimalCReVOS': TaskRVOS,
467
+ 'AutoCReVOS' : TaskRVOS,
468
+ 'HumanCReVOS': TaskRVOS,
469
+ 'HumanPartCReVOS': TaskRVOS,
470
+ 'EquipmentCReVOS': TaskRVOS,
471
+
472
+
473
+ ## Action Det
474
+ # V-C-10 HCSTVG2
475
+ 'StaticActionDet': TaskActionDet,
476
+ 'DynamicActionDet': TaskActionDet,
477
+ # V-C-12 VidSTG
478
+ 'AnimalVG': TaskActionDet,
479
+ 'AutoVG': TaskActionDet,
480
+ 'HumanVG': TaskActionDet,
481
+
482
+ ## VDE
483
+ 'StaticVDE': TaskVDE,
484
+ 'StreetVDE': TaskVDE,
485
+ 'SynVDE': TaskVDE,
486
+ 'DynamicVDE': TaskVDE,
487
+ }
488
+
489
+
490
+
491
+ def predict_dummy_boxes(video, text):
492
+ # Dummy function to simulate box prediction
493
+ # In practice, this should call the model's prediction method
494
+ num_frames = len(video)
495
+ return {
496
+ 'prediction_boxes': [
497
+ [[0,0, 100, 100]] * num_frames, # Example boxes, [0, 0, 0, 0] is empty box
498
+ ]
499
+ }
500
+
501
+
502
+ def predict_dummy_depth(video, text):
503
+ # Dummy function to simulate depth prediction
504
+ # In practice, this should call the model's prediction method
505
+ num_frames = len(video)
506
+ width, height = video[0].size
507
+ return {
508
+ 'prediction_depth': np.random.rand(num_frames, 1, height, width).astype(np.float32) * 80 # Random depth values
509
+ }
510
+
511
+
512
+ def main(root:str, model_path:str):
513
+ metrics = {}
514
+
515
+ model = AutoModelForCausalLM.from_pretrained(
516
+ model_path,
517
+ torch_dtype=torch.bfloat16,
518
+ low_cpu_mem_usage=True,
519
+ use_flash_attn=True,
520
+ trust_remote_code=True,
521
+ ).eval().cuda()
522
+ tokenizer = AutoTokenizer.from_pretrained(
523
+ model_path,
524
+ trust_remote_code=True
525
+ )
526
+ model.preparing_for_generation(tokenizer=tokenizer)
527
+
528
+ model.predict_boxes = predict_dummy_boxes
529
+ model.predict_depth = predict_dummy_depth
530
+
531
+ for task_name in tasks:
532
+ task_class = tasks[task_name]
533
+ task_data_path = os.path.join(root, task_name)
534
+ task_instance = task_class(task_data=task_data_path, model=model)
535
+
536
+ results = task_instance.run_inference()
537
+ evaluation_results = task_instance.evaluate(results)
538
+ metrics[task_instance.task_name] = evaluation_results
539
+
540
+ print(metrics)
541
+
542
+
543
+ if __name__ == "__main__":
544
+ # root = os.path.join(PREFIX, "General-Bench-Openset/video/comprehension")
545
+ import argparse
546
+ parser = argparse.ArgumentParser(description="Run video tasks evaluation.")
547
+ parser.add_argument("--model_path", type=str, default='ByteDance/Sa2VA-4B', required=False, help="Model to use for evaluation")
548
+ parser.add_argument("--root_path", type=str, default="General-Bench-Openset/video/comprehension", required=False, help="Root path to the dataset")
549
+ args = parser.parse_args()
550
+ main(args.root_path, args.model_path)
predictors/video_generation_evaluate_kit.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ from typing import List, Dict, Any
3
+ from dataclasses import dataclass
4
+ from abc import ABC, abstractmethod
5
+ from PIL import Image
6
+ from pathlib import Path
7
+ import numpy as np
8
+ import cv2
9
+ import clip
10
+ import torch
11
+ from torch import nn
12
+ import torch.nn.functional as F
13
+
14
+ from typing import Tuple
15
+ import os
16
+ import json
17
+ from diffusers import CogVideoXPipeline
18
+ from diffusers.utils import export_to_video
19
+ from video_generation_evaluation.toolkit.fvd import get_dataset_features, I3DFeatureExtractor
20
+ from numpy import cov
21
+ from numpy import mean
22
+ from scipy.linalg import sqrtm
23
+ from video_generation_evaluation.evaluate import task2dimension
24
+
25
+
26
+ class BaseTask(ABC):
27
+ def __init__(self, task_data: str, model):
28
+ self.task_data = task_data
29
+ self.model = model
30
+ self.data = self._parse_data(task_data)
31
+
32
+ @abstractmethod
33
+ def _parse_data(self, task_data: Dict[str, Any]):
34
+ pass
35
+
36
+ @abstractmethod
37
+ def evaluate(self) -> Dict[str, float]:
38
+ pass
39
+
40
+ @abstractmethod
41
+ def run_inference(self):
42
+ pass
43
+
44
+ class T2VTask(BaseTask):
45
+ def _parse_result_file(self, output_dir: Path) -> float | None:
46
+ for jsonfile in output_dir.iterdir():
47
+ if "eval" in jsonfile.name:
48
+ with open(jsonfile.as_posix(), "r") as file:
49
+ data = json.load(file)
50
+
51
+ return float(data[self.taskname][0])
52
+
53
+ def _parse_data(self, task_data):
54
+ with open(task_data, "r") as file:
55
+ annos = json.load(file)
56
+ taskname = annos["task"].replace(" ", "")
57
+ self.taskname = taskname
58
+ self.save_root = os.path.join("General-Bench", "Video-Generation", taskname)
59
+ return annos["data"]
60
+
61
+ def run_inference(self):
62
+ for d in self.data:
63
+ prompt = d["input"]["prompt"]
64
+ for i in range(5):
65
+ video = self.model(prompt, generator=torch.Generator(self.model.device).manual_seed(i)).frames[0]
66
+ save_name = prompt + "-" + str(i) + ".mp4"
67
+ save_path = os.path.join(self.save_root, save_name)
68
+ export_to_video(video, save_path, fps=8)
69
+
70
+ class FVDEval(T2VTask):
71
+ def evaluate(self, real_video_root):
72
+ model = I3DFeatureExtractor().cuda().eval()
73
+
74
+ real_features = get_dataset_features(real_video_root, model)
75
+ generated_features = get_dataset_features(self.save_root, model)
76
+
77
+ mu_real = mean(real_features, axis=0)
78
+ mu_generated = mean(generated_features, axis=0)
79
+
80
+ sigma_real = cov(real_features, rowvar=False)
81
+ sigma_generated = cov(generated_features, rowvar=False)
82
+
83
+ diff = mu_real - mu_generated
84
+ covmean, _ = sqrtm(sigma_real.dot(sigma_generated), disp=False)
85
+ if np.iscomplexobj(covmean):
86
+ covmean = covmean.real
87
+ fvd = diff.dot(diff) + np.trace(sigma_real + sigma_generated - 2 * covmean)
88
+ print(f"{self.taskname} score: {fvd}")
89
+ return fvd
90
+
91
+ class ThirdPartyEval(T2VTask):
92
+ def evaluate(self):
93
+ videos_path = Path(self.save_root).resolve()
94
+ dimension = task2dimension[self.taskname]
95
+ full_info = Path("./full_info_t2v.json").resolve()
96
+ output_dir = Path("./evaluation_results").resolve()
97
+ output_dir = output_dir.joinpath(self.taskname)
98
+ output_dir.mkdir(parents=True, exist_ok=True)
99
+
100
+ cmd = [
101
+ "python", "-W", "ignore", "evaluate.py",
102
+ "--full_json_dir", str(full_info),
103
+ "--videos_path", str(videos_path),
104
+ "--dimension", dimension,
105
+ "--output_path", str(output_dir)
106
+ ]
107
+
108
+ try:
109
+ subprocess.run(cmd, check=True)
110
+ except subprocess.CalledProcessError as exc:
111
+ raise RuntimeError(f"Evaluation failed: {exc}") from exc
112
+
113
+ score = self._parse_result_file(Path(output_dir))
114
+ print(f"{self.taskname} score: {score}")
115
+ return score
116
+
117
+ class I2VTask(BaseTask):
118
+ def _parse_result_file(self, output_dir: Path) -> float | None:
119
+ score = 0
120
+ for jsonfile in output_dir.iterdir():
121
+ if "eval" in jsonfile.name:
122
+ with open(jsonfile.as_posix(), "r") as file:
123
+ data: dict = json.load(file)
124
+ score += list(data.values())[0][0]
125
+ return score
126
+
127
+ def _parse_data(self, task_data):
128
+ self.dirpath = os.path.dirname(task_data)
129
+ with open(task_data, "r") as file:
130
+ annos = json.load(file)
131
+ taskname = annos["task"].replace(" ", "")
132
+ self.taskname = taskname
133
+ self.dimensions = ("subject_consistency", "overall_consistency", "motion_smoothness", "dynamic_degree")
134
+ self.save_root = os.path.join("General-Bench", "Video-Generation", taskname)
135
+
136
+ def run_inference(self):
137
+ for d in self.data:
138
+ prompt = d["input"]["prompt"]
139
+ image = d["input"]["image"]
140
+ image = os.path.join(self.dirpath, image)
141
+ for i in range(5):
142
+ video = self.model(
143
+ prompt=prompt,
144
+ image=image,
145
+ generator=torch.Generator(self.model.device).manual_seed(i)
146
+ ).frames[0]
147
+ save_name = prompt + "-" + str(i) + ".mp4"
148
+ save_path = os.path.join(self.save_root, save_name)
149
+ export_to_video(video, save_path, fps=8)
150
+
151
+ def evaluate(self):
152
+ taskname = self.taskname
153
+ full_info = Path("./full_info_i2v.json").resolve()
154
+ output_dir = Path("./evaluation_results").resolve()
155
+ output_dir = output_dir.joinpath(taskname)
156
+ output_dir.mkdir(parents=True, exist_ok=True)
157
+
158
+ for dimension in self.dimensions:
159
+ cmd = [
160
+ "python", "-W", "ignore", "evaluate.py",
161
+ "--full_json_dir", str(full_info),
162
+ "--videos_path", str(self.save_root),
163
+ "--dimension", dimension,
164
+ "--output_path", str(output_dir)
165
+ ]
166
+ try:
167
+ subprocess.run(cmd, check=True)
168
+ except subprocess.CalledProcessError as exc:
169
+ raise RuntimeError(f"Evaluation failed: {exc}") from exc
170
+
171
+ score = self._parse_result_file(Path(output_dir))
172
+ print(f"{self.taskname} score: {score}")
173
+ return score
174
+
175
+ class AthleticsT2V(FVDEval): pass
176
+
177
+ class HumanT2V(FVDEval): pass
178
+
179
+ class ConcertT2V(FVDEval): pass
180
+
181
+ class TerrestrialAnimalT2V(FVDEval): pass
182
+
183
+ class WaterSportsT2V(FVDEval): pass
184
+
185
+ class ActionT2V(ThirdPartyEval): pass
186
+
187
+ class ArtisticT2V(ThirdPartyEval): pass
188
+
189
+ class BackgroundConsistency(ThirdPartyEval): pass
190
+
191
+ class CameraMotionT2V(ThirdPartyEval): pass
192
+
193
+ class ClassConditionedT2V(ThirdPartyEval): pass
194
+
195
+ class ColorT2V(ThirdPartyEval): pass
196
+
197
+ class DynamicT2V(ThirdPartyEval): pass
198
+
199
+ class MaterialT2V(ThirdPartyEval): pass
200
+
201
+ class MultiClassConditionedT2V(ThirdPartyEval): pass
202
+
203
+ class SceneT2V(ThirdPartyEval): pass
204
+
205
+ class SpatialRelationT2V(ThirdPartyEval): pass
206
+
207
+ class StaticT2V(ThirdPartyEval): pass
208
+
209
+ class StyleT2V(ThirdPartyEval): pass
210
+
211
+ class ArchitectureI2V(I2VTask): pass
212
+
213
+ class ClothI2V(I2VTask): pass
214
+
215
+ class FoodI2V(I2VTask): pass
216
+
217
+ class FurnitureI2V(I2VTask): pass
218
+
219
+ class HumanI2V(I2VTask): pass
220
+
221
+ class PetI2V(I2VTask): pass
222
+
223
+ class PlantI2V(I2VTask): pass
224
+
225
+ class SceneI2V(I2VTask): pass
226
+
227
+ class VehicleI2V(I2VTask): pass
228
+
229
+ class WeatherI2V(I2VTask): pass
230
+
231
+ class WildAnimalI2V(I2VTask): pass
232
+
233
+
234
+ if __name__ == "__main__":
235
+ root = Path("General-Bench-Openset/video/generation")
236
+
237
+ task_type = "T2V"
238
+ model = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.bfloat16).to("cuda")
239
+
240
+ task_files = [
241
+ "AthleticsT2V",
242
+ "HumanT2V",
243
+ "ConcertT2V",
244
+ "TerrestrialAnimalT2V",
245
+ "WaterSportsT2V",
246
+ "ActionT2V",
247
+ "ArtisticT2V",
248
+ "BackgroundConsistency",
249
+ "CameraMotionT2V",
250
+ "ClassConditionedT2V",
251
+ "ColorT2V",
252
+ "DynamicT2V",
253
+ "MaterialT2V",
254
+ "MultiClassConditionedT2V",
255
+ "SceneT2V",
256
+ "SpatialRelationT2V",
257
+ "StaticT2V",
258
+ "StyleT2V",
259
+ "ArchitectureI2V",
260
+ "ClothI2V",
261
+ "FoodI2V",
262
+ "FurnitureI2V",
263
+ "HumanI2V",
264
+ "PetI2V",
265
+ "PlantI2V",
266
+ "SceneI2V",
267
+ "VehicleI2V",
268
+ "WeatherI2V",
269
+ "WildAnimalI2V",
270
+ ]
271
+
272
+ task_files = [root.joinpath(task, "annotation.json") for task in task_files]
273
+
274
+ for idx, file in enumerate(task_files):
275
+ if file.exists():
276
+ continue
277
+
278
+ with open(file.as_posix(), 'r', encoding='utf-8') as f:
279
+ task_data = json.load(f)
280
+
281
+ task_name = task_data["task"]
282
+ print(f"Running evaluation for task {idx + 1}: {task_name}")
283
+
284
+ TASK_MAPPING = {
285
+ "AthleticsT2V": AthleticsT2V,
286
+ "HumanT2V": HumanT2V,
287
+ "ConcertT2V": ConcertT2V,
288
+ "TerrestrialAnimalT2V": TerrestrialAnimalT2V,
289
+ "WaterSportsT2V": WaterSportsT2V,
290
+ "ActionT2V": ActionT2V,
291
+ "ArtisticT2V": ArtisticT2V,
292
+ "BackgroundConsistency": BackgroundConsistency,
293
+ "CameraMotionT2V": CameraMotionT2V,
294
+ "ClassConditionedT2V": ClassConditionedT2V,
295
+ "ColorT2V": ColorT2V,
296
+ "DynamicT2V": DynamicT2V,
297
+ "MaterialT2V": MaterialT2V,
298
+ "MultiClassConditionedT2V": MultiClassConditionedT2V,
299
+ "SceneT2V": SceneT2V,
300
+ "SpatialRelationT2V": SpatialRelationT2V,
301
+ "StaticT2V": StaticT2V,
302
+ "StyleT2V": StyleT2V,
303
+ "ArchitectureI2V": ArchitectureI2V,
304
+ "ClothI2V": ClothI2V,
305
+ "FoodI2V": FoodI2V,
306
+ "FurnitureI2V": FurnitureI2V,
307
+ "HumanI2V": HumanI2V,
308
+ "PetI2V": PetI2V,
309
+ "PlantI2V": PlantI2V,
310
+ "SceneI2V": SceneI2V,
311
+ "VehicleI2V": VehicleI2V,
312
+ "WeatherI2V": WeatherI2V,
313
+ "WildAnimalI2V": WildAnimalI2V,
314
+ }
315
+
316
+ clean_task_name = task_name.replace(" ", "")
317
+ task_class = TASK_MAPPING.get(clean_task_name)
318
+ if task_class is None:
319
+ raise NotImplementedError
320
+ elif task_type not in clean_task_name:
321
+ continue
322
+ else:
323
+ task = task_class(file.as_posix(), model)
324
+
325
+ task.run_inference()
326
+ metrics = task.evaluate()
327
+ print("Task name: ", task_name, "Task type: ", task_type, "Evaluation results:", metrics)
predictors/video_translation_restoration_superresolution_objectdetection.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified evaluator for four video–vision tasks and their metrics
3
+
4
+ • Video Translation → Frame-Acc (CLIP-based)
5
+ • Video Restoration (去噪/去模糊/…) → PSNR
6
+ • Video Super-Resolution → MUSIQ (no-reference IQA)
7
+ • Video (Salient / Camouflaged) Object Detection → Structure-measure
8
+
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import os
14
+ import sys
15
+ import json
16
+ import math
17
+ from abc import ABC, abstractmethod
18
+ from dataclasses import dataclass
19
+ from typing import Dict, Any, List, Tuple
20
+
21
+ # ───────────────────────── third-party imports ────────────────────────────
22
+ import numpy as np
23
+ from PIL import Image
24
+ from tqdm import tqdm
25
+ import torch
26
+ import torchvision.transforms as T
27
+
28
+ import open_clip # Frame-Acc
29
+
30
+ import pyiqa # MUSIQ
31
+
32
+
33
+ # Accepted image extensions (case-insensitive)
34
+ IMG_EXTS = ('.png', '.jpg', '.jpeg', '.bmp')
35
+
36
+ # ───────────────────────────── dataclass ────────────────────────────────
37
+ @dataclass
38
+ class Instance:
39
+ """Single sample inside the JSON"""
40
+ input: Dict[str, Any]
41
+ output: Dict[str, Any]
42
+ id: str
43
+
44
+ # ────────────────────────────── abstract ────────────────────────────────
45
+ class BaseTask(ABC):
46
+ def __init__(self, task_data: Dict[str, Any]):
47
+ self.task_data = task_data
48
+ self.data: List[Instance] = self._parse_data(task_data)
49
+
50
+ # --- implement in subclass ------------------------------------------------
51
+ @abstractmethod
52
+ def _parse_data(self, task_data: Dict[str, Any]) -> List[Instance]:
53
+ ...
54
+
55
+ @abstractmethod
56
+ def run_inference(self) -> None:
57
+ """collect paths & meta ⇒ self.records (does *not* run a model)"""
58
+ ...
59
+
60
+ @abstractmethod
61
+ def evaluate(self) -> Dict[str, float]:
62
+ ...
63
+
64
+ # ════════════════════════════════════════════════════════════════════════════
65
+ # 1. Video Translation – Frame-Acc
66
+ # ════════════════════════════════════════════════════════════════════════════
67
+ class VideoTranslationTask(BaseTask):
68
+ def _parse_data(self, task_data):
69
+ return [Instance(**d) for d in task_data["data"]]
70
+
71
+ def run_inference(self):
72
+ """gather [(frame_paths, src_prompt, tgt_prompt), …]"""
73
+ self.records: List[Tuple[List[str], str, str]] = []
74
+ for inst in tqdm(self.data, desc="collect-frames"):
75
+ frame_dir = inst.output["frame_dir"]
76
+ frames = sorted(
77
+ os.path.join(frame_dir, f)
78
+ for f in os.listdir(frame_dir)
79
+ if f.lower().endswith(IMG_EXTS)
80
+ )
81
+ self.records.append((frames,
82
+ inst.input["source_prompt"],
83
+ inst.input["target_prompt"]))
84
+
85
+ @torch.no_grad()
86
+ def evaluate(self, batch_size: int = 32):
87
+ if open_clip is None:
88
+ raise ImportError("open_clip_torch not installed. pip install open_clip_torch")
89
+
90
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
91
+ model, _, preprocess = open_clip.create_model_and_transforms(
92
+ "ViT-B-32", pretrained="openai", device=device
93
+ )
94
+ model.eval()
95
+ tokenizer = open_clip.tokenize
96
+
97
+ total, correct = 0, 0
98
+ for frame_paths, src_prompt, tgt_prompt in tqdm(self.records, desc="Frame-Acc eval"):
99
+ text_feat = model.encode_text(
100
+ tokenizer([src_prompt, tgt_prompt]).to(device)
101
+ ).float()
102
+ text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True) # (2,D)
103
+
104
+ for i in range(0, len(frame_paths), batch_size):
105
+ batch_files = frame_paths[i:i + batch_size]
106
+ imgs = torch.stack([
107
+ preprocess(Image.open(p).convert("RGB")) for p in batch_files
108
+ ]).to(device)
109
+ img_feat = model.encode_image(imgs).float()
110
+ img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True) # (B,D)
111
+ sim = img_feat @ text_feat.T # (B,2)
112
+ correct += (sim[:, 1] > sim[:, 0]).sum().item()
113
+ total += sim.size(0)
114
+
115
+ return {"Frame-Acc": 100.0 * correct / total if total else 0.0}
116
+
117
+
118
+ # ═══════════════════════════════════════��════════════════════════════════════
119
+ # 2. Video Restoration suite – PSNR
120
+ # ════════════════════════════════════════════════════════════════════════════
121
+ def compute_psnr(img1: np.ndarray, img2: np.ndarray, max_val: float = 255.0) -> float:
122
+ mse = np.mean((img1 - img2) ** 2, dtype=np.float64)
123
+ if mse == 0:
124
+ return math.inf
125
+ return 10.0 * math.log10((max_val ** 2) / mse)
126
+
127
+
128
+ class VideoRestorationTask(BaseTask):
129
+ def _parse_data(self, task_data):
130
+ return [Instance(**d) for d in task_data["data"]]
131
+
132
+ def run_inference(self):
133
+ """gather [(pred_paths, gt_paths), …]"""
134
+ self.records: List[Tuple[List[str], List[str]]] = []
135
+ for inst in tqdm(self.data, desc="collect-frames"):
136
+ pred_dir = inst.input["pred_dir"]
137
+ gt_dir = inst.input["gt_dir"]
138
+
139
+ frame_names = sorted(
140
+ f for f in os.listdir(gt_dir) if f.lower().endswith(IMG_EXTS)
141
+ )
142
+ pred_paths, gt_paths = [], []
143
+ for fname in frame_names:
144
+ p_path = os.path.join(pred_dir, fname)
145
+ g_path = os.path.join(gt_dir, fname)
146
+ if not os.path.exists(p_path):
147
+ raise FileNotFoundError(f"Missing prediction frame: {p_path}")
148
+ pred_paths.append(p_path)
149
+ gt_paths.append(g_path)
150
+ self.records.append((pred_paths, gt_paths))
151
+
152
+ def evaluate(self):
153
+ psnr_sum, valid_frames = 0.0, 0
154
+
155
+ for preds, gts in tqdm(self.records, desc="PSNR eval"):
156
+ for p, g in zip(preds, gts):
157
+ img1 = np.array(Image.open(p).convert("RGB"), dtype=np.float32)
158
+ img2 = np.array(Image.open(g).convert("RGB"), dtype=np.float32)
159
+
160
+ if img1.shape != img2.shape:
161
+ raise ValueError(f"Shape mismatch: {p} vs {g}")
162
+
163
+ val = compute_psnr(img1, img2)
164
+ if math.isfinite(val):
165
+ psnr_sum += val
166
+ valid_frames += 1
167
+
168
+ return {"PSNR": psnr_sum / valid_frames if valid_frames else 0.0}
169
+
170
+ # ════════════════════════════════════════════════════════════════════════════
171
+ # 3. Video Super-Resolution – MUSIQ
172
+ # ════════════════════════════════════════════════════════════════════════════
173
+ class VideoSuperResolutionTask(BaseTask):
174
+ def _parse_data(self, task_data):
175
+ return [Instance(**d) for d in task_data["data"]]
176
+
177
+ def run_inference(self):
178
+ self.records: List[List[str]] = []
179
+ for inst in tqdm(self.data, desc="collect-frames"):
180
+ pred_dir = inst.input["pred_dir"]
181
+ frames = sorted(
182
+ os.path.join(pred_dir, f)
183
+ for f in os.listdir(pred_dir)
184
+ if f.lower().endswith(IMG_EXTS)
185
+ )
186
+ if not frames:
187
+ raise RuntimeError(f"No prediction frames found in {pred_dir}")
188
+ self.records.append(frames)
189
+
190
+ @torch.no_grad()
191
+ def evaluate(self, batch_size: int = 8):
192
+ if pyiqa is None:
193
+ raise ImportError("pyiqa not installed. pip install pyiqa")
194
+
195
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
196
+ model = pyiqa.create_metric('musiq', device=device, as_loss=False)
197
+ model.eval()
198
+ transform = T.ToTensor()
199
+
200
+ total_sum, total_frames = 0.0, 0
201
+ for frames in tqdm(self.records, desc="MUSIQ eval"):
202
+ for i in range(0, len(frames), batch_size):
203
+ batch = frames[i:i + batch_size]
204
+ imgs = torch.stack([
205
+ transform(Image.open(p).convert("RGB")) for p in batch
206
+ ]).to(device)
207
+ scores = model(imgs) # (B,)
208
+ total_sum += scores.sum().item()
209
+ total_frames += scores.numel()
210
+
211
+ return {"MUSIQ": total_sum / total_frames if total_frames else 0.0}
212
+
213
+
214
+ # ════════════════════════════════════════════════════════════════════════════
215
+ # 4. Video (Salient / Camouflaged) Object Detection – Structure-measure
216
+ # ════════════════════════════════════════════════════════════════════════════
217
+ def _ssim(pred: np.ndarray, gt: np.ndarray) -> float:
218
+ C1, C2 = 0.01 ** 2, 0.03 ** 2
219
+ mp, mg = pred.mean(), gt.mean()
220
+ var_p, var_g = pred.var(), gt.var()
221
+ cov = ((pred - mp) * (gt - mg)).mean()
222
+ return ((2 * mp * mg + C1) * (2 * cov + C2)) / (
223
+ (mp ** 2 + mg ** 2 + C1) * (var_p + var_g + C2) + 1e-8)
224
+
225
+
226
+ def _object_score(x: np.ndarray) -> float:
227
+ if x.size == 0:
228
+ return 0.0
229
+ mu, sigma = x.mean(), x.std()
230
+ return 2 * mu / (mu * mu + 1 + sigma + 1e-8)
231
+
232
+
233
+ def structure_measure(pred: np.ndarray, gt: np.ndarray, alpha: float = 0.5) -> float:
234
+ """pred in [0,1] float32, gt binary uint8 (0/1)"""
235
+ y = gt.mean()
236
+ if y == 0: # GT 全黑
237
+ return 1.0 - pred.mean()
238
+ if y == 1: # GT 全白
239
+ return pred.mean()
240
+
241
+ # ─── object-aware term ─────────────────────────────────────────────────
242
+ S_fg = _object_score(pred[gt > 0.5])
243
+ S_bg = _object_score(1 - pred[gt <= 0.5])
244
+ s_object = y * S_fg + (1 - y) * S_bg
245
+
246
+ # ─── region-aware term ────────────────────────────────────────────────
247
+ h, w = gt.shape
248
+ rows, cols = np.where(gt > 0.5)
249
+ cx = int(np.round(cols.mean())) if cols.size else w // 2
250
+ cy = int(np.round(rows.mean())) if rows.size else h // 2
251
+
252
+ def split(img):
253
+ return [img[:cy, :cx], img[:cy, cx:], img[cy:, :cx], img[cy:, cx:]]
254
+
255
+ regions_p = split(pred)
256
+ regions_g = split(gt.astype(np.float32))
257
+
258
+ weights = [r.size / (h * w) for r in regions_g]
259
+ ssim_scores = [_ssim(p_r, g_r) for p_r, g_r in zip(regions_p, regions_g)]
260
+ s_region = sum(w * s for w, s in zip(weights, ssim_scores))
261
+
262
+ score = alpha * s_object + (1 - alpha) * s_region
263
+ return max(score, 0.0)
264
+
265
+
266
+ class VideoObjectDetectionTask(BaseTask):
267
+ def _parse_data(self, task_data):
268
+ return [Instance(**d) for d in task_data["data"]]
269
+
270
+ def run_inference(self):
271
+ self.records: List[Tuple[List[str], List[str]]] = []
272
+ for inst in tqdm(self.data, desc="collect-frames"):
273
+ pred_dir = inst.input["pred_dir"]
274
+ gt_dir = inst.input["gt_dir"]
275
+
276
+ frame_names = sorted(
277
+ f for f in os.listdir(gt_dir) if f.lower().endswith(IMG_EXTS)
278
+ )
279
+ preds, gts = [], []
280
+ for fname in frame_names:
281
+ p_path = os.path.join(pred_dir, fname)
282
+ g_path = os.path.join(gt_dir, fname)
283
+ if not os.path.exists(p_path):
284
+ raise FileNotFoundError(f"Missing prediction frame: {p_path}")
285
+ preds.append(p_path)
286
+ gts.append(g_path)
287
+ self.records.append((preds, gts))
288
+
289
+ def evaluate(self):
290
+ total_sum, total_frames = 0.0, 0
291
+
292
+ for preds, gts in tqdm(self.records, desc="S-measure eval"):
293
+ for p, g in zip(preds, gts):
294
+ pred = np.array(Image.open(p).convert('L'), dtype=np.float32)
295
+ if pred.max() > 1.0:
296
+ pred /= 255.0
297
+ gt = (np.array(Image.open(g).convert('L')) > 128).astype(np.uint8)
298
+
299
+ if pred.shape != gt.shape:
300
+ raise ValueError(f"Shape mismatch: {p} vs {g}")
301
+
302
+ total_sum += structure_measure(pred, gt)
303
+ total_frames += 1
304
+
305
+ return {"S-measure": total_sum / total_frames if total_frames else 0.0}
306
+
307
+
308
+ # ════════════════════════════════════════════════════════════════════════════
309
+ # unified runner
310
+ # ═════════════════
311
+ TASK_MAPPING = {
312
+ "VideoTranslation": VideoTranslationTask,
313
+ "VideoRestoration": VideoRestorationTask,
314
+ "VideoSuperResolution": VideoSuperResolutionTask,
315
+ "VideoObjectDetection": VideoObjectDetectionTask,
316
+ }
317
+
318
+
319
+ def main():
320
+ if len(sys.argv) != 2:
321
+ print("Usage: python integrated_eval.py <task_json>")
322
+ sys.exit(1)
323
+
324
+ task_json_path = sys.argv[1]
325
+ with open(task_json_path, 'r', encoding='utf-8') as f:
326
+ task_data = json.load(f)
327
+
328
+ task_type = task_data.get("type")
329
+ TaskCls = TASK_MAPPING.get(task_type)
330
+ if TaskCls is None:
331
+ raise NotImplementedError(f"Unsupported task type: {task_type}")
332
+
333
+ task = TaskCls(task_data)
334
+ task.run_inference()
335
+ metrics = task.evaluate()
336
+ print(f"[{task_type}] Evaluation Results → {metrics}")
337
+
338
+
339
+ if __name__ == "__main__":
340
+ main()
processors/._audio_processor.py ADDED
Binary file (176 Bytes). View file
 
processors/._image_processor.py ADDED
Binary file (219 Bytes). View file
 
processors/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Processors package for different modalities."""
processors/__pycache__/.___init__.cpython-38.pyc ADDED
Binary file (176 Bytes). View file
 
processors/__pycache__/.___init__.cpython-39.pyc ADDED
Binary file (176 Bytes). View file
 
processors/__pycache__/._video_processor.cpython-39.pyc ADDED
Binary file (176 Bytes). View file
 
processors/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (233 Bytes). View file
 
processors/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (208 Bytes). View file
 
processors/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (199 Bytes). View file
 
processors/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (225 Bytes). View file
 
processors/__pycache__/audio_processor.cpython-311.pyc ADDED
Binary file (4.07 kB). View file
 
processors/__pycache__/audio_processor.cpython-312.pyc ADDED
Binary file (3.64 kB). View file
 
processors/__pycache__/audio_processor.cpython-38.pyc ADDED
Binary file (3.06 kB). View file
 
processors/__pycache__/audio_processor.cpython-39.pyc ADDED
Binary file (3.06 kB). View file
 
processors/__pycache__/image_processor.cpython-311.pyc ADDED
Binary file (4.2 kB). View file
 
processors/__pycache__/image_processor.cpython-312.pyc ADDED
Binary file (3.63 kB). View file
 
processors/__pycache__/image_processor.cpython-38.pyc ADDED
Binary file (3.05 kB). View file
 
processors/__pycache__/image_processor.cpython-39.pyc ADDED
Binary file (3.05 kB). View file
 
processors/__pycache__/nlp_processor.cpython-311.pyc ADDED
Binary file (21.4 kB). View file
 
processors/__pycache__/nlp_processor.cpython-312.pyc ADDED
Binary file (2.29 kB). View file
 
processors/__pycache__/nlp_processor.cpython-38.pyc ADDED
Binary file (1.98 kB). View file
 
processors/__pycache__/nlp_processor.cpython-39.pyc ADDED
Binary file (1.93 kB). View file
 
processors/__pycache__/pseudo_audio_processor.cpython-39.pyc ADDED
Binary file (2.13 kB). View file
 
processors/__pycache__/three_d_processor.cpython-311.pyc ADDED
Binary file (4.06 kB). View file
 
processors/__pycache__/three_d_processor.cpython-312.pyc ADDED
Binary file (3.63 kB). View file
 
processors/__pycache__/three_d_processor.cpython-38.pyc ADDED
Binary file (3.05 kB). View file
 
processors/__pycache__/three_d_processor.cpython-39.pyc ADDED
Binary file (3.05 kB). View file
 
processors/__pycache__/video_processor.cpython-311.pyc ADDED
Binary file (4.07 kB). View file
 
processors/__pycache__/video_processor.cpython-312.pyc ADDED
Binary file (3.65 kB). View file
 
processors/__pycache__/video_processor.cpython-38.pyc ADDED
Binary file (3.07 kB). View file
 
processors/__pycache__/video_processor.cpython-39.pyc ADDED
Binary file (3.07 kB). View file
 
processors/audio_processor.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from utils.data_types import ModalityType, TaskType, TaskResult
3
+ from utils.base_processor import BaseModalityProcessor
4
+
5
+ class AudioProcessor(BaseModalityProcessor):
6
+ """音频模态处理器"""
7
+ def __init__(self, modality: ModalityType, dataset_dir: str, pred_json_file: str):
8
+ super().__init__(modality, dataset_dir, pred_json_file)
9
+
10
+ def process_comprehension(self) -> List[TaskResult]:
11
+ """处理音频理解类任务
12
+
13
+ 需要返回一个TaskResult列表,每个TaskResult包含:
14
+ - task_name: 任务名称,例如 "speech_recognition", "audio_classification" 等
15
+ - metric: 评估指标,例如 "WER", "accuracy" 等
16
+ - score: 评估分数
17
+ - task_type: 默认为 TaskType.COMPREHENSION,不需要指定
18
+
19
+ 示例格式:
20
+ return [
21
+ TaskResult(
22
+ task_name="speech_recognition",
23
+ metric="WER",
24
+ score=0.15
25
+ ),
26
+ TaskResult(
27
+ task_name="audio_classification",
28
+ metric="accuracy",
29
+ score=0.92
30
+ )
31
+ ]
32
+ """
33
+ return []
34
+
35
+ def process_generation(self) -> List[TaskResult]:
36
+ """处理音频生成类任务
37
+
38
+ 需要返回一个TaskResult列表,每个TaskResult包含:
39
+ - task_name: 任务名称,例如 "speech_synthesis", "audio_generation" 等
40
+ - metric: 评估指标,例如 "MOS", "FAD" 等
41
+ - score: 评估分数
42
+ - task_type: 需要指定为 TaskType.GENERATION
43
+
44
+ 示例格式:
45
+ return [
46
+ TaskResult(
47
+ task_name="speech_synthesis",
48
+ metric="MOS",
49
+ score=4.2,
50
+ task_type=TaskType.GENERATION
51
+ ),
52
+ TaskResult(
53
+ task_name="audio_generation",
54
+ metric="FAD",
55
+ score=12.5,
56
+ task_type=TaskType.GENERATION
57
+ )
58
+ ]
59
+ """
60
+ return []
61
+
62
+ # 使用示例
63
+ if __name__ == "__main__":
64
+ processor = AudioProcessor(ModalityType.AUDIO, "")
65
+
66
+ # 测试理解任务
67
+ print("\n理解类任务结果:")
68
+ for task in processor.process_comprehension():
69
+ print(f"任务: {task.task_name}")
70
+ print(f"指标: {task.metric}")
71
+ print(f"分数: {task.score}")
72
+ print("-" * 20)
73
+
74
+ # 测试生成任务
75
+ print("\n生成类任务结果:")
76
+ for task in processor.process_generation():
77
+ print(f"任务: {task.task_name}")
78
+ print(f"指标: {task.metric}")
79
+ print(f"分数: {task.score}")
80
+ print("-" * 20)
processors/image_processor.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from utils.data_types import ModalityType, TaskType, TaskResult
3
+ from utils.base_processor import BaseModalityProcessor
4
+
5
+ class ImageProcessor(BaseModalityProcessor):
6
+ """图像模态处理器"""
7
+ def __init__(self, modality: ModalityType, dataset_dir: str, pred_json_file: str):
8
+ super().__init__(modality, dataset_dir, pred_json_file)
9
+
10
+ def process_1(self):
11
+ return []
12
+
13
+ def process_comprehension(self) -> List[TaskResult]:
14
+ """处理图像理解类任务
15
+
16
+ 需要返回一个TaskResult列表,每个TaskResult包含:
17
+ - task_name: 任务名称,例如 "image_classification", "object_detection" 等
18
+ - metric: 评估指标,例如 "accuracy", "mAP" 等
19
+ - score: 评估分数
20
+ - task_type: 默认为 TaskType.COMPREHENSION,不需要指定
21
+
22
+ 示例格式:
23
+ return [
24
+ TaskResult(
25
+ task_name="image_classification",
26
+ metric="accuracy",
27
+ score=0.95
28
+ ),
29
+ TaskResult(
30
+ task_name="object_detection",
31
+ metric="mAP",
32
+ score=0.82
33
+ )
34
+ ]
35
+ """
36
+ return []
37
+
38
+ def process_generation(self) -> List[TaskResult]:
39
+ """处理图像生成类任务
40
+
41
+ 需要返回一个TaskResult列表,每个TaskResult包含:
42
+ - task_name: 任务名称,例如 "image_generation", "image_editing" 等
43
+ - metric: 评估指标,例如 "FID", "IS" 等
44
+ - score: 评估分数
45
+ - task_type: 需要指定为 TaskType.GENERATION
46
+
47
+ 示例格式:
48
+ return [
49
+ TaskResult(
50
+ task_name="image_generation",
51
+ metric="FID",
52
+ score=15.2,
53
+ task_type=TaskType.GENERATION
54
+ ),
55
+ TaskResult(
56
+ task_name="image_editing",
57
+ metric="PSNR",
58
+ score=28.5,
59
+ task_type=TaskType.GENERATION
60
+ )
61
+ ]
62
+ """
63
+ return []
64
+
65
+ # 使用示例
66
+ if __name__ == "__main__":
67
+ processor = ImageProcessor(ModalityType.IMAGE, "")
68
+
69
+ # 测试理解任务
70
+ print("\n理解类任务结果:")
71
+ for task in processor.process_comprehension():
72
+ print(f"任务: {task.task_name}")
73
+ print(f"指标: {task.metric}")
74
+ print(f"分数: {task.score}")
75
+ print("-" * 20)
76
+
77
+ # 测试生成任务
78
+ print("\n生成类任务结果:")
79
+ for task in processor.process_generation():
80
+ print(f"任务: {task.task_name}")
81
+ print(f"指标: {task.metric}")
82
+ print(f"分数: {task.score}")
83
+ print("-" * 20)
processors/nlp_processor.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ import math
5
+ import numpy as np
6
+ import pandas as pd
7
+ from typing import List, Dict, Any, Optional
8
+ import nltk
9
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
10
+ from rouge_score import rouge_scorer
11
+ from codebleu import calc_codebleu
12
+ from utils.data_types import TaskResult, TaskType
13
+
14
+
15
+ class NLPProcessor:
16
+ def __init__(self, modality, dataset_dir: str, pred_json_file: str = "prediction.json"):
17
+ self.modality = modality
18
+ self.dataset_dir = dataset_dir + '/nlp'
19
+ self.pred_json_file = pred_json_file
20
+
21
+ def process(self) -> List[TaskResult]:
22
+ results = []
23
+
24
+ task_dirs = [d for d in os.listdir(self.dataset_dir) if os.path.isdir(os.path.join(self.dataset_dir, d))]
25
+ total_tasks = len(task_dirs)
26
+ processed_tasks = 0
27
+
28
+ for task_folder in task_dirs:
29
+ folder_path = os.path.join(self.dataset_dir, task_folder)
30
+ annotation_path = os.path.join(folder_path, "annotation.json")
31
+ prediction_path = os.path.join(folder_path, self.pred_json_file)
32
+
33
+ if not os.path.exists(annotation_path):
34
+ print(f"Skip {task_folder}: annotation.json no exists")
35
+ continue
36
+
37
+ if not os.path.exists(prediction_path):
38
+ print(f"Skip {task_folder}: {self.pred_json_file} no exists.")
39
+ continue
40
+
41
+ try:
42
+ with open(annotation_path, "r", encoding="utf-8") as f:
43
+ task_data = json.load(f)
44
+
45
+ with open(prediction_path, "r", encoding="utf-8") as f:
46
+ predictions_data = json.load(f)
47
+
48
+ task_result = self._evaluate_task(task_data, predictions_data)
49
+ if task_result:
50
+ results.append(task_result)
51
+ processed_tasks += 1
52
+ print(f"Task: {task_folder} (Socre: {task_result.score:.4f})")
53
+ else:
54
+ print(f"Skip {task_folder}.")
55
+
56
+ except Exception as e:
57
+ print(f"Skip {task_folder}: Error - {e}")
58
+ continue
59
+
60
+ return results
61
+
62
+ def _evaluate_task(self, task_data: Dict[str, Any], predictions_data: List[Dict]) -> Optional[TaskResult]:
63
+ task_type = task_data.get("type", "")
64
+ task_name = task_data.get("task", "")
65
+
66
+ pred_map = {pred["id"]: pred for pred in predictions_data}
67
+
68
+ predictions = []
69
+ references = []
70
+
71
+ for data_item in task_data["data"]:
72
+ item_id = data_item["id"]
73
+ if item_id not in pred_map:
74
+ continue
75
+
76
+ pred_item = pred_map[item_id]
77
+
78
+ if "prediction" in pred_item:
79
+ pred = pred_item["prediction"]
80
+ elif "prediction_final" in pred_item:
81
+ pred = pred_item["prediction_final"]
82
+ else:
83
+ continue
84
+
85
+ ref = self._extract_reference(data_item, task_type)
86
+ if ref is None:
87
+ continue
88
+
89
+ predictions.append(pred)
90
+ references.append(ref)
91
+
92
+ if not predictions:
93
+ return None
94
+
95
+ score, metric = self._calculate_metrics(predictions, references, task_type)
96
+ metric = self._convert_metric(metric)
97
+
98
+ return TaskResult(
99
+ task_name=task_name,
100
+ metric=metric,
101
+ score=score,
102
+ task_type=TaskType.COMPREHENSION
103
+ )
104
+
105
+ def _extract_reference(self, data_item: Dict[str, Any], task_type: str) -> Any:
106
+ output = data_item.get("output", {})
107
+
108
+ if task_type == "MultipleChoiceQA":
109
+ return output.get("answer")
110
+ elif task_type == "OpenQA":
111
+ return output.get("answer")
112
+ elif task_type == "Summarization":
113
+ return output.get("summary") or output.get("highlights")
114
+ elif task_type == "Translation":
115
+ if isinstance(output, str):
116
+ return output
117
+ else:
118
+ return output.get("translation")
119
+ elif task_type == "Story Generation":
120
+ return output.get("story")
121
+ elif task_type == "Dialogue":
122
+ return output.get("reference")
123
+ elif task_type == "Code Generation":
124
+ return output.get("response", {}).get("content")
125
+ elif task_type == "Code Repair":
126
+ return output.get("repairCode")
127
+ elif task_type == "Code Defect Detection":
128
+ return str(output.get("target"))
129
+ elif task_type == "Text to SQL":
130
+ return output.get("sql")
131
+ elif task_type == "Code Explanation":
132
+ return output.get("nl")
133
+ elif task_type == "Proof":
134
+ proof_data = output.get("proof", {})
135
+ steps = proof_data.get("steps", [])
136
+ conclusion = proof_data.get("conclusion", "")
137
+ return "\n".join(steps) + f"\nConclusion: {conclusion}"
138
+ elif task_type == "Mathematical Word Problem Solving":
139
+ return output.get("solution", {}).get("final_answer")
140
+ elif task_type == "Paraphrase Generation":
141
+ return output.get("paraphraseSentence")
142
+ elif task_type == "Grammar Correction":
143
+ return output.get("Standard English")
144
+ elif task_type == "Text Style Transfer":
145
+ return output.get("answer")
146
+ elif task_type == "Table-to-Text Generation":
147
+ return output.get("response", {}).get("text")
148
+ elif task_type == "Time Series":
149
+ return output.get("target")
150
+ elif task_type in ["classification", "multiple choice"]:
151
+ return list(output.values())[0].lower() if output else ""
152
+ elif task_type in ["multi label classification", "ner", "extraction", "relation extraction", "event detection", "parsing"]:
153
+ value = list(output.values())[0] if output else ""
154
+ return '<p>'.join(value.lower().split(', ')) if isinstance(value, str) else ""
155
+ else:
156
+ # 默认取第一个值
157
+ return list(output.values())[0] if output else ""
158
+
159
+ def _calculate_metrics(self, predictions: List, references: List, task_type: str) -> tuple:
160
+ if task_type == "MultipleChoiceQA":
161
+ score = self._exact_match_accuracy(predictions, references)
162
+ return score, "accuracy"
163
+
164
+ elif task_type == "OpenQA":
165
+ f1_score = self._calculate_f1(predictions, references)
166
+ return f1_score, "f1"
167
+
168
+ elif task_type == "Summarization":
169
+ rouge_scores = self._rouge_evaluation(predictions, references)
170
+ return rouge_scores["rouge1"], "rouge1"
171
+
172
+ elif task_type == "Translation":
173
+ rouge_scores = self._rouge_evaluation(predictions, references)
174
+ return rouge_scores["rouge1"], "rouge1"
175
+
176
+ elif task_type in ["Story Generation", "Dialogue", "Paraphrase Generation", "Grammar Correction", "Text Style Transfer", "Table-to-Text Generation"]:
177
+ bleu_scores = self._bleu_evaluation(predictions, references)
178
+ return bleu_scores["bleu1"], "bleu1"
179
+
180
+ elif task_type in ["Code Generation", "Code Repair"]:
181
+ try:
182
+ result = calc_codebleu(references, predictions, lang="python", weights=(0.25, 0.25, 0.25, 0.25), tokenizer=None)
183
+ return result["codebleu"], "code_bleu"
184
+ except:
185
+ return 0.0, "code_bleu"
186
+
187
+ elif task_type == "Code Defect Detection":
188
+ score = self._exact_match_accuracy(predictions, references)
189
+ return score, "accuracy"
190
+
191
+ elif task_type == "Text to SQL":
192
+ score = self._exact_match_accuracy(predictions, references)
193
+ return score, "accuracy"
194
+
195
+ elif task_type in ["Code Explanation", "Proof"]:
196
+ bleu_scores = self._bleu_evaluation(predictions, references)
197
+ return bleu_scores["bleu1"], "bleu1"
198
+
199
+ elif task_type == "Mathematical Word Problem Solving":
200
+ score = self._exact_match_accuracy(predictions, references)
201
+ return score, "accuracy"
202
+
203
+ elif task_type == "Time Series":
204
+ mae = self._mean_absolute_error(predictions, references)
205
+ return mae, "MAE"
206
+
207
+ elif task_type in ["classification", "multiple choice"]:
208
+ f1_score = self._calculate_micro_f1(predictions, references)
209
+ return f1_score, "micro_f1"
210
+
211
+ elif task_type in ["multi label classification", "ner", "extraction", "relation extraction", "event detection", "parsing"]:
212
+ f1_score = self._calculate_micro_f1(predictions, references)
213
+ return f1_score, "micro_f1"
214
+
215
+ else:
216
+ f1_score = self._calculate_f1(predictions, references)
217
+ return f1_score, "f1"
218
+
219
+ def _exact_match_accuracy(self, predictions: List[str], references: List[str]) -> float:
220
+ correct = 0
221
+ for pred, ref in zip(predictions, references):
222
+ if isinstance(ref, str):
223
+ ref = [ref]
224
+ is_match = False
225
+ for r in ref:
226
+ if str(pred).strip() == str(r).strip():
227
+ is_match = True
228
+ break
229
+ if is_match:
230
+ correct += 1
231
+ return correct / len(predictions) if predictions else 0.0
232
+
233
+ def _calculate_f1(self, predictions: List[str], references: List[str]) -> float:
234
+ def compute_f1(pred: str, ref: str) -> float:
235
+ pred_tokens = str(pred).strip().split()
236
+ ref_tokens = str(ref).strip().split()
237
+
238
+ common_tokens = set(pred_tokens) & set(ref_tokens)
239
+ num_common = len(common_tokens)
240
+
241
+ if num_common == 0:
242
+ return 0.0
243
+
244
+ precision = num_common / len(pred_tokens) if pred_tokens else 0.0
245
+ recall = num_common / len(ref_tokens) if ref_tokens else 0.0
246
+
247
+ return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
248
+
249
+ total_f1 = 0.0
250
+ for pred, ref in zip(predictions, references):
251
+ if isinstance(ref, str):
252
+ ref = [ref]
253
+ max_f1 = 0.0
254
+ for r in ref:
255
+ max_f1 = max(compute_f1(pred, r), max_f1)
256
+ total_f1 += max_f1
257
+
258
+ return total_f1 / len(predictions) if predictions else 0.0
259
+
260
+ def _calculate_micro_f1(self, predictions: List[str], references: List[str]) -> float:
261
+ total_tp = 0
262
+ total_fp = 0
263
+ total_fn = 0
264
+
265
+ for pred, ref in zip(predictions, references):
266
+ pred_tokens = set(str(pred).strip().split('<p>'))
267
+ ref_tokens = set(str(ref).strip().split("<p>"))
268
+
269
+ tp = len(pred_tokens & ref_tokens)
270
+ fp = len(pred_tokens - ref_tokens)
271
+ fn = len(ref_tokens - pred_tokens)
272
+
273
+ total_tp += tp
274
+ total_fp += fp
275
+ total_fn += fn
276
+
277
+ if total_tp == 0:
278
+ return 0.0
279
+
280
+ precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
281
+ recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
282
+ return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
283
+
284
+ def _rouge_evaluation(self, predictions: List[str], references: List[str]) -> Dict[str, float]:
285
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
286
+ rouge1_scores, rouge2_scores, rougel_scores = [], [], []
287
+
288
+ for pred, ref in zip(predictions, references):
289
+ if isinstance(ref, str):
290
+ ref = [ref]
291
+ rouge1, rouge2, rougeL = 0, 0, 0
292
+ for r in ref:
293
+ scores = scorer.score(str(r), str(pred))
294
+ rouge1 = max(scores['rouge1'].fmeasure, rouge1)
295
+ rouge2 = max(scores['rouge2'].fmeasure, rouge2)
296
+ rougeL = max(scores['rougeL'].fmeasure, rougeL)
297
+ rouge1_scores.append(rouge1)
298
+ rouge2_scores.append(rouge2)
299
+ rougel_scores.append(rougeL)
300
+
301
+ return {
302
+ 'rouge1': sum(rouge1_scores) / len(rouge1_scores) if rouge1_scores else 0.0,
303
+ 'rouge2': sum(rouge2_scores) / len(rouge2_scores) if rouge2_scores else 0.0,
304
+ 'rougeL': sum(rougel_scores) / len(rougel_scores) if rougel_scores else 0.0,
305
+ }
306
+
307
+ def _bleu_evaluation(self, predictions: List[str], references: List[str]) -> Dict[str, float]:
308
+ smoothie = SmoothingFunction().method4
309
+ bleu1_scores, bleu2_scores, bleu3_scores, bleu4_scores = [], [], [], []
310
+
311
+ for pred, ref in zip(predictions, references):
312
+ try:
313
+ hypothesis = nltk.word_tokenize(str(pred))
314
+ except:
315
+ hypothesis = str(pred).split()
316
+
317
+ if isinstance(ref, str):
318
+ ref = [ref]
319
+
320
+ bleu1, bleu2, bleu3, bleu4 = 0, 0, 0, 0
321
+ for r in ref:
322
+ try:
323
+ reference = [nltk.word_tokenize(str(r))]
324
+ except:
325
+ reference = [str(r).split()]
326
+
327
+ try:
328
+ bleu1 = max(sentence_bleu(reference, hypothesis, weights=(1, 0, 0, 0), smoothing_function=smoothie), bleu1)
329
+ bleu2 = max(sentence_bleu(reference, hypothesis, weights=(0.5, 0.5, 0, 0), smoothing_function=smoothie), bleu2)
330
+ bleu3 = max(sentence_bleu(reference, hypothesis, weights=(1/3, 1/3, 1/3, 0), smoothing_function=smoothie), bleu3)
331
+ bleu4 = max(sentence_bleu(reference, hypothesis, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothie), bleu4)
332
+ except:
333
+ continue
334
+
335
+ bleu1_scores.append(bleu1)
336
+ bleu2_scores.append(bleu2)
337
+ bleu3_scores.append(bleu3)
338
+ bleu4_scores.append(bleu4)
339
+
340
+ return {
341
+ 'bleu1': sum(bleu1_scores) / len(bleu1_scores) if bleu1_scores else 0.0,
342
+ 'bleu2': sum(bleu2_scores) / len(bleu2_scores) if bleu2_scores else 0.0,
343
+ 'bleu3': sum(bleu3_scores) / len(bleu3_scores) if bleu3_scores else 0.0,
344
+ 'bleu4': sum(bleu4_scores) / len(bleu4_scores) if bleu4_scores else 0.0,
345
+ }
346
+
347
+ def _mean_absolute_error(self, predictions: List[float], references: List[float]) -> float:
348
+ if not predictions:
349
+ return 0.0
350
+
351
+ error_sum = 0.0
352
+ valid_count = 0
353
+
354
+ for p, r in zip(predictions, references):
355
+ try:
356
+ error_sum += abs(float(p) - float(r))
357
+ valid_count += 1
358
+ except:
359
+ continue
360
+
361
+ return error_sum / valid_count if valid_count > 0 else 0.0
362
+
363
+ def _convert_metric(self, metric: str) -> str:
364
+ m = metric.lower()
365
+ if m == "accuracy":
366
+ return "ACC"
367
+ if m == "f1":
368
+ return "F1"
369
+ if m == "micro_f1":
370
+ return "Micro-F1"
371
+ if m.startswith("rouge"):
372
+ if "l" in m:
373
+ return "ROUGE-L"
374
+ else:
375
+ return "ROUGE-1"
376
+ if m.startswith("bleu"):
377
+ return "BLEU-1"
378
+ if m == "code_bleu":
379
+ return "CodeBLEU"
380
+ return metric.upper()
381
+
processors/three_d_processor.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from utils.data_types import ModalityType, TaskType, TaskResult
3
+ from utils.base_processor import BaseModalityProcessor
4
+
5
+ class ThreeDProcessor(BaseModalityProcessor):
6
+ """3D模态处理器"""
7
+ def __init__(self, modality: ModalityType, dataset_dir: str, pred_json_file: str):
8
+ super().__init__(modality, dataset_dir, pred_json_file)
9
+
10
+ def process_comprehension(self) -> List[TaskResult]:
11
+ """处理3D理解类任务
12
+
13
+ 需要返回一个TaskResult列表,每个TaskResult包含:
14
+ - task_name: 任务名称,例如 "3d_object_detection", "point_cloud_segmentation" 等
15
+ - metric: 评估指标,例如 "mAP", "IoU" 等
16
+ - score: 评估分数
17
+ - task_type: 默认为 TaskType.COMPREHENSION,不需要指定
18
+ 示例格式:
19
+ return [
20
+ TaskResult(
21
+ task_name="3d_object_detection",
22
+ metric="mAP",
23
+ score=0.76
24
+ ),
25
+ TaskResult(
26
+ task_name="point_cloud_segmentation",
27
+ metric="IoU",
28
+ score=0.82
29
+ )
30
+ ]
31
+ """
32
+ return []
33
+
34
+ def process_generation(self) -> List[TaskResult]:
35
+ """处理3D生成类任务
36
+
37
+ 需要返回一个TaskResult列表,每个TaskResult包含:
38
+ - task_name: 任务名称,例如 "3d_reconstruction", "mesh_generation" 等
39
+ - metric: 评估指标,例如 "CD", "F1" 等
40
+ - score: 评估分数
41
+ - task_type: 这里需要指定为 TaskType.GENERATION
42
+
43
+ 示例格式:
44
+ return [
45
+ TaskResult(
46
+ task_name="3d_reconstruction",
47
+ metric="CD",
48
+ score=0.15,
49
+ task_type=TaskType.GENERATION
50
+ ),
51
+ TaskResult(
52
+ task_name="mesh_generation",
53
+ metric="F1",
54
+ score=0.88,
55
+ task_type=TaskType.GENERATION
56
+ )
57
+ ]
58
+ """
59
+ return []
60
+
61
+ # 使用示例
62
+ if __name__ == "__main__":
63
+ processor = ThreeDProcessor(ModalityType.THREE_D, "")
64
+
65
+ # 测试理解任务
66
+ print("\n理解类任务结果:")
67
+ for task in processor.process_comprehension():
68
+ print(f"任务: {task.task_name}")
69
+ print(f"指标: {task.metric}")
70
+ print(f"分数: {task.score}")
71
+ print("-" * 20)
72
+
73
+ # 测试生成任务
74
+ print("\n生成类任务结果:")
75
+ for task in processor.process_generation():
76
+ print(f"任务: {task.task_name}")
77
+ print(f"指标: {task.metric}")
78
+ print(f"分数: {task.score}")
79
+ print("-" * 20)