feat: add training progress prints to from_training_data()
Browse files
router.py
CHANGED
|
@@ -137,8 +137,12 @@ class R2Router:
|
|
| 137 |
with open(os.path.join(path, "training_data", "labels.json")) as f:
|
| 138 |
labels = json.load(f)
|
| 139 |
|
|
|
|
|
|
|
| 140 |
quality_knns = {}
|
| 141 |
token_knns = {}
|
|
|
|
|
|
|
| 142 |
|
| 143 |
for model_name, model_labels in labels.items():
|
| 144 |
quality_knns[model_name] = {}
|
|
@@ -154,6 +158,7 @@ class R2Router:
|
|
| 154 |
)
|
| 155 |
knn.fit(X_train[valid], acc[valid])
|
| 156 |
quality_knns[model_name][budget_name] = knn
|
|
|
|
| 157 |
|
| 158 |
if "concise" in model_labels and "output_tokens" in model_labels["concise"]:
|
| 159 |
tok = np.array([x if x is not None else np.nan for x in model_labels["concise"]["output_tokens"]])
|
|
@@ -166,6 +171,9 @@ class R2Router:
|
|
| 166 |
)
|
| 167 |
tknn.fit(X_train[valid], tok[valid])
|
| 168 |
token_knns[model_name] = tknn
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
model_prices = {
|
| 171 |
mn: cfg["output_price_per_million"]
|
|
|
|
| 137 |
with open(os.path.join(path, "training_data", "labels.json")) as f:
|
| 138 |
labels = json.load(f)
|
| 139 |
|
| 140 |
+
print(f"Training KNN (k={k}) on {len(X_train)} samples...")
|
| 141 |
+
|
| 142 |
quality_knns = {}
|
| 143 |
token_knns = {}
|
| 144 |
+
n_quality = 0
|
| 145 |
+
n_token = 0
|
| 146 |
|
| 147 |
for model_name, model_labels in labels.items():
|
| 148 |
quality_knns[model_name] = {}
|
|
|
|
| 158 |
)
|
| 159 |
knn.fit(X_train[valid], acc[valid])
|
| 160 |
quality_knns[model_name][budget_name] = knn
|
| 161 |
+
n_quality += 1
|
| 162 |
|
| 163 |
if "concise" in model_labels and "output_tokens" in model_labels["concise"]:
|
| 164 |
tok = np.array([x if x is not None else np.nan for x in model_labels["concise"]["output_tokens"]])
|
|
|
|
| 171 |
)
|
| 172 |
tknn.fit(X_train[valid], tok[valid])
|
| 173 |
token_knns[model_name] = tknn
|
| 174 |
+
n_token += 1
|
| 175 |
+
|
| 176 |
+
print(f"Trained {n_quality} quality KNNs + {n_token} token KNNs for {len(quality_knns)} models.")
|
| 177 |
|
| 178 |
model_prices = {
|
| 179 |
mn: cfg["output_price_per_million"]
|