File size: 1,342 Bytes
668bf5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from datasets import Dataset
from typing import Dict

def format_prompt(example: Dict[str, str]) -> Dict[str, str]:
    if all(k in example for k in ("instruction", "output")):
        instruction = example["instruction"]
        input_text = example.get("input", "")
        prompt = f"### Instruction:\n{instruction}\n\n"
        if input_text.strip():
            prompt += f"### Input:\n{input_text}\n\n"
        prompt += f"### Response:\n"
        return {"prompt": prompt, "completion": example["output"]}

    elif all(k in example for k in ("prompt", "completion")):
        return {"prompt": example["prompt"], "completion": example["completion"]}

    else:
        raise ValueError(f"μ§€μ›ν•˜μ§€ μ•ŠλŠ” 데이터 ν˜•μ‹: {example}")

def preprocess(dataset):
    # 데이터셋 μ—΄ 확인
    column_names = dataset.column_names
    if all(k in column_names for k in ("prompt", "completion")):
        return dataset  # κ·ΈλŒ€λ‘œ μ‚¬μš©
    elif all(k in column_names for k in ("instruction", "output")):
        return dataset.map(format_prompt, remove_columns=column_names)
    else:
        raise ValueError(f"μ§€μ›ν•˜μ§€ μ•ŠλŠ” μ—΄ ꡬ성: {column_names}")


"""

# 좜λ ₯ 확인

print(processed_dataset[0]) # input_ids , attention_mask , labels

print("111")

print(tokenized_dataset[0])

"""