Spaces:
Running
Running
from re import L | |
from typing import Dict, List | |
from collections import Counter | |
def grouping(bondlist): | |
# reference: https://blog.csdn.net/YnagShanwen/article/details/111344386 | |
groups = [] | |
break1 = False | |
while bondlist: | |
pair1 = bondlist.pop(0) | |
a = 11111 | |
b = 10000 | |
while b != a: | |
a = b | |
for atomid in pair1: | |
for i,pair2 in enumerate(bondlist): | |
if atomid in pair2: | |
pair1 = pair1 + pair2 | |
bondlist.pop(i) | |
if not bondlist: | |
break1 = True | |
break | |
if break1: | |
break | |
b = len(pair1) | |
groups.append(pair1) | |
return groups | |
def build_semantic_class_info(classes: List[str], aliases: List[List[str]]): | |
res = [] | |
for c in classes: | |
# print(res) | |
if len(aliases) == 0: | |
res += [[c]] | |
else: | |
find_alias = False | |
for alias in aliases: | |
if c in alias: | |
res += [alias] | |
find_alias = True | |
break | |
if not find_alias: | |
res += [[c]] | |
# print(classes, res) | |
return res | |
def merge_the_same_meaning_classes(classes_info_of_all_datasets): | |
# print(classes_info_of_all_datasets) | |
semantic_classes_of_all_datasets = [] | |
all_aliases = [] | |
for classes, aliases in classes_info_of_all_datasets.values(): | |
all_aliases += aliases | |
for classes, aliases in classes_info_of_all_datasets.values(): | |
semantic_classes_of_all_datasets += build_semantic_class_info(classes, all_aliases) | |
# print(semantic_classes_of_all_datasets) | |
grouped_classes_of_all_datasets = grouping(semantic_classes_of_all_datasets)#匹配过后的数据 | |
# print(grouped_classes_of_all_datasets) | |
# final_grouped_classes_of_all_datasets = [Counter(c).most_common()[0][0] for c in grouped_classes_of_all_datasets] | |
# use most common class name; if the same common, use shortest class name! | |
final_grouped_classes_of_all_datasets = [] | |
for c in grouped_classes_of_all_datasets: | |
counter = Counter(c).most_common() | |
max_times = counter[0][1] | |
candidate_class_names = [] | |
for item, times in counter: | |
if times < max_times: | |
break | |
candidate_class_names += [item] | |
candidate_class_names.sort(key=lambda x: len(x)) | |
final_grouped_classes_of_all_datasets += [candidate_class_names[0]] | |
res = {} | |
res_map = {d: {} for d in classes_info_of_all_datasets.keys()} | |
for dataset_name, (classes, _) in classes_info_of_all_datasets.items(): | |
final_classes = [] | |
for c in classes: | |
for grouped_names, final_name in zip(grouped_classes_of_all_datasets, final_grouped_classes_of_all_datasets): | |
if c in grouped_names: | |
final_classes += [final_name] | |
if final_name != c: | |
res_map[dataset_name][c] = final_name | |
break | |
res[dataset_name] = sorted(set(final_classes), key=final_classes.index) | |
return res, res_map | |
if __name__ == '__main__': | |
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] | |
cifar10_aliases = [['automobile', 'car']] | |
stl10_classes = ['airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck'] | |
final_classes_of_all_datasets, rename_map = merge_the_same_meaning_classes({ | |
'CIFAR10': (cifar10_classes, cifar10_aliases), | |
'STL10': (stl10_classes, []) | |
}) | |
print(final_classes_of_all_datasets, rename_map) | |