File size: 1,701 Bytes
8cfcd49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import argparse
import fnmatch
import json
import os

from datasets import Dataset

from src.envs import CODE_PROBLEMS_REPO
from src.logger import get_logger

logger = get_logger(__name__)


def get_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--input_dir", type=str, help="Dir with .json files", required=True)
    parser.add_argument("--dataset_name", type=str, default=f"{CODE_PROBLEMS_REPO}")
    parser.add_argument("--split", type=str, choices=["hard", "warmup"], default="hard")
    return parser.parse_args()


def main(args: argparse.Namespace) -> None:
    logger.info("Reading problem files from %s", args.input_dir)
    input_files = fnmatch.filter(os.listdir(args.input_dir), "*.json")
    if len(input_files) == 0:
        raise ValueError(f"No .json files in input dir {args.input_dir}")
    logger.info("Found %d code problems in %s", len(input_files), args.input_dir)

    def ds_generator():
        for fname in sorted(input_files):
            formula_name = os.path.splitext(fname)[0]
            cp_path = os.path.join(args.input_dir, fname)
            with open(cp_path, "r", encoding="utf-8") as f:
                code_problem = json.load(f)
            logger.info("Read code problem for formula %s from %s", formula_name, cp_path)
            yield dict(id=code_problem["id"], code_problem=code_problem)

    ds = Dataset.from_generator(ds_generator)
    logger.info("Created dataset")

    ds.push_to_hub(args.dataset_name, split=args.split, private=True)
    logger.info("Saved dataset to repo %s", args.dataset_name)


if __name__ == "__main__":
    main(get_args())