AdGPT / lauguage_model_fine_tuning /merge_teacher_model.py
goodmodeler's picture
ADD: LLM SFT, RLHF and Distillation
c1c9e88
#!/usr/bin/env python3
"""
模型合并脚本 - 将LoRA权重合并到基础模型中
用于推理和部署
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import argparse
def merge_lora_model(base_model_path, lora_model_path, output_path):
"""
合并LoRA权重到基础模型
Args:
base_model_path: 基础模型路径
lora_model_path: LoRA模型路径(训练输出)
output_path: 合并后模型保存路径
"""
print("📥 Loading base model...")
# 加载基础模型(不使用量化)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
print("📥 Loading LoRA model...")
# 加载LoRA模型
model = PeftModel.from_pretrained(base_model, lora_model_path)
print("🔄 Merging LoRA weights...")
# 合并权重
model = model.merge_and_unload()
print("💾 Saving merged model...")
# 保存合并后的模型
model.save_pretrained(output_path, safe_serialization=True)
# 复制tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
tokenizer.save_pretrained(output_path)
print(f"✅ Model merged and saved to {output_path}")
def test_merged_model(model_path):
"""测试合并后的模型"""
print("🧪 Testing merged model...")
# 加载模型和tokenizer
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# 测试提示
test_prompt = "### Human: Create an advertisement for a revolutionary AI-powered smartwatch\n### Assistant:"
inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=200,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_text = response[len(test_prompt):].strip()
print(f"\n📝 Test Prompt: Create an advertisement for a revolutionary AI-powered smartwatch")
print(f"📄 Generated Response:\n{generated_text}")
def main():
parser = argparse.ArgumentParser(description="Merge LoRA weights with base model")
parser.add_argument("--base_model", required=True, help="Path to base model")
parser.add_argument("--lora_model", required=True, help="Path to LoRA model (training output)")
parser.add_argument("--output", required=True, help="Output path for merged model")
parser.add_argument("--test", action="store_true", help="Test the merged model")
args = parser.parse_args()
# 合并模型
merge_lora_model(args.base_model, args.lora_model, args.output)
# 测试模型(可选)
if args.test:
test_merged_model(args.output)
if __name__ == "__main__":
# 示例用法
print("📋 Merge LoRA Model Script")
print("\n使用方法:")
print("python merge_model.py --base_model microsoft/DialoGPT-medium --lora_model ./results --output ./merged_model --test")
print("\n或者直接运行默认配置:")
# 默认配置
merge_lora_model(
base_model_path="microsoft/DialoGPT-medium", # 替换为实际的OpenAI OSS 120B模型
lora_model_path="./results",
output_path="./merged_model"
)
# 测试合并后的模型
test_merged_model("./merged_model")