Spaces:
No application file
No application file
File size: 5,538 Bytes
668bf5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
from transformers import AutoTokenizer
# Hugging Face datasets ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ๊ธฐ ์ํ ์ํฌํธ
from datasets import Dataset
import torch # PyTorch ํ
์ ์ฐ์ฐ์ ์ํด ์ํฌํธ
# Mistral-7B-v0.3 ํ ํฌ๋์ด์ ๋ก๋
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.3")
# ํจ๋ฉ ํ ํฐ์ด ์๋ ๊ฒฝ์ฐ [PAD] ํ ํฐ ์ถ๊ฐ
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
def tokenize(examples, MAX_LEN=520):
"""
์ฃผ์ด์ง ์์ ๋ฐ์ดํฐ๋ฅผ ํ ํฌ๋์ด์งํ๊ณ , ๋ชจ๋ธ ์
๋ ฅ ๋ฐ ๋ ์ด๋ธ์ ์์ฑํฉ๋๋ค.
'prompt'/'completion' ๋๋ 'instruction'/'input'/'output' ๊ตฌ์กฐ๋ฅผ ์ง์ํฉ๋๋ค.
ํจ๋ฉ์ ์ ์ฉํ๊ณ ๋ ์ด๋ธ์ ํจ๋ฉ ํ ํฐ์ -100์ผ๋ก ๋ง์คํนํฉ๋๋ค.
๋ ์ด๋ธ ๋ง์คํน์ PyTorch ํ
์ ์ฐ์ฐ์ ์ฌ์ฉํ์ฌ ์ง์ ์ํํฉ๋๋ค.
Args:
examples (dict): ํ ํฌ๋์ด์งํ ๋ฐ์ดํฐ ์์ ๋์
๋๋ฆฌ.
'prompt'์ 'completion' ํค ๋๋
'instruction', 'input', 'output' ํค๋ฅผ ํฌํจํด์ผ ํฉ๋๋ค.
(map ํจ์์์ batched=True๋ก ์ฌ์ฉ ์ ๊ฐ ํค์ ๊ฐ์ ๋ฆฌ์คํธ์ฌ์ผ ํจ)
MAX_LEN (int): ์ต๋ ์ํ์ค ๊ธธ์ด. ์ด ๊ธธ์ด์ ๋ง์ถฐ ํจ๋ฉ ๋ฐ ์๋ผ๋ด๊ธฐ๊ฐ ์ํ๋ฉ๋๋ค.
Returns:
dict: 'input_ids', 'attention_mask', 'labels'๋ฅผ ํฌํจํ๋ ๋์
๋๋ฆฌ.
๊ฐ ๊ฐ์ PyTorch ํ
์ ํํ์
๋๋ค.
"""
EOS_TOKEN = tokenizer.eos_token # End-of-sequence ํ ํฐ
prompts = []
completions = []
# ๋ฐ์ดํฐ ๊ตฌ์กฐ์ ๋ฐ๋ผ prompt์ completion์ ์ค๋น
# case 1: prompt/completion ๊ตฌ์กฐ
if "prompt" in examples and "completion" in examples:
# ์
๋ ฅ์ด ๋จ์ผ ๋ฌธ์์ด์ผ ์๋ ์๊ณ , map ํจ์์์ batched=True ์ ๋ฆฌ์คํธ์ผ ์๋ ์์
_prompts = examples["prompt"] if isinstance(examples["prompt"], list) else [examples["prompt"]]
_completions = examples["completion"] if isinstance(examples["completion"], list) else [examples["completion"]]
# completion ๋์ EOS ํ ํฐ ์ถ๊ฐ
completions = [c + EOS_TOKEN for c in _completions]
prompts = _prompts
# case 2: instruction/input/output ๊ตฌ์กฐ (Alpaca ์คํ์ผ)
elif "instruction" in examples and "output" in examples:
_instructions = examples["instruction"] if isinstance(examples["instruction"], list) else [examples["instruction"]]
# 'input' ํค๊ฐ ์์ ๊ฒฝ์ฐ ๋น ๋ฌธ์์ด ๋ฆฌ์คํธ๋ก ์ฒ๋ฆฌ
_inputs = examples["input"] if "input" in examples else ["" for _ in _instructions]
_outputs = examples["output"] if isinstance(examples["output"], list) else [examples["output"]]
for inst, inp in zip(_instructions, _inputs):
prompt = f"### Instruction:\n{inst}\n\n"
if inp and inp.strip(): # input์ด ๋น์ด์์ง ์์ ๊ฒฝ์ฐ์๋ง ์ถ๊ฐ
prompt += f"### Input:\n{inp}\n\n"
prompt += "### Response:\n" # ์๋ต ์์ ๋ถ๋ถ
prompts.append(prompt)
# output ๋์ EOS ํ ํฐ ์ถ๊ฐ
completions = [o + EOS_TOKEN for o in _outputs]
else:
# ์ง์ํ์ง ์๋ ํค ๊ตฌ์กฐ์ผ ๊ฒฝ์ฐ ์๋ฌ ๋ฐ์
raise ValueError(f"์ง์ํ์ง ์๋ ํค ๊ตฌ์กฐ: {examples.keys()}")
# 1. ํ๋กฌํํธ ํ ํฌ๋์ด์ง (๋ชจ๋ธ ์
๋ ฅ)
# return_tensors='pt'๋ฅผ ์ฌ์ฉํ์ฌ PyTorch ํ
์๋ก ๋ฐํ
model_inputs = tokenizer(
prompts,
truncation=True, # MAX_LEN๋ณด๋ค ๊ธธ๋ฉด ์๋ผ๋
max_length=MAX_LEN, # ์ต๋ ๊ธธ์ด ์ค์
padding='max_length', # ๋ชจ๋ ์ํ์ค๋ฅผ MAX_LEN์ผ๋ก ํจ๋ฉ
return_tensors='pt' # PyTorch ํ
์๋ก ๋ฐํ
)
# 2. ์ปดํ๋ฆฌ์
ํ ํฌ๋์ด์ง (๋ ์ด๋ธ ์์ฑ์ฉ)
# return_tensors='pt'๋ฅผ ์ฌ์ฉํ์ฌ PyTorch ํ
์๋ก ๋ฐํ
label_outputs = tokenizer(
completions,
truncation=True,
max_length=MAX_LEN,
padding='max_length',
return_tensors='pt'
)
# ๐ฅ labels์์ pad_token_id๋ฅผ -100์ผ๋ก ๋ง์คํน
# PyTorch ํ
์ ์ฐ์ฐ์ ์ฌ์ฉํ์ฌ ์ง์ ๋ง์คํน ์ํ
labels = label_outputs["input_ids"].clone() # ์๋ณธ ํ
์๋ฅผ ๋ณต์ฌํ์ฌ ์์
labels[labels == tokenizer.pad_token_id] = -100 # ํจ๋ฉ ํ ํฐ ID๋ฅผ -100์ผ๋ก ๋ณ๊ฒฝ
# ํจ๋ฉ ๋ฐ ๋ง์คํน ํ input_ids์ labels์ ๊ธธ์ด๊ฐ ๋์ผํ์ง ํ์ธ (๋๋ฒ๊น
๋ชฉ์ )
# map ํจ์์์ batched=True๋ฅผ ์ฌ์ฉํ ๋ ๋ชจ๋ ์ํ์ค์ ๊ธธ์ด๊ฐ ๋์ผํด์ผ ํฉ๋๋ค.
# ์ด์ input_ids์ labels ๋ชจ๋ ํ
์์ด๋ฏ๋ก .shape[1]์ผ๋ก ๊ธธ์ด ํ์ธ
for i in range(model_inputs["input_ids"].shape[0]): # ๋ฐฐ์น ํฌ๊ธฐ๋งํผ ๋ฐ๋ณต
assert model_inputs["input_ids"].shape[1] == MAX_LEN, \
f"Input IDs length mismatch at index {i}: Expected {MAX_LEN}, Got {model_inputs['input_ids'].shape[1]}"
assert labels.shape[1] == MAX_LEN, \
f"Labels length mismatch at index {i}: Expected {MAX_LEN}, Got {labels.shape[1]}"
assert model_inputs["input_ids"].shape[1] == labels.shape[1], \
f"Input IDs and Labels length mismatch at index {i}: Input IDs {model_inputs['input_ids'].shape[1]}, Labels {labels.shape[1]}"
# ์ต์ข
๋ ์ด๋ธ์ model_inputs ๋์
๋๋ฆฌ์ ์ถ๊ฐ
model_inputs["labels"] = labels
return model_inputs |