Spaces:
Sleeping
Sleeping
File size: 2,847 Bytes
97c09ee 05ce374 6877262 97c09ee 9a464d1 7f2343c f29d0cf 97c09ee dc3c63f 9a464d1 f29d0cf 9a464d1 dc3c63f 9a464d1 dc3c63f 9a464d1 0f89fc4 9a464d1 dc3c63f f29d0cf dc3c63f 9a464d1 dc3c63f 9a464d1 dc3c63f 9a464d1 dc3c63f 9a464d1 f29d0cf 9a464d1 f29d0cf 9a464d1 f29d0cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import os
import io
import logging
# Set HF cache dirs to writable paths before importing HF libraries
os.environ["HF_HOME"] = "/app/.cache/huggingface"
os.environ["XDG_CACHE_HOME"] = "/app/.cache"
# Make sure cache folders exist
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
os.makedirs(os.environ["XDG_CACHE_HOME"], exist_ok=True)
# Now import other libs
from fastapi import FastAPI, BackgroundTasks
import uuid
from huggingface_hub import snapshot_download
from flux_train import build_job
import sys
sys.path.append("/app/ai-toolkit")
from toolkit.job import run_job
log_stream = io.StringIO()
logger = logging.getLogger("lora_training")
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler(log_stream)
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
logger.addHandler(handler)
app = FastAPI()
REPO_ID = "rahul7star/ohamlab"
FOLDER_IN_REPO = "filter-demo/upload_20250708_041329_9c5c81"
CONCEPT_SENTENCE = "ohamlab style"
LORA_NAME = "ohami_filter_autorun"
HF_TOKEN = os.environ.get("HF_TOKEN", "")
status = {"running": False, "last_job": None, "error": None}
def run_lora_training(push_to_hub: bool = False):
try:
status.update({"running": True, "error": None})
logger.info("Starting training...")
local_dir = f"/tmp/{LORA_NAME}-{uuid.uuid4()}"
logger.info(f"Downloading dataset to {local_dir} ...")
snapshot_download(
repo_id=REPO_ID,
repo_type="model",
allow_patterns=[f"{FOLDER_IN_REPO}/*"],
local_dir=local_dir,
local_dir_use_symlinks=False
)
training_path = os.path.join(local_dir, FOLDER_IN_REPO)
logger.info(f"Building job with training path: {training_path}")
job = build_job(CONCEPT_SENTENCE, training_path, LORA_NAME, push_to_hub=push_to_hub)
logger.info("Running job...")
run_job(job)
logger.info("Training completed successfully.")
status.update({"running": False, "last_job": job})
except Exception as e:
logger.error(f"Training failed: {e}")
status.update({"running": False, "error": str(e)})
@app.get("/logs")
def get_logs():
return {"logs": log_stream.getvalue()}
@app.get("/")
def root():
return {"message": "LoRA training FastAPI is live."}
@app.get("/status")
def get_status():
return status
from pydantic import BaseModel
class TrainRequest(BaseModel):
push_to_hub: bool = False
@app.post("/train")
def start_training(background_tasks: BackgroundTasks, request: TrainRequest):
if status["running"]:
return {"message": "A training job is already running."}
background_tasks.add_task(run_lora_training, push_to_hub=request.push_to_hub)
return {"message": "Training started in background.", "push_to_hub": request.push_to_hub}
|