File size: 1,822 Bytes
96b6673 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import numpy as np
from numpy.typing import NDArray
from typing import Tuple
from abc import ABC, abstractmethod
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import Lasso
from sklearn.pipeline import make_pipeline
class BaseSolver(ABC):
"""
A base solver class.
Methods:
fit(self, masks: NDArray, outputs: NDArray, num_output_tokens: int) -> Tuple[NDArray, NDArray]:
Fit the solver to the given data.
"""
@abstractmethod
def fit(
self, masks: NDArray, outputs: NDArray, num_output_tokens: int
) -> Tuple[NDArray, NDArray]: ...
class LassoRegression(BaseSolver):
"""
A LASSO solver using the scikit-learn library.
Attributes:
lasso_alpha (float):
The alpha parameter for the LASSO regression. Defaults to 0.01.
Methods:
fit(self, masks: NDArray, outputs: NDArray, num_output_tokens: int) -> Tuple[NDArray, NDArray]:
Fit the solver to the given data.
"""
def __init__(self, lasso_alpha: float = 0.01) -> None:
self.lasso_alpha = lasso_alpha
def fit(
self, masks: NDArray, outputs: NDArray, num_output_tokens: int
) -> Tuple[NDArray, NDArray]:
X = masks.astype(np.float32)
Y = outputs / num_output_tokens
scaler = StandardScaler()
lasso = Lasso(alpha=self.lasso_alpha, random_state=0, fit_intercept=True)
# Pipeline is ((X - scaler.mean_) / scaler.scale_) @ lasso.coef_.T + lasso.intercept_
pipeline = make_pipeline(scaler, lasso)
pipeline.fit(X, Y)
# Rescale back to original scale
weight = lasso.coef_ / scaler.scale_
bias = lasso.intercept_ - (scaler.mean_ / scaler.scale_) @ lasso.coef_.T
return weight * num_output_tokens, bias * num_output_tokens
|