Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import re | |
| import sys | |
| import logging | |
| import warnings | |
| import urllib3 | |
| from modules import timer, errors | |
| initialized = False | |
| errors.install() | |
| logging.getLogger("DeepSpeed").disabled = True | |
| # os.environ.setdefault('OMP_NUM_THREADS', 1) | |
| # os.environ.setdefault('MKL_NUM_THREADS', 1) | |
| # import tensorflow as tf # pylint: disable=C0411 | |
| import torch # pylint: disable=C0411 | |
| # torch.set_num_threads(1) | |
| try: | |
| import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import | |
| errors.log.debug(f'Load IPEX=={ipex.__version__}') | |
| except Exception: | |
| pass | |
| urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) | |
| warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") | |
| import torchvision # pylint: disable=W0611,C0411 | |
| import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them # pylint: disable=W0611,C0411 | |
| logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) | |
| logging.getLogger("pytorch_lightning").disabled = True | |
| warnings.filterwarnings(action="ignore", category=DeprecationWarning) | |
| warnings.filterwarnings(action="ignore", category=FutureWarning) | |
| warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") | |
| if ".dev" in torch.__version__ or "+git" in torch.__version__: | |
| torch.__long_version__ = torch.__version__ | |
| torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) | |
| timer.startup.record("torch") | |
| import onnxruntime | |
| onnxruntime.set_default_logger_severity(3) | |
| timer.startup.record("onnx") | |
| from modules.onnx_impl import initialize_olive # pylint: disable=ungrouped-imports | |
| initialize_olive() | |
| timer.startup.record("olive") | |
| from fastapi import FastAPI # pylint: disable=W0611,C0411 | |
| import gradio # pylint: disable=W0611,C0411 | |
| timer.startup.record("gradio") | |
| errors.install([gradio]) | |
| import pydantic # pylint: disable=W0611,C0411 | |
| timer.startup.record("pydantic") | |
| import diffusers # pylint: disable=W0611,C0411 | |
| timer.startup.record("diffusers") | |
| def get_packages(): | |
| return { | |
| "torch": getattr(torch, "__long_version__", torch.__version__), | |
| "diffusers": diffusers.__version__, | |
| "gradio": gradio.__version__, | |
| } | |
| errors.log.info(f'Load packages: {get_packages()}') | |
| try: | |
| import os | |
| import math | |
| cores = os.cpu_count() | |
| affinity = len(os.sched_getaffinity(0)) | |
| threads = torch.get_num_threads() | |
| if threads < (affinity / 2): | |
| torch.set_num_threads(math.floor(affinity / 2)) | |
| threads = torch.get_num_threads() | |
| errors.log.debug(f'Detected: cores={cores} affinity={affinity} set threads={threads}') | |
| except Exception: | |
| pass | |
| try: # fix changed import in torchvision 0.17+, which breaks basicsr | |
| import torchvision.transforms.functional_tensor # pylint: disable=unused-import, ungrouped-imports | |
| except ImportError: | |
| try: | |
| import torchvision.transforms.functional as functional | |
| sys.modules["torchvision.transforms.functional_tensor"] = functional | |
| except ImportError: | |
| pass # shrug... | |