Spaces:
Running
on
Zero
Running
on
Zero
import contextlib | |
from typing import Any, Iterable, Iterator, Optional | |
try: | |
from tqdm import tqdm | |
except ImportError: | |
tqdm = None | |
try: | |
from ray.experimental.tqdm_ray import tqdm as ray_tqdm | |
except: | |
ray_tqdm = None | |
# Global state | |
_current_progress_type = "none" | |
_is_progress_bar_active = False | |
class DummyProgressBar: | |
"""A no-op progress bar that mimics tqdm interface""" | |
def __init__(self, iterable=None, **kwargs): | |
self.iterable = iterable | |
def __iter__(self): | |
return iter(self.iterable) | |
def update(self, n=1): | |
pass | |
def close(self): | |
pass | |
def set_description(self, desc): | |
pass | |
def get_new_progress_bar(iterable: Optional[Iterable] = None, **kwargs) -> Any: | |
if not _is_progress_bar_active: | |
return DummyProgressBar(iterable=iterable, **kwargs) | |
if _current_progress_type == "tqdm": | |
if tqdm is None: | |
raise ImportError("tqdm is required but not installed. Please install tqdm to use the tqdm progress bar.") | |
return tqdm(iterable=iterable, **kwargs) | |
elif _current_progress_type == "ray_tqdm": | |
if ray_tqdm is None: | |
raise ImportError("ray is required but not installed. Please install ray to use the ray_tqdm progress bar.") | |
return ray_tqdm(iterable=iterable, **kwargs) | |
return DummyProgressBar(iterable=iterable, **kwargs) | |
def progress_bar(type: str = "none", enabled=True): | |
""" | |
Context manager for setting progress bar type and options. | |
Args: | |
type: Type of progress bar ("none" or "tqdm") | |
**options: Options to pass to the progress bar (e.g., total, desc) | |
Raises: | |
ValueError: If progress bar type is invalid | |
RuntimeError: If progress bars are nested | |
Example: | |
with progress_bar(type="tqdm", total=100): | |
for i in get_new_progress_bar(range(100)): | |
process(i) | |
""" | |
if type not in ("none", "tqdm", "ray_tqdm"): | |
raise ValueError("Progress bar type must be 'none' or 'tqdm' or 'ray_tqdm'") | |
if not enabled: | |
type = "none" | |
global _current_progress_type, _is_progress_bar_active | |
if _is_progress_bar_active: | |
raise RuntimeError("Nested progress bars are not supported") | |
_is_progress_bar_active = True | |
_current_progress_type = type | |
try: | |
yield | |
finally: | |
_is_progress_bar_active = False | |
_current_progress_type = "none" | |