cren / weclone /data /clean /strategies.py
CrenCren's picture
Upload folder using huggingface_hub
88aba71 verified
import json
import pandas as pd
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from langchain_core.prompts import PromptTemplate
from weclone.data.models import QaPair, CutMessage, QaPairScore
from weclone.prompts.clean_data import CLEAN_PROMPT
import os
from weclone.utils.log import logger
@dataclass
class CleaningStrategy(ABC):
"""数据清洗策略的抽象基类"""
make_dataset_config: Dict
@abstractmethod
def clean(self, data: Any) -> Any:
"""
执行数据清洗操作。
Args:
data: 需要清洗的数据。
Returns:
清洗后的数据。
"""
pass
@dataclass
class LLMCleaningStrategy(CleaningStrategy):
"""使用大模型进行数据清洗的策略"""
def judge(self, data: List[QaPair]) -> None:
"""
调用llm打分,并将分数直接赋值给传入的QaPair。
"""
from weclone.core.inference.offline_infer import vllm_infer
logger.info("开始使用llm对数据打分")
inputs = []
prompt_template = PromptTemplate.from_template(CLEAN_PROMPT)
for qa in data:
inputs.append(prompt_template.invoke({"id": qa.id, "Q": qa.instruction, "A": qa.output}).text) # type: ignore
outputs = vllm_infer(
inputs,
self.make_dataset_config["model_name_or_path"],
template=self.make_dataset_config["template"],
temperature=0,
guided_decoding_class=QaPairScore,
repetition_penalty=1.2,
bad_words=[r"\n"],
)
parsed_scores: List[QaPairScore] = []
for result in outputs:
try:
score_data = json.loads(result.outputs[0].text)
qa_score = QaPairScore(**score_data)
parsed_scores.append(qa_score)
except json.JSONDecodeError:
logger.error(f"Error decoding JSON: {result.outputs[0].text}")
score_map = {score.id: score.score for score in parsed_scores}
for qa in data:
if qa.id in score_map:
qa.score = score_map[qa.id]
else:
logger.warning(f"Warning: Score not found for QaPair with id {qa.id}. Assigning default score.")
scores = [qa.score for qa in data if qa.score is not None]
score_series = pd.Series(scores)
score_counts = score_series.value_counts().sort_index()
score_percentages = score_series.value_counts(normalize=True).sort_index() * 100
pd.set_option("display.unicode.east_asian_width", True) # 尝试修正对齐问题
distribution_df = pd.DataFrame( # 合并数量和百分比到一个 DataFrame 中以便打印
{
"数量": score_counts,
"占比(%)": score_percentages.round(2),
}
)
distribution_df.index.name = "分数" # 给第一列加上列名:分数
printable_df_str = distribution_df.reset_index().to_string(index=False)
logger.success(f"llm打分分数分布情况:\n{printable_df_str}")
def clean(self) -> str:
"""
清洗 SFT 数据并返回清洗后的文件路径。
如果未启用清洗,则返回原始路径。
"""
config = self.make_dataset_config
dataset_dir = config["dataset_dir"]
dataset_info_path = os.path.join(dataset_dir, "dataset_info.json")
sft_json_path = os.path.join(dataset_dir, "sft-my.json")
output_json_path = os.path.join(dataset_dir, "sft-my-l.json")
accept_score = config.get("clean_dataset", {}).get("llm", {}).get("accept_score", 1)
if not config.get("clean_dataset", {}).get("enable_clean"):
logger.info("未启用清洗功能")
self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my.json")
return sft_json_path
try:
with open(sft_json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
filtered_data = [item for item in data if item.get("score", 0) >= accept_score]
with open(output_json_path, 'w', encoding='utf-8') as f:
json.dump(filtered_data, f, ensure_ascii=False, indent=4)
logger.success(f"已筛出低于{accept_score}分的数据,共保留 {len(filtered_data)} 条数据")
self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my-l.json")
return output_json_path
except Exception as e:
logger.error(f"清洗数据失败,使用原始数据: {str(e)}")
self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my.json")
return sft_json_path
def _update_dataset_info_file(self, dataset_info_path: str, new_file_name: str):
"""
修改 dataset_info.json 文件中的 file_name 字段
"""
try:
with open(dataset_info_path, "r", encoding="utf-8") as f:
dataset_info = json.load(f)
# 更新所有支持的数据集的 file_name
for key in ["wechat-sft", "wechat-sft-with-history"]:
if key in dataset_info:
dataset_info[key]["file_name"] = new_file_name
# 写回文件
with open(dataset_info_path, "w", encoding="utf-8") as f:
json.dump(dataset_info, f, indent=4, ensure_ascii=False)
logger.info(f"已更新 dataset_info.json 中的 file_name 为 {new_file_name}")
except Exception as e:
logger.warning(f"无法更新 dataset_info.json: {e}")