import argparse import fnmatch import json import os from typing import Iterator from datasets import Dataset from src.envs import CODE_PROBLEMS_REPO from src.logger import get_logger logger = get_logger(__name__) def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( "--input_dir", type=str, help="Dir with .json files", required=True, ) parser.add_argument( "--dataset_name", type=str, default=f"{CODE_PROBLEMS_REPO}", ) parser.add_argument( "--split", type=str, choices=["hard", "warmup"], default="hard", ) return parser.parse_args() def main(args: argparse.Namespace) -> None: logger.info("Reading problem files from %s", args.input_dir) input_files = fnmatch.filter(os.listdir(args.input_dir), "*.json") if len(input_files) == 0: raise ValueError(f"No .json files in input dir {args.input_dir}") logger.info("Found %d code problems in %s", len(input_files), args.input_dir) def ds_generator() -> Iterator[dict]: for fname in sorted(input_files): formula_name = os.path.splitext(fname)[0] cp_path = os.path.join(args.input_dir, fname) with open(cp_path, "r", encoding="utf-8") as f: code_problem = json.load(f) logger.info("Read code problem for formula %s from %s", formula_name, cp_path) yield dict(id=code_problem["id"], code_problem=code_problem) ds: Dataset = Dataset.from_generator(ds_generator) # type: ignore logger.info("Created dataset") ds.push_to_hub(args.dataset_name, split=args.split, private=True) logger.info("Saved dataset to repo %s", args.dataset_name) if __name__ == "__main__": main(get_args())