File size: 2,847 Bytes
97c09ee
05ce374
6877262
97c09ee
 
 
 
 
 
 
 
 
9a464d1
 
 
 
7f2343c
 
f29d0cf
 
97c09ee
 
dc3c63f
 
 
 
 
 
 
 
9a464d1
 
 
 
 
 
 
 
 
 
 
f29d0cf
9a464d1
 
dc3c63f
 
9a464d1
dc3c63f
9a464d1
 
0f89fc4
9a464d1
 
 
 
 
dc3c63f
f29d0cf
dc3c63f
9a464d1
dc3c63f
9a464d1
 
dc3c63f
9a464d1
 
dc3c63f
 
 
 
 
 
 
 
 
 
9a464d1
 
 
 
 
 
 
 
f29d0cf
 
 
 
 
9a464d1
f29d0cf
9a464d1
 
f29d0cf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import io
import logging
# Set HF cache dirs to writable paths before importing HF libraries
os.environ["HF_HOME"] = "/app/.cache/huggingface"
os.environ["XDG_CACHE_HOME"] = "/app/.cache"

# Make sure cache folders exist
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
os.makedirs(os.environ["XDG_CACHE_HOME"], exist_ok=True)

# Now import other libs
from fastapi import FastAPI, BackgroundTasks
import uuid
from huggingface_hub import snapshot_download
from flux_train import build_job
import sys

sys.path.append("/app/ai-toolkit")
from toolkit.job import run_job



log_stream = io.StringIO()
logger = logging.getLogger("lora_training")
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler(log_stream)
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
logger.addHandler(handler)


app = FastAPI()

REPO_ID = "rahul7star/ohamlab"
FOLDER_IN_REPO = "filter-demo/upload_20250708_041329_9c5c81"
CONCEPT_SENTENCE = "ohamlab style"
LORA_NAME = "ohami_filter_autorun"
HF_TOKEN = os.environ.get("HF_TOKEN", "")

status = {"running": False, "last_job": None, "error": None}

def run_lora_training(push_to_hub: bool = False):
    try:
        status.update({"running": True, "error": None})
        logger.info("Starting training...")

        local_dir = f"/tmp/{LORA_NAME}-{uuid.uuid4()}"
        logger.info(f"Downloading dataset to {local_dir} ...")
        snapshot_download(
            repo_id=REPO_ID,
            repo_type="model",
            allow_patterns=[f"{FOLDER_IN_REPO}/*"],
            local_dir=local_dir,
            local_dir_use_symlinks=False
        )
        training_path = os.path.join(local_dir, FOLDER_IN_REPO)
        logger.info(f"Building job with training path: {training_path}")
        job = build_job(CONCEPT_SENTENCE, training_path, LORA_NAME, push_to_hub=push_to_hub)
        logger.info("Running job...")
        run_job(job)
        logger.info("Training completed successfully.")
        status.update({"running": False, "last_job": job})
    except Exception as e:
        logger.error(f"Training failed: {e}")
        status.update({"running": False, "error": str(e)})





@app.get("/logs")
def get_logs():
    return {"logs": log_stream.getvalue()}


    
@app.get("/")
def root():
    return {"message": "LoRA training FastAPI is live."}

@app.get("/status")
def get_status():
    return status

from pydantic import BaseModel

class TrainRequest(BaseModel):
    push_to_hub: bool = False

@app.post("/train")
def start_training(background_tasks: BackgroundTasks, request: TrainRequest):
    if status["running"]:
        return {"message": "A training job is already running."}
    background_tasks.add_task(run_lora_training, push_to_hub=request.push_to_hub)
    return {"message": "Training started in background.", "push_to_hub": request.push_to_hub}