docs: replace KNN references with generic predictor terminology
Browse files
README.md
CHANGED
|
@@ -4,7 +4,7 @@ tags:
|
|
| 4 |
- llm-routing
|
| 5 |
- model-selection
|
| 6 |
- budget-optimization
|
| 7 |
-
-
|
| 8 |
language:
|
| 9 |
- en
|
| 10 |
library_name: sklearn
|
|
@@ -104,7 +104,7 @@ sys.path.insert(0, path)
|
|
| 104 |
|
| 105 |
from router import R2Router
|
| 106 |
|
| 107 |
-
# Train
|
| 108 |
router = R2Router.from_training_data(path, k=80, lambda_val=0.999)
|
| 109 |
```
|
| 110 |
|
|
@@ -127,8 +127,8 @@ Input Query
|
|
| 127 |
[1] Embed with Qwen3-0.6B -> 1024-dim vector
|
| 128 |
|
|
| 129 |
[2] For each (model, budget) pair:
|
| 130 |
-
-
|
| 131 |
-
-
|
| 132 |
- Compute risk = (1-lambda) * quality - lambda * cost
|
| 133 |
|
|
| 134 |
[3] Select (model, budget) with highest risk
|
|
@@ -155,10 +155,10 @@ Output: (model_name, token_budget)
|
|
| 155 |
|
| 156 |
| Parameter | Value |
|
| 157 |
|-----------|-------|
|
| 158 |
-
|
|
| 159 |
| Lambda | 0.999 |
|
| 160 |
| Distance Metric | Cosine |
|
| 161 |
-
|
|
| 162 |
| Embedding Dim | 1024 |
|
| 163 |
|
| 164 |
## Repository Contents
|
|
@@ -170,8 +170,8 @@ training_data/
|
|
| 170 |
embeddings.npy # Sub_10 training embeddings (809 x 1024)
|
| 171 |
labels.json # Per-(model, budget) accuracy & token labels
|
| 172 |
checkpoints/
|
| 173 |
-
quality_knn_*.joblib # Pre-fitted
|
| 174 |
-
token_knn_*.joblib # Pre-fitted
|
| 175 |
```
|
| 176 |
|
| 177 |
### Ways to Use
|
|
@@ -181,14 +181,14 @@ checkpoints/
|
|
| 181 |
| `route_text()` + vLLM server | Yes (server) | Start `vllm serve` once, route from anywhere via HTTP |
|
| 182 |
| `route_text()` + local vLLM | Yes (local) | Auto-loads Qwen3-0.6B on first call, caches it |
|
| 183 |
| `route(embedding)` | No | Route from pre-computed 1024-dim embedding |
|
| 184 |
-
| `from_training_data(path)` | No | Train your own
|
| 185 |
|
| 186 |
## Training Details
|
| 187 |
|
| 188 |
- **Training Data**: RouterArena sub_10 split (809 queries, 10% of full 8,400)
|
| 189 |
-
- **Method**:
|
| 190 |
- **Evaluation**: Full 8,400 RouterArena queries (no data leakage)
|
| 191 |
-
- **Training Time**: < 1 second
|
| 192 |
|
| 193 |
## Citation
|
| 194 |
|
|
|
|
| 4 |
- llm-routing
|
| 5 |
- model-selection
|
| 6 |
- budget-optimization
|
| 7 |
+
- nearest-neighbor
|
| 8 |
language:
|
| 9 |
- en
|
| 10 |
library_name: sklearn
|
|
|
|
| 104 |
|
| 105 |
from router import R2Router
|
| 106 |
|
| 107 |
+
# Train predictors with custom hyperparameters
|
| 108 |
router = R2Router.from_training_data(path, k=80, lambda_val=0.999)
|
| 109 |
```
|
| 110 |
|
|
|
|
| 127 |
[1] Embed with Qwen3-0.6B -> 1024-dim vector
|
| 128 |
|
|
| 129 |
[2] For each (model, budget) pair:
|
| 130 |
+
- Predict quality (accuracy)
|
| 131 |
+
- Predict output token count
|
| 132 |
- Compute risk = (1-lambda) * quality - lambda * cost
|
| 133 |
|
|
| 134 |
[3] Select (model, budget) with highest risk
|
|
|
|
| 155 |
|
| 156 |
| Parameter | Value |
|
| 157 |
|-----------|-------|
|
| 158 |
+
| K (neighbors) | 80 |
|
| 159 |
| Lambda | 0.999 |
|
| 160 |
| Distance Metric | Cosine |
|
| 161 |
+
| Weights | Distance-weighted |
|
| 162 |
| Embedding Dim | 1024 |
|
| 163 |
|
| 164 |
## Repository Contents
|
|
|
|
| 170 |
embeddings.npy # Sub_10 training embeddings (809 x 1024)
|
| 171 |
labels.json # Per-(model, budget) accuracy & token labels
|
| 172 |
checkpoints/
|
| 173 |
+
quality_knn_*.joblib # Pre-fitted quality predictors (18 total)
|
| 174 |
+
token_knn_*.joblib # Pre-fitted token predictors (6 total)
|
| 175 |
```
|
| 176 |
|
| 177 |
### Ways to Use
|
|
|
|
| 181 |
| `route_text()` + vLLM server | Yes (server) | Start `vllm serve` once, route from anywhere via HTTP |
|
| 182 |
| `route_text()` + local vLLM | Yes (local) | Auto-loads Qwen3-0.6B on first call, caches it |
|
| 183 |
| `route(embedding)` | No | Route from pre-computed 1024-dim embedding |
|
| 184 |
+
| `from_training_data(path)` | No | Train your own predictors with custom hyperparameters |
|
| 185 |
|
| 186 |
## Training Details
|
| 187 |
|
| 188 |
- **Training Data**: RouterArena sub_10 split (809 queries, 10% of full 8,400)
|
| 189 |
+
- **Method**: Nearest-neighbor regression with cosine distance, distance-weighted
|
| 190 |
- **Evaluation**: Full 8,400 RouterArena queries (no data leakage)
|
| 191 |
+
- **Training Time**: < 1 second
|
| 192 |
|
| 193 |
## Citation
|
| 194 |
|