Spaces:
Runtime error
Runtime error
import threading | |
import time | |
import huggingface_hub | |
from gradio_client import Client, handle_file | |
from trackio.media import TrackioImage | |
from trackio.sqlite_storage import SQLiteStorage | |
from trackio.typehints import LogEntry, UploadEntry | |
from trackio.utils import RESERVED_KEYS, fibo, generate_readable_name | |
BATCH_SEND_INTERVAL = 0.5 | |
class Run: | |
def __init__( | |
self, | |
url: str, | |
project: str, | |
client: Client | None, | |
name: str | None = None, | |
config: dict | None = None, | |
space_id: str | None = None, | |
): | |
self.url = url | |
self.project = project | |
self._client_lock = threading.Lock() | |
self._client_thread = None | |
self._client = client | |
self._space_id = space_id | |
self.name = name or generate_readable_name( | |
SQLiteStorage.get_runs(project), space_id | |
) | |
self.config = config or {} | |
self._queued_logs: list[LogEntry] = [] | |
self._queued_uploads: list[UploadEntry] = [] | |
self._stop_flag = threading.Event() | |
self._client_thread = threading.Thread(target=self._init_client_background) | |
self._client_thread.daemon = True | |
self._client_thread.start() | |
def _batch_sender(self): | |
"""Send batched logs every BATCH_SEND_INTERVAL.""" | |
while not self._stop_flag.is_set() or len(self._queued_logs) > 0: | |
# If the stop flag has been set, then just quickly send all | |
# the logs and exit. | |
if not self._stop_flag.is_set(): | |
time.sleep(BATCH_SEND_INTERVAL) | |
with self._client_lock: | |
if self._queued_logs and self._client is not None: | |
logs_to_send = self._queued_logs.copy() | |
self._queued_logs.clear() | |
self._client.predict( | |
api_name="/bulk_log", | |
logs=logs_to_send, | |
hf_token=huggingface_hub.utils.get_token(), | |
) | |
if self._queued_uploads and self._client is not None: | |
uploads_to_send = self._queued_uploads.copy() | |
self._queued_uploads.clear() | |
self._client.predict( | |
api_name="/bulk_upload_media", | |
uploads=uploads_to_send, | |
hf_token=huggingface_hub.utils.get_token(), | |
) | |
def _init_client_background(self): | |
if self._client is None: | |
fib = fibo() | |
for sleep_coefficient in fib: | |
try: | |
client = Client(self.url, verbose=False) | |
with self._client_lock: | |
self._client = client | |
break | |
except Exception: | |
pass | |
if sleep_coefficient is not None: | |
time.sleep(0.1 * sleep_coefficient) | |
self._batch_sender() | |
def _process_media(self, metrics, step: int | None) -> dict: | |
""" | |
Serialize media in metrics and upload to space if needed. | |
""" | |
serializable_metrics = {} | |
if not step: | |
step = 0 | |
for key, value in metrics.items(): | |
if isinstance(value, TrackioImage): | |
value._save(self.project, self.name, step) | |
serializable_metrics[key] = value._to_dict() | |
if self._space_id: | |
# Upload local media when deploying to space | |
upload_entry: UploadEntry = { | |
"project": self.project, | |
"run": self.name, | |
"step": step, | |
"uploaded_file": handle_file(value._get_absolute_file_path()), | |
} | |
with self._client_lock: | |
self._queued_uploads.append(upload_entry) | |
else: | |
serializable_metrics[key] = value | |
return serializable_metrics | |
def log(self, metrics: dict, step: int | None = None): | |
for k in metrics.keys(): | |
if k in RESERVED_KEYS or k.startswith("__"): | |
raise ValueError( | |
f"Please do not use this reserved key as a metric: {k}" | |
) | |
metrics = self._process_media(metrics, step) | |
log_entry: LogEntry = { | |
"project": self.project, | |
"run": self.name, | |
"metrics": metrics, | |
"step": step, | |
} | |
with self._client_lock: | |
self._queued_logs.append(log_entry) | |
def finish(self): | |
"""Cleanup when run is finished.""" | |
self._stop_flag.set() | |
# Wait for the batch sender to finish before joining the client thread. | |
time.sleep(2 * BATCH_SEND_INTERVAL) | |
if self._client_thread is not None: | |
print( | |
f"* Run finished. Uploading logs to Trackio Space: {self.url} (please wait...)" | |
) | |
self._client_thread.join() | |