sqlspace / app.py
16pramodh's picture
changes to app.py
eac2bba
raw
history blame contribute delete
891 Bytes
import os
os.environ['HF_HOME'] = '/tmp/hf_home'
os.environ['HF_DATASETS_CACHE'] = '/tmp/hf_datasets_cache'
os.environ['TRANSFORMERS_CACHE'] = '/tmp/cache'
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import uvicorn
MODEL_NAME = "16pramodh/t2s_model"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
app = FastAPI()
class QueryRequest(BaseModel):
text: str
@app.get("/")
def read_root():
return {"status": "running"}
@app.post("/predict")
def predict(request: QueryRequest):
inputs = tokenizer(request.text, return_tensors="pt")
outputs = model.generate(**inputs, max_length=256)
return {"sql": tokenizer.decode(outputs[0], skip_special_tokens=True)}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)