File size: 1,726 Bytes
c887522
ed67886
c887522
 
 
416ebf1
272ff8c
 
 
 
8cfcd49
c887522
416ebf1
 
34a2915
 
 
 
416ebf1
c887522
 
 
8cfcd49
 
 
c887522
8cfcd49
416ebf1
ed67886
34a2915
8cfcd49
34a2915
 
8cfcd49
34a2915
 
 
8cfcd49
 
80fb2c0
8cfcd49
 
 
 
 
34a2915
 
 
 
 
 
8cfcd49
80fb2c0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import functools
import time

from datasets import load_dataset

from src.envs import CODE_PROBLEMS_REPO, RESULTS_REPO, SUBMISSIONS_REPO, TOKEN
from src.logger import get_logger

logger = get_logger(__name__)


class F1Data:
    def __init__(
        self,
        cp_ds_name: str,  # Name of the dataset. Fixed.
        sub_ds_name: str,  # Name of subdataset. Fixed.
        res_ds_name: str,  # Name of results repository. Fixed.
        split: str = "hard",  # Split is either 'hard' or 'easy'.
    ):
        self.cp_dataset_name = cp_ds_name
        self.submissions_dataset_name = sub_ds_name
        self.results_dataset_name = res_ds_name
        self.split = split
        self.code_problems = None
        self._initialize()

    def _initialize(self):
        logger.info(f"Initialize F1Data TOKEN='{TOKEN}'")
        start_time = time.monotonic()
        cp_ds = load_dataset(
            self.cp_dataset_name,
            split=self.split,
            token=TOKEN,
        )
        logger.info(f"Loaded code-problems dataset from {self.cp_dataset_name} in {time.monotonic() - start_time} sec")
        self.code_problems = {r["id"]: r["code_problem"] for r in cp_ds}  # id string -> code problem.
        logger.info(f"Loaded {len(self.code_problems)} code problems")

    @functools.cached_property
    def code_problem_ids(self) -> set[str]:
        return set(self.code_problems.keys())


if __name__ == "__main__":
    split = "hard"
    f1_data = F1Data(
        cp_ds_name=CODE_PROBLEMS_REPO,
        sub_ds_name=SUBMISSIONS_REPO,
        res_ds_name=RESULTS_REPO,
        split=split,
    )

    print(f"Found {len(f1_data.code_problem_ids)} code problems in {split} split of {f1_data.cp_dataset_name}")