Spaces:
Runtime error
Runtime error
File size: 2,937 Bytes
5fc6e5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
from abc import ABC, abstractmethod
import os
import shutil
from loguru import logger
import mlflow
from numpy import ndarray
class BaseModel(ABC):
"""
Abstract base class for training models.
Subclasses should define the model and implement specific logic
for training, evaluation, and model persistence.
"""
def __init__(self, language, path=None):
"""
Initialize the trainer.
Args:
language (str): Language for the model.
path (str, optional): Path to load a pre-trained model. Defaults to None.
If None, a new model is initialized.
"""
self.language = language
self.model = None
if path:
self.load(path)
else:
self.setup_model()
@abstractmethod
def setup_model(self):
"""
Initialize or build the model.
Called in __init__ of subclass.
"""
pass
@abstractmethod
def train(self, X_train, y_train) -> dict[str,any]:
"""
Main training logic for the model.
Args:
X_train: Input training data.
y_train: True labels for training data.
"""
pass
@abstractmethod
def evaluate(self, X_test, y_test) -> dict[str,any]:
"""
Evaluation logic for the model.
Args:
X_test: Input test data.
y_test: True labels for test data.
"""
pass
@abstractmethod
def predict(self, X) -> ndarray:
"""
Make predictions using the trained model.
Args:
X: Input data for prediction.
Returns:
Predictions made by the model.
"""
pass
def save(self, path, model_name):
"""
Save model and log to MLflow.
Args:
path (str): Path to save the model.
model_name (str): Name to use when saving the model (without extension).
"""
if self.model is None:
raise ValueError("Model is not trained. Cannot save uninitialized model.")
complete_path = os.path.join(path, f"{model_name}_{self.language}")
if os.path.exists(complete_path) and os.path.isdir(complete_path):
shutil.rmtree(complete_path)
mlflow.sklearn.save_model(self.model, complete_path)
try:
mlflow.log_artifact(complete_path)
except Exception as e:
logger.error(f"Failed to log model to MLflow: {e}")
logger.info(f"Model saved to: {complete_path}")
def load(self, model_path):
"""
Load model from specified local path or mlflow model URI.
Args:
model_path (str): Path to load the model from (local or mlflow URI).
"""
self.model = mlflow.sklearn.load_model(model_path)
logger.info(f"Model loaded from: {model_path}")
|