Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
模型合并脚本 - 将LoRA权重合并到基础模型中 | |
用于推理和部署 | |
""" | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from peft import PeftModel | |
import argparse | |
def merge_lora_model(base_model_path, lora_model_path, output_path): | |
""" | |
合并LoRA权重到基础模型 | |
Args: | |
base_model_path: 基础模型路径 | |
lora_model_path: LoRA模型路径(训练输出) | |
output_path: 合并后模型保存路径 | |
""" | |
print("📥 Loading base model...") | |
# 加载基础模型(不使用量化) | |
base_model = AutoModelForCausalLM.from_pretrained( | |
base_model_path, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True, | |
) | |
print("📥 Loading LoRA model...") | |
# 加载LoRA模型 | |
model = PeftModel.from_pretrained(base_model, lora_model_path) | |
print("🔄 Merging LoRA weights...") | |
# 合并权重 | |
model = model.merge_and_unload() | |
print("💾 Saving merged model...") | |
# 保存合并后的模型 | |
model.save_pretrained(output_path, safe_serialization=True) | |
# 复制tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(base_model_path) | |
tokenizer.save_pretrained(output_path) | |
print(f"✅ Model merged and saved to {output_path}") | |
def test_merged_model(model_path): | |
"""测试合并后的模型""" | |
print("🧪 Testing merged model...") | |
# 加载模型和tokenizer | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
# 测试提示 | |
test_prompt = "### Human: Create an advertisement for a revolutionary AI-powered smartwatch\n### Assistant:" | |
inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=200, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
generated_text = response[len(test_prompt):].strip() | |
print(f"\n📝 Test Prompt: Create an advertisement for a revolutionary AI-powered smartwatch") | |
print(f"📄 Generated Response:\n{generated_text}") | |
def main(): | |
parser = argparse.ArgumentParser(description="Merge LoRA weights with base model") | |
parser.add_argument("--base_model", required=True, help="Path to base model") | |
parser.add_argument("--lora_model", required=True, help="Path to LoRA model (training output)") | |
parser.add_argument("--output", required=True, help="Output path for merged model") | |
parser.add_argument("--test", action="store_true", help="Test the merged model") | |
args = parser.parse_args() | |
# 合并模型 | |
merge_lora_model(args.base_model, args.lora_model, args.output) | |
# 测试模型(可选) | |
if args.test: | |
test_merged_model(args.output) | |
if __name__ == "__main__": | |
# 示例用法 | |
print("📋 Merge LoRA Model Script") | |
print("\n使用方法:") | |
print("python merge_model.py --base_model microsoft/DialoGPT-medium --lora_model ./results --output ./merged_model --test") | |
print("\n或者直接运行默认配置:") | |
# 默认配置 | |
merge_lora_model( | |
base_model_path="microsoft/DialoGPT-medium", # 替换为实际的OpenAI OSS 120B模型 | |
lora_model_path="./results", | |
output_path="./merged_model" | |
) | |
# 测试合并后的模型 | |
test_merged_model("./merged_model") |