OrcaleSeek / train.py
prelington's picture
Create train.py
8ea3de0 verified
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
Trainer,
DataCollatorWithPadding
)
from datasets import load_dataset
import torch
def train_model():
# Load your model and tokenizer
model_name = "your-username/your-model-name"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Load your dataset (replace with actual dataset)
dataset = load_dataset("imdb") # Example dataset
def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
# Training arguments
training_args = TrainingArguments(
output_dir="./results",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=3,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
)
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["test"],
tokenizer=tokenizer,
data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
)
# Start training
trainer.train()
# Save the fine-tuned model
trainer.save_model("./fine-tuned-model")
tokenizer.save_pretrained("./fine-tuned-model")
if __name__ == "__main__":
train_model()