Spaces:
Sleeping
Sleeping
import os | |
import io | |
import logging | |
# Set HF cache dirs to writable paths before importing HF libraries | |
os.environ["HF_HOME"] = "/app/.cache/huggingface" | |
os.environ["XDG_CACHE_HOME"] = "/app/.cache" | |
# Make sure cache folders exist | |
os.makedirs(os.environ["HF_HOME"], exist_ok=True) | |
os.makedirs(os.environ["XDG_CACHE_HOME"], exist_ok=True) | |
# Now import other libs | |
from fastapi import FastAPI, BackgroundTasks | |
import uuid | |
from huggingface_hub import snapshot_download | |
from flux_train import build_job | |
import sys | |
sys.path.append("/app/ai-toolkit") | |
from toolkit.job import run_job | |
log_stream = io.StringIO() | |
logger = logging.getLogger("lora_training") | |
logger.setLevel(logging.DEBUG) | |
handler = logging.StreamHandler(log_stream) | |
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s")) | |
logger.addHandler(handler) | |
app = FastAPI() | |
REPO_ID = "rahul7star/ohamlab" | |
FOLDER_IN_REPO = "filter-demo/upload_20250708_041329_9c5c81" | |
CONCEPT_SENTENCE = "ohamlab style" | |
LORA_NAME = "ohami_filter_autorun" | |
HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
status = {"running": False, "last_job": None, "error": None} | |
def run_lora_training(push_to_hub: bool = False): | |
try: | |
status.update({"running": True, "error": None}) | |
logger.info("Starting training...") | |
local_dir = f"/tmp/{LORA_NAME}-{uuid.uuid4()}" | |
logger.info(f"Downloading dataset to {local_dir} ...") | |
snapshot_download( | |
repo_id=REPO_ID, | |
repo_type="model", | |
allow_patterns=[f"{FOLDER_IN_REPO}/*"], | |
local_dir=local_dir, | |
local_dir_use_symlinks=False | |
) | |
training_path = os.path.join(local_dir, FOLDER_IN_REPO) | |
logger.info(f"Building job with training path: {training_path}") | |
job = build_job(CONCEPT_SENTENCE, training_path, LORA_NAME, push_to_hub=push_to_hub) | |
logger.info("Running job...") | |
run_job(job) | |
logger.info("Training completed successfully.") | |
status.update({"running": False, "last_job": job}) | |
except Exception as e: | |
logger.error(f"Training failed: {e}") | |
status.update({"running": False, "error": str(e)}) | |
def get_logs(): | |
return {"logs": log_stream.getvalue()} | |
def root(): | |
return {"message": "LoRA training FastAPI is live."} | |
def get_status(): | |
return status | |
from pydantic import BaseModel | |
class TrainRequest(BaseModel): | |
push_to_hub: bool = False | |
def start_training(background_tasks: BackgroundTasks, request: TrainRequest): | |
if status["running"]: | |
return {"message": "A training job is already running."} | |
background_tasks.add_task(run_lora_training, push_to_hub=request.push_to_hub) | |
return {"message": "Training started in background.", "push_to_hub": request.push_to_hub} | |