|
import os |
|
os.environ['HF_HOME'] = '/tmp/hf_home' |
|
os.environ['HF_DATASETS_CACHE'] = '/tmp/hf_datasets_cache' |
|
os.environ['TRANSFORMERS_CACHE'] = '/tmp/cache' |
|
|
|
from fastapi import FastAPI |
|
from pydantic import BaseModel |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import uvicorn |
|
|
|
MODEL_NAME = "16pramodh/t2s_model" |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) |
|
|
|
app = FastAPI() |
|
|
|
class QueryRequest(BaseModel): |
|
text: str |
|
|
|
@app.get("/") |
|
def read_root(): |
|
return {"status": "running"} |
|
|
|
@app.post("/predict") |
|
def predict(request: QueryRequest): |
|
inputs = tokenizer(request.text, return_tensors="pt") |
|
outputs = model.generate(**inputs, max_length=256) |
|
return {"sql": tokenizer.decode(outputs[0], skip_special_tokens=True)} |
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |