Alovestocode commited on
Commit
a83f1cc
·
verified ·
1 Parent(s): dc894b6

Fix ZeroGPU startup error: Move GPU decorator to request handler

Browse files
Files changed (3) hide show
  1. README.md +14 -5
  2. __pycache__/app.cpython-313.pyc +0 -0
  3. app.py +200 -39
README.md CHANGED
@@ -20,7 +20,7 @@ endpoint via the `HF_ROUTER_API` environment variable.
20
 
21
  | File | Purpose |
22
  | ---- | ------- |
23
- | `app.py` | Loads the merged checkpoint on demand (tries `MODEL_REPO` first, then `router-qwen3-32b-merged`, `router-gemma3-merged`), exposes a `/v1/generate` API, and serves a small HTML console at `/gradio`. |
24
  | `requirements.txt` | Minimal dependency set (transformers, bitsandbytes, torch, fastapi, accelerate, sentencepiece, spaces, uvicorn). |
25
  | `.huggingface/spaces.yml` | Configures the Space for ZeroGPU hardware and disables automatic sleep. |
26
 
@@ -38,9 +38,14 @@ endpoint via the `HF_ROUTER_API` environment variable.
38
  huggingface-cli upload . Alovestocode/router-router-zero --repo-type space
39
  ```
40
 
41
- 3. **Configure secrets**
42
- - `MODEL_REPO` – optional override; defaults to the fallback list (`router-qwen3-32b-merged`, `router-gemma3-merged`)
43
- - `HF_TOKEN` – token with read access to the merged model
 
 
 
 
 
44
 
45
  4. **Connect the main router UI**
46
  ```bash
@@ -66,4 +71,8 @@ Response:
66
  ```
67
 
68
  Use `HF_ROUTER_API` in the main application or the smoke-test script to validate
69
- that the deployed model returns the expected JSON plan.
 
 
 
 
 
20
 
21
  | File | Purpose |
22
  | ---- | ------- |
23
+ | `app.py` | Loads the merged checkpoint on demand (tries `MODEL_REPO` first, then `MODEL_FALLBACKS` or the default Gemma → Llama → Qwen order), exposes a `/v1/generate` API, and serves a small HTML console at `/gradio`. |
24
  | `requirements.txt` | Minimal dependency set (transformers, bitsandbytes, torch, fastapi, accelerate, sentencepiece, spaces, uvicorn). |
25
  | `.huggingface/spaces.yml` | Configures the Space for ZeroGPU hardware and disables automatic sleep. |
26
 
 
38
  huggingface-cli upload . Alovestocode/router-router-zero --repo-type space
39
  ```
40
 
41
+ 3. **Configure secrets & variables**
42
+ - `HF_TOKEN` – token with read access to the merged checkpoint(s)
43
+ - `MODEL_REPO` – optional hard pin if you only want a single model considered
44
+ - `MODEL_FALLBACKS` – comma-separated preference order (defaults to `router-gemma3-merged,router-llama31-merged,router-qwen3-32b-merged`)
45
+ - `MODEL_LOAD_STRATEGY` – `8bit` (default), `4bit`, or `fp16`; backwards-compatible with `LOAD_IN_8BIT` / `LOAD_IN_4BIT`
46
+ - `MODEL_LOAD_STRATEGIES` – optional ordered fallback list (e.g. `8bit,4bit,cpu`). The loader will automatically walk this list and finally fall back to `8bit→4bit→bf16→fp16→cpu`.
47
+ - `SKIP_WARM_START` – set to `1` if you prefer to load lazily on the first request
48
+ - `ALLOW_WARM_START_FAILURE` – set to `1` to keep the container alive even if warm-up fails (the next request will retry)
49
 
50
  4. **Connect the main router UI**
51
  ```bash
 
71
  ```
72
 
73
  Use `HF_ROUTER_API` in the main application or the smoke-test script to validate
74
+ that the deployed model returns the expected JSON plan. When running on ZeroGPU
75
+ we recommend keeping `MODEL_LOAD_STRATEGY=8bit` (or `LOAD_IN_8BIT=1`) so the
76
+ weights fit comfortably in the 70GB slice; if that fails the app automatically
77
+ degrades through 4-bit, bf16/fp16, and finally CPU mode. You can inspect the
78
+ active load mode via the `/` healthcheck (`strategy` field).
__pycache__/app.cpython-313.pyc CHANGED
Binary files a/__pycache__/app.cpython-313.pyc and b/__pycache__/app.cpython-313.pyc differ
 
app.py CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations
2
 
3
  import os
4
  from functools import lru_cache
5
- from typing import Optional
6
 
7
  import torch
8
  from fastapi import FastAPI, HTTPException
@@ -31,27 +31,120 @@ from transformers import (
31
  MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "600"))
32
  DEFAULT_TEMPERATURE = float(os.environ.get("DEFAULT_TEMPERATURE", "0.2"))
33
  DEFAULT_TOP_P = float(os.environ.get("DEFAULT_TOP_P", "0.9"))
34
- USE_4BIT = os.environ.get("LOAD_IN_4BIT", "1") not in {"0", "false", "False"}
35
- USE_8BIT = os.environ.get("LOAD_IN_8BIT", "0").lower() in {"1", "true", "yes"}
36
 
37
- MODEL_FALLBACKS = [
38
- "Alovestocode/router-qwen3-32b-merged",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  "Alovestocode/router-gemma3-merged",
 
 
40
  ]
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def _initialise_tokenizer() -> tuple[str, AutoTokenizer]:
44
  errors: dict[str, str] = {}
45
- candidates = []
46
- explicit = os.environ.get("MODEL_REPO")
47
- if explicit:
48
- candidates.append(explicit)
49
- for name in MODEL_FALLBACKS:
50
- if name not in candidates:
51
- candidates.append(name)
52
- for candidate in candidates:
53
  try:
54
- tok = AutoTokenizer.from_pretrained(candidate, use_fast=False)
 
 
 
 
55
  print(f"Loaded tokenizer from {candidate}")
56
  return candidate, tok
57
  except Exception as exc: # pragma: no cover - download errors
@@ -78,27 +171,81 @@ class GenerateResponse(BaseModel):
78
 
79
 
80
  _MODEL = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
 
83
- @spaces.GPU(duration=120)
84
  def get_model() -> AutoModelForCausalLM:
85
- global _MODEL
 
86
  if _MODEL is None:
87
- dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
88
- kwargs = {
89
- "device_map": "auto",
90
- "trust_remote_code": True,
91
- }
92
- if USE_8BIT:
93
- kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
94
- elif USE_4BIT:
95
- kwargs["quantization_config"] = BitsAndBytesConfig(
96
- load_in_4bit=True,
97
- bnb_4bit_compute_dtype=dtype,
98
- )
99
- else:
100
- kwargs["torch_dtype"] = dtype
101
- _MODEL = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs).eval()
 
 
 
 
 
 
 
 
 
102
  return _MODEL
103
 
104
 
@@ -141,23 +288,37 @@ fastapi_app = FastAPI(title="Router Model API", version="1.0.0")
141
 
142
  @fastapi_app.get("/")
143
  def healthcheck() -> dict[str, str]:
144
- return {"status": "ok", "model": MODEL_ID}
 
 
 
 
145
 
146
 
147
  @fastapi_app.on_event("startup")
148
  def warm_start() -> None:
149
- """Ensure the GPU reservation is established during startup."""
150
- try:
151
- get_model()
152
- except Exception as exc:
153
- # Surface the failure early so the container exits with a useful log.
154
- raise RuntimeError(f"Model warm-up failed: {exc}") from exc
 
 
 
 
 
 
 
 
 
 
155
 
156
 
157
  @fastapi_app.post("/v1/generate", response_model=GenerateResponse)
158
  def generate_endpoint(payload: GeneratePayload) -> GenerateResponse:
159
  try:
160
- text = _generate(
161
  prompt=payload.prompt,
162
  max_new_tokens=payload.max_new_tokens or MAX_NEW_TOKENS,
163
  temperature=payload.temperature or DEFAULT_TEMPERATURE,
 
2
 
3
  import os
4
  from functools import lru_cache
5
+ from typing import List, Optional, Tuple
6
 
7
  import torch
8
  from fastapi import FastAPI, HTTPException
 
31
  MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "600"))
32
  DEFAULT_TEMPERATURE = float(os.environ.get("DEFAULT_TEMPERATURE", "0.2"))
33
  DEFAULT_TOP_P = float(os.environ.get("DEFAULT_TOP_P", "0.9"))
34
+ HF_TOKEN = os.environ.get("HF_TOKEN")
 
35
 
36
+ def _normalise_bool(value: Optional[str], *, default: bool = False) -> bool:
37
+ if value is None:
38
+ return default
39
+ return value.lower() in {"1", "true", "yes", "on"}
40
+
41
+
42
+ _strategy = os.environ.get("MODEL_LOAD_STRATEGY") or os.environ.get("LOAD_STRATEGY")
43
+ if _strategy:
44
+ _strategy = _strategy.lower().strip()
45
+
46
+ # Backwards compatibility flags remain available for older deployments.
47
+ USE_8BIT = _normalise_bool(os.environ.get("LOAD_IN_8BIT"), default=True)
48
+ USE_4BIT = _normalise_bool(os.environ.get("LOAD_IN_4BIT"), default=False)
49
+
50
+ SKIP_WARM_START = _normalise_bool(os.environ.get("SKIP_WARM_START"), default=False)
51
+ ALLOW_WARM_START_FAILURE = _normalise_bool(
52
+ os.environ.get("ALLOW_WARM_START_FAILURE"),
53
+ default=False,
54
+ )
55
+
56
+
57
+ def _normalise_strategy(name: Optional[str]) -> Optional[str]:
58
+ if not name:
59
+ return None
60
+ alias = name.lower().strip()
61
+ mapping = {
62
+ "8": "8bit",
63
+ "8bit": "8bit",
64
+ "int8": "8bit",
65
+ "bnb8": "8bit",
66
+ "llm.int8": "8bit",
67
+ "4": "4bit",
68
+ "4bit": "4bit",
69
+ "int4": "4bit",
70
+ "bnb4": "4bit",
71
+ "nf4": "4bit",
72
+ "bf16": "bf16",
73
+ "bfloat16": "bf16",
74
+ "fp16": "fp16",
75
+ "float16": "fp16",
76
+ "half": "fp16",
77
+ "cpu": "cpu",
78
+ "fp32": "cpu",
79
+ "full": "cpu",
80
+ }
81
+ canonical = mapping.get(alias, alias)
82
+ if canonical not in {"8bit", "4bit", "bf16", "fp16", "cpu"}:
83
+ return None
84
+ return canonical
85
+
86
+
87
+ def _strategy_sequence() -> List[str]:
88
+ order: List[str] = []
89
+ seen: set[str] = set()
90
+
91
+ def push(entry: Optional[str]) -> None:
92
+ canonical = _normalise_strategy(entry)
93
+ if not canonical or canonical in seen:
94
+ return
95
+ seen.add(canonical)
96
+ order.append(canonical)
97
+
98
+ push(_strategy)
99
+ for raw in os.environ.get("MODEL_LOAD_STRATEGIES", "").split(","):
100
+ push(raw)
101
+
102
+ # Compatibility: honour legacy boolean switches.
103
+ if USE_8BIT:
104
+ push("8bit")
105
+ if USE_4BIT:
106
+ push("4bit")
107
+ if not (USE_8BIT or USE_4BIT):
108
+ push("bf16" if torch.cuda.is_available() else "cpu")
109
+
110
+ for fallback in ("8bit", "4bit", "bf16", "fp16", "cpu"):
111
+ push(fallback)
112
+ return order
113
+
114
+
115
+ DEFAULT_MODEL_FALLBACKS: List[str] = [
116
  "Alovestocode/router-gemma3-merged",
117
+ "Alovestocode/router-llama31-merged",
118
+ "Alovestocode/router-qwen3-32b-merged",
119
  ]
120
 
121
 
122
+ def _candidate_models() -> List[str]:
123
+ explicit = os.environ.get("MODEL_REPO")
124
+ overrides = [
125
+ item.strip()
126
+ for item in os.environ.get("MODEL_FALLBACKS", "").split(",")
127
+ if item.strip()
128
+ ]
129
+ candidates: List[str] = []
130
+ seen = set()
131
+ for name in [explicit, *overrides, *DEFAULT_MODEL_FALLBACKS]:
132
+ if not name or name in seen:
133
+ continue
134
+ seen.add(name)
135
+ candidates.append(name)
136
+ return candidates
137
+
138
+
139
  def _initialise_tokenizer() -> tuple[str, AutoTokenizer]:
140
  errors: dict[str, str] = {}
141
+ for candidate in _candidate_models():
 
 
 
 
 
 
 
142
  try:
143
+ tok = AutoTokenizer.from_pretrained(
144
+ candidate,
145
+ use_fast=False,
146
+ token=HF_TOKEN,
147
+ )
148
  print(f"Loaded tokenizer from {candidate}")
149
  return candidate, tok
150
  except Exception as exc: # pragma: no cover - download errors
 
171
 
172
 
173
  _MODEL = None
174
+ ACTIVE_STRATEGY: Optional[str] = None
175
+
176
+
177
+ def _build_load_kwargs(strategy: str, gpu_compute_dtype: torch.dtype) -> Tuple[str, dict]:
178
+ """Return kwargs for `from_pretrained` using the given strategy."""
179
+ cuda_available = torch.cuda.is_available()
180
+ strategy = strategy.lower()
181
+ kwargs: dict = {
182
+ "trust_remote_code": True,
183
+ "low_cpu_mem_usage": True,
184
+ "token": HF_TOKEN,
185
+ }
186
+ if strategy == "8bit":
187
+ if not cuda_available:
188
+ raise RuntimeError("8bit loading requires CUDA availability")
189
+ kwargs["device_map"] = "auto"
190
+ kwargs["quantization_config"] = BitsAndBytesConfig(
191
+ load_in_8bit=True,
192
+ llm_int8_threshold=6.0,
193
+ )
194
+ return "8bit", kwargs
195
+ if strategy == "4bit":
196
+ if not cuda_available:
197
+ raise RuntimeError("4bit loading requires CUDA availability")
198
+ kwargs["device_map"] = "auto"
199
+ kwargs["quantization_config"] = BitsAndBytesConfig(
200
+ load_in_4bit=True,
201
+ bnb_4bit_compute_dtype=gpu_compute_dtype,
202
+ bnb_4bit_use_double_quant=True,
203
+ bnb_4bit_quant_type="nf4",
204
+ )
205
+ return "4bit", kwargs
206
+ if strategy == "bf16":
207
+ kwargs["device_map"] = "auto" if cuda_available else "cpu"
208
+ kwargs["torch_dtype"] = torch.bfloat16 if cuda_available else torch.float32
209
+ return "bf16", kwargs
210
+ if strategy == "fp16":
211
+ kwargs["device_map"] = "auto" if cuda_available else "cpu"
212
+ kwargs["torch_dtype"] = torch.float16 if cuda_available else torch.float32
213
+ return "fp16", kwargs
214
+ if strategy == "cpu":
215
+ kwargs["device_map"] = "cpu"
216
+ kwargs["torch_dtype"] = torch.float32
217
+ return "cpu", kwargs
218
+ raise ValueError(f"Unknown load strategy: {strategy}")
219
 
220
 
 
221
  def get_model() -> AutoModelForCausalLM:
222
+ """Load the model. This function should be called within a @spaces.GPU decorated function."""
223
+ global _MODEL, ACTIVE_STRATEGY
224
  if _MODEL is None:
225
+ compute_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
226
+ attempts: List[Tuple[str, Exception]] = []
227
+ strategies = _strategy_sequence()
228
+ print(f"Attempting to load {MODEL_ID} with strategies: {strategies}")
229
+ for candidate in strategies:
230
+ try:
231
+ label, kwargs = _build_load_kwargs(candidate, compute_dtype)
232
+ print(f"Trying strategy '{label}' for {MODEL_ID} ...")
233
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
234
+ _MODEL = model.eval()
235
+ ACTIVE_STRATEGY = label
236
+ print(f"Loaded {MODEL_ID} with strategy='{label}'")
237
+ break
238
+ except Exception as exc: # pragma: no cover - depends on runtime
239
+ attempts.append((candidate, exc))
240
+ print(f"Strategy '{candidate}' failed: {exc}")
241
+ if torch.cuda.is_available():
242
+ torch.cuda.empty_cache()
243
+ if _MODEL is None:
244
+ detail = "; ".join(f"{name}: {err}" for name, err in attempts) or "no details"
245
+ last_exc = attempts[-1][1] if attempts else None
246
+ raise RuntimeError(
247
+ f"Unable to load {MODEL_ID}. Tried strategies {strategies}. Details: {detail}"
248
+ ) from last_exc
249
  return _MODEL
250
 
251
 
 
288
 
289
  @fastapi_app.get("/")
290
  def healthcheck() -> dict[str, str]:
291
+ return {
292
+ "status": "ok",
293
+ "model": MODEL_ID,
294
+ "strategy": ACTIVE_STRATEGY or "pending",
295
+ }
296
 
297
 
298
  @fastapi_app.on_event("startup")
299
  def warm_start() -> None:
300
+ """Warm start is disabled for ZeroGPU - model loads on first request."""
301
+ # ZeroGPU functions decorated with @spaces.GPU cannot be called during startup.
302
+ # They must be called within request handlers. Skip warm start for ZeroGPU.
303
+ print("Warm start skipped for ZeroGPU. Model will load on first request.")
304
+ return
305
+
306
+
307
+ @spaces.GPU(duration=300)
308
+ def _generate_with_gpu(
309
+ prompt: str,
310
+ max_new_tokens: int = MAX_NEW_TOKENS,
311
+ temperature: float = DEFAULT_TEMPERATURE,
312
+ top_p: float = DEFAULT_TOP_P,
313
+ ) -> str:
314
+ """Generate function wrapped with ZeroGPU decorator."""
315
+ return _generate(prompt, max_new_tokens, temperature, top_p)
316
 
317
 
318
  @fastapi_app.post("/v1/generate", response_model=GenerateResponse)
319
  def generate_endpoint(payload: GeneratePayload) -> GenerateResponse:
320
  try:
321
+ text = _generate_with_gpu(
322
  prompt=payload.prompt,
323
  max_new_tokens=payload.max_new_tokens or MAX_NEW_TOKENS,
324
  temperature=payload.temperature or DEFAULT_TEMPERATURE,