File size: 5,846 Bytes
88aba71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import json
import pandas as pd
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from langchain_core.prompts import PromptTemplate
from weclone.data.models import QaPair, CutMessage, QaPairScore
from weclone.prompts.clean_data import CLEAN_PROMPT
import os
from weclone.utils.log import logger
@dataclass
class CleaningStrategy(ABC):
"""数据清洗策略的抽象基类"""
make_dataset_config: Dict
@abstractmethod
def clean(self, data: Any) -> Any:
"""
执行数据清洗操作。
Args:
data: 需要清洗的数据。
Returns:
清洗后的数据。
"""
pass
@dataclass
class LLMCleaningStrategy(CleaningStrategy):
"""使用大模型进行数据清洗的策略"""
def judge(self, data: List[QaPair]) -> None:
"""
调用llm打分,并将分数直接赋值给传入的QaPair。
"""
from weclone.core.inference.offline_infer import vllm_infer
logger.info("开始使用llm对数据打分")
inputs = []
prompt_template = PromptTemplate.from_template(CLEAN_PROMPT)
for qa in data:
inputs.append(prompt_template.invoke({"id": qa.id, "Q": qa.instruction, "A": qa.output}).text) # type: ignore
outputs = vllm_infer(
inputs,
self.make_dataset_config["model_name_or_path"],
template=self.make_dataset_config["template"],
temperature=0,
guided_decoding_class=QaPairScore,
repetition_penalty=1.2,
bad_words=[r"\n"],
)
parsed_scores: List[QaPairScore] = []
for result in outputs:
try:
score_data = json.loads(result.outputs[0].text)
qa_score = QaPairScore(**score_data)
parsed_scores.append(qa_score)
except json.JSONDecodeError:
logger.error(f"Error decoding JSON: {result.outputs[0].text}")
score_map = {score.id: score.score for score in parsed_scores}
for qa in data:
if qa.id in score_map:
qa.score = score_map[qa.id]
else:
logger.warning(f"Warning: Score not found for QaPair with id {qa.id}. Assigning default score.")
scores = [qa.score for qa in data if qa.score is not None]
score_series = pd.Series(scores)
score_counts = score_series.value_counts().sort_index()
score_percentages = score_series.value_counts(normalize=True).sort_index() * 100
pd.set_option("display.unicode.east_asian_width", True) # 尝试修正对齐问题
distribution_df = pd.DataFrame( # 合并数量和百分比到一个 DataFrame 中以便打印
{
"数量": score_counts,
"占比(%)": score_percentages.round(2),
}
)
distribution_df.index.name = "分数" # 给第一列加上列名:分数
printable_df_str = distribution_df.reset_index().to_string(index=False)
logger.success(f"llm打分分数分布情况:\n{printable_df_str}")
def clean(self) -> str:
"""
清洗 SFT 数据并返回清洗后的文件路径。
如果未启用清洗,则返回原始路径。
"""
config = self.make_dataset_config
dataset_dir = config["dataset_dir"]
dataset_info_path = os.path.join(dataset_dir, "dataset_info.json")
sft_json_path = os.path.join(dataset_dir, "sft-my.json")
output_json_path = os.path.join(dataset_dir, "sft-my-l.json")
accept_score = config.get("clean_dataset", {}).get("llm", {}).get("accept_score", 1)
if not config.get("clean_dataset", {}).get("enable_clean"):
logger.info("未启用清洗功能")
self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my.json")
return sft_json_path
try:
with open(sft_json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
filtered_data = [item for item in data if item.get("score", 0) >= accept_score]
with open(output_json_path, 'w', encoding='utf-8') as f:
json.dump(filtered_data, f, ensure_ascii=False, indent=4)
logger.success(f"已筛出低于{accept_score}分的数据,共保留 {len(filtered_data)} 条数据")
self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my-l.json")
return output_json_path
except Exception as e:
logger.error(f"清洗数据失败,使用原始数据: {str(e)}")
self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my.json")
return sft_json_path
def _update_dataset_info_file(self, dataset_info_path: str, new_file_name: str):
"""
修改 dataset_info.json 文件中的 file_name 字段
"""
try:
with open(dataset_info_path, "r", encoding="utf-8") as f:
dataset_info = json.load(f)
# 更新所有支持的数据集的 file_name
for key in ["wechat-sft", "wechat-sft-with-history"]:
if key in dataset_info:
dataset_info[key]["file_name"] = new_file_name
# 写回文件
with open(dataset_info_path, "w", encoding="utf-8") as f:
json.dump(dataset_info, f, indent=4, ensure_ascii=False)
logger.info(f"已更新 dataset_info.json 中的 file_name 为 {new_file_name}")
except Exception as e:
logger.warning(f"无法更新 dataset_info.json: {e}")
|