├── LICENSE ├── README.md ├── dataflow ├── __init__.py ├── dataflow │ ├── __init__.py │ ├── base.py │ ├── common.py │ ├── dataset │ │ ├── .gitignore │ │ ├── __init__.py │ │ ├── bsds500.py │ │ ├── caltech101.py │ │ ├── cifar.py │ │ ├── ilsvrc.py │ │ ├── mnist.py │ │ └── svhn.py │ ├── format.py │ ├── image.py │ ├── imgaug │ │ ├── __init__.py │ │ ├── base.py │ │ ├── convert.py │ │ ├── crop.py │ │ ├── deform.py │ │ ├── external.py │ │ ├── geometry.py │ │ ├── imgaug_test.py │ │ ├── imgproc.py │ │ ├── meta.py │ │ ├── misc.py │ │ ├── noise.py │ │ ├── paste.py │ │ └── transform.py │ ├── parallel.py │ ├── parallel_map.py │ ├── raw.py │ ├── remote.py │ ├── serialize.py │ └── serialize_test.py └── utils │ ├── __init__.py │ ├── argtools.py │ ├── compatible_serialize.py │ ├── concurrency.py │ ├── develop.py │ ├── fs.py │ ├── loadcaffe.py │ ├── logger.py │ ├── serialize.py │ ├── stats.py │ ├── timer.py │ └── utils.py ├── setup.py ├── sync.py └── tox.ini /README.md: -------------------------------------------------------------------------------- 1 | # Tensorpack DataFlow 2 | 3 | Tensorpack DataFlow is an **efficient** and **flexible** data 4 | loading pipeline for deep learning, written in pure Python. 5 | 6 | Its main features are: 7 | 8 | 1. **Highly-optimized for speed**. 9 | Parallelization in Python is hard and most libraries do it wrong. 10 | DataFlow implements highly-optimized 11 | parallel building blocks which gives you an easy interface to parallelize your workload. 12 | 13 | 2. **Written in pure Python**. 14 | This allows it to be used together with any other Python-based library. 15 | 16 | DataFlow is originally part of the [tensorpack library](https://github.com/tensorpack/tensorpack/) 17 | and has been through many years of polishing. 18 | Given its independence of the rest of the tensorpack library, 19 | it is now a separate library whose source code is synced with tensorpack. 20 | Please use [tensorpack issues](https://github.com/tensorpack/tensorpack/issues/) for support. 21 | 22 | Why would you want to use DataFlow instead of a platform-specific data loading solutions? 23 | We recommend you to read 24 | [Why DataFlow?](https://tensorpack.readthedocs.io/tutorial/philosophy/dataflow.html). 25 | 26 | ## Install: 27 | ``` 28 | pip install --upgrade git+https://github.com/tensorpack/dataflow.git 29 | # or add `--user` to install to user's local directories 30 | ``` 31 | You may also need to install opencv, which is used by many builtin DataFlows. 32 | 33 | ## Examples: 34 | ```python 35 | import dataflow as D 36 | d = D.ILSVRC12('/path/to/imagenet') # produce [img, label] 37 | d = D.MapDataComponent(d, lambda img: some_transform(img), index=0) 38 | d = D.MultiProcessMapData(d, num_proc=10, lambda img, label: other_transform(img, label)) 39 | d = D.BatchData(d, 64) 40 | d.reset_state() 41 | for img, label in d: 42 | # ... 43 | ``` 44 | 45 | ## Documentation: 46 | ### Tutorials: 47 | 1. [Basics](https://tensorpack.readthedocs.io/tutorial/dataflow.html) 48 | 1. [Why DataFlow?](https://tensorpack.readthedocs.io/tutorial/philosophy/dataflow.html) 49 | 1. [Write a DataFlow](https://tensorpack.readthedocs.io/tutorial/extend/dataflow.html) 50 | 1. [Parallel DataFlow](https://tensorpack.readthedocs.io/tutorial/parallel-dataflow.html) 51 | 1. [Efficient DataFlow](https://tensorpack.readthedocs.io/tutorial/efficient-dataflow.html) 52 | 53 | ### APIs: 54 | 1. [Built-in DataFlows](https://tensorpack.readthedocs.io/modules/dataflow.html) 55 | 1. [Built-in Datasets](https://tensorpack.readthedocs.io/modules/dataflow.dataset.html) 56 | 57 | ## Support & Contributing 58 | 59 | Please send issues and pull requests (for the `dataflow/` directory) to the 60 | [tensorpack project](https://github.com/tensorpack/tensorpack/) where the source code is developed. 61 | -------------------------------------------------------------------------------- /dataflow/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | #File: 3 | 4 | from .dataflow import * 5 | from . import utils 6 | -------------------------------------------------------------------------------- /dataflow/dataflow/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: __init__.py 3 | 4 | # https://github.com/celery/kombu/blob/7d13f9b95d0b50c94393b962e6def928511bfda6/kombu/__init__.py#L34-L36 5 | STATICA_HACK = True 6 | globals()['kcah_acitats'[::-1].upper()] = False 7 | if STATICA_HACK: 8 | from .base import * 9 | from .common import * 10 | from .format import * 11 | from .image import * 12 | from .parallel_map import * 13 | from .parallel import * 14 | from .raw import * 15 | from .remote import * 16 | from .serialize import * 17 | from . import imgaug 18 | from . import dataset 19 | 20 | 21 | from pkgutil import iter_modules 22 | import os 23 | import os.path 24 | from ..utils.develop import LazyLoader 25 | 26 | __all__ = [] 27 | 28 | 29 | def _global_import(name): 30 | p = __import__(name, globals(), locals(), level=1) 31 | lst = p.__all__ if '__all__' in dir(p) else dir(p) 32 | if lst: 33 | del globals()[name] 34 | for k in lst: 35 | if not k.startswith('__'): 36 | globals()[k] = p.__dict__[k] 37 | __all__.append(k) 38 | 39 | 40 | __SKIP = set(['dataset', 'imgaug']) 41 | _CURR_DIR = os.path.dirname(__file__) 42 | for _, module_name, __ in iter_modules( 43 | [os.path.dirname(__file__)]): 44 | srcpath = os.path.join(_CURR_DIR, module_name + '.py') 45 | if not os.path.isfile(srcpath): 46 | continue 47 | if "_test" not in module_name and \ 48 | not module_name.startswith('_') and \ 49 | module_name not in __SKIP: 50 | _global_import(module_name) 51 | 52 | 53 | globals()['dataset'] = LazyLoader('dataset', globals(), __name__ + '.dataset') 54 | globals()['imgaug'] = LazyLoader('imgaug', globals(), __name__ + '.imgaug') 55 | 56 | del LazyLoader 57 | 58 | __all__.extend(['imgaug', 'dataset']) 59 | -------------------------------------------------------------------------------- /dataflow/dataflow/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: base.py 3 | 4 | 5 | import threading 6 | from abc import ABCMeta, abstractmethod 7 | import six 8 | 9 | from ..utils.utils import get_rng 10 | 11 | __all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow', 'DataFlowTerminated'] 12 | 13 | 14 | class DataFlowTerminated(BaseException): 15 | """ 16 | An exception indicating that the DataFlow is unable to produce any more 17 | data, i.e. something wrong happened so that calling :meth:`get_data` 18 | cannot give a valid iterator any more. 19 | In most DataFlow this will never be raised. 20 | """ 21 | pass 22 | 23 | 24 | class DataFlowReentrantGuard(object): 25 | """ 26 | A tool to enforce non-reentrancy. 27 | Mostly used on DataFlow whose :meth:`get_data` is stateful, 28 | so that multiple instances of the iterator cannot co-exist. 29 | """ 30 | def __init__(self): 31 | self._lock = threading.Lock() 32 | 33 | def __enter__(self): 34 | self._succ = self._lock.acquire(False) 35 | if not self._succ: 36 | raise threading.ThreadError("This DataFlow is not reentrant!") 37 | 38 | def __exit__(self, exc_type, exc_val, exc_tb): 39 | self._lock.release() 40 | return False 41 | 42 | 43 | class DataFlowMeta(ABCMeta): 44 | """ 45 | DataFlow uses "__iter__()" and "__len__()" instead of 46 | "get_data()" and "size()". This add back-compatibility. 47 | """ 48 | def __new__(mcls, name, bases, namespace, **kwargs): 49 | 50 | def hot_patch(required, existing): 51 | if required not in namespace and existing in namespace: 52 | namespace[required] = namespace[existing] 53 | 54 | hot_patch('__iter__', 'get_data') 55 | hot_patch('__len__', 'size') 56 | 57 | return ABCMeta.__new__(mcls, name, bases, namespace, **kwargs) 58 | 59 | 60 | @six.add_metaclass(DataFlowMeta) 61 | class DataFlow(object): 62 | """ Base class for all DataFlow """ 63 | 64 | @abstractmethod 65 | def __iter__(self): 66 | """ 67 | * A dataflow is an iterable. The :meth:`__iter__` method should yield a list or dict each time. 68 | Note that dict is **partially** supported at the moment: certain dataflow does not support dict. 69 | 70 | * The :meth:`__iter__` method can be either finite (will stop iteration) or infinite 71 | (will not stop iteration). For a finite dataflow, :meth:`__iter__` can be called 72 | again immediately after the previous call returned. 73 | 74 | * For many dataflow, the :meth:`__iter__` method is non-reentrant, which means for an dataflow 75 | instance ``df``, :meth:`df.__iter__` cannot be called before the previous 76 | :meth:`df.__iter__` call has finished (iteration has stopped). 77 | When a dataflow is non-reentrant, :meth:`df.__iter__` should throw an exception if 78 | called before the previous call has finished. 79 | For such non-reentrant dataflows, if you need to use the same dataflow in two places, 80 | you need to create two dataflow instances. 81 | 82 | Yields: 83 | list/dict: The datapoint, i.e. list/dict of components. 84 | """ 85 | 86 | def __len__(self): 87 | """ 88 | * A dataflow can optionally implement :meth:`__len__`. If not implemented, it will 89 | throw :class:`NotImplementedError`. 90 | 91 | * It returns an integer representing the size of the dataflow. 92 | The return value **may not be accurate or meaningful** at all. 93 | When saying the length is "accurate", it means that 94 | :meth:`__iter__` will always yield this many of datapoints before it stops iteration. 95 | 96 | * There could be many reasons why :meth:`__len__` is inaccurate. 97 | For example, some dataflow has dynamic size, if it throws away datapoints on the fly. 98 | Some dataflow mixes the datapoints between consecutive passes over 99 | the dataset, due to parallelism and buffering. 100 | In this case it does not make sense to stop the iteration anywhere. 101 | 102 | * Due to the above reasons, the length is only a rough guidance. 103 | And it's up to the user how to interpret it. 104 | Inside tensorpack it's only used in these places: 105 | 106 | + A default ``steps_per_epoch`` in training, but you probably want to customize 107 | it yourself, especially when using data-parallel trainer. 108 | + The length of progress bar when processing a dataflow. 109 | + Used by :class:`InferenceRunner` to get the number of iterations in inference. 110 | In this case users are **responsible** for making sure that :meth:`__len__` is "accurate". 111 | This is to guarantee that inference is run on a fixed set of images. 112 | 113 | Returns: 114 | int: rough size of this dataflow. 115 | 116 | Raises: 117 | :class:`NotImplementedError` if this DataFlow doesn't have a size. 118 | """ 119 | raise NotImplementedError() 120 | 121 | def reset_state(self): 122 | """ 123 | * The caller must guarantee that :meth:`reset_state` should be called **once and only once** 124 | by the **process that uses the dataflow** before :meth:`__iter__` is called. 125 | The caller thread of this method should stay alive to keep this dataflow alive. 126 | 127 | * It is meant for certain initialization that involves processes, 128 | e.g., initialize random number generators (RNG), create worker processes. 129 | 130 | Because it's very common to use RNG in data processing, 131 | developers of dataflow can also subclass :class:`RNGDataFlow` to have easier access to 132 | a properly-initialized RNG. 133 | 134 | * A dataflow is not fork-safe after :meth:`reset_state` is called (because this will violate the guarantee). 135 | There are a few other dataflows that are not fork-safe anytime, which will be mentioned in the docs. 136 | 137 | * You should take the responsibility and follow the above guarantee if you're the caller of a dataflow yourself 138 | (either when you're using dataflow outside of tensorpack, or if you're writing a wrapper dataflow). 139 | 140 | * Tensorpack's built-in forking dataflows (:class:`MultiProcessRunner`, :class:`MultiProcessMapData`, etc) 141 | and other component that uses dataflows (:class:`InputSource`) 142 | already take care of the responsibility of calling this method. 143 | """ 144 | pass 145 | 146 | # These are the old (overly verbose) names for the methods: 147 | def get_data(self): 148 | return self.__iter__() 149 | 150 | def size(self): 151 | return self.__len__() 152 | 153 | 154 | class RNGDataFlow(DataFlow): 155 | """ A DataFlow with RNG""" 156 | 157 | rng = None 158 | """ 159 | ``self.rng`` is a ``np.random.RandomState`` instance that is initialized 160 | correctly (with different seeds in each process) in ``RNGDataFlow.reset_state()``. 161 | """ 162 | 163 | def reset_state(self): 164 | """ Reset the RNG """ 165 | self.rng = get_rng(self) 166 | 167 | 168 | class ProxyDataFlow(DataFlow): 169 | """ Base class for DataFlow that proxies another. 170 | Every method is proxied to ``self.ds`` unless overriden by a subclass. 171 | """ 172 | 173 | def __init__(self, ds): 174 | """ 175 | Args: 176 | ds (DataFlow): DataFlow to proxy. 177 | """ 178 | self.ds = ds 179 | 180 | def reset_state(self): 181 | self.ds.reset_state() 182 | 183 | def __len__(self): 184 | return self.ds.__len__() 185 | 186 | def __iter__(self): 187 | return self.ds.__iter__() 188 | -------------------------------------------------------------------------------- /dataflow/dataflow/dataset/.gitignore: -------------------------------------------------------------------------------- 1 | mnist_data 2 | cifar10_data 3 | cifar100_data 4 | svhn_data 5 | ilsvrc_metadata 6 | bsds500_data 7 | -------------------------------------------------------------------------------- /dataflow/dataflow/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: __init__.py 3 | 4 | # https://github.com/celery/kombu/blob/7d13f9b95d0b50c94393b962e6def928511bfda6/kombu/__init__.py#L34-L36 5 | STATICA_HACK = True 6 | globals()['kcah_acitats'[::-1].upper()] = False 7 | if STATICA_HACK: 8 | from .bsds500 import * 9 | from .cifar import * 10 | from .ilsvrc import * 11 | from .mnist import * 12 | from .svhn import * 13 | from .caltech101 import * 14 | 15 | from pkgutil import iter_modules 16 | import os 17 | import os.path 18 | 19 | __all__ = [] 20 | 21 | 22 | def global_import(name): 23 | p = __import__(name, globals(), locals(), level=1) 24 | lst = p.__all__ if '__all__' in dir(p) else dir(p) 25 | if lst: 26 | del globals()[name] 27 | for k in lst: 28 | if not k.startswith('__'): 29 | globals()[k] = p.__dict__[k] 30 | __all__.append(k) 31 | 32 | 33 | _CURR_DIR = os.path.dirname(__file__) 34 | for _, module_name, _ in iter_modules( 35 | [_CURR_DIR]): 36 | srcpath = os.path.join(_CURR_DIR, module_name + '.py') 37 | if not os.path.isfile(srcpath): 38 | continue 39 | if not module_name.startswith('_'): 40 | global_import(module_name) 41 | -------------------------------------------------------------------------------- /dataflow/dataflow/dataset/bsds500.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: bsds500.py 3 | 4 | 5 | import glob 6 | import numpy as np 7 | import os 8 | 9 | from ...utils.fs import download, get_dataset_path 10 | from ..base import RNGDataFlow 11 | 12 | __all__ = ['BSDS500'] 13 | 14 | 15 | DATA_URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz" 16 | DATA_SIZE = 70763455 17 | IMG_W, IMG_H = 481, 321 18 | 19 | 20 | class BSDS500(RNGDataFlow): 21 | """ 22 | `Berkeley Segmentation Data Set and Benchmarks 500 dataset 23 | `_. 24 | 25 | Produce ``(image, label)`` pair, where ``image`` has shape (321, 481, 3(BGR)) and 26 | ranges in [0,255]. 27 | ``Label`` is a floating point image of shape (321, 481) in range [0, 1]. 28 | The value of each pixel is ``number of times it is annotated as edge / total number of annotators for this image``. 29 | """ 30 | 31 | def __init__(self, name, data_dir=None, shuffle=True): 32 | """ 33 | Args: 34 | name (str): 'train', 'test', 'val' 35 | data_dir (str): a directory containing the original 'BSR' directory. 36 | """ 37 | # check and download data 38 | if data_dir is None: 39 | data_dir = get_dataset_path('bsds500_data') 40 | if not os.path.isdir(os.path.join(data_dir, 'BSR')): 41 | download(DATA_URL, data_dir, expect_size=DATA_SIZE) 42 | filename = DATA_URL.split('/')[-1] 43 | filepath = os.path.join(data_dir, filename) 44 | import tarfile 45 | tarfile.open(filepath, 'r:gz').extractall(data_dir) 46 | self.data_root = os.path.join(data_dir, 'BSR', 'BSDS500', 'data') 47 | assert os.path.isdir(self.data_root) 48 | 49 | self.shuffle = shuffle 50 | assert name in ['train', 'test', 'val'] 51 | self._load(name) 52 | 53 | def _load(self, name): 54 | image_glob = os.path.join(self.data_root, 'images', name, '*.jpg') 55 | image_files = glob.glob(image_glob) 56 | gt_dir = os.path.join(self.data_root, 'groundTruth', name) 57 | self.data = np.zeros((len(image_files), IMG_H, IMG_W, 3), dtype='uint8') 58 | self.label = np.zeros((len(image_files), IMG_H, IMG_W), dtype='float32') 59 | 60 | for idx, f in enumerate(image_files): 61 | im = cv2.imread(f, cv2.IMREAD_COLOR) 62 | assert im is not None 63 | if im.shape[0] > im.shape[1]: 64 | im = np.transpose(im, (1, 0, 2)) 65 | assert im.shape[:2] == (IMG_H, IMG_W), "{} != {}".format(im.shape[:2], (IMG_H, IMG_W)) 66 | 67 | imgid = os.path.basename(f).split('.')[0] 68 | gt_file = os.path.join(gt_dir, imgid) 69 | gt = loadmat(gt_file)['groundTruth'][0] 70 | n_annot = gt.shape[0] 71 | gt = sum(gt[k]['Boundaries'][0][0] for k in range(n_annot)) 72 | gt = gt.astype('float32') 73 | gt *= 1.0 / n_annot 74 | if gt.shape[0] > gt.shape[1]: 75 | gt = gt.transpose() 76 | assert gt.shape == (IMG_H, IMG_W) 77 | 78 | self.data[idx] = im 79 | self.label[idx] = gt 80 | 81 | def __len__(self): 82 | return self.data.shape[0] 83 | 84 | def __iter__(self): 85 | idxs = np.arange(self.data.shape[0]) 86 | if self.shuffle: 87 | self.rng.shuffle(idxs) 88 | for k in idxs: 89 | yield [self.data[k], self.label[k]] 90 | 91 | 92 | try: 93 | from scipy.io import loadmat 94 | import cv2 95 | except ImportError: 96 | from ...utils.develop import create_dummy_class 97 | BSDS500 = create_dummy_class('BSDS500', ['scipy.io', 'cv2']) # noqa 98 | 99 | if __name__ == '__main__': 100 | a = BSDS500('val') 101 | a.reset_state() 102 | for k in a: 103 | cv2.imshow("haha", k[1].astype('uint8') * 255) 104 | cv2.waitKey(1000) 105 | -------------------------------------------------------------------------------- /dataflow/dataflow/dataset/caltech101.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: caltech101.py 3 | 4 | 5 | import os 6 | 7 | from ...utils import logger 8 | from ...utils.fs import download, get_dataset_path 9 | from ..base import RNGDataFlow 10 | 11 | __all__ = ["Caltech101Silhouettes"] 12 | 13 | 14 | def maybe_download(url, work_directory): 15 | """Download the data from Marlin's website, unless it's already here.""" 16 | filename = url.split("/")[-1] 17 | filepath = os.path.join(work_directory, filename) 18 | if not os.path.exists(filepath): 19 | logger.info("Downloading to {}...".format(filepath)) 20 | download(url, work_directory) 21 | return filepath 22 | 23 | 24 | class Caltech101Silhouettes(RNGDataFlow): 25 | """ 26 | Produces [image, label] in Caltech101 Silhouettes dataset, 27 | image is 28x28 in the range [0,1], label is an int in the range [0,100]. 28 | """ 29 | 30 | _DIR_NAME = "caltech101_data" 31 | _SOURCE_URL = "https://people.cs.umass.edu/~marlin/data/" 32 | 33 | def __init__(self, name, shuffle=True, dir=None): 34 | """ 35 | Args: 36 | name (str): 'train', 'test', 'val' 37 | shuffle (bool): shuffle the dataset 38 | """ 39 | if dir is None: 40 | dir = get_dataset_path(self._DIR_NAME) 41 | assert name in ['train', 'test', 'val'] 42 | self.name = name 43 | self.shuffle = shuffle 44 | 45 | def get_images_and_labels(data_file): 46 | f = maybe_download(self._SOURCE_URL + data_file, dir) 47 | data = scipy.io.loadmat(f) 48 | return data 49 | 50 | self.data = get_images_and_labels("caltech101_silhouettes_28_split1.mat") 51 | 52 | if self.name == "train": 53 | self.images = self.data["train_data"].reshape((4100, 28, 28)) 54 | self.labels = self.data["train_labels"].ravel() - 1 55 | elif self.name == "test": 56 | self.images = self.data["test_data"].reshape((2307, 28, 28)) 57 | self.labels = self.data["test_labels"].ravel() - 1 58 | else: 59 | self.images = self.data["val_data"].reshape((2264, 28, 28)) 60 | self.labels = self.data["val_labels"].ravel() - 1 61 | 62 | def __len__(self): 63 | return self.images.shape[0] 64 | 65 | def __iter__(self): 66 | idxs = list(range(self.__len__())) 67 | if self.shuffle: 68 | self.rng.shuffle(idxs) 69 | for k in idxs: 70 | img = self.images[k] 71 | label = self.labels[k] 72 | yield [img, label] 73 | 74 | 75 | try: 76 | import scipy.io 77 | except ImportError: 78 | from ...utils.develop import create_dummy_class 79 | Caltech101Silhouettes = create_dummy_class('Caltech101Silhouettes', 'scipy.io') # noqa 80 | 81 | 82 | if __name__ == "__main__": 83 | ds = Caltech101Silhouettes("train") 84 | ds.reset_state() 85 | for _ in ds: 86 | from IPython import embed 87 | 88 | embed() 89 | break 90 | -------------------------------------------------------------------------------- /dataflow/dataflow/dataset/cifar.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: cifar.py 3 | 4 | # Yukun Chen 5 | 6 | import numpy as np 7 | import os 8 | import pickle 9 | import tarfile 10 | 11 | from ...utils import logger 12 | from ...utils.fs import download, get_dataset_path 13 | from ..base import RNGDataFlow 14 | 15 | __all__ = ['CifarBase', 'Cifar10', 'Cifar100'] 16 | 17 | 18 | DATA_URL_CIFAR_10 = ('http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz', 170498071) 19 | DATA_URL_CIFAR_100 = ('http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz', 169001437) 20 | 21 | 22 | def maybe_download_and_extract(dest_directory, cifar_classnum): 23 | """Download and extract the tarball from Alex's website. Copied from tensorflow example """ 24 | assert cifar_classnum == 10 or cifar_classnum == 100 25 | if cifar_classnum == 10: 26 | cifar_foldername = 'cifar-10-batches-py' 27 | else: 28 | cifar_foldername = 'cifar-100-python' 29 | if os.path.isdir(os.path.join(dest_directory, cifar_foldername)): 30 | logger.info("Found cifar{} data in {}.".format(cifar_classnum, dest_directory)) 31 | return 32 | else: 33 | DATA_URL = DATA_URL_CIFAR_10 if cifar_classnum == 10 else DATA_URL_CIFAR_100 34 | filename = DATA_URL[0].split('/')[-1] 35 | filepath = os.path.join(dest_directory, filename) 36 | download(DATA_URL[0], dest_directory, expect_size=DATA_URL[1]) 37 | tarfile.open(filepath, 'r:gz').extractall(dest_directory) 38 | 39 | 40 | def read_cifar(filenames, cifar_classnum): 41 | assert cifar_classnum == 10 or cifar_classnum == 100 42 | ret = [] 43 | for fname in filenames: 44 | fo = open(fname, 'rb') 45 | dic = pickle.load(fo, encoding='bytes') 46 | data = dic[b'data'] 47 | if cifar_classnum == 10: 48 | label = dic[b'labels'] 49 | IMG_NUM = 10000 # cifar10 data are split into blocks of 10000 50 | else: 51 | label = dic[b'fine_labels'] 52 | IMG_NUM = 50000 if 'train' in fname else 10000 53 | fo.close() 54 | for k in range(IMG_NUM): 55 | img = data[k].reshape(3, 32, 32) 56 | img = np.transpose(img, [1, 2, 0]) 57 | ret.append([img, label[k]]) 58 | return ret 59 | 60 | 61 | def get_filenames(dir, cifar_classnum): 62 | assert cifar_classnum == 10 or cifar_classnum == 100 63 | if cifar_classnum == 10: 64 | train_files = [os.path.join( 65 | dir, 'cifar-10-batches-py', 'data_batch_%d' % i) for i in range(1, 6)] 66 | test_files = [os.path.join( 67 | dir, 'cifar-10-batches-py', 'test_batch')] 68 | meta_file = os.path.join(dir, 'cifar-10-batches-py', 'batches.meta') 69 | elif cifar_classnum == 100: 70 | train_files = [os.path.join(dir, 'cifar-100-python', 'train')] 71 | test_files = [os.path.join(dir, 'cifar-100-python', 'test')] 72 | meta_file = os.path.join(dir, 'cifar-100-python', 'meta') 73 | return train_files, test_files, meta_file 74 | 75 | 76 | def _parse_meta(filename, cifar_classnum): 77 | with open(filename, 'rb') as f: 78 | obj = pickle.load(f) 79 | return obj['label_names' if cifar_classnum == 10 else 'fine_label_names'] 80 | 81 | 82 | class CifarBase(RNGDataFlow): 83 | """ 84 | Produces [image, label] in Cifar10/100 dataset, 85 | image is 32x32x3 in the range [0,255]. 86 | label is an int. 87 | """ 88 | def __init__(self, train_or_test, shuffle=None, dir=None, cifar_classnum=10): 89 | """ 90 | Args: 91 | train_or_test (str): 'train' or 'test' 92 | shuffle (bool): defaults to True for training set. 93 | dir (str): path to the dataset directory 94 | cifar_classnum (int): 10 or 100 95 | """ 96 | assert train_or_test in ['train', 'test'] 97 | assert cifar_classnum == 10 or cifar_classnum == 100 98 | self.cifar_classnum = cifar_classnum 99 | if dir is None: 100 | dir = get_dataset_path('cifar{}_data'.format(cifar_classnum)) 101 | maybe_download_and_extract(dir, self.cifar_classnum) 102 | train_files, test_files, meta_file = get_filenames(dir, cifar_classnum) 103 | if train_or_test == 'train': 104 | self.fs = train_files 105 | else: 106 | self.fs = test_files 107 | for f in self.fs: 108 | if not os.path.isfile(f): 109 | raise ValueError('Failed to find file: ' + f) 110 | self._label_names = _parse_meta(meta_file, cifar_classnum) 111 | self.train_or_test = train_or_test 112 | self.data = read_cifar(self.fs, cifar_classnum) 113 | self.dir = dir 114 | 115 | if shuffle is None: 116 | shuffle = train_or_test == 'train' 117 | self.shuffle = shuffle 118 | 119 | def __len__(self): 120 | return 50000 if self.train_or_test == 'train' else 10000 121 | 122 | def __iter__(self): 123 | idxs = np.arange(len(self.data)) 124 | if self.shuffle: 125 | self.rng.shuffle(idxs) 126 | for k in idxs: 127 | # since cifar is quite small, just do it for safety 128 | yield self.data[k] 129 | 130 | def get_per_pixel_mean(self, names=('train', 'test')): 131 | """ 132 | Args: 133 | names (tuple[str]): the names ('train' or 'test') of the datasets 134 | 135 | Returns: 136 | a mean image of all images in the given datasets, with size 32x32x3 137 | """ 138 | for name in names: 139 | assert name in ['train', 'test'], name 140 | train_files, test_files, _ = get_filenames(self.dir, self.cifar_classnum) 141 | all_files = [] 142 | if 'train' in names: 143 | all_files.extend(train_files) 144 | if 'test' in names: 145 | all_files.extend(test_files) 146 | all_imgs = [x[0] for x in read_cifar(all_files, self.cifar_classnum)] 147 | arr = np.array(all_imgs, dtype='float32') 148 | mean = np.mean(arr, axis=0) 149 | return mean 150 | 151 | def get_label_names(self): 152 | """ 153 | Returns: 154 | [str]: name of each class. 155 | """ 156 | return self._label_names 157 | 158 | def get_per_channel_mean(self, names=('train', 'test')): 159 | """ 160 | Args: 161 | names (tuple[str]): the names ('train' or 'test') of the datasets 162 | 163 | Returns: 164 | An array of three values as mean of each channel, for all images in the given datasets. 165 | """ 166 | mean = self.get_per_pixel_mean(names) 167 | return np.mean(mean, axis=(0, 1)) 168 | 169 | 170 | class Cifar10(CifarBase): 171 | """ 172 | Produces [image, label] in Cifar10 dataset, 173 | image is 32x32x3 in the range [0,255]. 174 | label is an int. 175 | """ 176 | def __init__(self, train_or_test, shuffle=None, dir=None): 177 | """ 178 | Args: 179 | train_or_test (str): either 'train' or 'test'. 180 | shuffle (bool): shuffle the dataset, default to shuffle in training 181 | """ 182 | super(Cifar10, self).__init__(train_or_test, shuffle, dir, 10) 183 | 184 | 185 | class Cifar100(CifarBase): 186 | """ Similar to Cifar10""" 187 | def __init__(self, train_or_test, shuffle=None, dir=None): 188 | super(Cifar100, self).__init__(train_or_test, shuffle, dir, 100) 189 | 190 | 191 | if __name__ == '__main__': 192 | ds = Cifar10('train') 193 | mean = ds.get_per_channel_mean() 194 | print(mean) 195 | 196 | import cv2 197 | ds.reset_state() 198 | for i, dp in enumerate(ds): 199 | if i == 100: 200 | break 201 | img = dp[0] 202 | cv2.imwrite("{:04d}.jpg".format(i), img) 203 | -------------------------------------------------------------------------------- /dataflow/dataflow/dataset/mnist.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: mnist.py 3 | 4 | 5 | import gzip 6 | import numpy 7 | import os 8 | 9 | from ...utils import logger 10 | from ...utils.fs import download, get_dataset_path 11 | from ..base import RNGDataFlow 12 | 13 | __all__ = ['Mnist', 'FashionMnist'] 14 | 15 | 16 | def maybe_download(url, work_directory): 17 | """Download the data from Yann's website, unless it's already here.""" 18 | filename = url.split('/')[-1] 19 | filepath = os.path.join(work_directory, filename) 20 | if not os.path.exists(filepath): 21 | logger.info("Downloading to {}...".format(filepath)) 22 | download(url, work_directory) 23 | return filepath 24 | 25 | 26 | def _read32(bytestream): 27 | dt = numpy.dtype(numpy.uint32).newbyteorder('>') 28 | return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] 29 | 30 | 31 | def extract_images(filename): 32 | """Extract the images into a 4D uint8 numpy array [index, y, x, depth].""" 33 | with gzip.open(filename) as bytestream: 34 | magic = _read32(bytestream) 35 | if magic != 2051: 36 | raise ValueError( 37 | 'Invalid magic number %d in MNIST image file: %s' % 38 | (magic, filename)) 39 | num_images = _read32(bytestream) 40 | rows = _read32(bytestream) 41 | cols = _read32(bytestream) 42 | buf = bytestream.read(rows * cols * num_images) 43 | data = numpy.frombuffer(buf, dtype=numpy.uint8) 44 | data = data.reshape(num_images, rows, cols, 1) 45 | data = data.astype('float32') / 255.0 46 | return data 47 | 48 | 49 | def extract_labels(filename): 50 | """Extract the labels into a 1D uint8 numpy array [index].""" 51 | with gzip.open(filename) as bytestream: 52 | magic = _read32(bytestream) 53 | if magic != 2049: 54 | raise ValueError( 55 | 'Invalid magic number %d in MNIST label file: %s' % 56 | (magic, filename)) 57 | num_items = _read32(bytestream) 58 | buf = bytestream.read(num_items) 59 | labels = numpy.frombuffer(buf, dtype=numpy.uint8) 60 | return labels 61 | 62 | 63 | class Mnist(RNGDataFlow): 64 | """ 65 | Produces [image, label] in MNIST dataset, 66 | image is 28x28 in the range [0,1], label is an int. 67 | """ 68 | 69 | _DIR_NAME = 'mnist_data' 70 | _SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' 71 | 72 | def __init__(self, train_or_test, shuffle=True, dir=None): 73 | """ 74 | Args: 75 | train_or_test (str): either 'train' or 'test' 76 | shuffle (bool): shuffle the dataset 77 | """ 78 | if dir is None: 79 | dir = get_dataset_path(self._DIR_NAME) 80 | assert train_or_test in ['train', 'test'] 81 | self.train_or_test = train_or_test 82 | self.shuffle = shuffle 83 | 84 | def get_images_and_labels(image_file, label_file): 85 | f = maybe_download(self._SOURCE_URL + image_file, dir) 86 | images = extract_images(f) 87 | f = maybe_download(self._SOURCE_URL + label_file, dir) 88 | labels = extract_labels(f) 89 | assert images.shape[0] == labels.shape[0] 90 | return images, labels 91 | 92 | if self.train_or_test == 'train': 93 | self.images, self.labels = get_images_and_labels( 94 | 'train-images-idx3-ubyte.gz', 95 | 'train-labels-idx1-ubyte.gz') 96 | else: 97 | self.images, self.labels = get_images_and_labels( 98 | 't10k-images-idx3-ubyte.gz', 99 | 't10k-labels-idx1-ubyte.gz') 100 | 101 | def __len__(self): 102 | return self.images.shape[0] 103 | 104 | def __iter__(self): 105 | idxs = list(range(self.__len__())) 106 | if self.shuffle: 107 | self.rng.shuffle(idxs) 108 | for k in idxs: 109 | img = self.images[k].reshape((28, 28)) 110 | label = self.labels[k] 111 | yield [img, label] 112 | 113 | 114 | class FashionMnist(Mnist): 115 | """ 116 | Same API as :class:`Mnist`, but more fashion. 117 | """ 118 | 119 | _DIR_NAME = 'fashion_mnist_data' 120 | _SOURCE_URL = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/' 121 | 122 | def get_label_names(self): 123 | """ 124 | Returns: 125 | [str]: the name of each class 126 | """ 127 | # copied from https://github.com/zalandoresearch/fashion-mnist 128 | return ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 129 | 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] 130 | 131 | 132 | if __name__ == '__main__': 133 | ds = Mnist('train') 134 | ds.reset_state() 135 | for _ in ds: 136 | from IPython import embed 137 | embed() 138 | break 139 | -------------------------------------------------------------------------------- /dataflow/dataflow/dataset/svhn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: svhn.py 3 | 4 | 5 | import numpy as np 6 | import os 7 | 8 | from ...utils import logger 9 | from ...utils.fs import download, get_dataset_path 10 | from ..base import RNGDataFlow 11 | 12 | __all__ = ['SVHNDigit'] 13 | 14 | SVHN_URL = "http://ufldl.stanford.edu/housenumbers/" 15 | 16 | 17 | class SVHNDigit(RNGDataFlow): 18 | """ 19 | `SVHN `_ Cropped Digit Dataset. 20 | Produces [img, label], img of 32x32x3 in range [0,255], label of 0-9 21 | """ 22 | _Cache = {} 23 | 24 | def __init__(self, name, data_dir=None, shuffle=True): 25 | """ 26 | Args: 27 | name (str): 'train', 'test', or 'extra'. 28 | data_dir (str): a directory containing the original {train,test,extra}_32x32.mat. 29 | shuffle (bool): shuffle the dataset. 30 | """ 31 | self.shuffle = shuffle 32 | 33 | if name in SVHNDigit._Cache: 34 | self.X, self.Y = SVHNDigit._Cache[name] 35 | return 36 | if data_dir is None: 37 | data_dir = get_dataset_path('svhn_data') 38 | assert name in ['train', 'test', 'extra'], name 39 | filename = os.path.join(data_dir, name + '_32x32.mat') 40 | if not os.path.isfile(filename): 41 | url = SVHN_URL + os.path.basename(filename) 42 | logger.info("File {} not found!".format(filename)) 43 | logger.info("Downloading from {} ...".format(url)) 44 | download(url, os.path.dirname(filename)) 45 | logger.info("Loading {} ...".format(filename)) 46 | data = scipy.io.loadmat(filename) 47 | self.X = data['X'].transpose(3, 0, 1, 2) 48 | self.Y = data['y'].reshape((-1)) 49 | self.Y[self.Y == 10] = 0 50 | SVHNDigit._Cache[name] = (self.X, self.Y) 51 | 52 | def __len__(self): 53 | return self.X.shape[0] 54 | 55 | def __iter__(self): 56 | n = self.X.shape[0] 57 | idxs = np.arange(n) 58 | if self.shuffle: 59 | self.rng.shuffle(idxs) 60 | for k in idxs: 61 | # since svhn is quite small, just do it for safety 62 | yield [self.X[k], self.Y[k]] 63 | 64 | @staticmethod 65 | def get_per_pixel_mean(names=('train', 'test', 'extra')): 66 | """ 67 | Args: 68 | names (tuple[str]): names of the dataset split 69 | 70 | Returns: 71 | a 32x32x3 image, the mean of all images in the given datasets 72 | """ 73 | for name in names: 74 | assert name in ['train', 'test', 'extra'], name 75 | images = [SVHNDigit(x).X for x in names] 76 | return np.concatenate(tuple(images)).mean(axis=0) 77 | 78 | 79 | try: 80 | import scipy.io 81 | except ImportError: 82 | from ...utils.develop import create_dummy_class 83 | SVHNDigit = create_dummy_class('SVHNDigit', 'scipy.io') # noqa 84 | 85 | if __name__ == '__main__': 86 | a = SVHNDigit('train') 87 | b = SVHNDigit.get_per_pixel_mean() 88 | -------------------------------------------------------------------------------- /dataflow/dataflow/format.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: format.py 3 | 4 | 5 | import numpy as np 6 | import os 7 | import six 8 | 9 | from ..utils import logger 10 | from ..utils.argtools import log_once 11 | from ..utils.serialize import loads 12 | from ..utils.develop import create_dummy_class # noqa 13 | from ..utils.loadcaffe import get_caffe_pb 14 | from ..utils.timer import timed_operation 15 | from ..utils.utils import get_tqdm 16 | from .base import DataFlowReentrantGuard, RNGDataFlow 17 | from .common import MapData 18 | 19 | __all__ = ['HDF5Data', 'LMDBData', 'LMDBDataDecoder', 20 | 'CaffeLMDB', 'SVMLightData'] 21 | 22 | """ 23 | Adapters for different data format. 24 | """ 25 | 26 | 27 | class HDF5Data(RNGDataFlow): 28 | """ 29 | Zip data from different paths in an HDF5 file. 30 | 31 | Warning: 32 | The current implementation will load all data into memory. (TODO) 33 | """ 34 | # TODO 35 | 36 | def __init__(self, filename, data_paths, shuffle=True): 37 | """ 38 | Args: 39 | filename (str): h5 data file. 40 | data_paths (list): list of h5 paths to zipped. 41 | For example `['images', 'labels']`. 42 | shuffle (bool): shuffle all data. 43 | """ 44 | self.f = h5py.File(filename, 'r') 45 | logger.info("Loading {} to memory...".format(filename)) 46 | self.dps = [self.f[k].value for k in data_paths] 47 | lens = [len(k) for k in self.dps] 48 | assert all(k == lens[0] for k in lens) 49 | self._size = lens[0] 50 | self.shuffle = shuffle 51 | 52 | def __len__(self): 53 | return self._size 54 | 55 | def __iter__(self): 56 | idxs = list(range(self._size)) 57 | if self.shuffle: 58 | self.rng.shuffle(idxs) 59 | for k in idxs: 60 | yield [dp[k] for dp in self.dps] 61 | 62 | 63 | class LMDBData(RNGDataFlow): 64 | """ 65 | Read a LMDB database and produce (k,v) raw bytes pairs. 66 | The raw bytes are usually not what you're interested in. 67 | You might want to use 68 | :class:`LMDBDataDecoder` or apply a 69 | mapper function after :class:`LMDBData`. 70 | """ 71 | def __init__(self, lmdb_path, shuffle=True, keys=None): 72 | """ 73 | Args: 74 | lmdb_path (str): a directory or a file. 75 | shuffle (bool): shuffle the keys or not. 76 | keys (list[str] or str): list of str as the keys, used only when shuffle is True. 77 | It can also be a format string e.g. ``{:0>8d}`` which will be 78 | formatted with the indices from 0 to *total_size - 1*. 79 | 80 | If not given, it will then look in the database for ``__keys__`` which 81 | :func:`LMDBSerializer.save` used to store the list of keys. 82 | If still not found, it will iterate over the database to find 83 | all the keys. 84 | """ 85 | self._lmdb_path = lmdb_path 86 | self._shuffle = shuffle 87 | 88 | self._open_lmdb() 89 | self._size = self._txn.stat()['entries'] 90 | self._set_keys(keys) 91 | logger.info("Found {} entries in {}".format(self._size, self._lmdb_path)) 92 | 93 | # Clean them up after finding the list of keys, since we don't want to fork them 94 | self._close_lmdb() 95 | 96 | def _set_keys(self, keys=None): 97 | def find_keys(txn, size): 98 | logger.warn("Traversing the database to find keys is slow. Your should specify the keys.") 99 | keys = [] 100 | with timed_operation("Loading LMDB keys ...", log_start=True), \ 101 | get_tqdm(total=size) as pbar: 102 | for k in self._txn.cursor(): 103 | assert k[0] != b'__keys__' 104 | keys.append(k[0]) 105 | pbar.update() 106 | return keys 107 | 108 | self.keys = self._txn.get(b'__keys__') 109 | if self.keys is not None: 110 | self.keys = loads(self.keys) 111 | self._size -= 1 # delete this item 112 | 113 | if self._shuffle: # keys are necessary when shuffle is True 114 | if keys is None: 115 | if self.keys is None: 116 | self.keys = find_keys(self._txn, self._size) 117 | else: 118 | # check if key-format like '{:0>8d}' was given 119 | if isinstance(keys, six.string_types): 120 | self.keys = map(lambda x: keys.format(x), list(np.arange(self._size))) 121 | else: 122 | self.keys = keys 123 | 124 | def _open_lmdb(self): 125 | self._lmdb = lmdb.open(self._lmdb_path, 126 | subdir=os.path.isdir(self._lmdb_path), 127 | readonly=True, lock=False, readahead=True, 128 | map_size=1099511627776 * 2, max_readers=100) 129 | self._txn = self._lmdb.begin() 130 | 131 | def _close_lmdb(self): 132 | self._lmdb.close() 133 | del self._lmdb 134 | del self._txn 135 | 136 | def reset_state(self): 137 | self._guard = DataFlowReentrantGuard() 138 | super(LMDBData, self).reset_state() 139 | self._open_lmdb() # open the LMDB in the worker process 140 | 141 | def __len__(self): 142 | return self._size 143 | 144 | def __iter__(self): 145 | with self._guard: 146 | if not self._shuffle: 147 | c = self._txn.cursor() 148 | for k, v in c: 149 | if k != b'__keys__': 150 | yield [k, v] 151 | else: 152 | self.rng.shuffle(self.keys) 153 | for k in self.keys: 154 | v = self._txn.get(k) 155 | yield [k, v] 156 | 157 | 158 | class LMDBDataDecoder(MapData): 159 | """ Read a LMDB database with a custom decoder and produce decoded outputs.""" 160 | def __init__(self, lmdb_data, decoder): 161 | """ 162 | Args: 163 | lmdb_data: a :class:`LMDBData` instance. 164 | decoder (k,v -> dp | None): a function taking k, v and returning a datapoint, 165 | or return None to discard. 166 | """ 167 | def f(dp): 168 | return decoder(dp[0], dp[1]) 169 | super(LMDBDataDecoder, self).__init__(lmdb_data, f) 170 | 171 | 172 | def CaffeLMDB(lmdb_path, shuffle=True, keys=None): 173 | """ 174 | Read a Caffe-format LMDB file where each value contains a ``caffe.Datum`` protobuf. 175 | Produces datapoints of the format: [HWC image, label]. 176 | 177 | Note that Caffe LMDB format is not efficient: it stores serialized raw 178 | arrays rather than JPEG images. 179 | 180 | Args: 181 | lmdb_path, shuffle, keys: same as :class:`LMDBData`. 182 | 183 | Example: 184 | .. code-block:: python 185 | 186 | ds = CaffeLMDB("/tmp/validation", keys='{:0>8d}') 187 | """ 188 | 189 | cpb = get_caffe_pb() 190 | lmdb_data = LMDBData(lmdb_path, shuffle, keys) 191 | 192 | def decoder(k, v): 193 | try: 194 | datum = cpb.Datum() 195 | datum.ParseFromString(v) 196 | img = np.fromstring(datum.data, dtype=np.uint8) 197 | img = img.reshape(datum.channels, datum.height, datum.width) 198 | except Exception: 199 | log_once("Cannot read key {}".format(k), 'warn') 200 | return None 201 | return [img.transpose(1, 2, 0), datum.label] 202 | logger.warn("Caffe LMDB format doesn't store jpeg-compressed images, \ 203 | it's not recommended due to its inferior performance.") 204 | return LMDBDataDecoder(lmdb_data, decoder) 205 | 206 | 207 | class SVMLightData(RNGDataFlow): 208 | """ Read X,y from an SVMlight file, and produce [X_i, y_i] pairs. """ 209 | 210 | def __init__(self, filename, shuffle=True): 211 | """ 212 | Args: 213 | filename (str): input file 214 | shuffle (bool): shuffle the data 215 | """ 216 | import sklearn.datasets # noqa 217 | self.X, self.y = sklearn.datasets.load_svmlight_file(filename) 218 | self.X = np.asarray(self.X.todense()) 219 | self.shuffle = shuffle 220 | 221 | def __len__(self): 222 | return len(self.y) 223 | 224 | def __iter__(self): 225 | idxs = np.arange(self.__len__()) 226 | if self.shuffle: 227 | self.rng.shuffle(idxs) 228 | for id in idxs: 229 | yield [self.X[id, :], self.y[id]] 230 | 231 | 232 | try: 233 | import h5py 234 | except ImportError: 235 | HDF5Data = create_dummy_class('HDF5Data', 'h5py') # noqa 236 | 237 | try: 238 | import lmdb 239 | except ImportError: 240 | for klass in ['LMDBData', 'LMDBDataDecoder', 'CaffeLMDB']: 241 | globals()[klass] = create_dummy_class(klass, 'lmdb') 242 | -------------------------------------------------------------------------------- /dataflow/dataflow/image.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: image.py 3 | 4 | 5 | import copy as copy_mod 6 | import numpy as np 7 | from contextlib import contextmanager 8 | 9 | from ..utils import logger 10 | from ..utils.argtools import shape2d 11 | from .base import RNGDataFlow 12 | from .common import MapData, MapDataComponent 13 | 14 | __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageCoordinates', 'AugmentImageComponents'] 15 | 16 | 17 | def check_dtype(img): 18 | assert isinstance(img, np.ndarray), "[Augmentor] Needs an numpy array, but got a {}!".format(type(img)) 19 | assert not isinstance(img.dtype, np.integer) or (img.dtype == np.uint8), \ 20 | "[Augmentor] Got image of type {}, use uint8 or floating points instead!".format(img.dtype) 21 | 22 | 23 | def validate_coords(coords): 24 | assert coords.ndim == 2, coords.ndim 25 | assert coords.shape[1] == 2, coords.shape 26 | assert np.issubdtype(coords.dtype, np.float), coords.dtype 27 | 28 | 29 | class ExceptionHandler: 30 | def __init__(self, catch_exceptions=False): 31 | self._nr_error = 0 32 | self.catch_exceptions = catch_exceptions 33 | 34 | @contextmanager 35 | def catch(self): 36 | try: 37 | yield 38 | except Exception: 39 | self._nr_error += 1 40 | if not self.catch_exceptions: 41 | raise 42 | else: 43 | if self._nr_error % 100 == 0 or self._nr_error < 10: 44 | logger.exception("Got {} augmentation errors.".format(self._nr_error)) 45 | 46 | 47 | class ImageFromFile(RNGDataFlow): 48 | """ Produce images read from a list of files as (h, w, c) arrays. """ 49 | def __init__(self, files, channel=3, resize=None, shuffle=False): 50 | """ 51 | Args: 52 | files (list): list of file paths. 53 | channel (int): 1 or 3. Will convert grayscale to RGB images if channel==3. 54 | Will produce (h, w, 1) array if channel==1. 55 | resize (tuple): int or (h, w) tuple. If given, resize the image. 56 | """ 57 | assert len(files), "No image files given to ImageFromFile!" 58 | self.files = files 59 | self.channel = int(channel) 60 | assert self.channel in [1, 3], self.channel 61 | self.imread_mode = cv2.IMREAD_GRAYSCALE if self.channel == 1 else cv2.IMREAD_COLOR 62 | if resize is not None: 63 | resize = shape2d(resize) 64 | self.resize = resize 65 | self.shuffle = shuffle 66 | 67 | def __len__(self): 68 | return len(self.files) 69 | 70 | def __iter__(self): 71 | if self.shuffle: 72 | self.rng.shuffle(self.files) 73 | for f in self.files: 74 | im = cv2.imread(f, self.imread_mode) 75 | assert im is not None, f 76 | if self.channel == 3: 77 | im = im[:, :, ::-1] 78 | if self.resize is not None: 79 | im = cv2.resize(im, tuple(self.resize[::-1])) 80 | if self.channel == 1: 81 | im = im[:, :, np.newaxis] 82 | yield [im] 83 | 84 | 85 | class AugmentImageComponent(MapDataComponent): 86 | """ 87 | Apply image augmentors on 1 image component. 88 | """ 89 | 90 | def __init__(self, ds, augmentors, index=0, copy=True, catch_exceptions=False): 91 | """ 92 | Args: 93 | ds (DataFlow): input DataFlow. 94 | augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order. 95 | index (int or str): the index or key of the image component to be augmented in the datapoint. 96 | copy (bool): Some augmentors modify the input images. When copy is 97 | True, a copy will be made before any augmentors are applied, 98 | to keep the original images not modified. 99 | Turn it off to save time when you know it's OK. 100 | catch_exceptions (bool): when set to True, will catch 101 | all exceptions and only warn you when there are too many (>100). 102 | Can be used to ignore occasion errors in data. 103 | """ 104 | if isinstance(augmentors, AugmentorList): 105 | self.augs = augmentors 106 | else: 107 | self.augs = AugmentorList(augmentors) 108 | self._copy = copy 109 | 110 | self._exception_handler = ExceptionHandler(catch_exceptions) 111 | super(AugmentImageComponent, self).__init__(ds, self._aug_mapper, index) 112 | 113 | def reset_state(self): 114 | self.ds.reset_state() 115 | self.augs.reset_state() 116 | 117 | def _aug_mapper(self, x): 118 | check_dtype(x) 119 | with self._exception_handler.catch(): 120 | if self._copy: 121 | x = copy_mod.deepcopy(x) 122 | return self.augs.augment(x) 123 | 124 | 125 | class AugmentImageCoordinates(MapData): 126 | """ 127 | Apply image augmentors on an image and a list of coordinates. 128 | Coordinates must be a Nx2 floating point array, each row is (x, y). 129 | """ 130 | 131 | def __init__(self, ds, augmentors, img_index=0, coords_index=1, copy=True, catch_exceptions=False): 132 | 133 | """ 134 | Args: 135 | ds (DataFlow): input DataFlow. 136 | augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order. 137 | img_index (int or str): the index/key of the image component to be augmented. 138 | coords_index (int or str): the index/key of the coordinate component to be augmented. 139 | copy, catch_exceptions: same as in :class:`AugmentImageComponent` 140 | """ 141 | if isinstance(augmentors, AugmentorList): 142 | self.augs = augmentors 143 | else: 144 | self.augs = AugmentorList(augmentors) 145 | 146 | self._img_index = img_index 147 | self._coords_index = coords_index 148 | self._copy = copy 149 | self._exception_handler = ExceptionHandler(catch_exceptions) 150 | 151 | super(AugmentImageCoordinates, self).__init__(ds, self._aug_mapper) 152 | 153 | def reset_state(self): 154 | self.ds.reset_state() 155 | self.augs.reset_state() 156 | 157 | def _aug_mapper(self, dp): 158 | with self._exception_handler.catch(): 159 | img, coords = dp[self._img_index], dp[self._coords_index] 160 | check_dtype(img) 161 | validate_coords(coords) 162 | if self._copy: 163 | img, coords = copy_mod.deepcopy((img, coords)) 164 | tfms = self.augs.get_transform(img) 165 | dp[self._img_index] = tfms.apply_image(img) 166 | dp[self._coords_index] = tfms.apply_coords(coords) 167 | return dp 168 | 169 | 170 | class AugmentImageComponents(MapData): 171 | """ 172 | Apply image augmentors on several components, with shared augmentation parameters. 173 | 174 | Example: 175 | 176 | .. code-block:: python 177 | 178 | ds = MyDataFlow() # produce [image(HWC), segmask(HW), keypoint(Nx2)] 179 | ds = AugmentImageComponents( 180 | ds, augs, 181 | index=(0,1), coords_index=(2,)) 182 | 183 | """ 184 | 185 | def __init__(self, ds, augmentors, index=(0, 1), coords_index=(), copy=True, catch_exceptions=False): 186 | """ 187 | Args: 188 | ds (DataFlow): input DataFlow. 189 | augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` instance to be applied in order. 190 | index: tuple of indices of the image components. 191 | coords_index: tuple of indices of the coordinates components. 192 | copy, catch_exceptions: same as in :class:`AugmentImageComponent` 193 | """ 194 | if isinstance(augmentors, AugmentorList): 195 | self.augs = augmentors 196 | else: 197 | self.augs = AugmentorList(augmentors) 198 | self.ds = ds 199 | self._exception_handler = ExceptionHandler(catch_exceptions) 200 | self._copy = copy 201 | self._index = index 202 | self._coords_index = coords_index 203 | 204 | super(AugmentImageComponents, self).__init__(ds, self._aug_mapper) 205 | 206 | def reset_state(self): 207 | self.ds.reset_state() 208 | self.augs.reset_state() 209 | 210 | def _aug_mapper(self, dp): 211 | dp = copy_mod.copy(dp) # always do a shallow copy, make sure the list is intact 212 | copy_func = copy_mod.deepcopy if self._copy else lambda x: x # noqa 213 | with self._exception_handler.catch(): 214 | major_image = self._index[0] # image to be used to get params. TODO better design? 215 | im = copy_func(dp[major_image]) 216 | check_dtype(im) 217 | tfms = self.augs.get_transform(im) 218 | dp[major_image] = tfms.apply_image(im) 219 | for idx in self._index[1:]: 220 | check_dtype(dp[idx]) 221 | dp[idx] = tfms.apply_image(copy_func(dp[idx])) 222 | for idx in self._coords_index: 223 | coords = copy_func(dp[idx]) 224 | validate_coords(coords) 225 | dp[idx] = tfms.apply_coords(coords) 226 | return dp 227 | 228 | 229 | try: 230 | import cv2 231 | from .imgaug import AugmentorList 232 | except ImportError: 233 | from ..utils.develop import create_dummy_class 234 | ImageFromFile = create_dummy_class('ImageFromFile', 'cv2') # noqa 235 | AugmentImageComponent = create_dummy_class('AugmentImageComponent', 'cv2') # noqa 236 | AugmentImageCoordinates = create_dummy_class('AugmentImageCoordinates', 'cv2') # noqa 237 | AugmentImageComponents = create_dummy_class('AugmentImageComponents', 'cv2') # noqa 238 | -------------------------------------------------------------------------------- /dataflow/dataflow/imgaug/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: __init__.py 3 | 4 | # https://github.com/celery/kombu/blob/7d13f9b95d0b50c94393b962e6def928511bfda6/kombu/__init__.py#L34-L36 5 | STATICA_HACK = True 6 | globals()['kcah_acitats'[::-1].upper()] = False 7 | if STATICA_HACK: 8 | from .base import * 9 | from .convert import * 10 | from .crop import * 11 | from .deform import * 12 | from .geometry import * 13 | from .imgproc import * 14 | from .meta import * 15 | from .misc import * 16 | from .noise import * 17 | from .paste import * 18 | from .transform import * 19 | from .external import * 20 | 21 | 22 | import os 23 | from pkgutil import iter_modules 24 | 25 | __all__ = [] 26 | 27 | 28 | def global_import(name): 29 | p = __import__(name, globals(), locals(), level=1) 30 | lst = p.__all__ if '__all__' in dir(p) else dir(p) 31 | if lst: 32 | del globals()[name] 33 | for k in lst: 34 | if not k.startswith('__'): 35 | globals()[k] = p.__dict__[k] 36 | __all__.append(k) 37 | 38 | 39 | try: 40 | import cv2 # noqa 41 | except ImportError: 42 | from ...utils import logger 43 | logger.warn("Cannot import 'cv2', therefore image augmentation is not available.") 44 | else: 45 | _CURR_DIR = os.path.dirname(__file__) 46 | for _, module_name, _ in iter_modules( 47 | [os.path.dirname(__file__)]): 48 | srcpath = os.path.join(_CURR_DIR, module_name + '.py') 49 | if not os.path.isfile(srcpath): 50 | continue 51 | if not module_name.startswith('_') and "_test" not in module_name: 52 | global_import(module_name) 53 | -------------------------------------------------------------------------------- /dataflow/dataflow/imgaug/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: base.py 3 | 4 | import os 5 | import inspect 6 | import pprint 7 | from collections import namedtuple 8 | import weakref 9 | 10 | from ...utils.argtools import log_once 11 | from ...utils.utils import get_rng 12 | from ...utils.develop import deprecated 13 | from ..image import check_dtype 14 | 15 | # Cannot import here if we want to keep backward compatibility. 16 | # Because this causes circular dependency 17 | # from .transform import TransformList, PhotometricTransform, TransformFactory 18 | 19 | __all__ = ['Augmentor', 'ImageAugmentor', 'AugmentorList', 'PhotometricAugmentor'] 20 | 21 | 22 | def _reset_augmentor_after_fork(aug_ref): 23 | aug = aug_ref() 24 | if aug: 25 | aug.reset_state() 26 | 27 | 28 | def _default_repr(self): 29 | """ 30 | Produce something like: 31 | "imgaug.MyAugmentor(field1={self.field1}, field2={self.field2})" 32 | 33 | It assumes that the instance `self` contains attributes that match its constructor. 34 | """ 35 | classname = type(self).__name__ 36 | argspec = inspect.getfullargspec(self.__init__) 37 | assert argspec.varargs is None, "The default __repr__ in {} doesn't work for varargs!".format(classname) 38 | assert argspec.varkw is None, "The default __repr__ in {} doesn't work for kwargs!".format(classname) 39 | defaults = {} 40 | 41 | fields = argspec.args[1:] 42 | defaults_pos = argspec.defaults 43 | if defaults_pos is not None: 44 | for f, d in zip(fields[::-1], defaults_pos[::-1]): 45 | defaults[f] = d 46 | 47 | for k in argspec.kwonlyargs: 48 | fields.append(k) 49 | if k in argspec.kwonlydefaults: 50 | defaults[k] = argspec.kwonlydefaults[k] 51 | 52 | argstr = [] 53 | for f in fields: 54 | assert hasattr(self, f), \ 55 | "Attribute {} in {} not found! Default __repr__ only works if " \ 56 | "the instance has attributes that match the constructor.".format(f, classname) 57 | attr = getattr(self, f) 58 | if f in defaults and attr is defaults[f]: 59 | continue 60 | argstr.append("{}={}".format(f, pprint.pformat(attr))) 61 | return "imgaug.{}({})".format(classname, ', '.join(argstr)) 62 | 63 | 64 | ImagePlaceholder = namedtuple("ImagePlaceholder", ["shape"]) 65 | 66 | 67 | class ImageAugmentor(object): 68 | """ 69 | Base class for an augmentor 70 | 71 | ImageAugmentor should take images of type uint8 in range [0, 255], or 72 | floating point images in range [0, 1] or [0, 255]. 73 | 74 | Attributes: 75 | rng: a numpy :class:`RandomState` 76 | """ 77 | 78 | def __init__(self): 79 | self.reset_state() 80 | 81 | # only available on Unix after Python 3.7 82 | if hasattr(os, 'register_at_fork'): 83 | os.register_at_fork( 84 | after_in_child=lambda: _reset_augmentor_after_fork(weakref.ref(self))) 85 | 86 | def _init(self, params=None): 87 | if params: 88 | for k, v in params.items(): 89 | if k != 'self' and not k.startswith('_'): 90 | setattr(self, k, v) 91 | 92 | def reset_state(self): 93 | """ 94 | Reset rng and other state of the augmentor. 95 | 96 | Similar to :meth:`DataFlow.reset_state`, the caller of Augmentor 97 | is responsible for calling this method (once or more times) in the **process that uses the augmentor** 98 | before using it. 99 | 100 | If you use a built-in augmentation dataflow (:class:`AugmentImageComponent`, etc), 101 | this method will be called in the dataflow's own `reset_state` method. 102 | 103 | If you use Python≥3.7 on Unix, this method will be automatically called after fork, 104 | and you do not need to bother calling it. 105 | """ 106 | self.rng = get_rng(self) 107 | 108 | def _rand_range(self, low=1.0, high=None, size=None): 109 | """ 110 | Generate uniform float random number between low and high using `self.rng`. 111 | """ 112 | if high is None: 113 | low, high = 0, low 114 | if size is None: 115 | size = [] 116 | return self.rng.uniform(low, high, size).astype("float32") 117 | 118 | def __str__(self): 119 | try: 120 | return _default_repr(self) 121 | except AssertionError as e: 122 | log_once(e.args[0], 'warn') 123 | return super(Augmentor, self).__repr__() 124 | 125 | __repr__ = __str__ 126 | 127 | def get_transform(self, img): 128 | """ 129 | Instantiate a :class:`Transform` object to be used given the input image. 130 | Subclasses should implement this method. 131 | 132 | The :class:`ImageAugmentor` often has random policies which generate deterministic transform. 133 | Any of those random policies should happen inside this method and instantiate 134 | an actual deterministic transform to be performed. 135 | The returned :class:`Transform` object should perform deterministic transforms 136 | through its :meth:`apply_*` method. 137 | 138 | In this way, the returned :class:`Transform` object can be used to transform not only the 139 | input image, but other images or coordinates associated with the image. 140 | 141 | Args: 142 | img (ndarray): see notes of this class on the requirements. 143 | 144 | Returns: 145 | Transform 146 | """ 147 | # This should be an abstract method 148 | # But we provide an implementation that uses the old interface, 149 | # for backward compatibility 150 | log_once("The old augmentor interface was deprecated. " 151 | "Please implement {} with `get_transform` instead!".format(self.__class__.__name__), 152 | "warning") 153 | 154 | def legacy_augment_coords(self, coords, p): 155 | try: 156 | return self._augment_coords(coords, p) 157 | except AttributeError: 158 | pass 159 | try: 160 | return self.augment_coords(coords, p) 161 | except AttributeError: 162 | pass 163 | return coords # this is the old default 164 | 165 | p = None # the default return value for this method 166 | try: 167 | p = self._get_augment_params(img) 168 | except AttributeError: 169 | pass 170 | try: 171 | p = self.get_augment_params(img) 172 | except AttributeError: 173 | pass 174 | 175 | from .transform import BaseTransform, TransformFactory 176 | if isinstance(p, BaseTransform): # some old augs return Transform already 177 | return p 178 | 179 | return TransformFactory(name="LegacyConversion -- " + str(self), 180 | apply_image=lambda img: self._augment(img, p), 181 | apply_coords=lambda coords: legacy_augment_coords(self, coords, p)) 182 | 183 | def augment(self, img): 184 | """ 185 | Create a transform, and apply it to augment the input image. 186 | 187 | This can save you one line of code, when you only care the augmentation of "one image". 188 | It will not return the :class:`Transform` object to you 189 | so you won't be able to apply the same transformation on 190 | other data associated with the image. 191 | 192 | Args: 193 | img (ndarray): see notes of this class on the requirements. 194 | 195 | Returns: 196 | img: augmented image. 197 | """ 198 | check_dtype(img) 199 | t = self.get_transform(img) 200 | return t.apply_image(img) 201 | 202 | # ########################### 203 | # Legacy interfaces: 204 | # ########################### 205 | @deprecated("Please use `get_transform` instead!", "2020-06-06", max_num_warnings=3) 206 | def augment_return_params(self, d): 207 | t = self.get_transform(d) 208 | return t.apply_image(d), t 209 | 210 | @deprecated("Please use `transform.apply_image` instead!", "2020-06-06", max_num_warnings=3) 211 | def augment_with_params(self, d, param): 212 | return param.apply_image(d) 213 | 214 | @deprecated("Please use `transform.apply_coords` instead!", "2020-06-06", max_num_warnings=3) 215 | def augment_coords(self, coords, param): 216 | return param.apply_coords(coords) 217 | 218 | 219 | class AugmentorList(ImageAugmentor): 220 | """ 221 | Augment an image by a list of augmentors 222 | """ 223 | 224 | def __init__(self, augmentors): 225 | """ 226 | Args: 227 | augmentors (list): list of :class:`ImageAugmentor` instance to be applied. 228 | """ 229 | assert isinstance(augmentors, (list, tuple)), augmentors 230 | self.augmentors = augmentors 231 | super(AugmentorList, self).__init__() 232 | 233 | def reset_state(self): 234 | """ Will reset state of each augmentor """ 235 | super(AugmentorList, self).reset_state() 236 | for a in self.augmentors: 237 | a.reset_state() 238 | 239 | def get_transform(self, img): 240 | check_dtype(img) 241 | assert img.ndim in [2, 3], img.ndim 242 | 243 | from .transform import LazyTransform, TransformList 244 | # The next augmentor requires the previous one to finish. 245 | # So we have to use LazyTransform 246 | tfms = [] 247 | for idx, a in enumerate(self.augmentors): 248 | if idx == 0: 249 | t = a.get_transform(img) 250 | else: 251 | t = LazyTransform(a.get_transform) 252 | 253 | if isinstance(t, TransformList): 254 | tfms.extend(t.tfms) 255 | else: 256 | tfms.append(t) 257 | return TransformList(tfms) 258 | 259 | def __str__(self): 260 | repr_each_aug = ",\n".join([" " + repr(x) for x in self.augmentors]) 261 | return "imgaug.AugmentorList([\n{}])".format(repr_each_aug) 262 | 263 | __repr__ = __str__ 264 | 265 | 266 | Augmentor = ImageAugmentor 267 | """ 268 | Legacy name. Augmentor and ImageAugmentor are now the same thing. 269 | """ 270 | 271 | 272 | class PhotometricAugmentor(ImageAugmentor): 273 | """ 274 | A base class for ImageAugmentor which only affects pixels. 275 | 276 | Subclass should implement `_get_params(img)` and `_impl(img, params)`. 277 | """ 278 | def get_transform(self, img): 279 | p = self._get_augment_params(img) 280 | from .transform import PhotometricTransform 281 | return PhotometricTransform(func=lambda img: self._augment(img, p), 282 | name="from " + str(self)) 283 | 284 | def _get_augment_params(self, _): 285 | return None 286 | -------------------------------------------------------------------------------- /dataflow/dataflow/imgaug/convert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: convert.py 3 | 4 | import numpy as np 5 | import cv2 6 | 7 | from .base import PhotometricAugmentor 8 | 9 | __all__ = ['ColorSpace', 'Grayscale', 'ToUint8', 'ToFloat32'] 10 | 11 | 12 | class ColorSpace(PhotometricAugmentor): 13 | """ Convert into another color space. """ 14 | 15 | def __init__(self, mode, keepdims=True): 16 | """ 17 | Args: 18 | mode: OpenCV color space conversion code (e.g., ``cv2.COLOR_BGR2HSV``) 19 | keepdims (bool): keep the dimension of image unchanged if OpenCV 20 | changes it. 21 | """ 22 | super(ColorSpace, self).__init__() 23 | self._init(locals()) 24 | 25 | def _augment(self, img, _): 26 | transf = cv2.cvtColor(img, self.mode) 27 | if self.keepdims: 28 | if len(transf.shape) is not len(img.shape): 29 | transf = transf[..., None] 30 | return transf 31 | 32 | 33 | class Grayscale(ColorSpace): 34 | """ Convert RGB or BGR image to grayscale. """ 35 | 36 | def __init__(self, keepdims=True, rgb=False, keepshape=False): 37 | """ 38 | Args: 39 | keepdims (bool): return image of shape [H, W, 1] instead of [H, W] 40 | rgb (bool): interpret input as RGB instead of the default BGR 41 | keepshape (bool): whether to duplicate the gray image into 3 channels 42 | so the result has the same shape as input. 43 | """ 44 | mode = cv2.COLOR_RGB2GRAY if rgb else cv2.COLOR_BGR2GRAY 45 | if keepshape: 46 | assert keepdims, "keepdims must be True when keepshape==True" 47 | super(Grayscale, self).__init__(mode, keepdims) 48 | self.keepshape = keepshape 49 | self.rgb = rgb 50 | 51 | def _augment(self, img, _): 52 | ret = super()._augment(img, _) 53 | if self.keepshape: 54 | return np.concatenate([ret] * 3, axis=2) 55 | else: 56 | return ret 57 | 58 | 59 | class ToUint8(PhotometricAugmentor): 60 | """ Clip and convert image to uint8. Useful to reduce communication overhead. """ 61 | def _augment(self, img, _): 62 | return np.clip(img, 0, 255).astype(np.uint8) 63 | 64 | 65 | class ToFloat32(PhotometricAugmentor): 66 | """ Convert image to float32, may increase quality of the augmentor. """ 67 | def _augment(self, img, _): 68 | return img.astype(np.float32) 69 | -------------------------------------------------------------------------------- /dataflow/dataflow/imgaug/crop.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: crop.py 3 | 4 | import numpy as np 5 | import cv2 6 | 7 | from ...utils.argtools import shape2d 8 | from ...utils.develop import log_deprecated 9 | from .base import ImageAugmentor, ImagePlaceholder 10 | from .transform import CropTransform, TransformList, ResizeTransform, PhotometricTransform 11 | from .misc import ResizeShortestEdge 12 | 13 | __all__ = ['RandomCrop', 'CenterCrop', 'RandomCropRandomShape', 14 | 'GoogleNetRandomCropAndResize', 'RandomCutout'] 15 | 16 | 17 | class RandomCrop(ImageAugmentor): 18 | """ Randomly crop the image into a smaller one """ 19 | 20 | def __init__(self, crop_shape): 21 | """ 22 | Args: 23 | crop_shape: (h, w), int or a tuple of int 24 | """ 25 | crop_shape = shape2d(crop_shape) 26 | crop_shape = (int(crop_shape[0]), int(crop_shape[1])) 27 | super(RandomCrop, self).__init__() 28 | self._init(locals()) 29 | 30 | def get_transform(self, img): 31 | orig_shape = img.shape 32 | assert orig_shape[0] >= self.crop_shape[0] \ 33 | and orig_shape[1] >= self.crop_shape[1], orig_shape 34 | diffh = orig_shape[0] - self.crop_shape[0] 35 | h0 = self.rng.randint(diffh + 1) 36 | diffw = orig_shape[1] - self.crop_shape[1] 37 | w0 = self.rng.randint(diffw + 1) 38 | return CropTransform(h0, w0, self.crop_shape[0], self.crop_shape[1]) 39 | 40 | 41 | class CenterCrop(ImageAugmentor): 42 | """ Crop the image at the center""" 43 | 44 | def __init__(self, crop_shape): 45 | """ 46 | Args: 47 | crop_shape: (h, w) tuple or a int 48 | """ 49 | crop_shape = shape2d(crop_shape) 50 | self._init(locals()) 51 | 52 | def get_transform(self, img): 53 | orig_shape = img.shape 54 | assert orig_shape[0] >= self.crop_shape[0] \ 55 | and orig_shape[1] >= self.crop_shape[1], orig_shape 56 | h0 = int((orig_shape[0] - self.crop_shape[0]) * 0.5) 57 | w0 = int((orig_shape[1] - self.crop_shape[1]) * 0.5) 58 | return CropTransform(h0, w0, self.crop_shape[0], self.crop_shape[1]) 59 | 60 | 61 | class RandomCropRandomShape(ImageAugmentor): 62 | """ Random crop with a random shape""" 63 | 64 | def __init__(self, wmin, hmin, 65 | wmax=None, hmax=None, 66 | max_aspect_ratio=None): 67 | """ 68 | Randomly crop a box of shape (h, w), sampled from [min, max] (both inclusive). 69 | If max is None, will use the input image shape. 70 | 71 | Args: 72 | wmin, hmin, wmax, hmax: range to sample shape. 73 | max_aspect_ratio (float): this argument has no effect and is deprecated. 74 | """ 75 | super(RandomCropRandomShape, self).__init__() 76 | if max_aspect_ratio is not None: 77 | log_deprecated("RandomCropRandomShape(max_aspect_ratio)", "It is never implemented!", "2020-06-06") 78 | self._init(locals()) 79 | 80 | def get_transform(self, img): 81 | hmax = self.hmax or img.shape[0] 82 | wmax = self.wmax or img.shape[1] 83 | h = self.rng.randint(self.hmin, hmax + 1) 84 | w = self.rng.randint(self.wmin, wmax + 1) 85 | diffh = img.shape[0] - h 86 | diffw = img.shape[1] - w 87 | assert diffh >= 0 and diffw >= 0, str(diffh) + ", " + str(diffw) 88 | y0 = 0 if diffh == 0 else self.rng.randint(diffh) 89 | x0 = 0 if diffw == 0 else self.rng.randint(diffw) 90 | return CropTransform(y0, x0, h, w) 91 | 92 | 93 | class GoogleNetRandomCropAndResize(ImageAugmentor): 94 | """ 95 | The random crop and resize augmentation proposed in 96 | Sec. 6 of "Going Deeper with Convolutions" by Google. 97 | This implementation follows the details in ``fb.resnet.torch``. 98 | 99 | It attempts to crop a random rectangle with 8%~100% area of the original image, 100 | and keep the aspect ratio between 3/4 to 4/3. Then it resize this crop to the target shape. 101 | If such crop cannot be found in 10 iterations, it will do a ResizeShortestEdge + CenterCrop. 102 | """ 103 | def __init__(self, crop_area_fraction=(0.08, 1.), 104 | aspect_ratio_range=(0.75, 1.333), 105 | target_shape=224, interp=cv2.INTER_LINEAR): 106 | """ 107 | Args: 108 | crop_area_fraction (tuple(float)): Defaults to crop 8%-100% area. 109 | aspect_ratio_range (tuple(float)): Defaults to make aspect ratio in 3/4-4/3. 110 | target_shape (int): Defaults to 224, the standard ImageNet image shape. 111 | """ 112 | super(GoogleNetRandomCropAndResize, self).__init__() 113 | self._init(locals()) 114 | 115 | def get_transform(self, img): 116 | h, w = img.shape[:2] 117 | area = h * w 118 | for _ in range(10): 119 | targetArea = self.rng.uniform(*self.crop_area_fraction) * area 120 | aspectR = self.rng.uniform(*self.aspect_ratio_range) 121 | ww = int(np.sqrt(targetArea * aspectR) + 0.5) 122 | hh = int(np.sqrt(targetArea / aspectR) + 0.5) 123 | if self.rng.uniform() < 0.5: 124 | ww, hh = hh, ww 125 | if hh <= h and ww <= w: 126 | x1 = self.rng.randint(0, w - ww + 1) 127 | y1 = self.rng.randint(0, h - hh + 1) 128 | return TransformList([ 129 | CropTransform(y1, x1, hh, ww), 130 | ResizeTransform(hh, ww, self.target_shape, self.target_shape, interp=self.interp) 131 | ]) 132 | resize = ResizeShortestEdge(self.target_shape, interp=self.interp).get_transform(img) 133 | out_shape = (resize.new_h, resize.new_w) 134 | crop = CenterCrop(self.target_shape).get_transform(ImagePlaceholder(shape=out_shape)) 135 | return TransformList([resize, crop]) 136 | 137 | 138 | class RandomCutout(ImageAugmentor): 139 | """ 140 | The cutout augmentation, as described in https://arxiv.org/abs/1708.04552 141 | """ 142 | def __init__(self, h_range, w_range, fill=0.): 143 | """ 144 | Args: 145 | h_range (int or tuple): the height of rectangle to cut. 146 | If a tuple, will randomly sample from this range [low, high) 147 | w_range (int or tuple): similar to above 148 | fill (float): the fill value 149 | """ 150 | super(RandomCutout, self).__init__() 151 | self._init(locals()) 152 | 153 | def _get_cutout_shape(self): 154 | if isinstance(self.h_range, int): 155 | h = self.h_range 156 | else: 157 | h = self.rng.randint(self.h_range) 158 | 159 | if isinstance(self.w_range, int): 160 | w = self.w_range 161 | else: 162 | w = self.rng.randint(self.w_range) 163 | return h, w 164 | 165 | @staticmethod 166 | def _cutout(img, y0, x0, h, w, fill): 167 | img[y0:y0 + h, x0:x0 + w] = fill 168 | return img 169 | 170 | def get_transform(self, img): 171 | h, w = self._get_cutout_shape() 172 | x0 = self.rng.randint(0, img.shape[1] + 1 - w) 173 | y0 = self.rng.randint(0, img.shape[0] + 1 - h) 174 | return PhotometricTransform( 175 | lambda img: RandomCutout._cutout(img, y0, x0, h, w, self.fill), 176 | "cutout") 177 | -------------------------------------------------------------------------------- /dataflow/dataflow/imgaug/deform.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: deform.py 3 | 4 | 5 | import numpy as np 6 | 7 | from ...utils import logger 8 | from .base import ImageAugmentor 9 | from .transform import TransformFactory 10 | 11 | __all__ = [] 12 | 13 | # Code was temporarily kept here for a future reference in case someone needs it 14 | # But it was already deprecated, 15 | # because this augmentation is not a general one that people will often find helpful. 16 | 17 | 18 | class GaussianMap(object): 19 | """ Generate Gaussian weighted deformation map""" 20 | # TODO really needs speedup 21 | 22 | def __init__(self, image_shape, sigma=0.5): 23 | assert len(image_shape) == 2 24 | self.shape = image_shape 25 | self.sigma = sigma 26 | 27 | def get_gaussian_weight(self, anchor): 28 | """ 29 | Args: 30 | anchor: coordinate of the center 31 | """ 32 | ret = np.zeros(self.shape, dtype='float32') 33 | 34 | y, x = np.mgrid[:self.shape[0], :self.shape[1]] 35 | y = y.astype('float32') / ret.shape[0] - anchor[0] 36 | x = x.astype('float32') / ret.shape[1] - anchor[1] 37 | g = np.exp(-(x**2 + y ** 2) / self.sigma) 38 | # cv2.imshow(" ", g) 39 | # cv2.waitKey() 40 | return g 41 | 42 | 43 | def np_sample(img, coords): 44 | # a numpy implementation of ImageSample layer 45 | coords = np.maximum(coords, 0) 46 | coords = np.minimum(coords, np.array([img.shape[0] - 1, img.shape[1] - 1])) 47 | 48 | lcoor = np.floor(coords).astype('int32') 49 | ucoor = lcoor + 1 50 | ucoor = np.minimum(ucoor, np.array([img.shape[0] - 1, img.shape[1] - 1])) 51 | diff = coords - lcoor 52 | neg_diff = 1.0 - diff 53 | 54 | lcoory, lcoorx = np.split(lcoor, 2, axis=2) 55 | ucoory, ucoorx = np.split(ucoor, 2, axis=2) 56 | diff = np.repeat(diff, 3, 2).reshape((diff.shape[0], diff.shape[1], 2, 3)) 57 | neg_diff = np.repeat(neg_diff, 3, 2).reshape((diff.shape[0], diff.shape[1], 2, 3)) 58 | diffy, diffx = np.split(diff, 2, axis=2) 59 | ndiffy, ndiffx = np.split(neg_diff, 2, axis=2) 60 | 61 | ret = img[lcoory, lcoorx, :] * ndiffx * ndiffy + \ 62 | img[ucoory, ucoorx, :] * diffx * diffy + \ 63 | img[lcoory, ucoorx, :] * ndiffy * diffx + \ 64 | img[ucoory, lcoorx, :] * diffy * ndiffx 65 | return ret[:, :, 0, :] 66 | 67 | 68 | class GaussianDeform(ImageAugmentor): 69 | """ 70 | Some kind of slow deformation I made up. Don't count on it. 71 | """ 72 | 73 | # TODO input/output with different shape 74 | 75 | def __init__(self, anchors, shape, sigma=0.5, randrange=None): 76 | """ 77 | Args: 78 | anchors (list): list of center coordinates in range [0,1]. 79 | shape(list or tuple): image shape in [h, w]. 80 | sigma (float): sigma for Gaussian weight 81 | randrange (int): offset range. Defaults to shape[0] / 8 82 | """ 83 | logger.warn("GaussianDeform is slow. Consider using it with 4 or more prefetching processes.") 84 | super(GaussianDeform, self).__init__() 85 | self.anchors = anchors 86 | self.K = len(self.anchors) 87 | self.shape = shape 88 | self.grid = np.mgrid[0:self.shape[0], 0:self.shape[1]].transpose(1, 2, 0) 89 | self.grid = self.grid.astype('float32') # HxWx2 90 | 91 | gm = GaussianMap(self.shape, sigma=sigma) 92 | self.gws = np.array([gm.get_gaussian_weight(ank) 93 | for ank in self.anchors], dtype='float32') # KxHxW 94 | self.gws = self.gws.transpose(1, 2, 0) # HxWxK 95 | if randrange is None: 96 | self.randrange = self.shape[0] / 8 97 | else: 98 | self.randrange = randrange 99 | self.sigma = sigma 100 | 101 | def get_transform(self, img): 102 | v = self.rng.rand(self.K, 2).astype('float32') - 0.5 103 | v = v * 2 * self.randrange 104 | return TransformFactory(name=str(self), apply_image=lambda img: self._augment(img, v)) 105 | 106 | def _augment(self, img, v): 107 | grid = self.grid + np.dot(self.gws, v) 108 | return np_sample(img, grid) 109 | -------------------------------------------------------------------------------- /dataflow/dataflow/imgaug/external.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import numpy as np 4 | 5 | from .base import ImageAugmentor 6 | from .transform import Transform 7 | 8 | __all__ = ['IAAugmentor', 'Albumentations'] 9 | 10 | 11 | class IAATransform(Transform): 12 | def __init__(self, aug, img_shape): 13 | self._init(locals()) 14 | 15 | def apply_image(self, img): 16 | return self.aug.augment_image(img) 17 | 18 | def apply_coords(self, coords): 19 | import imgaug as IA 20 | points = [IA.Keypoint(x=x, y=y) for x, y in coords] 21 | points = IA.KeypointsOnImage(points, shape=self.img_shape) 22 | augmented = self.aug.augment_keypoints([points])[0].keypoints 23 | return np.asarray([[p.x, p.y] for p in augmented]) 24 | 25 | 26 | class IAAugmentor(ImageAugmentor): 27 | """ 28 | Wrap an augmentor form the IAA library: https://github.com/aleju/imgaug. 29 | Both images and coordinates are supported. 30 | 31 | Note: 32 | 1. It's NOT RECOMMENDED 33 | to use coordinates because the IAA library does not handle coordinates accurately. 34 | 35 | 2. Only uint8 images are supported by the IAA library. 36 | 37 | 3. The IAA library can only produces images of the same shape. 38 | 39 | Example: 40 | 41 | .. code-block:: python 42 | 43 | from imgaug import augmenters as iaa # this is the aleju/imgaug library 44 | from tensorpack import imgaug # this is not the aleju/imgaug library 45 | # or from dataflow import imgaug # if you're using the standalone version of dataflow 46 | myaug = imgaug.IAAugmentor( 47 | iaa.Sequential([ 48 | iaa.Sharpen(alpha=(0, 1), lightness=(0.75, 1.5)), 49 | iaa.Fliplr(0.5), 50 | iaa.Crop(px=(0, 100)), 51 | ]) 52 | """ 53 | 54 | def __init__(self, augmentor): 55 | """ 56 | Args: 57 | augmentor (iaa.Augmenter): 58 | """ 59 | super(IAAugmentor, self).__init__() 60 | self._aug = augmentor 61 | 62 | def get_transform(self, img): 63 | return IAATransform(self._aug.to_deterministic(), img.shape) 64 | 65 | 66 | class AlbumentationsTransform(Transform): 67 | def __init__(self, aug, param): 68 | self._init(locals()) 69 | 70 | def apply_image(self, img): 71 | return self.aug.apply(img, **self.param) 72 | 73 | 74 | class Albumentations(ImageAugmentor): 75 | """ 76 | Wrap an augmentor form the albumentations library: https://github.com/albu/albumentations. 77 | Coordinate augmentation is not supported by the library. 78 | 79 | Example: 80 | 81 | .. code-block:: python 82 | 83 | from tensorpack import imgaug 84 | # or from dataflow import imgaug # if you're using the standalone version of dataflow 85 | import albumentations as AB 86 | myaug = imgaug.Albumentations(AB.RandomRotate90(p=1)) 87 | """ 88 | def __init__(self, augmentor): 89 | """ 90 | Args: 91 | augmentor (albumentations.BasicTransform): 92 | """ 93 | super(Albumentations, self).__init__() 94 | self._aug = augmentor 95 | 96 | def get_transform(self, img): 97 | return AlbumentationsTransform(self._aug, self._aug.get_params()) 98 | -------------------------------------------------------------------------------- /dataflow/dataflow/imgaug/geometry.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: geometry.py 3 | 4 | 5 | import math 6 | import numpy as np 7 | import cv2 8 | 9 | from .base import ImageAugmentor 10 | from .transform import WarpAffineTransform, CropTransform, TransformList 11 | 12 | __all__ = ['Shift', 'Rotation', 'RotationAndCropValid', 'Affine'] 13 | 14 | 15 | class Shift(ImageAugmentor): 16 | """ Random horizontal and vertical shifts """ 17 | 18 | def __init__(self, horiz_frac=0, vert_frac=0, 19 | border=cv2.BORDER_REPLICATE, border_value=0): 20 | """ 21 | Args: 22 | horiz_frac (float): max abs fraction for horizontal shift 23 | vert_frac (float): max abs fraction for horizontal shift 24 | border: cv2 border method 25 | border_value: cv2 border value for border=cv2.BORDER_CONSTANT 26 | """ 27 | assert horiz_frac < 1.0 and vert_frac < 1.0 28 | super(Shift, self).__init__() 29 | self._init(locals()) 30 | 31 | def get_transform(self, img): 32 | max_dx = self.horiz_frac * img.shape[1] 33 | max_dy = self.vert_frac * img.shape[0] 34 | dx = np.round(self._rand_range(-max_dx, max_dx)) 35 | dy = np.round(self._rand_range(-max_dy, max_dy)) 36 | 37 | mat = np.array([[1, 0, dx], [0, 1, dy]], dtype='float32') 38 | return WarpAffineTransform( 39 | mat, img.shape[1::-1], 40 | borderMode=self.border, borderValue=self.border_value) 41 | 42 | 43 | class Rotation(ImageAugmentor): 44 | """ Random rotate the image w.r.t a random center""" 45 | 46 | def __init__(self, max_deg, center_range=(0, 1), 47 | interp=cv2.INTER_LINEAR, 48 | border=cv2.BORDER_REPLICATE, step_deg=None, border_value=0): 49 | """ 50 | Args: 51 | max_deg (float): max abs value of the rotation angle (in degree). 52 | center_range (tuple): (min, max) range of the random rotation center. 53 | interp: cv2 interpolation method 54 | border: cv2 border method 55 | step_deg (float): if not None, the stepping of the rotation 56 | angle. The rotation angle will be a multiple of step_deg. This 57 | option requires ``max_deg==180`` and step_deg has to be a divisor of 180) 58 | border_value: cv2 border value for border=cv2.BORDER_CONSTANT 59 | """ 60 | assert step_deg is None or (max_deg == 180 and max_deg % step_deg == 0) 61 | super(Rotation, self).__init__() 62 | self._init(locals()) 63 | 64 | def get_transform(self, img): 65 | center = img.shape[1::-1] * self._rand_range( 66 | self.center_range[0], self.center_range[1], (2,)) 67 | deg = self._rand_range(-self.max_deg, self.max_deg) 68 | if self.step_deg: 69 | deg = deg // self.step_deg * self.step_deg 70 | """ 71 | The correct center is shape*0.5-0.5. This can be verified by: 72 | 73 | SHAPE = 7 74 | arr = np.random.rand(SHAPE, SHAPE) 75 | orig = arr 76 | c = SHAPE * 0.5 - 0.5 77 | c = (c, c) 78 | for k in range(4): 79 | mat = cv2.getRotationMatrix2D(c, 90, 1) 80 | arr = cv2.warpAffine(arr, mat, arr.shape) 81 | assert np.all(arr == orig) 82 | """ 83 | mat = cv2.getRotationMatrix2D(tuple(center - 0.5), float(deg), 1) 84 | return WarpAffineTransform( 85 | mat, img.shape[1::-1], interp=self.interp, 86 | borderMode=self.border, borderValue=self.border_value) 87 | 88 | 89 | class RotationAndCropValid(ImageAugmentor): 90 | """ Random rotate and then crop the largest possible rectangle. 91 | Note that this will produce images of different shapes. 92 | """ 93 | 94 | def __init__(self, max_deg, interp=cv2.INTER_LINEAR, step_deg=None): 95 | """ 96 | Args: 97 | max_deg, interp, step_deg: same as :class:`Rotation` 98 | """ 99 | assert step_deg is None or (max_deg == 180 and max_deg % step_deg == 0) 100 | super(RotationAndCropValid, self).__init__() 101 | self._init(locals()) 102 | 103 | def _get_deg(self, img): 104 | deg = self._rand_range(-self.max_deg, self.max_deg) 105 | if self.step_deg: 106 | deg = deg // self.step_deg * self.step_deg 107 | return float(deg) 108 | 109 | def get_transform(self, img): 110 | deg = self._get_deg(img) 111 | 112 | h, w = img.shape[:2] 113 | center = (img.shape[1] * 0.5, img.shape[0] * 0.5) 114 | rot_m = cv2.getRotationMatrix2D((center[0] - 0.5, center[1] - 0.5), deg, 1) 115 | tfm = WarpAffineTransform(rot_m, (w, h), interp=self.interp) 116 | 117 | neww, newh = RotationAndCropValid.largest_rotated_rect(w, h, deg) 118 | neww = min(neww, w) 119 | newh = min(newh, h) 120 | newx = int(center[0] - neww * 0.5) 121 | newy = int(center[1] - newh * 0.5) 122 | tfm2 = CropTransform(newy, newx, newh, neww) 123 | return TransformList([tfm, tfm2]) 124 | 125 | @staticmethod 126 | def largest_rotated_rect(w, h, angle): 127 | """ 128 | Get largest rectangle after rotation. 129 | http://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders 130 | """ 131 | angle = angle / 180.0 * math.pi 132 | if w <= 0 or h <= 0: 133 | return 0, 0 134 | 135 | width_is_longer = w >= h 136 | side_long, side_short = (w, h) if width_is_longer else (h, w) 137 | 138 | # since the solutions for angle, -angle and 180-angle are all the same, 139 | # if suffices to look at the first quadrant and the absolute values of sin,cos: 140 | sin_a, cos_a = abs(math.sin(angle)), abs(math.cos(angle)) 141 | if side_short <= 2. * sin_a * cos_a * side_long: 142 | # half constrained case: two crop corners touch the longer side, 143 | # the other two corners are on the mid-line parallel to the longer line 144 | x = 0.5 * side_short 145 | wr, hr = (x / sin_a, x / cos_a) if width_is_longer else (x / cos_a, x / sin_a) 146 | else: 147 | # fully constrained case: crop touches all 4 sides 148 | cos_2a = cos_a * cos_a - sin_a * sin_a 149 | wr, hr = (w * cos_a - h * sin_a) / cos_2a, (h * cos_a - w * sin_a) / cos_2a 150 | return int(np.round(wr)), int(np.round(hr)) 151 | 152 | 153 | class Affine(ImageAugmentor): 154 | """ 155 | Random affine transform of the image w.r.t to the image center. 156 | Transformations involve: 157 | 158 | - Translation ("move" image on the x-/y-axis) 159 | - Rotation 160 | - Scaling ("zoom" in/out) 161 | - Shear (move one side of the image, turning a square into a trapezoid) 162 | """ 163 | 164 | def __init__(self, scale=None, translate_frac=None, rotate_max_deg=0.0, shear=0.0, 165 | interp=cv2.INTER_LINEAR, border=cv2.BORDER_REPLICATE, border_value=0): 166 | """ 167 | Args: 168 | scale (tuple of 2 floats): scaling factor interval, e.g (a, b), then scale is 169 | randomly sampled from the range a <= scale <= b. Will keep 170 | original scale by default. 171 | translate_frac (tuple of 2 floats): tuple of max abs fraction for horizontal 172 | and vertical translation. For example translate_frac=(a, b), then horizontal shift 173 | is randomly sampled in the range 0 < dx < img_width * a and vertical shift is 174 | randomly sampled in the range 0 < dy < img_height * b. Will 175 | not translate by default. 176 | shear (float): max abs shear value in degrees between 0 to 180 177 | interp: cv2 interpolation method 178 | border: cv2 border method 179 | border_value: cv2 border value for border=cv2.BORDER_CONSTANT 180 | """ 181 | if scale is not None: 182 | assert isinstance(scale, tuple) and len(scale) == 2, \ 183 | "Argument scale should be a tuple of two floats, e.g (a, b)" 184 | 185 | if translate_frac is not None: 186 | assert isinstance(translate_frac, tuple) and len(translate_frac) == 2, \ 187 | "Argument translate_frac should be a tuple of two floats, e.g (a, b)" 188 | 189 | assert shear >= 0.0, "Argument shear should be between 0.0 and 180.0" 190 | 191 | super(Affine, self).__init__() 192 | self._init(locals()) 193 | 194 | def get_transform(self, img): 195 | if self.scale is not None: 196 | scale = self._rand_range(self.scale[0], self.scale[1]) 197 | else: 198 | scale = 1.0 199 | 200 | if self.translate_frac is not None: 201 | max_dx = self.translate_frac[0] * img.shape[1] 202 | max_dy = self.translate_frac[1] * img.shape[0] 203 | dx = np.round(self._rand_range(-max_dx, max_dx)) 204 | dy = np.round(self._rand_range(-max_dy, max_dy)) 205 | else: 206 | dx = 0 207 | dy = 0 208 | 209 | if self.shear > 0.0: 210 | shear = self._rand_range(-self.shear, self.shear) 211 | sin_shear = math.sin(math.radians(shear)) 212 | cos_shear = math.cos(math.radians(shear)) 213 | else: 214 | sin_shear = 0.0 215 | cos_shear = 1.0 216 | 217 | center = (img.shape[1::-1] * np.array((0.5, 0.5))) - 0.5 218 | deg = self._rand_range(-self.rotate_max_deg, self.rotate_max_deg) 219 | 220 | transform_matrix = cv2.getRotationMatrix2D(tuple(center), deg, scale) 221 | 222 | # Apply shear : 223 | if self.shear > 0.0: 224 | m00 = transform_matrix[0, 0] 225 | m01 = transform_matrix[0, 1] 226 | m10 = transform_matrix[1, 0] 227 | m11 = transform_matrix[1, 1] 228 | transform_matrix[0, 1] = m01 * cos_shear + m00 * sin_shear 229 | transform_matrix[1, 1] = m11 * cos_shear + m10 * sin_shear 230 | # Add correction term to keep the center unchanged 231 | tx = center[0] * (1.0 - m00) - center[1] * transform_matrix[0, 1] 232 | ty = -center[0] * m10 + center[1] * (1.0 - transform_matrix[1, 1]) 233 | transform_matrix[0, 2] = tx 234 | transform_matrix[1, 2] = ty 235 | 236 | # Apply shift : 237 | transform_matrix[0, 2] += dx 238 | transform_matrix[1, 2] += dy 239 | return WarpAffineTransform(transform_matrix, img.shape[1::-1], 240 | self.interp, self.border, self.border_value) 241 | -------------------------------------------------------------------------------- /dataflow/dataflow/imgaug/imgaug_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: _test.py 3 | 4 | 5 | import sys 6 | import numpy as np 7 | import cv2 8 | import unittest 9 | 10 | from .base import ImageAugmentor, AugmentorList 11 | from .imgproc import Contrast 12 | from .noise import SaltPepperNoise 13 | from .misc import Flip, Resize 14 | 15 | 16 | def _rand_image(shape=(20, 20)): 17 | return np.random.rand(*shape).astype("float32") 18 | 19 | 20 | class LegacyBrightness(ImageAugmentor): 21 | def __init__(self, delta, clip=True): 22 | super(LegacyBrightness, self).__init__() 23 | assert delta > 0 24 | self._init(locals()) 25 | 26 | def _get_augment_params(self, _): 27 | v = self._rand_range(-self.delta, self.delta) 28 | return v 29 | 30 | def _augment(self, img, v): 31 | old_dtype = img.dtype 32 | img = img.astype('float32') 33 | img += v 34 | if self.clip or old_dtype == np.uint8: 35 | img = np.clip(img, 0, 255) 36 | return img.astype(old_dtype) 37 | 38 | 39 | class LegacyFlip(ImageAugmentor): 40 | def __init__(self, horiz=False, vert=False, prob=0.5): 41 | super(LegacyFlip, self).__init__() 42 | if horiz and vert: 43 | raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.") 44 | elif horiz: 45 | self.code = 1 46 | elif vert: 47 | self.code = 0 48 | else: 49 | raise ValueError("At least one of horiz or vert has to be True!") 50 | self._init(locals()) 51 | 52 | def _get_augment_params(self, img): 53 | h, w = img.shape[:2] 54 | do = self._rand_range() < self.prob 55 | return (do, h, w) 56 | 57 | def _augment(self, img, param): 58 | do, _, _ = param 59 | if do: 60 | ret = cv2.flip(img, self.code) 61 | if img.ndim == 3 and ret.ndim == 2: 62 | ret = ret[:, :, np.newaxis] 63 | else: 64 | ret = img 65 | return ret 66 | 67 | def _augment_coords(self, coords, param): 68 | do, h, w = param 69 | if do: 70 | if self.code == 0: 71 | coords[:, 1] = h - coords[:, 1] 72 | elif self.code == 1: 73 | coords[:, 0] = w - coords[:, 0] 74 | return coords 75 | 76 | 77 | class ImgAugTest(unittest.TestCase): 78 | def _get_augs(self): 79 | return AugmentorList([ 80 | Contrast((0.8, 1.2)), 81 | Flip(horiz=True), 82 | Resize((30, 30)), 83 | SaltPepperNoise() 84 | ]) 85 | 86 | def _get_augs_with_legacy(self): 87 | return AugmentorList([ 88 | LegacyBrightness(0.5), 89 | LegacyFlip(horiz=True), 90 | Resize((30, 30)), 91 | SaltPepperNoise() 92 | ]) 93 | 94 | def test_augmentors(self): 95 | augmentors = self._get_augs() 96 | 97 | img = _rand_image() 98 | orig = img.copy() 99 | tfms = augmentors.get_transform(img) 100 | 101 | # test printing 102 | print(augmentors) 103 | print(tfms) 104 | 105 | newimg = tfms.apply_image(img) 106 | print(tfms) # lazy ones will instantiate after the first apply 107 | 108 | newimg2 = tfms.apply_image(orig) 109 | self.assertTrue(np.allclose(newimg, newimg2)) 110 | self.assertEqual(newimg2.shape[0], 30) 111 | 112 | coords = np.asarray([[0, 0], [10, 12]], dtype="float32") 113 | tfms.apply_coords(coords) 114 | 115 | def test_legacy_usage(self): 116 | augmentors = self._get_augs() 117 | 118 | img = _rand_image() 119 | orig = img.copy() 120 | newimg, tfms = augmentors.augment_return_params(img) 121 | newimg2 = augmentors.augment_with_params(orig, tfms) 122 | self.assertTrue(np.allclose(newimg, newimg2)) 123 | self.assertEqual(newimg2.shape[0], 30) 124 | 125 | coords = np.asarray([[0, 0], [10, 12]], dtype="float32") 126 | augmentors.augment_coords(coords, tfms) 127 | 128 | def test_legacy_augs_new_usage(self): 129 | augmentors = self._get_augs_with_legacy() 130 | 131 | img = _rand_image() 132 | orig = img.copy() 133 | tfms = augmentors.get_transform(img) 134 | newimg = tfms.apply_image(img) 135 | newimg2 = tfms.apply_image(orig) 136 | self.assertTrue(np.allclose(newimg, newimg2)) 137 | self.assertEqual(newimg2.shape[0], 30) 138 | 139 | coords = np.asarray([[0, 0], [10, 12]], dtype="float32") 140 | tfms.apply_coords(coords) 141 | 142 | def test_legacy_augs_legacy_usage(self): 143 | augmentors = self._get_augs_with_legacy() 144 | 145 | img = _rand_image() 146 | orig = img.copy() 147 | newimg, tfms = augmentors.augment_return_params(img) 148 | newimg2 = augmentors.augment_with_params(orig, tfms) 149 | self.assertTrue(np.allclose(newimg, newimg2)) 150 | self.assertEqual(newimg2.shape[0], 30) 151 | 152 | coords = np.asarray([[0, 0], [10, 12]], dtype="float32") 153 | augmentors.augment_coords(coords, tfms) 154 | 155 | 156 | if __name__ == '__main__': 157 | anchors = [(0.2, 0.2), (0.7, 0.2), (0.8, 0.8), (0.5, 0.5), (0.2, 0.5)] 158 | augmentors = AugmentorList([ 159 | Contrast((0.8, 1.2)), 160 | Flip(horiz=True), 161 | # RandomCropRandomShape(0.3), 162 | SaltPepperNoise() 163 | ]) 164 | 165 | img = cv2.imread(sys.argv[1]) 166 | newimg, prms = augmentors._augment_return_params(img) 167 | cv2.imshow(" ", newimg.astype('uint8')) 168 | cv2.waitKey() 169 | 170 | newimg = augmentors._augment(img, prms) 171 | cv2.imshow(" ", newimg.astype('uint8')) 172 | cv2.waitKey() 173 | -------------------------------------------------------------------------------- /dataflow/dataflow/imgaug/meta.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: meta.py 3 | 4 | 5 | from .base import ImageAugmentor 6 | from .transform import NoOpTransform, TransformList, TransformFactory 7 | 8 | __all__ = ['RandomChooseAug', 'MapImage', 'Identity', 'RandomApplyAug', 9 | 'RandomOrderAug'] 10 | 11 | 12 | class Identity(ImageAugmentor): 13 | """ A no-op augmentor """ 14 | def get_transform(self, img): 15 | return NoOpTransform() 16 | 17 | 18 | class RandomApplyAug(ImageAugmentor): 19 | """ Randomly apply the augmentor with a probability. 20 | Otherwise do nothing 21 | """ 22 | 23 | def __init__(self, aug, prob): 24 | """ 25 | Args: 26 | aug (ImageAugmentor): an augmentor. 27 | prob (float): the probability to apply the augmentor. 28 | """ 29 | self._init(locals()) 30 | super(RandomApplyAug, self).__init__() 31 | 32 | def get_transform(self, img): 33 | p = self.rng.rand() 34 | if p < self.prob: 35 | return self.aug.get_transform(img) 36 | else: 37 | return NoOpTransform() 38 | 39 | def reset_state(self): 40 | super(RandomApplyAug, self).reset_state() 41 | self.aug.reset_state() 42 | 43 | 44 | class RandomChooseAug(ImageAugmentor): 45 | """ Randomly choose one from a list of augmentors """ 46 | def __init__(self, aug_lists): 47 | """ 48 | Args: 49 | aug_lists (list): list of augmentors, or list of (augmentor, probability) tuples 50 | """ 51 | if isinstance(aug_lists[0], (tuple, list)): 52 | prob = [k[1] for k in aug_lists] 53 | aug_lists = [k[0] for k in aug_lists] 54 | self._init(locals()) 55 | else: 56 | prob = [1.0 / len(aug_lists)] * len(aug_lists) 57 | self._init(locals()) 58 | super(RandomChooseAug, self).__init__() 59 | 60 | def reset_state(self): 61 | super(RandomChooseAug, self).reset_state() 62 | for a in self.aug_lists: 63 | a.reset_state() 64 | 65 | def get_transform(self, img): 66 | aug_idx = self.rng.choice(len(self.aug_lists), p=self.prob) 67 | return self.aug_lists[aug_idx].get_transform(img) 68 | 69 | 70 | class RandomOrderAug(ImageAugmentor): 71 | """ 72 | Apply the augmentors with randomized order. 73 | """ 74 | 75 | def __init__(self, aug_lists): 76 | """ 77 | Args: 78 | aug_lists (list): list of augmentors. 79 | The augmentors are assumed to not change the shape of images. 80 | """ 81 | self._init(locals()) 82 | super(RandomOrderAug, self).__init__() 83 | 84 | def reset_state(self): 85 | super(RandomOrderAug, self).reset_state() 86 | for a in self.aug_lists: 87 | a.reset_state() 88 | 89 | def get_transform(self, img): 90 | # Note: this makes assumption that the augmentors do not make changes 91 | # to the image that will affect how the transforms will be instantiated 92 | # in the subsequent augmentors. 93 | idxs = self.rng.permutation(len(self.aug_lists)) 94 | tfms = [self.aug_lists[k].get_transform(img) 95 | for k in range(len(self.aug_lists))] 96 | return TransformList([tfms[k] for k in idxs]) 97 | 98 | 99 | class MapImage(ImageAugmentor): 100 | """ 101 | Map the image array by simple functions. 102 | """ 103 | 104 | def __init__(self, func, coord_func=None): 105 | """ 106 | Args: 107 | func: a function which takes an image array and return an augmented one 108 | coord_func: optional. A function which takes coordinates and return augmented ones. 109 | Coordinates should be Nx2 array of (x, y)s. 110 | """ 111 | super(MapImage, self).__init__() 112 | self.func = func 113 | self.coord_func = coord_func 114 | 115 | def get_transform(self, img): 116 | if self.coord_func: 117 | return TransformFactory(name="MapImage", apply_image=self.func, apply_coords=self.coord_func) 118 | else: 119 | return TransformFactory(name="MapImage", apply_image=self.func) 120 | -------------------------------------------------------------------------------- /dataflow/dataflow/imgaug/misc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: misc.py 3 | 4 | import cv2 5 | 6 | from ...utils import logger 7 | from ...utils.argtools import shape2d 8 | from .base import ImageAugmentor 9 | from .transform import ResizeTransform, NoOpTransform, FlipTransform, TransposeTransform 10 | 11 | __all__ = ['Flip', 'Resize', 'RandomResize', 'ResizeShortestEdge', 'Transpose'] 12 | 13 | 14 | class Flip(ImageAugmentor): 15 | """ 16 | Random flip the image either horizontally or vertically. 17 | """ 18 | def __init__(self, horiz=False, vert=False, prob=0.5): 19 | """ 20 | Args: 21 | horiz (bool): use horizontal flip. 22 | vert (bool): use vertical flip. 23 | prob (float): probability of flip. 24 | """ 25 | super(Flip, self).__init__() 26 | if horiz and vert: 27 | raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.") 28 | if not horiz and not vert: 29 | raise ValueError("At least one of horiz or vert has to be True!") 30 | self._init(locals()) 31 | 32 | def get_transform(self, img): 33 | h, w = img.shape[:2] 34 | do = self._rand_range() < self.prob 35 | if not do: 36 | return NoOpTransform() 37 | else: 38 | return FlipTransform(h, w, self.horiz) 39 | 40 | 41 | class Resize(ImageAugmentor): 42 | """ Resize image to a target size""" 43 | 44 | def __init__(self, shape, interp=cv2.INTER_LINEAR): 45 | """ 46 | Args: 47 | shape: (h, w) tuple or a int 48 | interp: cv2 interpolation method 49 | """ 50 | shape = tuple(shape2d(shape)) 51 | self._init(locals()) 52 | 53 | def get_transform(self, img): 54 | return ResizeTransform( 55 | img.shape[0], img.shape[1], 56 | self.shape[0], self.shape[1], self.interp) 57 | 58 | 59 | class ResizeShortestEdge(ImageAugmentor): 60 | """ 61 | Resize the shortest edge to a certain number while 62 | keeping the aspect ratio. 63 | """ 64 | 65 | def __init__(self, size, interp=cv2.INTER_LINEAR): 66 | """ 67 | Args: 68 | size (int): the size to resize the shortest edge to. 69 | """ 70 | size = int(size) 71 | self._init(locals()) 72 | 73 | def get_transform(self, img): 74 | h, w = img.shape[:2] 75 | scale = self.size * 1.0 / min(h, w) 76 | if h < w: 77 | newh, neww = self.size, int(scale * w + 0.5) 78 | else: 79 | newh, neww = int(scale * h + 0.5), self.size 80 | return ResizeTransform(h, w, newh, neww, self.interp) 81 | 82 | 83 | class RandomResize(ImageAugmentor): 84 | """ Randomly rescale width and height of the image.""" 85 | 86 | def __init__(self, xrange, yrange=None, minimum=(0, 0), aspect_ratio_thres=0.15, 87 | interp=cv2.INTER_LINEAR): 88 | """ 89 | Args: 90 | xrange (tuple): a (min, max) tuple. If is floating point, the 91 | tuple defines the range of scaling ratio of new width, e.g. (0.9, 1.2). 92 | If is integer, the tuple defines the range of new width in pixels, e.g. (200, 350). 93 | yrange (tuple): similar to xrange, but for height. Should be None when aspect_ratio_thres==0. 94 | minimum (tuple): (xmin, ymin) in pixels. To avoid scaling down too much. 95 | aspect_ratio_thres (float): discard samples which change aspect ratio 96 | larger than this threshold. Set to 0 to keep aspect ratio. 97 | interp: cv2 interpolation method 98 | """ 99 | super(RandomResize, self).__init__() 100 | assert aspect_ratio_thres >= 0 101 | self._init(locals()) 102 | 103 | def is_float(tp): 104 | return isinstance(tp[0], float) or isinstance(tp[1], float) 105 | 106 | if yrange is not None: 107 | assert is_float(xrange) == is_float(yrange), "xrange and yrange has different type!" 108 | self._is_scale = is_float(xrange) 109 | 110 | if aspect_ratio_thres == 0: 111 | if self._is_scale: 112 | assert xrange == yrange or yrange is None 113 | else: 114 | if yrange is not None: 115 | logger.warn("aspect_ratio_thres==0, yrange is not used!") 116 | 117 | def get_transform(self, img): 118 | cnt = 0 119 | h, w = img.shape[:2] 120 | 121 | def get_dest_size(): 122 | if self._is_scale: 123 | sx = self._rand_range(*self.xrange) 124 | if self.aspect_ratio_thres == 0: 125 | sy = sx 126 | else: 127 | sy = self._rand_range(*self.yrange) 128 | destX = max(sx * w, self.minimum[0]) 129 | destY = max(sy * h, self.minimum[1]) 130 | else: 131 | sx = self._rand_range(*self.xrange) 132 | if self.aspect_ratio_thres == 0: 133 | sy = sx * 1.0 / w * h 134 | else: 135 | sy = self._rand_range(*self.yrange) 136 | destX = max(sx, self.minimum[0]) 137 | destY = max(sy, self.minimum[1]) 138 | return (int(destX + 0.5), int(destY + 0.5)) 139 | 140 | while True: 141 | destX, destY = get_dest_size() 142 | if self.aspect_ratio_thres > 0: # don't check when thres == 0 143 | oldr = w * 1.0 / h 144 | newr = destX * 1.0 / destY 145 | diff = abs(newr - oldr) / oldr 146 | if diff >= self.aspect_ratio_thres + 1e-5: 147 | cnt += 1 148 | if cnt > 50: 149 | logger.warn("RandomResize failed to augment an image") 150 | return ResizeTransform(h, w, h, w, self.interp) 151 | continue 152 | return ResizeTransform(h, w, destY, destX, self.interp) 153 | 154 | 155 | class Transpose(ImageAugmentor): 156 | """ 157 | Random transpose the image 158 | """ 159 | def __init__(self, prob=0.5): 160 | """ 161 | Args: 162 | prob (float): probability of transpose. 163 | """ 164 | super(Transpose, self).__init__() 165 | self.prob = prob 166 | 167 | def get_transform(self, _): 168 | if self.rng.rand() < self.prob: 169 | return TransposeTransform() 170 | else: 171 | return NoOpTransform() 172 | -------------------------------------------------------------------------------- /dataflow/dataflow/imgaug/noise.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: noise.py 3 | 4 | 5 | import numpy as np 6 | import cv2 7 | 8 | from .base import PhotometricAugmentor 9 | 10 | __all__ = ['JpegNoise', 'GaussianNoise', 'SaltPepperNoise'] 11 | 12 | 13 | class JpegNoise(PhotometricAugmentor): 14 | """ Random JPEG noise. """ 15 | 16 | def __init__(self, quality_range=(40, 100)): 17 | """ 18 | Args: 19 | quality_range (tuple): range to sample JPEG quality 20 | """ 21 | super(JpegNoise, self).__init__() 22 | self._init(locals()) 23 | 24 | def _get_augment_params(self, img): 25 | return self.rng.randint(*self.quality_range) 26 | 27 | def _augment(self, img, q): 28 | enc = cv2.imencode('.jpg', img, [cv2.IMWRITE_JPEG_QUALITY, q])[1] 29 | return cv2.imdecode(enc, 1).astype(img.dtype) 30 | 31 | 32 | class GaussianNoise(PhotometricAugmentor): 33 | """ 34 | Add random Gaussian noise N(0, sigma^2) of the same shape to img. 35 | """ 36 | def __init__(self, sigma=1, clip=True): 37 | """ 38 | Args: 39 | sigma (float): stddev of the Gaussian distribution. 40 | clip (bool): clip the result to [0,255] in the end. 41 | """ 42 | super(GaussianNoise, self).__init__() 43 | self._init(locals()) 44 | 45 | def _get_augment_params(self, img): 46 | return self.rng.randn(*img.shape) 47 | 48 | def _augment(self, img, noise): 49 | old_dtype = img.dtype 50 | ret = img + noise * self.sigma 51 | if self.clip or old_dtype == np.uint8: 52 | ret = np.clip(ret, 0, 255) 53 | return ret.astype(old_dtype) 54 | 55 | 56 | class SaltPepperNoise(PhotometricAugmentor): 57 | """ Salt and pepper noise. 58 | Randomly set some elements in image to 0 or 255, regardless of its channels. 59 | """ 60 | 61 | def __init__(self, white_prob=0.05, black_prob=0.05): 62 | """ 63 | Args: 64 | white_prob (float), black_prob (float): probabilities setting an element to 255 or 0. 65 | """ 66 | assert white_prob + black_prob <= 1, "Sum of probabilities cannot be greater than 1" 67 | super(SaltPepperNoise, self).__init__() 68 | self._init(locals()) 69 | 70 | def _get_augment_params(self, img): 71 | return self.rng.uniform(low=0, high=1, size=img.shape) 72 | 73 | def _augment(self, img, param): 74 | img[param > (1 - self.white_prob)] = 255 75 | img[param < self.black_prob] = 0 76 | return img 77 | -------------------------------------------------------------------------------- /dataflow/dataflow/imgaug/paste.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: paste.py 3 | 4 | 5 | import numpy as np 6 | from abc import abstractmethod 7 | 8 | from .base import ImageAugmentor 9 | from .transform import TransformFactory 10 | 11 | __all__ = ['CenterPaste', 'BackgroundFiller', 'ConstantBackgroundFiller', 12 | 'RandomPaste'] 13 | 14 | 15 | class BackgroundFiller(object): 16 | """ Base class for all BackgroundFiller""" 17 | 18 | def fill(self, background_shape, img): 19 | """ 20 | Return a proper background image of background_shape, given img. 21 | 22 | Args: 23 | background_shape (tuple): a shape (h, w) 24 | img: an image 25 | Returns: 26 | a background image 27 | """ 28 | background_shape = tuple(background_shape) 29 | return self._fill(background_shape, img) 30 | 31 | @abstractmethod 32 | def _fill(self, background_shape, img): 33 | pass 34 | 35 | 36 | class ConstantBackgroundFiller(BackgroundFiller): 37 | """ Fill the background by a constant """ 38 | 39 | def __init__(self, value): 40 | """ 41 | Args: 42 | value (float): the value to fill the background. 43 | """ 44 | self.value = value 45 | 46 | def _fill(self, background_shape, img): 47 | assert img.ndim in [3, 2] 48 | if img.ndim == 3: 49 | return_shape = background_shape + (img.shape[2],) 50 | else: 51 | return_shape = background_shape 52 | return np.zeros(return_shape, dtype=img.dtype) + self.value 53 | 54 | 55 | # NOTE: 56 | # apply_coords should be implemeted in paste transform, but not yet done 57 | 58 | 59 | class CenterPaste(ImageAugmentor): 60 | """ 61 | Paste the image onto the center of a background canvas. 62 | """ 63 | 64 | def __init__(self, background_shape, background_filler=None): 65 | """ 66 | Args: 67 | background_shape (tuple): shape of the background canvas. 68 | background_filler (BackgroundFiller): How to fill the background. Defaults to zero-filler. 69 | """ 70 | if background_filler is None: 71 | background_filler = ConstantBackgroundFiller(0) 72 | 73 | self._init(locals()) 74 | 75 | def get_transform(self, _): 76 | return TransformFactory(name=str(self), apply_image=lambda img: self._impl(img)) 77 | 78 | def _impl(self, img): 79 | img_shape = img.shape[:2] 80 | assert self.background_shape[0] >= img_shape[0] and self.background_shape[1] >= img_shape[1] 81 | 82 | background = self.background_filler.fill( 83 | self.background_shape, img) 84 | y0 = int((self.background_shape[0] - img_shape[0]) * 0.5) 85 | x0 = int((self.background_shape[1] - img_shape[1]) * 0.5) 86 | background[y0:y0 + img_shape[0], x0:x0 + img_shape[1]] = img 87 | return background 88 | 89 | 90 | class RandomPaste(CenterPaste): 91 | """ 92 | Randomly paste the image onto a background canvas. 93 | """ 94 | 95 | def get_transform(self, img): 96 | img_shape = img.shape[:2] 97 | assert self.background_shape[0] > img_shape[0] and self.background_shape[1] > img_shape[1] 98 | 99 | y0 = self._rand_range(self.background_shape[0] - img_shape[0]) 100 | x0 = self._rand_range(self.background_shape[1] - img_shape[1]) 101 | l = int(x0), int(y0) 102 | return TransformFactory(name=str(self), apply_image=lambda img: self._impl(img, l)) 103 | 104 | def _impl(self, img, loc): 105 | x0, y0 = loc 106 | img_shape = img.shape[:2] 107 | background = self.background_filler.fill( 108 | self.background_shape, img) 109 | background[y0:y0 + img_shape[0], x0:x0 + img_shape[1]] = img 110 | return background 111 | -------------------------------------------------------------------------------- /dataflow/dataflow/raw.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: raw.py 3 | 4 | 5 | import copy 6 | import numpy as np 7 | import six 8 | 9 | from .base import DataFlow, RNGDataFlow 10 | 11 | __all__ = ['FakeData', 'DataFromQueue', 'DataFromList', 'DataFromGenerator', 'DataFromIterable'] 12 | 13 | 14 | class FakeData(RNGDataFlow): 15 | """ Generate fake data of given shapes""" 16 | 17 | def __init__(self, shapes, size=1000, random=True, dtype='float32', domain=(0, 1)): 18 | """ 19 | Args: 20 | shapes (list): a list of lists/tuples. Shapes of each component. 21 | size (int): size of this DataFlow. 22 | random (bool): whether to randomly generate data every iteration. 23 | Note that merely generating the data could sometimes be time-consuming! 24 | dtype (str or list): data type as string, or a list of data types. 25 | domain (tuple or list): (min, max) tuple, or a list of such tuples 26 | """ 27 | super(FakeData, self).__init__() 28 | self.shapes = shapes 29 | self._size = int(size) 30 | self.random = random 31 | self.dtype = [dtype] * len(shapes) if isinstance(dtype, six.string_types) else dtype 32 | self.domain = [domain] * len(shapes) if isinstance(domain, tuple) else domain 33 | assert len(self.dtype) == len(self.shapes) 34 | assert len(self.domain) == len(self.domain) 35 | 36 | def __len__(self): 37 | return self._size 38 | 39 | def __iter__(self): 40 | if self.random: 41 | for _ in range(self._size): 42 | val = [] 43 | for k in range(len(self.shapes)): 44 | v = self.rng.rand(*self.shapes[k]) * (self.domain[k][1] - self.domain[k][0]) + self.domain[k][0] 45 | val.append(v.astype(self.dtype[k])) 46 | yield val 47 | else: 48 | val = [] 49 | for k in range(len(self.shapes)): 50 | v = self.rng.rand(*self.shapes[k]) * (self.domain[k][1] - self.domain[k][0]) + self.domain[k][0] 51 | val.append(v.astype(self.dtype[k])) 52 | for _ in range(self._size): 53 | yield copy.copy(val) 54 | 55 | 56 | class DataFromQueue(DataFlow): 57 | """ Produce data from a queue """ 58 | def __init__(self, queue): 59 | """ 60 | Args: 61 | queue (queue): a queue with ``get()`` method. 62 | """ 63 | self.queue = queue 64 | 65 | def __iter__(self): 66 | while True: 67 | yield self.queue.get() 68 | 69 | 70 | class DataFromList(RNGDataFlow): 71 | """ Wrap a list of datapoints to a DataFlow""" 72 | 73 | def __init__(self, lst, shuffle=True): 74 | """ 75 | Args: 76 | lst (list): input list. Each element is a datapoint. 77 | shuffle (bool): shuffle data. 78 | """ 79 | super(DataFromList, self).__init__() 80 | self.lst = lst 81 | self.shuffle = shuffle 82 | 83 | def __len__(self): 84 | return len(self.lst) 85 | 86 | def __iter__(self): 87 | if not self.shuffle: 88 | yield from self.lst 89 | else: 90 | idxs = np.arange(len(self.lst)) 91 | self.rng.shuffle(idxs) 92 | for k in idxs: 93 | yield self.lst[k] 94 | 95 | 96 | class DataFromGenerator(DataFlow): 97 | """ 98 | Wrap a generator to a DataFlow. 99 | The dataflow will not have length. 100 | """ 101 | def __init__(self, gen): 102 | """ 103 | Args: 104 | gen: iterable, or a callable that returns an iterable 105 | """ 106 | self._gen = gen 107 | 108 | def __iter__(self): 109 | if not callable(self._gen): 110 | yield from self._gen 111 | else: 112 | yield from self._gen() 113 | 114 | def __len__(self): 115 | return len(self._gen) 116 | 117 | 118 | class DataFromIterable(DataFlow): 119 | """ Wrap an iterable of datapoints to a DataFlow""" 120 | def __init__(self, iterable): 121 | """ 122 | Args: 123 | iterable: an iterable object 124 | """ 125 | self._itr = iterable 126 | try: 127 | self._len = len(iterable) 128 | except Exception: 129 | self._len = None 130 | 131 | def __len__(self): 132 | if self._len is None: 133 | raise NotImplementedError 134 | return self._len 135 | 136 | def __iter__(self): 137 | yield from self._itr 138 | -------------------------------------------------------------------------------- /dataflow/dataflow/remote.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: remote.py 3 | 4 | 5 | import multiprocessing as mp 6 | import time 7 | from collections import deque 8 | import tqdm 9 | 10 | from ..utils import logger 11 | from ..utils.concurrency import DIE 12 | from ..utils.serialize import dumps, loads 13 | from ..utils.utils import get_tqdm_kwargs 14 | from .base import DataFlow, DataFlowReentrantGuard 15 | 16 | try: 17 | import zmq 18 | except ImportError: 19 | logger.warn("Error in 'import zmq'. remote feature won't be available") 20 | __all__ = [] 21 | else: 22 | __all__ = ['send_dataflow_zmq', 'RemoteDataZMQ'] 23 | 24 | 25 | def send_dataflow_zmq(df, addr, hwm=50, format=None, bind=False): 26 | """ 27 | Run DataFlow and send data to a ZMQ socket addr. 28 | It will serialize and send each datapoint to this address with a PUSH socket. 29 | This function never returns. 30 | 31 | Args: 32 | df (DataFlow): Will infinitely loop over the DataFlow. 33 | addr: a ZMQ socket endpoint. 34 | hwm (int): ZMQ high-water mark (buffer size) 35 | format (str): The serialization format. 36 | Default format uses :mod:`utils.serialize`. 37 | This format works with :class:`dataflow.RemoteDataZMQ`. 38 | An alternate format is 'zmq_ops', used by https://github.com/tensorpack/zmq_ops 39 | and :class:`input_source.ZMQInput`. 40 | bind (bool): whether to bind or connect to the endpoint address. 41 | """ 42 | assert format in [None, 'zmq_op', 'zmq_ops'] 43 | if format is None: 44 | dump_fn = dumps 45 | else: 46 | from zmq_ops import dump_arrays 47 | dump_fn = dump_arrays 48 | 49 | ctx = zmq.Context() 50 | socket = ctx.socket(zmq.PUSH) 51 | socket.set_hwm(hwm) 52 | if bind: 53 | socket.bind(addr) 54 | else: 55 | socket.connect(addr) 56 | try: 57 | df.reset_state() 58 | logger.info("Serving data to {} with {} format ...".format( 59 | addr, 'default' if format is None else 'zmq_ops')) 60 | INTERVAL = 200 61 | q = deque(maxlen=INTERVAL) 62 | 63 | try: 64 | total = len(df) 65 | except NotImplementedError: 66 | total = 0 67 | tqdm_args = get_tqdm_kwargs(leave=True, smoothing=0.8) 68 | tqdm_args['bar_format'] = tqdm_args['bar_format'] + "{postfix}" 69 | while True: 70 | with tqdm.trange(total, **tqdm_args) as pbar: 71 | for dp in df: 72 | start = time.time() 73 | socket.send(dump_fn(dp), copy=False) 74 | q.append(time.time() - start) 75 | pbar.update(1) 76 | if pbar.n % INTERVAL == 0: 77 | avg = "{:.3f}".format(sum(q) / len(q)) 78 | pbar.set_postfix({'AvgSendLat': avg}) 79 | finally: 80 | logger.info("Exiting send_dataflow_zmq ...") 81 | socket.setsockopt(zmq.LINGER, 0) 82 | socket.close() 83 | if not ctx.closed: 84 | ctx.destroy(0) 85 | 86 | 87 | class RemoteDataZMQ(DataFlow): 88 | """ 89 | Produce data from ZMQ PULL socket(s). 90 | It is the receiver-side counterpart of :func:`send_dataflow_zmq`, which uses :mod:`tensorpack.utils.serialize` 91 | for serialization. 92 | See http://tensorpack.readthedocs.io/tutorial/efficient-dataflow.html#distributed-dataflow 93 | 94 | Attributes: 95 | cnt1, cnt2 (int): number of data points received from addr1 and addr2 96 | """ 97 | def __init__(self, addr1, addr2=None, hwm=50, bind=True): 98 | """ 99 | Args: 100 | addr1,addr2 (str): addr of the zmq endpoint to connect to. 101 | Use both if you need two protocols (e.g. both IPC and TCP). 102 | I don't think you'll ever need 3. 103 | hwm (int): ZMQ high-water mark (buffer size) 104 | bind (bool): whether to connect or bind the endpoint 105 | """ 106 | assert addr1 107 | self._addr1 = addr1 108 | self._addr2 = addr2 109 | self._hwm = int(hwm) 110 | self._guard = DataFlowReentrantGuard() 111 | self._bind = bind 112 | 113 | def reset_state(self): 114 | self.cnt1 = 0 115 | self.cnt2 = 0 116 | 117 | def bind_or_connect(self, socket, addr): 118 | if self._bind: 119 | socket.bind(addr) 120 | else: 121 | socket.connect(addr) 122 | 123 | def __iter__(self): 124 | with self._guard: 125 | try: 126 | ctx = zmq.Context() 127 | if self._addr2 is None: 128 | socket = ctx.socket(zmq.PULL) 129 | socket.set_hwm(self._hwm) 130 | self.bind_or_connect(socket, self._addr1) 131 | 132 | while True: 133 | dp = loads(socket.recv(copy=False)) 134 | yield dp 135 | self.cnt1 += 1 136 | else: 137 | socket1 = ctx.socket(zmq.PULL) 138 | socket1.set_hwm(self._hwm) 139 | self.bind_or_connect(socket1, self._addr1) 140 | 141 | socket2 = ctx.socket(zmq.PULL) 142 | socket2.set_hwm(self._hwm) 143 | self.bind_or_connect(socket2, self._addr2) 144 | 145 | poller = zmq.Poller() 146 | poller.register(socket1, zmq.POLLIN) 147 | poller.register(socket2, zmq.POLLIN) 148 | 149 | while True: 150 | evts = poller.poll() 151 | for sock, evt in evts: 152 | dp = loads(sock.recv(copy=False)) 153 | yield dp 154 | if sock == socket1: 155 | self.cnt1 += 1 156 | else: 157 | self.cnt2 += 1 158 | finally: 159 | ctx.destroy(linger=0) 160 | 161 | 162 | # for internal use only 163 | def dump_dataflow_to_process_queue(df, size, nr_consumer): 164 | """ 165 | Convert a DataFlow to a :class:`multiprocessing.Queue`. 166 | The DataFlow will only be reset in the spawned process. 167 | 168 | Args: 169 | df (DataFlow): the DataFlow to dump. 170 | size (int): size of the queue 171 | nr_consumer (int): number of consumer of the queue. 172 | The producer will add this many of ``DIE`` sentinel to the end of the queue. 173 | 174 | Returns: 175 | tuple(queue, process): 176 | The process will take data from ``df`` and fill 177 | the queue, once you start it. Each element in the queue is (idx, 178 | dp). idx can be the ``DIE`` sentinel when ``df`` is exhausted. 179 | """ 180 | q = mp.Queue(size) 181 | 182 | class EnqueProc(mp.Process): 183 | 184 | def __init__(self, df, q, nr_consumer): 185 | super(EnqueProc, self).__init__() 186 | self.df = df 187 | self.q = q 188 | 189 | def run(self): 190 | self.df.reset_state() 191 | try: 192 | for idx, dp in enumerate(self.df): 193 | self.q.put((idx, dp)) 194 | finally: 195 | for _ in range(nr_consumer): 196 | self.q.put((DIE, None)) 197 | 198 | proc = EnqueProc(df, q, nr_consumer) 199 | return q, proc 200 | 201 | 202 | if __name__ == '__main__': 203 | from argparse import ArgumentParser 204 | from .raw import FakeData 205 | from .common import TestDataSpeed 206 | 207 | """ 208 | Test the multi-producer single-consumer model 209 | """ 210 | parser = ArgumentParser() 211 | parser.add_argument('-t', '--task', choices=['send', 'recv'], required=True) 212 | parser.add_argument('-a', '--addr1', required=True) 213 | parser.add_argument('-b', '--addr2', default=None) 214 | args = parser.parse_args() 215 | 216 | # tcp addr like "tcp://127.0.0.1:8877" 217 | # ipc addr like "ipc://@ipc-test" 218 | if args.task == 'send': 219 | # use random=True to make it slow and cpu-consuming 220 | ds = FakeData([(128, 244, 244, 3)], 1000, random=True) 221 | send_dataflow_zmq(ds, args.addr1) 222 | else: 223 | ds = RemoteDataZMQ(args.addr1, args.addr2) 224 | logger.info("Each DP is 73.5MB") 225 | TestDataSpeed(ds).start_test() 226 | -------------------------------------------------------------------------------- /dataflow/dataflow/serialize.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: serialize.py 3 | 4 | import numpy as np 5 | import os 6 | import platform 7 | from collections import defaultdict 8 | 9 | from ..utils import logger 10 | from ..utils.serialize import dumps, loads 11 | from ..utils.develop import create_dummy_class # noqa 12 | from ..utils.utils import get_tqdm 13 | from .base import DataFlow 14 | from .common import FixedSizeData, MapData 15 | from .format import HDF5Data, LMDBData 16 | from .raw import DataFromGenerator, DataFromList 17 | 18 | __all__ = ['LMDBSerializer', 'NumpySerializer', 'TFRecordSerializer', 'HDF5Serializer'] 19 | 20 | 21 | def _reset_df_and_get_size(df): 22 | df.reset_state() 23 | try: 24 | sz = len(df) 25 | except NotImplementedError: 26 | sz = 0 27 | return sz 28 | 29 | 30 | class LMDBSerializer(): 31 | """ 32 | Serialize a Dataflow to a lmdb database, where the keys are indices and values 33 | are serialized datapoints. 34 | 35 | You will need to ``pip install lmdb`` to use it. 36 | 37 | Example: 38 | 39 | .. code-block:: python 40 | 41 | LMDBSerializer.save(my_df, "output.lmdb") 42 | 43 | new_df = LMDBSerializer.load("output.lmdb", shuffle=True) 44 | """ 45 | @staticmethod 46 | def save(df, path, write_frequency=5000): 47 | """ 48 | Args: 49 | df (DataFlow): the DataFlow to serialize. 50 | path (str): output path. Either a directory or an lmdb file. 51 | write_frequency (int): the frequency to write back data to disk. 52 | A smaller value reduces memory usage. 53 | """ 54 | assert isinstance(df, DataFlow), type(df) 55 | isdir = os.path.isdir(path) 56 | if isdir: 57 | assert not os.path.isfile(os.path.join(path, 'data.mdb')), "LMDB file exists!" 58 | else: 59 | assert not os.path.isfile(path), "LMDB file {} exists!".format(path) 60 | # It's OK to use super large map_size on Linux, but not on other platforms 61 | # See: https://github.com/NVIDIA/DIGITS/issues/206 62 | map_size = 1099511627776 * 2 if platform.system() == 'Linux' else 128 * 10**6 63 | db = lmdb.open(path, subdir=isdir, 64 | map_size=map_size, readonly=False, 65 | meminit=False, map_async=True) # need sync() at the end 66 | size = _reset_df_and_get_size(df) 67 | 68 | # put data into lmdb, and doubling the size if full. 69 | # Ref: https://github.com/NVIDIA/DIGITS/pull/209/files 70 | def put_or_grow(txn, key, value): 71 | try: 72 | txn.put(key, value) 73 | return txn 74 | except lmdb.MapFullError: 75 | pass 76 | txn.abort() 77 | curr_size = db.info()['map_size'] 78 | new_size = curr_size * 2 79 | logger.info("Doubling LMDB map_size to {:.2f}GB".format(new_size / 10**9)) 80 | db.set_mapsize(new_size) 81 | txn = db.begin(write=True) 82 | txn = put_or_grow(txn, key, value) 83 | return txn 84 | 85 | with get_tqdm(total=size) as pbar: 86 | idx = -1 87 | 88 | # LMDB transaction is not exception-safe! 89 | # although it has a context manager interface 90 | txn = db.begin(write=True) 91 | for idx, dp in enumerate(df): 92 | txn = put_or_grow(txn, u'{:08}'.format(idx).encode('ascii'), dumps(dp)) 93 | pbar.update() 94 | if (idx + 1) % write_frequency == 0: 95 | txn.commit() 96 | txn = db.begin(write=True) 97 | txn.commit() 98 | 99 | keys = [u'{:08}'.format(k).encode('ascii') for k in range(idx + 1)] 100 | with db.begin(write=True) as txn: 101 | txn = put_or_grow(txn, b'__keys__', dumps(keys)) 102 | 103 | logger.info("Flushing database ...") 104 | db.sync() 105 | db.close() 106 | 107 | @staticmethod 108 | def load(path, shuffle=True): 109 | """ 110 | Note: 111 | If you found deserialization being the bottleneck, you can use :class:`LMDBData` as the reader 112 | and run deserialization as a mapper in parallel. 113 | """ 114 | df = LMDBData(path, shuffle=shuffle) 115 | return MapData(df, LMDBSerializer._deserialize_lmdb) 116 | 117 | @staticmethod 118 | def _deserialize_lmdb(dp): 119 | return loads(dp[1]) 120 | 121 | 122 | class NumpySerializer(): 123 | """ 124 | Serialize the entire dataflow to a npz dict. 125 | Note that this would have to store the entire dataflow in memory, 126 | and is also >10x slower than LMDB/TFRecord serializers. 127 | """ 128 | 129 | @staticmethod 130 | def save(df, path): 131 | """ 132 | Args: 133 | df (DataFlow): the DataFlow to serialize. 134 | path (str): output npz file. 135 | """ 136 | buffer = [] 137 | size = _reset_df_and_get_size(df) 138 | with get_tqdm(total=size) as pbar: 139 | for dp in df: 140 | buffer.append(dp) 141 | pbar.update() 142 | np.savez_compressed(path, buffer=np.asarray(buffer, dtype=np.object)) 143 | 144 | @staticmethod 145 | def load(path, shuffle=True): 146 | # allow_pickle defaults to False since numpy 1.16.3 147 | # (https://www.numpy.org/devdocs/release.html#unpickling-while-loading-requires-explicit-opt-in) 148 | buffer = np.load(path, allow_pickle=True)['buffer'] 149 | return DataFromList(buffer, shuffle=shuffle) 150 | 151 | 152 | class TFRecordSerializer(): 153 | """ 154 | Serialize datapoints to bytes (by tensorpack's default serializer) and write to a TFRecord file. 155 | 156 | Note that TFRecord does not support random access and is in fact not very performant. 157 | It's better to use :class:`LMDBSerializer`. 158 | """ 159 | @staticmethod 160 | def save(df, path): 161 | """ 162 | Args: 163 | df (DataFlow): the DataFlow to serialize. 164 | path (str): output tfrecord file. 165 | """ 166 | size = _reset_df_and_get_size(df) 167 | with tf.python_io.TFRecordWriter(path) as writer, get_tqdm(total=size) as pbar: 168 | for dp in df: 169 | writer.write(dumps(dp)) 170 | pbar.update() 171 | 172 | @staticmethod 173 | def load(path, size=None): 174 | """ 175 | Args: 176 | size (int): total number of records. If not provided, the returned dataflow will have no `__len__()`. 177 | It's needed because this metadata is not stored in the TFRecord file. 178 | """ 179 | gen = tf.python_io.tf_record_iterator(path) 180 | ds = DataFromGenerator(gen) 181 | ds = MapData(ds, loads) 182 | if size is not None: 183 | ds = FixedSizeData(ds, size) 184 | return ds 185 | 186 | 187 | class HDF5Serializer(): 188 | """ 189 | Write datapoints to a HDF5 file. 190 | 191 | Note that HDF5 files are in fact not very performant and currently do not support lazy loading. 192 | It's better to use :class:`LMDBSerializer`. 193 | """ 194 | @staticmethod 195 | def save(df, path, data_paths): 196 | """ 197 | Args: 198 | df (DataFlow): the DataFlow to serialize. 199 | path (str): output hdf5 file. 200 | data_paths (list[str]): list of h5 paths. It should have the same 201 | length as each datapoint, and each path should correspond to one 202 | component of the datapoint. 203 | """ 204 | size = _reset_df_and_get_size(df) 205 | buffer = defaultdict(list) 206 | 207 | with get_tqdm(total=size) as pbar: 208 | for dp in df: 209 | assert len(dp) == len(data_paths), "Datapoint has {} components!".format(len(dp)) 210 | for k, el in zip(data_paths, dp): 211 | buffer[k].append(el) 212 | pbar.update() 213 | 214 | with h5py.File(path, 'w') as hf, get_tqdm(total=len(data_paths)) as pbar: 215 | for data_path in data_paths: 216 | hf.create_dataset(data_path, data=buffer[data_path]) 217 | pbar.update() 218 | 219 | @staticmethod 220 | def load(path, data_paths, shuffle=True): 221 | """ 222 | Args: 223 | data_paths (list): list of h5 paths to be zipped. 224 | """ 225 | return HDF5Data(path, data_paths, shuffle) 226 | 227 | 228 | try: 229 | import lmdb 230 | except ImportError: 231 | LMDBSerializer = create_dummy_class('LMDBSerializer', 'lmdb') # noqa 232 | 233 | try: 234 | import tensorflow as tf 235 | except ImportError: 236 | TFRecordSerializer = create_dummy_class('TFRecordSerializer', 'tensorflow') # noqa 237 | 238 | try: 239 | import h5py 240 | except ImportError: 241 | HDF5Serializer = create_dummy_class('HDF5Serializer', 'h5py') # noqa 242 | 243 | 244 | if __name__ == '__main__': 245 | from .raw import FakeData 246 | import time 247 | ds = FakeData([[300, 300, 3], [1]], 1000) 248 | 249 | print(time.time()) 250 | TFRecordSerializer.save(ds, 'out.tfrecords') 251 | print(time.time()) 252 | df = TFRecordSerializer.load('out.tfrecords', size=1000) 253 | df.reset_state() 254 | for idx, dp in enumerate(df): 255 | pass 256 | print("TF Finished, ", idx) 257 | print(time.time()) 258 | 259 | LMDBSerializer.save(ds, 'out.lmdb') 260 | print(time.time()) 261 | df = LMDBSerializer.load('out.lmdb') 262 | df.reset_state() 263 | for idx, dp in enumerate(df): 264 | pass 265 | print("LMDB Finished, ", idx) 266 | print(time.time()) 267 | 268 | NumpySerializer.save(ds, 'out.npz') 269 | print(time.time()) 270 | df = NumpySerializer.load('out.npz') 271 | df.reset_state() 272 | for idx, dp in enumerate(df): 273 | pass 274 | print("Numpy Finished, ", idx) 275 | print(time.time()) 276 | 277 | paths = ['p1', 'p2'] 278 | HDF5Serializer.save(ds, 'out.h5', paths) 279 | print(time.time()) 280 | df = HDF5Serializer.load('out.h5', paths) 281 | df.reset_state() 282 | for idx, dp in enumerate(df): 283 | pass 284 | print("HDF5 Finished, ", idx) 285 | print(time.time()) 286 | -------------------------------------------------------------------------------- /dataflow/dataflow/serialize_test.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import tempfile 5 | import numpy as np 6 | import os 7 | import unittest 8 | 9 | from tensorpack.dataflow import HDF5Serializer, LMDBSerializer, NumpySerializer, TFRecordSerializer 10 | from tensorpack.dataflow.base import DataFlow 11 | 12 | 13 | def delete_file_if_exists(fn): 14 | try: 15 | os.remove(fn) 16 | except OSError: 17 | pass 18 | 19 | 20 | class SeededFakeDataFlow(DataFlow): 21 | """docstring for SeededFakeDataFlow""" 22 | 23 | def __init__(self, seed=42, size=32): 24 | super(SeededFakeDataFlow, self).__init__() 25 | self.seed = seed 26 | self._size = size 27 | self.cache = [] 28 | 29 | def reset_state(self): 30 | np.random.seed(self.seed) 31 | for _ in range(self._size): 32 | label = np.random.randint(low=0, high=10) 33 | img = np.random.randn(28, 28, 3) 34 | self.cache.append([label, img]) 35 | 36 | def __len__(self): 37 | return self._size 38 | 39 | def __iter__(self): 40 | for dp in self.cache: 41 | yield dp 42 | 43 | 44 | class SerializerTest(unittest.TestCase): 45 | 46 | def run_write_read_test(self, file, serializer, w_args, w_kwargs, r_args, r_kwargs, error_msg): 47 | try: 48 | delete_file_if_exists(file) 49 | 50 | ds_expected = SeededFakeDataFlow() 51 | serializer.save(ds_expected, file, *w_args, **w_kwargs) 52 | ds_actual = serializer.load(file, *r_args, **r_kwargs) 53 | 54 | ds_actual.reset_state() 55 | ds_expected.reset_state() 56 | 57 | for dp_expected, dp_actual in zip(ds_expected.__iter__(), ds_actual.__iter__()): 58 | self.assertEqual(dp_expected[0], dp_actual[0]) 59 | self.assertTrue(np.allclose(dp_expected[1], dp_actual[1])) 60 | except ImportError: 61 | print(error_msg) 62 | 63 | def test_lmdb(self): 64 | with tempfile.TemporaryDirectory() as f: 65 | self.run_write_read_test( 66 | os.path.join(f, 'test.lmdb'), 67 | LMDBSerializer, 68 | {}, {}, 69 | {}, {'shuffle': False}, 70 | 'Skip test_lmdb, no lmdb available') 71 | 72 | def test_tfrecord(self): 73 | with tempfile.TemporaryDirectory() as f: 74 | self.run_write_read_test( 75 | os.path.join(f, 'test.tfrecord'), 76 | TFRecordSerializer, 77 | {}, {}, 78 | {}, {'size': 32}, 79 | 'Skip test_tfrecord, no tensorflow available') 80 | 81 | def test_numpy(self): 82 | with tempfile.TemporaryDirectory() as f: 83 | self.run_write_read_test( 84 | os.path.join(f, 'test.npz'), 85 | NumpySerializer, 86 | {}, {}, 87 | {}, {'shuffle': False}, 88 | 'Skip test_numpy, no numpy available') 89 | 90 | def test_hdf5(self): 91 | args = [['label', 'image']] 92 | with tempfile.TemporaryDirectory() as f: 93 | self.run_write_read_test( 94 | os.path.join(f, 'test.h5'), 95 | HDF5Serializer, 96 | args, {}, 97 | args, {'shuffle': False}, 98 | 'Skip test_hdf5, no h5py available') 99 | 100 | 101 | if __name__ == '__main__': 102 | unittest.main() 103 | -------------------------------------------------------------------------------- /dataflow/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: __init__.py 3 | 4 | """ 5 | Common utils. 6 | These utils should be irrelevant to tensorflow. 7 | """ 8 | 9 | # https://github.com/celery/kombu/blob/7d13f9b95d0b50c94393b962e6def928511bfda6/kombu/__init__.py#L34-L36 10 | STATICA_HACK = True 11 | globals()['kcah_acitats'[::-1].upper()] = False 12 | if STATICA_HACK: 13 | from .utils import * 14 | 15 | 16 | __all__ = [] 17 | 18 | 19 | def _global_import(name): 20 | p = __import__(name, globals(), None, level=1) 21 | lst = p.__all__ if '__all__' in dir(p) else dir(p) 22 | for k in lst: 23 | if not k.startswith('__'): 24 | globals()[k] = p.__dict__[k] 25 | __all__.append(k) 26 | 27 | 28 | _global_import('utils') 29 | 30 | # Import no other submodules. they are supposed to be explicitly imported by users. 31 | __all__.extend(['logger']) 32 | -------------------------------------------------------------------------------- /dataflow/utils/argtools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: argtools.py 3 | 4 | 5 | import inspect 6 | import functools 7 | 8 | from . import logger 9 | 10 | __all__ = ['map_arg', 'memoized', 'memoized_method', 'graph_memoized', 'shape2d', 'shape4d', 11 | 'memoized_ignoreargs', 'log_once'] 12 | 13 | 14 | def map_arg(**maps): 15 | """ 16 | Apply a mapping on certain argument before calling the original function. 17 | 18 | Args: 19 | maps (dict): {argument_name: map_func} 20 | """ 21 | def deco(func): 22 | @functools.wraps(func) 23 | def wrapper(*args, **kwargs): 24 | # getcallargs was deprecated since 3.5 25 | sig = inspect.signature(func) 26 | argmap = sig.bind_partial(*args, **kwargs).arguments 27 | for k, map_func in maps.items(): 28 | if k in argmap: 29 | argmap[k] = map_func(argmap[k]) 30 | return func(**argmap) 31 | return wrapper 32 | return deco 33 | 34 | 35 | memoized = functools.lru_cache(maxsize=None) 36 | """ Alias to :func:`functools.lru_cache` 37 | WARNING: memoization will keep keys and values alive! 38 | """ 39 | 40 | 41 | def graph_memoized(func): 42 | """ 43 | Like memoized, but keep one cache per default graph. 44 | """ 45 | 46 | # TODO it keeps the graph alive 47 | from ..compat import tfv1 48 | GRAPH_ARG_NAME = '__IMPOSSIBLE_NAME_FOR_YOU__' 49 | 50 | @memoized 51 | def func_with_graph_arg(*args, **kwargs): 52 | kwargs.pop(GRAPH_ARG_NAME) 53 | return func(*args, **kwargs) 54 | 55 | @functools.wraps(func) 56 | def wrapper(*args, **kwargs): 57 | assert GRAPH_ARG_NAME not in kwargs, "No Way!!" 58 | graph = tfv1.get_default_graph() 59 | kwargs[GRAPH_ARG_NAME] = graph 60 | return func_with_graph_arg(*args, **kwargs) 61 | return wrapper 62 | 63 | 64 | _MEMOIZED_NOARGS = {} 65 | 66 | 67 | def memoized_ignoreargs(func): 68 | """ 69 | A decorator. It performs memoization ignoring the arguments used to call 70 | the function. 71 | """ 72 | def wrapper(*args, **kwargs): 73 | if func not in _MEMOIZED_NOARGS: 74 | res = func(*args, **kwargs) 75 | _MEMOIZED_NOARGS[func] = res 76 | return res 77 | return _MEMOIZED_NOARGS[func] 78 | return wrapper 79 | 80 | 81 | def shape2d(a): 82 | """ 83 | Ensure a 2D shape. 84 | 85 | Args: 86 | a: a int or tuple/list of length 2 87 | 88 | Returns: 89 | list: of length 2. if ``a`` is a int, return ``[a, a]``. 90 | """ 91 | if type(a) == int: 92 | return [a, a] 93 | if isinstance(a, (list, tuple)): 94 | assert len(a) == 2 95 | return list(a) 96 | raise RuntimeError("Illegal shape: {}".format(a)) 97 | 98 | 99 | def get_data_format(data_format, keras_mode=True): 100 | if keras_mode: 101 | dic = {'NCHW': 'channels_first', 'NHWC': 'channels_last'} 102 | else: 103 | dic = {'channels_first': 'NCHW', 'channels_last': 'NHWC'} 104 | ret = dic.get(data_format, data_format) 105 | if ret not in dic.values(): 106 | raise ValueError("Unknown data_format: {}".format(data_format)) 107 | return ret 108 | 109 | 110 | def shape4d(a, data_format='NHWC'): 111 | """ 112 | Ensuer a 4D shape, to use with 4D symbolic functions. 113 | 114 | Args: 115 | a: a int or tuple/list of length 2 116 | 117 | Returns: 118 | list: of length 4. if ``a`` is a int, return ``[1, a, a, 1]`` 119 | or ``[1, 1, a, a]`` depending on data_format. 120 | """ 121 | s2d = shape2d(a) 122 | if get_data_format(data_format, False) == 'NHWC': 123 | return [1] + s2d + [1] 124 | else: 125 | return [1, 1] + s2d 126 | 127 | 128 | @memoized 129 | def log_once(message, func='info'): 130 | """ 131 | Log certain message only once. Call this function more than one times with 132 | the same message will result in no-op. 133 | 134 | Args: 135 | message(str): message to log 136 | func(str): the name of the logger method. e.g. "info", "warn", "error". 137 | """ 138 | getattr(logger, func)(message) 139 | 140 | 141 | def call_only_once(func): 142 | """ 143 | Decorate a method or property of a class, so that this method can only 144 | be called once for every instance. 145 | Calling it more than once will result in exception. 146 | """ 147 | @functools.wraps(func) 148 | def wrapper(*args, **kwargs): 149 | self = args[0] 150 | # cannot use hasattr here, because hasattr tries to getattr, which 151 | # fails if func is a property 152 | assert func.__name__ in dir(self), "call_only_once can only be used on method or property!" 153 | 154 | if not hasattr(self, '_CALL_ONLY_ONCE_CACHE'): 155 | cache = self._CALL_ONLY_ONCE_CACHE = set() 156 | else: 157 | cache = self._CALL_ONLY_ONCE_CACHE 158 | 159 | cls = type(self) 160 | # cannot use ismethod(), because decorated method becomes a function 161 | is_method = inspect.isfunction(getattr(cls, func.__name__)) 162 | assert func not in cache, \ 163 | "{} {}.{} can only be called once per object!".format( 164 | 'Method' if is_method else 'Property', 165 | cls.__name__, func.__name__) 166 | cache.add(func) 167 | 168 | return func(*args, **kwargs) 169 | 170 | return wrapper 171 | 172 | 173 | def memoized_method(func): 174 | """ 175 | A decorator that performs memoization on methods. It stores the cache on the object instance itself. 176 | """ 177 | 178 | @functools.wraps(func) 179 | def wrapper(*args, **kwargs): 180 | self = args[0] 181 | assert func.__name__ in dir(self), "memoized_method can only be used on method!" 182 | 183 | if not hasattr(self, '_MEMOIZED_CACHE'): 184 | cache = self._MEMOIZED_CACHE = {} 185 | else: 186 | cache = self._MEMOIZED_CACHE 187 | 188 | key = (func, ) + args[1:] + tuple(kwargs) 189 | ret = cache.get(key, None) 190 | if ret is not None: 191 | return ret 192 | value = func(*args, **kwargs) 193 | cache[key] = value 194 | return value 195 | 196 | return wrapper 197 | 198 | 199 | if __name__ == '__main__': 200 | class A(): 201 | def __init__(self): 202 | self._p = 0 203 | 204 | @call_only_once 205 | def f(self, x): 206 | print(x) 207 | 208 | @property 209 | def p(self): 210 | return self._p 211 | 212 | @p.setter 213 | @call_only_once 214 | def p(self, val): 215 | self._p = val 216 | 217 | a = A() 218 | a.f(1) 219 | 220 | b = A() 221 | b.f(2) 222 | b.f(1) 223 | 224 | print(b.p) 225 | print(b.p) 226 | b.p = 2 227 | print(b.p) 228 | b.p = 3 229 | print(b.p) 230 | -------------------------------------------------------------------------------- /dataflow/utils/compatible_serialize.py: -------------------------------------------------------------------------------- 1 | from .serialize import loads, dumps # noqa 2 | 3 | # keep this file for BC 4 | -------------------------------------------------------------------------------- /dataflow/utils/concurrency.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: concurrency.py 3 | 4 | # Some code taken from zxytim 5 | 6 | import sys 7 | import atexit 8 | import bisect 9 | import multiprocessing as mp 10 | import platform 11 | import signal 12 | import threading 13 | import weakref 14 | from contextlib import contextmanager 15 | import six 16 | from six.moves import queue 17 | import subprocess 18 | 19 | from . import logger 20 | from .argtools import log_once 21 | 22 | 23 | __all__ = ['StoppableThread', 'LoopThread', 'ShareSessionThread', 24 | 'ensure_proc_terminate', 25 | 'start_proc_mask_signal'] 26 | 27 | 28 | class StoppableThread(threading.Thread): 29 | """ 30 | A thread that has a 'stop' event. 31 | """ 32 | 33 | def __init__(self, evt=None): 34 | """ 35 | Args: 36 | evt(threading.Event): if None, will create one. 37 | """ 38 | super(StoppableThread, self).__init__() 39 | if evt is None: 40 | evt = threading.Event() 41 | self._stop_evt = evt 42 | 43 | def stop(self): 44 | """ Stop the thread""" 45 | self._stop_evt.set() 46 | 47 | def stopped(self): 48 | """ 49 | Returns: 50 | bool: whether the thread is stopped or not 51 | """ 52 | return self._stop_evt.isSet() 53 | 54 | def queue_put_stoppable(self, q, obj): 55 | """ Put obj to queue, but will give up when the thread is stopped""" 56 | while not self.stopped(): 57 | try: 58 | q.put(obj, timeout=5) 59 | break 60 | except queue.Full: 61 | pass 62 | 63 | def queue_get_stoppable(self, q): 64 | """ Take obj from queue, but will give up when the thread is stopped""" 65 | while not self.stopped(): 66 | try: 67 | return q.get(timeout=5) 68 | except queue.Empty: 69 | pass 70 | 71 | 72 | class LoopThread(StoppableThread): 73 | """ A pausable thread that simply runs a loop""" 74 | 75 | def __init__(self, func, pausable=True): 76 | """ 77 | Args: 78 | func: the function to run 79 | """ 80 | super(LoopThread, self).__init__() 81 | self._func = func 82 | self._pausable = pausable 83 | if pausable: 84 | self._lock = threading.Lock() 85 | self.daemon = True 86 | 87 | def run(self): 88 | while not self.stopped(): 89 | if self._pausable: 90 | self._lock.acquire() 91 | self._lock.release() 92 | self._func() 93 | 94 | def pause(self): 95 | """ Pause the loop """ 96 | assert self._pausable 97 | self._lock.acquire() 98 | 99 | def resume(self): 100 | """ Resume the loop """ 101 | assert self._pausable 102 | self._lock.release() 103 | 104 | 105 | class ShareSessionThread(threading.Thread): 106 | """ A wrapper around thread so that the thread 107 | uses the default session at "start()" time. 108 | """ 109 | def __init__(self, th=None): 110 | """ 111 | Args: 112 | th (threading.Thread or None): 113 | """ 114 | super(ShareSessionThread, self).__init__() 115 | if th is not None: 116 | assert isinstance(th, threading.Thread), th 117 | self._th = th 118 | self.name = th.name 119 | self.daemon = th.daemon 120 | 121 | @contextmanager 122 | def default_sess(self): 123 | if self._sess: 124 | with self._sess.as_default(): 125 | yield self._sess 126 | else: 127 | logger.warn("ShareSessionThread {} wasn't under a default session!".format(self.name)) 128 | yield None 129 | 130 | def start(self): 131 | from ..compat import tfv1 132 | self._sess = tfv1.get_default_session() 133 | super(ShareSessionThread, self).start() 134 | 135 | def run(self): 136 | if not self._th: 137 | raise NotImplementedError() 138 | with self._sess.as_default(): 139 | self._th.run() 140 | 141 | 142 | class DIE(object): 143 | """ A placeholder class indicating end of queue """ 144 | pass 145 | 146 | 147 | def ensure_proc_terminate(proc): 148 | """ 149 | Make sure processes terminate when main process exit. 150 | 151 | Args: 152 | proc (multiprocessing.Process or list) 153 | """ 154 | if isinstance(proc, list): 155 | for p in proc: 156 | ensure_proc_terminate(p) 157 | return 158 | 159 | def stop_proc_by_weak_ref(ref): 160 | proc = ref() 161 | if proc is None: 162 | return 163 | if not proc.is_alive(): 164 | return 165 | proc.terminate() 166 | proc.join() 167 | 168 | assert isinstance(proc, mp.Process) 169 | atexit.register(stop_proc_by_weak_ref, weakref.ref(proc)) 170 | 171 | 172 | def enable_death_signal(_warn=True): 173 | """ 174 | Set the "death signal" of the current process, so that 175 | the current process will be cleaned with guarantee 176 | in case the parent dies accidentally. 177 | """ 178 | if platform.system() != 'Linux': 179 | return 180 | try: 181 | import prctl # pip install python-prctl 182 | except ImportError: 183 | if _warn: 184 | log_once('"import prctl" failed! Install python-prctl so that processes can be cleaned with guarantee.', 185 | 'warn') 186 | return 187 | else: 188 | assert hasattr(prctl, 'set_pdeathsig'), \ 189 | "prctl.set_pdeathsig does not exist! Note that you need to install 'python-prctl' instead of 'prctl'." 190 | # is SIGHUP a good choice? 191 | prctl.set_pdeathsig(signal.SIGHUP) 192 | 193 | 194 | def is_main_thread(): 195 | if six.PY2: 196 | return isinstance(threading.current_thread(), threading._MainThread) 197 | else: 198 | # a nicer solution with py3 199 | return threading.current_thread() == threading.main_thread() 200 | 201 | 202 | @contextmanager 203 | def mask_sigint(): 204 | """ 205 | Returns: 206 | If called in main thread, returns a context where ``SIGINT`` is ignored, and yield True. 207 | Otherwise yield False. 208 | """ 209 | if is_main_thread(): 210 | sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN) 211 | yield True 212 | signal.signal(signal.SIGINT, sigint_handler) 213 | else: 214 | yield False 215 | 216 | 217 | def start_proc_mask_signal(proc): 218 | """ 219 | Start process(es) with SIGINT ignored. 220 | 221 | Args: 222 | proc: (mp.Process or list) 223 | 224 | Note: 225 | The signal mask is only applied when called from main thread. 226 | """ 227 | if not isinstance(proc, list): 228 | proc = [proc] 229 | 230 | with mask_sigint(): 231 | for p in proc: 232 | if isinstance(p, mp.Process): 233 | if sys.version_info < (3, 4) or mp.get_start_method() == 'fork': 234 | log_once(""" 235 | Starting a process with 'fork' method is efficient but not safe and may cause deadlock or crash. 236 | Use 'forkserver' or 'spawn' method instead if you run into such issues. 237 | See https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods on how to set them. 238 | """.replace("\n", ""), 239 | 'warn') # noqa 240 | p.start() 241 | 242 | 243 | def subproc_call(cmd, timeout=None): 244 | """ 245 | Execute a command with timeout, and return STDOUT and STDERR 246 | 247 | Args: 248 | cmd(str): the command to execute. 249 | timeout(float): timeout in seconds. 250 | 251 | Returns: 252 | output(bytes), retcode(int). If timeout, retcode is -1. 253 | """ 254 | try: 255 | output = subprocess.check_output( 256 | cmd, stderr=subprocess.STDOUT, 257 | shell=True, timeout=timeout) 258 | return output, 0 259 | except subprocess.TimeoutExpired as e: 260 | logger.warn("Command '{}' timeout!".format(cmd)) 261 | if e.output: 262 | logger.warn(e.output.decode('utf-8')) 263 | return e.output, -1 264 | else: 265 | return "", -1 266 | except subprocess.CalledProcessError as e: 267 | logger.warn("Command '{}' failed, return code={}".format(cmd, e.returncode)) 268 | logger.warn(e.output.decode('utf-8')) 269 | return e.output, e.returncode 270 | except Exception: 271 | logger.warn("Command '{}' failed to run.".format(cmd)) 272 | return "", -2 273 | 274 | 275 | class OrderedContainer(object): 276 | """ 277 | Like a queue, but will always wait to receive item with rank 278 | (x+1) and produce (x+1) before producing (x+2). 279 | 280 | Warning: 281 | It is not thread-safe. 282 | """ 283 | 284 | def __init__(self, start=0): 285 | """ 286 | Args: 287 | start(int): the starting rank. 288 | """ 289 | self.ranks = [] 290 | self.data = [] 291 | self.wait_for = start 292 | 293 | def put(self, rank, val): 294 | """ 295 | Args: 296 | rank(int): rank of th element. All elements must have different ranks. 297 | val: an object 298 | """ 299 | idx = bisect.bisect(self.ranks, rank) 300 | self.ranks.insert(idx, rank) 301 | self.data.insert(idx, val) 302 | 303 | def has_next(self): 304 | if len(self.ranks) == 0: 305 | return False 306 | return self.ranks[0] == self.wait_for 307 | 308 | def get(self): 309 | assert self.has_next() 310 | ret = self.data[0] 311 | rank = self.ranks[0] 312 | del self.ranks[0] 313 | del self.data[0] 314 | self.wait_for += 1 315 | return rank, ret 316 | 317 | 318 | class OrderedResultGatherProc(mp.Process): 319 | """ 320 | Gather indexed data from a data queue, and produce results with the 321 | original index-based order. 322 | """ 323 | 324 | def __init__(self, data_queue, nr_producer, start=0): 325 | """ 326 | Args: 327 | data_queue(mp.Queue): a queue which contains datapoints. 328 | nr_producer(int): number of producer processes. This process will 329 | terminate after receiving this many of :class:`DIE` sentinel. 330 | start(int): the rank of the first object 331 | """ 332 | super(OrderedResultGatherProc, self).__init__() 333 | self.data_queue = data_queue 334 | self.ordered_container = OrderedContainer(start=start) 335 | self.result_queue = mp.Queue() 336 | self.nr_producer = nr_producer 337 | 338 | def run(self): 339 | nr_end = 0 340 | try: 341 | while True: 342 | task_id, data = self.data_queue.get() 343 | if task_id == DIE: 344 | self.result_queue.put((task_id, data)) 345 | nr_end += 1 346 | if nr_end == self.nr_producer: 347 | return 348 | else: 349 | self.ordered_container.put(task_id, data) 350 | while self.ordered_container.has_next(): 351 | self.result_queue.put(self.ordered_container.get()) 352 | except Exception as e: 353 | import traceback 354 | traceback.print_exc() 355 | raise e 356 | 357 | def get(self): 358 | return self.result_queue.get() 359 | -------------------------------------------------------------------------------- /dataflow/utils/develop.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: develop.py 3 | # Author: tensorpack contributors 4 | 5 | 6 | """ Utilities for developers only. 7 | These are not visible to users (not automatically imported). And should not 8 | appeared in docs.""" 9 | import functools 10 | import importlib 11 | import os 12 | import types 13 | from collections import defaultdict 14 | from datetime import datetime 15 | import six 16 | 17 | from . import logger 18 | 19 | __all__ = [] 20 | 21 | 22 | def create_dummy_class(klass, dependency): 23 | """ 24 | When a dependency of a class is not available, create a dummy class which throws ImportError when used. 25 | 26 | Args: 27 | klass (str): name of the class. 28 | dependency (str): name of the dependency. 29 | 30 | Returns: 31 | class: a class object 32 | """ 33 | assert not building_rtfd() 34 | 35 | class _DummyMetaClass(type): 36 | # throw error on class attribute access 37 | def __getattr__(_, __): 38 | raise AttributeError("Cannot import '{}', therefore '{}' is not available".format(dependency, klass)) 39 | 40 | @six.add_metaclass(_DummyMetaClass) 41 | class _Dummy(object): 42 | # throw error on constructor 43 | def __init__(self, *args, **kwargs): 44 | raise ImportError("Cannot import '{}', therefore '{}' is not available".format(dependency, klass)) 45 | 46 | return _Dummy 47 | 48 | 49 | def create_dummy_func(func, dependency): 50 | """ 51 | When a dependency of a function is not available, create a dummy function which throws ImportError when used. 52 | 53 | Args: 54 | func (str): name of the function. 55 | dependency (str or list[str]): name(s) of the dependency. 56 | 57 | Returns: 58 | function: a function object 59 | """ 60 | assert not building_rtfd() 61 | 62 | if isinstance(dependency, (list, tuple)): 63 | dependency = ','.join(dependency) 64 | 65 | def _dummy(*args, **kwargs): 66 | raise ImportError("Cannot import '{}', therefore '{}' is not available".format(dependency, func)) 67 | return _dummy 68 | 69 | 70 | def building_rtfd(): 71 | """ 72 | Returns: 73 | bool: if the library is being imported to generate docs now. 74 | """ 75 | return os.environ.get('READTHEDOCS') == 'True' \ 76 | or os.environ.get('DOC_BUILDING') 77 | 78 | 79 | _DEPRECATED_LOG_NUM = defaultdict(int) 80 | 81 | 82 | def log_deprecated(name="", text="", eos="", max_num_warnings=None): 83 | """ 84 | Log deprecation warning. 85 | 86 | Args: 87 | name (str): name of the deprecated item. 88 | text (str, optional): information about the deprecation. 89 | eos (str, optional): end of service date such as "YYYY-MM-DD". 90 | max_num_warnings (int, optional): the maximum number of times to print this warning 91 | """ 92 | assert name or text 93 | if eos: 94 | eos = "after " + datetime(*map(int, eos.split("-"))).strftime("%d %b") 95 | if name: 96 | if eos: 97 | warn_msg = "%s will be deprecated %s. %s" % (name, eos, text) 98 | else: 99 | warn_msg = "%s was deprecated. %s" % (name, text) 100 | else: 101 | warn_msg = text 102 | if eos: 103 | warn_msg += " Legacy period ends %s" % eos 104 | 105 | if max_num_warnings is not None: 106 | if _DEPRECATED_LOG_NUM[warn_msg] >= max_num_warnings: 107 | return 108 | _DEPRECATED_LOG_NUM[warn_msg] += 1 109 | logger.warn("[Deprecated] " + warn_msg) 110 | 111 | 112 | def deprecated(text="", eos="", max_num_warnings=None): 113 | """ 114 | Args: 115 | text, eos, max_num_warnings: same as :func:`log_deprecated`. 116 | 117 | Returns: 118 | a decorator which deprecates the function. 119 | 120 | Example: 121 | .. code-block:: python 122 | 123 | @deprecated("Explanation of what to do instead.", "2017-11-4") 124 | def foo(...): 125 | pass 126 | """ 127 | 128 | def get_location(): 129 | import inspect 130 | frame = inspect.currentframe() 131 | if frame: 132 | callstack = inspect.getouterframes(frame)[-1] 133 | return '%s:%i' % (callstack[1], callstack[2]) 134 | else: 135 | stack = inspect.stack(0) 136 | entry = stack[2] 137 | return '%s:%i' % (entry[1], entry[2]) 138 | 139 | def deprecated_inner(func): 140 | @functools.wraps(func) 141 | def new_func(*args, **kwargs): 142 | name = "{} [{}]".format(func.__name__, get_location()) 143 | log_deprecated(name, text, eos, max_num_warnings=max_num_warnings) 144 | return func(*args, **kwargs) 145 | return new_func 146 | return deprecated_inner 147 | 148 | 149 | def HIDE_DOC(func): 150 | func.__HIDE_SPHINX_DOC__ = True 151 | return func 152 | 153 | 154 | # Copied from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/lazy_loader.py 155 | class LazyLoader(types.ModuleType): 156 | def __init__(self, local_name, parent_module_globals, name): 157 | self._local_name = local_name 158 | self._parent_module_globals = parent_module_globals 159 | super(LazyLoader, self).__init__(name) 160 | 161 | def _load(self): 162 | # Import the target module and insert it into the parent's namespace 163 | module = importlib.import_module(self.__name__) 164 | self._parent_module_globals[self._local_name] = module 165 | 166 | # Update this object's dict so that if someone keeps a reference to the 167 | # LazyLoader, lookups are efficient (__getattr__ is only called on lookups 168 | # that fail). 169 | self.__dict__.update(module.__dict__) 170 | 171 | return module 172 | 173 | def __getattr__(self, item): 174 | module = self._load() 175 | return getattr(module, item) 176 | 177 | def __dir__(self): 178 | module = self._load() 179 | return dir(module) 180 | -------------------------------------------------------------------------------- /dataflow/utils/fs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: fs.py 3 | 4 | 5 | import errno 6 | import os 7 | import tqdm 8 | from six.moves import urllib 9 | 10 | from . import logger 11 | from .utils import execute_only_once 12 | 13 | __all__ = ['mkdir_p', 'download', 'recursive_walk', 'get_dataset_path', 'normpath'] 14 | 15 | 16 | def mkdir_p(dirname): 17 | """ Like "mkdir -p", make a dir recursively, but do nothing if the dir exists 18 | 19 | Args: 20 | dirname(str): 21 | """ 22 | assert dirname is not None 23 | if dirname == '' or os.path.isdir(dirname): 24 | return 25 | try: 26 | os.makedirs(dirname) 27 | except OSError as e: 28 | if e.errno != errno.EEXIST: 29 | raise e 30 | 31 | 32 | def download(url, dir, filename=None, expect_size=None): 33 | """ 34 | Download URL to a directory. 35 | Will figure out the filename automatically from URL, if not given. 36 | """ 37 | mkdir_p(dir) 38 | if filename is None: 39 | filename = url.split('/')[-1] 40 | fpath = os.path.join(dir, filename) 41 | 42 | if os.path.isfile(fpath): 43 | if expect_size is not None and os.stat(fpath).st_size == expect_size: 44 | logger.info("File {} exists! Skip download.".format(filename)) 45 | return fpath 46 | else: 47 | logger.warn("File {} exists. Will overwrite with a new download!".format(filename)) 48 | 49 | def hook(t): 50 | last_b = [0] 51 | 52 | def inner(b, bsize, tsize=None): 53 | if tsize is not None: 54 | t.total = tsize 55 | t.update((b - last_b[0]) * bsize) 56 | last_b[0] = b 57 | return inner 58 | try: 59 | with tqdm.tqdm(unit='B', unit_scale=True, miniters=1, desc=filename) as t: 60 | fpath, _ = urllib.request.urlretrieve(url, fpath, reporthook=hook(t)) 61 | statinfo = os.stat(fpath) 62 | size = statinfo.st_size 63 | except IOError: 64 | logger.error("Failed to download {}".format(url)) 65 | raise 66 | assert size > 0, "Downloaded an empty file from {}!".format(url) 67 | 68 | if expect_size is not None and size != expect_size: 69 | logger.error("File downloaded from {} does not match the expected size!".format(url)) 70 | logger.error("You may have downloaded a broken file, or the upstream may have modified the file.") 71 | 72 | # TODO human-readable size 73 | logger.info('Succesfully downloaded ' + filename + ". " + str(size) + ' bytes.') 74 | return fpath 75 | 76 | 77 | def recursive_walk(rootdir): 78 | """ 79 | Yields: 80 | str: All files in rootdir, recursively. 81 | """ 82 | for r, dirs, files in os.walk(rootdir): 83 | for f in files: 84 | yield os.path.join(r, f) 85 | 86 | 87 | def get_dataset_path(*args): 88 | """ 89 | Get the path to some dataset under ``$TENSORPACK_DATASET``. 90 | 91 | Args: 92 | args: strings to be joined to form path. 93 | 94 | Returns: 95 | str: path to the dataset. 96 | """ 97 | d = os.environ.get('TENSORPACK_DATASET', None) 98 | if d is None: 99 | d = os.path.join(os.path.expanduser('~'), 'tensorpack_data') 100 | if execute_only_once(): 101 | logger.warn("Env var $TENSORPACK_DATASET not set, using {} for datasets.".format(d)) 102 | if not os.path.isdir(d): 103 | mkdir_p(d) 104 | logger.info("Created the directory {}.".format(d)) 105 | assert os.path.isdir(d), d 106 | return os.path.join(d, *args) 107 | 108 | 109 | def normpath(path): 110 | """ 111 | Normalizes a path to a folder by taking into consideration remote storages like Cloud storaged 112 | referenced by '://' at the beginning of the path. 113 | 114 | Args: 115 | args: path to be normalized. 116 | 117 | Returns: 118 | str: normalized path. 119 | """ 120 | return path if '://' in path else os.path.normpath(path) 121 | 122 | 123 | if __name__ == '__main__': 124 | download('http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz', '.') 125 | -------------------------------------------------------------------------------- /dataflow/utils/loadcaffe.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: loadcaffe.py 3 | 4 | 5 | import numpy as np 6 | import os 7 | import sys 8 | 9 | from . import logger 10 | from .concurrency import subproc_call 11 | from .fs import download, get_dataset_path 12 | from .utils import change_env 13 | 14 | __all__ = ['load_caffe', 'get_caffe_pb'] 15 | 16 | CAFFE_PROTO_URL = "https://github.com/BVLC/caffe/raw/master/src/caffe/proto/caffe.proto" 17 | 18 | 19 | class CaffeLayerProcessor(object): 20 | 21 | def __init__(self, net): 22 | self.net = net 23 | self.layer_names = net._layer_names 24 | self.param_dict = {} 25 | self.processors = { 26 | 'Convolution': self.proc_conv, 27 | 'InnerProduct': self.proc_fc, 28 | 'BatchNorm': self.proc_bn, 29 | 'Scale': self.proc_scale 30 | } 31 | 32 | def process(self): 33 | for idx, layer in enumerate(self.net.layers): 34 | param = layer.blobs 35 | name = self.layer_names[idx] 36 | if layer.type in self.processors: 37 | logger.info("Processing layer {} of type {}".format( 38 | name, layer.type)) 39 | dic = self.processors[layer.type](idx, name, param) 40 | self.param_dict.update(dic) 41 | elif len(layer.blobs) != 0: 42 | logger.warn( 43 | "{} layer contains parameters but is not supported!".format(layer.type)) 44 | return self.param_dict 45 | 46 | def proc_conv(self, idx, name, param): 47 | assert len(param) <= 2 48 | assert param[0].data.ndim == 4 49 | # caffe: ch_out, ch_in, h, w 50 | W = param[0].data.transpose(2, 3, 1, 0) 51 | if len(param) == 1: 52 | return {name + '/W': W} 53 | else: 54 | return {name + '/W': W, 55 | name + '/b': param[1].data} 56 | 57 | def proc_fc(self, idx, name, param): 58 | # TODO caffe has an 'transpose' option for fc/W 59 | assert len(param) == 2 60 | prev_layer_name = self.net.bottom_names[name][0] 61 | prev_layer_output = self.net.blobs[prev_layer_name].data 62 | if prev_layer_output.ndim == 4: 63 | logger.info("FC layer {} takes spatial data.".format(name)) 64 | W = param[0].data 65 | # original: outx(CxHxW) 66 | W = W.reshape((-1,) + prev_layer_output.shape[1:]).transpose(2, 3, 1, 0) 67 | # become: (HxWxC)xout 68 | else: 69 | W = param[0].data.transpose() 70 | return {name + '/W': W, 71 | name + '/b': param[1].data} 72 | 73 | def proc_bn(self, idx, name, param): 74 | scale_factor = param[2].data[0] 75 | return {name + '/mean/EMA': param[0].data / scale_factor, 76 | name + '/variance/EMA': param[1].data / scale_factor} 77 | 78 | def proc_scale(self, idx, name, param): 79 | bottom_name = self.net.bottom_names[name][0] 80 | # find the bn layer before this scaling 81 | for i, layer in enumerate(self.net.layers): 82 | if layer.type == 'BatchNorm': 83 | name2 = self.layer_names[i] 84 | bottom_name2 = self.net.bottom_names[name2][0] 85 | if bottom_name2 == bottom_name: 86 | # scaling and BN share the same bottom, should merge 87 | logger.info("Merge {} and {} into one BatchNorm layer".format( 88 | name, name2)) 89 | return {name2 + '/beta': param[1].data, 90 | name2 + '/gamma': param[0].data} 91 | # assume this scaling layer is part of some BN 92 | logger.error("Could not find a BN layer corresponding to this Scale layer!") 93 | raise ValueError() 94 | 95 | 96 | def load_caffe(model_desc, model_file): 97 | """ 98 | Load a caffe model. You must be able to ``import caffe`` to use this 99 | function. 100 | 101 | Args: 102 | model_desc (str): path to caffe model description file (.prototxt). 103 | model_file (str): path to caffe model parameter file (.caffemodel). 104 | Returns: 105 | dict: the parameters. 106 | """ 107 | with change_env('GLOG_minloglevel', '2'): 108 | import caffe 109 | caffe.set_mode_cpu() 110 | net = caffe.Net(model_desc, model_file, caffe.TEST) 111 | param_dict = CaffeLayerProcessor(net).process() 112 | logger.info("Model loaded from caffe. Params: " + 113 | ", ".join(sorted(param_dict.keys()))) 114 | return param_dict 115 | 116 | 117 | def get_caffe_pb(): 118 | """ 119 | Get caffe protobuf. 120 | 121 | Returns: 122 | The imported caffe protobuf module. 123 | """ 124 | dir = get_dataset_path('caffe') 125 | caffe_pb_file = os.path.join(dir, 'caffe_pb2.py') 126 | if not os.path.isfile(caffe_pb_file): 127 | download(CAFFE_PROTO_URL, dir) 128 | assert os.path.isfile(os.path.join(dir, 'caffe.proto')) 129 | 130 | cmd = "protoc --version" 131 | version, ret = subproc_call(cmd, timeout=3) 132 | if ret != 0: 133 | sys.exit(1) 134 | try: 135 | version = version.decode('utf-8') 136 | version = float('.'.join(version.split(' ')[1].split('.')[:2])) 137 | assert version >= 2.7, "Require protoc>=2.7 for Python3" 138 | except Exception: 139 | logger.exception("protoc --version gives: " + str(version)) 140 | raise 141 | 142 | cmd = 'cd {} && protoc caffe.proto --python_out .'.format(dir) 143 | ret = os.system(cmd) 144 | assert ret == 0, \ 145 | "Command `{}` failed!".format(cmd) 146 | assert os.path.isfile(caffe_pb_file), caffe_pb_file 147 | import imp 148 | return imp.load_source('caffepb', caffe_pb_file) 149 | 150 | 151 | if __name__ == '__main__': 152 | import argparse 153 | parser = argparse.ArgumentParser() 154 | parser.add_argument('model', help='.prototxt file') 155 | parser.add_argument('weights', help='.caffemodel file') 156 | parser.add_argument('output', help='output npz file') 157 | args = parser.parse_args() 158 | ret = load_caffe(args.model, args.weights) 159 | 160 | if args.output.endswith('.npz'): 161 | np.savez_compressed(args.output, **ret) 162 | elif args.output.endswith('.npy'): 163 | logger.warn("Please use npz format instead!") 164 | np.save(args.output, ret) 165 | else: 166 | raise ValueError("Unknown format {}".format(args.output)) 167 | -------------------------------------------------------------------------------- /dataflow/utils/logger.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: logger.py 3 | 4 | """ 5 | The logger module itself has the common logging functions of Python's 6 | :class:`logging.Logger`. For example: 7 | 8 | .. code-block:: python 9 | 10 | from tensorpack.utils import logger 11 | logger.set_logger_dir('train_log/test') 12 | logger.info("Test") 13 | logger.error("Error happened!") 14 | """ 15 | 16 | 17 | import logging 18 | import os 19 | import os.path 20 | import shutil 21 | import sys 22 | from datetime import datetime 23 | from six.moves import input 24 | from termcolor import colored 25 | 26 | __all__ = ['set_logger_dir', 'auto_set_dir', 'get_logger_dir'] 27 | 28 | 29 | class _MyFormatter(logging.Formatter): 30 | def format(self, record): 31 | date = colored('[%(asctime)s @%(filename)s:%(lineno)d]', 'green') 32 | msg = '%(message)s' 33 | if record.levelno == logging.WARNING: 34 | fmt = date + ' ' + colored('WRN', 'red', attrs=['blink']) + ' ' + msg 35 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 36 | fmt = date + ' ' + colored('ERR', 'red', attrs=['blink', 'underline']) + ' ' + msg 37 | elif record.levelno == logging.DEBUG: 38 | fmt = date + ' ' + colored('DBG', 'yellow', attrs=['blink']) + ' ' + msg 39 | else: 40 | fmt = date + ' ' + msg 41 | if hasattr(self, '_style'): 42 | # Python3 compatibility 43 | self._style._fmt = fmt 44 | self._fmt = fmt 45 | return super(_MyFormatter, self).format(record) 46 | 47 | 48 | def _getlogger(): 49 | # this file is synced to "dataflow" package as well 50 | package_name = "dataflow" if __name__.startswith("dataflow") else "tensorpack" 51 | logger = logging.getLogger(package_name) 52 | logger.propagate = False 53 | logger.setLevel(logging.INFO) 54 | handler = logging.StreamHandler(sys.stdout) 55 | handler.setFormatter(_MyFormatter(datefmt='%m%d %H:%M:%S')) 56 | logger.addHandler(handler) 57 | return logger 58 | 59 | 60 | _logger = _getlogger() 61 | _LOGGING_METHOD = ['info', 'warning', 'error', 'critical', 'exception', 'debug', 'setLevel', 'addFilter'] 62 | # export logger functions 63 | for func in _LOGGING_METHOD: 64 | locals()[func] = getattr(_logger, func) 65 | __all__.append(func) 66 | # 'warn' is deprecated in logging module 67 | warn = _logger.warning 68 | __all__.append('warn') 69 | 70 | 71 | def _get_time_str(): 72 | return datetime.now().strftime('%m%d-%H%M%S') 73 | 74 | 75 | # globals: logger file and directory: 76 | LOG_DIR = None 77 | _FILE_HANDLER = None 78 | 79 | 80 | def _set_file(path): 81 | global _FILE_HANDLER 82 | if os.path.isfile(path): 83 | backup_name = path + '.' + _get_time_str() 84 | shutil.move(path, backup_name) 85 | _logger.info("Existing log file '{}' backuped to '{}'".format(path, backup_name)) # noqa: F821 86 | hdl = logging.FileHandler( 87 | filename=path, encoding='utf-8', mode='w') 88 | hdl.setFormatter(_MyFormatter(datefmt='%m%d %H:%M:%S')) 89 | 90 | _FILE_HANDLER = hdl 91 | _logger.addHandler(hdl) 92 | _logger.info("Argv: " + ' '.join(sys.argv)) 93 | 94 | 95 | def set_logger_dir(dirname, action=None): 96 | """ 97 | Set the directory for global logging. 98 | 99 | Args: 100 | dirname(str): log directory 101 | action(str): an action of ["k","d","q"] to be performed 102 | when the directory exists. Will ask user by default. 103 | 104 | "d": delete the directory. Note that the deletion may fail when 105 | the directory is used by tensorboard. 106 | 107 | "k": keep the directory. This is useful when you resume from a 108 | previous training and want the directory to look as if the 109 | training was not interrupted. 110 | Note that this option does not load old models or any other 111 | old states for you. It simply does nothing. 112 | 113 | """ 114 | dirname = os.path.normpath(dirname) 115 | global LOG_DIR, _FILE_HANDLER 116 | if _FILE_HANDLER: 117 | # unload and close the old file handler, so that we may safely delete the logger directory 118 | _logger.removeHandler(_FILE_HANDLER) 119 | del _FILE_HANDLER 120 | 121 | def dir_nonempty(dirname): 122 | # If directory exists and nonempty (ignore hidden files), prompt for action 123 | return os.path.isdir(dirname) and len([x for x in os.listdir(dirname) if x[0] != '.']) 124 | 125 | if dir_nonempty(dirname): 126 | if not action: 127 | _logger.warning("""\ 128 | Log directory {} exists! Use 'd' to delete it. """.format(dirname)) 129 | _logger.warning("""\ 130 | If you're resuming from a previous run, you can choose to keep it. 131 | Press any other key to exit. """) 132 | while not action: 133 | action = input("Select Action: k (keep) / d (delete) / q (quit):").lower().strip() 134 | act = action 135 | if act == 'b': 136 | backup_name = dirname + _get_time_str() 137 | shutil.move(dirname, backup_name) 138 | info("Directory '{}' backuped to '{}'".format(dirname, backup_name)) # noqa: F821 139 | elif act == 'd': 140 | shutil.rmtree(dirname, ignore_errors=True) 141 | if dir_nonempty(dirname): 142 | shutil.rmtree(dirname, ignore_errors=False) 143 | elif act == 'n': 144 | dirname = dirname + _get_time_str() 145 | info("Use a new log directory {}".format(dirname)) # noqa: F821 146 | elif act == 'k': 147 | pass 148 | else: 149 | raise OSError("Directory {} exits!".format(dirname)) 150 | LOG_DIR = dirname 151 | from .fs import mkdir_p 152 | mkdir_p(dirname) 153 | _set_file(os.path.join(dirname, 'log.log')) 154 | 155 | 156 | def auto_set_dir(action=None, name=None): 157 | """ 158 | Use :func:`logger.set_logger_dir` to set log directory to 159 | "./train_log/{scriptname}:{name}". "scriptname" is the name of the main python file currently running""" 160 | mod = sys.modules['__main__'] 161 | basename = os.path.basename(mod.__file__) 162 | auto_dirname = os.path.join('train_log', basename[:basename.rfind('.')]) 163 | if name: 164 | auto_dirname += '_%s' % name if os.name == 'nt' else ':%s' % name 165 | set_logger_dir(auto_dirname, action=action) 166 | 167 | 168 | def get_logger_dir(): 169 | """ 170 | Returns: 171 | The logger directory, or None if not set. 172 | The directory is used for general logging, tensorboard events, checkpoints, etc. 173 | """ 174 | return LOG_DIR 175 | -------------------------------------------------------------------------------- /dataflow/utils/serialize.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: serialize.py 3 | 4 | import os 5 | 6 | import pickle 7 | from multiprocessing.reduction import ForkingPickler 8 | import msgpack 9 | import msgpack_numpy 10 | 11 | msgpack_numpy.patch() 12 | assert msgpack.version >= (0, 5, 2) 13 | 14 | __all__ = ['loads', 'dumps'] 15 | 16 | 17 | MAX_MSGPACK_LEN = 1000000000 18 | 19 | 20 | class MsgpackSerializer(object): 21 | 22 | @staticmethod 23 | def dumps(obj): 24 | """ 25 | Serialize an object. 26 | 27 | Returns: 28 | Implementation-dependent bytes-like object. 29 | """ 30 | return msgpack.dumps(obj, use_bin_type=True) 31 | 32 | @staticmethod 33 | def loads(buf): 34 | """ 35 | Args: 36 | buf: the output of `dumps`. 37 | """ 38 | # Since 0.6, the default max size was set to 1MB. 39 | # We change it to approximately 1G. 40 | return msgpack.loads(buf, raw=False, 41 | max_bin_len=MAX_MSGPACK_LEN, 42 | max_array_len=MAX_MSGPACK_LEN, 43 | max_map_len=MAX_MSGPACK_LEN, 44 | max_str_len=MAX_MSGPACK_LEN) 45 | 46 | 47 | class PyarrowSerializer(object): 48 | @staticmethod 49 | def dumps(obj): 50 | """ 51 | Serialize an object. 52 | 53 | Returns: 54 | Implementation-dependent bytes-like object. 55 | May not be compatible across different versions of pyarrow. 56 | """ 57 | import pyarrow as pa 58 | return pa.serialize(obj).to_buffer() 59 | 60 | @staticmethod 61 | def dumps_bytes(obj): 62 | """ 63 | Returns: 64 | bytes 65 | """ 66 | return PyarrowSerializer.dumps(obj).to_pybytes() 67 | 68 | @staticmethod 69 | def loads(buf): 70 | """ 71 | Args: 72 | buf: the output of `dumps` or `dumps_bytes`. 73 | """ 74 | import pyarrow as pa 75 | return pa.deserialize(buf) 76 | 77 | 78 | class PickleSerializer(object): 79 | @staticmethod 80 | def dumps(obj): 81 | """ 82 | Returns: 83 | bytes 84 | """ 85 | return pickle.dumps(obj, protocol=-1) 86 | 87 | @staticmethod 88 | def loads(buf): 89 | """ 90 | Args: 91 | bytes 92 | """ 93 | return pickle.loads(buf) 94 | 95 | 96 | # Define the default serializer to be used that dumps data to bytes 97 | _DEFAULT_S = os.environ.get('TENSORPACK_SERIALIZE', 'pickle') 98 | 99 | if _DEFAULT_S == "pyarrow": 100 | dumps = PyarrowSerializer.dumps_bytes 101 | loads = PyarrowSerializer.loads 102 | elif _DEFAULT_S == "pickle": 103 | dumps = PickleSerializer.dumps 104 | loads = PickleSerializer.loads 105 | else: 106 | dumps = MsgpackSerializer.dumps 107 | loads = MsgpackSerializer.loads 108 | 109 | # Define the default serializer to be used for passing data 110 | # among a pair of peers. In this case the deserialization is 111 | # known to happen only once 112 | _DEFAULT_S = os.environ.get('TENSORPACK_ONCE_SERIALIZE', 'pickle') 113 | 114 | if _DEFAULT_S == "pyarrow": 115 | dumps_once = PyarrowSerializer.dumps 116 | loads_once = PyarrowSerializer.loads 117 | elif _DEFAULT_S == "pickle": 118 | dumps_once = ForkingPickler.dumps 119 | loads_once = ForkingPickler.loads 120 | else: 121 | dumps_once = MsgpackSerializer.dumps 122 | loads_once = MsgpackSerializer.loads 123 | -------------------------------------------------------------------------------- /dataflow/utils/stats.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: stats.py 3 | 4 | import numpy as np 5 | 6 | __all__ = ['StatCounter', 'BinaryStatistics', 'RatioCounter', 'Accuracy', 7 | 'OnlineMoments'] 8 | 9 | 10 | class StatCounter(object): 11 | """ A simple counter""" 12 | 13 | def __init__(self): 14 | self.reset() 15 | 16 | def feed(self, v): 17 | """ 18 | Args: 19 | v(float or np.ndarray): has to be the same shape between calls. 20 | """ 21 | self._values.append(v) 22 | 23 | def reset(self): 24 | self._values = [] 25 | 26 | @property 27 | def count(self): 28 | return len(self._values) 29 | 30 | @property 31 | def average(self): 32 | assert len(self._values) 33 | return np.mean(self._values) 34 | 35 | @property 36 | def sum(self): 37 | assert len(self._values) 38 | return np.sum(self._values) 39 | 40 | @property 41 | def max(self): 42 | assert len(self._values) 43 | return max(self._values) 44 | 45 | @property 46 | def min(self): 47 | assert len(self._values) 48 | return min(self._values) 49 | 50 | def samples(self): 51 | """ 52 | Returns all samples. 53 | """ 54 | return self._values 55 | 56 | 57 | class RatioCounter(object): 58 | """ A counter to count ratio of something. """ 59 | 60 | def __init__(self): 61 | self.reset() 62 | 63 | def reset(self): 64 | self._tot = 0 65 | self._cnt = 0 66 | 67 | def feed(self, count, total=1): 68 | """ 69 | Args: 70 | cnt(int): the count of some event of interest. 71 | tot(int): the total number of events. 72 | """ 73 | self._tot += total 74 | self._cnt += count 75 | 76 | @property 77 | def ratio(self): 78 | if self._tot == 0: 79 | return 0 80 | return self._cnt * 1.0 / self._tot 81 | 82 | @property 83 | def total(self): 84 | """ 85 | Returns: 86 | int: the total 87 | """ 88 | return self._tot 89 | 90 | @property 91 | def count(self): 92 | """ 93 | Returns: 94 | int: the total 95 | """ 96 | return self._cnt 97 | 98 | 99 | class Accuracy(RatioCounter): 100 | """ A RatioCounter with a fancy name """ 101 | @property 102 | def accuracy(self): 103 | return self.ratio 104 | 105 | 106 | class BinaryStatistics(object): 107 | """ 108 | Statistics for binary decision, 109 | including precision, recall, false positive, false negative 110 | """ 111 | 112 | def __init__(self): 113 | self.reset() 114 | 115 | def reset(self): 116 | self.nr_pos = 0 # positive label 117 | self.nr_neg = 0 # negative label 118 | self.nr_pred_pos = 0 119 | self.nr_pred_neg = 0 120 | self.corr_pos = 0 # correct predict positive 121 | self.corr_neg = 0 # correct predict negative 122 | 123 | def feed(self, pred, label): 124 | """ 125 | Args: 126 | pred (np.ndarray): binary array. 127 | label (np.ndarray): binary array of the same size. 128 | """ 129 | assert pred.shape == label.shape, "{} != {}".format(pred.shape, label.shape) 130 | self.nr_pos += (label == 1).sum() 131 | self.nr_neg += (label == 0).sum() 132 | self.nr_pred_pos += (pred == 1).sum() 133 | self.nr_pred_neg += (pred == 0).sum() 134 | self.corr_pos += ((pred == 1) & (pred == label)).sum() 135 | self.corr_neg += ((pred == 0) & (pred == label)).sum() 136 | 137 | @property 138 | def precision(self): 139 | if self.nr_pred_pos == 0: 140 | return 0 141 | return self.corr_pos * 1. / self.nr_pred_pos 142 | 143 | @property 144 | def recall(self): 145 | if self.nr_pos == 0: 146 | return 0 147 | return self.corr_pos * 1. / self.nr_pos 148 | 149 | @property 150 | def false_positive(self): 151 | if self.nr_pred_pos == 0: 152 | return 0 153 | return 1 - self.precision 154 | 155 | @property 156 | def false_negative(self): 157 | if self.nr_pos == 0: 158 | return 0 159 | return 1 - self.recall 160 | 161 | 162 | class OnlineMoments(object): 163 | """Compute 1st and 2nd moments online (to avoid storing all elements). 164 | 165 | See algorithm at: https://www.wikiwand.com/en/Algorithms_for_calculating_variance#/Online_algorithm 166 | """ 167 | 168 | def __init__(self): 169 | self._mean = 0 170 | self._M2 = 0 171 | self._n = 0 172 | 173 | def feed(self, x): 174 | """ 175 | Args: 176 | x (float or np.ndarray): must have the same shape. 177 | """ 178 | self._n += 1 179 | delta = x - self._mean 180 | self._mean += delta * (1.0 / self._n) 181 | delta2 = x - self._mean 182 | self._M2 += delta * delta2 183 | 184 | @property 185 | def mean(self): 186 | return self._mean 187 | 188 | @property 189 | def variance(self): 190 | return self._M2 / (self._n - 1) 191 | 192 | @property 193 | def std(self): 194 | return np.sqrt(self.variance) 195 | -------------------------------------------------------------------------------- /dataflow/utils/timer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: timer.py 3 | 4 | 5 | import atexit 6 | from collections import defaultdict 7 | from contextlib import contextmanager 8 | from time import perf_counter as timer # noqa 9 | 10 | from . import logger 11 | from .stats import StatCounter 12 | 13 | 14 | __all__ = ['timed_operation', 'IterSpeedCounter', 'Timer'] 15 | 16 | 17 | @contextmanager 18 | def timed_operation(msg, log_start=False): 19 | """ 20 | Surround a context with a timer. 21 | 22 | Args: 23 | msg(str): the log to print. 24 | log_start(bool): whether to print also at the beginning. 25 | 26 | Example: 27 | .. code-block:: python 28 | 29 | with timed_operation('Good Stuff'): 30 | time.sleep(1) 31 | 32 | Will print: 33 | 34 | .. code-block:: python 35 | 36 | Good stuff finished, time:1sec. 37 | """ 38 | assert len(msg) 39 | if log_start: 40 | logger.info('Start {} ...'.format(msg)) 41 | start = timer() 42 | yield 43 | msg = msg[0].upper() + msg[1:] 44 | logger.info('{} finished, time:{:.4f} sec.'.format( 45 | msg, timer() - start)) 46 | 47 | 48 | _TOTAL_TIMER_DATA = defaultdict(StatCounter) 49 | 50 | 51 | @contextmanager 52 | def total_timer(msg): 53 | """ A context which add the time spent inside to the global TotalTimer. """ 54 | start = timer() 55 | yield 56 | t = timer() - start 57 | _TOTAL_TIMER_DATA[msg].feed(t) 58 | 59 | 60 | def print_total_timer(): 61 | """ 62 | Print the content of the global TotalTimer, if it's not empty. This function will automatically get 63 | called when program exits. 64 | """ 65 | if len(_TOTAL_TIMER_DATA) == 0: 66 | return 67 | for k, v in _TOTAL_TIMER_DATA.items(): 68 | logger.info("Total Time: {} -> {:.2f} sec, {} times, {:.3g} sec/time".format( 69 | k, v.sum, v.count, v.average)) 70 | 71 | 72 | atexit.register(print_total_timer) 73 | 74 | 75 | class IterSpeedCounter(object): 76 | """ Test how often some code gets reached. 77 | 78 | Example: 79 | Print the speed of the iteration every 100 times. 80 | 81 | .. code-block:: python 82 | 83 | speed = IterSpeedCounter(100) 84 | for k in range(1000): 85 | # do something 86 | speed() 87 | """ 88 | 89 | def __init__(self, print_every, name=None): 90 | """ 91 | Args: 92 | print_every(int): interval to print. 93 | name(str): name to used when print. 94 | """ 95 | self.cnt = 0 96 | self.print_every = int(print_every) 97 | self.name = name if name else 'IterSpeed' 98 | 99 | def reset(self): 100 | self.start = timer() 101 | 102 | def __call__(self): 103 | if self.cnt == 0: 104 | self.reset() 105 | self.cnt += 1 106 | if self.cnt % self.print_every != 0: 107 | return 108 | t = timer() - self.start 109 | logger.info("{}: {:.2f} sec, {} times, {:.3g} sec/time".format( 110 | self.name, t, self.cnt, t / self.cnt)) 111 | 112 | 113 | class Timer(): 114 | """ 115 | A timer class which computes the time elapsed since the start/reset of the timer. 116 | """ 117 | def __init__(self): 118 | self.reset() 119 | 120 | def reset(self): 121 | """ 122 | Reset the timer. 123 | """ 124 | self._start = timer() 125 | self._paused = False 126 | self._total_paused = 0 127 | 128 | def pause(self): 129 | """ 130 | Pause the timer. 131 | """ 132 | assert self._paused is False 133 | self._paused = timer() 134 | 135 | def is_paused(self): 136 | return self._paused is not False 137 | 138 | def resume(self): 139 | """ 140 | Resume the timer. 141 | """ 142 | assert self._paused is not False 143 | self._total_paused += timer() - self._paused 144 | self._paused = False 145 | 146 | def seconds(self): 147 | """ 148 | Returns: 149 | float: the total number of seconds since the start/reset of the timer, excluding the 150 | time in between when the timer is paused. 151 | """ 152 | if self._paused: 153 | self.resume() 154 | self.pause() 155 | return timer() - self._start - self._total_paused 156 | -------------------------------------------------------------------------------- /dataflow/utils/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: utils.py 3 | 4 | 5 | import inspect 6 | import numpy as np 7 | import re 8 | import os 9 | import sys 10 | from contextlib import contextmanager 11 | from datetime import datetime, timedelta 12 | from tqdm import tqdm 13 | 14 | from . import logger 15 | from .concurrency import subproc_call 16 | 17 | __all__ = ['change_env', 18 | 'get_rng', 19 | 'fix_rng_seed', 20 | 'get_tqdm', 21 | 'execute_only_once', 22 | 'humanize_time_delta' 23 | ] 24 | 25 | 26 | def humanize_time_delta(sec): 27 | """Humanize timedelta given in seconds 28 | 29 | Args: 30 | sec (float): time difference in seconds. Must be positive. 31 | 32 | Returns: 33 | str - time difference as a readable string 34 | 35 | Example: 36 | 37 | .. code-block:: python 38 | 39 | print(humanize_time_delta(1)) # 1 second 40 | print(humanize_time_delta(60 + 1)) # 1 minute 1 second 41 | print(humanize_time_delta(87.6)) # 1 minute 27 seconds 42 | print(humanize_time_delta(0.01)) # 0.01 seconds 43 | print(humanize_time_delta(60 * 60 + 1)) # 1 hour 1 second 44 | print(humanize_time_delta(60 * 60 * 24 + 1)) # 1 day 1 second 45 | print(humanize_time_delta(60 * 60 * 24 + 60 * 2 + 60*60*9 + 3)) # 1 day 9 hours 2 minutes 3 seconds 46 | """ 47 | if sec < 0: 48 | logger.warn("humanize_time_delta() obtains negative seconds!") 49 | return "{:.3g} seconds".format(sec) 50 | if sec == 0: 51 | return "0 second" 52 | time = datetime(2000, 1, 1) + timedelta(seconds=int(sec)) 53 | units = ['day', 'hour', 'minute', 'second'] 54 | vals = [int(sec // 86400), time.hour, time.minute, time.second] 55 | if sec < 60: 56 | vals[-1] = sec 57 | 58 | def _format(v, u): 59 | return "{:.3g} {}{}".format(v, u, "s" if v > 1 else "") 60 | 61 | ans = [] 62 | for v, u in zip(vals, units): 63 | if v > 0: 64 | ans.append(_format(v, u)) 65 | return " ".join(ans) 66 | 67 | 68 | @contextmanager 69 | def change_env(name, val): 70 | """ 71 | Args: 72 | name(str): name of the env var 73 | val(str or None): the value, or set to None to clear the env var. 74 | 75 | Returns: 76 | a context where the environment variable ``name`` being set to 77 | ``val``. It will be set back after the context exits. 78 | """ 79 | oldval = os.environ.get(name, None) 80 | 81 | if val is None: 82 | try: 83 | del os.environ[name] 84 | except KeyError: 85 | pass 86 | else: 87 | os.environ[name] = val 88 | 89 | yield 90 | 91 | if oldval is None: 92 | try: 93 | del os.environ[name] 94 | except KeyError: 95 | pass 96 | else: 97 | os.environ[name] = oldval 98 | 99 | 100 | _RNG_SEED = None 101 | 102 | 103 | def fix_rng_seed(seed): 104 | """ 105 | Call this function at the beginning of program to fix rng seed within tensorpack. 106 | 107 | Args: 108 | seed (int): 109 | 110 | Note: 111 | See https://github.com/tensorpack/tensorpack/issues/196. 112 | 113 | Example: 114 | 115 | Fix random seed in both tensorpack and tensorflow. 116 | 117 | .. code-block:: python 118 | 119 | seed = 42 120 | utils.fix_rng_seed(seed) 121 | tesnorflow.set_random_seed(seed) 122 | # run trainer 123 | """ 124 | global _RNG_SEED 125 | _RNG_SEED = int(seed) 126 | 127 | 128 | def get_rng(obj=None): 129 | """ 130 | Get a good RNG seeded with time, pid and the object. 131 | 132 | Args: 133 | obj: some object to use to generate random seed. 134 | Returns: 135 | np.random.RandomState: the RNG. 136 | """ 137 | seed = (id(obj) + os.getpid() + 138 | int(datetime.now().strftime("%Y%m%d%H%M%S%f"))) % 4294967295 139 | if _RNG_SEED is not None: 140 | seed = _RNG_SEED 141 | return np.random.RandomState(seed) 142 | 143 | 144 | _EXECUTE_HISTORY = set() 145 | 146 | 147 | def execute_only_once(): 148 | """ 149 | Each called in the code to this function is guaranteed to return True the 150 | first time and False afterwards. 151 | 152 | Returns: 153 | bool: whether this is the first time this function gets called from this line of code. 154 | 155 | Example: 156 | .. code-block:: python 157 | 158 | if execute_only_once(): 159 | # do something only once 160 | """ 161 | f = inspect.currentframe().f_back 162 | ident = (f.f_code.co_filename, f.f_lineno) 163 | if ident in _EXECUTE_HISTORY: 164 | return False 165 | _EXECUTE_HISTORY.add(ident) 166 | return True 167 | 168 | 169 | def _pick_tqdm_interval(file): 170 | # Heuristics to pick a update interval for progress bar that's nice-looking for users. 171 | isatty = file.isatty() 172 | # Jupyter notebook should be recognized as tty. 173 | # Wait for https://github.com/ipython/ipykernel/issues/268 174 | try: 175 | from ipykernel import iostream 176 | if isinstance(file, iostream.OutStream): 177 | isatty = True 178 | except ImportError: 179 | pass 180 | 181 | if isatty: 182 | return 0.5 183 | else: 184 | # When run under mpirun/slurm, isatty is always False. 185 | # Here we apply some hacky heuristics for slurm. 186 | if 'SLURM_JOB_ID' in os.environ: 187 | if int(os.environ.get('SLURM_JOB_NUM_NODES', 1)) > 1: 188 | # multi-machine job, probably not interactive 189 | return 60 190 | else: 191 | # possibly interactive, so let's be conservative 192 | return 15 193 | 194 | if 'OMPI_COMM_WORLD_SIZE' in os.environ: 195 | return 60 196 | 197 | # If not a tty, don't refresh progress bar that often 198 | return 180 199 | 200 | 201 | def get_tqdm_kwargs(**kwargs): 202 | """ 203 | Return default arguments to be used with tqdm. 204 | 205 | Args: 206 | kwargs: extra arguments to be used. 207 | Returns: 208 | dict: 209 | """ 210 | default = dict( 211 | smoothing=0.5, 212 | dynamic_ncols=True, 213 | ascii=True, 214 | bar_format='{l_bar}{bar}|{n_fmt}/{total_fmt}[{elapsed}<{remaining},{rate_noinv_fmt}]' 215 | ) 216 | 217 | try: 218 | # Use this env var to override the refresh interval setting 219 | interval = float(os.environ['TENSORPACK_PROGRESS_REFRESH']) 220 | except KeyError: 221 | interval = _pick_tqdm_interval(kwargs.get('file', sys.stderr)) 222 | 223 | default['mininterval'] = interval 224 | default.update(kwargs) 225 | return default 226 | 227 | 228 | def get_tqdm(*args, **kwargs): 229 | """ Similar to :func:`tqdm.tqdm()`, 230 | but use tensorpack's default options to have consistent style. """ 231 | return tqdm(*args, **get_tqdm_kwargs(**kwargs)) 232 | 233 | 234 | def find_library_full_path(name): 235 | """ 236 | Similar to `from ctypes.util import find_library`, but try 237 | to return full path if possible. 238 | """ 239 | from ctypes.util import find_library 240 | 241 | if os.name == "posix" and sys.platform == "darwin": 242 | # on Mac, ctypes already returns full path 243 | return find_library(name) 244 | 245 | def _use_proc_maps(name): 246 | """ 247 | Find so from /proc/pid/maps 248 | Only works with libraries that has already been loaded. 249 | But this is the most accurate method -- it finds the exact library that's being used. 250 | """ 251 | procmap = os.path.join('/proc', str(os.getpid()), 'maps') 252 | if not os.path.isfile(procmap): 253 | return None 254 | try: 255 | with open(procmap, 'r') as f: 256 | for line in f: 257 | line = line.strip().split(' ') 258 | sofile = line[-1] 259 | 260 | basename = os.path.basename(sofile) 261 | if 'lib' + name + '.so' in basename: 262 | if os.path.isfile(sofile): 263 | return os.path.realpath(sofile) 264 | except IOError: 265 | # can fail in certain environment (e.g. chroot) 266 | # if the pids are incorrectly mapped 267 | pass 268 | 269 | # The following two methods come from https://github.com/python/cpython/blob/master/Lib/ctypes/util.py 270 | def _use_ld(name): 271 | """ 272 | Find so with `ld -lname -Lpath`. 273 | It will search for files in LD_LIBRARY_PATH, but not in ldconfig. 274 | """ 275 | cmd = "ld -t -l{} -o {}".format(name, os.devnull) 276 | ld_lib_path = os.environ.get('LD_LIBRARY_PATH', '') 277 | for d in ld_lib_path.split(':'): 278 | cmd = cmd + " -L " + d 279 | result, ret = subproc_call(cmd + '|| true') 280 | expr = r'[^\(\)\s]*lib%s\.[^\(\)\s]*' % re.escape(name) 281 | res = re.search(expr, result.decode('utf-8')) 282 | if res: 283 | res = res.group(0) 284 | if not os.path.isfile(res): 285 | return None 286 | return os.path.realpath(res) 287 | 288 | def _use_ldconfig(name): 289 | """ 290 | Find so in `ldconfig -p`. 291 | It does not handle LD_LIBRARY_PATH. 292 | """ 293 | with change_env('LC_ALL', 'C'), change_env('LANG', 'C'): 294 | ldconfig, ret = subproc_call("ldconfig -p") 295 | ldconfig = ldconfig.decode('utf-8') 296 | if ret != 0: 297 | return None 298 | expr = r'\s+(lib%s\.[^\s]+)\s+\(.*=>\s+(.*)' % (re.escape(name)) 299 | res = re.search(expr, ldconfig) 300 | if not res: 301 | return None 302 | else: 303 | ret = res.group(2) 304 | return os.path.realpath(ret) 305 | 306 | if sys.platform.startswith('linux'): 307 | return _use_proc_maps(name) or _use_ld(name) or _use_ldconfig(name) or find_library(name) 308 | 309 | return find_library(name) # don't know what to do 310 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | import setuptools 3 | from setuptools import setup, find_packages 4 | 5 | version = int(setuptools.__version__.split('.')[0]) 6 | assert version > 30, "Dataflow installation requires setuptools > 30" 7 | 8 | this_directory = path.abspath(path.dirname(__file__)) 9 | 10 | 11 | # setup metainfo 12 | 13 | with open(path.join(this_directory, 'README.md'), 'rb') as f: 14 | long_description = f.read().decode('utf-8') 15 | __version__ = '0.9.5' 16 | 17 | 18 | setup( 19 | name='dataflow', 20 | author="TensorPack contributors", 21 | author_email="ppwwyyxxc@gmail.com", 22 | url="https://github.com/tensorpack/dataflow", 23 | keywords="deep learning, neural network, data processing", 24 | license="Apache", 25 | 26 | version=__version__, # noqa 27 | description='', 28 | long_description=long_description, 29 | long_description_content_type='text/markdown', 30 | 31 | packages=find_packages(exclude=["examples", "tests"]), 32 | zip_safe=False, # dataset and __init__ use file 33 | 34 | install_requires=[ 35 | "numpy>=1.14", 36 | "six", 37 | "termcolor>=1.1", 38 | "tabulate>=0.7.7", 39 | "tqdm>4.29.0", 40 | "msgpack>=0.5.2", 41 | "msgpack-numpy>=0.4.4.2", 42 | "pyzmq>=16", 43 | "psutil>=5", 44 | "subprocess32; python_version < '3.0'", 45 | "functools32; python_version < '3.0'", 46 | ], 47 | tests_require=['flake8'], 48 | extras_require={ 49 | 'all': ['scipy', 'h5py', 'lmdb>=0.92', 'matplotlib', 'scikit-learn'], 50 | 'all: "linux" in sys_platform': ['python-prctl'], 51 | }, 52 | 53 | # https://packaging.python.org/guides/distributing-packages-using-setuptools/#universal-wheels 54 | options={'bdist_wheel': {'universal': '1'}}, 55 | ) 56 | -------------------------------------------------------------------------------- /sync.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding: utf-8 -*- 3 | 4 | """ 5 | This script sync commits from tensorpack to dataflow. 6 | 7 | Dependencies: 8 | gitpython 9 | """ 10 | 11 | import os 12 | import sys 13 | import shutil 14 | from datetime import datetime 15 | import git 16 | import argparse 17 | from dataflow.utils import logger 18 | 19 | 20 | DF_ROOT = os.path.dirname(__file__) 21 | UTILS_TO_SYNC = [ 22 | 'argtools', 'concurrency', 'compatible_serialize', 'develop', 23 | 'fs', '__init__', 'loadcaffe', 'logger', 'serialize', 'stats', 24 | 'timer', 'utils'] 25 | 26 | 27 | def match_commit(tp_commit, df_commit): 28 | df_commit_message = df_commit.message.strip().strip('"').strip() 29 | tp_commit_message = tp_commit.message.strip() 30 | if (tp_commit_message == df_commit_message) and \ 31 | (tp_commit.authored_date == df_commit.authored_date): 32 | return True 33 | return False 34 | 35 | 36 | def show_commit(commit): 37 | return commit.repo.git.show('-s', commit.hexsha, '--color=always') 38 | 39 | 40 | def show_date(timestamp): 41 | return datetime.fromtimestamp(timestamp).strftime("%Y/%M/%d-%H:%m:%S") 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('tensorpack') 47 | args = parser.parse_args() 48 | tp_root = args.tensorpack 49 | 50 | tp_repo = git.Repo(tp_root) 51 | assert tp_repo.active_branch.name == 'master' 52 | if tp_repo.is_dirty(): 53 | print("Warning: tensorpack repo is not clean!") 54 | 55 | df_repo = git.Repo(DF_ROOT) 56 | assert not df_repo.is_dirty() 57 | df_commits = df_repo.iter_commits(paths='dataflow') 58 | df_latest_commit = next(df_commits) 59 | logger.info("DataFlow commit to match: \n" + show_commit(df_latest_commit)) 60 | 61 | unsynced_commits = [] 62 | paths_to_sync = ['tensorpack/dataflow'] + ['tensorpack/utils/{}.py'.format(u) for u in UTILS_TO_SYNC] 63 | for cmt in tp_repo.iter_commits(paths=paths_to_sync): 64 | if match_commit(cmt, df_latest_commit): 65 | logger.info("Matched tensorpack commit: \n" + show_commit(cmt)) 66 | break 67 | unsynced_commits.append(cmt) 68 | else: 69 | logger.error("Cannot find tensorpack commit that matches the above commit.") 70 | sys.exit(1) 71 | logger.info("{} more commits to sync".format(len(unsynced_commits))) 72 | 73 | unsynced_commits = unsynced_commits[::-1] 74 | 75 | try: 76 | for commit_to_sync in unsynced_commits: 77 | tp_repo.git.checkout(commit_to_sync.hexsha) 78 | logger.info("-" * 60) 79 | logger.info("Syncing commit '{}' at {}".format( 80 | commit_to_sync.message.strip(), show_date(commit_to_sync.authored_date))) 81 | 82 | # sync files 83 | dst = os.path.join(DF_ROOT, 'dataflow', 'dataflow') 84 | logger.info("Syncing {} ...".format(dst)) 85 | shutil.rmtree(dst) 86 | shutil.copytree(os.path.join(tp_root, 'tensorpack', 'dataflow'), dst) 87 | 88 | logger.info("Syncing utils ...") 89 | for util in UTILS_TO_SYNC: 90 | dst = os.path.join(DF_ROOT, 'dataflow', 'utils', util + '.py') 91 | src = os.path.join(tp_root, 'tensorpack', 'utils', util + '.py') 92 | os.unlink(dst) 93 | shutil.copy2(src, dst) 94 | 95 | author = "\"{} <{}>\"".format(commit_to_sync.author.name, commit_to_sync.author.email) 96 | log = df_repo.git.commit( 97 | '--all', 98 | message='{}'.format(commit_to_sync.message.strip()), 99 | date=commit_to_sync.authored_date, 100 | author=author) 101 | logger.info("Successfully sync commit:\n" + log) 102 | finally: 103 | tp_repo.git.checkout('master') 104 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | # See https://pep8.readthedocs.io/en/latest/intro.html#error-codes 4 | ignore = E265,E741,E742,E743,W504,W605 5 | exclude = .git, 6 | __init__.py, 7 | setup.py, 8 | _test.py, 9 | examples, 10 | 11 | [isort] 12 | line_length=100 13 | multi_line_output=4 14 | known_tensorpack=tensorpack 15 | known_standard_library=numpy 16 | known_third_party=gym,matplotlib 17 | no_lines_before=STDLIB,THIRDPARTY 18 | sections=FUTURE,STDLIB,THIRDPARTY,tensorpack,FIRSTPARTY,LOCALFOLDER 19 | --------------------------------------------------------------------------------