Hanrui / SpecForge /specforge /tracker.py
Lekr0's picture
Add files using upload-large-folder tool
2d67aa6 verified
# tracker.py
import abc
import netrc
import os
from typing import Any, Dict, Optional
import torch.distributed as dist
# --- Lazy Imports ---
# These libraries are imported only when their respective trackers are used.
try:
import wandb
except ImportError:
wandb = None
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
SummaryWriter = None
try:
import swanlab
except ImportError:
swanlab = None
try:
import mlflow
except ImportError:
mlflow = None
# --- End Lazy Imports ---
class Tracker(abc.ABC):
"""
Abstract Base Class for experiment trackers.
Each tracker implementation should handle its own initialization, logging,
and cleanup. It should also provide a class method to validate
command-line arguments before initialization.
"""
def __init__(self, args, output_dir: str):
self.args = args
self.output_dir = output_dir
self.rank = dist.get_rank()
self.is_initialized = False
@classmethod
@abc.abstractmethod
def validate_args(cls, parser, args) -> None:
"""
Validate necessary arguments for this tracker.
This method is called during argument parsing.
It should raise an error if required arguments are missing.
"""
@abc.abstractmethod
def log(self, log_dict: Dict[str, Any], step: Optional[int] = None) -> None:
"""
Log metrics to the tracker.
"""
@abc.abstractmethod
def close(self) -> None:
"""
Close the tracker and clean up resources.
"""
class NoOpTracker(Tracker):
"""A tracker that does nothing, for when no tracking is desired."""
@classmethod
def validate_args(cls, parser, args):
pass # No arguments to validate
def __init__(self, args, output_dir: str):
super().__init__(args, output_dir)
self.is_initialized = True # Considered initialized to do nothing
def log(self, log_dict: Dict[str, Any], step: Optional[int] = None):
pass # Do nothing
def close(self):
pass # Do nothing
class WandbTracker(Tracker):
"""Tracks experiments using Weights & Biases."""
@staticmethod
def _default_wandb_dir() -> str:
# specforge/tracker.py -> project root is one level up
return os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "wandb"))
@classmethod
def validate_args(cls, parser, args):
if wandb is None:
parser.error(
"To use --report-to wandb, you must install wandb: 'pip install wandb'"
)
if args.wandb_dir is None:
args.wandb_dir = cls._default_wandb_dir()
if args.wandb_offline:
return
if args.wandb_key is not None:
return
if "WANDB_API_KEY" in os.environ:
args.wandb_key = os.environ["WANDB_API_KEY"]
return
try:
netrc_path = os.path.expanduser("~/.netrc")
if os.path.exists(netrc_path):
netrc_file = netrc.netrc(netrc_path)
if "api.wandb.ai" in netrc_file.hosts:
_, _, password = netrc_file.authenticators("api.wandb.ai")
if password:
args.wandb_key = password
return
except (FileNotFoundError, netrc.NetrcParseError):
pass
if args.wandb_key is None:
parser.error(
"When --report-to is 'wandb', you must provide a wandb API key via one of:\n"
" 1. --wandb-key argument\n"
" 2. WANDB_API_KEY environment variable\n"
" 3. `wandb login` command"
)
def __init__(self, args, output_dir: str):
super().__init__(args, output_dir)
if self.rank == 0:
if args.wandb_dir is None:
args.wandb_dir = self._default_wandb_dir()
os.makedirs(args.wandb_dir, exist_ok=True)
if not args.wandb_offline:
wandb.login(key=args.wandb_key)
init_kwargs = {
"project": args.wandb_project,
"name": args.wandb_name,
"config": vars(args),
"dir": args.wandb_dir,
}
if args.wandb_offline:
init_kwargs["mode"] = "offline"
wandb.init(**init_kwargs)
self.is_initialized = True
def log(self, log_dict: Dict[str, Any], step: Optional[int] = None):
if self.rank == 0 and self.is_initialized:
wandb.log(log_dict, step=step)
def close(self):
if self.rank == 0 and self.is_initialized and wandb.run:
wandb.finish()
self.is_initialized = False
class SwanlabTracker(Tracker):
"""Tracks experiments using SwanLab."""
@classmethod
def validate_args(cls, parser, args):
if swanlab is None:
parser.error(
"To use --report-to swanlab, you must install swanlab: 'pip install swanlab'"
)
if args.swanlab_key is not None:
return
if "SWANLAB_API_KEY" in os.environ:
args.swanlab_key = os.environ["SWANLAB_API_KEY"]
return
# Swanlab can run in anonymous mode if no key is provided in a non-distributed env.
# However, a key is often required for distributed runs to sync correctly.
if (
dist.is_initialized()
and dist.get_world_size() > 1
and args.swanlab_key is None
):
parser.error(
"In a distributed environment, when --report-to is 'swanlab', you must provide a swanlab API key via:\n"
" 1. --swanlab-key argument\n"
" 2. SWANLAB_API_KEY environment variable"
)
def __init__(self, args, output_dir: str):
super().__init__(args, output_dir)
if self.rank == 0:
if args.swanlab_key:
swanlab.login(api_key=args.swanlab_key)
swanlog_dir = os.path.join(output_dir, "swanlog")
os.makedirs(swanlog_dir, exist_ok=True)
swanlab.init(
project=args.swanlab_project,
experiment_name=args.swanlab_name,
config=vars(args),
logdir=swanlog_dir,
)
self.is_initialized = True
def log(self, log_dict: Dict[str, Any], step: Optional[int] = None):
if self.rank == 0 and self.is_initialized:
swanlab.log(log_dict, step=step)
def close(self):
if self.rank == 0 and self.is_initialized and swanlab.get_run() is not None:
swanlab.finish()
self.is_initialized = False
class TensorboardTracker(Tracker):
"""Tracks experiments using TensorBoard."""
@classmethod
def validate_args(cls, parser, args):
if SummaryWriter is None:
parser.error(
"To use --report-to tensorboard, you must have tensorboard installed: 'pip install tensorboard'"
)
def __init__(self, args, output_dir: str):
super().__init__(args, output_dir)
if self.rank == 0:
log_dir = os.path.join(output_dir, "runs")
self.writer = SummaryWriter(log_dir=log_dir)
self.is_initialized = True
def log(self, log_dict: Dict[str, Any], step: Optional[int] = None):
if self.rank == 0 and self.is_initialized:
for key, value in log_dict.items():
if isinstance(value, (int, float)):
self.writer.add_scalar(key, value, global_step=step)
def close(self):
if self.rank == 0 and self.is_initialized:
self.writer.close()
self.is_initialized = False
class MLflowTracker(Tracker):
"""Tracks experiments using MLflow."""
@classmethod
def validate_args(cls, parser, args):
if mlflow is None:
parser.error(
"To use --report-to mlflow, you must install mlflow: 'pip install mlflow'"
)
# Set tracking URI from environment variable if not explicitly provided
if args.mlflow_tracking_uri is None and "MLFLOW_TRACKING_URI" in os.environ:
args.mlflow_tracking_uri = os.environ["MLFLOW_TRACKING_URI"]
elif args.mlflow_tracking_uri is None:
print(
"Warning: MLflow tracking URI not set. Defaulting to local './mlruns'."
)
# Set experiment name from environment variable if not explicitly provided
if (
args.mlflow_experiment_name is None
and "MLFLOW_EXPERIMENT_NAME" in os.environ
):
args.mlflow_experiment_name = os.environ["MLFLOW_EXPERIMENT_NAME"]
def __init__(self, args, output_dir: str):
super().__init__(args, output_dir)
if self.rank == 0:
if args.mlflow_tracking_uri:
mlflow.set_tracking_uri(args.mlflow_tracking_uri)
# This will either use the set URI or the default
mlflow.set_experiment(args.mlflow_experiment_name)
mlflow.start_run(run_name=args.mlflow_run_name)
mlflow.log_params(vars(args))
self.is_initialized = True
def log(self, log_dict: Dict[str, Any], step: Optional[int] = None):
if self.rank == 0 and self.is_initialized:
# MLflow's log_metrics takes a dictionary directly
mlflow.log_metrics(log_dict, step=step)
def close(self):
if self.rank == 0 and self.is_initialized:
mlflow.end_run()
self.is_initialized = False
# --- Tracker Factory ---
TRACKER_REGISTRY = {
"wandb": WandbTracker,
"swanlab": SwanlabTracker,
"tensorboard": TensorboardTracker,
"mlflow": MLflowTracker,
"none": NoOpTracker,
}
def get_tracker_class(report_to: str) -> Optional[Tracker]:
"""Returns the tracker class based on the name."""
return TRACKER_REGISTRY.get(report_to)
def create_tracker(args, output_dir: str) -> Tracker:
"""Factory function to create an experiment tracker instance."""
tracker_class = get_tracker_class(args.report_to)
if not tracker_class:
raise ValueError(f"Unsupported report_to type: {args.report_to}")
return tracker_class(args, output_dir)