├── .gitignore ├── LICENSE ├── README.md ├── figs ├── arch_detail.png ├── l3c_code_outline.jpg └── teaser.png ├── pip_requirements.txt └── src ├── auto_crop.py ├── bitcoding ├── __init__.py ├── bitcoding.py ├── coders.py ├── coders_helpers.py └── part_suffix_helper.py ├── blueprints ├── __init__.py └── multiscale_blueprint.py ├── configs ├── dl │ ├── in32.cf │ ├── in64.cf │ └── oi.cf └── ms │ ├── cr.cf │ ├── cr_rgb.cf │ └── cr_rgb_shared.cf ├── criterion ├── __init__.py └── logistic_mixture.py ├── dataloaders ├── __init__.py └── images_loader.py ├── helpers ├── __init__.py ├── aligned_printer.py ├── config_checker.py ├── global_config.py ├── logdir_helpers.py ├── pad.py ├── paths.py ├── rolling_buffer.py ├── saver.py └── testset.py ├── import_train_images.py ├── import_train_images_v1.py ├── l3c.py ├── modules ├── __init__.py ├── edsr.py ├── head.py ├── multiscale_network.py ├── net.py ├── prob_clf.py └── quantizer.py ├── prep_openimages.sh ├── pytorch_ext.py ├── test.py ├── test ├── __init__.py ├── cuda_timer.py ├── image_saver.py └── multiscale_tester.py ├── tmp_discard_list ├── torchac ├── setup.py ├── torchac.py └── torchac_backend │ ├── torchac.cpp │ └── torchac_kernel.cu ├── train.py ├── train ├── __init__.py ├── lr_schedule.py ├── multiscale_trainer.py ├── train_restorer.py └── trainer.py └── vis ├── __init__.py ├── figure_plotter.py ├── grid.py ├── histogram_plot.py ├── histogram_plotter.py ├── image_summaries.py ├── safe_summary_writer.py └── summarizable_module.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | -------------------------------------------------------------------------------- /figs/arch_detail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/figs/arch_detail.png -------------------------------------------------------------------------------- /figs/l3c_code_outline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/figs/l3c_code_outline.jpg -------------------------------------------------------------------------------- /figs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/figs/teaser.png -------------------------------------------------------------------------------- /pip_requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow 2 | scipy==1.1.0 3 | fasteners==0.14.1 4 | fjcommon==0.2.10 5 | tensorboardX==1.2 6 | matplotlib==2.2.2 7 | -------------------------------------------------------------------------------- /src/auto_crop.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for cropping image depending on resolution. 3 | 4 | TODO: replace recursive code with something more adaptive, right now we only do 5 | - no crops / 2x2 / 4x4 / 16x16 / etc. 6 | and the stitching code is complicated, but would be nice to have e.g. 3x3. 7 | """ 8 | import math 9 | import os 10 | 11 | import itertools 12 | 13 | import torch 14 | 15 | from blueprints.multiscale_blueprint import MultiscaleLoss 16 | import functools 17 | import operator 18 | 19 | 20 | def prod(it): 21 | return functools.reduce(operator.mul, it, 1) 22 | 23 | 24 | # Images with H * W > prod(_NEEDS_CROP_DIM) will be split into crops 25 | # We set this empirically such that crops fit into our TITAN X (Pascal) with 12GB VRAM. 26 | # You can set this from the console using AC_NEEDS_CROP_DIM, e.g., 27 | # 28 | # AC_NEEDS_CROP_DIM=2000,2000 python test.py ... 29 | # 30 | # But expect OOM errors for big values. 31 | _NEEDS_CROP_DIM_DEFAULT = '2000,1500' 32 | _NEEDS_CROP_DIM = os.environ.get('AC_NEEDS_CROP_DIM', _NEEDS_CROP_DIM_DEFAULT) 33 | if _NEEDS_CROP_DIM != _NEEDS_CROP_DIM_DEFAULT: 34 | print('*** AC_NEEDS_CROP_DIM =', _NEEDS_CROP_DIM) 35 | _NEEDS_CROP_DIM = prod(map(int, _NEEDS_CROP_DIM.split(','))) 36 | print('*** AC_NEEDS_CROP_DIM =', _NEEDS_CROP_DIM) 37 | 38 | 39 | def _assert_valid_image(i): 40 | if len(i.shape) != 4 or i.shape[1] != 3: 41 | raise ValueError(f'Expected BCHW image, got {i.shape}') 42 | 43 | 44 | def needs_crop(img, needs_crop_dim=_NEEDS_CROP_DIM): 45 | _assert_valid_image(img) 46 | H, W = img.shape[-2:] 47 | return H * W > needs_crop_dim 48 | 49 | 50 | def _crop16(im): 51 | for im_cropped in _crop4(im): 52 | yield from _crop4(im_cropped) 53 | 54 | 55 | def iter_crops(img, needs_crop_dim=_NEEDS_CROP_DIM): 56 | _assert_valid_image(img) 57 | 58 | if not needs_crop(img, needs_crop_dim): 59 | yield img 60 | return 61 | for img_crop in _crop4(img): 62 | yield from iter_crops(img_crop, needs_crop_dim) 63 | 64 | 65 | def _crop4(img): 66 | _assert_valid_image(img) 67 | H, W = img.shape[-2:] 68 | imgs = [img[..., :H//2, :W//2], # Top left 69 | img[..., :H//2, W//2:], # Top right 70 | img[..., H//2:, :W//2], # Bottom left 71 | img[..., H//2:, W//2:]] # Bottom right 72 | # Validate that we got all pixels 73 | assert sum(prod(img.shape[-2:]) for img in imgs) == \ 74 | prod(img.shape[-2:]) 75 | return imgs 76 | 77 | 78 | def _get_crop_idx_mapping(side): 79 | """Helper method to get the order of crops. 80 | 81 | :param side: how many crops live on each side. 82 | 83 | Example. Say you have an image that gets devided into 16 crops, i.e., the image gets cut into 16 parts: 84 | 85 | [[ 0, 1, 2, 3], 86 | [ 4, 5, 6, 7], 87 | [ 8, 9, 10, 11], 88 | [12, 13, 14, 15]], 89 | 90 | However, due to our recursive cropping code, this results in crops that are ordered like this: 91 | 92 | index of crop: 93 | 0 1 2 3 4 ... 94 | corresponds to part in image: 95 | 0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15 96 | 97 | This method returns the inverse, going from the index of the crop back to the index in the image. 98 | """ 99 | a = torch.arange(side * side).reshape(1, 1, side, side) 100 | a = torch.cat((a, a, a), dim=1) 101 | # Create mapping 102 | # Index of crop in original image -> index of crop in the order it was extracted, 103 | # E.g. 2 -> 4 means it's the 2nd crop, but in the image, it's at position 4 (see above). 104 | crops = {i: crop[0, 0, ...].flatten().item() 105 | for i, crop in enumerate(iter_crops(a, 1))} 106 | return crops 107 | 108 | 109 | def stitch(parts): 110 | side = int(math.sqrt(len(parts))) 111 | if side * side != len(parts): 112 | raise ValueError(f'Invalid number of parts {len(parts)}') 113 | 114 | rows = [] 115 | 116 | # Sort by original position in image 117 | crops_idx_mapping = _get_crop_idx_mapping(side) 118 | parts_sorted = ( 119 | part for _, part in sorted( 120 | enumerate(parts), key=lambda ip: crops_idx_mapping[ip[0]])) 121 | 122 | parts_itr = iter(parts_sorted) # Turn into iterator so we can easily grab elements 123 | for _ in range(side): 124 | parts_row = itertools.islice(parts_itr, side) # Get `side` number of parts 125 | row = torch.cat(list(parts_row), dim=3) # cat on W dimension 126 | rows.append(row) 127 | 128 | assert next(parts_itr, None) is None, f'Iterator should be empty, got {len(rows)} rows' 129 | img = torch.cat(rows, dim=2) # cat on H dimension 130 | 131 | # Validate. 132 | B, C, H_part, W_part = parts[0].shape 133 | expected_shape = (B, C, H_part * side, W_part * side) 134 | assert img.shape == expected_shape, f'{img.shape} != {expected_shape}' 135 | 136 | return img 137 | 138 | 139 | class CropLossCombinator(object): 140 | """Used to combine the bpsp of different crops into one. Supports crops of varying dimensions.""" 141 | def __init__(self): 142 | self._num_bits_total = 0. 143 | self._num_subpixels_total = 0 144 | 145 | def add(self, bpsp, num_subpixels_crop): 146 | bits = bpsp * num_subpixels_crop 147 | self._num_bits_total += bits 148 | self._num_subpixels_total += num_subpixels_crop 149 | 150 | def get_bpsp(self): 151 | assert self._num_subpixels_total > 0 152 | return self._num_bits_total / self._num_subpixels_total 153 | 154 | 155 | def test_auto_crop(): 156 | import torch 157 | import pytorch_ext as pe 158 | 159 | for H, W, num_crops_expected in [(10000, 6000, 64), 160 | (4928, 3264, 16), 161 | (2048, 2048, 4), 162 | (1024, 1024, 1), 163 | ]: 164 | img = (torch.rand(1, 3, H, W) * 255).round().long() 165 | print(img.shape) 166 | if num_crops_expected > 1: 167 | assert needs_crop(img) 168 | crops = list(iter_crops(img, 2048 * 1024)) 169 | assert len(crops) == num_crops_expected 170 | pe.assert_equal(stitch(crops), img) 171 | else: 172 | pe.assert_equal(next(iter_crops(img, 2048 * 1024)), img) 173 | 174 | -------------------------------------------------------------------------------- /src/bitcoding/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/src/bitcoding/__init__.py -------------------------------------------------------------------------------- /src/bitcoding/coders.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | 19 | -------------------------------------------------------------------------------- 20 | 21 | Very thin wrapper around torchac, for arithmetic coding. 22 | 23 | """ 24 | import torch 25 | from torchac import torchac 26 | from fjcommon import no_op 27 | 28 | 29 | from criterion.logistic_mixture import CDFOut 30 | from test.cuda_timer import StackTimeLogger 31 | 32 | 33 | class ArithmeticCoder(object): 34 | def __init__(self, L): 35 | self.L = L 36 | self._cached_cdf = None 37 | 38 | def range_encode(self, data, cdf, time_logger: StackTimeLogger): 39 | """ 40 | :param data: data to encode 41 | :param cdf: cdf to use, either a NHWLp matrix or instance of CDFOut 42 | :return: data encode to a bytes string 43 | """ 44 | assert len(data.shape) == 3, data.shape 45 | 46 | with time_logger.run('data -> cpu'): 47 | data = data.to('cpu', non_blocking=True) 48 | assert data.dtype == torch.int16, 'Wrong dtype: {}'.format(data.dtype) 49 | 50 | with time_logger.run('reshape'): 51 | data = data.reshape(-1).contiguous() 52 | 53 | if isinstance(cdf, CDFOut): 54 | logit_probs_c_sm, means_c, log_scales_c, K, targets = cdf 55 | 56 | with time_logger.run('ac.encode'): 57 | out_bytes = torchac.encode_logistic_mixture( 58 | targets, means_c, log_scales_c, logit_probs_c_sm, data) 59 | else: 60 | N, H, W, Lp = cdf.shape 61 | assert Lp == self.L + 1, (Lp, self.L) 62 | 63 | with time_logger.run('ac.encode'): 64 | out_bytes = torchac.encode_cdf(cdf, data) 65 | 66 | return out_bytes 67 | 68 | def range_decode(self, encoded_bytes, cdf, time_logger: StackTimeLogger = no_op.NoOp): 69 | """ 70 | :param encoded_bytes: bytes encoded by range_encode 71 | :param cdf: cdf to use, either a NHWLp matrix or instance of CDFOut 72 | :return: decoded matrix as np.int16, NHW 73 | """ 74 | if isinstance(cdf, CDFOut): 75 | logit_probs_c_sm, means_c, log_scales_c, K, targets = cdf 76 | 77 | N, _, H, W = means_c.shape 78 | 79 | with time_logger.run('ac.encode'): 80 | decoded = torchac.decode_logistic_mixture( 81 | targets, means_c, log_scales_c, logit_probs_c_sm, encoded_bytes) 82 | 83 | else: 84 | N, H, W, Lp = cdf.shape 85 | assert Lp == self.L + 1, (Lp, self.L) 86 | 87 | with time_logger.run('ac.encode'): 88 | decoded = torchac.decode_cdf(cdf, encoded_bytes) 89 | 90 | return decoded.reshape(N, H, W) 91 | -------------------------------------------------------------------------------- /src/bitcoding/coders_helpers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | 19 | -------------------------------------------------------------------------------- 20 | 21 | Very thin wrapper around DiscretizedMixLogisticLoss.cdf_step_non_shared that keeps track of targets, which are the 22 | same for all channels of the bottleneck, as well as the current channel index. 23 | 24 | """ 25 | 26 | import torch 27 | 28 | from criterion.logistic_mixture import DiscretizedMixLogisticLoss, CDFOut 29 | 30 | 31 | class CodingCDFNonshared(object): 32 | def __init__(self, l, total_C, dmll: DiscretizedMixLogisticLoss): 33 | """ 34 | :param l: predicted distribution, i.e., NKpHW, see DiscretizedMixLogisticLoss 35 | :param total_C: 36 | :param dmll: 37 | """ 38 | self.l = l 39 | self.dmll = dmll 40 | 41 | # Lp = L+1 42 | self.targets = torch.linspace(dmll.x_min - dmll.bin_width / 2, 43 | dmll.x_max + dmll.bin_width / 2, 44 | dmll.L + 1, dtype=torch.float32, device=l.device) 45 | self.total_C = total_C 46 | self.c_cur = 0 47 | 48 | def get_next_C(self, decoded_x) -> CDFOut: 49 | """ 50 | Get CDF to encode/decode next channel 51 | :param decoded_x: NCHW 52 | :return: C_cond_cur, NHWL' 53 | """ 54 | C_Cur = self.dmll.cdf_step_non_shared(self.l, self.targets, self.c_cur, self.total_C, decoded_x) 55 | self.c_cur += 1 56 | return C_Cur 57 | 58 | -------------------------------------------------------------------------------- /src/bitcoding/part_suffix_helper.py: -------------------------------------------------------------------------------- 1 | import re 2 | import glob 3 | import os 4 | import string 5 | 6 | _PART_SUFFIX_BASE = '.part' 7 | _PART_SUFFIX_REGEX = _PART_SUFFIX_BASE + r'(\d+)$' 8 | 9 | 10 | def make_part_suffix(i): 11 | """Return str that is suffix for index `i`.""" 12 | assert i >= 0, i 13 | return _PART_SUFFIX_BASE + str(i) 14 | 15 | 16 | def contains_part_suffix(p): 17 | """Return true iff path `p` contains a part suffix.""" 18 | return re.search(_PART_SUFFIX_REGEX, p) is not None 19 | 20 | 21 | def index_of_part_suffix(p): 22 | """Return index of suffix (final number in path).""" 23 | return int(re.search(_PART_SUFFIX_REGEX, p).group(1)) 24 | 25 | 26 | def iter_part_suffixes(pin): 27 | """Return list of all paths that are the same base.""" 28 | assert os.path.isfile(pin) 29 | assert contains_part_suffix(pin) 30 | base = pin.rstrip(string.digits) 31 | matches = glob.glob(base + '*') 32 | # Filter things of the form base* where * is not digits 33 | matches = [m for m in matches if contains_part_suffix(m)] 34 | matches = sorted(matches, key=index_of_part_suffix) 35 | return matches 36 | 37 | 38 | def test_part_suffix(): 39 | assert make_part_suffix(1) == '.part1' 40 | assert make_part_suffix(10) == '.part10' 41 | assert contains_part_suffix('bla/bli/blupp.part1') 42 | assert index_of_part_suffix('bla/bli/blupp.part1') == 1 43 | assert contains_part_suffix('bla/bli/blupp.part3') 44 | assert not contains_part_suffix('bla/bli/blupp.part3/more') 45 | 46 | 47 | def test_iter(tmpdir): 48 | prefix = 'some.file' 49 | for i in range(16): 50 | tmpdir.join('some.file' + str(make_part_suffix(i))).write('test') 51 | 52 | some_file = str(tmpdir.join(prefix + str(make_part_suffix(3)))) 53 | assert list(map(os.path.basename, iter_part_suffixes(some_file))) == \ 54 | ['some.file.part{}'.format(i) for i in range(16)] 55 | -------------------------------------------------------------------------------- /src/blueprints/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/src/blueprints/__init__.py -------------------------------------------------------------------------------- /src/blueprints/multiscale_blueprint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | from collections import namedtuple 20 | 21 | import numpy as np 22 | import torch 23 | import torch.nn.functional as F 24 | import torchvision 25 | 26 | import pytorch_ext as pe 27 | import vis.grid 28 | import vis.summarizable_module 29 | from helpers.pad import pad 30 | from modules.multiscale_network import MultiscaleNetwork, Out 31 | from vis import histogram_plotter, image_summaries 32 | 33 | 34 | MultiscaleLoss = namedtuple( 35 | 'MultiscaleLoss', 36 | ['loss_pc', # loss to minimize 37 | 'nonrecursive_bpsps', # bpsp corresponding to non-recursive scales 38 | 'recursive_bpsps']) # None if not recursive, else all bpsp including recursive 39 | 40 | 41 | 42 | class MultiscaleBlueprint(vis.summarizable_module.SummarizableModule): 43 | def __init__(self, config_ms): 44 | super(MultiscaleBlueprint, self).__init__() 45 | net = MultiscaleNetwork(config_ms) 46 | net.to(pe.DEVICE) 47 | 48 | self.net = net 49 | self.losses = net.get_losses() 50 | 51 | def set_eval(self): 52 | self.net.eval() 53 | self.losses.loss_dmol_rgb.eval() 54 | self.losses.loss_dmol_n.eval() 55 | 56 | def forward(self, in_batch, auto_recurse=0) -> Out: 57 | """ 58 | :param in_batch: NCHW 0..255 float 59 | :param auto_recurse: int, how many times the last scales should be applied again. Used for RGB Shared. 60 | :return: layers.multiscale.Out 61 | """ 62 | return self.net(in_batch, auto_recurse) 63 | 64 | def get_loss(self, out: Out, num_subpixels_before_pad=None) -> MultiscaleLoss: 65 | """ 66 | :param num_subpixels_before_pad: If given, calculate bpsp with this, instead of num_pixels returned by self.losses. 67 | This is needed because while testing, we have to pad images. To calculate the correct bpsp, we need to 68 | calcualte it with respect to the actual (non-padded) number of pixels 69 | :returns instance of MultiscaleLoss, see above. 70 | """ 71 | # `costs`: a list, containing the cost of each scale, in nats 72 | costs, final_cost_uniform, num_subpixels = self.losses.get(out) 73 | if num_subpixels_before_pad: 74 | assert num_subpixels_before_pad <= num_subpixels, num_subpixels_before_pad 75 | num_subpixels = num_subpixels_before_pad 76 | # conversion between nats and bits per subpixel 77 | conversion = np.log(2.) * num_subpixels 78 | costs_bpsp = [cost/conversion for cost in costs] 79 | 80 | self.summarizer.register_scalars( 81 | 'auto', 82 | {'costs/scale_{}_bpsp'.format(i): cost for i, cost in enumerate(costs_bpsp)}) 83 | 84 | # all bpsps corresponding to non-recursive scales, including final (uniform-prior) cost 85 | nonrecursive_bpsps = costs_bpsp[:out.auto_recursive_from] + [final_cost_uniform / conversion] 86 | if out.auto_recursive_from is not None: 87 | # all bpsps corresponding to non-recursive AND recursive scales, including final cost 88 | recursive_bpsps = costs_bpsp + [out.get_nat_count(-1) / conversion] 89 | else: 90 | recursive_bpsps = None 91 | 92 | # loss is everything without final (uniform-prior) scale 93 | total_bpsp_without_final = sum(costs_bpsp) 94 | loss_pc = total_bpsp_without_final 95 | return MultiscaleLoss(loss_pc, nonrecursive_bpsps, recursive_bpsps) 96 | 97 | def sample_forward(self, in_batch, sample_scales, partial_final=None): 98 | return self.net.sample_forward(in_batch, self.losses, sample_scales, partial_final) 99 | 100 | @staticmethod 101 | def add_image_summaries(sw, out: Out, global_step, prefix): 102 | tag = lambda t: sw.pre(prefix, t) 103 | is_train = prefix == 'train' 104 | for scale, (S_i, _, P_i, L_i) in enumerate(out.iter_all_scales(), 1): # start from 1, as 0 is RGB 105 | sw.add_image(tag('bn/{}'.format(scale)), new_bottleneck_summary(S_i, L_i), global_step) 106 | # This will only trigger for the final scale, where P_i is the uniform distribution. 107 | # With this, we can check how accurate the uniform assumption is (hint: not very) 108 | is_logits = P_i.shape[1] == L_i 109 | if is_logits and is_train: 110 | with sw.add_figure_ctx(tag('histo_out/{}'.format(scale)), global_step) as plt: 111 | add_ps_summaries(S_i, get_p_y(P_i), L_i, plt) 112 | 113 | @staticmethod 114 | def bottleneck_images(s, L): 115 | assert s.dim() == 4, s.shape 116 | _assert_contains_symbol_indices(s, L) 117 | s = s.float().div(L) 118 | return [image_summaries.to_image(s[:, c, ...]) for c in range(s.shape[1])] 119 | 120 | @staticmethod 121 | def unpack_batch_pad(raw, fac): 122 | """ 123 | :param raw: uint8, input image. 124 | :param fac: downscaling factor we will use, used to determine proper padding. 125 | """ 126 | if len(raw.shape) == 3: 127 | raw.unsqueeze_(0) # add batch dim 128 | assert len(raw.shape) == 4 129 | raw = MultiscaleBlueprint.pad(raw, fac) 130 | raw = raw.to(pe.DEVICE) 131 | img_batch = raw.float() 132 | s = raw.long() # symbols 133 | return img_batch, s 134 | 135 | @staticmethod 136 | def pad(raw, fac): 137 | raw, _ = pad(raw, fac, mode=MultiscaleBlueprint.get_padding_mode()) 138 | return raw 139 | 140 | @staticmethod 141 | def get_padding_mode(): 142 | return 'constant' 143 | 144 | @staticmethod 145 | def unpack(img_batch): 146 | idxs = img_batch['idx'].squeeze().tolist() 147 | raw = img_batch['raw'].to(pe.DEVICE) 148 | img_batch = raw.float() 149 | s = raw.long() # symbols 150 | return idxs, img_batch, s 151 | 152 | 153 | def new_bottleneck_summary(s, L): 154 | """ 155 | Grayscale bottleneck representation: Expects the actual bottleneck symbols. 156 | :param s: NCHW 157 | :return: [0, 1] image 158 | """ 159 | assert s.dim() == 4, s.shape 160 | _assert_contains_symbol_indices(s, L) 161 | s = s.float().div(L) 162 | grid = vis.grid.prep_for_grid(s, channelwise=True) 163 | assert len(grid) == s.shape[1], (len(grid), s.shape) 164 | assert [g.max() <= 1 for g in grid], [g.max() for g in grid] 165 | assert grid[0].dtype == torch.float32, grid.dtype 166 | return torchvision.utils.make_grid(grid, nrow=5) 167 | 168 | 169 | def _assert_contains_symbol_indices(t, L): 170 | """ assert 0 <= t < L """ 171 | assert 0 <= t.min() and t.max() < L, (t.min(), t.max()) 172 | 173 | 174 | def add_ps_summaries(s, p_y, L, plt): 175 | histo_s = pe.histogram(s, L) 176 | p_x = histo_s / np.sum(histo_s) 177 | 178 | assert p_x.shape == p_y.shape, (p_x.shape, p_y.shape) 179 | 180 | histogram_plotter.plot_histogram([ 181 | ('p_x', p_x), 182 | ('p_y', p_y), 183 | ], plt) 184 | 185 | 186 | def get_p_y(y): 187 | """ 188 | :param y: NLCHW float, logits 189 | :return: L dimensional vector p 190 | """ 191 | Ldim = 1 192 | L = y.shape[Ldim] 193 | y = y.detach() 194 | p = F.softmax(y, dim=Ldim) 195 | p = p.transpose(Ldim, -1) 196 | p = p.contiguous().view(-1, L) # nL 197 | p = torch.mean(p, dim=0) # L 198 | return pe.tensor_to_np(p) 199 | -------------------------------------------------------------------------------- /src/configs/dl/in32.cf: -------------------------------------------------------------------------------- 1 | # Note: this is different from the text, batch size is actually 30. See Errata in README 2 | batchsize_train = 30 3 | batchsize_val = 120 4 | crop_size = 32 5 | 6 | max_epochs = None 7 | 8 | image_cache_pkl = None 9 | train_imgs_glob = 'path/to/train_32x32/ (check readme)' 10 | val_glob = 'path/to/val_32x32 (check readme)' 11 | 12 | val_glob_min_size = None 13 | fixed_first_image = None 14 | num_val_batches = 5 15 | -------------------------------------------------------------------------------- /src/configs/dl/in64.cf: -------------------------------------------------------------------------------- 1 | use in32.cf 2 | 3 | crop_size = 64 4 | 5 | image_cache_pkl = None 6 | train_imgs_glob = 'path/to/train_64x64/ (check readme)' 7 | val_glob = 'path/to/val_64x64 (check readme)' 8 | 9 | num_val_batches = 5 10 | -------------------------------------------------------------------------------- /src/configs/dl/oi.cf: -------------------------------------------------------------------------------- 1 | batchsize_train = 30 2 | batchsize_val = 30 3 | crop_size = 128 4 | 5 | max_epochs = None 6 | 7 | image_cache_pkl = None 8 | train_imgs_glob = 'path/to/train/ (check readme)' 9 | val_glob = 'path/to/val (check readme)' 10 | 11 | val_glob_min_size = None # for cache_p 12 | num_val_batches = 5 13 | 14 | -------------------------------------------------------------------------------- /src/configs/ms/cr.cf: -------------------------------------------------------------------------------- 1 | optim = 'RMSprop' 2 | 3 | lr.initial = 0.0001 4 | lr.schedule = 'exp_0.75_e5' 5 | weight_decay = 0 6 | 7 | num_scales = 3 8 | shared_across_scales = False 9 | 10 | Cf = 64 11 | kernel_size = 3 12 | 13 | dmll_enable_grad = 0 14 | 15 | rgb_bicubic_baseline = False 16 | 17 | enc.cls = 'EDSRLikeEnc' 18 | enc.num_blocks = 8 19 | enc.feed_F = True 20 | enc.importance_map = False 21 | 22 | learned_L = False 23 | 24 | dec.cls = 'EDSRDec' 25 | dec.num_blocks = 8 26 | dec.skip = True 27 | 28 | q.cls = 'Quantizer' 29 | q.C = 5 30 | # We assume q.L levels, evenly distributed between q.levels_range[0] and q.levels_range[1], see net.py 31 | q.L = 25 32 | q.levels_range = (-1, 1) 33 | q.sigma = 2 34 | 35 | prob.K = 10 36 | 37 | after_q1x1 = True 38 | x4_down_in_scale0 = False 39 | -------------------------------------------------------------------------------- /src/configs/ms/cr_rgb.cf: -------------------------------------------------------------------------------- 1 | use cr_rgb_shared.cf 2 | 3 | num_scales = 3 4 | # We have a skip in the decoder, see comment in _forward_with_scales in multiscale.py 5 | dec.skip = True 6 | 7 | -------------------------------------------------------------------------------- /src/configs/ms/cr_rgb_shared.cf: -------------------------------------------------------------------------------- 1 | use cr.cf 2 | 3 | rgb_bicubic_baseline = True 4 | q.C = 3 5 | q.L = 5 6 | 7 | # shared for all 8 | num_scales = 1 9 | 10 | enc.feed_F = False 11 | 12 | dec.skip = False 13 | 14 | enc.cls = 'BicubicSubsampling' 15 | 16 | -------------------------------------------------------------------------------- /src/criterion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/src/criterion/__init__.py -------------------------------------------------------------------------------- /src/dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/src/dataloaders/__init__.py -------------------------------------------------------------------------------- /src/dataloaders/images_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | import argparse 20 | import glob 21 | import os 22 | import pickle 23 | 24 | import numpy as np 25 | import torch 26 | import torchvision.transforms as transforms 27 | from PIL import Image 28 | from torch.utils.data import Dataset 29 | 30 | from helpers.paths import has_image_ext 31 | from helpers.testset import Testset 32 | 33 | 34 | class NoImagesFoundException(Exception): 35 | def __init__(self, p): 36 | self.p = p 37 | 38 | 39 | class IndexImagesDataset(Dataset): 40 | """ 41 | A Dataset class for images, that also returns the index of the image in the dataset. 42 | """ 43 | @staticmethod 44 | def to_tensor_uint8_transform(): 45 | """ Convert PIL to uint8 tensor, CHW. """ 46 | return to_tensor_not_normalized 47 | 48 | @staticmethod 49 | def to_grb(t): 50 | assert t.shape[0] == 3 51 | with torch.no_grad(): 52 | return torch.stack((t[1, ...], t[0, ...], t[2, ...]), dim=0) 53 | 54 | @staticmethod 55 | def to_float_tensor_transform(): 56 | return transforms.Lambda(lambda img: img.float().div(255.)) 57 | 58 | def copy(self): 59 | return IndexImagesDataset( 60 | self.images_spec, self.to_tensor_transform, self.cache_p, self.min_size) 61 | 62 | def __init__(self, images, to_tensor_transform, fixed_first=None): 63 | """ 64 | :param images: Instance of Testset or ImagesCached 65 | :param to_tensor_transform: A function that takes a PIL RGB image and returns a torch.Tensor 66 | """ 67 | self.to_tensor_transform = to_tensor_transform 68 | 69 | if isinstance(images, ImagesCached): 70 | self.files = images.get_images_sorted_cached() 71 | self.id = '{}_{}_{}'.format(images.images_spec, images.min_size, len(self.files)) 72 | elif isinstance(images, Testset): 73 | self.files = images.ps 74 | self.id = images.id 75 | else: 76 | raise ValueError('Expected ImagesCached or Testset, got images={}'.format(images)) 77 | 78 | if fixed_first: 79 | assert os.path.isfile(fixed_first) 80 | self.files = [fixed_first] + self.files 81 | 82 | if len(self.files) == 0: 83 | raise NoImagesFoundException(images.search_path()) 84 | 85 | def __len__(self): 86 | return len(self.files) 87 | 88 | def __str__(self): 89 | return 'IndexImagesDataset({} images, id={})'.format(len(self.files), self.id) 90 | 91 | def __getitem__(self, idx): 92 | path = self.files[idx] 93 | with open(path, 'rb') as f: 94 | pil = Image.open(f).convert('RGB') 95 | raw = self.to_tensor_transform(pil) # 3HW uint8s 96 | return {'idx': idx, 97 | 'raw': raw} 98 | 99 | 100 | def to_tensor_not_normalized(pic): 101 | """ copied from PyTorch functional.to_tensor, removed final .float().div(255.) """ 102 | if isinstance(pic, np.ndarray): 103 | # handle numpy array 104 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 105 | return img 106 | 107 | # handle PIL Image 108 | if pic.mode == 'I': 109 | img = torch.from_numpy(np.array(pic, np.int32, copy=False)) 110 | elif pic.mode == 'I;16': 111 | img = torch.from_numpy(np.array(pic, np.int16, copy=False)) 112 | elif pic.mode == 'F': 113 | img = torch.from_numpy(np.array(pic, np.float32, copy=False)) 114 | elif pic.mode == '1': 115 | img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False)) 116 | else: 117 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 118 | # PIL image mode: L, P, I, F, RGB, YCbCr, RGBA, CMYK 119 | if pic.mode == 'YCbCr': 120 | nchannel = 3 121 | elif pic.mode == 'I;16': 122 | nchannel = 1 123 | else: 124 | nchannel = len(pic.mode) 125 | img = img.view(pic.size[1], pic.size[0], nchannel) 126 | # put it from HWC to CHW format 127 | # yikes, this transpose takes 80% of the loading time/CPU 128 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 129 | return img 130 | 131 | 132 | class ImagesCached(object): 133 | """ Caches contents of folders, for slow filesystems. """ 134 | def __init__(self, images_spec, cache_p=None, min_size=None): 135 | """ 136 | :param images_spec: str, interpreted as 137 | - a glob, if it contains a * 138 | - a single image, if it ends in one of IMG_EXTENSIONS 139 | - a directory otherwise 140 | :param cache_p: path to a cache or None. If given, check there when `get_images_sorted_cached` is called 141 | :param min_size: if given, make sure to only return/cache images of the given size 142 | :return: 143 | """ 144 | self.images_spec = os.path.expanduser(images_spec) 145 | self.cache_p = os.path.expanduser(cache_p) 146 | self.min_size = min_size 147 | self.cache = ImagesCached._get_cache(self.cache_p) 148 | 149 | def __str__(self): 150 | return f'ImagesCached(images_spec={self.images_spec})' 151 | 152 | def search_path(self): 153 | return self.images_spec 154 | 155 | def get_images_sorted_cached(self): 156 | key = self.images_spec, self.min_size 157 | if key in self.cache: 158 | return self.cache[key] 159 | if self.cache_p: 160 | print(f'WARN: Given cache_p={self.cache_p}, but key not found:\n{key}') 161 | available_keys = sorted(self.cache.keys(), key=lambda img_size: img_size[0]) 162 | print('Found:\n' + '\n'.join(map(str, available_keys))) 163 | return sorted(self._iter_imgs_unordered_filter_size()) 164 | 165 | def update(self, force, verbose): 166 | """ 167 | Writes/updates to cache_p 168 | :param force: overwrites 169 | """ 170 | if not force and (self.images_spec, self.min_size) in self.cache: 171 | print('Cache already contains {}'.format(self.images_spec)) 172 | return 173 | print('Updating cache for {}...'.format(self.images_spec)) 174 | images = sorted(self._iter_imgs_unordered_filter_size(verbose)) 175 | if len(images) == 0: 176 | print('No images found...') 177 | return 178 | print('Found {} images...'.format(len(images))) 179 | self.cache[(self.images_spec, self.min_size)] = images 180 | print('Writing cache...') 181 | with open(self.cache_p, 'wb') as f: 182 | pickle.dump(self.cache, f) 183 | 184 | @staticmethod 185 | def print_all(cache_p): 186 | """ 187 | Print all cache entries to console 188 | :param cache_p: Path of the cache to print 189 | """ 190 | cache = ImagesCached._get_cache(cache_p) 191 | for key in list(cache.keys()): 192 | if len(cache[key]) == 0: 193 | del cache[key] 194 | for (p, min_size), imgs in cache.items(): 195 | min_size_str = ' (>={})'.format(min_size) if min_size else '' 196 | print('{}{}: {} images'.format(p, min_size_str, len(imgs))) 197 | 198 | def _iter_imgs_unordered_filter_size(self, verbose=False): 199 | for p in self._iter_imgs_unordered(self.images_spec): 200 | if self.min_size: 201 | img = Image.open(p) 202 | img_min_dim = min(img.size) 203 | if img_min_dim < self.min_size: 204 | print('Skipping {} ({})...'.format(p, img.size)) 205 | continue 206 | if verbose: 207 | print(p) 208 | yield p 209 | 210 | @staticmethod 211 | def _get_cache(cache_p): 212 | if not cache_p: 213 | return {} 214 | if not os.path.isfile(cache_p): 215 | print(f'cache_p={cache_p} does not exist.') 216 | return {} 217 | with open(cache_p, 'rb') as f: 218 | return pickle.load(f) 219 | 220 | @staticmethod 221 | def _iter_imgs_unordered(images_spec): 222 | if '*' in images_spec: 223 | matches = glob.glob(images_spec) 224 | _, ext = os.path.splitext(images_spec) 225 | if not ext: 226 | matches = (p for p in matches if has_image_ext(p)) 227 | elif not has_image_ext(images_spec): 228 | raise ValueError('Unrecognized extension {} in glob ({})'.format(ext, images_spec)) 229 | if not matches: 230 | raise ValueError('No matches for glob {}'.format(images_spec)) 231 | yield from matches 232 | return 233 | 234 | if has_image_ext(images_spec): 235 | yield from [images_spec] 236 | return 237 | 238 | # At this point, images_spec should be a path to a directory 239 | if not os.path.isdir(images_spec): 240 | raise NotADirectoryError(images_spec) 241 | 242 | print('Recursively traversing {}...'.format(images_spec)) 243 | for root, _, fnames in os.walk(images_spec, followlinks=True): 244 | for fname in fnames: 245 | if has_image_ext(fname): 246 | p = os.path.join(root, fname) 247 | if os.path.getsize(p) == 0: 248 | print('WARN / 0 bytes /', p) 249 | continue 250 | yield p 251 | 252 | def main(): 253 | # See README 254 | p = argparse.ArgumentParser() 255 | mode_parsers = p.add_subparsers(dest='mode') 256 | show_p = mode_parsers.add_parser('show') 257 | show_p.add_argument('cache_p') 258 | update_p = mode_parsers.add_parser('update') 259 | update_p.add_argument('images_spec') 260 | update_p.add_argument('cache_p') 261 | update_p.add_argument('--min_size', type=int) 262 | update_p.add_argument('--force', '-f', action='store_true') 263 | update_p.add_argument('--verbose', '-v', action='store_true', help='Print found paths. Might be slow!') 264 | 265 | flags = p.parse_args() 266 | if flags.mode == 'show': 267 | ImagesCached.print_all(flags.cache_p) 268 | elif flags.mode == 'update': 269 | ImagesCached(flags.images_spec, flags.cache_p, flags.min_size).update(flags.force, flags.verbose) 270 | else: 271 | p.print_usage() 272 | 273 | 274 | # Resize bicubic 275 | 276 | 277 | def resize_bicubic_batch(t, fac): 278 | assert len(t.shape) == 4 279 | N = t.shape[0] 280 | return torch.stack([resize_bicubic(t[n, ...], fac) for n in range(N)], dim=0) 281 | 282 | 283 | def resize_bicubic(t, fac): 284 | img = _tensor_to_image(t) # to PIL 285 | h, w = img.size 286 | img = img.resize((int(h * fac), int(w * fac)), Image.BICUBIC) 287 | t = to_tensor_not_normalized(img) # back to 3HW uint8 tensor 288 | return t 289 | 290 | 291 | def _tensor_to_image(t): 292 | assert t.shape[0] == 3, t.shape 293 | return Image.fromarray(t.permute(1, 2, 0).detach().cpu().numpy()) 294 | 295 | 296 | if __name__ == '__main__': 297 | main() 298 | 299 | -------------------------------------------------------------------------------- /src/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/src/helpers/__init__.py -------------------------------------------------------------------------------- /src/helpers/aligned_printer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | import itertools 20 | 21 | 22 | class AlignedPrinter(object): 23 | """ Print Rows nicely as a table. """ 24 | def __init__(self): 25 | self.rows = [] 26 | self.maxs = [] 27 | 28 | def append(self, *row): 29 | self.rows.append(row) 30 | self.maxs = [max(max_cur, len(row_entry)) 31 | for max_cur, row_entry in 32 | itertools.zip_longest(self.maxs, row, fillvalue=0)] 33 | 34 | def print(self): 35 | for row in self.rows: 36 | for width, row_entry in zip(self.maxs, row): 37 | print('{row_entry:{width}}'.format(row_entry=row_entry, width=width), end=' ') 38 | print() 39 | 40 | def __enter__(self): 41 | return self 42 | 43 | def __exit__(self, exc_type, exc_val, exc_tb): 44 | self.print() 45 | -------------------------------------------------------------------------------- /src/helpers/config_checker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | import os 20 | 21 | 22 | DEFAULT_CONFIG_DIR = 'configs' 23 | 24 | 25 | class ConfigsRepo(object): 26 | def __init__(self, config_dir=DEFAULT_CONFIG_DIR): 27 | self.config_dir = config_dir 28 | 29 | def check_configs_available(self, *config_ps): 30 | for p in config_ps: 31 | assert self.config_dir in p, 'Expected {} to contain {}!'.format(p, self.config_dir) 32 | if not os.path.isfile(p): 33 | raise FileNotFoundError(p) 34 | -------------------------------------------------------------------------------- /src/helpers/global_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | 19 | 20 | -------------------------------------------------------------------------------- 21 | 22 | 23 | Support for global config parameters shared over the whole program. The goal is to easily add parameters into some 24 | nested module without passing it all the way through, for fast prototyping. 25 | Also supports updating Configs returned by config_parser.py. 26 | 27 | 28 | Usage 29 | 30 | ``` 31 | from global_config import global_config 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('-p', action='append', nargs=1) 35 | ... 36 | flags = parser.parse_args() 37 | ... 38 | global_config.add_from_flag(flags.p) 39 | 40 | # in some module 41 | 42 | from global_config import global_config 43 | 44 | # check if loss is set, if not, use the default of 'mse' 45 | loss_fct = global_config.get('loss', default_value='mse') 46 | if loss_fct == 'mse': 47 | loss = MSE 48 | elif loss_fct == 'ce': 49 | loss = CrossEntropy 50 | 51 | # start the training 52 | python train.py -p loss=ce 53 | ``` 54 | 55 | 56 | """ 57 | from contextlib import contextmanager 58 | 59 | 60 | class _GlobalConfig(object): 61 | def __init__(self, default_unset_value=False, type_classes=None): 62 | """ 63 | :param default_unset_value: Value to use if global_config['key'] is called on a key that is not set 64 | :param type_classes: supported type classes that values are tried to convert to, see `_eval_value` 65 | """ 66 | if type_classes is None: 67 | type_classes = [int, float] 68 | self.type_classes = type_classes 69 | self.default_unset_value = default_unset_value 70 | self._values = {} 71 | self._used_params = set() 72 | 73 | def add_from_flag(self, param_flag): 74 | """ Add from a list containing key=value, or key (eqivalent to key=True). """ 75 | if param_flag is None: 76 | return 77 | for param_spec in param_flag: 78 | if isinstance(param_spec, list): 79 | assert len(param_spec) == 1 80 | param_spec = param_spec[0] 81 | self.add_param_from_spec(param_spec) 82 | 83 | def update_config(self, config): 84 | """ Update a fjcommon._Config returned by fjcommon.config_parser """ 85 | for k, v in config.all_params_and_values(): 86 | if k in self: 87 | print('Updating config.{} = {}'.format(k, self[k])) 88 | config.set_attr(k, self[k]) 89 | self.declare_used(k) 90 | 91 | def add_param_from_spec(self, spec): 92 | if '=' not in spec: 93 | spec = '{}=True'.format(spec) 94 | key, value = spec.split('=') 95 | key = key.strip() 96 | value = self._eval_value(value.strip()) 97 | self[key] = value 98 | 99 | def _eval_value(self, value): 100 | # try if `value` is True, False, or None, and return the actual instance 101 | if value in ('True', 'False', 'None'): 102 | return {'True': True, 103 | 'False': False, 104 | 'None': None}[value] 105 | 106 | # try casting to classes in type_classes 107 | for type_cls in self.type_classes: 108 | try: 109 | return type_cls(value) 110 | except ValueError: 111 | continue 112 | 113 | # finally, just interpret as string type 114 | if ' ' in value: 115 | raise ValueError('values are not allowed to contain spaces! {}'.format(value)) 116 | if '/' in value or '~' in value: 117 | raise ValueError('values are not allowed to contain "/" or "~"! {}'.format(value)) 118 | return value 119 | 120 | def __setitem__(self, key, value): 121 | self._values[key] = value 122 | 123 | def __getitem__(self, key): 124 | return self.get(key, self.default_unset_value) 125 | 126 | def __contains__(self, item): 127 | return item in self._values 128 | 129 | def get(self, key, default_value, incompatible=None): 130 | """ 131 | Check if `key` is set 132 | :param default_value: value to return if `key` is not set 133 | :param incompatible: list of keys which are not allowed to bet set if `key` is set 134 | :return: values[key] if key is set, `default_value` otherwise 135 | """ 136 | if incompatible and key in self._values: 137 | self._ensure_not_specified(key, incompatible) 138 | self._used_params.add(key) 139 | return self._values.get(key, default_value) 140 | 141 | def declare_used(self, *keys): 142 | """ Hack to mark all keys in `keys` as used, even if global_config[key] was never called """ 143 | self._used_params.update(keys) 144 | 145 | def get_unused_params(self): 146 | return [k for k in self._values.keys() if k not in self._used_params] 147 | 148 | def values(self): 149 | return list(self._values_to_spec()) 150 | 151 | def values_str(self, joiner=' '): 152 | return joiner.join(self.values()) 153 | 154 | def reset(self): # ugly, needed because global... 155 | """ Reset global_config. """ 156 | self._values = {} 157 | self._used_params = set() 158 | 159 | @contextmanager 160 | def reset_after(self): 161 | yield 162 | self.reset() 163 | 164 | def _values_to_spec(self): 165 | for k, v in sorted(self._values.items()): 166 | if v is True: 167 | yield k 168 | else: 169 | yield '{}={}'.format(k, v) 170 | 171 | def _ensure_not_specified(self, key, incompatible): 172 | """ Raises ValueError if any of the keys in `incompatible` were specified """ 173 | assert isinstance(incompatible, list) 174 | errors = [k for k in incompatible if k in self] 175 | if errors: 176 | raise ValueError(f"Got {key}, incompatible with: {','.join(errors)}") 177 | 178 | def __str__(self): 179 | if len(self._values) == 0: 180 | return 'GlobalConfig()' 181 | return 'GlobalConfig(\n\t{})'.format('\n\t'.join(self.values())) 182 | 183 | 184 | global_config = _GlobalConfig() 185 | 186 | -------------------------------------------------------------------------------- /src/helpers/logdir_helpers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | import glob 20 | from collections import namedtuple 21 | from datetime import datetime, timedelta 22 | 23 | import fasteners 24 | import re 25 | import os 26 | from os import path 27 | 28 | _LOG_DATE_FORMAT = "%m%d_%H%M" 29 | _RESTORE_PREFIX = 'r@' 30 | 31 | 32 | def create_unique_log_dir(config_rel_paths, log_dir_root, line_breaking_chars_pat=r'[-]', 33 | postfix=None, restore_dir=None, strip_ext=None): 34 | """ 35 | 0117_1704 repr@soa3_med_8e*5_deePer_b50_noHM_C16 repr@v2_res_shallow r@0115_1340 36 | :param config_rel_paths: paths to the configs, relative to the config root dir 37 | :param log_dir_root: In this directory, all log dirs are stored. Created if needed. 38 | :param line_breaking_chars_pat: 39 | :param postfix: appended to the returned log dir 40 | :param restore_dir: if given, expected to be a log dir. the JOB_ID of that will be appended 41 | :param strip_ext: if given, do not store extension `strip_ext` of config_rel_paths 42 | :return: path to a newly created directory 43 | """ 44 | if any('@' in config_rel_path for config_rel_path in config_rel_paths): 45 | raise ValueError('"@" not allowed in paths, got {}'.format(config_rel_paths)) 46 | 47 | if strip_ext: 48 | assert all(strip_ext in c for c in config_rel_paths) 49 | config_rel_paths = [c.replace(strip_ext, '') for c in config_rel_paths] 50 | 51 | def prep_path(p): 52 | p = p.replace(path.sep, '@') 53 | return re.sub(line_breaking_chars_pat, '*', p) 54 | 55 | postfix_dir_name = ' '.join(map(prep_path, config_rel_paths)) 56 | if restore_dir: 57 | _, restore_job_component = _split_log_dir(restore_dir) 58 | restore_job_id = log_date_from_log_dir(restore_job_component) 59 | postfix_dir_name += ' {restore_prefix}{job_id}'.format( 60 | restore_prefix=_RESTORE_PREFIX, job_id=restore_job_id) 61 | if postfix: 62 | if isinstance(postfix, list): 63 | postfix = ' '.join(postfix) 64 | postfix_dir_name += ' ' + postfix 65 | return _mkdir_threadsafe_unique(log_dir_root, datetime.now(), postfix_dir_name) 66 | 67 | 68 | LogDirComps = namedtuple('LogDirComps', ['config_paths', 'postfix']) 69 | 70 | 71 | def parse_log_dir(log_dir, configs_dir, base_dirs, append_ext=''): 72 | """ 73 | Given a log_dir produced by `create_unique_log_dir`, return the full paths of all configs used. 74 | The log dir has thus the following format 75 | {now} {netconfig} {probconfig} [r@XXXX_YYYY] [{postfix} {postfix}] 76 | 77 | :param log_dir: the log dir to parse 78 | :param configs_dir: the root config dir, where all the configs live 79 | :param base_dirs: Prefixed to the paths of the configs, e.g., ['ae', 'pc'] 80 | :return: all config paths, as well as the postfix if one was given 81 | """ 82 | base_dirs = [path.join(configs_dir, base_dir) for base_dir in base_dirs] 83 | log_dir = path.basename(log_dir.strip(path.sep)) 84 | 85 | comps = log_dir.split(' ') 86 | assert is_log_date(comps[0]), 'Invalid log_dir: {}'.format(log_dir) 87 | 88 | assert len(comps) > len(base_dirs), 'Expected a base dir for every component, got {} and {}'.format( 89 | comps, base_dirs) 90 | config_components = comps[1:(1+len(base_dirs))] 91 | has_restore = any(_RESTORE_PREFIX in c for c in comps) 92 | postfix = comps[1+len(base_dirs)+has_restore:] 93 | 94 | def get_real_path(base, prepped_p): 95 | p_glob = prepped_p.replace('@', path.sep) 96 | p_glob = path.join(base, p_glob) + append_ext # e.g., ae_configs/p_glob.cf 97 | glob_matches = glob.glob(p_glob) 98 | # We always only replace one character with *, so filter for those. 99 | # I.e. lr1e-5 will become lr1e*5, which will match lr1e-5 but also lr1e-4.5 100 | glob_matches_of_same_len = [g for g in glob_matches if len(g) == len(p_glob)] 101 | if len(glob_matches_of_same_len) != 1: 102 | raise ValueError('Cannot find config on disk: {} (matches: {})'.format(p_glob, glob_matches_of_same_len)) 103 | return glob_matches_of_same_len[0] 104 | 105 | return LogDirComps( 106 | config_paths=tuple(get_real_path(base_dir, comp) 107 | for base_dir, comp in zip(base_dirs, config_components)), 108 | postfix=tuple(postfix) if postfix else None) 109 | 110 | 111 | # ------------------------------------------------------------------------------ 112 | 113 | 114 | def _split_log_dir(log_dir): 115 | """ 116 | given 117 | some/path/to/job/dir/0101_1818 ae_config pc_config/ckpts 118 | or 119 | some/path/to/job/dir/0101_1818 ae_config pc_config 120 | returns 121 | tuple some/path/to/job/dir, 0101_1818 ae_config pc_config 122 | """ 123 | log_dir_root = [] 124 | job_component = None 125 | 126 | for comp in log_dir.split(path.sep): 127 | try: 128 | log_date_from_log_dir(comp) 129 | job_component = comp 130 | break # this component is an actual log dir. stop and return components 131 | except ValueError: 132 | log_dir_root.append(comp) 133 | 134 | assert job_component is not None, 'Invalid log_dir: {}'.format(log_dir) 135 | return path.sep.join(log_dir_root), job_component 136 | 137 | 138 | def _mkdir_threadsafe_unique(log_dir_root, log_date, postfix_dir_name): 139 | os.makedirs(log_dir_root, exist_ok=True) 140 | # Make sure only one process at a time writes into log_dir_root 141 | with fasteners.InterProcessLock(os.path.join(log_dir_root, 'lock')): 142 | return _mkdir_unique(log_dir_root, log_date, postfix_dir_name) 143 | 144 | 145 | def _mkdir_unique(log_dir_root, log_date, postfix_dir_name): 146 | log_date_str = log_date.strftime(_LOG_DATE_FORMAT) 147 | if _log_dir_with_log_date_exists(log_dir_root, log_date): 148 | print('Log dir starting with {} exists...'.format(log_date_str)) 149 | return _mkdir_unique(log_dir_root, log_date + timedelta(minutes=1), postfix_dir_name) 150 | 151 | log_dir = path.join(log_dir_root, '{log_date_str} {postfix_dir_name}'.format( 152 | log_date_str=log_date_str, 153 | postfix_dir_name=postfix_dir_name).strip()) 154 | os.makedirs(log_dir) 155 | return log_dir 156 | 157 | 158 | def _log_dir_with_log_date_exists(log_dir_root, log_date): 159 | log_date_str = log_date.strftime(_LOG_DATE_FORMAT) 160 | all_log_dates = set() 161 | for log_dir in os.listdir(log_dir_root): 162 | try: 163 | all_log_dates.add(log_date_from_log_dir(log_dir)) 164 | except ValueError: 165 | continue 166 | return log_date_str in all_log_dates 167 | 168 | 169 | def log_date_from_log_dir(log_dir): 170 | # extract {log_date} from LOG_DIR/{log_date} {netconfig} {probconfig} 171 | possible_log_date = os.path.basename(log_dir).split(' ')[0] 172 | if not is_log_date(possible_log_date): 173 | raise ValueError('Invalid log dir: {}'.format(log_dir)) 174 | return possible_log_date 175 | 176 | 177 | def is_log_dir(log_dir): 178 | try: 179 | log_date_from_log_dir(log_dir) 180 | return True 181 | except ValueError: 182 | return False 183 | 184 | 185 | def is_log_date(possible_log_date): 186 | try: 187 | datetime.strptime(possible_log_date, _LOG_DATE_FORMAT) 188 | return True 189 | except ValueError: 190 | return False 191 | -------------------------------------------------------------------------------- /src/helpers/pad.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | from fjcommon import functools_ext as ft 20 | from torch.nn import functional as F 21 | 22 | 23 | def pad(img, fac, mode='replicate'): 24 | """ 25 | pad img such that height and width are divisible by fac 26 | """ 27 | _, _, h, w = img.shape 28 | padH = fac - (h % fac) 29 | padW = fac - (w % fac) 30 | if padH == fac and padW == fac: 31 | return img, ft.identity 32 | if padH == fac: 33 | padTop = 0 34 | padBottom = 0 35 | else: 36 | padTop = padH // 2 37 | padBottom = padH - padTop 38 | if padW == fac: 39 | padLeft = 0 40 | padRight = 0 41 | else: 42 | padLeft = padW // 2 43 | padRight = padW - padLeft 44 | assert (padTop + padBottom + h) % fac == 0 45 | assert (padLeft + padRight + w) % fac == 0 46 | 47 | padding_tuple = (padLeft, padRight, padTop, padBottom) 48 | 49 | return F.pad(img, padding_tuple, mode), padding_tuple 50 | 51 | 52 | def undo_pad(img, padLeft, padRight, padTop, padBottom, target_shape=None): 53 | # the 'or None' makes sure that we don't get 0:0 54 | img_out = img[..., padTop:(-padBottom or None), padLeft:(-padRight or None)] 55 | if target_shape: 56 | h, w = target_shape 57 | assert img_out.shape[-2:] == (h, w), (img_out.shape[-2:], (h, w), img_.shape, 58 | (padLeft, padRight, padTop, padBottom)) 59 | return img_out 60 | -------------------------------------------------------------------------------- /src/helpers/paths.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | import os 20 | import glob 21 | 22 | from PIL import Image 23 | from torchvision import transforms as transforms 24 | 25 | from helpers import logdir_helpers 26 | 27 | from fjcommon.assertions import assert_exc 28 | 29 | import pytorch_ext as pe 30 | 31 | 32 | CKPTS_DIR_NAME = 'ckpts' 33 | VAL_DIR_NAME = 'val' 34 | 35 | 36 | IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'} 37 | 38 | 39 | def get_ckpts_dir(experiment_dir, ensure_exists=True): 40 | ckpts_p = os.path.join(experiment_dir, CKPTS_DIR_NAME) 41 | if ensure_exists: 42 | assert_exc(os.path.isdir(ckpts_p), 'Not found: {}'.format(ckpts_p)) 43 | return ckpts_p 44 | 45 | 46 | def get_experiment_dir(log_dir, experiment_spec): 47 | """ 48 | experiment_spec: if is a logdate, find correct full path in log_dir, otherwise assume logdir/experiment_spec exists 49 | :return experiment dir, no slash at the end. containing /ckpts 50 | """ 51 | if logdir_helpers.is_log_date(experiment_spec): # assume that log_dir/restore* matches 52 | assert_exc(log_dir is not None, 'Can only infer experiment_dir from log_date if log_dir is not None') 53 | restore_dir_glob = os.path.join(log_dir, experiment_spec + '*') 54 | restore_dir_possible = glob.glob(restore_dir_glob) 55 | assert_exc(len(restore_dir_possible) == 1, 'Expected one match for {}, got {}'.format( 56 | restore_dir_glob, restore_dir_possible)) 57 | experiment_spec = restore_dir_possible[0] 58 | else: 59 | experiment_spec = os.path.join(log_dir, experiment_spec) 60 | experiment_dir = experiment_spec.rstrip(os.path.sep) 61 | assert_exc(os.path.isdir(experiment_dir), 'Invalid experiment_dir: {}'.format(experiment_dir)) 62 | return experiment_dir 63 | 64 | 65 | def img_name(img_p): 66 | return os.path.splitext(os.path.basename(img_p))[0] 67 | 68 | 69 | def has_image_ext(p): 70 | return os.path.splitext(p)[1].lower() in IMG_EXTENSIONS 71 | 72 | -------------------------------------------------------------------------------- /src/helpers/rolling_buffer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | import numpy as np 20 | from fjcommon.assertions import assert_exc 21 | 22 | import pytorch_ext as pe 23 | 24 | 25 | class BufferSizeMismatch(Exception): 26 | pass 27 | 28 | 29 | class RollingBufferHistogram(object): 30 | """ 31 | Buffer that sets shape to be number of entries of first v it receives. 32 | Has function `plot` to plot histogram of data. 33 | """ 34 | def __init__(self, buffer_size, name=None): 35 | self._name = name or 'RollingBufferHistogram' 36 | self._buffer = None 37 | self._buffer_size = buffer_size 38 | self._idx = 0 39 | self._filled_idx = 0 40 | 41 | def add(self, v): 42 | v = pe.tensor_to_np(v) 43 | num_values = np.prod(v.shape) 44 | if self._buffer is None: 45 | print(f'Creating {v.dtype} buffer for {self._name}: {self._buffer_size}x{num_values}') 46 | self._buffer = np.zeros((self._buffer_size, num_values), dtype=v.dtype) 47 | assert_exc(self._buffer.shape[1] == num_values, (self._buffer.shape, v.shape, num_values), BufferSizeMismatch) 48 | self._buffer[self._idx, :] = v.flatten() 49 | self._idx = (self._idx + 1) % self._buffer_size 50 | self._filled_idx = min(self._filled_idx + 1, self._buffer_size) 51 | 52 | def get_buffer(self): 53 | return self._buffer[:self._filled_idx, :] 54 | 55 | def plot(self, bins='auto', most_mass=0): 56 | counts, bins = np.histogram(self._buffer[:self._filled_idx], bins) 57 | counts = counts / self._filled_idx # normalize it by number of entries 58 | idx_min, idx_max = _most_mass_indices(counts, most_mass) 59 | return bins[idx_min:idx_max], counts[idx_min:idx_max] 60 | 61 | 62 | def _most_mass_indices(a, mass=0): 63 | total_mass = np.sum(a) 64 | threshold = total_mass * mass 65 | indices, = (a > threshold).nonzero() # non zero returns a tuple of len a.ndim -> unpack! 66 | return indices[0], indices[-1] 67 | 68 | 69 | -------------------------------------------------------------------------------- /src/helpers/saver.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | import os 20 | import time 21 | import re 22 | import shutil 23 | 24 | import pytorch_ext as pe 25 | from os.path import basename 26 | import torch 27 | from torch.optim import optimizer 28 | from fjcommon.no_op import NoOp 29 | from fjcommon import timer 30 | from fjcommon.assertions import assert_exc 31 | 32 | 33 | class _CheckpointTracker(object): 34 | """ out_dir is usally set via set_out_dir """ 35 | def __init__(self, out_dir=None, ckpt_name_fmt='ckpt_{:010d}.pt', tmp_postfix='.tmp'): 36 | assert len(tmp_postfix) 37 | assert '.' in tmp_postfix 38 | m = re.search(r'{:0(\d+?)d}', ckpt_name_fmt) 39 | assert m, 'Expected ckpt_name_fmt to have an int specifier such as or {:09d} or {:010d}.' 40 | max_itr = 10 ** int(m.group(1)) - 1 41 | if max_itr < 10000000: # ten million, should be enough 42 | print(f'Maximum iteration supported: {max_itr}') 43 | assert os.sep not in ckpt_name_fmt 44 | self.ckpt_name_fmt = ckpt_name_fmt 45 | self.ckpt_prefix = ckpt_name_fmt.split('{')[0] 46 | assert len(self.ckpt_prefix), 'Expected ckpt_name_fmt to start with a prefix before the format part!' 47 | self.tmp_postfix = tmp_postfix 48 | 49 | self._out_dir = None 50 | if out_dir is not None: 51 | self.set_out_dir(out_dir) 52 | 53 | def set_out_dir(self, out_dir): 54 | assert self._out_dir is None 55 | os.makedirs(out_dir, exist_ok=True) 56 | self._out_dir = out_dir 57 | 58 | def get_all_ckpts(self): 59 | """ 60 | :return: All checkpoints in `self._out_dir`, sorted ascendingly by global_step. 61 | """ 62 | return [os.path.join(self._out_dir, f) 63 | for f in sorted(os.listdir(self._out_dir)) 64 | if f.startswith(self.ckpt_prefix)] 65 | 66 | def itr_ckpt(self): 67 | for ckpt_p in self.get_all_ckpts(): 68 | yield self.get_itr_from_ckpt_p(ckpt_p), ckpt_p 69 | 70 | def get_ckpt_for_itr(self, itr): 71 | """ 72 | Gets ckpt_itrc where itrc <= itr, i.e., the latest ckpt before `itr`. 73 | Special values: itr == -1 -> newest ckpt 74 | """ 75 | ckpts = list(self.itr_ckpt()) 76 | assert_exc(len(ckpts) > 0, 'No ckpts found in {}'.format(self._out_dir)) 77 | if itr == -1: 78 | return ckpts[-1] 79 | first_itrc, _ = ckpts[0] 80 | assert_exc(first_itrc <= itr, 'Earliest ckpt {} is after {}'.format(first_itrc, itr)) 81 | for itrc, ckpt_p in reversed(ckpts): 82 | if itrc <= itr: 83 | return itrc, ckpt_p 84 | raise ValueError('Unexpected, {}, {}'.format(itr, ckpts)) 85 | 86 | def get_latest_ckpt(self): 87 | """ 88 | :return: Most recent checkpoint. May be a temporary checkpoint. 89 | """ 90 | return self.get_all_ckpts()[-1] 91 | 92 | def get_lastest_persistent_ckpt(self): 93 | """ 94 | :return: Most recent persistent checkpoint. May be a temporary checkpoint. 95 | """ 96 | candidates = [p for p in self.get_all_ckpts() if not p.endswith(self.tmp_postfix)] 97 | if len(candidates) == 0: 98 | raise ValueError('No persistent checkpoints') 99 | return candidates[-1] 100 | 101 | def _get_out_p(self, global_step, is_tmp): 102 | postfix = self.tmp_postfix if is_tmp else '' 103 | return os.path.join(self._out_dir, self.ckpt_name_fmt.format(global_step) + postfix) 104 | 105 | def get_itr_from_ckpt_p(self, ckpt_p): 106 | file_name = os.path.splitext(os.path.basename(ckpt_p))[0] 107 | assert self.ckpt_prefix in file_name 108 | itr_part = file_name.replace(self.ckpt_prefix, '') 109 | itr_part_digits_only = int(''.join(c for c in itr_part if c.isdigit())) 110 | return itr_part_digits_only 111 | 112 | 113 | 114 | class Saver(_CheckpointTracker): 115 | """ 116 | Saves ckpts: 117 | - ckpt_XXXXXXXX.pt.tmp 118 | If keep_tmp_last=None: 119 | Every `keep_every`-th ckpt is renamed to 120 | - ckpt_XXXXXXXX.pt 121 | and kept, the intermediate ones are removed. We call this a persistent checkpoint. 122 | else: 123 | Let C be the most recent persistent checkpoint. 124 | In addition to C being kept, the last `keep_tmp_last` temporary checkpoints before C are also kept. 125 | This means that always `keep_tmp_last` more checkpoints are kept than if keep_tmp_last=None 126 | """ 127 | def __init__(self, 128 | keep_tmp_itr: int, keep_every=10, keep_tmp_last=None, 129 | out_dir=None, ckpt_name_fmt='ckpt_{:010d}.pt', tmp_postfix='.tmp', 130 | verbose=False): 131 | """ 132 | :param keep_every: keep every `keep_every`-th checkpoint, making it a persistent checkpoint 133 | :param keep_tmp_itr: keep checkpoint every `keep_tmp_itr` iterations. 134 | :param keep_tmp_last: Also keep the last `keep_tmp_last` temporary checkpoints before a persistent checkpoint. 135 | :param ckpt_name_fmt: filename, must include a format spec and some prefix before the format 136 | :param tmp_postfix: non-empty string to append to temporary checkpoints 137 | :param verbose: if True, print rename and remove info. 138 | """ 139 | self.keep_every = keep_every 140 | self.keep_tmp_last = keep_tmp_last 141 | self.keep_tmp_itr = keep_tmp_itr 142 | self.ckpts_since_last_permanent = 0 143 | self.print = print if verbose else NoOp 144 | self.save_time_acc = timer.TimeAccumulator() 145 | super(Saver, self).__init__(out_dir, ckpt_name_fmt, tmp_postfix) 146 | 147 | def save(self, modules, global_step, force=False): 148 | """ 149 | Save iff (force given or global_step % keep_tmp_itr == 0) 150 | :param modules: dictionary name -> nn.Module 151 | :param global_step: current step 152 | :return: bool, Whether previous checkpoints were removed 153 | """ 154 | if not (force or (global_step % self.keep_tmp_itr == 0)): 155 | return False 156 | assert self._out_dir is not None 157 | current_ckpt_p = self._save(modules, global_step) 158 | self.ckpts_since_last_permanent += 1 159 | if self.ckpts_since_last_permanent == self.keep_every: 160 | self._remove_previous(current_ckpt_p) 161 | self.ckpts_since_last_permanent = 0 162 | return True 163 | return False 164 | 165 | def _save(self, modules, global_step): 166 | out_p = self._get_out_p(global_step, is_tmp=True) 167 | with self.save_time_acc.execute(): 168 | torch.save({key: m.state_dict() for key, m in modules.items()}, out_p) 169 | return out_p 170 | 171 | def _remove_previous(self, current_ckpt_p): 172 | assert self.tmp_postfix in current_ckpt_p 173 | current_ckpt_p_non_tmp = current_ckpt_p.replace(self.tmp_postfix, '') 174 | self.print('{} -> {}'.format(basename(current_ckpt_p), basename(current_ckpt_p_non_tmp))) 175 | os.rename(current_ckpt_p, current_ckpt_p_non_tmp) 176 | keep_tmp_last = self.get_all_ckpts()[-(self.keep_tmp_last+1):] if self.keep_tmp_last else [] 177 | for p in self.get_all_ckpts(): 178 | if self.tmp_postfix in p and p not in keep_tmp_last: 179 | self.print('Removing {}...'.format(basename(p))) 180 | os.remove(p) 181 | self.print('Average save time: {:.3f}s'.format(self.save_time_acc.mean_time_spent())) 182 | 183 | 184 | class Restorer(_CheckpointTracker): 185 | def restore_latest_persistent(self, net): 186 | return self.restore(net, self.get_lastest_persistent_ckpt()) 187 | 188 | def restore(self, modules, ckpt_p, strict=True, restore_restart=False): 189 | print('Restoring {}... (strict={})'.format(ckpt_p, strict)) 190 | map_location = None if pe.CUDA_AVAILABLE else 'cpu' 191 | state_dicts = torch.load(ckpt_p, map_location=map_location) 192 | # --- 193 | for key, m in modules.items(): 194 | # optim implements its own load_state_dict which does not have the `strict` keyword... 195 | if isinstance(m, optimizer.Optimizer): 196 | if restore_restart: 197 | print('Not restoring optimizer, --restore_restart given...') 198 | else: 199 | try: 200 | m.load_state_dict(state_dicts[key]) 201 | except ValueError as e: 202 | raise ValueError('Error while restoring Optimizer:', str(e)) 203 | else: 204 | try: 205 | m.load_state_dict(state_dicts[key], strict=strict) 206 | except RuntimeError as e: # loading error 207 | for n, module in sorted(m.named_modules()): 208 | print(n, module) 209 | raise e 210 | return self.get_itr_from_ckpt_p(ckpt_p) 211 | 212 | 213 | 214 | -------------------------------------------------------------------------------- /src/helpers/testset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | import argparse 20 | import os 21 | import shutil 22 | from functools import total_ordering 23 | 24 | import numpy as np 25 | from fjcommon import os_ext 26 | from fjcommon.assertions import assert_exc 27 | 28 | from helpers.paths import has_image_ext, img_name 29 | 30 | 31 | @total_ordering 32 | class Testset(object): 33 | """ 34 | Class the holds a reference to paths of images inside a folder 35 | """ 36 | def __init__(self, root_dir_or_img, max_imgs=None, skip_hidden=False, append_id=None): 37 | """ 38 | :param root_dir_or_img: Either a directory with images or the path of a single image. 39 | :param max_imgs: If given, subsample deterministically to only contain max_imgs 40 | :param skip_hidden: If given, skip images starting with '.' 41 | :param append_id: If given, append `append_id` to self.id 42 | :raises ValueError if root_dir is not a directory or does not contain images 43 | """ 44 | self.root_dir_or_img = root_dir_or_img 45 | 46 | if os.path.isdir(root_dir_or_img): 47 | root_dir = root_dir_or_img 48 | self.name = os.path.basename(root_dir.rstrip('/')) 49 | self.ps = sorted(p for p in os_ext.listdir_paths(root_dir) if has_image_ext(p)) 50 | if skip_hidden: 51 | self.ps = self._filter_hidden(self.ps) 52 | if max_imgs and max_imgs < len(self.ps): 53 | print('Subsampling to use {} imgs of {}...'.format(max_imgs, self.name)) 54 | idxs = np.linspace(0, len(self.ps) - 1, max_imgs, dtype=np.int) 55 | self.ps = np.array(self.ps)[idxs].tolist() 56 | assert len(self.ps) == max_imgs 57 | assert_exc(len(self.ps) > 0, 'No images found in {}'.format(root_dir), ValueError) 58 | self.id = '{}_{}'.format(self.name, len(self.ps)) 59 | self._str = 'Testset({}): in {}, {} images'.format(self.name, root_dir, len(self.ps)) 60 | else: 61 | img = root_dir_or_img 62 | assert_exc(os.path.isfile(img), 'Does not exist: {}'.format(img), FileNotFoundError) 63 | self.name = os.path.basename(img) 64 | self.ps = [img] 65 | self.id = img 66 | self._str = 'Testset([{}]): 1 image'.format(self.name) 67 | if append_id: 68 | self.id += append_id 69 | 70 | def search_path(self): 71 | return self.root_dir_or_img 72 | 73 | def filter_filenames(self, filter_filenames): 74 | filename = lambda p: os.path.splitext(os.path.basename(p))[0] 75 | self.ps = [p for p in self.ps 76 | if filename(p) in filter_filenames] 77 | assert_exc(len(self.ps) > 0, 'No files after filtering for {}'.format(filter_filenames)) 78 | 79 | def iter_img_names(self): 80 | return map(img_name, self.ps) 81 | 82 | def iter_orig_paths(self): 83 | return self.ps 84 | 85 | def __len__(self): 86 | return len(self.ps) 87 | 88 | def __str__(self): 89 | return self._str 90 | 91 | def __repr__(self): 92 | return 'Testset({}): {} paths'.format(self.name, len(self.ps)) 93 | 94 | # to enable sorting 95 | def __lt__(self, other): 96 | return self.id < other.id 97 | 98 | @staticmethod 99 | def _filter_hidden(ps): 100 | count_a = len(ps) 101 | ps = [p for p in ps if not os.path.basename(p).startswith('.')] 102 | count_b = len(ps) 103 | if count_b < count_a: 104 | print(f'NOTE: Filtered {count_a - count_b} hidden file(s).') 105 | return ps 106 | 107 | 108 | def main(): 109 | p = argparse.ArgumentParser('Copy deterministic subset of images to another directory.') 110 | p.add_argument('root_dir') 111 | p.add_argument('max_imgs', type=int) 112 | p.add_argument('out_dir') 113 | p.add_argument('--dry') 114 | p.add_argument('--verbose', '-v', action='store_true') 115 | flags = p.parse_args() 116 | os.makedirs(flags.out_dir, exist_ok=True) 117 | 118 | t = Testset(flags.root_dir, flags.max_imgs) 119 | 120 | def cp(p1, p2): 121 | if os.path.isfile(p2): 122 | print('Exists, skipping: {}'.format(p2)) 123 | return 124 | if flags.verbose: 125 | print('cp {} -> {}'.format(p1, p2)) 126 | if not flags.dry: 127 | shutil.copy(p1, p2) 128 | 129 | for p in t.iter_orig_paths(): 130 | cp(p, os.path.join(flags.out_dir, os.path.basename(p))) -------------------------------------------------------------------------------- /src/import_train_images.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | import argparse 20 | import multiprocessing 21 | import os 22 | import random 23 | import shutil 24 | import time 25 | import warnings 26 | from os.path import join 27 | 28 | import PIL 29 | import numpy as np 30 | import skimage.color 31 | from PIL import Image 32 | 33 | from helpers.paths import IMG_EXTENSIONS 34 | 35 | # TO SPEED THINGS UP: run on CPU cluster! We use task_array for this. 36 | # task_array is not released. It's used by us to batch process on our servers. Feel free to replace with whatever you 37 | # use. Make sure to set NUM_TASKS (number of concurrent processes) and set job_enumerate to a function that takes an 38 | # iterable and only yield elements to be processed by the current process. 39 | try: 40 | from task_array import NUM_TASKS, job_enumerate 41 | except ImportError: 42 | NUM_TASKS = 1 43 | job_enumerate = enumerate 44 | 45 | warnings.filterwarnings("ignore") 46 | 47 | 48 | random.seed(123) 49 | 50 | 51 | _NUM_PROCESSES = int(os.environ.get('NUM_PROCESS', 16)) 52 | _DEFAULT_MAX_SCALE = 0.8 53 | 54 | 55 | get_fn = lambda p_: os.path.splitext(os.path.basename(p_))[0] 56 | 57 | 58 | def iter_images(root_dir, num_folder_levels=0): 59 | fns = sorted(os.listdir(root_dir)) 60 | for fn in fns: 61 | if num_folder_levels > 0: 62 | dir_p = os.path.join(root_dir, fn) 63 | if os.path.isdir(dir_p): 64 | print('Recursing into', fn) 65 | yield from iter_images(dir_p, num_folder_levels - 1) 66 | continue 67 | _, ext = os.path.splitext(fn) 68 | if ext.lower() in IMG_EXTENSIONS: 69 | yield os.path.join(root_dir, fn) 70 | 71 | 72 | class Helper(object): 73 | def __init__(self, out_dir_clean, out_dir_discard, min_res: int): 74 | print(f'Creating {out_dir_clean}, {out_dir_discard}...') 75 | os.makedirs(out_dir_clean, exist_ok=True) 76 | os.makedirs(out_dir_discard, exist_ok=True) 77 | self.out_dir_clean = out_dir_clean 78 | self.out_dir_discard = out_dir_discard 79 | 80 | print('Getting images already processed...', end=" ", flush=True) 81 | self.images_cleaned = set(map(get_fn, os.listdir(out_dir_clean))) 82 | self.images_discarded = set(map(get_fn, os.listdir(out_dir_discard))) 83 | print(f'-> Found {len(self.images_cleaned) + len(self.images_discarded)} images.') 84 | 85 | self.min_res = min_res 86 | 87 | def process_all_in(self, input_dir): 88 | images_dl = iter_images(input_dir) # generator of paths 89 | 90 | # files this job should compress 91 | files_of_job = [p for _, p in job_enumerate(images_dl)] 92 | # files that were compressed already by somebody (i.e. this job earlier) 93 | processed_already = self.images_cleaned | self.images_discarded 94 | # resulting files to be compressed 95 | files_of_job = [p for p in files_of_job if get_fn(p) not in processed_already] 96 | 97 | N = len(files_of_job) 98 | if N == 0: 99 | print('Everything processed / nothing to process.') 100 | return 101 | 102 | num_process = 2 if NUM_TASKS > 1 else _NUM_PROCESSES 103 | print(f'Processing {N} images using {num_process} processes in {NUM_TASKS} tasks...') 104 | 105 | start = time.time() 106 | predicted_time = None 107 | with multiprocessing.Pool(processes=num_process) as pool: 108 | for i, clean in enumerate(pool.imap_unordered(self.process, files_of_job)): 109 | if i > 0 and i % 100 == 0: 110 | time_per_img = (time.time() - start) / (i + 1) 111 | time_remaining = time_per_img * (N - i) 112 | if not predicted_time: 113 | predicted_time = time_remaining 114 | print(f'\r{time_per_img:.2e} s/img | ' 115 | f'{i / N * 100:.1f}% | ' 116 | f'{time_remaining / 60:.1f} min remaining', end='', flush=True) 117 | 118 | def process(self, p_in): 119 | fn, ext = os.path.splitext(os.path.basename(p_in)) 120 | if fn in self.images_cleaned: 121 | return 1 122 | if fn in self.images_discarded: 123 | return 0 124 | try: 125 | im = Image.open(p_in) 126 | except OSError as e: 127 | print(f'\n*** Error while opening {p_in}: {e}') 128 | return 0 129 | im_out = random_resize_or_discard(im, self.min_res) 130 | if im_out is not None: 131 | p_out = join(self.out_dir_clean, fn + '.png') # Make sure to use .png! 132 | im_out.save(p_out) 133 | return 1 134 | else: 135 | p_out = join(self.out_dir_discard, os.path.basename(p_in)) 136 | shutil.copy(p_in, p_out) 137 | return 0 138 | 139 | 140 | def random_resize_or_discard(im, min_res: int): 141 | """Randomly resize image with `random_resize` and check if it should be discarded.""" 142 | im_resized = random_resize(im, min_res) 143 | if im_resized is None: 144 | return None 145 | if should_discard(im_resized): 146 | return None 147 | return im_resized 148 | 149 | 150 | def random_resize(im, min_res: int, max_scale=_DEFAULT_MAX_SCALE): 151 | """Scale longer side to `min_res`, but only if that scales by <= max_scale.""" 152 | W, H = im.size 153 | D = min(W, H) 154 | scale_min = min_res / D 155 | # Image is too small to downscale by a factor smaller MAX_SCALE. 156 | if scale_min > max_scale: 157 | return None 158 | 159 | # Get a random scale for new size. 160 | scale = random.uniform(scale_min, max_scale) 161 | new_size = round(W * scale), round(H * scale) 162 | try: 163 | # Using LANCZOS! 164 | return im.resize(new_size, resample=PIL.Image.LANCZOS) 165 | except OSError as e: # Happens for corrupted images 166 | print('*** Caught im.resize error', e) 167 | return None 168 | 169 | 170 | def should_discard(im): 171 | """Return true iff the image is high in saturation or value, or not RGB.""" 172 | # Modes found in train_0: 173 | # Counter({'RGB': 152326, 'L': 4149, 'CMYK': 66}) 174 | if im.mode != 'RGB': 175 | return True 176 | im_rgb = np.array(im) 177 | im_hsv = skimage.color.rgb2hsv(im_rgb) 178 | mean_hsv = np.mean(im_hsv, axis=(0, 1)) 179 | _, s, v = mean_hsv 180 | if s > 0.9: 181 | return True 182 | if v > 0.8: 183 | return True 184 | return False 185 | 186 | 187 | def main(): 188 | p = argparse.ArgumentParser() 189 | p.add_argument('base_dir', 190 | help='Directory of images, or directory of DIRS.') 191 | p.add_argument('dirs', nargs='*', 192 | help='If given, must be subdirectories in BASE_DIR. Will be processed. ' 193 | 'If not given, assume BASE_DIR is already a directory of images.') 194 | p.add_argument('--out_dir_clean', required=True) 195 | p.add_argument('--out_dir_discard', required=True) 196 | p.add_argument('--resolution', type=int, default=512, 197 | help='Randomly rescale each image to be at least ' 198 | 'RANDOM_SCALE long on the longer side.') 199 | 200 | flags = p.parse_args() 201 | 202 | # If --dirs not given, just assume `base_dir` is already the directory of images. 203 | if not flags.dirs: 204 | flags.dirs = [os.path.basename(flags.base_dir)] 205 | flags.base_dir = os.path.dirname(flags.base_dir) 206 | 207 | h = Helper(flags.out_dir_clean, flags.out_dir_discard, flags.resolution) 208 | for i, d in enumerate(flags.dirs): 209 | print(f'*** {d}: {i}/{len(flags.dirs)}') 210 | h.process_all_in(join(flags.base_dir, d)) 211 | 212 | print('\n\nDONE') # For cluster logs. 213 | 214 | 215 | if __name__ == '__main__': 216 | main() 217 | -------------------------------------------------------------------------------- /src/import_train_images_v1.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | import os 20 | import random 21 | from os.path import join 22 | import PIL 23 | from PIL import Image 24 | import skimage.color 25 | import numpy as np 26 | import argparse 27 | import warnings 28 | # task_array is not released. It's used by us to batch process on our servers. Feel free to replace with whatever you 29 | # use. Make sure to set NUM_TASKS (number of concurrent processes) and set job_enumerate to a function that takes an 30 | # iterable and only yield elements to be processed by the current process. 31 | try: 32 | from task_array import NUM_TASKS, job_enumerate 33 | except ImportError: 34 | NUM_TASKS = 1 35 | job_enumerate = enumerate 36 | 37 | warnings.filterwarnings("ignore") 38 | 39 | 40 | QUALITY = 95 41 | MAX_SCALE = 0.95 42 | 43 | 44 | def main(): 45 | p = argparse.ArgumentParser() 46 | p.add_argument('base_dir') 47 | p.add_argument('dirs', nargs='+') 48 | p.add_argument('--out_dir_clean', required=True) 49 | p.add_argument('--out_dir_discard', required=True) 50 | 51 | p.add_argument('--resolution', '-r', type=str, default='768', help='can be randX_Y') 52 | p.add_argument('--seed') 53 | flags = p.parse_args() 54 | 55 | if flags.seed: 56 | print('Seeding {}'.format(flags.seed)) 57 | random.seed(flags.seed) 58 | 59 | for d in flags.dirs: 60 | process(join(flags.base_dir, d), flags.out_dir_clean, flags.out_dir_discard, flags.resolution) 61 | 62 | 63 | def get_res(res) -> int: 64 | try: 65 | return int(res) 66 | except ValueError: 67 | pass 68 | if not res.startswith('rand'): 69 | raise ValueError('Expected res to be either int or `randX_Y`') 70 | X, Y = map(int, res.replace('rand', '').split('_')) 71 | return random.randint(X, Y) 72 | 73 | 74 | def process(input_dir, out_dir_clean, out_dir_discard, res: str): 75 | os.makedirs(out_dir_clean, exist_ok=True) 76 | os.makedirs(out_dir_discard, exist_ok=True) 77 | 78 | images_cleaned = set(os.listdir(out_dir_clean)) 79 | images_discarded = set(os.listdir(out_dir_discard)) 80 | 81 | images_dl = os.listdir(input_dir) 82 | N = len(images_dl) // NUM_TASKS 83 | 84 | clean = 0 85 | discarded = 0 86 | for i, imfile in job_enumerate(images_dl): 87 | if imfile in images_cleaned: 88 | clean += 1 89 | continue 90 | if imfile in images_discarded: 91 | discarded += 1 92 | continue 93 | im = Image.open(join(input_dir, imfile)) 94 | res = get_res(res) 95 | im2 = resize_or_discard(im, res, should_clean=True) 96 | if im2 is not None: 97 | fn, ext = os.path.splitext(imfile) 98 | im2.save(join(out_dir_clean, fn + '.jpg'), quality=QUALITY) 99 | clean += 1 100 | else: 101 | im.save(join(out_dir_discard, imfile)) 102 | discarded += 1 103 | print(f'\r{os.path.basename(input_dir)} -> {os.path.basename(out_dir_clean)} // ' 104 | f'Resized: {clean}/{N}; Discarded: {discarded}/{N}', end='') 105 | # Done 106 | print(f'\n{os.path.basename(input_dir)} -> {os.path.basename(out_dir_clean)} // ' 107 | f'Resized: {clean}/{N}; Discarded: {discarded}/{N}') 108 | 109 | 110 | def resize_or_discard(im, res: int, verbose=False, should_clean=True): 111 | im2 = resize(im, res, verbose) 112 | if im2 is None: 113 | return None 114 | if should_clean and should_discard(im2, verbose): 115 | return None 116 | return im2 117 | 118 | 119 | def resize(im, res, verbose=False, max_scale=MAX_SCALE): 120 | W, H = im.size 121 | D = max(W, H) 122 | s = float(res) / D 123 | if max_scale and s > max_scale: 124 | if verbose: 125 | print('Too big: {}'.format((W, H))) 126 | return None 127 | W2 = round(W * s) 128 | H2 = round(H * s) 129 | try: 130 | return im.resize((W2, H2), resample=PIL.Image.BICUBIC) 131 | except OSError as e: 132 | print(e) 133 | return None 134 | 135 | 136 | def should_discard(im, verbose=False): 137 | im_rgb = np.array(im) 138 | if im_rgb.ndim != 3 or im_rgb.shape[2] != 3: 139 | if verbose: 140 | print('Invalid shape: {}'.format(im_rgb.shape)) 141 | return True 142 | im_hsv = skimage.color.rgb2hsv(im_rgb) 143 | mean_hsv = np.mean(im_hsv, axis=(0, 1)) 144 | h, s, v = mean_hsv 145 | if s > 0.9: 146 | if verbose: 147 | print('Invalid s: {}'.format(s)) 148 | return True 149 | if v > 0.8: 150 | if verbose: 151 | print('Invalid v: {}'.format(v)) 152 | return True 153 | return False 154 | 155 | 156 | def get_hsv(im): 157 | im_rgb = np.array(im) 158 | im_hsv = skimage.color.rgb2hsv(im_rgb) 159 | mean_hsv = np.mean(im_hsv, axis=(0, 1)) 160 | h, s, v = mean_hsv 161 | return h, s, v 162 | 163 | 164 | if __name__ == '__main__': 165 | main() 166 | -------------------------------------------------------------------------------- /src/l3c.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | import torch.backends.cudnn 20 | torch.backends.cudnn.benchmark = False 21 | torch.backends.cudnn.deterministic = True 22 | 23 | import pytorch_ext as pe 24 | 25 | import argparse 26 | 27 | from test.multiscale_tester import MultiscaleTester, EncodeError, DecodeError 28 | 29 | 30 | # throws an exception if no backend available! 31 | from torchac import torchac 32 | 33 | 34 | class _FakeFlags(object): 35 | def __init__(self, flags): 36 | self.flags = flags 37 | 38 | def __getattr__(self, item): 39 | try: 40 | return self.flags.__dict__[item] 41 | except KeyError: 42 | return None 43 | 44 | 45 | def parse_device_flag(flag): 46 | print(f'Status: ' 47 | f'torchac-backend-gpu available: {torchac.CUDA_SUPPORTED} // ' 48 | f'torchac-backend-cpu available: {torchac.CPU_SUPPORTED} // ' 49 | f'CUDA available: {pe.CUDA_AVAILABLE}') 50 | if flag == 'auto': 51 | if torchac.CUDA_SUPPORTED and pe.CUDA_AVAILABLE: 52 | flag = 'gpu' 53 | elif torchac.CPU_SUPPORTED: 54 | flag = 'cpu' 55 | else: 56 | raise ValueError('No suitable backend found!') 57 | 58 | if flag == 'cpu' and not torchac.CPU_SUPPORTED: 59 | raise ValueError('torchac-backend-cpu is not available. Please install.') 60 | if flag == 'gpu' and not torchac.CUDA_SUPPORTED: 61 | raise ValueError('torchac-backend-gpu is not available. Please install.') 62 | if flag == 'gpu' and not pe.CUDA_AVAILABLE: 63 | raise ValueError('Selected torchac-backend-gpu but CUDA not available!') 64 | 65 | assert flag in ('gpu', 'cpu') 66 | if flag == 'gpu' and not pe.CUDA_AVAILABLE: 67 | raise ValueError('Selected GPU backend but cuda is not available!') 68 | 69 | pe.CUDA_AVAILABLE = flag == 'gpu' 70 | pe.set_device(pe.CUDA_AVAILABLE) 71 | print(f'*** Using torchac-backend-{flag}; did set CUDA_AVAILABLE={pe.CUDA_AVAILABLE} and DEVICE={pe.DEVICE}') 72 | 73 | 74 | def main(): 75 | p = argparse.ArgumentParser(description='Encoder/Decoder for L3C') 76 | 77 | p.add_argument('log_dir', help='Directory of experiments.') 78 | p.add_argument('log_date', help='A log_date, such as 0104_1345.') 79 | 80 | p.add_argument('--device', type=str, choices=['auto', 'gpu', 'cpu'], default='auto', 81 | help='Select the device to run this code on, as mentioned in the README section "Selecting torchac". ' 82 | 'If DEVICE=auto, select torchac-backend depending on whether torchac-backend-gpu or -cpu is ' 83 | 'available. See function parse_device_flag for details.' 84 | 'If DEVICE=gpu or =cpu, force usage of this backend.') 85 | 86 | p.add_argument('--restore_itr', '-i', default=-1, type=int, 87 | help='Which iteration to restore. -1 means latest iteration. Default: -1') 88 | 89 | mode = p.add_subparsers(title='mode', dest='mode') 90 | 91 | enc = mode.add_parser('enc', help='Encode image. enc IMG_P OUT_P [--overwrite | -f]\n' 92 | ' IMG_P: Path to an Image, readable by PIL.\n' 93 | ' OUT_P: Path to where to save the bitstream.\n' 94 | ' OVERWRITE: If given, overwrite OUT_P.' 95 | 'Example:\n' 96 | ' python l3c.py LOG_DIR LOG_DATE enc some/img.jpg out/img.l3c') 97 | dec = mode.add_parser('dec', help='Decode image. dec IMG_P OUT_P_PNG\n' 98 | ' IMG_P: Path to an L3C-encoded Image, readable by PIL.\n' 99 | ' OUT_P_PNG: Path to where to save the decode image to, as a PNG.\n' 100 | 'Example:\n' 101 | ' python l3c.py LOG_DIR LOG_DATE dec out/img.l3c decoded.png') 102 | 103 | enc.add_argument('img_p') 104 | enc.add_argument('out_p') 105 | enc.add_argument('--overwrite', '-f', action='store_true') 106 | 107 | dec.add_argument('img_p') 108 | dec.add_argument('out_p_png') 109 | 110 | flags = p.parse_args() 111 | parse_device_flag(flags.device) 112 | 113 | print('Testing {} at {} ---'.format(flags.log_date, flags.restore_itr)) 114 | tester = MultiscaleTester(flags.log_date, _FakeFlags(flags), flags.restore_itr, l3c=True) 115 | 116 | if flags.mode == 'enc': 117 | try: 118 | tester.encode(flags.img_p, flags.out_p, flags.overwrite) 119 | except EncodeError as e: 120 | print('*** EncodeError:', e) 121 | else: 122 | try: 123 | tester.decode(flags.img_p, flags.out_p_png) 124 | except DecodeError as e: 125 | print('*** DecodeError:', e) 126 | 127 | 128 | if __name__ == '__main__': 129 | main() 130 | -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/src/modules/__init__.py -------------------------------------------------------------------------------- /src/modules/edsr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | # Code adapted from src/model/common.py of this repo: 20 | # 21 | # https://github.com/thstkdgus35/EDSR-PyTorch 22 | # 23 | # Original license: 24 | # 25 | # MIT License 26 | # 27 | # Copyright (c) 2018 Sanghyun Son 28 | # 29 | # Permission is hereby granted, free of charge, to any person obtaining a copy 30 | # of this software and associated documentation files (the "Software"), to deal 31 | # in the Software without restriction, including without limitation the rights 32 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 33 | # copies of the Software, and to permit persons to whom the Software is 34 | # furnished to do so, subject to the following conditions: 35 | # 36 | # The above copyright notice and this permission notice shall be included in all 37 | # copies or substantial portions of the Software. 38 | # 39 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 40 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 41 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 42 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 43 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 44 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 45 | # SOFTWARE. 46 | 47 | import math 48 | 49 | import torch 50 | import torch.nn as nn 51 | 52 | class MeanShift(nn.Conv2d): 53 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): 54 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 55 | std = torch.Tensor(rgb_std) 56 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 57 | self.weight.data.div_(std.view(3, 1, 1, 1)) 58 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 59 | self.bias.data.div_(std) 60 | self.requires_grad = False 61 | 62 | 63 | class ResBlock(nn.Module): 64 | def __init__(self, conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), atrous=False): 65 | 66 | super(ResBlock, self).__init__() 67 | m = [] 68 | _repr = [] 69 | for i in range(2): 70 | atrous_rate = 1 if (not atrous or i == 0) else atrous 71 | m.append(conv(n_feats, n_feats, kernel_size, rate=atrous_rate, bias=bias)) 72 | _repr.append(f'Conv({n_feats}x{kernel_size}' + (f';A*{atrous_rate})' if atrous_rate != 1 else '') + ')') 73 | if bn: 74 | m.append(nn.BatchNorm2d(n_feats)) 75 | _repr.append(f'BN({n_feats})') 76 | if i == 0: 77 | m.append(act) 78 | _repr.append(f'Act') 79 | 80 | self.body = nn.Sequential(*m) 81 | self._repr = '/'.join(_repr) 82 | 83 | def forward(self, x): 84 | res = self.body(x) 85 | res += x 86 | return res 87 | 88 | def __repr__(self): 89 | return f'ResBlock({self._repr})' 90 | 91 | 92 | class Upsampler(nn.Sequential): 93 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 94 | 95 | m = [] 96 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 97 | for _ in range(int(math.log(scale, 2))): 98 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 99 | m.append(nn.PixelShuffle(2)) 100 | if bn: m.append(nn.BatchNorm2d(n_feats)) 101 | 102 | if act == 'relu': 103 | m.append(nn.ReLU(True)) 104 | elif act == 'prelu': 105 | m.append(nn.PReLU(n_feats)) 106 | 107 | elif scale == 3: 108 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 109 | m.append(nn.PixelShuffle(3)) 110 | if bn: m.append(nn.BatchNorm2d(n_feats)) 111 | 112 | if act == 'relu': 113 | m.append(nn.ReLU(True)) 114 | elif act == 'prelu': 115 | m.append(nn.PReLU(n_feats)) 116 | else: 117 | raise NotImplementedError 118 | 119 | super(Upsampler, self).__init__(*m) 120 | 121 | -------------------------------------------------------------------------------- /src/modules/head.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | from torch import nn 20 | 21 | from modules import edsr 22 | from pytorch_ext import default_conv as conv 23 | 24 | 25 | 26 | class RGBHead(nn.Module): 27 | """ Go from 3 channels (RGB) to Cf channels, also normalize RGB """ 28 | def __init__(self, config_ms): 29 | super(RGBHead, self).__init__() 30 | assert 'Subsampling' not in config_ms.enc.cls, 'For Subsampling encoders, head should be ID' 31 | self.head = nn.Sequential( 32 | edsr.MeanShift(0, (0., 0., 0.), (128., 128., 128.)), 33 | Head(config_ms, Cin=3)) 34 | self._repr = 'MeanShift//Head(C=3)' 35 | 36 | def __repr__(self): 37 | return f'RGBHead({self._repr})' 38 | 39 | def forward(self, x): 40 | return self.head(x) 41 | 42 | 43 | class Head(nn.Module): 44 | """ 45 | Go from Cin channels to Cf channels. 46 | For L3C, Cin=Cf, and this is the convolution yielding E^{s+1}_in in Fig. 2. 47 | 48 | """ 49 | def __init__(self, config_ms, Cin): 50 | super(Head, self).__init__() 51 | assert 'Subsampling' not in config_ms.enc.cls, 'For Subsampling encoders, head should be ID' 52 | self.head = conv(Cin, config_ms.Cf, config_ms.kernel_size) 53 | self._repr = f'Conv({config_ms.Cf})' 54 | 55 | def __repr__(self): 56 | return f'Head({self._repr})' 57 | 58 | def forward(self, x): 59 | return self.head(x) 60 | 61 | 62 | -------------------------------------------------------------------------------- /src/modules/net.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | 19 | -------------------------------------------------------------------------------- 20 | 21 | File for Net, which contains encoder/decoder of MultiscaleNetwork 22 | 23 | """ 24 | from collections import namedtuple 25 | 26 | import torch 27 | from torch import nn 28 | 29 | import pytorch_ext as pe 30 | import vis.histogram_plot 31 | import vis.summarizable_module 32 | from dataloaders.images_loader import resize_bicubic_batch 33 | from modules import edsr 34 | from modules.quantizer import Quantizer 35 | 36 | EncOut = namedtuple('EncOut', ['bn', # NCH'W' 37 | 'bn_q', # quantized bn, NCH'W' 38 | 'S', # NCH'W', long 39 | 'L', # int 40 | 'F' # NCfH'W', float, before Q 41 | ]) 42 | DecOut = namedtuple('DecOut', ['F', # NCfHW 43 | ]) 44 | 45 | 46 | conv = pe.default_conv 47 | 48 | 49 | class Net(nn.Module): 50 | def __init__(self, config_ms, scale): 51 | super(Net, self).__init__() 52 | self.config_ms = config_ms 53 | self.enc = { 54 | 'EDSRLikeEnc': EDSRLikeEnc, 55 | 'BicubicSubsampling': BicubicDownsamplingEnc, 56 | }[config_ms.enc.cls](config_ms, scale) 57 | self.dec = { 58 | 'EDSRDec': EDSRDec 59 | }[config_ms.dec.cls](config_ms, scale) 60 | 61 | def forward(self, x): 62 | raise NotImplementedError() # Call .enc and .dec directly 63 | 64 | 65 | class BicubicDownsamplingEnc(vis.summarizable_module.SummarizableModule): 66 | def __init__(self, *_): 67 | super(BicubicDownsamplingEnc, self).__init__() 68 | # TODO: ugly 69 | self.rgb_mean = torch.tensor( 70 | [0.4488, 0.4371, 0.4040], dtype=torch.float32).reshape(3, 1, 1).mul(255.).to(pe.DEVICE) 71 | 72 | def forward(self, x): 73 | x = x + self.rgb_mean # back to 0...255 74 | x = x.clamp(0, 255.).round().type(torch.uint8) 75 | x = resize_bicubic_batch(x, 0.5).to(pe.DEVICE) 76 | sym = x.long() 77 | x = x.float() - self.rgb_mean 78 | x = x.detach() # make sure no gradients back to this point 79 | self.summarizer.register_images('train', {'input_subsampled': lambda: sym.type(torch.uint8)}) 80 | return EncOut(x, x, sym, 256, None) 81 | 82 | 83 | def new_levels(L, initial_levels): 84 | lo, hi = initial_levels 85 | levels = torch.linspace(lo, hi, L) 86 | return torch.tensor(levels, requires_grad=False) 87 | 88 | 89 | class EDSRLikeEnc(vis.summarizable_module.SummarizableModule): 90 | def __init__(self, config_ms, scale): 91 | super(EDSRLikeEnc, self).__init__() 92 | 93 | self.scale = scale 94 | self.config_ms = config_ms 95 | Cf = config_ms.Cf 96 | kernel_size = config_ms.kernel_size 97 | C, self.L = config_ms.q.C, config_ms.q.L 98 | 99 | n_resblock = config_ms.enc.num_blocks 100 | 101 | # Downsampling 102 | self.down = conv(Cf, Cf, kernel_size=5, stride=2) 103 | 104 | # Body 105 | m_body = [ 106 | edsr.ResBlock(conv, Cf, kernel_size, act=nn.ReLU(True)) 107 | for _ in range(n_resblock) 108 | ] 109 | m_body.append(conv(Cf, Cf, kernel_size)) 110 | self.body = nn.Sequential(*m_body) 111 | 112 | # to Quantizer 113 | to_q = [conv(Cf, C, 1)] 114 | if self.training: 115 | to_q.append( 116 | # start scale from 1, as 0 is RGB 117 | vis.histogram_plot.HistogramPlot('train', 'histo/enc_{}_after_1x1'.format(scale+1), buffer_size=10, 118 | num_inputs_to_buffer=1, per_channel=False)) 119 | self.to_q = nn.Sequential(*to_q) 120 | 121 | # We assume q.L levels, evenly distributed between q.levels_range[0] and q.levels_range[1] 122 | # In theory, the levels could be learned. But in this code, they are assumed to be fixed. 123 | levels_first, levels_last = config_ms.q.levels_range 124 | # Wrapping this in a nn.Parameter ensures it is copied to gpu when .to('cuda') is called 125 | self.levels = nn.Parameter(torch.linspace(levels_first, levels_last, self.L), requires_grad=False) 126 | self._extra_repr = 'Levels={}'.format(','.join(map('{:.1f}'.format, list(self.levels)))) 127 | self.q = Quantizer(self.levels, config_ms.q.sigma) 128 | 129 | def extra_repr(self): 130 | return self._extra_repr 131 | 132 | def quantize_x(self, x): 133 | _, x_hard, _ = self.q(x) 134 | return x_hard 135 | 136 | def forward(self, x): 137 | """ 138 | :param x: NCHW 139 | :return: 140 | """ 141 | x = self.down(x) 142 | x = self.body(x) + x 143 | F = x 144 | x = self.to_q(x) 145 | # assert self.summarizer is not None 146 | x_soft, x_hard, symbols_hard = self.q(x) 147 | # TODO(parallel): To support nn.DataParallel, this must be changed, as it not a tensor 148 | return EncOut(x_soft, x_hard, symbols_hard, self.L, F) 149 | 150 | 151 | class EDSRDec(nn.Module): 152 | def __init__(self, config_ms, scale): 153 | super(EDSRDec, self).__init__() 154 | 155 | self.scale = scale 156 | n_resblock = config_ms.dec.num_blocks 157 | 158 | Cf = config_ms.Cf 159 | kernel_size = config_ms.kernel_size 160 | C = config_ms.q.C 161 | 162 | after_q_kernel = 1 163 | self.head = conv(C, config_ms.Cf, after_q_kernel) 164 | m_body = [ 165 | edsr.ResBlock(conv, Cf, kernel_size, act=nn.ReLU(True)) 166 | for _ in range(n_resblock) 167 | ] 168 | 169 | m_body.append(conv(Cf, Cf, kernel_size)) 170 | self.body = nn.Sequential(*m_body) 171 | self.tail = edsr.Upsampler(conv, 2, Cf, act=False) 172 | 173 | def forward(self, x, features_to_fuse=None): 174 | """ 175 | :param x: NCHW 176 | :return: 177 | """ 178 | x = self.head(x) 179 | if features_to_fuse is not None: 180 | x = x + features_to_fuse 181 | x = self.body(x) + x 182 | x = self.tail(x) 183 | # TODO(parallel): To support nn.DataParallel, this must be changed, as it not a tensor 184 | return DecOut(x) 185 | 186 | 187 | -------------------------------------------------------------------------------- /src/modules/prob_clf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | import torch 20 | from torch import nn, nn as nn 21 | 22 | import pytorch_ext as pe 23 | from criterion.logistic_mixture import non_shared_get_Kp 24 | 25 | conv = pe.default_conv 26 | 27 | 28 | 29 | class AtrousProbabilityClassifier(nn.Module): 30 | def __init__(self, config_ms, C=3, atrous_rates_str='1,2,4'): 31 | super(AtrousProbabilityClassifier, self).__init__() 32 | 33 | K = config_ms.prob.K 34 | Kp = non_shared_get_Kp(K, C) 35 | 36 | self.atrous = StackedAtrousConvs(atrous_rates_str, config_ms.Cf, Kp, 37 | kernel_size=config_ms.kernel_size) 38 | self._repr = f'C={C}; K={K}; Kp={Kp}; rates={atrous_rates_str}' 39 | 40 | def __repr__(self): 41 | return f'AtrousProbabilityClassifier({self._repr})' 42 | 43 | def forward(self, x): 44 | """ 45 | :param x: NCfHW 46 | :return: NKpHW 47 | """ 48 | return self.atrous(x) 49 | 50 | 51 | class StackedAtrousConvs(nn.Module): 52 | def __init__(self, atrous_rates_str, Cin, Cout, bias=True, kernel_size=3): 53 | super(StackedAtrousConvs, self).__init__() 54 | atrous_rates = self._parse_atrous_rates_str(atrous_rates_str) 55 | self.atrous = nn.ModuleList( 56 | [conv(Cin, Cin, kernel_size, rate=rate) for rate in atrous_rates]) 57 | self.lin = conv(len(atrous_rates) * Cin, Cout, 1, bias=bias) 58 | self._extra_repr = 'rates={}'.format(atrous_rates) 59 | 60 | @staticmethod 61 | def _parse_atrous_rates_str(atrous_rates_str): 62 | # expected to either be an int or a comma-separated string 1,2,4 63 | if isinstance(atrous_rates_str, int): 64 | return [atrous_rates_str] 65 | else: 66 | return list(map(int, atrous_rates_str.split(','))) 67 | 68 | def extra_repr(self): 69 | return self._extra_repr 70 | 71 | def forward(self, x): 72 | x = torch.cat([atrous(x) for atrous in self.atrous], dim=1) 73 | x = self.lin(x) 74 | return x -------------------------------------------------------------------------------- /src/modules/quantizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | 19 | -------------------------------------------------------------------------------- 20 | 21 | Based on our TensorFlow implementation for the CVPR 2018 paper 22 | 23 | "Conditional Probability Models for Deep Image Compression" 24 | 25 | This is a PyTorch implementation of that quantization layer. 26 | 27 | https://github.com/fab-jul/imgcomp-cvpr/blob/master/code/quantizer.py 28 | 29 | """ 30 | import torch 31 | import torch.nn.functional as F 32 | from torch import nn 33 | 34 | 35 | SIGMA_HARD = 1e7 36 | 37 | 38 | def to_sym(x, x_min, x_max, L): 39 | sym_range = x_max - x_min 40 | bin_size = sym_range / (L-1) 41 | return x.clamp(x_min, x_max).sub(x_min).div(bin_size).round().long() 42 | 43 | 44 | def to_bn(S, x_min, x_max, L): 45 | sym_range = x_max - x_min 46 | bin_size = sym_range / (L-1) 47 | return S.float().mul(bin_size).add(x_min) 48 | 49 | 50 | class Quantizer(nn.Module): 51 | def __init__(self, levels, sigma=1.0): 52 | super(Quantizer, self).__init__() 53 | assert levels.dim() == 1, 'Expected 1D levels, got {}'.format(levels) 54 | self.levels = levels 55 | self.sigma = sigma 56 | self.L = self.levels.size()[0] 57 | 58 | def __repr__(self): 59 | return '{}(sigma={})'.format( 60 | self._get_name(), self.sigma) 61 | 62 | def forward(self, x): 63 | """ 64 | :param x: NCHW 65 | :return:, x_soft, symbols 66 | """ 67 | assert x.dim() == 4, 'Expected NCHW, got {}'.format(x.size()) 68 | N, C, H, W = x.shape 69 | # make x into NCm1, where m=H*W 70 | x = x.view(N, C, H*W, 1) 71 | # NCmL, d[..., l] gives distance to l-th level 72 | d = torch.pow(x - self.levels, 2) 73 | # NCmL, \sum_l d[..., l] sums to 1 74 | phi_soft = F.softmax(-self.sigma * d, dim=-1) 75 | # - Calcualte soft assignements --- 76 | # NCm, soft assign x to levels 77 | x_soft = torch.sum(self.levels * phi_soft, dim=-1) 78 | # NCHW 79 | x_soft = x_soft.view(N, C, H, W) 80 | 81 | # - Calcualte hard assignements --- 82 | # NCm, symbols_hard[..., i] contains index of symbol to use 83 | _, symbols_hard = torch.min(d.detach(), dim=-1) 84 | # NCHW 85 | symbols_hard = symbols_hard.view(N, C, H, W) 86 | # NCHW, contains value of symbol to use 87 | x_hard = self.levels[symbols_hard] 88 | 89 | x_soft.data = x_hard # assign data, keep gradient 90 | return x_soft, x_hard, symbols_hard 91 | -------------------------------------------------------------------------------- /src/prep_openimages.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | if [[ -z $1 ]]; then 6 | echo "USAGE: $0 DATA_DIR [OUT_DIR]" 7 | exit 1 8 | fi 9 | 10 | DATA_DIR=$(realpath $1) 11 | SCRIPT_DIR=$(dirname "$(readlink -f "$0")") 12 | 13 | if [[ -n $2 ]]; then 14 | OUT_DIR=$2 15 | else 16 | OUT_DIR=$DATA_DIR 17 | fi 18 | 19 | progress () { 20 | COUNTER=0 21 | while read LINE; do 22 | COUNTER=$((COUNTER+1)) 23 | if [[ $((COUNTER % 10)) == 0 ]]; then 24 | echo -ne "\rExtracting $LINE; Unpacked $COUNTER files." 25 | fi 26 | done 27 | echo "" 28 | } 29 | 30 | echo "DATA_DIR=$DATA_DIR;\nSCRIPT_DIR=$SCRIPT_DIR;\nSaving to $OUT_DIR" 31 | 32 | mkdir -pv $DATA_DIR 33 | 34 | TRAIN_0=train_0 35 | TRAIN_1=train_1 36 | TRAIN_2=train_2 37 | VAL=validation 38 | 39 | # Download ---------- 40 | DOWNLOAD_DIR=$DATA_DIR/download 41 | mkdir -p $DOWNLOAD_DIR 42 | pushd $DOWNLOAD_DIR 43 | for DIR in $TRAIN_0 $TRAIN_1 $TRAIN_2 $VAL; do 44 | TAR=${DIR}.tar.gz 45 | if [ ! -f "$TAR" ]; then 46 | echo "Downloading $TAR..." 47 | aws s3 --no-sign-request cp s3://open-images-dataset/tar/$TAR $TAR 48 | else 49 | echo "Found $TAR..." 50 | fi 51 | done 52 | 53 | for DIR in $TRAIN_0 $TRAIN_1 $TRAIN_2 $VAL; do 54 | TAR=${DIR}.tar.gz 55 | if [ -d $DIR ]; then 56 | echo "Found $DIR, not unpacking $TAR..." 57 | continue 58 | fi 59 | if [ ! -f $TAR ]; then 60 | echo "ERROR: Expected $TAR in $DOWNLOAD_DIR" 61 | exit 1 62 | fi 63 | echo "Unpacking $TAR..." 64 | ( tar xvf $TAR | progress ) & 65 | done 66 | 67 | # Wait for all unpacking background processes 68 | wait 69 | echo "Unpacked all!" 70 | 71 | popd 72 | 73 | # Convert ---------- 74 | FINAL_TRAIN_DIR=$DATA_DIR/train_oi 75 | FINAL_VAL_DIR=$DATA_DIR/val_oi 76 | 77 | DISCARD=$OUT_DIR/discard 78 | DISCARD_VAL=$OUT_DIR/discard_val 79 | pushd $SCRIPT_DIR 80 | 81 | echo "Importing train..." 82 | # NOTE: this is were you want to employ parallelization on a cluster if it's available. 83 | # See import_train_images.py 84 | python import_train_images.py $DOWNLOAD_DIR $TRAIN_0 $TRAIN_1 $TRAIN_2 \ 85 | --out_dir_clean=$FINAL_TRAIN_DIR \ 86 | --out_dir_discard=$DISCARD \ 87 | --resolution=512 88 | 89 | python import_train_images.py $DOWNLOAD_DIR $VAL \ 90 | --out_dir_clean=$FINAL_VAL_DIR \ 91 | --out_dir_discard=$DISCARD_VAL \ 92 | --resolution=512 93 | 94 | # Update Cache ---------- 95 | CACHE_P=$DATA_DIR/cache.pkl 96 | export PYTHONPATH=$(pwd) 97 | 98 | echo "Updating cache $CACHE_P..." 99 | python dataloaders/images_loader.py update $FINAL_TRAIN_DIR "$CACHE_P" --min_size 128 100 | python dataloaders/images_loader.py update $FINAL_VAL_DIR "$CACHE_P" --min_size 128 101 | 102 | echo "----------------------------------------" 103 | echo "Done" 104 | echo "To train, you MUST UPDATE configs/dl/oi.cf:" 105 | echo "" 106 | echo " image_cache_pkl = '$1/cache.pkl'" 107 | echo " train_imgs_glob = '$(realpath $1/train_oi)'" 108 | echo " val_glob = '$(realpath $1/val_oi)'" 109 | echo "" 110 | echo "----------------------------------------" 111 | -------------------------------------------------------------------------------- /src/pytorch_ext.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | 19 | -------------------------------------------------------------------------------- 20 | 21 | General PyTorch related stuff 22 | 23 | """ 24 | 25 | import os 26 | from collections import defaultdict 27 | 28 | import numpy as np 29 | import math 30 | 31 | import torch 32 | 33 | from torch import nn as nn 34 | from torch.utils.data import Dataset 35 | 36 | import itertools 37 | 38 | 39 | CUDA_AVAILABLE = torch.cuda.is_available() 40 | 41 | # IGNORE_CUDA = os.environ.get('IGNORE_CUDA', '0') == '1' 42 | # if IGNORE_CUDA: 43 | # print('*** IGNORE_CUDA=1') 44 | 45 | DEVICE = torch.device("cuda:0" if CUDA_AVAILABLE else "cpu") 46 | 47 | 48 | # This has an effect to all parts reading pytorch_ext.DEVICE *after* it has been set. 49 | def set_device(cuda_available): 50 | global DEVICE 51 | DEVICE = torch.device("cuda:0" if cuda_available else "cpu") 52 | 53 | 54 | # Conv ------------------------------------------------------------------------- 55 | 56 | 57 | def default_conv(in_channels, out_channels, kernel_size, bias=True, rate=1, stride=1): 58 | padding = kernel_size // 2 if rate == 1 else rate 59 | return nn.Conv2d( 60 | in_channels, out_channels, kernel_size, stride=stride, dilation=rate, 61 | padding=padding, bias=bias) 62 | 63 | 64 | def initialize_with_filter(conv_or_deconv, f): 65 | assert conv_or_deconv.weight.size() == f.size(), 'Must match: {}, {}'.format( 66 | conv_or_deconv.weight.size(), f.size()) 67 | conv_or_deconv.weight.data = f 68 | return conv_or_deconv 69 | 70 | 71 | def initialize_with_id(conv_or_deconv, with_noise=False): 72 | n, Cout, H, W = conv_or_deconv.weight.shape 73 | assert n == Cout and H == W == 1, 'Invalid shape: {}'.format(conv_or_deconv.weight.shape) 74 | eye = torch.eye(n).reshape(n, n, 1, 1) 75 | if with_noise: # nicked from torch.nn.modules.conv:reset_parameters 76 | stdv = 1. / math.sqrt(n) 77 | noise = torch.empty_like(eye) 78 | noise.uniform_(-stdv, stdv) 79 | eye.add_(noise) 80 | 81 | print(eye[:10, :10, 0, 0]) 82 | initialize_with_filter(conv_or_deconv, eye) 83 | 84 | 85 | # Numpy ----------------------------------------------------------------------- 86 | 87 | 88 | def tensor_to_np(t): 89 | return t.detach().cpu().numpy() 90 | 91 | 92 | def histogram(t, L): 93 | """ 94 | A: If t is a list of tensors/np.ndarrays, B is executed for all, yielding len(ts) histograms, which are summed 95 | per bin 96 | B: convert t to numpy, count bins. 97 | :param t: tensor or list of tensor, each expected to be in [0, L) 98 | :param L: number of symbols 99 | :return: length-L array, containing at l the number of values mapping to to symbol l 100 | """ 101 | if isinstance(t, list): 102 | ts = t 103 | histograms = np.stack((histogram(t, L) for t in ts), axis=0) # get array (len(ts) x L) 104 | return np.sum(histograms, 0) 105 | assert 0 <= t.min() and t.max() < L, (t.min(), t.max()) 106 | a = tensor_to_np(t) 107 | counts, _ = np.histogram(a, np.arange(L+1)) # +1 because np.histogram takes bin edges, including rightmost edge 108 | return counts 109 | 110 | 111 | # Gradients -------------------------------------------------------------------- 112 | 113 | 114 | def get_total_grad_norm(params, norm_type=2): 115 | # nicked from torch.nn.utils.clip_grad_norm 116 | with torch.no_grad(): 117 | total_norm = 0 118 | for p in params: 119 | if p.grad is None: 120 | continue 121 | param_norm = p.grad.data.norm(norm_type) 122 | total_norm += param_norm.item() ** norm_type 123 | total_norm = total_norm ** (1. / norm_type) 124 | return total_norm 125 | 126 | 127 | def get_average_grad_norm(params, norm_type=2): 128 | """ 129 | :param params: Assumed to be generator 130 | :param norm_type: 131 | """ 132 | # nicked from torch.nn.utils.clip_grad_norm 133 | with torch.no_grad(): 134 | average_norm = 0 135 | num_params = 0 136 | for p in params: 137 | if p.grad is None: 138 | continue 139 | average_norm += p.grad.data.norm(norm_type) 140 | num_params += 1 141 | if num_params == 0: 142 | return 0 143 | return average_norm / float(num_params) 144 | 145 | 146 | # Datasets -------------------------------------------------------------------- 147 | 148 | 149 | class TruncatedDataset(Dataset): 150 | def __init__(self, dataset, num_elemens): 151 | assert len(dataset) >= num_elemens, 'Cannot truncate to {}: dataset has {} elements'.format( 152 | num_elemens, len(dataset)) 153 | self.dataset = dataset 154 | self.num_elemens = num_elemens 155 | 156 | def __len__(self): 157 | return self.num_elemens 158 | 159 | def __getitem__(self, item): 160 | return self.dataset[item] 161 | 162 | 163 | # Helpful modules -------------------------------------------------------------- 164 | 165 | 166 | class LambdaModule(nn.Module): 167 | def __init__(self, forward_lambda, name=''): 168 | super(LambdaModule, self).__init__() 169 | self.forward_lambda = forward_lambda 170 | self.description = 'LambdaModule({})'.format(name) 171 | 172 | def __repr__(self): 173 | return self.description 174 | 175 | def forward(self, x): 176 | return self.forward_lambda(x) 177 | 178 | 179 | class ChannelToLogitsTranspose(nn.Module): 180 | def __init__(self, Cout, Lout): 181 | super(ChannelToLogitsTranspose, self).__init__() 182 | self.Cout = Cout 183 | self.Lout = Lout 184 | 185 | def forward(self, x): 186 | N, C, H, W = x.shape 187 | # unfold channel dimension to (Cout, Lout) 188 | # this fills up the Cout dimension first! 189 | x = x.view(N, self.Lout, self.Cout, H, W) 190 | return x 191 | 192 | def __repr__(self): 193 | return 'ChannelToLogitsTranspose(Cout={}, Lout={})'.format(self.Cout, self.Lout) 194 | 195 | 196 | class LogitsToChannelTranspose(nn.Module): 197 | def __init__(self): 198 | super(LogitsToChannelTranspose, self).__init__() 199 | 200 | def forward(self, x): 201 | N, L, C, H, W = x.shape 202 | # fold channel, L dimension back to channel dim 203 | x = x.view(N, C * L, H, W) 204 | return x 205 | 206 | def __repr__(self): 207 | return 'LogitsToChannelTranspose()' 208 | 209 | 210 | def channel_to_logits(x, Cout, Lout): 211 | N, C, H, W = x.shape 212 | # unfold channel dimension to (Cout, Lout) 213 | # this fills up the Cout dimension first! 214 | x = x.view(N, Lout, Cout, H, W) 215 | return x 216 | 217 | 218 | def logits_to_channel(x): 219 | N, L, C, H, W = x.shape 220 | # fold channel, L dimension back to channel dim 221 | x = x.view(N, C * L, H, W) 222 | return x 223 | 224 | 225 | class OneHot(nn.Module): 226 | """ 227 | Take long tensor x of some shape (N,d1,d2,...,dN) containing integers in [0, L), 228 | produces one hot encoding `out` of out_shape (N, d1, ..., L, ..., dN), where out_shape[Ldim] = L, containing 229 | out[n, i, ..., l, ..., j] == {1 if x[n, i, ..., j] == l 230 | 0 otherwise 231 | """ 232 | def __init__(self, L, Ldim=1): 233 | super(OneHot, self).__init__() 234 | self.L = L 235 | self.Ldim = Ldim 236 | 237 | def forward(self, x): 238 | return one_hot(x, self.L, self.Ldim) 239 | 240 | 241 | def one_hot(x, L, Ldim): 242 | """ add dim L at Ldim """ 243 | assert Ldim >= 0 or Ldim == -1, f'Only supporting Ldim >= 0 or Ldim == -1: {Ldim}' 244 | out_shape = list(x.shape) 245 | if Ldim == -1: 246 | out_shape.append(L) 247 | else: 248 | out_shape.insert(Ldim, L) 249 | x = x.unsqueeze(Ldim) # x must match # dims of outshape 250 | assert x.dim() == len(out_shape), (x.shape, out_shape) 251 | oh = torch.zeros(*out_shape, dtype=torch.float32, device=x.device) 252 | oh.scatter_(Ldim, x, 1) 253 | return oh 254 | 255 | 256 | # ------------------------------------------------------------------------------ 257 | 258 | 259 | def assert_equal(t1, t2, show_num_wrong=3, names=None, msg=''): 260 | if t1.shape != t2.shape: 261 | raise AssertionError('Different shapes! {} != {}'.format(t1.shape, t2.shape)) 262 | wrong = t1 != t2 263 | if not wrong.any(): 264 | return 265 | if names is None: 266 | names = ('t1', 't2') 267 | wrong_idxs = wrong.nonzero() 268 | num_wrong = len(wrong_idxs) 269 | show_num_wrong = min(show_num_wrong, num_wrong) 270 | wrong_idxs = itertools.islice((tuple(i.tolist()) for i in wrong_idxs), 271 | show_num_wrong) 272 | err_msg = ' // '.join('{}: {}!={}'.format(idx, t1[idx], t2[idx]) 273 | for idx in wrong_idxs) 274 | raise AssertionError(('{} != {}: {}, and {}/{} other(s) '.format( 275 | names[0], names[1], err_msg, num_wrong - show_num_wrong, np.prod(t1.shape)) + msg).strip()) 276 | 277 | 278 | class BatchSummarizer(object): 279 | """ 280 | Summarize values from multiple batches 281 | """ 282 | def __init__(self, writer, global_step): 283 | self.writer = writer 284 | self.global_step = global_step 285 | self._values = defaultdict(list) 286 | 287 | def append(self, key, values): 288 | self._values[key].append(values) 289 | 290 | def output_summaries(self, strip_for_str='val'): 291 | strs = [] 292 | for key, values in self._values.items(): 293 | avg = sum(values) / len(values) 294 | self.writer.add_scalar(key, avg, self.global_step) 295 | key = key.replace(strip_for_str, '').strip('/') 296 | strs.append('{}={:.3f}'.format(key, avg)) 297 | self.writer.file_writer.flush() 298 | return strs 299 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | 19 | -------------------------------------------------------------------------------- 20 | 21 | test.py log_dir log_dates images 22 | 23 | test.py log_dir log_dates images --sample=samples 24 | 25 | test.py log_dir log_dates images --write_to_files=l3c_out 26 | 27 | This code uses a cache: If some experiment has already been tested for some iteration and crop and dataset, 28 | we just print that (see TestID in multiscale_tester.py). 29 | 30 | """ 31 | 32 | import torch.backends.cudnn 33 | torch.backends.cudnn.benchmark = True 34 | 35 | import argparse 36 | from operator import itemgetter 37 | 38 | from helpers.aligned_printer import AlignedPrinter 39 | from helpers.testset import Testset 40 | from test.multiscale_tester import MultiscaleTester 41 | 42 | 43 | 44 | def main(): 45 | p = argparse.ArgumentParser() 46 | 47 | p.add_argument('log_dir', help='Directory of experiments. Will create a new folder, LOG_DIR_test, to save test ' 48 | 'outputs.') 49 | p.add_argument('log_dates', help='A comma-separated list, where each entry is a log_date, such as 0104_1345. ' 50 | 'These experiments will be tested.') 51 | p.add_argument('images', help='A comma-separated list, where each entry is either a directory with images or ' 52 | 'the path of a single image. Will test on all these images.') 53 | p.add_argument('--match_filenames', '-fns', nargs='+', metavar='FILTER', 54 | help='If given, remove any images in the folders given by IMAGES that do not match any ' 55 | 'of specified filter.') 56 | p.add_argument('--max_imgs_per_folder', '-m', type=int, metavar='MAX', 57 | help='If given, only use MAX images per folder given in IMAGES. Default: None') 58 | p.add_argument('--crop', type=int, help='Crop all images to CROP x CROP squares. Default: None') 59 | 60 | p.add_argument('--names', '-n', type=str, 61 | help='Comma separated list, if given, must be as long as LOG_DATES. Used for output. If not given, ' 62 | 'will just print LOG_DATES as names.') 63 | 64 | p.add_argument('--overwrite_cache', '-f', action='store_true', 65 | help='Ignore cached test outputs, and re-create.') 66 | p.add_argument('--reset_entire_cache', action='store_true', 67 | help='Remove cache.') 68 | 69 | p.add_argument('--restore_itr', '-i', default='-1', 70 | help='Which iteration to restore. -1 means latest iteration. Will use closest smaller if exact ' 71 | 'iteration is not found. Default: -1') 72 | 73 | p.add_argument('--recursive', default='0', 74 | help='Either an number or "auto". If given, the rgb configs with num_scales == 1 will ' 75 | 'automatically be evaluated recursively (i.e., the RGB baseline). See _parse_recursive_flag ' 76 | 'in multiscale_tester.py. Default: 0') 77 | 78 | p.add_argument('--sample', type=str, metavar='SAMPLE_OUT_DIR', 79 | help='Sample from model. Store results in SAMPLE_OUT_DIR.') 80 | 81 | p.add_argument('--write_to_files', type=str, metavar='WRITE_OUT_DIR', 82 | help='Write images to files in folder WRITE_OUT_DIR, with arithmetic coder. If given, the cache is ' 83 | 'ignored and no test output is printed. Requires torchac to be installed, see README. Files ' 84 | 'that already exist in WRITE_OUT_DIR are overwritten.') 85 | p.add_argument('--compare_theory', action='store_true', 86 | help='If given with --write_to_files, will compare actual bitrate on disk to theoretical bitrate ' 87 | 'given by cross entropy.') 88 | p.add_argument('--time_report', type=str, metavar='TIME_REPORT_PATH', 89 | help='If given with --write_to_files, write a report of time needed for each component to ' 90 | 'TIME_REPORT_PATH.') 91 | 92 | p.add_argument('--sort_output', '-s', choices=['testset', 'exp', 'itr', 'res'], default='testset', 93 | help='How to sort the final summary. Possible values: "testset" to sort by ' 94 | 'name of the testset // "exp" to sort by experiment log_date // "itr" to sort by iteration // ' 95 | '"res" to sort by result, i.e., show smaller first. Default: testset') 96 | 97 | flags = p.parse_args() 98 | 99 | if flags.compare_theory and not flags.write_to_files: 100 | raise ValueError('Cannot have --compare_theory without --write_to_files.') 101 | if flags.write_to_files and flags.sample: 102 | raise ValueError('Cannot have --write_to_files and --sample.') 103 | if flags.time_report and not flags.write_to_files: 104 | raise ValueError('--time_report only valid with --write_to_files.') 105 | 106 | testsets = [Testset(images_dir_or_image.rstrip('/'), flags.max_imgs_per_folder, 107 | # Append flags.crop to ID so that it creates unique entry in cache 108 | append_id=f'_crop{flags.crop}' if flags.crop else None) 109 | for images_dir_or_image in flags.images.split(',')] 110 | if flags.match_filenames: 111 | for ts in testsets: 112 | ts.filter_filenames(flags.match_filenames) 113 | 114 | splitter = ',' if ',' in flags.log_dates else '|' # support tensorboard strings, too 115 | results = [] 116 | log_dates = flags.log_dates.split(splitter) 117 | for log_date in log_dates: 118 | for restore_itr in map(int, flags.restore_itr.split(',')): 119 | print('Testing {} at {} ---'.format(log_date, restore_itr)) 120 | tester = MultiscaleTester(log_date, flags, restore_itr) 121 | results += tester.test_all(testsets) 122 | 123 | # if --names was passed: will print 'name (log_date)'. otherwise, will just print 'log_date' 124 | if flags.names: 125 | names = flags.names.split(splitter) if flags.names else log_dates 126 | names_to_log_date = {log_date: f'{name} ({log_date})' 127 | for log_date, name in zip(log_dates, names)} 128 | else: 129 | # set names to log_dates if --names is not given, i.e., we just output log_date 130 | names_to_log_date = {log_date: log_date for log_date in log_dates} 131 | if not flags.write_to_files: 132 | print('*** Summary:') 133 | with AlignedPrinter() as a: 134 | sortby = {'testset': 0, 'exp': 1, 'itr': 2, 'res': 3}[flags.sort_output] 135 | a.append('Testset', 'Experiment', 'Itr', 'Result') 136 | for testset, log_date, restore_itr, result in sorted(results, key=itemgetter(sortby)): 137 | a.append(testset.id, names_to_log_date[log_date], str(restore_itr), result) 138 | 139 | 140 | if __name__ == '__main__': 141 | main() 142 | -------------------------------------------------------------------------------- /src/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/src/test/__init__.py -------------------------------------------------------------------------------- /src/test/cuda_timer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | from contextlib import contextmanager 20 | import numpy as np 21 | import torch 22 | import os 23 | import time 24 | import pytorch_ext as pe 25 | from collections import defaultdict 26 | from collections import namedtuple 27 | 28 | 29 | _NO_CUDA_SYNC_OVERWRITE = int(os.environ.get('NO_CUDA_SYNC', 0)) == 1 30 | 31 | 32 | if _NO_CUDA_SYNC_OVERWRITE or not pe.CUDA_AVAILABLE: 33 | sync = lambda: None 34 | else: 35 | sync = torch.cuda.synchronize 36 | 37 | 38 | 39 | class StackLogger(object): 40 | Entry = namedtuple('Entry', ['fmt_str', 'logs']) 41 | CombineCtx = namedtuple('CombineCtx', ['prefixes_of_created_entries', 'fmt_str']) 42 | 43 | def __init__(self, default_fmt_str='{}'): 44 | self.default_fmt_str = default_fmt_str 45 | 46 | self.logs = defaultdict(list) 47 | self._order = [] 48 | self._prefixes = [] 49 | 50 | self._combine_ctx = None 51 | self._global_skip = False 52 | 53 | @contextmanager 54 | def skip(self, flag): 55 | self._global_skip = flag 56 | yield 57 | self._global_skip = False 58 | 59 | @contextmanager 60 | def prefix_scope(self, p): 61 | self._prefixes.append(p) 62 | yield 63 | del self._prefixes[-1] 64 | 65 | @contextmanager 66 | def combine(self, fmt_str): 67 | if self._combine_ctx is not None: 68 | raise ValueError('Already in combine!') 69 | self._combine_ctx = set(), fmt_str 70 | yield 71 | self._combine_ctx = None 72 | 73 | @contextmanager 74 | def prefix_scope_combine(self, prefix, fmt_str): 75 | with self.prefix_scope(prefix): 76 | with self.combine(fmt_str): 77 | yield 78 | 79 | def log(self, name, msg): 80 | if self._global_skip: 81 | return 82 | prefix = ' '.join(self._prefixes + [name]) 83 | if prefix not in self.logs: 84 | self._order.append(prefix) 85 | logs = self._get_log_entry_list(prefix) 86 | logs.append(msg) 87 | 88 | def _get_log_entry_list(self, prefix): 89 | if self._combine_ctx is None: 90 | # Always create a new entry 91 | return self.logs[prefix] 92 | 93 | prefixes_of_created_entries, fmt_str = self._combine_ctx 94 | if prefix not in prefixes_of_created_entries: 95 | prefixes_of_created_entries.add(prefix) 96 | return self._append_entry(prefix, fmt_str).logs 97 | 98 | assert len(self.logs[prefix]) > 0 99 | return self.logs[prefix][-1].logs 100 | 101 | def _append_entry(self, prefix, fmt_str): 102 | log_entry = StackLogger.Entry(fmt_str, logs=[]) 103 | self.logs[prefix].append(log_entry) 104 | return log_entry 105 | 106 | 107 | class StackTimeLogger(StackLogger): 108 | def __init__(self, default_fmt_str='{:.5f}'): 109 | super(StackTimeLogger, self).__init__(default_fmt_str) 110 | 111 | 112 | def get_mean_strs(self): 113 | for prefix in self._order: 114 | entries = self.logs[prefix] 115 | first_entry = entries[0] 116 | if isinstance(first_entry, StackLogger.Entry): 117 | num_values_per_entry = len(first_entry.logs) 118 | means = np.zeros(num_values_per_entry, dtype=np.float) 119 | for e in entries: 120 | means += np.array(e.logs) 121 | means = means / len(entries) 122 | yield self._to_str(prefix, first_entry.fmt_str, means) 123 | else: # entries is just a list 124 | mean = np.mean(entries) 125 | yield self._to_str(prefix, self.default_fmt_str, mean) 126 | 127 | def get_last_strs(self): 128 | for prefix in self._order: 129 | entries = self.logs[prefix] 130 | last_entry = entries[-1] 131 | if isinstance(last_entry, StackLogger.Entry): 132 | yield self._to_str(prefix, last_entry.fmt_str, last_entry.logs) 133 | else: 134 | yield self._to_str(prefix, self.default_fmt_str, last_entry) 135 | 136 | @contextmanager 137 | def run(self, name): 138 | sync() 139 | start = time.time() 140 | yield 141 | sync() 142 | duration = time.time() - start 143 | self.log(name, duration) 144 | 145 | @staticmethod 146 | def _to_str(prefix, fmt_str, values): 147 | should_iter = isinstance(values, (list, np.ndarray)) 148 | values_str = (fmt_str.format(values) if not should_iter 149 | else '/'.join(fmt_str.format(i, v) 150 | for i, v in enumerate(values))) 151 | return prefix + ': ' + values_str 152 | 153 | 154 | def test_stack_time_logger(): 155 | t = StackTimeLogger() 156 | for i in [1, 2]: 157 | with t.prefix_scope('foo'): 158 | with t.prefix_scope('bar'): 159 | with t.run('setup'): 160 | time.sleep(0.1) 161 | with t.combine('c{}: {:.5f}'): 162 | for c in range(5): 163 | with t.run('run'): 164 | time.sleep(c * 0.01) 165 | from pprint import pprint 166 | pprint(t.logs) 167 | pprint(t._order) 168 | print('\n'.join(t.get_mean_strs())) 169 | print('...') 170 | print('\n'.join(t.get_last_strs())) 171 | 172 | -------------------------------------------------------------------------------- /src/test/image_saver.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | import glob 20 | import os 21 | 22 | import torch 23 | from PIL import Image 24 | 25 | from vis.image_summaries import to_image 26 | 27 | 28 | class ImageSaver(object): 29 | def __init__(self, out_dir): 30 | self.out_dir = out_dir 31 | os.makedirs(self.out_dir, exist_ok=True) 32 | self.saved_fs = [] 33 | 34 | def __str__(self): 35 | return 'ImageSaver({})'.format(self.out_dir) 36 | 37 | def save_img(self, img, filename, convert_to_image=True): 38 | """ 39 | :param img: image tensor, in {0, ..., 255} 40 | :param filename: output filename 41 | :param convert_to_image: if True, call to_image on img, otherwise assume this has already been done. 42 | :return: 43 | """ 44 | if convert_to_image: 45 | img = to_image(img.type(torch.uint8)) 46 | out_p = self.get_save_p(filename) 47 | Image.fromarray(img).save(out_p) 48 | return out_p 49 | 50 | def get_save_p(self, file_name): 51 | out_p = os.path.join(self.out_dir, file_name) 52 | self.saved_fs.append(file_name) 53 | return out_p 54 | 55 | def file_starting_with_exists(self, prefix): 56 | check_p = os.path.join(self.out_dir, prefix) + '*' 57 | print(check_p, glob.glob(check_p)) 58 | return len(glob.glob(check_p)) > 0 -------------------------------------------------------------------------------- /src/torchac/setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | 19 | -------------------------------------------------------------------------------- 20 | 21 | NOTE: Needs PyTorch 1.0 or newer, as the C++ code relies on that API! 22 | 23 | Depending on the environment variable COMPILE_CUDA, compiles the torchac_backend with or 24 | without support for CUDA, into a module called torchac_backend_gpu or torchac_backend_cpu. 25 | 26 | COMPILE_CUDA = auto is equal to yes if one of the supported combinations of nvcc and gcc is found (see 27 | _supported_compilers_available). 28 | COMPILE_CUDA = force means compile with CUDA, even if it is not one of the supported combinations 29 | COMPILE_CUDA = no means no CUDA. 30 | 31 | The only difference between the CUDA and non-CUDA versions is: With CUDA, _get_uint16_cdf from torchac is done with a 32 | simple/non-optimized CUDA kernel (torchac_kernel.cu), which has one benefit: we can directly write into shared memory! 33 | This saves an expensive copying step from GPU to CPU. 34 | 35 | Flags read by this script: 36 | COMPILE_CUDA=[auto|force|no] 37 | 38 | """ 39 | 40 | import sys 41 | import re 42 | import subprocess 43 | from setuptools import setup 44 | from distutils.version import LooseVersion 45 | from torch.utils.cpp_extension import CppExtension, BuildExtension, CUDAExtension 46 | import os 47 | 48 | 49 | MODULE_BASE_NAME = 'torchac_backend' 50 | 51 | 52 | def prefixed(prefix, l): 53 | ps = [os.path.join(prefix, el) for el in l] 54 | for p in ps: 55 | if not os.path.isfile(p): 56 | raise FileNotFoundError(p) 57 | return ps 58 | 59 | 60 | def compile_ext(cuda_support): 61 | print('Compiling, cuda_support={}'.format(cuda_support)) 62 | ext_module = get_extension(cuda_support) 63 | 64 | setup(name=ext_module.name, 65 | version='1.0.0', 66 | ext_modules=[ext_module], 67 | extra_compile_args=['-mmacosx-version-min=10.9'], 68 | cmdclass={'build_ext': BuildExtension}) 69 | 70 | 71 | def get_extension(cuda_support): 72 | # dir of this file 73 | setup_dir = os.path.dirname(os.path.realpath(__file__)) 74 | # Where the cpp and cu files are 75 | prefix = os.path.join(setup_dir, MODULE_BASE_NAME) 76 | if not os.path.isdir(prefix): 77 | raise ValueError('Did not find backend foler: {}'.format(prefix)) 78 | if cuda_support: 79 | nvcc_avaible, nvcc_version = supported_nvcc_available() 80 | if not nvcc_avaible: 81 | print(_bold_warn_str('***WARN') + ': Found untested nvcc {}'.format(nvcc_version)) 82 | 83 | return CUDAExtension( 84 | MODULE_BASE_NAME + '_gpu', 85 | prefixed(prefix, ['torchac.cpp', 'torchac_kernel.cu']), 86 | define_macros=[('COMPILE_CUDA', '1')]) 87 | else: 88 | return CppExtension( 89 | MODULE_BASE_NAME + '_cpu', 90 | prefixed(prefix, ['torchac.cpp'])) 91 | 92 | 93 | # TODO: 94 | # Add further supported version as specified in readme 95 | 96 | 97 | def _supported_compilers_available(): 98 | """ 99 | To see an up-to-date list of tested combinations of GCC and NVCC, see the README 100 | """ 101 | return _supported_gcc_available()[0] and supported_nvcc_available()[0] 102 | 103 | 104 | def _supported_gcc_available(): 105 | v = _get_version(['gcc', '-v'], r'version (.*?)\s+') 106 | return LooseVersion('6.0') > LooseVersion(v) >= LooseVersion('5.0'), v 107 | 108 | 109 | def supported_nvcc_available(): 110 | v = _get_version(['nvcc', '-V'], 'release (.*?),') 111 | if v is None: 112 | return False, 'nvcc unavailable!' 113 | return LooseVersion(v) >= LooseVersion('9.0'), v 114 | 115 | 116 | def _get_version(cmd, regex): 117 | try: 118 | otp = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode() 119 | if len(otp.strip()) == 0: 120 | raise ValueError('No output') 121 | m = re.search(regex, otp) 122 | if not m: 123 | raise ValueError('Regex does not match output:\n{}'.format(otp)) 124 | return m.group(1) 125 | except FileNotFoundError: 126 | return None 127 | 128 | 129 | def _bold_warn_str(s): 130 | return '\x1b[91m\x1b[1m' + s + '\x1b[0m' 131 | 132 | 133 | def _assert_torch_version_sufficient(): 134 | import torch 135 | if LooseVersion(torch.__version__) >= LooseVersion('1.0'): 136 | return 137 | print(_bold_warn_str('Error:'), 'Need PyTorch version >= 1.0, found {}'.format(torch.__version__)) 138 | sys.exit(1) 139 | 140 | 141 | def main(): 142 | _assert_torch_version_sufficient() 143 | 144 | cuda_flag = os.environ.get('COMPILE_CUDA', 'no') 145 | 146 | if cuda_flag == 'auto': 147 | cuda_support = _supported_compilers_available() 148 | print('Found CUDA supported:', cuda_support) 149 | elif cuda_flag == 'force': 150 | cuda_support = True 151 | elif cuda_flag == 'no': 152 | cuda_support = False 153 | else: 154 | raise ValueError('COMPILE_CUDA must be in (auto, force, no), got {}'.format(cuda_flag)) 155 | 156 | compile_ext(cuda_support) 157 | 158 | 159 | if __name__ == '__main__': 160 | main() 161 | -------------------------------------------------------------------------------- /src/torchac/torchac.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | 20 | # TODO some comments needed about [..., -1] == 0 21 | 22 | import torch 23 | 24 | 25 | # torchac can be built with and without CUDA support. 26 | # Here, we try to import both torchac_backend_gpu and torchac_backend_cpu. 27 | # If both fail, an exception is thrown here already. 28 | # 29 | # The right version is then picked in the functions below. 30 | # 31 | # NOTE: 32 | # Without a clean build, multiple versions might be installed. You may use python seutp.py clean --all to prevent this. 33 | # But it should not be an issue. 34 | 35 | 36 | import_errors = [] 37 | 38 | 39 | try: 40 | import torchac_backend_gpu 41 | CUDA_SUPPORTED = True 42 | except ImportError as e: 43 | CUDA_SUPPORTED = False 44 | import_errors.append(e) 45 | 46 | try: 47 | import torchac_backend_cpu 48 | CPU_SUPPORTED = True 49 | except ImportError as e: 50 | CPU_SUPPORTED = False 51 | import_errors.append(e) 52 | 53 | 54 | imported_at_least_one = CUDA_SUPPORTED or CPU_SUPPORTED 55 | 56 | 57 | # if import_errors: 58 | # import_errors_str = '\n'.join(map(str, import_errors)) 59 | # print(f'*** Import errors:\n{import_errors_str}') 60 | 61 | 62 | if not imported_at_least_one: 63 | raise ImportError('*** Failed to import any torchac_backend! Make sure to install torchac with torchac/setup.py. ' 64 | 'See the README for details.') 65 | 66 | 67 | any_backend = torchac_backend_cpu if CPU_SUPPORTED else torchac_backend_gpu 68 | 69 | 70 | # print(f'*** torchac: GPU support: {CUDA_SUPPORTED} // CPU support: {CPU_SUPPORTED}') 71 | 72 | 73 | def _get_gpu_backend(): 74 | if not CUDA_SUPPORTED: 75 | raise ValueError('Got CUDA tensor, but torchac_backend_gpu is not available. ' 76 | 'Compile torchac with CUDA support, or use CPU mode (see README).') 77 | return torchac_backend_gpu 78 | 79 | 80 | def _get_cpu_backend(): 81 | if not CPU_SUPPORTED: 82 | raise ValueError('Got CPU tensor, but torchac_backend_cpu is not available. ' 83 | 'Compile torchac without CUDA support, or use GPU mode (see README).') 84 | return torchac_backend_cpu 85 | 86 | 87 | def encode_cdf(cdf, sym): 88 | """ 89 | :param cdf: CDF as 1HWLp, as int16, on CPU! 90 | :param sym: the symbols to encode, as int16, on CPU 91 | :return: byte-string, encoding `sym` 92 | """ 93 | if cdf.is_cuda or sym.is_cuda: 94 | raise ValueError('CDF and symbols must be on CPU for `encode_cdf`') 95 | # encode_cdf is defined in both backends, so doesn't matter which one we use! 96 | return any_backend.encode_cdf(cdf, sym) 97 | 98 | 99 | def decode_cdf(cdf, input_string): 100 | """ 101 | :param cdf: CDF as 1HWLp, as int16, on CPU 102 | :param input_string: byte-string, encoding some symbols `sym`. 103 | :return: decoded `sym`. 104 | """ 105 | if cdf.is_cuda: 106 | raise ValueError('CDF must be on CPU for `decode_cdf`') 107 | # encode_cdf is defined in both backends, so doesn't matter which one we use! 108 | return any_backend.decode_cdf(cdf, input_string) 109 | 110 | 111 | def encode_logistic_mixture( 112 | targets, means, log_scales, logit_probs_softmax, # CDF 113 | sym): 114 | """ 115 | NOTE: This function uses either the CUDA or CPU backend, depending on the device of the input tensors. 116 | NOTE: targets, means, log_scales, logit_probs_softmax must all be on the same device (CPU or GPU) 117 | In the following, we use 118 | Lp: Lp = L+1, where L = number of symbols. 119 | K: number of mixtures 120 | :param targets: values of symbols, tensor of length Lp, float32 121 | :param means: means of mixtures, tensor of shape 1KHW, float32 122 | :param log_scales: log(scales) of mixtures, tensor of shape 1KHW, float32 123 | :param logit_probs_softmax: weights of the mixtures (PI), tensorf of shape 1KHW, float32 124 | :param sym: the symbols to encode. MUST be on CPU!! 125 | :return: byte-string, encoding `sym`. 126 | """ 127 | if not (targets.is_cuda == means.is_cuda == log_scales.is_cuda == logit_probs_softmax.is_cuda): 128 | raise ValueError('targets, means, log_scales, logit_probs_softmax must all be on the same device! Got ' 129 | f'{targets.device}, {means.device}, {log_scales.device}, {logit_probs_softmax.device}.') 130 | if sym.is_cuda: 131 | raise ValueError('sym must be on CPU!') 132 | 133 | if targets.is_cuda: 134 | return _get_gpu_backend().encode_logistic_mixture( 135 | targets, means, log_scales, logit_probs_softmax, sym) 136 | else: 137 | cdf = _get_uint16_cdf(logit_probs_softmax, targets, means, log_scales) 138 | return encode_cdf(cdf, sym) 139 | 140 | 141 | def decode_logistic_mixture( 142 | targets, means, log_scales, logit_probs_softmax, # CDF 143 | input_string): 144 | """ 145 | NOTE: This function uses either the CUDA or CPU backend, depending on the device of the input tensors. 146 | NOTE: targets, means, log_scales, logit_probs_softmax must all be on the same device (CPU or GPU) 147 | In the following, we use 148 | Lp: Lp = L+1, where L = number of symbols. 149 | K: number of mixtures 150 | :param targets: values of symbols, tensor of length Lp, float32 151 | :param means: means of mixtures, tensor of shape 1KHW, float32 152 | :param log_scales: log(scales) of mixtures, tensor of shape 1KHW, float32 153 | :param logit_probs_softmax: weights of the mixtures (PI), tensorf of shape 1KHW, float32 154 | :param input_string: byte-string, encoding some symbols `sym`. 155 | :return: decoded `sym`. 156 | """ 157 | if not (targets.is_cuda == means.is_cuda == log_scales.is_cuda == logit_probs_softmax.is_cuda): 158 | raise ValueError('targets, means, log_scales, logit_probs_softmax must all be on the same device! Got ' 159 | f'{targets.device}, {means.device}, {log_scales.device}, {logit_probs_softmax.device}.') 160 | 161 | if targets.is_cuda: 162 | return _get_gpu_backend().decode_logistic_mixture( 163 | targets, means, log_scales, logit_probs_softmax, input_string) 164 | else: 165 | cdf = _get_uint16_cdf(logit_probs_softmax, targets, means, log_scales) 166 | return decode_cdf(cdf, input_string) 167 | 168 | 169 | # ------------------------------------------------------------------------------ 170 | 171 | # The following code is invoced for when the CDF is not on GPU, and we cannot use torchac/torchac_kernel.cu 172 | # This basically replicates that kernel in pure PyTorch. 173 | 174 | def _get_uint16_cdf(logit_probs_softmax, targets, means, log_scales): 175 | cdf_float = _get_C_cur_weighted(logit_probs_softmax, targets, means, log_scales) 176 | cdf = _renorm_cast_cdf_(cdf_float, precision=16) 177 | cdf = cdf.cpu() 178 | return cdf 179 | 180 | 181 | def _get_C_cur_weighted(logit_probs_softmax_c, targets, means_c, log_scales_c): 182 | C_cur = _get_C_cur(targets, means_c, log_scales_c) # NKHWL 183 | C_cur = C_cur.mul(logit_probs_softmax_c.unsqueeze(-1)).sum(1) # NHWL 184 | return C_cur 185 | 186 | 187 | def _get_C_cur(targets, means_c, log_scales_c): # NKHWL 188 | """ 189 | :param targets: Lp floats 190 | :param means_c: NKHW 191 | :param log_scales_c: NKHW 192 | :return: 193 | """ 194 | # NKHW1 195 | inv_stdv = torch.exp(-log_scales_c).unsqueeze(-1) 196 | # NKHWL' 197 | centered_targets = (targets - means_c.unsqueeze(-1)) 198 | # NKHWL' 199 | cdf = centered_targets.mul(inv_stdv).sigmoid() # sigma' * (x - mu) 200 | return cdf 201 | 202 | 203 | def _renorm_cast_cdf_(cdf, precision): 204 | Lp = cdf.shape[-1] 205 | finals = 1 # NHW1 206 | # RENORMALIZATION_FACTOR in cuda 207 | f = torch.tensor(2, dtype=torch.float32, device=cdf.device).pow_(precision) 208 | cdf = cdf.mul((f - (Lp - 1)) / finals) # TODO 209 | cdf = cdf.round() 210 | cdf = cdf.to(dtype=torch.int16, non_blocking=True) 211 | r = torch.arange(Lp, dtype=torch.int16, device=cdf.device) 212 | cdf.add_(r) 213 | return cdf 214 | -------------------------------------------------------------------------------- /src/torchac/torchac_backend/torchac_kernel.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * COPYRIGHT 2019 ETH Zurich 3 | */ 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | 11 | using cdf_t = uint16_t; 12 | const int PRECISION = 16; 13 | const int RENORMALIZATION_FACTOR = 2 << (PRECISION - 1); 14 | 15 | namespace { 16 | __device__ __forceinline__ float sigmoidf (float a) { 17 | return 1.0 / (1.0 + expf (-a)); 18 | } 19 | 20 | __device__ __forceinline__ cdf_t renorm(float cdf, const int Lp, const int l) { 21 | cdf *= (RENORMALIZATION_FACTOR - (Lp - 1)); 22 | cdf_t cdf_int = static_cast(lrintf(cdf) + l); 23 | return cdf_int; 24 | } 25 | 26 | __global__ void calculate_cdf_kernel( 27 | const int N, const int Lp, const int K, 28 | const float* __restrict__ targets, // Lp length vector 29 | const float* __restrict__ means, 30 | const float* __restrict__ log_scales, 31 | const float* __restrict__ logit_probs_softmax, 32 | cdf_t* __restrict__ cdf_mem /* out */) { 33 | /** 34 | * Expects to be launched on a N*Lp grid? TODO 35 | * 36 | * means, log_scales, logit_probs_softmax: 37 | * each is a 1KHW matrix reshaped to KN, where N = H*W 38 | * cdf_mem: 39 | * an array of length N * Lp, representing a NxLp matrix, where 40 | * cdf[n][l] = cdf_mem[n*Lp + l] 41 | * 42 | * Code: 43 | * for n, l in range(N) x range(Lp) 44 | * target = l 45 | * cdf_n_l = 0; 46 | * for k in range(K) 47 | * log_scale = log_scales[k][n] 48 | * mean = means[k][n] 49 | * logit_prob = logit_probs_softmax[k][n] 50 | * inv_stdv = exp(log_scale) 51 | * centered_target = target - mean 52 | * cdf_n_l += logit_prob * sigmoid(centered_target * inv_stdv) 53 | * cdf[n][l] = cdf_mem 54 | */ 55 | int index = blockIdx.x * blockDim.x + threadIdx.x; 56 | int stride = blockDim.x * gridDim.x; 57 | for (int i_1d = index; i_1d < N * Lp; i_1d += stride) { 58 | const int n = i_1d / Lp; 59 | const int l = i_1d % Lp; 60 | 61 | const float target = targets[l]; 62 | float cdf_n_l_float = 0; // initialize 63 | 64 | for (int k = 0; k < K; ++k) { 65 | const float log_scale = log_scales[k * N + n]; 66 | const float mean = means[k * N + n]; 67 | const float logit_prob = logit_probs_softmax[k * N + n]; 68 | const float inv_stdv = expf(-log_scale); 69 | const float centered_target = target - mean; 70 | cdf_n_l_float += logit_prob * sigmoidf(centered_target * inv_stdv); 71 | } 72 | 73 | const int cdf_n_l_idx = i_1d; 74 | cdf_mem[cdf_n_l_idx] = renorm(cdf_n_l_float, Lp, l); 75 | } 76 | } 77 | } 78 | 79 | 80 | cdf_t* malloc_cdf(const int N, const int Lp) { 81 | cdf_t* cdf_mem; 82 | cudaMallocManaged(&cdf_mem, N*Lp*sizeof(cdf_t)); 83 | return cdf_mem; 84 | } 85 | 86 | 87 | void free_cdf(cdf_t* cdf_mem) { 88 | cudaFree(cdf_mem); 89 | } 90 | 91 | 92 | template 93 | std::string to_string(const T& object) { 94 | std::ostringstream ss; 95 | ss << object; 96 | return ss.str(); 97 | } 98 | 99 | #define CHECK_1KHW(K, x) AT_CHECK(x.sizes().size() == 4 && x.sizes()[0] == 1 && x.sizes()[1] == K, \ 100 | "#x must be 4D, got %s", to_string(x.sizes())) 101 | 102 | #define CHECK_CONTIGUOUS_AND_CUDA(x) AT_CHECK(x.is_contiguous() && x.is_cuda(), \ 103 | "#x must be contiguous and on GPU, got %d and %d", x.is_contiguous(), x.is_cuda()) 104 | 105 | void calculate_cdf( 106 | const at::Tensor& targets, 107 | const at::Tensor& means, 108 | const at::Tensor& log_scales, 109 | const at::Tensor& logit_probs_softmax, 110 | cdf_t * cdf_mem, 111 | const int K, const int Lp, const int N_cdf) { 112 | 113 | CHECK_1KHW(K, means); 114 | CHECK_1KHW(K, log_scales); 115 | CHECK_1KHW(K, logit_probs_softmax); 116 | 117 | CHECK_CONTIGUOUS_AND_CUDA(targets); 118 | CHECK_CONTIGUOUS_AND_CUDA(means); 119 | CHECK_CONTIGUOUS_AND_CUDA(log_scales); 120 | CHECK_CONTIGUOUS_AND_CUDA(logit_probs_softmax); 121 | 122 | AT_CHECK(means.sizes() == log_scales.sizes() && 123 | log_scales.sizes() == logit_probs_softmax.sizes()) 124 | 125 | const auto param_sizes = means.sizes(); 126 | const auto N = param_sizes[2] * param_sizes[3]; // H * W 127 | AT_CHECK(N == N_cdf, "%d != %d", N, N_cdf); 128 | 129 | const int blockSize = 1024; 130 | const int numBlocks = (N * Lp + blockSize - 1) / blockSize; 131 | 132 | calculate_cdf_kernel<<>>( 133 | N, Lp, K, 134 | targets.data(), 135 | means.data(), 136 | log_scales.data(), 137 | logit_probs_softmax.data(), 138 | cdf_mem); 139 | 140 | // Wait for GPU to finish before accessing on host 141 | cudaDeviceSynchronize(); 142 | } 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | import torch 20 | 21 | # seed at least the random number generators. 22 | # doesn't guarantee full reproducability: https://pytorch.org/docs/stable/notes/randomness.html 23 | torch.manual_seed(0) 24 | 25 | # --- 26 | 27 | import argparse 28 | import sys 29 | 30 | import torch.backends.cudnn 31 | from fjcommon import no_op 32 | 33 | import pytorch_ext as pe 34 | from helpers.config_checker import DEFAULT_CONFIG_DIR, ConfigsRepo 35 | from helpers.global_config import global_config 36 | from helpers.saver import Saver 37 | from train.multiscale_trainer import MultiscaleTrainer 38 | from train.train_restorer import TrainRestorer 39 | from train.trainer import LogConfig 40 | 41 | torch.backends.cudnn.benchmark = True 42 | 43 | def _print_debug_info(): 44 | print('*' * 80) 45 | print(f'DEVICE == {pe.DEVICE} // PyTorch v{torch.__version__}') 46 | print('*' * 80) 47 | 48 | 49 | def main(args, configs_dir=DEFAULT_CONFIG_DIR): 50 | p = argparse.ArgumentParser() 51 | 52 | p.add_argument('ms_config_p', help='Path to a multiscale config, see README') 53 | p.add_argument('dl_config_p', help='Path to a dataloader config, see README') 54 | p.add_argument('log_dir_root', default='logs', help='All outputs (checkpoints, tensorboard) will be saved here.') 55 | p.add_argument('--temporary', '-t', action='store_true', 56 | help='If given, outputs are actually saved in ${LOG_DIR_ROOT}_TMP.') 57 | p.add_argument('--log_train', '-ltrain', type=int, default=100, 58 | help='Interval of train output.') 59 | p.add_argument('--log_train_heavy', '-ltrainh', type=int, default=5, metavar='LOG_HEAVY_FAC', 60 | help='Every LOG_HEAVY_FAC-th time that i %% LOG_TRAIN is 0, also output heavy logs.') 61 | p.add_argument('--log_val', '-lval', type=int, default=500, 62 | help='Interval of validation output.') 63 | 64 | p.add_argument('-p', action='append', nargs=1, 65 | help='Specify global_config parameters, see README') 66 | 67 | p.add_argument('--restore', type=str, metavar='RESTORE_DIR', 68 | help='Path to the log_dir of the model to restore. If a log_date (' 69 | 'MMDD_HHmm) is given, the model is assumed to be in LOG_DIR_ROOT.') 70 | p.add_argument('--restore_continue', action='store_true', 71 | help='If given, continue in RESTORE_DIR instead of starting in a new folder.') 72 | p.add_argument('--restore_restart', action='store_true', 73 | help='If given, start from iteration 0, instead of the iteration of RESTORE_DIR. ' 74 | 'Means that the model in RESTORE_DIR is used as pretrained model') 75 | p.add_argument('--restore_itr', '-i', type=int, default=-1, 76 | help='Which iteration to restore. -1 means latest iteration. Will use closest smaller if exact ' 77 | 'iteration is not found. Only valid with --restore. Default: -1') 78 | p.add_argument('--restore_strict', type=str, help='y|n', choices=['y', 'n'], default='y') 79 | 80 | p.add_argument('--num_workers', '-W', type=int, default=8, 81 | help='Number of workers used for DataLoader') 82 | 83 | p.add_argument('--saver_keep_tmp_itr', '-si', type=int, default=250) 84 | p.add_argument('--saver_keep_every', '-sk', type=int, default=10) 85 | p.add_argument('--saver_keep_tmp_last', '-skt', type=int, default=3) 86 | p.add_argument('--no_saver', action='store_true', 87 | help='If given, no checkpoints are stored.') 88 | 89 | p.add_argument('--debug', action='store_true') 90 | 91 | flags = p.parse_args(args) 92 | 93 | _print_debug_info() 94 | 95 | if flags.debug: 96 | flags.temporary = True 97 | 98 | global_config.add_from_flag(flags.p) 99 | print(global_config) 100 | 101 | ConfigsRepo(configs_dir).check_configs_available(flags.ms_config_p, flags.dl_config_p) 102 | 103 | saver = (Saver(flags.saver_keep_tmp_itr, flags.saver_keep_every, flags.saver_keep_tmp_last, 104 | verbose=True) 105 | if not flags.no_saver 106 | else no_op.NoOp()) 107 | 108 | restorer = TrainRestorer.from_flags(flags.restore, flags.log_dir_root, flags.restore_continue, flags.restore_itr, 109 | flags.restore_restart, flags.restore_strict) 110 | 111 | trainer = MultiscaleTrainer(flags.ms_config_p, flags.dl_config_p, 112 | flags.log_dir_root + ('_TMP' if flags.temporary else ''), 113 | LogConfig(flags.log_train, flags.log_val, flags.log_train_heavy), 114 | flags.num_workers, 115 | saver=saver, restorer=restorer) 116 | if not flags.debug: 117 | trainer.train() 118 | else: 119 | trainer.debug() 120 | 121 | 122 | if __name__ == '__main__': 123 | main(sys.argv[1:]) 124 | -------------------------------------------------------------------------------- /src/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/src/train/__init__.py -------------------------------------------------------------------------------- /src/train/lr_schedule.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | import numpy as np 20 | 21 | from fjcommon.assertions import assert_exc 22 | 23 | from vis.figure_plotter import PlotToArray 24 | 25 | 26 | SPEC_SEP = '_' 27 | 28 | 29 | def from_spec(s, initial_lr, optims, epoch_len): 30 | """ 31 | grammar: one of 32 | none 33 | exp FAC (iITR|eEPOCH) 34 | cos lrmax lrmin (iITR|eEPOCH) # time to finish 35 | 36 | Example: 37 | exp_0.5_e8_warm_20_0.75_e1 38 | """ 39 | if s == 'none': 40 | return ConstantLRSchedule() 41 | 42 | schedule_kind, s = s.split(SPEC_SEP, 1) 43 | 44 | def _parse_cos_spec(): 45 | lrmax, lrmin, T = s.split(SPEC_SEP) 46 | kind, T = T[0], T[1:] 47 | assert_exc(kind in ('i', 'e'), 'Invalid spec: {}'.format(s)) 48 | T_itr = int(T) if kind == 'i' else None 49 | T_epoch = float(T) if kind == 'e' else None 50 | return CosineDecayLRSchedule(optims, float(lrmax), float(lrmin), T_itr, T_epoch, epoch_len) 51 | 52 | p = {'exp': lambda: _parse_exp_spec(s, optims, initial_lr, epoch_len), 53 | 'cos': _parse_cos_spec 54 | }[schedule_kind] 55 | return p() 56 | 57 | 58 | # ------------------------------------------------------------------------------ 59 | 60 | 61 | class LRSchedule(object): 62 | def __init__(self, optims): 63 | self.optims = optims 64 | self._current_lr = None 65 | 66 | def _get_lr(self, i): 67 | raise NotImplementedError() 68 | 69 | def update(self, i): 70 | lr = self._get_lr(i) 71 | if lr == self._current_lr: 72 | return 73 | for optim in self.optims: 74 | for pg in optim.param_groups: 75 | pg['lr'] = lr 76 | self._current_lr = lr 77 | 78 | 79 | class ConstantLRSchedule(object): 80 | def update(self, i): # no-op 81 | pass 82 | 83 | 84 | 85 | class ExponentialDecayLRSchedule(LRSchedule): 86 | def __init__(self, optims, initial, decay_fac, 87 | decay_interval_itr=None, decay_interval_epoch=None, epoch_len=None, 88 | warm_restart=None, 89 | warm_restart_schedule=None): 90 | super(ExponentialDecayLRSchedule, self).__init__(optims) 91 | assert_exc((decay_interval_itr is not None) ^ (decay_interval_epoch is not None), 'Need either iter or epoch') 92 | if decay_interval_epoch: 93 | assert epoch_len is not None 94 | decay_interval_itr = int(decay_interval_epoch * epoch_len) 95 | if warm_restart: 96 | warm_restart = int(warm_restart * epoch_len) 97 | self.initial = initial 98 | self.decay_fac = decay_fac 99 | self.decay_every_itr = decay_interval_itr 100 | 101 | self.warm_restart_itr = warm_restart 102 | self.warm_restart_schedule = warm_restart_schedule 103 | 104 | self.last_warm_restart = 0 105 | 106 | def _get_lr(self, i): 107 | if i > 0 and self.warm_restart_itr and ((i - self.last_warm_restart) % self.warm_restart_itr) == 0: 108 | if i != self.last_warm_restart: 109 | self._warm_restart() 110 | self.last_warm_restart = i 111 | i -= self.last_warm_restart 112 | num_decays = i // self.decay_every_itr 113 | return self.initial * (self.decay_fac ** num_decays) 114 | 115 | def _warm_restart(self): 116 | print('WARM restart') 117 | if self.warm_restart_schedule: 118 | self.initial = self.warm_restart_schedule.initial 119 | self.decay_fac = self.warm_restart_schedule.decay_fac 120 | self.decay_every_itr = self.warm_restart_schedule.decay_every_itr 121 | self.warm_restart_itr = self.warm_restart_schedule.warm_restart_itr 122 | self.warm_restart_schedule = self.warm_restart_schedule.warm_restart_schedule 123 | 124 | 125 | class CosineDecayLRSchedule(LRSchedule): 126 | def __init__(self, optims, lrmax, lrmin, T_itr, T_epoch, epoch_len): 127 | super(CosineDecayLRSchedule, self).__init__(optims) 128 | self.lrmax = lrmax 129 | self.lrmin = lrmin 130 | if T_itr is None: 131 | assert epoch_len is not None 132 | T_itr = int(T_epoch * epoch_len) 133 | self.Ti = T_itr 134 | self.epoch_len = epoch_len 135 | 136 | def _get_lr(self, i): 137 | Tcur = (i % self.Ti) / (2 * self.Ti) 138 | return self.lrmin + (self.lrmax - self.lrmin)*(np.cos(np.pi * Tcur)) 139 | 140 | 141 | def _parse_exp_spec(s, optims, initial_lr, epoch_len): 142 | if s.count(SPEC_SEP) > 2: 143 | fac, interval, warm, warm_start, warm_fac, warm_interval = s.split(SPEC_SEP) 144 | assert warm == 'warm' 145 | warm_start = int(warm_start) 146 | warm_schedule = _parse_exp_spec(SPEC_SEP.join([warm_fac, warm_interval]), optims, initial_lr, epoch_len) 147 | else: 148 | fac, interval = s.split(SPEC_SEP) 149 | warm_start, warm_schedule = None, None 150 | kind, interval = interval[0], interval[1:] 151 | assert_exc(kind in ('i', 'e'), 'Invalid spec: {}'.format(s)) 152 | decay_interval_itr = int(interval) if kind == 'i' else None 153 | decay_interval_epoch = float(interval) if kind == 'e' else None 154 | return ExponentialDecayLRSchedule( 155 | optims, initial_lr, float(fac), decay_interval_itr, decay_interval_epoch, epoch_len, 156 | warm_restart=warm_start, warm_restart_schedule=warm_schedule) 157 | 158 | -------------------------------------------------------------------------------- /src/train/train_restorer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | import os 20 | 21 | from helpers import logdir_helpers 22 | from helpers.paths import get_ckpts_dir, get_experiment_dir, CKPTS_DIR_NAME 23 | from helpers.saver import Restorer 24 | 25 | 26 | class TrainRestorer(Restorer): 27 | @staticmethod 28 | def from_flags(restore_dir, log_dir, restore_continue, restore_itr, 29 | restart_at_zero=False, strict='y'): 30 | if restore_dir is None: 31 | return None 32 | strict = {'y': True, 'n': False}[strict] 33 | return TrainRestorer(get_ckpts_dir(get_experiment_dir(log_dir, restore_dir)), 34 | restore_continue, restore_itr, restart_at_zero, strict) 35 | 36 | def __init__(self, out_dir, restore_continue=False, restore_itr=-1, restart_at_zero=False, strict=True, 37 | ckpt_name_fmt='ckpt_{:010d}.pt', tmp_postfix='.tmp'): 38 | """ 39 | :param out_dir: ends in ckpts/ 40 | :param restore_continue: 41 | :param restart_at_zero: 42 | :param ckpt_name_fmt: 43 | :param tmp_postfix: 44 | """ 45 | assert out_dir.rstrip(os.path.sep).endswith(CKPTS_DIR_NAME), out_dir 46 | super(TrainRestorer, self).__init__(out_dir, ckpt_name_fmt, tmp_postfix) 47 | self.restore_continue = restore_continue 48 | self.restore_itr = restore_itr 49 | self.restart_at_zero = restart_at_zero 50 | self.strict = strict 51 | 52 | def restore_desired_ckpt(self, modules): 53 | itrc, ckpt_p = self.get_ckpt_for_itr(self.restore_itr) 54 | print('Restoring {}...'.format(itrc)) 55 | return self.restore(modules, ckpt_p, self.strict, restore_restart=self.restart_at_zero) 56 | 57 | def get_log_dir(self): 58 | log_dir = os.path.dirname(self._out_dir) # should be .../logs/MMDD_HHdd config config config 59 | assert logdir_helpers.is_log_dir(log_dir) 60 | return log_dir 61 | -------------------------------------------------------------------------------- /src/train/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | 19 | -------------------------------------------------------------------------------- 20 | 21 | General Trainer class, subclassed by MultiscaleTrainer 22 | 23 | """ 24 | 25 | from collections import namedtuple 26 | 27 | import torch 28 | 29 | import torchvision 30 | from fjcommon import timer 31 | from fjcommon.no_op import NoOp 32 | 33 | from helpers import logdir_helpers 34 | import vis.safe_summary_writer 35 | from helpers.saver import Saver 36 | from helpers.global_config import global_config 37 | import itertools 38 | 39 | from train.train_restorer import TrainRestorer 40 | 41 | 42 | LogConfig = namedtuple('LogConfig', ['log_train', 'log_val', 'log_train_heavy']) 43 | 44 | 45 | class TimedIterator(object): 46 | def __init__(self, it): 47 | self.t = timer.TimeAccumulator() 48 | self.it = iter(it) 49 | 50 | def __iter__(self): 51 | return self 52 | 53 | def __next__(self): 54 | with self.t.execute(): 55 | return next(self.it) 56 | 57 | 58 | # TODO: allow "restart last epoch" or sth 59 | class TrainingSetIterator(object): 60 | """ Implements skipping to a certain iteration """ 61 | def __init__(self, skip_to_itr, dl_train): 62 | self.skip_to_itr = skip_to_itr 63 | self.dl_train = dl_train 64 | self.epoch_len = len(self.dl_train) 65 | 66 | def epochs_to_skip(self): 67 | if self.skip_to_itr: 68 | skip_epochs, skip_batches = self.skip_to_itr // self.epoch_len, self.skip_to_itr % self.epoch_len 69 | return skip_epochs, skip_batches 70 | return 0, 0 71 | 72 | def iterator(self, epoch): 73 | """ :returns an iterator over tuples (itr, batch) """ 74 | skip_epochs, skip_batches = self.epochs_to_skip() 75 | if epoch < skip_epochs: 76 | print('Skipping epoch {}'.format(epoch)) 77 | return [] # nothing to iterate 78 | if epoch > skip_epochs or (epoch == skip_epochs and skip_batches == 0): # iterate like normal 79 | return enumerate(self.dl_train, epoch * len(self.dl_train)) 80 | # if we get to here, we are in the first epoch which we should not skip, so skip `skip_batches` batches 81 | it = iter(self.dl_train) 82 | for i in range(skip_batches): 83 | print('\rDropping batch {: 10d}...'.format(i), end='') 84 | if not global_config.get('drop_batches', False): 85 | # would be nice to not load images but this is hard to do as DataLoader caches Dataset's respondes, 86 | # might even be immutable? 87 | next(it) # drop batch 88 | print(' -- dropped {} batches'.format(skip_batches)) 89 | return enumerate(it, epoch * len(self.dl_train) + skip_batches) 90 | 91 | 92 | class AbortTrainingException(Exception): 93 | pass 94 | 95 | 96 | class Trainer(object): 97 | def __init__(self, dl_train, dl_val, optims, net, sw: vis.safe_summary_writer.SafeSummaryWriter, 98 | max_epochs, log_config: LogConfig, saver: Saver=None, skip_to_itr=None): 99 | 100 | assert isinstance(optims, list) 101 | 102 | self.dl_train = dl_train 103 | self.dl_val = dl_val 104 | self.optims = optims 105 | self.net = net 106 | self.sw = sw 107 | self.max_epochs = max_epochs 108 | self.log_config = log_config 109 | self.saver = saver if saver is not None else NoOp 110 | 111 | self.skip_to_itr = skip_to_itr 112 | 113 | def continue_from(self, ckpt_dir): 114 | pass 115 | 116 | def train(self): 117 | log_train, log_val, log_train_heavy = self.log_config 118 | 119 | dl_train_it = TrainingSetIterator(self.skip_to_itr, self.dl_train) 120 | 121 | _print_unused_global_config() 122 | 123 | try: 124 | for epoch in (range(self.max_epochs) if self.max_epochs else itertools.count()): 125 | self.print_epoch_sep(epoch) 126 | self.prepare_for_epoch(epoch) 127 | t = TimedIterator(dl_train_it.iterator(epoch)) 128 | for i, img_batch in t: 129 | for o in self.optims: 130 | o.zero_grad() 131 | should_log = (i > 0 and i % log_train == 0) 132 | should_log_heavy = (i > 0 and (i / log_train_heavy) % log_train == 0) 133 | self.train_step(i, img_batch, 134 | log=should_log, 135 | log_heavy=should_log_heavy, 136 | load_time='[{:.2e} s/batch load]'.format(t.t.mean_time_spent()) if should_log else None) 137 | self.saver.save(self.modules_to_save(), i) 138 | 139 | if i > 0 and i % log_val == 0: 140 | self._eval(i) 141 | except AbortTrainingException as e: 142 | print('Caught {}'.format(e)) 143 | return 144 | 145 | def _eval(self, i): 146 | self.net.eval() 147 | with torch.no_grad(): 148 | self.validation_loop(i) 149 | self.net.train() 150 | 151 | def debug(self): 152 | print('Debug ---') 153 | _print_unused_global_config() 154 | self.prepare_for_epoch(0) 155 | self.train_step(0, next(iter(self.dl_train)), 156 | log=True, log_heavy=True, load_time=0) 157 | self._eval(101) 158 | 159 | def print_epoch_sep(self, epoch): 160 | print('-' * 80) 161 | print(' EPOCH {}'.format(epoch)) 162 | print('-' * 80) 163 | 164 | def modules_to_save(self): 165 | """ used to save and restore. Should return a dictionary module_name -> nn.Module """ 166 | raise NotImplementedError() 167 | 168 | def train_step(self, i, img_batch, log, log_heavy, load_time=None): 169 | raise NotImplementedError() 170 | 171 | def validation_loop(self, i): 172 | raise NotImplementedError() 173 | 174 | def prepare_for_epoch(self, epoch): 175 | pass 176 | 177 | def add_filter_summaray(self, tag, p, global_step): 178 | if len(p.shape) == 1: # bias 179 | C, = p.shape 180 | p = p.reshape(C, 1).expand(C, C).reshape(C, C, 1, 1) 181 | 182 | try: 183 | _, _, H, W = p.shape 184 | except ValueError: 185 | if global_step == 0: 186 | print('INFO: Cannot unpack {} ({})'.format(p.shape, tag)) 187 | return 188 | 189 | if H == W == 1: # 1x1 conv 190 | p = p[:, :, 0, 0] 191 | filter_vis = torchvision.utils.make_grid(p, normalize=True) 192 | self.sw.add_image(tag, filter_vis, global_step) 193 | 194 | @staticmethod 195 | def update_lrs(epoch, optims_lrs_facs_interval): 196 | raise DeprecationWarning('use lr_schedule.py') 197 | for optimizer, init_lr, decay_fac, interval_epochs in optims_lrs_facs_interval: 198 | if decay_fac is None: 199 | continue 200 | Trainer.exp_lr_scheduler(optimizer, epoch, init_lr, decay_fac, interval_epochs) 201 | 202 | @staticmethod 203 | def exp_lr_scheduler(optimizer, epoch, init_lr, decay_fac=0.1, interval_epochs=7): 204 | raise DeprecationWarning('use lr_schedule.py') 205 | lr = init_lr * (decay_fac ** (epoch // interval_epochs)) 206 | print('LR = {}'.format(lr)) 207 | for param_group in optimizer.param_groups: 208 | param_group['lr'] = lr 209 | 210 | def get_lrs(self): 211 | for optim in self.optims: 212 | for param_group in optim.param_groups: 213 | yield param_group['lr'] 214 | 215 | def maybe_restore(self, restorer: TrainRestorer): 216 | """ 217 | :return: skip_to_itr 218 | """ 219 | if restorer is None: 220 | return None # start from 0 221 | restore_itr = restorer.restore_desired_ckpt(self.modules_to_save()) # TODO: allow arbitrary ckpts 222 | if restorer.restart_at_zero: 223 | return 0 224 | return restore_itr 225 | 226 | @staticmethod 227 | def get_log_dir(log_dir_root, rel_paths, restorer, strip_ext='.cf'): 228 | if not restorer or not restorer.restore_continue: 229 | log_dir = logdir_helpers.create_unique_log_dir( 230 | rel_paths, log_dir_root, strip_ext=strip_ext, postfix=global_config.values()) 231 | print('Created {}...'.format(log_dir)) 232 | else: 233 | log_dir = restorer.get_log_dir() 234 | print('Using {}...'.format(log_dir)) 235 | return log_dir 236 | 237 | 238 | def _print_unused_global_config(ignore=None): 239 | """ For safety, print parameters that were passed with -p but never used during construction of graph. """ 240 | if not ignore: 241 | ignore = [] 242 | unused = [u for u in global_config.get_unused_params() if u not in ignore] 243 | if unused: 244 | raise ValueError('Unused params:\n- ' + '\n- '.join(unused)) 245 | 246 | -------------------------------------------------------------------------------- /src/vis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/src/vis/__init__.py -------------------------------------------------------------------------------- /src/vis/figure_plotter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | import torch 20 | 21 | from sys import platform 22 | import matplotlib as mpl 23 | if platform != 'darwin': 24 | mpl.use('Agg') # No display 25 | import matplotlib.pyplot as plt 26 | import matplotlib.backends.backend_agg as plt_backend_agg 27 | import numpy as np 28 | 29 | 30 | class PlotToArray(object): 31 | """ 32 | p = PlotToArray() 33 | plt = p.prepare() 34 | # add plot to plt 35 | im = plt.get_numpy() # CHW 36 | """ 37 | def __init__(self): 38 | self.fig = None 39 | 40 | def prepare(self): 41 | self.fig = plt.figure(dpi=100) 42 | return plt 43 | 44 | def get_numpy(self): 45 | assert self.fig is not None 46 | return figure_to_image(self.fig) # CHW 47 | 48 | def get_tensor(self): 49 | return torch.from_numpy(self.get_numpy()) # CHW 50 | 51 | 52 | def figure_to_image(figures, close=True): 53 | if not isinstance(figures, list): 54 | image = _render_to_rgb(figures, close) 55 | return image 56 | else: 57 | images = [_render_to_rgb(figure, close) for figure in figures] 58 | return np.stack(images) 59 | 60 | def _render_to_rgb(figure, close): 61 | canvas = plt_backend_agg.FigureCanvasAgg(figure) 62 | canvas.draw() 63 | data = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8) 64 | w, h = figure.canvas.get_width_height() 65 | image_hwc = data.reshape([h, w, 4])[..., :3] 66 | image_chw = np.moveaxis(image_hwc, source=2, destination=0) 67 | if close: 68 | plt.close(figure) 69 | return image_chw 70 | -------------------------------------------------------------------------------- /src/vis/grid.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | import functools 20 | 21 | import torch 22 | from fjcommon import functools_ext as ft 23 | from torch.nn import functional as F 24 | 25 | 26 | # TODO: document thoroughly 27 | def prep_for_grid(x, pad_to=None, channelwise=False, insert_empty_indices=None): 28 | if insert_empty_indices is not None: 29 | assert isinstance(insert_empty_indices, list) 30 | assert isinstance(x, list), 'Need list for insert_empty_indices, got {}'.format(x) 31 | if isinstance(x, tuple) or isinstance(x, list): 32 | if insert_empty_indices: 33 | some_x = x[0] 34 | for idx in sorted(insert_empty_indices, reverse=True): 35 | x.insert(idx, torch.zeros_like(some_x)) 36 | if pad_to is None: # we are given a list of tensors, they must be padded! 37 | pad_to = max(el.size()[-1] for el in x) # maximum width 38 | _prep_for_grid = functools.partial(prep_for_grid, 39 | pad_to=pad_to, channelwise=channelwise, insert_empty_indices=None) 40 | return ft.lconcat(map(_prep_for_grid, x)) 41 | 42 | if x.dim() == 2: # HW 43 | x = x.unsqueeze(0) # NHW 44 | if x.dim() == 3: # NHW 45 | assert not channelwise 46 | x = x.unsqueeze(1) # NCHW 47 | assert x.dim() == 4, "Expected NCHW" 48 | x = x[0, ...] # now: CHW 49 | if pad_to: 50 | w = x.size()[-1] 51 | pad = (pad_to - w) // 2 52 | if pad: 53 | x = F.pad(x, (pad, pad, pad, pad)) 54 | if channelwise: 55 | return [c.unsqueeze(0) for c in torch.unbind(x, 0)] 56 | else: 57 | return [x] 58 | -------------------------------------------------------------------------------- /src/vis/histogram_plot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | 19 | -------------------------------------------------------------------------------- 20 | 21 | Class to create histograms of tensors. 22 | 23 | """ 24 | 25 | from helpers import rolling_buffer 26 | from vis.summarizable_module import SummarizableModule 27 | 28 | 29 | class HistogramPlot(SummarizableModule): 30 | def __init__(self, prefix, name, buffer_size, num_inputs_to_buffer=1, per_channel=False, most_mass=5e-5): 31 | """ 32 | :param prefix: 33 | :param name: Name in TensorBoard 34 | :param buffer_size: buffer size 35 | :param num_inputs_to_buffer: x[:num_inputs_to_buffer, ...] will be stored only 36 | :param per_channel: if True, create a histo per channel 37 | """ 38 | super(HistogramPlot, self).__init__() 39 | self.buffer_size = buffer_size 40 | self.num_inputs_to_buffer = num_inputs_to_buffer 41 | self.per_channel = per_channel 42 | self.buffers = None 43 | self.num_chan = None # non-None if self.per_channel and forward has been called at least once 44 | self.figure_creator = {name: self._plot} 45 | self.prefix = prefix 46 | self.name = name 47 | self.most_mass = most_mass 48 | 49 | def forward(self, x): 50 | """ 51 | :param x: Tensor 52 | :returns: x 53 | """ 54 | if not self.training: # only during training 55 | return x 56 | if self.buffers is None: 57 | self.buffers = self._new_buffers(x.detach()) 58 | if self.per_channel: 59 | for c in range(self.num_chan): 60 | self.buffers[c].add(x[:self.num_inputs_to_buffer, c, ...].detach()) 61 | else: 62 | self.buffers.add(x[:self.num_inputs_to_buffer, ...].detach()) 63 | # register for plotting 64 | self.summarizer.register_figures(self.prefix, self.figure_creator) 65 | return x 66 | 67 | def _plot(self, plt): 68 | """ Called when summarizer decides to plot. """ 69 | for b in self._iter_over_buffers(): 70 | x, y = b.plot(bins=128, most_mass=self.most_mass) 71 | plt.plot(x, y) 72 | 73 | def _iter_over_buffers(self): 74 | if not self.per_channel: # self.buffers just a RollingBuffer instance 75 | yield self.buffers 76 | else: 77 | yield from self.buffers 78 | 79 | def _new_buffers(self, x): 80 | if not self.per_channel: 81 | return rolling_buffer.RollingBufferHistogram(self.buffer_size, self.name) 82 | self.num_chan = x.shape[1] 83 | return [rolling_buffer.RollingBufferHistogram(self.buffer_size, self.name) 84 | for _ in range(self.num_chan)] -------------------------------------------------------------------------------- /src/vis/histogram_plotter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | import numpy as np 20 | 21 | 22 | _WIDTH = 0.8 23 | 24 | 25 | # TODO: rename to something with bar plot 26 | 27 | 28 | def plot_histogram(datas, plt): 29 | for i, data in enumerate(datas): 30 | rel_i = i - len(datas) / 2 31 | w = _WIDTH/len(datas) 32 | _plot_histogram(data, plt, w, rel_i * w) 33 | plt.legend() 34 | 35 | 36 | def _plot_histogram(data, plt, width, offset): 37 | name, values = data 38 | plt.bar(np.arange(len(values)) + offset, values, 39 | width=width, 40 | label=name, align='edge') 41 | 42 | 43 | def _test(): 44 | import matplotlib.pyplot as plt 45 | 46 | f = plt.figure() 47 | datas = [('gt', [1000, 10, 33, 500, 600, 700]), 48 | ('outs', [900, 20, 0, 0, 100, 1000]), 49 | ('ups', 0.5 * np.array([900, 20, 0, 0, 100, 1000])), 50 | ] 51 | plot_histogram(datas, plt) 52 | f.show() 53 | 54 | 55 | if __name__ == '__main__': 56 | _test() 57 | -------------------------------------------------------------------------------- /src/vis/image_summaries.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | import numpy as np 20 | import pytorch_ext as pe 21 | from fjcommon.assertions import assert_exc 22 | 23 | 24 | def imshow(img): 25 | """ Only meant for local visualizations. Requires HW3 """ 26 | import matplotlib.pyplot as plt 27 | assert isinstance(img, np.ndarray) 28 | assert img.dtype == np.uint8 29 | assert img.ndim == 3 30 | if img.shape[0] == 3: 31 | img = img.transpose(1, 2, 0) 32 | assert img.shape[2] == 3, img.shape 33 | plt.imshow(img) 34 | plt.show() 35 | 36 | 37 | def to_image(t): 38 | """ 39 | :param t: tensor or np.ndarray, may be of shape NCHW / CHW with C=1 or 3 / HW, dtype float32 or uint8. If float32: 40 | must be in [0, 1] 41 | :return: HW3 uint8 np.ndarray 42 | """ 43 | if not isinstance(t, np.ndarray): 44 | t = pe.tensor_to_np(t) 45 | # - t is numpy array 46 | if t.ndim == 4: 47 | # - t has batch dimension, only use first 48 | t = t[0, ...] 49 | elif t.ndim == 2: 50 | t = np.expand_dims(t, 0) # Now 1HW 51 | assert_exc(t.ndim == 3, 'Invalid shape: {}'.format(t.shape)) 52 | # - t is 3 dimensional CHW numpy array 53 | if t.dtype != np.uint8: 54 | assert_exc(t.dtype == np.float32, 'Expected either uint8 or float32, got {}'.format(t.dtype)) 55 | _check_range(t, 0, 1) 56 | t = (t * 255.).astype(np.uint8) 57 | # - t is uint8 numpy array 58 | num_channels = t.shape[0] 59 | if num_channels == 3: 60 | t = np.transpose(t, (1, 2, 0)) 61 | elif num_channels == 1: 62 | t = np.stack([t[0, :, :] for _ in range(3)], -1) 63 | else: 64 | raise ValueError('Expected CHW, got {}'.format(t.shape)) 65 | assert_exc(t.ndim == 3 and t.shape[2] == 3, str(t.shape)) 66 | # - t is uint8 numpy array of shape HW3 67 | return t 68 | 69 | 70 | def _check_range(a, lo, hi): 71 | a_lo, a_hi = np.min(a), np.max(a) 72 | assert_exc(a_lo >= lo and a_hi <= hi, 'Invalid range: [{}, {}]. Expected: [{}, {}]'.format(a_lo, a_hi, lo, hi)) 73 | 74 | 75 | -------------------------------------------------------------------------------- /src/vis/safe_summary_writer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | """ 19 | from contextlib import contextmanager 20 | 21 | from tensorboardX import SummaryWriter 22 | from tensorboardX.summary import Summary, _clean_tag, make_image as make_image_summary 23 | 24 | from vis.figure_plotter import PlotToArray 25 | from vis.image_summaries import to_image 26 | 27 | 28 | class SafeSummaryWriter(SummaryWriter): 29 | def get_events_file_path(self): 30 | try: 31 | return self.file_writer.event_writer._ev_writer._file_prefix 32 | except AttributeError: 33 | print('Cannot get events file name...') 34 | return None 35 | 36 | @staticmethod 37 | def pre(prefix, tag): 38 | assert prefix[0] != '/' 39 | return prefix.rstrip('/') + '/' + tag.lstrip('/') 40 | 41 | def add_image(self, tag, img_tensor, global_step=None, **kwargs): 42 | """ 43 | Add image img_tensor to summary. 44 | img_tensor can be np.ndarray or torch.Tensor, 1HW or 3HW or HW 45 | If img_tensor is uint8: 46 | add it like it is 47 | If img_tensor is float32: 48 | check that it is in [0, 1] and add it 49 | :param **kwargs: 50 | """ 51 | if len(img_tensor.shape) == 2: 52 | img_tensor = img_tensor.reshape(1, *img_tensor.shape) 53 | self.file_writer.add_summary(SafeSummaryWriter._to_image_summary_safe(tag, img_tensor), global_step) 54 | 55 | @contextmanager 56 | def add_figure_ctx(self, tag, global_step=None): 57 | """ 58 | Context manager that yields a plt to draw on, converts it to image 59 | """ 60 | p = PlotToArray() 61 | yield p.prepare() # user draws 62 | self.add_image(tag, p.get_numpy(), global_step, ) 63 | 64 | @staticmethod 65 | def _to_image_summary_safe(tag, tensor): 66 | tag = _clean_tag(tag) 67 | img = make_image_summary(to_image(tensor)) 68 | return Summary(value=[Summary.Value(tag=tag, image=img)]) 69 | 70 | -------------------------------------------------------------------------------- /src/vis/summarizable_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019, ETH Zurich 3 | 4 | This file is part of L3C-PyTorch. 5 | 6 | L3C-PyTorch is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | any later version. 10 | 11 | L3C-PyTorch is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with L3C-PyTorch. If not, see . 18 | 19 | -------------------------------------------------------------------------------- 20 | 21 | Contains the neat SummarizableModule class. It's a replacement for nn.Module. If in a tree of nn.Modules, the root 22 | is a SummarizableModule and some leaves are also, you can call `register_summarizer` on the root, and this will add 23 | an instance of `Summarizer` to every SummarizableModule in the tree. 24 | Then, the module can call stuff like: 25 | 26 | self.summarizer.register_scalars('train', {'lr': self.lr}) 27 | 28 | def _plot(plt): 29 | plt.plot(x, y) 30 | 31 | self.summarizer.register_figures('val', {'niceplot': _plot}) 32 | 33 | """ 34 | from contextlib import contextmanager 35 | 36 | import torch 37 | from fjcommon.assertions import assert_exc 38 | from fjcommon.no_op import NoOp 39 | from torch import nn as nn 40 | 41 | 42 | class _GlobalStepDependable(object): 43 | def __init__(self): 44 | self.global_step = None 45 | self.enabled_prefix = None 46 | 47 | def enable(self, prefix, global_step): 48 | """ Enable logging of prefix """ 49 | assert_exc(isinstance(prefix, str), 'prefix must be str, got {}'.format(prefix)) 50 | assert_exc(prefix[-1] != '/') 51 | self.enabled_prefix = prefix 52 | self.global_step = global_step 53 | 54 | def disable(self): 55 | self.enabled_prefix = None 56 | 57 | @contextmanager 58 | def maybe_enable(self, prefix, flag, global_step): 59 | if flag: 60 | self.enable(prefix, global_step) 61 | yield 62 | self.disable() 63 | 64 | 65 | def normalize_to_0_1(t): 66 | return t.add(-t.min()).div(t.max() - t.min() + 1e-5) 67 | 68 | 69 | class Summarizer(_GlobalStepDependable): 70 | def __init__(self, sw): 71 | super(Summarizer, self).__init__() 72 | self.sw = sw 73 | 74 | def register_scalars(self, prefix, values): 75 | """ 76 | :param prefix: Prefix to use in TensorBoard 77 | :param values: A dictionary of name -> value, where value can be callable (useful if it is expensive) 78 | """ 79 | if self.enabled_prefix is None: 80 | return 81 | if prefix == 'auto': 82 | prefix = self.enabled_prefix 83 | if prefix == self.enabled_prefix: 84 | for name, value in values.items(): 85 | self.sw.add_scalar(prefix + '/' + name, _convert_if_callable(value), self.global_step) 86 | 87 | def register_figures(self, prefix, creators): 88 | """ 89 | :param prefix: Prefix to use in TensorBoard 90 | :param creators: plot_name -> (plt -> None) 91 | """ 92 | if self.enabled_prefix is None: 93 | return 94 | if prefix == 'auto': 95 | prefix = self.enabled_prefix 96 | if prefix == self.enabled_prefix: 97 | for name, creator in creators.items(): 98 | with self.sw.add_figure_ctx(prefix + '/' + name, self.global_step) as plt: 99 | creator(plt) 100 | 101 | def register_images(self, prefix, imgs, normalize=False): # , early_only=False): 102 | """ 103 | :param prefix: Prefix to use in TensorBoard 104 | :param imgs: A dictionary of name -> img, where img can be callable (useful if it is expensive) 105 | :param normalize: If given, will normalize imgs to [0,1] 106 | """ 107 | if self.enabled_prefix is None: 108 | return 109 | if prefix == 'auto': 110 | prefix = self.enabled_prefix 111 | if prefix == self.enabled_prefix: 112 | for name, img in imgs.items(): 113 | img = _convert_if_callable(img) 114 | if normalize: 115 | img = normalize_to_0_1(img) 116 | self.sw.add_image(prefix + '/' + name, img, self.global_step) 117 | 118 | 119 | def _convert_if_callable(v): 120 | if hasattr(v, '__call__'): # python 3 only 121 | return v() 122 | return v 123 | 124 | 125 | class SummarizableModule(nn.Module): 126 | def __init__(self): 127 | super(SummarizableModule, self).__init__() 128 | self.summarizer = NoOp 129 | 130 | def forward(self, *input): 131 | raise NotImplementedError 132 | 133 | def register_summarizer(self, summarizer: Summarizer): 134 | for m in iter_modules_of_class(self, SummarizableModule): 135 | m.summarizer = summarizer 136 | 137 | 138 | def iter_modules_of_class(root_module: nn.Module, cls): 139 | """ 140 | Helpful for extending nn.Module. How to use: 141 | 1. define new nn.Module subclass with some new instance methods, cls 142 | 2. make your root module inherit from cls 143 | 3. make some leaf module inherit from cls 144 | """ 145 | for m in root_module.modules(): 146 | if isinstance(m, cls): 147 | yield m 148 | 149 | 150 | # Tests ------------------------------------------------------------------------ 151 | 152 | 153 | def test_submodules(): 154 | class _T(nn.Module): 155 | def __init__(self): 156 | super(_T, self).__init__() 157 | self.foo = [] 158 | 159 | def register_foo(self, f): 160 | self.foo.append(f) 161 | 162 | def get_all(self): 163 | for m_ in iter_modules_of_class(self, _T): 164 | yield from m_.foo 165 | 166 | class _SomethingWithTs(nn.Module): 167 | def __init__(self): 168 | super(_SomethingWithTs, self).__init__() 169 | self.a_t = _T() 170 | 171 | class _M(_T): # first T 172 | def __init__(self): 173 | super(_M, self).__init__() 174 | self.conv = nn.Conv2d(1, 2, 3) 175 | self.t = _T() # here 176 | self.t.register_foo(1) 177 | self.list = nn.ModuleList( 178 | [nn.Conv2d(1, 2, 3), 179 | _T()]) # here 180 | inner = _T() 181 | self.seq = nn.Sequential( 182 | nn.Conv2d(1, 2, 3), 183 | inner, # here 184 | _SomethingWithTs()) # here 185 | inner.register_foo(2) 186 | 187 | m = _M() 188 | all_ts = list(iter_modules_of_class(m, _T)) 189 | assert len(all_ts) == 5, all_ts 190 | assert list(m.get_all()) == [1, 2] 191 | --------------------------------------------------------------------------------