trackio-71827 / run.py
saba9's picture
saba9 HF Staff
Upload folder using huggingface_hub
b77494e verified
import threading
import time
import huggingface_hub
from gradio_client import Client, handle_file
from trackio.media import TrackioImage
from trackio.sqlite_storage import SQLiteStorage
from trackio.typehints import LogEntry, UploadEntry
from trackio.utils import RESERVED_KEYS, fibo, generate_readable_name
BATCH_SEND_INTERVAL = 0.5
class Run:
def __init__(
self,
url: str,
project: str,
client: Client | None,
name: str | None = None,
config: dict | None = None,
space_id: str | None = None,
):
self.url = url
self.project = project
self._client_lock = threading.Lock()
self._client_thread = None
self._client = client
self._space_id = space_id
self.name = name or generate_readable_name(
SQLiteStorage.get_runs(project), space_id
)
self.config = config or {}
self._queued_logs: list[LogEntry] = []
self._queued_uploads: list[UploadEntry] = []
self._stop_flag = threading.Event()
self._client_thread = threading.Thread(target=self._init_client_background)
self._client_thread.daemon = True
self._client_thread.start()
def _batch_sender(self):
"""Send batched logs every BATCH_SEND_INTERVAL."""
while not self._stop_flag.is_set() or len(self._queued_logs) > 0:
# If the stop flag has been set, then just quickly send all
# the logs and exit.
if not self._stop_flag.is_set():
time.sleep(BATCH_SEND_INTERVAL)
with self._client_lock:
if self._queued_logs and self._client is not None:
logs_to_send = self._queued_logs.copy()
self._queued_logs.clear()
self._client.predict(
api_name="/bulk_log",
logs=logs_to_send,
hf_token=huggingface_hub.utils.get_token(),
)
if self._queued_uploads and self._client is not None:
uploads_to_send = self._queued_uploads.copy()
self._queued_uploads.clear()
self._client.predict(
api_name="/bulk_upload_media",
uploads=uploads_to_send,
hf_token=huggingface_hub.utils.get_token(),
)
def _init_client_background(self):
if self._client is None:
fib = fibo()
for sleep_coefficient in fib:
try:
client = Client(self.url, verbose=False)
with self._client_lock:
self._client = client
break
except Exception:
pass
if sleep_coefficient is not None:
time.sleep(0.1 * sleep_coefficient)
self._batch_sender()
def _process_media(self, metrics, step: int | None) -> dict:
"""
Serialize media in metrics and upload to space if needed.
"""
serializable_metrics = {}
if not step:
step = 0
for key, value in metrics.items():
if isinstance(value, TrackioImage):
value._save(self.project, self.name, step)
serializable_metrics[key] = value._to_dict()
if self._space_id:
# Upload local media when deploying to space
upload_entry: UploadEntry = {
"project": self.project,
"run": self.name,
"step": step,
"uploaded_file": handle_file(value._get_absolute_file_path()),
}
with self._client_lock:
self._queued_uploads.append(upload_entry)
else:
serializable_metrics[key] = value
return serializable_metrics
def log(self, metrics: dict, step: int | None = None):
for k in metrics.keys():
if k in RESERVED_KEYS or k.startswith("__"):
raise ValueError(
f"Please do not use this reserved key as a metric: {k}"
)
metrics = self._process_media(metrics, step)
log_entry: LogEntry = {
"project": self.project,
"run": self.name,
"metrics": metrics,
"step": step,
}
with self._client_lock:
self._queued_logs.append(log_entry)
def finish(self):
"""Cleanup when run is finished."""
self._stop_flag.set()
# Wait for the batch sender to finish before joining the client thread.
time.sleep(2 * BATCH_SEND_INTERVAL)
if self._client_thread is not None:
print(
f"* Run finished. Uploading logs to Trackio Space: {self.url} (please wait...)"
)
self._client_thread.join()