rahul7star's picture
Update app.py
6877262 verified
import os
import io
import logging
# Set HF cache dirs to writable paths before importing HF libraries
os.environ["HF_HOME"] = "/app/.cache/huggingface"
os.environ["XDG_CACHE_HOME"] = "/app/.cache"
# Make sure cache folders exist
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
os.makedirs(os.environ["XDG_CACHE_HOME"], exist_ok=True)
# Now import other libs
from fastapi import FastAPI, BackgroundTasks
import uuid
from huggingface_hub import snapshot_download
from flux_train import build_job
import sys
sys.path.append("/app/ai-toolkit")
from toolkit.job import run_job
log_stream = io.StringIO()
logger = logging.getLogger("lora_training")
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler(log_stream)
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
logger.addHandler(handler)
app = FastAPI()
REPO_ID = "rahul7star/ohamlab"
FOLDER_IN_REPO = "filter-demo/upload_20250708_041329_9c5c81"
CONCEPT_SENTENCE = "ohamlab style"
LORA_NAME = "ohami_filter_autorun"
HF_TOKEN = os.environ.get("HF_TOKEN", "")
status = {"running": False, "last_job": None, "error": None}
def run_lora_training(push_to_hub: bool = False):
try:
status.update({"running": True, "error": None})
logger.info("Starting training...")
local_dir = f"/tmp/{LORA_NAME}-{uuid.uuid4()}"
logger.info(f"Downloading dataset to {local_dir} ...")
snapshot_download(
repo_id=REPO_ID,
repo_type="model",
allow_patterns=[f"{FOLDER_IN_REPO}/*"],
local_dir=local_dir,
local_dir_use_symlinks=False
)
training_path = os.path.join(local_dir, FOLDER_IN_REPO)
logger.info(f"Building job with training path: {training_path}")
job = build_job(CONCEPT_SENTENCE, training_path, LORA_NAME, push_to_hub=push_to_hub)
logger.info("Running job...")
run_job(job)
logger.info("Training completed successfully.")
status.update({"running": False, "last_job": job})
except Exception as e:
logger.error(f"Training failed: {e}")
status.update({"running": False, "error": str(e)})
@app.get("/logs")
def get_logs():
return {"logs": log_stream.getvalue()}
@app.get("/")
def root():
return {"message": "LoRA training FastAPI is live."}
@app.get("/status")
def get_status():
return status
from pydantic import BaseModel
class TrainRequest(BaseModel):
push_to_hub: bool = False
@app.post("/train")
def start_training(background_tasks: BackgroundTasks, request: TrainRequest):
if status["running"]:
return {"message": "A training job is already running."}
background_tasks.add_task(run_lora_training, push_to_hub=request.push_to_hub)
return {"message": "Training started in background.", "push_to_hub": request.push_to_hub}