|
from compute_perp import check_equal |
|
import multiprocessing, json, os, time |
|
|
|
def solve(predict, answer): |
|
cache_dict = {} |
|
m = len(predict) |
|
|
|
for i in range(m): |
|
key = str(predict[i]) + "<##>" + str(answer) |
|
rev_key = str(answer) + "<##>" + str(predict[i]) |
|
if key in cache_dict or rev_key in cache_dict: |
|
continue |
|
val = check_equal(predict[i], answer) |
|
cache_dict[key] = val |
|
cache_dict[rev_key] = val |
|
|
|
for i in range(m): |
|
for j in range(m): |
|
key = str(predict[i]) + "<##>" + str(predict[j]) |
|
rev_key = str(predict[j]) + "<##>" + str(predict[i]) |
|
if key in cache_dict or rev_key in cache_dict: |
|
continue |
|
val = check_equal(predict[i], predict[j]) |
|
cache_dict[key] = val |
|
cache_dict[rev_key] = val |
|
|
|
return cache_dict |
|
|
|
def cache(data, cache_path): |
|
if os.path.exists(cache_path): |
|
print(f"Cache file {cache_path} exists, skip!") |
|
return |
|
start_time = time.time() |
|
predicts = data["predict"] |
|
answers = data["answer"] |
|
n = len(predicts) |
|
cache_dict = {} |
|
with multiprocessing.Pool() as pool: |
|
results = pool.starmap( |
|
solve, [(predicts[i], answers[i]) for i in range(n)] |
|
) |
|
for result in results: |
|
cache_dict.update(result) |
|
with open(cache_path, "w") as fw: |
|
json.dump(cache_dict, fw) |
|
print(f"Cache file {cache_path} built in {time.time() - start_time:.2f}S") |