|
import json
|
|
import pandas as pd
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, List, Union
|
|
from langchain_core.prompts import PromptTemplate
|
|
from weclone.data.models import QaPair, CutMessage, QaPairScore
|
|
from weclone.prompts.clean_data import CLEAN_PROMPT
|
|
import os
|
|
from weclone.utils.log import logger
|
|
|
|
|
|
@dataclass
|
|
class CleaningStrategy(ABC):
|
|
"""数据清洗策略的抽象基类"""
|
|
|
|
make_dataset_config: Dict
|
|
|
|
@abstractmethod
|
|
def clean(self, data: Any) -> Any:
|
|
"""
|
|
执行数据清洗操作。
|
|
|
|
Args:
|
|
data: 需要清洗的数据。
|
|
|
|
Returns:
|
|
清洗后的数据。
|
|
"""
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class LLMCleaningStrategy(CleaningStrategy):
|
|
"""使用大模型进行数据清洗的策略"""
|
|
|
|
|
|
def judge(self, data: List[QaPair]) -> None:
|
|
"""
|
|
调用llm打分,并将分数直接赋值给传入的QaPair。
|
|
"""
|
|
from weclone.core.inference.offline_infer import vllm_infer
|
|
logger.info("开始使用llm对数据打分")
|
|
inputs = []
|
|
prompt_template = PromptTemplate.from_template(CLEAN_PROMPT)
|
|
for qa in data:
|
|
inputs.append(prompt_template.invoke({"id": qa.id, "Q": qa.instruction, "A": qa.output}).text)
|
|
outputs = vllm_infer(
|
|
inputs,
|
|
self.make_dataset_config["model_name_or_path"],
|
|
template=self.make_dataset_config["template"],
|
|
temperature=0,
|
|
guided_decoding_class=QaPairScore,
|
|
repetition_penalty=1.2,
|
|
bad_words=[r"\n"],
|
|
)
|
|
|
|
parsed_scores: List[QaPairScore] = []
|
|
for result in outputs:
|
|
try:
|
|
score_data = json.loads(result.outputs[0].text)
|
|
qa_score = QaPairScore(**score_data)
|
|
parsed_scores.append(qa_score)
|
|
except json.JSONDecodeError:
|
|
logger.error(f"Error decoding JSON: {result.outputs[0].text}")
|
|
|
|
score_map = {score.id: score.score for score in parsed_scores}
|
|
for qa in data:
|
|
if qa.id in score_map:
|
|
qa.score = score_map[qa.id]
|
|
else:
|
|
logger.warning(f"Warning: Score not found for QaPair with id {qa.id}. Assigning default score.")
|
|
|
|
scores = [qa.score for qa in data if qa.score is not None]
|
|
score_series = pd.Series(scores)
|
|
score_counts = score_series.value_counts().sort_index()
|
|
score_percentages = score_series.value_counts(normalize=True).sort_index() * 100
|
|
pd.set_option("display.unicode.east_asian_width", True)
|
|
distribution_df = pd.DataFrame(
|
|
{
|
|
"数量": score_counts,
|
|
"占比(%)": score_percentages.round(2),
|
|
}
|
|
)
|
|
distribution_df.index.name = "分数"
|
|
printable_df_str = distribution_df.reset_index().to_string(index=False)
|
|
logger.success(f"llm打分分数分布情况:\n{printable_df_str}")
|
|
|
|
def clean(self) -> str:
|
|
"""
|
|
清洗 SFT 数据并返回清洗后的文件路径。
|
|
如果未启用清洗,则返回原始路径。
|
|
"""
|
|
config = self.make_dataset_config
|
|
dataset_dir = config["dataset_dir"]
|
|
dataset_info_path = os.path.join(dataset_dir, "dataset_info.json")
|
|
|
|
sft_json_path = os.path.join(dataset_dir, "sft-my.json")
|
|
output_json_path = os.path.join(dataset_dir, "sft-my-l.json")
|
|
accept_score = config.get("clean_dataset", {}).get("llm", {}).get("accept_score", 1)
|
|
|
|
if not config.get("clean_dataset", {}).get("enable_clean"):
|
|
logger.info("未启用清洗功能")
|
|
self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my.json")
|
|
return sft_json_path
|
|
|
|
try:
|
|
with open(sft_json_path, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
filtered_data = [item for item in data if item.get("score", 0) >= accept_score]
|
|
|
|
with open(output_json_path, 'w', encoding='utf-8') as f:
|
|
json.dump(filtered_data, f, ensure_ascii=False, indent=4)
|
|
|
|
logger.success(f"已筛出低于{accept_score}分的数据,共保留 {len(filtered_data)} 条数据")
|
|
self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my-l.json")
|
|
return output_json_path
|
|
|
|
except Exception as e:
|
|
logger.error(f"清洗数据失败,使用原始数据: {str(e)}")
|
|
self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my.json")
|
|
return sft_json_path
|
|
|
|
def _update_dataset_info_file(self, dataset_info_path: str, new_file_name: str):
|
|
"""
|
|
修改 dataset_info.json 文件中的 file_name 字段
|
|
"""
|
|
try:
|
|
with open(dataset_info_path, "r", encoding="utf-8") as f:
|
|
dataset_info = json.load(f)
|
|
|
|
|
|
for key in ["wechat-sft", "wechat-sft-with-history"]:
|
|
if key in dataset_info:
|
|
dataset_info[key]["file_name"] = new_file_name
|
|
|
|
|
|
with open(dataset_info_path, "w", encoding="utf-8") as f:
|
|
json.dump(dataset_info, f, indent=4, ensure_ascii=False)
|
|
|
|
logger.info(f"已更新 dataset_info.json 中的 file_name 为 {new_file_name}")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"无法更新 dataset_info.json: {e}")
|
|
|