Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from typing import Any, List | |
| class _FastCopy: | |
| """ | |
| Overview: | |
| The idea of this class comes from this article \ | |
| https://newbedev.com/what-is-a-fast-pythonic-way-to-deepcopy-just-data-from-a-python-dict-or-list. | |
| We use recursive calls to copy each object that needs to be copied, which will be 5x faster \ | |
| than copy.deepcopy. | |
| Interfaces: | |
| ``__init__``, ``_copy_list``, ``_copy_dict``, ``_copy_tensor``, ``_copy_ndarray``, ``copy``. | |
| """ | |
| def __init__(self): | |
| """ | |
| Overview: | |
| Initialize the _FastCopy object. | |
| """ | |
| dispatch = {} | |
| dispatch[list] = self._copy_list | |
| dispatch[dict] = self._copy_dict | |
| dispatch[torch.Tensor] = self._copy_tensor | |
| dispatch[np.ndarray] = self._copy_ndarray | |
| self.dispatch = dispatch | |
| def _copy_list(self, l: List) -> dict: | |
| """ | |
| Overview: | |
| Copy the list. | |
| Arguments: | |
| - l (:obj:`List`): The list to be copied. | |
| """ | |
| ret = l.copy() | |
| for idx, item in enumerate(ret): | |
| cp = self.dispatch.get(type(item)) | |
| if cp is not None: | |
| ret[idx] = cp(item) | |
| return ret | |
| def _copy_dict(self, d: dict) -> dict: | |
| """ | |
| Overview: | |
| Copy the dict. | |
| Arguments: | |
| - d (:obj:`dict`): The dict to be copied. | |
| """ | |
| ret = d.copy() | |
| for key, value in ret.items(): | |
| cp = self.dispatch.get(type(value)) | |
| if cp is not None: | |
| ret[key] = cp(value) | |
| return ret | |
| def _copy_tensor(self, t: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Copy the tensor. | |
| Arguments: | |
| - t (:obj:`torch.Tensor`): The tensor to be copied. | |
| """ | |
| return t.clone() | |
| def _copy_ndarray(self, a: np.ndarray) -> np.ndarray: | |
| """ | |
| Overview: | |
| Copy the ndarray. | |
| Arguments: | |
| - a (:obj:`np.ndarray`): The ndarray to be copied. | |
| """ | |
| return np.copy(a) | |
| def copy(self, sth: Any) -> Any: | |
| """ | |
| Overview: | |
| Copy the object. | |
| Arguments: | |
| - sth (:obj:`Any`): The object to be copied. | |
| """ | |
| cp = self.dispatch.get(type(sth)) | |
| if cp is None: | |
| return sth | |
| else: | |
| return cp(sth) | |
| fastcopy = _FastCopy() | |