Spaces:
Running
Running
| """Loggers.""" | |
| import os | |
| from os.path import dirname, realpath, abspath | |
| from tqdm.auto import tqdm | |
| import numpy as np | |
| curr_filepath = abspath(__file__) | |
| repo_path = dirname(dirname(dirname(curr_filepath))) | |
| # repo_path = dirname(dirname(dirname(realpath(__file__)))) | |
| def tqdm_iterator(items, desc=None, bar_format=None, **kwargs): | |
| tqdm._instances.clear() | |
| iterator = tqdm( | |
| items, | |
| desc=desc, | |
| # bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}', | |
| **kwargs, | |
| ) | |
| tqdm._instances.clear() | |
| return iterator | |
| def print_retrieval_metrics_for_csv(metrics, scale=100): | |
| print_string = [ | |
| np.round(scale * metrics["R1"], 3), | |
| np.round(scale * metrics["R5"], 3), | |
| np.round(scale * metrics["R10"], 3), | |
| ] | |
| if "MR" in metrics: | |
| print_string += [metrics["MR"]] | |
| print() | |
| print("Final metrics: ", ",".join([str(x) for x in print_string])) | |
| print() | |
| def print_update(update, fillchar=":", color="yellow", pos="center"): | |
| from termcolor import colored | |
| # add ::: to the beginning and end of the update s.t. the total length of the | |
| # update spans the whole terminal | |
| try: | |
| terminal_width = os.get_terminal_size().columns - 2 | |
| except: | |
| terminal_width = 98 | |
| if pos == "center": | |
| update = update.center(len(update) + 2, " ") | |
| update = update.center(terminal_width, fillchar) | |
| elif pos == "left": | |
| update = update.ljust(terminal_width, fillchar) | |
| update = update.ljust(len(update) + 2, " ") | |
| elif pos == "right": | |
| update = update.rjust(terminal_width, fillchar) | |
| update = update.rjust(len(update) + 2, " ") | |
| else: | |
| raise ValueError("pos must be one of 'center', 'left', 'right'") | |
| print(colored(update, color)) | |
| def json_print(data, indent=4): | |
| import json | |
| print(json.dumps(data, indent=indent)) | |
| def get_terminal_width(): | |
| import shutil | |
| return shutil.get_terminal_size().columns | |
| if __name__ == "__main__": | |
| print("Repo path:", repo_path) | |