JiaqiXue commited on
Commit
9e08fe1
·
verified ·
1 Parent(s): 8ac636e

feat: add training progress prints to from_training_data()

Browse files
Files changed (1) hide show
  1. router.py +8 -0
router.py CHANGED
@@ -137,8 +137,12 @@ class R2Router:
137
  with open(os.path.join(path, "training_data", "labels.json")) as f:
138
  labels = json.load(f)
139
 
 
 
140
  quality_knns = {}
141
  token_knns = {}
 
 
142
 
143
  for model_name, model_labels in labels.items():
144
  quality_knns[model_name] = {}
@@ -154,6 +158,7 @@ class R2Router:
154
  )
155
  knn.fit(X_train[valid], acc[valid])
156
  quality_knns[model_name][budget_name] = knn
 
157
 
158
  if "concise" in model_labels and "output_tokens" in model_labels["concise"]:
159
  tok = np.array([x if x is not None else np.nan for x in model_labels["concise"]["output_tokens"]])
@@ -166,6 +171,9 @@ class R2Router:
166
  )
167
  tknn.fit(X_train[valid], tok[valid])
168
  token_knns[model_name] = tknn
 
 
 
169
 
170
  model_prices = {
171
  mn: cfg["output_price_per_million"]
 
137
  with open(os.path.join(path, "training_data", "labels.json")) as f:
138
  labels = json.load(f)
139
 
140
+ print(f"Training KNN (k={k}) on {len(X_train)} samples...")
141
+
142
  quality_knns = {}
143
  token_knns = {}
144
+ n_quality = 0
145
+ n_token = 0
146
 
147
  for model_name, model_labels in labels.items():
148
  quality_knns[model_name] = {}
 
158
  )
159
  knn.fit(X_train[valid], acc[valid])
160
  quality_knns[model_name][budget_name] = knn
161
+ n_quality += 1
162
 
163
  if "concise" in model_labels and "output_tokens" in model_labels["concise"]:
164
  tok = np.array([x if x is not None else np.nan for x in model_labels["concise"]["output_tokens"]])
 
171
  )
172
  tknn.fit(X_train[valid], tok[valid])
173
  token_knns[model_name] = tknn
174
+ n_token += 1
175
+
176
+ print(f"Trained {n_quality} quality KNNs + {n_token} token KNNs for {len(quality_knns)} models.")
177
 
178
  model_prices = {
179
  mn: cfg["output_price_per_million"]