Dwrko-M1.0 / train.py
rajatsainisim's picture
πŸ”§ Switch to StarCoder2-3B: Fix gated model access issue
74855c7
raw
history blame
9.99 kB
#!/usr/bin/env python3
"""
Dwrko-M1.0 Fine-tuning Script
Fine-tune Mistral 7B to create your own Claude-like assistant
Optimized for 16GB RAM systems with QLoRA
"""
import os
import torch
import argparse
from datasets import Dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import wandb
# Dwrko-M1.0 Configuration
MODEL_NAME = "Dwrko-M1.0"
BASE_MODEL = "bigcode/starcoder2-3b" # Open coding model, perfect for Dwrko-M1.0
def setup_dwrko_model(use_4bit=True):
"""Setup Mistral 7B for Dwrko-M1.0 fine-tuning"""
print(f"πŸ€– Setting up {MODEL_NAME} based on {BASE_MODEL}")
# Quantization config for memory efficiency
if use_4bit:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)
print("βœ“ 4-bit quantization enabled for memory efficiency")
else:
bnb_config = None
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("βœ“ Tokenizer loaded and configured")
# Load model
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True
)
print("βœ“ Base model loaded successfully")
# Prepare model for k-bit training if using quantization
if use_4bit:
model = prepare_model_for_kbit_training(model)
print("βœ“ Model prepared for QLoRA training")
return model, tokenizer
def setup_dwrko_lora():
"""Setup LoRA configuration optimized for Dwrko-M1.0"""
lora_config = LoraConfig(
r=16, # Rank - balanced performance/memory
lora_alpha=32, # Scaling factor
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], # Target all attention layers
lora_dropout=0.1, # Dropout for regularization
bias="none", # No bias training
task_type="CAUSAL_LM" # Causal language modeling
)
print("βœ“ LoRA configuration optimized for Dwrko-M1.0")
return lora_config
def prepare_dwrko_dataset(data_path, tokenizer, max_length=512):
"""Prepare dataset for Dwrko-M1.0 training"""
print(f"πŸ“š Preparing dataset for {MODEL_NAME}...")
# Load data (supporting both JSONL and text formats)
if data_path.endswith('.jsonl'):
import json
data = []
with open(data_path, 'r', encoding='utf-8') as f:
for line in f:
data.append(json.loads(line))
else:
# Simple text file
with open(data_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
data = [{"text": line.strip()} for line in lines if line.strip()]
def tokenize_function(examples):
# Tokenize the texts for Dwrko-M1.0
tokenized = tokenizer(
examples["text"],
truncation=True,
padding=True,
max_length=max_length,
return_tensors="pt"
)
tokenized["labels"] = tokenized["input_ids"].clone()
return tokenized
dataset = Dataset.from_list(data)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
print(f"βœ“ Dataset prepared: {len(tokenized_dataset)} examples")
return tokenized_dataset
def main():
parser = argparse.ArgumentParser(description=f"Fine-tune {MODEL_NAME} - Your Claude-like AI Assistant")
parser.add_argument("--data", required=True, help="Path to training data")
parser.add_argument("--output_dir", default="./dwrko-m1.0", help="Output directory for Dwrko-M1.0")
parser.add_argument("--epochs", type=int, default=3, help="Number of training epochs")
parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate (2e-4 optimal for Dwrko-M1.0)")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size (1 for 16GB RAM)")
parser.add_argument("--grad_steps", type=int, default=8, help="Gradient accumulation steps")
parser.add_argument("--max_length", type=int, default=512, help="Max sequence length")
parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases for monitoring")
parser.add_argument("--project_name", default="dwrko-m1.0", help="W&B project name")
parser.add_argument("--run_name", default=None, help="W&B run name")
args = parser.parse_args()
# Set run name if not provided
if args.run_name is None:
args.run_name = f"{MODEL_NAME}-training"
print("=" * 60)
print(f"πŸš€ {MODEL_NAME} Fine-tuning Started!")
print("=" * 60)
print(f"πŸ“Š Training Configuration:")
print(f" β€’ Model: {MODEL_NAME} (based on Mistral 7B)")
print(f" β€’ Epochs: {args.epochs}")
print(f" β€’ Learning Rate: {args.lr}")
print(f" β€’ Batch Size: {args.batch_size}")
print(f" β€’ Gradient Accumulation: {args.grad_steps}")
print(f" β€’ Max Length: {args.max_length}")
print(f" β€’ Output Directory: {args.output_dir}")
print("=" * 60)
# Initialize wandb if requested
if args.use_wandb:
wandb.init(
project=args.project_name,
name=args.run_name,
config=vars(args),
tags=["dwrko-m1.0", "mistral-7b", "qlora", "coding", "reasoning"]
)
print("βœ“ Weights & Biases initialized")
# Setup model and tokenizer
print("\nπŸ”§ Loading Dwrko-M1.0 base model...")
model, tokenizer = setup_dwrko_model()
# Setup LoRA
print("\n🎯 Setting up LoRA for Dwrko-M1.0...")
lora_config = setup_dwrko_lora()
model = get_peft_model(model, lora_config)
# Print trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
trainable_percentage = 100 * trainable_params / total_params
print(f"\nπŸ“ˆ {MODEL_NAME} Parameter Statistics:")
print(f" β€’ Total parameters: {total_params:,}")
print(f" β€’ Trainable parameters: {trainable_params:,}")
print(f" β€’ Trainable percentage: {trainable_percentage:.2f}%")
# Prepare dataset
print(f"\nπŸ“š Preparing dataset for {MODEL_NAME}...")
train_dataset = prepare_dwrko_dataset(args.data, tokenizer, args.max_length)
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Training arguments optimized for Dwrko-M1.0
training_args = TrainingArguments(
output_dir=args.output_dir,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_steps,
learning_rate=args.lr,
num_train_epochs=args.epochs,
fp16=True, # Mixed precision for memory efficiency
gradient_checkpointing=True, # Memory optimization
dataloader_pin_memory=False, # Reduce memory usage
save_strategy="epoch", # Save every epoch
logging_steps=10, # Log every 10 steps
remove_unused_columns=False,
push_to_hub=False,
report_to="wandb" if args.use_wandb else None,
run_name=args.run_name if args.use_wandb else None,
save_total_limit=3, # Keep only 3 checkpoints
load_best_model_at_end=True,
metric_for_best_model="loss",
greater_is_better=False,
warmup_steps=100, # Warmup for stable training
logging_first_step=True,
optim="adamw_torch", # Optimizer
max_grad_norm=1.0, # Gradient clipping
)
# Initialize trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
# Start training
print(f"\nπŸŽ“ Starting {MODEL_NAME} training...")
print("=" * 60)
try:
# Train the model
trainer.train()
# Save the final model
print(f"\nπŸ’Ύ Saving {MODEL_NAME}...")
trainer.save_model()
tokenizer.save_pretrained(args.output_dir)
# Save model info
model_info = {
"model_name": MODEL_NAME,
"base_model": BASE_MODEL,
"training_args": vars(args),
"trainable_params": trainable_params,
"total_params": total_params,
"trainable_percentage": trainable_percentage
}
import json
with open(os.path.join(args.output_dir, "model_info.json"), "w") as f:
json.dump(model_info, f, indent=2)
print("=" * 60)
print(f"βœ… {MODEL_NAME} training completed successfully!")
print(f"πŸ“ Model saved to: {args.output_dir}")
print(f"🎯 Your {MODEL_NAME} is ready for coding and reasoning tasks!")
print("=" * 60)
# Instructions for next steps
print(f"\nπŸš€ Next Steps:")
print(f"1. Test your model: python test_dwrko.py --model_path {args.output_dir}")
print(f"2. Upload to HuggingFace: huggingface-cli upload {args.output_dir}/ your-username/{MODEL_NAME}")
print(f"3. Share with the community! 🌟")
except Exception as e:
print(f"\n❌ {MODEL_NAME} training failed: {str(e)}")
raise
finally:
if args.use_wandb:
wandb.finish()
if __name__ == "__main__":
main()