|
import re
|
|
import json
|
|
import pandas as pd
|
|
from tqdm import tqdm
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, List
|
|
from langchain_core.prompts import PromptTemplate
|
|
from weclone.data.models import QaPair, QaPairScore
|
|
from weclone.prompts.clean_data import CLEAN_PROMPT,ONLINE_LLM_CLEAN_PROMPT
|
|
from weclone.core.inference.online_infer import OnlineLLM
|
|
from weclone.utils.log import logger
|
|
import os
|
|
|
|
@dataclass
|
|
class CleaningStrategy(ABC):
|
|
"""数据清洗策略的抽象基类"""
|
|
|
|
make_dataset_config: Dict
|
|
|
|
@abstractmethod
|
|
def clean(self, data: Any) -> Any:
|
|
pass
|
|
|
|
@dataclass
|
|
class OlineLLMCleaningStrategy(CleaningStrategy):
|
|
"""使用大模型进行数据清洗的策略"""
|
|
|
|
def judge(self, data: List[QaPair]) -> None:
|
|
logger.info("开始使用在线模型对数据打分")
|
|
|
|
logger.info(f"使用模型 {self.make_dataset_config.get('model_name', '')}")
|
|
|
|
client = OnlineLLM(
|
|
api_key = self.make_dataset_config.get("llm_api_key"),
|
|
base_url = self.make_dataset_config.get("base_url"),
|
|
model_name = self.make_dataset_config.get("model_name"),
|
|
default_system = self.make_dataset_config.get("default_system")
|
|
)
|
|
prompt_template = PromptTemplate.from_template(ONLINE_LLM_CLEAN_PROMPT)
|
|
|
|
parsed_scores = []
|
|
clean_batch_size = int(self.make_dataset_config.get("clean_batch_size", 10))
|
|
for i in tqdm(range(0, len(data), clean_batch_size), desc="在线模型评分进度"):
|
|
batch = data[i : i + clean_batch_size]
|
|
|
|
qa_list = [
|
|
{"id": qa.id, "Q": qa.instruction, "A": qa.output}
|
|
for qa in batch
|
|
]
|
|
qa_list_json = json.dumps(qa_list, ensure_ascii=False)
|
|
|
|
prompt_text = prompt_template.invoke({
|
|
"qa_list": qa_list_json
|
|
}).text
|
|
try:
|
|
response = client.chat(prompt_text)
|
|
result_text = response.choices[0].message.content
|
|
|
|
|
|
if "</think>" in result_text:
|
|
result_text = result_text.split("</think>", 1)[1]
|
|
|
|
result_text = re.sub(r"^```json\s*|```$", "", result_text.strip(), flags=re.MULTILINE)
|
|
|
|
try:
|
|
score_list = json.loads(result_text)
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"JSON 解析失败,跳过本批次: {e}\n内容:{result_text}")
|
|
continue
|
|
|
|
for item in score_list:
|
|
parsed_scores.append(QaPairScore(**item))
|
|
except Exception as e:
|
|
ids_in_batch = [qa["id"] for qa in qa_list]
|
|
logger.error(f"调用在线模型或解析结果失败,当前 batch QA ID 列表: {ids_in_batch},错误信息: {str(e)}")
|
|
|
|
score_map = {score.id: score.score for score in parsed_scores}
|
|
for qa in data:
|
|
if qa.id in score_map:
|
|
qa.score = score_map[qa.id]
|
|
else:
|
|
logger.warning(f"未获取到QA ID {qa.id}的分数,默认赋值0")
|
|
qa.score = 0
|
|
|
|
|
|
scores = [qa.score for qa in data if qa.score is not None]
|
|
score_series = pd.Series(scores)
|
|
score_counts = score_series.value_counts().sort_index()
|
|
score_percentages = score_series.value_counts(normalize=True).sort_index() * 100
|
|
pd.set_option("display.unicode.east_asian_width", True)
|
|
distribution_df = pd.DataFrame({
|
|
"数量": score_counts,
|
|
"占比(%)": score_percentages.round(2),
|
|
})
|
|
distribution_df.index.name = "分数"
|
|
printable_df_str = distribution_df.reset_index().to_string(index=False)
|
|
logger.success(f"在线模型打分分数分布情况:\n{printable_df_str}")
|
|
|
|
def clean(self) -> str:
|
|
"""
|
|
清洗 SFT 数据并返回清洗后的文件路径。
|
|
如果未启用清洗,则返回原始路径。
|
|
"""
|
|
config = self.make_dataset_config
|
|
dataset_dir = config["dataset_dir"]
|
|
dataset_info_path = os.path.join(dataset_dir, "dataset_info.json")
|
|
|
|
sft_json_path = os.path.join(dataset_dir, "sft-my.json")
|
|
output_json_path = os.path.join(dataset_dir, "sft-my-l.json")
|
|
accept_score = config.get("clean_dataset", {}).get("llm", {}).get("accept_score", 1)
|
|
|
|
if not config.get("clean_dataset", {}).get("enable_clean"):
|
|
logger.info("未启用清洗功能")
|
|
self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my.json")
|
|
return sft_json_path
|
|
|
|
try:
|
|
with open(sft_json_path, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
filtered_data = [item for item in data if item.get("score", 0) >= accept_score]
|
|
|
|
with open(output_json_path, 'w', encoding='utf-8') as f:
|
|
json.dump(filtered_data, f, ensure_ascii=False, indent=4)
|
|
|
|
logger.success(f"已筛出低于{accept_score}分的数据,共保留 {len(filtered_data)} 条数据")
|
|
self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my-l.json")
|
|
return output_json_path
|
|
|
|
except Exception as e:
|
|
logger.error(f"清洗数据失败,使用原始数据: {str(e)}")
|
|
self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my.json")
|
|
return sft_json_path
|
|
|
|
def _update_dataset_info_file(self, dataset_info_path: str, new_file_name: str):
|
|
"""
|
|
修改 dataset_info.json 文件中的 file_name 字段
|
|
"""
|
|
try:
|
|
with open(dataset_info_path, "r", encoding="utf-8") as f:
|
|
dataset_info = json.load(f)
|
|
|
|
|
|
for key in ["wechat-sft", "wechat-sft-with-history"]:
|
|
if key in dataset_info:
|
|
dataset_info[key]["file_name"] = new_file_name
|
|
|
|
|
|
with open(dataset_info_path, "w", encoding="utf-8") as f:
|
|
json.dump(dataset_info, f, indent=4, ensure_ascii=False)
|
|
|
|
logger.info(f"已更新 dataset_info.json 中的 file_name 为 {new_file_name}")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"无法更新 dataset_info.json: {e}")
|
|
|