Alovestocode commited on
Commit
40a2927
·
verified ·
1 Parent(s): 7d6ddbd

Refactor: Mount Gradio on FastAPI, use gr.mount_gradio_app for proper integration

Browse files
Files changed (3) hide show
  1. README.md +4 -2
  2. app.py +21 -67
  3. test_api.py +106 -0
README.md CHANGED
@@ -20,7 +20,7 @@ endpoint via the `HF_ROUTER_API` environment variable.
20
 
21
  | File | Purpose |
22
  | ---- | ------- |
23
- | `app.py` | Loads the merged checkpoint on demand (tries `MODEL_REPO` first, then `MODEL_FALLBACKS` or the default Gemma → Llama → Qwen order), exposes a `/v1/generate` API, and serves a small HTML console at `/gradio`. |
24
  | `requirements.txt` | Minimal dependency set (transformers, bitsandbytes, torch, fastapi, accelerate, sentencepiece, spaces, uvicorn). |
25
  | `.huggingface/spaces.yml` | Configures the Space for ZeroGPU hardware and disables automatic sleep. |
26
 
@@ -75,4 +75,6 @@ that the deployed model returns the expected JSON plan. When running on ZeroGPU
75
  we recommend keeping `MODEL_LOAD_STRATEGY=8bit` (or `LOAD_IN_8BIT=1`) so the
76
  weights fit comfortably in the 70GB slice; if that fails the app automatically
77
  degrades through 4-bit, bf16/fp16, and finally CPU mode. You can inspect the
78
- active load mode via the `/` healthcheck (`strategy` field).
 
 
 
20
 
21
  | File | Purpose |
22
  | ---- | ------- |
23
+ | `app.py` | Loads the merged checkpoint on demand (tries `MODEL_REPO` first, then `MODEL_FALLBACKS` or the default Gemma → Llama → Qwen order), exposes a `/v1/generate` API, mounts the Gradio UI at `/gradio`, and keeps a lightweight HTML console at `/console`. |
24
  | `requirements.txt` | Minimal dependency set (transformers, bitsandbytes, torch, fastapi, accelerate, sentencepiece, spaces, uvicorn). |
25
  | `.huggingface/spaces.yml` | Configures the Space for ZeroGPU hardware and disables automatic sleep. |
26
 
 
75
  we recommend keeping `MODEL_LOAD_STRATEGY=8bit` (or `LOAD_IN_8BIT=1`) so the
76
  weights fit comfortably in the 70GB slice; if that fails the app automatically
77
  degrades through 4-bit, bf16/fp16, and finally CPU mode. You can inspect the
78
+ active load mode via the `/health` endpoint (`strategy` field). The root path
79
+ (`/`) now redirects to the Gradio UI, while `/console` serves the minimal HTML
80
+ form for quick manual testing.
app.py CHANGED
@@ -6,7 +6,7 @@ from typing import List, Optional, Tuple
6
 
7
  import torch
8
  from fastapi import FastAPI, HTTPException
9
- from fastapi.responses import HTMLResponse
10
  from pydantic import BaseModel
11
 
12
  try:
@@ -297,7 +297,13 @@ def _generate_with_gpu(
297
  fastapi_app = FastAPI(title="Router Model API", version="1.0.0")
298
 
299
 
300
- @fastapi_app.get("/")
 
 
 
 
 
 
301
  def healthcheck() -> dict[str, str]:
302
  return {
303
  "status": "ok",
@@ -329,7 +335,7 @@ def generate_endpoint(payload: GeneratePayload) -> GenerateResponse:
329
  return GenerateResponse(text=text)
330
 
331
 
332
- @fastapi_app.get("/gradio", response_class=HTMLResponse)
333
  def interactive_ui() -> str:
334
  return """
335
  <!doctype html>
@@ -495,8 +501,9 @@ with gr.Blocks(
495
  }
496
  ```
497
 
498
- **GET** `/` - Health check
499
- **GET** `/gradio` - Interactive UI
 
500
  """)
501
 
502
  # Event handlers
@@ -505,72 +512,19 @@ with gr.Blocks(
505
  inputs=[prompt_input, max_tokens_input, temp_input, top_p_input],
506
  outputs=output,
507
  )
508
-
509
  clear_btn.click(
510
  fn=lambda: ("", ""),
511
  outputs=[prompt_input, output],
512
  )
513
-
514
- # Add FastAPI routes using Gradio's load event
515
- # This ensures routes are added after Gradio is fully initialized
516
- def add_api_routes():
517
- """Add API routes after Gradio app is loaded."""
518
- try:
519
- from fastapi.responses import JSONResponse
520
- from starlette.routing import Route
521
-
522
- async def generate_handler(request):
523
- """Handle POST /v1/generate requests."""
524
- try:
525
- data = await request.json()
526
- payload = GeneratePayload(**data)
527
- text = _generate_with_gpu(
528
- prompt=payload.prompt,
529
- max_new_tokens=payload.max_new_tokens or MAX_NEW_TOKENS,
530
- temperature=payload.temperature or DEFAULT_TEMPERATURE,
531
- top_p=payload.top_p or DEFAULT_TOP_P,
532
- )
533
- return JSONResponse(content={"text": text})
534
- except Exception as exc:
535
- from fastapi import HTTPException
536
- raise HTTPException(status_code=500, detail=str(exc))
537
-
538
- async def healthcheck_handler(request):
539
- """Handle GET /api/health requests."""
540
- return JSONResponse(content={
541
- "status": "ok",
542
- "model": MODEL_ID,
543
- "strategy": ACTIVE_STRATEGY or "pending",
544
- })
545
-
546
- async def gradio_ui_handler(request):
547
- """Handle GET /api/gradio requests."""
548
- return HTMLResponse(interactive_ui())
549
-
550
- # Add routes using Route objects
551
- gradio_app.app.router.routes.append(
552
- Route("/v1/generate", generate_handler, methods=["POST"])
553
- )
554
- gradio_app.app.router.routes.append(
555
- Route("/api/health", healthcheck_handler, methods=["GET"])
556
- )
557
- gradio_app.app.router.routes.append(
558
- Route("/api/gradio", gradio_ui_handler, methods=["GET"])
559
- )
560
- gradio_app.app.router.routes.append(
561
- Route("/gradio", gradio_ui_handler, methods=["GET"])
562
- )
563
- print("FastAPI routes added successfully via load event")
564
- except Exception as e:
565
- print(f"Warning: Could not add FastAPI routes: {e}")
566
- import traceback
567
- traceback.print_exc()
568
-
569
- # Use load event to add routes after app initialization
570
- gradio_app.load(add_api_routes)
571
 
572
- # Set app to Gradio Blocks for Spaces - ZeroGPU requires Gradio SDK
573
- app = gradio_app
 
 
 
574
 
575
  if __name__ == "__main__": # pragma: no cover
576
- app.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
 
 
 
6
 
7
  import torch
8
  from fastapi import FastAPI, HTTPException
9
+ from fastapi.responses import HTMLResponse, RedirectResponse
10
  from pydantic import BaseModel
11
 
12
  try:
 
297
  fastapi_app = FastAPI(title="Router Model API", version="1.0.0")
298
 
299
 
300
+ @fastapi_app.get("/", response_class=RedirectResponse)
301
+ def root_redirect() -> RedirectResponse:
302
+ """Redirect root traffic to the Gradio UI for a cleaner Spaces landing."""
303
+ return RedirectResponse(url="/gradio", status_code=307)
304
+
305
+
306
+ @fastapi_app.get("/health")
307
  def healthcheck() -> dict[str, str]:
308
  return {
309
  "status": "ok",
 
335
  return GenerateResponse(text=text)
336
 
337
 
338
+ @fastapi_app.get("/console", response_class=HTMLResponse)
339
  def interactive_ui() -> str:
340
  return """
341
  <!doctype html>
 
501
  }
502
  ```
503
 
504
+ **GET** `/health` - JSON health check
505
+ **GET** `/gradio` - Full Gradio UI
506
+ **GET** `/console` - Minimal HTML console
507
  """)
508
 
509
  # Event handlers
 
512
  inputs=[prompt_input, max_tokens_input, temp_input, top_p_input],
513
  outputs=output,
514
  )
515
+
516
  clear_btn.click(
517
  fn=lambda: ("", ""),
518
  outputs=[prompt_input, output],
519
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
 
521
+ # Enable queued execution so ZeroGPU can schedule GPU work reliably
522
+ gradio_app.queue(max_size=8)
523
+
524
+ # Mount the Gradio UI onto the FastAPI app (served under /gradio)
525
+ app = gr.mount_gradio_app(fastapi_app, gradio_app, path="/gradio")
526
 
527
  if __name__ == "__main__": # pragma: no cover
528
+ import uvicorn
529
+
530
+ uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))
test_api.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Test script for Router API endpoints."""
3
+
4
+ import requests
5
+ import json
6
+ import time
7
+ import sys
8
+
9
+ BASE_URL = "https://Alovestocode-router-router-zero.hf.space"
10
+
11
+ def test_healthcheck():
12
+ """Test the health check endpoint."""
13
+ print("Testing GET /health...")
14
+ try:
15
+ response = requests.get(f"{BASE_URL}/health", timeout=10)
16
+ print(f"Status: {response.status_code}")
17
+ if response.status_code == 200:
18
+ print(f"Response: {json.dumps(response.json(), indent=2)}")
19
+ return True
20
+ else:
21
+ print(f"Error: {response.text}")
22
+ return False
23
+ except Exception as e:
24
+ print(f"Exception: {e}")
25
+ return False
26
+
27
+ def test_generate():
28
+ """Test the generate endpoint."""
29
+ print("\nTesting POST /v1/generate...")
30
+ try:
31
+ payload = {
32
+ "prompt": "You are a router agent. User query: What is 2+2?",
33
+ "max_new_tokens": 100,
34
+ "temperature": 0.2,
35
+ "top_p": 0.9
36
+ }
37
+ response = requests.post(
38
+ f"{BASE_URL}/v1/generate",
39
+ json=payload,
40
+ headers={"Content-Type": "application/json"},
41
+ timeout=60 # Longer timeout for model loading
42
+ )
43
+ print(f"Status: {response.status_code}")
44
+ if response.status_code == 200:
45
+ result = response.json()
46
+ print(f"Response keys: {list(result.keys())}")
47
+ if "text" in result:
48
+ print(f"Generated text (first 200 chars): {result['text'][:200]}...")
49
+ else:
50
+ print(f"Full response: {json.dumps(result, indent=2)}")
51
+ return True
52
+ else:
53
+ print(f"Error: {response.text}")
54
+ return False
55
+ except Exception as e:
56
+ print(f"Exception: {e}")
57
+ return False
58
+
59
+ def test_gradio_ui():
60
+ """Test the Gradio UI endpoint."""
61
+ print("\nTesting GET /gradio (UI redirect target)...")
62
+ try:
63
+ response = requests.get(f"{BASE_URL}/gradio", timeout=10)
64
+ print(f"Status: {response.status_code}")
65
+ if response.status_code == 200:
66
+ print(f"Response length: {len(response.text)} chars")
67
+ print(f"Response type: {response.headers.get('content-type', 'unknown')}")
68
+ return True
69
+ else:
70
+ print(f"Error: {response.text[:200]}")
71
+ return False
72
+ except Exception as e:
73
+ print(f"Exception: {e}")
74
+ return False
75
+
76
+ def main():
77
+ """Run all API tests."""
78
+ print("=" * 60)
79
+ print("Router API Test Suite")
80
+ print("=" * 60)
81
+ print(f"Base URL: {BASE_URL}\n")
82
+
83
+ # Wait a moment for Space to be ready
84
+ print("Waiting 5 seconds for Space to be ready...")
85
+ time.sleep(5)
86
+
87
+ results = []
88
+
89
+ # Test endpoints
90
+ results.append(("Health Check", test_healthcheck()))
91
+ results.append(("Generate", test_generate()))
92
+ results.append(("Gradio UI", test_gradio_ui()))
93
+
94
+ # Summary
95
+ print("\n" + "=" * 60)
96
+ print("Test Summary")
97
+ print("=" * 60)
98
+ for name, passed in results:
99
+ status = "✅ PASS" if passed else "❌ FAIL"
100
+ print(f"{name}: {status}")
101
+
102
+ all_passed = all(result[1] for result in results)
103
+ sys.exit(0 if all_passed else 1)
104
+
105
+ if __name__ == "__main__":
106
+ main()