import os import io import logging # Set HF cache dirs to writable paths before importing HF libraries os.environ["HF_HOME"] = "/app/.cache/huggingface" os.environ["XDG_CACHE_HOME"] = "/app/.cache" # Make sure cache folders exist os.makedirs(os.environ["HF_HOME"], exist_ok=True) os.makedirs(os.environ["XDG_CACHE_HOME"], exist_ok=True) # Now import other libs from fastapi import FastAPI, BackgroundTasks import uuid from huggingface_hub import snapshot_download from flux_train import build_job import sys sys.path.append("/app/ai-toolkit") from toolkit.job import run_job log_stream = io.StringIO() logger = logging.getLogger("lora_training") logger.setLevel(logging.DEBUG) handler = logging.StreamHandler(log_stream) handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s")) logger.addHandler(handler) app = FastAPI() REPO_ID = "rahul7star/ohamlab" FOLDER_IN_REPO = "filter-demo/upload_20250708_041329_9c5c81" CONCEPT_SENTENCE = "ohamlab style" LORA_NAME = "ohami_filter_autorun" HF_TOKEN = os.environ.get("HF_TOKEN", "") status = {"running": False, "last_job": None, "error": None} def run_lora_training(push_to_hub: bool = False): try: status.update({"running": True, "error": None}) logger.info("Starting training...") local_dir = f"/tmp/{LORA_NAME}-{uuid.uuid4()}" logger.info(f"Downloading dataset to {local_dir} ...") snapshot_download( repo_id=REPO_ID, repo_type="model", allow_patterns=[f"{FOLDER_IN_REPO}/*"], local_dir=local_dir, local_dir_use_symlinks=False ) training_path = os.path.join(local_dir, FOLDER_IN_REPO) logger.info(f"Building job with training path: {training_path}") job = build_job(CONCEPT_SENTENCE, training_path, LORA_NAME, push_to_hub=push_to_hub) logger.info("Running job...") run_job(job) logger.info("Training completed successfully.") status.update({"running": False, "last_job": job}) except Exception as e: logger.error(f"Training failed: {e}") status.update({"running": False, "error": str(e)}) @app.get("/logs") def get_logs(): return {"logs": log_stream.getvalue()} @app.get("/") def root(): return {"message": "LoRA training FastAPI is live."} @app.get("/status") def get_status(): return status from pydantic import BaseModel class TrainRequest(BaseModel): push_to_hub: bool = False @app.post("/train") def start_training(background_tasks: BackgroundTasks, request: TrainRequest): if status["running"]: return {"message": "A training job is already running."} background_tasks.add_task(run_lora_training, push_to_hub=request.push_to_hub) return {"message": "Training started in background.", "push_to_hub": request.push_to_hub}