prelington commited on
Commit
7e13eda
·
verified ·
1 Parent(s): 15b92a1

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +35 -9
api.py CHANGED
@@ -1,32 +1,58 @@
1
  from fastapi import FastAPI
 
2
  from pydantic import BaseModel
3
- from inference import predict
 
4
 
5
- app = FastAPI(title="Your Model API", version="1.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  class PredictionRequest(BaseModel):
8
  text: str
 
9
 
10
  class PredictionResponse(BaseModel):
11
- prediction: dict
12
  status: str
 
13
 
14
  @app.get("/")
15
- def read_root():
16
- return {"message": "Your Model API is running!"}
 
 
 
 
17
 
18
  @app.post("/predict", response_model=PredictionResponse)
19
- async def predict_endpoint(request: PredictionRequest):
20
  try:
21
- result = predict(request.text)
22
  return PredictionResponse(
23
  prediction=result,
24
  status="success"
25
  )
26
  except Exception as e:
27
  return PredictionResponse(
28
- prediction={},
29
  status=f"error: {str(e)}"
30
  )
31
 
32
- # Run with: uvicorn api:app --host 0.0.0.0 --port 8000
 
 
 
1
  from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
+ from transformers import pipeline
5
+ import uvicorn
6
 
7
+ # Initialize app
8
+ app = FastAPI(title="OrcaleSeek API", version="1.0.0")
9
+
10
+ # CORS for web access
11
+ app.add_middleware(
12
+ CORSMiddleware,
13
+ allow_origins=["*"], # Change this to your website domain
14
+ allow_methods=["*"],
15
+ allow_headers=["*"],
16
+ )
17
+
18
+ # Load model
19
+ classifier = pipeline(
20
+ "text-classification",
21
+ model="your-username/OrcaleSeek",
22
+ tokenizer="your-username/OrcaleSeek"
23
+ )
24
 
25
  class PredictionRequest(BaseModel):
26
  text: str
27
+ max_length: int = 128
28
 
29
  class PredictionResponse(BaseModel):
30
+ prediction: list
31
  status: str
32
+ model: str = "OrcaleSeek"
33
 
34
  @app.get("/")
35
+ def home():
36
+ return {"message": "OrcaleSeek API is running! 🚀"}
37
+
38
+ @app.get("/health")
39
+ def health_check():
40
+ return {"status": "healthy"}
41
 
42
  @app.post("/predict", response_model=PredictionResponse)
43
+ async def predict(request: PredictionRequest):
44
  try:
45
+ result = classifier(request.text)
46
  return PredictionResponse(
47
  prediction=result,
48
  status="success"
49
  )
50
  except Exception as e:
51
  return PredictionResponse(
52
+ prediction=[],
53
  status=f"error: {str(e)}"
54
  )
55
 
56
+ # Run with: uvicorn api:app --host 0.0.0.0 --port 8000 --reload
57
+ if __name__ == "__main__":
58
+ uvicorn.run(app, host="0.0.0.0", port=8000)