File size: 7,602 Bytes
88aba71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import click
import commentjson
from pathlib import Path
import os
import sys
import functools

from weclone.utils.log import logger, capture_output
from weclone.utils.config import load_config

cli_config: dict | None = None

try:
    import tomllib  # type: ignore Python 3.11+
except ImportError:
    import tomli as tomllib


def clear_argv(func):
    """

    装饰器:在调用被装饰函数前,清理 sys.argv,只保留脚本名。调用后恢复原始 sys.argv。

    用于防止参数被 Hugging Face HfArgumentParser 解析造成 ValueError。

    """

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        original_argv = sys.argv.copy()
        sys.argv = [original_argv[0]]  # 只保留脚本名
        try:
            return func(*args, **kwargs)
        finally:
            sys.argv = original_argv  # 恢复原始 sys.argv

    return wrapper


def apply_common_decorators(capture_output_enabled=False):
    """

    A unified decorator for applications

    """

    def decorator(original_cmd_func):
        @functools.wraps(original_cmd_func)
        def new_runtime_wrapper(*args, **kwargs):
            if cli_config and cli_config.get("full_log", False):
                return capture_output(original_cmd_func)(*args, **kwargs)
            else:
                return original_cmd_func(*args, **kwargs)

        func_with_clear_argv = clear_argv(new_runtime_wrapper)

        return functools.wraps(original_cmd_func)(func_with_clear_argv)

    return decorator


@click.group()
def cli():
    """WeClone: 从聊天记录创造数字分身的一站式解决方案"""
    _check_project_root()
    _check_versions()
    global cli_config
    cli_config = load_config(arg_type="cli_args")


@cli.command("make-dataset", help="处理聊天记录CSV文件,生成问答对数据集。")
@apply_common_decorators()
def qa_generator():
    """处理聊天记录CSV文件,生成问答对数据集。"""
    from weclone.data.qa_generator import DataProcessor

    processor = DataProcessor()
    processor.main()


@cli.command("train-sft", help="使用准备好的数据集对模型进行微调。")
@apply_common_decorators()
def train_sft():
    """使用准备好的数据集对模型进行微调。"""
    from weclone.train.train_sft import main as train_sft_main

    train_sft_main()


@cli.command("webchat-demo", help="启动 Web UI 与微调后的模型进行交互测试。")  # 命令名修改为 web-demo
@apply_common_decorators()
def web_demo():
    """启动 Web UI 与微调后的模型进行交互测试。"""
    from weclone.eval.web_demo import main as web_demo_main

    web_demo_main()


# TODO 添加评估功能 @cli.command("eval-model", help="使用从训练数据中划分出来的验证集评估。")
@apply_common_decorators()
def eval_model():
    """使用从训练数据中划分出来的验证集评估。"""
    from weclone.eval.eval_model import main as evaluate_main

    evaluate_main()


@cli.command("test-model", help="使用常见聊天问题测试模型。")
@apply_common_decorators()
def test_model():
    """测试"""
    from weclone.eval.test_model import main as test_main

    test_main()


@cli.command("server", help="启动API服务,提供模型推理接口。")
@apply_common_decorators()
def server():
    """启动API服务,提供模型推理接口。"""
    from weclone.server.api_service import main as server_main

    server_main()


def _check_project_root():
    """检查当前目录是否为项目根目录,并验证项目名称。"""
    project_root_marker = "pyproject.toml"
    current_dir = Path(os.getcwd())
    pyproject_path = current_dir / project_root_marker

    if not pyproject_path.is_file():
        logger.error(f"未在当前目录找到 {project_root_marker} 文件。")
        logger.error("请确保在WeClone项目根目录下运行此命令。")
        sys.exit(1)

    try:
        with open(pyproject_path, "rb") as f:
            pyproject_data = tomllib.load(f)
        project_name = pyproject_data.get("project", {}).get("name")
        if project_name != "WeClone":
            logger.error("请确保在正确的 WeClone 项目根目录下运行。")
            sys.exit(1)
    except tomllib.TOMLDecodeError as e:
        logger.error(f"错误:无法解析 {pyproject_path} 文件: {e}")
        sys.exit(1)
    except Exception as e:
        logger.error(f"读取或处理 {pyproject_path} 时发生意外错误: {e}")
        sys.exit(1)


def _check_versions():
    """比较本地 settings.jsonc 版本和 pyproject.toml 中的配置文件指南版本"""
    if tomllib is None:  # Skip check if toml parser failed to import
        return

    ROOT_DIR = Path(__file__).parent.parent
    SETTINGS_PATH = ROOT_DIR / "settings.jsonc"
    PYPROJECT_PATH = ROOT_DIR / "pyproject.toml"

    settings_version = None
    config_guide_version = None
    config_changelog = None

    if SETTINGS_PATH.exists():
        try:
            with open(SETTINGS_PATH, "r", encoding="utf-8") as f:
                settings_data = commentjson.load(f)
                settings_version = settings_data.get("version")
        except Exception as e:
            logger.error(f"错误:无法读取或解析 {SETTINGS_PATH}: {e}")
            logger.error("请确保 settings.jsonc 文件存在且格式正确。")
            sys.exit(1)
    else:
        logger.error(f"错误:未找到配置文件 {SETTINGS_PATH}。")
        logger.error("请确保 settings.jsonc 文件位于项目根目录。")
        sys.exit(1)

    if PYPROJECT_PATH.exists():
        try:
            with open(PYPROJECT_PATH, "rb") as f:  # tomllib 需要二进制模式
                pyproject_data = tomllib.load(f)
                weclone_tool_data = pyproject_data.get("tool", {}).get("weclone", {})
                config_guide_version = weclone_tool_data.get("config_version")
                config_changelog = weclone_tool_data.get("config_changelog", "N/A")
        except Exception as e:
            logger.warning(f"警告:无法读取或解析 {PYPROJECT_PATH}: {e}。无法检查配置文件是否为最新。")
    else:
        logger.warning(f"警告:未找到文件 {PYPROJECT_PATH}。无法检查配置文件是否为最新。")

    if not settings_version:
        logger.error(f"错误:在 {SETTINGS_PATH} 中未找到 'version' 字段。")
        logger.error("请从 settings.template.json 复制或更新您的 settings.jsonc 文件。")
        sys.exit(1)

    if config_guide_version:
        if settings_version != config_guide_version:
            logger.warning(
                f"警告:您的 settings.jsonc 文件版本 ({settings_version}) 与项目建议的配置版本 ({config_guide_version}) 不一致。"
            )
            logger.warning("这可能导致意外行为或错误。请从 settings.template.json 复制或更新您的 settings.jsonc 文件。")
            # TODO 根据版本号打印更新日志
            logger.warning(f"配置文件更新日志:\n{config_changelog}")
    elif PYPROJECT_PATH.exists():  # 如果文件存在但未读到版本
        logger.warning(
            f"警告:在 {PYPROJECT_PATH} 的 [tool.weclone] 下未找到 'config_version' 字段。"
            "无法确认您的 settings.jsonc 是否为最新配置版本。"
        )


if __name__ == "__main__":
    cli()