mistralai-distributed / tokenizer.py
changgyu's picture
Upload 19 files
668bf5d verified
from transformers import AutoTokenizer
# Hugging Face datasets ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•œ ์ž„ํฌํŠธ
from datasets import Dataset
import torch # PyTorch ํ…์„œ ์—ฐ์‚ฐ์„ ์œ„ํ•ด ์ž„ํฌํŠธ
# Mistral-7B-v0.3 ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.3")
# ํŒจ๋”ฉ ํ† ํฐ์ด ์—†๋Š” ๊ฒฝ์šฐ [PAD] ํ† ํฐ ์ถ”๊ฐ€
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
def tokenize(examples, MAX_LEN=520):
"""
์ฃผ์–ด์ง„ ์˜ˆ์‹œ ๋ฐ์ดํ„ฐ๋ฅผ ํ† ํฌ๋‚˜์ด์ง•ํ•˜๊ณ , ๋ชจ๋ธ ์ž…๋ ฅ ๋ฐ ๋ ˆ์ด๋ธ”์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
'prompt'/'completion' ๋˜๋Š” 'instruction'/'input'/'output' ๊ตฌ์กฐ๋ฅผ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค.
ํŒจ๋”ฉ์„ ์ ์šฉํ•˜๊ณ  ๋ ˆ์ด๋ธ”์˜ ํŒจ๋”ฉ ํ† ํฐ์„ -100์œผ๋กœ ๋งˆ์Šคํ‚นํ•ฉ๋‹ˆ๋‹ค.
๋ ˆ์ด๋ธ” ๋งˆ์Šคํ‚น์€ PyTorch ํ…์„œ ์—ฐ์‚ฐ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ง์ ‘ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
Args:
examples (dict): ํ† ํฌ๋‚˜์ด์ง•ํ•  ๋ฐ์ดํ„ฐ ์˜ˆ์‹œ ๋”•์…”๋„ˆ๋ฆฌ.
'prompt'์™€ 'completion' ํ‚ค ๋˜๋Š”
'instruction', 'input', 'output' ํ‚ค๋ฅผ ํฌํ•จํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
(map ํ•จ์ˆ˜์—์„œ batched=True๋กœ ์‚ฌ์šฉ ์‹œ ๊ฐ ํ‚ค์˜ ๊ฐ’์€ ๋ฆฌ์ŠคํŠธ์—ฌ์•ผ ํ•จ)
MAX_LEN (int): ์ตœ๋Œ€ ์‹œํ€€์Šค ๊ธธ์ด. ์ด ๊ธธ์ด์— ๋งž์ถฐ ํŒจ๋”ฉ ๋ฐ ์ž˜๋ผ๋‚ด๊ธฐ๊ฐ€ ์ˆ˜ํ–‰๋ฉ๋‹ˆ๋‹ค.
Returns:
dict: 'input_ids', 'attention_mask', 'labels'๋ฅผ ํฌํ•จํ•˜๋Š” ๋”•์…”๋„ˆ๋ฆฌ.
๊ฐ ๊ฐ’์€ PyTorch ํ…์„œ ํ˜•ํƒœ์ž…๋‹ˆ๋‹ค.
"""
EOS_TOKEN = tokenizer.eos_token # End-of-sequence ํ† ํฐ
prompts = []
completions = []
# ๋ฐ์ดํ„ฐ ๊ตฌ์กฐ์— ๋”ฐ๋ผ prompt์™€ completion์„ ์ค€๋น„
# case 1: prompt/completion ๊ตฌ์กฐ
if "prompt" in examples and "completion" in examples:
# ์ž…๋ ฅ์ด ๋‹จ์ผ ๋ฌธ์ž์—ด์ผ ์ˆ˜๋„ ์žˆ๊ณ , map ํ•จ์ˆ˜์—์„œ batched=True ์‹œ ๋ฆฌ์ŠคํŠธ์ผ ์ˆ˜๋„ ์žˆ์Œ
_prompts = examples["prompt"] if isinstance(examples["prompt"], list) else [examples["prompt"]]
_completions = examples["completion"] if isinstance(examples["completion"], list) else [examples["completion"]]
# completion ๋์— EOS ํ† ํฐ ์ถ”๊ฐ€
completions = [c + EOS_TOKEN for c in _completions]
prompts = _prompts
# case 2: instruction/input/output ๊ตฌ์กฐ (Alpaca ์Šคํƒ€์ผ)
elif "instruction" in examples and "output" in examples:
_instructions = examples["instruction"] if isinstance(examples["instruction"], list) else [examples["instruction"]]
# 'input' ํ‚ค๊ฐ€ ์—†์„ ๊ฒฝ์šฐ ๋นˆ ๋ฌธ์ž์—ด ๋ฆฌ์ŠคํŠธ๋กœ ์ฒ˜๋ฆฌ
_inputs = examples["input"] if "input" in examples else ["" for _ in _instructions]
_outputs = examples["output"] if isinstance(examples["output"], list) else [examples["output"]]
for inst, inp in zip(_instructions, _inputs):
prompt = f"### Instruction:\n{inst}\n\n"
if inp and inp.strip(): # input์ด ๋น„์–ด์žˆ์ง€ ์•Š์€ ๊ฒฝ์šฐ์—๋งŒ ์ถ”๊ฐ€
prompt += f"### Input:\n{inp}\n\n"
prompt += "### Response:\n" # ์‘๋‹ต ์‹œ์ž‘ ๋ถ€๋ถ„
prompts.append(prompt)
# output ๋์— EOS ํ† ํฐ ์ถ”๊ฐ€
completions = [o + EOS_TOKEN for o in _outputs]
else:
# ์ง€์›ํ•˜์ง€ ์•Š๋Š” ํ‚ค ๊ตฌ์กฐ์ผ ๊ฒฝ์šฐ ์—๋Ÿฌ ๋ฐœ์ƒ
raise ValueError(f"์ง€์›ํ•˜์ง€ ์•Š๋Š” ํ‚ค ๊ตฌ์กฐ: {examples.keys()}")
# 1. ํ”„๋กฌํ”„ํŠธ ํ† ํฌ๋‚˜์ด์ง• (๋ชจ๋ธ ์ž…๋ ฅ)
# return_tensors='pt'๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ PyTorch ํ…์„œ๋กœ ๋ฐ˜ํ™˜
model_inputs = tokenizer(
prompts,
truncation=True, # MAX_LEN๋ณด๋‹ค ๊ธธ๋ฉด ์ž˜๋ผ๋ƒ„
max_length=MAX_LEN, # ์ตœ๋Œ€ ๊ธธ์ด ์„ค์ •
padding='max_length', # ๋ชจ๋“  ์‹œํ€€์Šค๋ฅผ MAX_LEN์œผ๋กœ ํŒจ๋”ฉ
return_tensors='pt' # PyTorch ํ…์„œ๋กœ ๋ฐ˜ํ™˜
)
# 2. ์ปดํ”Œ๋ฆฌ์…˜ ํ† ํฌ๋‚˜์ด์ง• (๋ ˆ์ด๋ธ” ์ƒ์„ฑ์šฉ)
# return_tensors='pt'๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ PyTorch ํ…์„œ๋กœ ๋ฐ˜ํ™˜
label_outputs = tokenizer(
completions,
truncation=True,
max_length=MAX_LEN,
padding='max_length',
return_tensors='pt'
)
# ๐Ÿ”ฅ labels์—์„œ pad_token_id๋ฅผ -100์œผ๋กœ ๋งˆ์Šคํ‚น
# PyTorch ํ…์„œ ์—ฐ์‚ฐ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ง์ ‘ ๋งˆ์Šคํ‚น ์ˆ˜ํ–‰
labels = label_outputs["input_ids"].clone() # ์›๋ณธ ํ…์„œ๋ฅผ ๋ณต์‚ฌํ•˜์—ฌ ์ˆ˜์ •
labels[labels == tokenizer.pad_token_id] = -100 # ํŒจ๋”ฉ ํ† ํฐ ID๋ฅผ -100์œผ๋กœ ๋ณ€๊ฒฝ
# ํŒจ๋”ฉ ๋ฐ ๋งˆ์Šคํ‚น ํ›„ input_ids์™€ labels์˜ ๊ธธ์ด๊ฐ€ ๋™์ผํ•œ์ง€ ํ™•์ธ (๋””๋ฒ„๊น… ๋ชฉ์ )
# map ํ•จ์ˆ˜์—์„œ batched=True๋ฅผ ์‚ฌ์šฉํ•  ๋•Œ ๋ชจ๋“  ์‹œํ€€์Šค์˜ ๊ธธ์ด๊ฐ€ ๋™์ผํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
# ์ด์ œ input_ids์™€ labels ๋ชจ๋‘ ํ…์„œ์ด๋ฏ€๋กœ .shape[1]์œผ๋กœ ๊ธธ์ด ํ™•์ธ
for i in range(model_inputs["input_ids"].shape[0]): # ๋ฐฐ์น˜ ํฌ๊ธฐ๋งŒํผ ๋ฐ˜๋ณต
assert model_inputs["input_ids"].shape[1] == MAX_LEN, \
f"Input IDs length mismatch at index {i}: Expected {MAX_LEN}, Got {model_inputs['input_ids'].shape[1]}"
assert labels.shape[1] == MAX_LEN, \
f"Labels length mismatch at index {i}: Expected {MAX_LEN}, Got {labels.shape[1]}"
assert model_inputs["input_ids"].shape[1] == labels.shape[1], \
f"Input IDs and Labels length mismatch at index {i}: Input IDs {model_inputs['input_ids'].shape[1]}, Labels {labels.shape[1]}"
# ์ตœ์ข… ๋ ˆ์ด๋ธ”์„ model_inputs ๋”•์…”๋„ˆ๋ฆฌ์— ์ถ”๊ฐ€
model_inputs["labels"] = labels
return model_inputs