File size: 2,192 Bytes
668bf5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def inspect_tokenized_dataset(dataset, num_samples=3):
    print(f"\n๐Ÿ” Inspecting first {num_samples} samples in tokenized dataset...")

    for i in range(min(num_samples, len(dataset))):
        sample = dataset[i]
        print(f"\n=== Sample {i} ===")
        print("input_ids:", sample.get("input_ids"))
        print("attention_mask:", sample.get("attention_mask"))
        print("labels:", sample.get("labels"))

        # ํƒ€์ž… ํ™•์ธ
        if not isinstance(sample.get("labels"), list):
            print(f"โŒ Sample {i}: labels๊ฐ€ list๊ฐ€ ์•„๋‹˜ โ†’ {type(sample.get('labels'))}")
        elif not all(isinstance(x, int) for x in sample["labels"]):
            print(f"โŒ Sample {i}: labels ๋‚ด๋ถ€์— int๊ฐ€ ์•„๋‹Œ ๊ฐ’ ์กด์žฌ")
        else:
            print(f"โœ… Sample {i}: labels ๊ตฌ์กฐ ์ •์ƒ")

    # ๊ธธ์ด ๋น„๊ต
    input_lens = [len(s["input_ids"]) for s in dataset[:num_samples]]
    label_lens = [len(s["labels"]) for s in dataset[:num_samples]]

    print("\n๐Ÿ“ input_ids ๊ธธ์ด:", input_lens)
    print("๐Ÿ“ labels ๊ธธ์ด:   ", label_lens)

def print_label_lengths(dataset):
    lengths = [len(sample["labels"]) for sample in dataset]
    print(f"[๋””๋ฒ„๊น…] labels ๊ธธ์ด - min: {min(lengths)}, max: {max(lengths)}, mean: {sum(lengths) / len(lengths):.2f}")
    # ์˜ˆ์‹œ: 10๊ฐœ ์ƒ˜ํ”Œ ์‹ค์ œ ๊ธธ์ด ์ง์ ‘ ํ™•์ธ
    print("[๋””๋ฒ„๊น…] ์ƒ˜ํ”Œ๋ณ„ labels ๊ธธ์ด (์ƒ์œ„ 10๊ฐœ):", lengths[:10])

def print_field_lengths(dataset, n=10, stage=""):
    """

    ๋ฐ์ดํ„ฐ์…‹์—์„œ input_ids, attention_mask, labels ๊ธธ์ด ๋ถ„ํฌ ๋ฐ ์ƒ˜ํ”Œ ํ‘œ์‹œ (์ƒ์œ„ n๊ฐœ)

    """
    print(f"\n[๋””๋ฒ„๊น…][{stage}] ๊ธธ์ด ํ†ต๊ณ„ ================================")
    for key in ["input_ids", "attention_mask", "labels"]:
        try:
            lengths = [len(x[key]) for x in dataset]
            print(f"{key} โ†’ min: {min(lengths)}, max: {max(lengths)}, mean: {sum(lengths)/len(lengths):.2f}")
            print(f"{key} ์ƒ˜ํ”Œ (์ƒ์œ„ {n}๊ฐœ):", lengths[:n])
        except Exception as e:
            print(f"{key}: (์กด์žฌํ•˜์ง€ ์•Š๊ฑฐ๋‚˜ ์—๋Ÿฌ) {e}")
    print("====================================================\n")