import os os.environ['HF_HOME'] = '/tmp/hf_home' os.environ['HF_DATASETS_CACHE'] = '/tmp/hf_datasets_cache' os.environ['TRANSFORMERS_CACHE'] = '/tmp/cache' from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import uvicorn MODEL_NAME = "16pramodh/t2s_model" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) app = FastAPI() class QueryRequest(BaseModel): text: str @app.get("/") def read_root(): return {"status": "running"} @app.post("/predict") def predict(request: QueryRequest): inputs = tokenizer(request.text, return_tensors="pt") outputs = model.generate(**inputs, max_length=256) return {"sql": tokenizer.decode(outputs[0], skip_special_tokens=True)} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)