File size: 5,538 Bytes
668bf5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from transformers import AutoTokenizer
# Hugging Face datasets ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•œ ์ž„ํฌํŠธ
from datasets import Dataset
import torch # PyTorch ํ…์„œ ์—ฐ์‚ฐ์„ ์œ„ํ•ด ์ž„ํฌํŠธ

# Mistral-7B-v0.3 ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.3")

# ํŒจ๋”ฉ ํ† ํฐ์ด ์—†๋Š” ๊ฒฝ์šฐ [PAD] ํ† ํฐ ์ถ”๊ฐ€
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

def tokenize(examples, MAX_LEN=520):
    """

    ์ฃผ์–ด์ง„ ์˜ˆ์‹œ ๋ฐ์ดํ„ฐ๋ฅผ ํ† ํฌ๋‚˜์ด์ง•ํ•˜๊ณ , ๋ชจ๋ธ ์ž…๋ ฅ ๋ฐ ๋ ˆ์ด๋ธ”์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.

    'prompt'/'completion' ๋˜๋Š” 'instruction'/'input'/'output' ๊ตฌ์กฐ๋ฅผ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค.

    ํŒจ๋”ฉ์„ ์ ์šฉํ•˜๊ณ  ๋ ˆ์ด๋ธ”์˜ ํŒจ๋”ฉ ํ† ํฐ์„ -100์œผ๋กœ ๋งˆ์Šคํ‚นํ•ฉ๋‹ˆ๋‹ค.

    ๋ ˆ์ด๋ธ” ๋งˆ์Šคํ‚น์€ PyTorch ํ…์„œ ์—ฐ์‚ฐ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ง์ ‘ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.



    Args:

        examples (dict): ํ† ํฌ๋‚˜์ด์ง•ํ•  ๋ฐ์ดํ„ฐ ์˜ˆ์‹œ ๋”•์…”๋„ˆ๋ฆฌ.

                         'prompt'์™€ 'completion' ํ‚ค ๋˜๋Š”

                         'instruction', 'input', 'output' ํ‚ค๋ฅผ ํฌํ•จํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

                         (map ํ•จ์ˆ˜์—์„œ batched=True๋กœ ์‚ฌ์šฉ ์‹œ ๊ฐ ํ‚ค์˜ ๊ฐ’์€ ๋ฆฌ์ŠคํŠธ์—ฌ์•ผ ํ•จ)

        MAX_LEN (int): ์ตœ๋Œ€ ์‹œํ€€์Šค ๊ธธ์ด. ์ด ๊ธธ์ด์— ๋งž์ถฐ ํŒจ๋”ฉ ๋ฐ ์ž˜๋ผ๋‚ด๊ธฐ๊ฐ€ ์ˆ˜ํ–‰๋ฉ๋‹ˆ๋‹ค.



    Returns:

        dict: 'input_ids', 'attention_mask', 'labels'๋ฅผ ํฌํ•จํ•˜๋Š” ๋”•์…”๋„ˆ๋ฆฌ.

              ๊ฐ ๊ฐ’์€ PyTorch ํ…์„œ ํ˜•ํƒœ์ž…๋‹ˆ๋‹ค.

    """
    EOS_TOKEN = tokenizer.eos_token # End-of-sequence ํ† ํฐ

    prompts = []
    completions = []

    # ๋ฐ์ดํ„ฐ ๊ตฌ์กฐ์— ๋”ฐ๋ผ prompt์™€ completion์„ ์ค€๋น„
    # case 1: prompt/completion ๊ตฌ์กฐ
    if "prompt" in examples and "completion" in examples:
        # ์ž…๋ ฅ์ด ๋‹จ์ผ ๋ฌธ์ž์—ด์ผ ์ˆ˜๋„ ์žˆ๊ณ , map ํ•จ์ˆ˜์—์„œ batched=True ์‹œ ๋ฆฌ์ŠคํŠธ์ผ ์ˆ˜๋„ ์žˆ์Œ
        _prompts = examples["prompt"] if isinstance(examples["prompt"], list) else [examples["prompt"]]
        _completions = examples["completion"] if isinstance(examples["completion"], list) else [examples["completion"]]
        
        # completion ๋์— EOS ํ† ํฐ ์ถ”๊ฐ€
        completions = [c + EOS_TOKEN for c in _completions]
        prompts = _prompts

    # case 2: instruction/input/output ๊ตฌ์กฐ (Alpaca ์Šคํƒ€์ผ)
    elif "instruction" in examples and "output" in examples:
        _instructions = examples["instruction"] if isinstance(examples["instruction"], list) else [examples["instruction"]]
        # 'input' ํ‚ค๊ฐ€ ์—†์„ ๊ฒฝ์šฐ ๋นˆ ๋ฌธ์ž์—ด ๋ฆฌ์ŠคํŠธ๋กœ ์ฒ˜๋ฆฌ
        _inputs = examples["input"] if "input" in examples else ["" for _ in _instructions]
        _outputs = examples["output"] if isinstance(examples["output"], list) else [examples["output"]]

        for inst, inp in zip(_instructions, _inputs):
            prompt = f"### Instruction:\n{inst}\n\n"
            if inp and inp.strip(): # input์ด ๋น„์–ด์žˆ์ง€ ์•Š์€ ๊ฒฝ์šฐ์—๋งŒ ์ถ”๊ฐ€
                prompt += f"### Input:\n{inp}\n\n"
            prompt += "### Response:\n" # ์‘๋‹ต ์‹œ์ž‘ ๋ถ€๋ถ„
            prompts.append(prompt)
        
        # output ๋์— EOS ํ† ํฐ ์ถ”๊ฐ€
        completions = [o + EOS_TOKEN for o in _outputs]

    else:
        # ์ง€์›ํ•˜์ง€ ์•Š๋Š” ํ‚ค ๊ตฌ์กฐ์ผ ๊ฒฝ์šฐ ์—๋Ÿฌ ๋ฐœ์ƒ
        raise ValueError(f"์ง€์›ํ•˜์ง€ ์•Š๋Š” ํ‚ค ๊ตฌ์กฐ: {examples.keys()}")

    # 1. ํ”„๋กฌํ”„ํŠธ ํ† ํฌ๋‚˜์ด์ง• (๋ชจ๋ธ ์ž…๋ ฅ)
    # return_tensors='pt'๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ PyTorch ํ…์„œ๋กœ ๋ฐ˜ํ™˜
    model_inputs = tokenizer(
        prompts,
        truncation=True,        # MAX_LEN๋ณด๋‹ค ๊ธธ๋ฉด ์ž˜๋ผ๋ƒ„
        max_length=MAX_LEN,     # ์ตœ๋Œ€ ๊ธธ์ด ์„ค์ •
        padding='max_length',   # ๋ชจ๋“  ์‹œํ€€์Šค๋ฅผ MAX_LEN์œผ๋กœ ํŒจ๋”ฉ
        return_tensors='pt'     # PyTorch ํ…์„œ๋กœ ๋ฐ˜ํ™˜
    )

    # 2. ์ปดํ”Œ๋ฆฌ์…˜ ํ† ํฌ๋‚˜์ด์ง• (๋ ˆ์ด๋ธ” ์ƒ์„ฑ์šฉ)
    # return_tensors='pt'๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ PyTorch ํ…์„œ๋กœ ๋ฐ˜ํ™˜
    label_outputs = tokenizer(
        completions,
        truncation=True,
        max_length=MAX_LEN,
        padding='max_length',
        return_tensors='pt'
    )

    # ๐Ÿ”ฅ labels์—์„œ pad_token_id๋ฅผ -100์œผ๋กœ ๋งˆ์Šคํ‚น
    # PyTorch ํ…์„œ ์—ฐ์‚ฐ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ง์ ‘ ๋งˆ์Šคํ‚น ์ˆ˜ํ–‰
    labels = label_outputs["input_ids"].clone() # ์›๋ณธ ํ…์„œ๋ฅผ ๋ณต์‚ฌํ•˜์—ฌ ์ˆ˜์ •
    labels[labels == tokenizer.pad_token_id] = -100 # ํŒจ๋”ฉ ํ† ํฐ ID๋ฅผ -100์œผ๋กœ ๋ณ€๊ฒฝ

    # ํŒจ๋”ฉ ๋ฐ ๋งˆ์Šคํ‚น ํ›„ input_ids์™€ labels์˜ ๊ธธ์ด๊ฐ€ ๋™์ผํ•œ์ง€ ํ™•์ธ (๋””๋ฒ„๊น… ๋ชฉ์ )
    # map ํ•จ์ˆ˜์—์„œ batched=True๋ฅผ ์‚ฌ์šฉํ•  ๋•Œ ๋ชจ๋“  ์‹œํ€€์Šค์˜ ๊ธธ์ด๊ฐ€ ๋™์ผํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
    # ์ด์ œ input_ids์™€ labels ๋ชจ๋‘ ํ…์„œ์ด๋ฏ€๋กœ .shape[1]์œผ๋กœ ๊ธธ์ด ํ™•์ธ
    for i in range(model_inputs["input_ids"].shape[0]): # ๋ฐฐ์น˜ ํฌ๊ธฐ๋งŒํผ ๋ฐ˜๋ณต
        assert model_inputs["input_ids"].shape[1] == MAX_LEN, \
            f"Input IDs length mismatch at index {i}: Expected {MAX_LEN}, Got {model_inputs['input_ids'].shape[1]}"
        assert labels.shape[1] == MAX_LEN, \
            f"Labels length mismatch at index {i}: Expected {MAX_LEN}, Got {labels.shape[1]}"
        assert model_inputs["input_ids"].shape[1] == labels.shape[1], \
            f"Input IDs and Labels length mismatch at index {i}: Input IDs {model_inputs['input_ids'].shape[1]}, Labels {labels.shape[1]}"

    # ์ตœ์ข… ๋ ˆ์ด๋ธ”์„ model_inputs ๋”•์…”๋„ˆ๋ฆฌ์— ์ถ”๊ฐ€
    model_inputs["labels"] = labels
    return model_inputs