|
import logging |
|
import wandb |
|
from transformers import TrainerCallback |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.FileHandler('training.log'), |
|
logging.StreamHandler() |
|
] |
|
) |
|
|
|
class CustomCallback(TrainerCallback): |
|
def on_log(self, args, state, control, logs=None, **kwargs): |
|
if logs: |
|
logging.info(f"Step {state.global_step}: {logs}") |
|
|
|
def on_epoch_end(self, args, state, control, **kwargs): |
|
logging.info(f"Epoch {state.epoch} completed") |
|
|
|
def setup_wandb(): |
|
wandb.init(project="OrcaleSeek", entity="your-username") |
|
wandb.config = { |
|
"learning_rate": 2e-5, |
|
"architecture": "OrcaleSeek", |
|
"dataset": "Your-Dataset", |
|
} |