Spaces:
No application file
No application file
from transformers import AutoTokenizer | |
# Hugging Face datasets ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ๊ธฐ ์ํ ์ํฌํธ | |
from datasets import Dataset | |
import torch # PyTorch ํ ์ ์ฐ์ฐ์ ์ํด ์ํฌํธ | |
# Mistral-7B-v0.3 ํ ํฌ๋์ด์ ๋ก๋ | |
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.3") | |
# ํจ๋ฉ ํ ํฐ์ด ์๋ ๊ฒฝ์ฐ [PAD] ํ ํฐ ์ถ๊ฐ | |
if tokenizer.pad_token is None: | |
tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
def tokenize(examples, MAX_LEN=520): | |
""" | |
์ฃผ์ด์ง ์์ ๋ฐ์ดํฐ๋ฅผ ํ ํฌ๋์ด์งํ๊ณ , ๋ชจ๋ธ ์ ๋ ฅ ๋ฐ ๋ ์ด๋ธ์ ์์ฑํฉ๋๋ค. | |
'prompt'/'completion' ๋๋ 'instruction'/'input'/'output' ๊ตฌ์กฐ๋ฅผ ์ง์ํฉ๋๋ค. | |
ํจ๋ฉ์ ์ ์ฉํ๊ณ ๋ ์ด๋ธ์ ํจ๋ฉ ํ ํฐ์ -100์ผ๋ก ๋ง์คํนํฉ๋๋ค. | |
๋ ์ด๋ธ ๋ง์คํน์ PyTorch ํ ์ ์ฐ์ฐ์ ์ฌ์ฉํ์ฌ ์ง์ ์ํํฉ๋๋ค. | |
Args: | |
examples (dict): ํ ํฌ๋์ด์งํ ๋ฐ์ดํฐ ์์ ๋์ ๋๋ฆฌ. | |
'prompt'์ 'completion' ํค ๋๋ | |
'instruction', 'input', 'output' ํค๋ฅผ ํฌํจํด์ผ ํฉ๋๋ค. | |
(map ํจ์์์ batched=True๋ก ์ฌ์ฉ ์ ๊ฐ ํค์ ๊ฐ์ ๋ฆฌ์คํธ์ฌ์ผ ํจ) | |
MAX_LEN (int): ์ต๋ ์ํ์ค ๊ธธ์ด. ์ด ๊ธธ์ด์ ๋ง์ถฐ ํจ๋ฉ ๋ฐ ์๋ผ๋ด๊ธฐ๊ฐ ์ํ๋ฉ๋๋ค. | |
Returns: | |
dict: 'input_ids', 'attention_mask', 'labels'๋ฅผ ํฌํจํ๋ ๋์ ๋๋ฆฌ. | |
๊ฐ ๊ฐ์ PyTorch ํ ์ ํํ์ ๋๋ค. | |
""" | |
EOS_TOKEN = tokenizer.eos_token # End-of-sequence ํ ํฐ | |
prompts = [] | |
completions = [] | |
# ๋ฐ์ดํฐ ๊ตฌ์กฐ์ ๋ฐ๋ผ prompt์ completion์ ์ค๋น | |
# case 1: prompt/completion ๊ตฌ์กฐ | |
if "prompt" in examples and "completion" in examples: | |
# ์ ๋ ฅ์ด ๋จ์ผ ๋ฌธ์์ด์ผ ์๋ ์๊ณ , map ํจ์์์ batched=True ์ ๋ฆฌ์คํธ์ผ ์๋ ์์ | |
_prompts = examples["prompt"] if isinstance(examples["prompt"], list) else [examples["prompt"]] | |
_completions = examples["completion"] if isinstance(examples["completion"], list) else [examples["completion"]] | |
# completion ๋์ EOS ํ ํฐ ์ถ๊ฐ | |
completions = [c + EOS_TOKEN for c in _completions] | |
prompts = _prompts | |
# case 2: instruction/input/output ๊ตฌ์กฐ (Alpaca ์คํ์ผ) | |
elif "instruction" in examples and "output" in examples: | |
_instructions = examples["instruction"] if isinstance(examples["instruction"], list) else [examples["instruction"]] | |
# 'input' ํค๊ฐ ์์ ๊ฒฝ์ฐ ๋น ๋ฌธ์์ด ๋ฆฌ์คํธ๋ก ์ฒ๋ฆฌ | |
_inputs = examples["input"] if "input" in examples else ["" for _ in _instructions] | |
_outputs = examples["output"] if isinstance(examples["output"], list) else [examples["output"]] | |
for inst, inp in zip(_instructions, _inputs): | |
prompt = f"### Instruction:\n{inst}\n\n" | |
if inp and inp.strip(): # input์ด ๋น์ด์์ง ์์ ๊ฒฝ์ฐ์๋ง ์ถ๊ฐ | |
prompt += f"### Input:\n{inp}\n\n" | |
prompt += "### Response:\n" # ์๋ต ์์ ๋ถ๋ถ | |
prompts.append(prompt) | |
# output ๋์ EOS ํ ํฐ ์ถ๊ฐ | |
completions = [o + EOS_TOKEN for o in _outputs] | |
else: | |
# ์ง์ํ์ง ์๋ ํค ๊ตฌ์กฐ์ผ ๊ฒฝ์ฐ ์๋ฌ ๋ฐ์ | |
raise ValueError(f"์ง์ํ์ง ์๋ ํค ๊ตฌ์กฐ: {examples.keys()}") | |
# 1. ํ๋กฌํํธ ํ ํฌ๋์ด์ง (๋ชจ๋ธ ์ ๋ ฅ) | |
# return_tensors='pt'๋ฅผ ์ฌ์ฉํ์ฌ PyTorch ํ ์๋ก ๋ฐํ | |
model_inputs = tokenizer( | |
prompts, | |
truncation=True, # MAX_LEN๋ณด๋ค ๊ธธ๋ฉด ์๋ผ๋ | |
max_length=MAX_LEN, # ์ต๋ ๊ธธ์ด ์ค์ | |
padding='max_length', # ๋ชจ๋ ์ํ์ค๋ฅผ MAX_LEN์ผ๋ก ํจ๋ฉ | |
return_tensors='pt' # PyTorch ํ ์๋ก ๋ฐํ | |
) | |
# 2. ์ปดํ๋ฆฌ์ ํ ํฌ๋์ด์ง (๋ ์ด๋ธ ์์ฑ์ฉ) | |
# return_tensors='pt'๋ฅผ ์ฌ์ฉํ์ฌ PyTorch ํ ์๋ก ๋ฐํ | |
label_outputs = tokenizer( | |
completions, | |
truncation=True, | |
max_length=MAX_LEN, | |
padding='max_length', | |
return_tensors='pt' | |
) | |
# ๐ฅ labels์์ pad_token_id๋ฅผ -100์ผ๋ก ๋ง์คํน | |
# PyTorch ํ ์ ์ฐ์ฐ์ ์ฌ์ฉํ์ฌ ์ง์ ๋ง์คํน ์ํ | |
labels = label_outputs["input_ids"].clone() # ์๋ณธ ํ ์๋ฅผ ๋ณต์ฌํ์ฌ ์์ | |
labels[labels == tokenizer.pad_token_id] = -100 # ํจ๋ฉ ํ ํฐ ID๋ฅผ -100์ผ๋ก ๋ณ๊ฒฝ | |
# ํจ๋ฉ ๋ฐ ๋ง์คํน ํ input_ids์ labels์ ๊ธธ์ด๊ฐ ๋์ผํ์ง ํ์ธ (๋๋ฒ๊น ๋ชฉ์ ) | |
# map ํจ์์์ batched=True๋ฅผ ์ฌ์ฉํ ๋ ๋ชจ๋ ์ํ์ค์ ๊ธธ์ด๊ฐ ๋์ผํด์ผ ํฉ๋๋ค. | |
# ์ด์ input_ids์ labels ๋ชจ๋ ํ ์์ด๋ฏ๋ก .shape[1]์ผ๋ก ๊ธธ์ด ํ์ธ | |
for i in range(model_inputs["input_ids"].shape[0]): # ๋ฐฐ์น ํฌ๊ธฐ๋งํผ ๋ฐ๋ณต | |
assert model_inputs["input_ids"].shape[1] == MAX_LEN, \ | |
f"Input IDs length mismatch at index {i}: Expected {MAX_LEN}, Got {model_inputs['input_ids'].shape[1]}" | |
assert labels.shape[1] == MAX_LEN, \ | |
f"Labels length mismatch at index {i}: Expected {MAX_LEN}, Got {labels.shape[1]}" | |
assert model_inputs["input_ids"].shape[1] == labels.shape[1], \ | |
f"Input IDs and Labels length mismatch at index {i}: Input IDs {model_inputs['input_ids'].shape[1]}, Labels {labels.shape[1]}" | |
# ์ต์ข ๋ ์ด๋ธ์ model_inputs ๋์ ๋๋ฆฌ์ ์ถ๊ฐ | |
model_inputs["labels"] = labels | |
return model_inputs |