from compute_perp import check_equal import multiprocessing, json, os, time def solve(predict, answer): cache_dict = {} m = len(predict) for i in range(m): key = str(predict[i]) + "<##>" + str(answer) rev_key = str(answer) + "<##>" + str(predict[i]) if key in cache_dict or rev_key in cache_dict: continue val = check_equal(predict[i], answer) cache_dict[key] = val cache_dict[rev_key] = val for i in range(m): for j in range(m): key = str(predict[i]) + "<##>" + str(predict[j]) rev_key = str(predict[j]) + "<##>" + str(predict[i]) if key in cache_dict or rev_key in cache_dict: continue val = check_equal(predict[i], predict[j]) cache_dict[key] = val cache_dict[rev_key] = val return cache_dict def cache(data, cache_path): if os.path.exists(cache_path): print(f"Cache file {cache_path} exists, skip!") return start_time = time.time() predicts = data["predict"] answers = data["answer"] n = len(predicts) cache_dict = {} with multiprocessing.Pool() as pool: results = pool.starmap( solve, [(predicts[i], answers[i]) for i in range(n)] ) for result in results: cache_dict.update(result) with open(cache_path, "w") as fw: json.dump(cache_dict, fw) print(f"Cache file {cache_path} built in {time.time() - start_time:.2f}S")