File size: 891 Bytes
ebf3750 f4d4584 33abc82 9be26a7 a2f6b55 9be26a7 33abc82 9be26a7 33abc82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import os
os.environ['HF_HOME'] = '/tmp/hf_home'
os.environ['HF_DATASETS_CACHE'] = '/tmp/hf_datasets_cache'
os.environ['TRANSFORMERS_CACHE'] = '/tmp/cache'
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import uvicorn
MODEL_NAME = "16pramodh/t2s_model"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
app = FastAPI()
class QueryRequest(BaseModel):
text: str
@app.get("/")
def read_root():
return {"status": "running"}
@app.post("/predict")
def predict(request: QueryRequest):
inputs = tokenizer(request.text, return_tensors="pt")
outputs = model.generate(**inputs, max_length=256)
return {"sql": tokenizer.decode(outputs[0], skip_special_tokens=True)}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860) |