File size: 5,846 Bytes
88aba71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import json
import pandas as pd
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from langchain_core.prompts import PromptTemplate
from weclone.data.models import QaPair, CutMessage, QaPairScore
from weclone.prompts.clean_data import CLEAN_PROMPT
import os
from weclone.utils.log import logger


@dataclass
class CleaningStrategy(ABC):
    """数据清洗策略的抽象基类"""

    make_dataset_config: Dict

    @abstractmethod
    def clean(self, data: Any) -> Any:
        """

        执行数据清洗操作。



        Args:

            data: 需要清洗的数据。



        Returns:

            清洗后的数据。

        """
        pass


@dataclass
class LLMCleaningStrategy(CleaningStrategy):
    """使用大模型进行数据清洗的策略"""


    def judge(self, data: List[QaPair]) -> None:
        """

        调用llm打分,并将分数直接赋值给传入的QaPair。

        """
        from weclone.core.inference.offline_infer import vllm_infer
        logger.info("开始使用llm对数据打分")
        inputs = []
        prompt_template = PromptTemplate.from_template(CLEAN_PROMPT)
        for qa in data:
            inputs.append(prompt_template.invoke({"id": qa.id, "Q": qa.instruction, "A": qa.output}).text)  # type: ignore
        outputs = vllm_infer(
            inputs,
            self.make_dataset_config["model_name_or_path"],
            template=self.make_dataset_config["template"],
            temperature=0,
            guided_decoding_class=QaPairScore,
            repetition_penalty=1.2,
            bad_words=[r"\n"],
        )

        parsed_scores: List[QaPairScore] = []
        for result in outputs:
            try:
                score_data = json.loads(result.outputs[0].text)
                qa_score = QaPairScore(**score_data)
                parsed_scores.append(qa_score)
            except json.JSONDecodeError:
                logger.error(f"Error decoding JSON: {result.outputs[0].text}")

        score_map = {score.id: score.score for score in parsed_scores}
        for qa in data:
            if qa.id in score_map:
                qa.score = score_map[qa.id]
            else:
                logger.warning(f"Warning: Score not found for QaPair with id {qa.id}. Assigning default score.")

        scores = [qa.score for qa in data if qa.score is not None]
        score_series = pd.Series(scores)
        score_counts = score_series.value_counts().sort_index()
        score_percentages = score_series.value_counts(normalize=True).sort_index() * 100
        pd.set_option("display.unicode.east_asian_width", True)  # 尝试修正对齐问题
        distribution_df = pd.DataFrame(  # 合并数量和百分比到一个 DataFrame 中以便打印
            {
                "数量": score_counts,
                "占比(%)": score_percentages.round(2),
            }
        )
        distribution_df.index.name = "分数"  # 给第一列加上列名:分数
        printable_df_str = distribution_df.reset_index().to_string(index=False)
        logger.success(f"llm打分分数分布情况:\n{printable_df_str}")

    def clean(self) -> str:
        """

        清洗 SFT 数据并返回清洗后的文件路径。

        如果未启用清洗,则返回原始路径。

        """
        config = self.make_dataset_config
        dataset_dir = config["dataset_dir"]
        dataset_info_path = os.path.join(dataset_dir, "dataset_info.json")

        sft_json_path = os.path.join(dataset_dir, "sft-my.json")
        output_json_path = os.path.join(dataset_dir, "sft-my-l.json")
        accept_score = config.get("clean_dataset", {}).get("llm", {}).get("accept_score", 1)

        if not config.get("clean_dataset", {}).get("enable_clean"):
            logger.info("未启用清洗功能")
            self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my.json")
            return sft_json_path

        try:
            with open(sft_json_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            filtered_data = [item for item in data if item.get("score", 0) >= accept_score]

            with open(output_json_path, 'w', encoding='utf-8') as f:
                json.dump(filtered_data, f, ensure_ascii=False, indent=4)

            logger.success(f"已筛出低于{accept_score}分的数据,共保留 {len(filtered_data)} 条数据")
            self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my-l.json")
            return output_json_path

        except Exception as e:
            logger.error(f"清洗数据失败,使用原始数据: {str(e)}")
            self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my.json")
            return sft_json_path

    def _update_dataset_info_file(self, dataset_info_path: str, new_file_name: str):
        """

        修改 dataset_info.json 文件中的 file_name 字段

        """
        try:
            with open(dataset_info_path, "r", encoding="utf-8") as f:
                dataset_info = json.load(f)

            # 更新所有支持的数据集的 file_name
            for key in ["wechat-sft", "wechat-sft-with-history"]:
                if key in dataset_info:
                    dataset_info[key]["file_name"] = new_file_name

            # 写回文件
            with open(dataset_info_path, "w", encoding="utf-8") as f:
                json.dump(dataset_info, f, indent=4, ensure_ascii=False)

            logger.info(f"已更新 dataset_info.json 中的 file_name 为 {new_file_name}")

        except Exception as e:
            logger.warning(f"无法更新 dataset_info.json: {e}")