├── .gitignore
├── LICENSE
├── README.md
├── figs
├── arch_detail.png
├── l3c_code_outline.jpg
└── teaser.png
├── pip_requirements.txt
└── src
├── auto_crop.py
├── bitcoding
├── __init__.py
├── bitcoding.py
├── coders.py
├── coders_helpers.py
└── part_suffix_helper.py
├── blueprints
├── __init__.py
└── multiscale_blueprint.py
├── configs
├── dl
│ ├── in32.cf
│ ├── in64.cf
│ └── oi.cf
└── ms
│ ├── cr.cf
│ ├── cr_rgb.cf
│ └── cr_rgb_shared.cf
├── criterion
├── __init__.py
└── logistic_mixture.py
├── dataloaders
├── __init__.py
└── images_loader.py
├── helpers
├── __init__.py
├── aligned_printer.py
├── config_checker.py
├── global_config.py
├── logdir_helpers.py
├── pad.py
├── paths.py
├── rolling_buffer.py
├── saver.py
└── testset.py
├── import_train_images.py
├── import_train_images_v1.py
├── l3c.py
├── modules
├── __init__.py
├── edsr.py
├── head.py
├── multiscale_network.py
├── net.py
├── prob_clf.py
└── quantizer.py
├── prep_openimages.sh
├── pytorch_ext.py
├── test.py
├── test
├── __init__.py
├── cuda_timer.py
├── image_saver.py
└── multiscale_tester.py
├── tmp_discard_list
├── torchac
├── setup.py
├── torchac.py
└── torchac_backend
│ ├── torchac.cpp
│ └── torchac_kernel.cu
├── train.py
├── train
├── __init__.py
├── lr_schedule.py
├── multiscale_trainer.py
├── train_restorer.py
└── trainer.py
└── vis
├── __init__.py
├── figure_plotter.py
├── grid.py
├── histogram_plot.py
├── histogram_plotter.py
├── image_summaries.py
├── safe_summary_writer.py
└── summarizable_module.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .idea/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 |
62 | # Flask stuff:
63 | instance/
64 | .webassets-cache
65 |
66 | # Scrapy stuff:
67 | .scrapy
68 |
69 | # Sphinx documentation
70 | docs/_build/
71 |
72 | # PyBuilder
73 | target/
74 |
75 | # Jupyter Notebook
76 | .ipynb_checkpoints
77 |
78 | # pyenv
79 | .python-version
80 |
81 | # celery beat schedule file
82 | celerybeat-schedule
83 |
84 | # SageMath parsed files
85 | *.sage.py
86 |
87 | # Environments
88 | .env
89 | .venv
90 | env/
91 | venv/
92 | ENV/
93 | env.bak/
94 | venv.bak/
95 |
96 | # Spyder project settings
97 | .spyderproject
98 | .spyproject
99 |
100 | # Rope project settings
101 | .ropeproject
102 |
103 | # mkdocs documentation
104 | /site
105 |
106 | # mypy
107 | .mypy_cache/
108 |
--------------------------------------------------------------------------------
/figs/arch_detail.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/figs/arch_detail.png
--------------------------------------------------------------------------------
/figs/l3c_code_outline.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/figs/l3c_code_outline.jpg
--------------------------------------------------------------------------------
/figs/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/figs/teaser.png
--------------------------------------------------------------------------------
/pip_requirements.txt:
--------------------------------------------------------------------------------
1 | Pillow
2 | scipy==1.1.0
3 | fasteners==0.14.1
4 | fjcommon==0.2.10
5 | tensorboardX==1.2
6 | matplotlib==2.2.2
7 |
--------------------------------------------------------------------------------
/src/auto_crop.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for cropping image depending on resolution.
3 |
4 | TODO: replace recursive code with something more adaptive, right now we only do
5 | - no crops / 2x2 / 4x4 / 16x16 / etc.
6 | and the stitching code is complicated, but would be nice to have e.g. 3x3.
7 | """
8 | import math
9 | import os
10 |
11 | import itertools
12 |
13 | import torch
14 |
15 | from blueprints.multiscale_blueprint import MultiscaleLoss
16 | import functools
17 | import operator
18 |
19 |
20 | def prod(it):
21 | return functools.reduce(operator.mul, it, 1)
22 |
23 |
24 | # Images with H * W > prod(_NEEDS_CROP_DIM) will be split into crops
25 | # We set this empirically such that crops fit into our TITAN X (Pascal) with 12GB VRAM.
26 | # You can set this from the console using AC_NEEDS_CROP_DIM, e.g.,
27 | #
28 | # AC_NEEDS_CROP_DIM=2000,2000 python test.py ...
29 | #
30 | # But expect OOM errors for big values.
31 | _NEEDS_CROP_DIM_DEFAULT = '2000,1500'
32 | _NEEDS_CROP_DIM = os.environ.get('AC_NEEDS_CROP_DIM', _NEEDS_CROP_DIM_DEFAULT)
33 | if _NEEDS_CROP_DIM != _NEEDS_CROP_DIM_DEFAULT:
34 | print('*** AC_NEEDS_CROP_DIM =', _NEEDS_CROP_DIM)
35 | _NEEDS_CROP_DIM = prod(map(int, _NEEDS_CROP_DIM.split(',')))
36 | print('*** AC_NEEDS_CROP_DIM =', _NEEDS_CROP_DIM)
37 |
38 |
39 | def _assert_valid_image(i):
40 | if len(i.shape) != 4 or i.shape[1] != 3:
41 | raise ValueError(f'Expected BCHW image, got {i.shape}')
42 |
43 |
44 | def needs_crop(img, needs_crop_dim=_NEEDS_CROP_DIM):
45 | _assert_valid_image(img)
46 | H, W = img.shape[-2:]
47 | return H * W > needs_crop_dim
48 |
49 |
50 | def _crop16(im):
51 | for im_cropped in _crop4(im):
52 | yield from _crop4(im_cropped)
53 |
54 |
55 | def iter_crops(img, needs_crop_dim=_NEEDS_CROP_DIM):
56 | _assert_valid_image(img)
57 |
58 | if not needs_crop(img, needs_crop_dim):
59 | yield img
60 | return
61 | for img_crop in _crop4(img):
62 | yield from iter_crops(img_crop, needs_crop_dim)
63 |
64 |
65 | def _crop4(img):
66 | _assert_valid_image(img)
67 | H, W = img.shape[-2:]
68 | imgs = [img[..., :H//2, :W//2], # Top left
69 | img[..., :H//2, W//2:], # Top right
70 | img[..., H//2:, :W//2], # Bottom left
71 | img[..., H//2:, W//2:]] # Bottom right
72 | # Validate that we got all pixels
73 | assert sum(prod(img.shape[-2:]) for img in imgs) == \
74 | prod(img.shape[-2:])
75 | return imgs
76 |
77 |
78 | def _get_crop_idx_mapping(side):
79 | """Helper method to get the order of crops.
80 |
81 | :param side: how many crops live on each side.
82 |
83 | Example. Say you have an image that gets devided into 16 crops, i.e., the image gets cut into 16 parts:
84 |
85 | [[ 0, 1, 2, 3],
86 | [ 4, 5, 6, 7],
87 | [ 8, 9, 10, 11],
88 | [12, 13, 14, 15]],
89 |
90 | However, due to our recursive cropping code, this results in crops that are ordered like this:
91 |
92 | index of crop:
93 | 0 1 2 3 4 ...
94 | corresponds to part in image:
95 | 0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15
96 |
97 | This method returns the inverse, going from the index of the crop back to the index in the image.
98 | """
99 | a = torch.arange(side * side).reshape(1, 1, side, side)
100 | a = torch.cat((a, a, a), dim=1)
101 | # Create mapping
102 | # Index of crop in original image -> index of crop in the order it was extracted,
103 | # E.g. 2 -> 4 means it's the 2nd crop, but in the image, it's at position 4 (see above).
104 | crops = {i: crop[0, 0, ...].flatten().item()
105 | for i, crop in enumerate(iter_crops(a, 1))}
106 | return crops
107 |
108 |
109 | def stitch(parts):
110 | side = int(math.sqrt(len(parts)))
111 | if side * side != len(parts):
112 | raise ValueError(f'Invalid number of parts {len(parts)}')
113 |
114 | rows = []
115 |
116 | # Sort by original position in image
117 | crops_idx_mapping = _get_crop_idx_mapping(side)
118 | parts_sorted = (
119 | part for _, part in sorted(
120 | enumerate(parts), key=lambda ip: crops_idx_mapping[ip[0]]))
121 |
122 | parts_itr = iter(parts_sorted) # Turn into iterator so we can easily grab elements
123 | for _ in range(side):
124 | parts_row = itertools.islice(parts_itr, side) # Get `side` number of parts
125 | row = torch.cat(list(parts_row), dim=3) # cat on W dimension
126 | rows.append(row)
127 |
128 | assert next(parts_itr, None) is None, f'Iterator should be empty, got {len(rows)} rows'
129 | img = torch.cat(rows, dim=2) # cat on H dimension
130 |
131 | # Validate.
132 | B, C, H_part, W_part = parts[0].shape
133 | expected_shape = (B, C, H_part * side, W_part * side)
134 | assert img.shape == expected_shape, f'{img.shape} != {expected_shape}'
135 |
136 | return img
137 |
138 |
139 | class CropLossCombinator(object):
140 | """Used to combine the bpsp of different crops into one. Supports crops of varying dimensions."""
141 | def __init__(self):
142 | self._num_bits_total = 0.
143 | self._num_subpixels_total = 0
144 |
145 | def add(self, bpsp, num_subpixels_crop):
146 | bits = bpsp * num_subpixels_crop
147 | self._num_bits_total += bits
148 | self._num_subpixels_total += num_subpixels_crop
149 |
150 | def get_bpsp(self):
151 | assert self._num_subpixels_total > 0
152 | return self._num_bits_total / self._num_subpixels_total
153 |
154 |
155 | def test_auto_crop():
156 | import torch
157 | import pytorch_ext as pe
158 |
159 | for H, W, num_crops_expected in [(10000, 6000, 64),
160 | (4928, 3264, 16),
161 | (2048, 2048, 4),
162 | (1024, 1024, 1),
163 | ]:
164 | img = (torch.rand(1, 3, H, W) * 255).round().long()
165 | print(img.shape)
166 | if num_crops_expected > 1:
167 | assert needs_crop(img)
168 | crops = list(iter_crops(img, 2048 * 1024))
169 | assert len(crops) == num_crops_expected
170 | pe.assert_equal(stitch(crops), img)
171 | else:
172 | pe.assert_equal(next(iter_crops(img, 2048 * 1024)), img)
173 |
174 |
--------------------------------------------------------------------------------
/src/bitcoding/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/src/bitcoding/__init__.py
--------------------------------------------------------------------------------
/src/bitcoding/coders.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 |
19 | --------------------------------------------------------------------------------
20 |
21 | Very thin wrapper around torchac, for arithmetic coding.
22 |
23 | """
24 | import torch
25 | from torchac import torchac
26 | from fjcommon import no_op
27 |
28 |
29 | from criterion.logistic_mixture import CDFOut
30 | from test.cuda_timer import StackTimeLogger
31 |
32 |
33 | class ArithmeticCoder(object):
34 | def __init__(self, L):
35 | self.L = L
36 | self._cached_cdf = None
37 |
38 | def range_encode(self, data, cdf, time_logger: StackTimeLogger):
39 | """
40 | :param data: data to encode
41 | :param cdf: cdf to use, either a NHWLp matrix or instance of CDFOut
42 | :return: data encode to a bytes string
43 | """
44 | assert len(data.shape) == 3, data.shape
45 |
46 | with time_logger.run('data -> cpu'):
47 | data = data.to('cpu', non_blocking=True)
48 | assert data.dtype == torch.int16, 'Wrong dtype: {}'.format(data.dtype)
49 |
50 | with time_logger.run('reshape'):
51 | data = data.reshape(-1).contiguous()
52 |
53 | if isinstance(cdf, CDFOut):
54 | logit_probs_c_sm, means_c, log_scales_c, K, targets = cdf
55 |
56 | with time_logger.run('ac.encode'):
57 | out_bytes = torchac.encode_logistic_mixture(
58 | targets, means_c, log_scales_c, logit_probs_c_sm, data)
59 | else:
60 | N, H, W, Lp = cdf.shape
61 | assert Lp == self.L + 1, (Lp, self.L)
62 |
63 | with time_logger.run('ac.encode'):
64 | out_bytes = torchac.encode_cdf(cdf, data)
65 |
66 | return out_bytes
67 |
68 | def range_decode(self, encoded_bytes, cdf, time_logger: StackTimeLogger = no_op.NoOp):
69 | """
70 | :param encoded_bytes: bytes encoded by range_encode
71 | :param cdf: cdf to use, either a NHWLp matrix or instance of CDFOut
72 | :return: decoded matrix as np.int16, NHW
73 | """
74 | if isinstance(cdf, CDFOut):
75 | logit_probs_c_sm, means_c, log_scales_c, K, targets = cdf
76 |
77 | N, _, H, W = means_c.shape
78 |
79 | with time_logger.run('ac.encode'):
80 | decoded = torchac.decode_logistic_mixture(
81 | targets, means_c, log_scales_c, logit_probs_c_sm, encoded_bytes)
82 |
83 | else:
84 | N, H, W, Lp = cdf.shape
85 | assert Lp == self.L + 1, (Lp, self.L)
86 |
87 | with time_logger.run('ac.encode'):
88 | decoded = torchac.decode_cdf(cdf, encoded_bytes)
89 |
90 | return decoded.reshape(N, H, W)
91 |
--------------------------------------------------------------------------------
/src/bitcoding/coders_helpers.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 |
19 | --------------------------------------------------------------------------------
20 |
21 | Very thin wrapper around DiscretizedMixLogisticLoss.cdf_step_non_shared that keeps track of targets, which are the
22 | same for all channels of the bottleneck, as well as the current channel index.
23 |
24 | """
25 |
26 | import torch
27 |
28 | from criterion.logistic_mixture import DiscretizedMixLogisticLoss, CDFOut
29 |
30 |
31 | class CodingCDFNonshared(object):
32 | def __init__(self, l, total_C, dmll: DiscretizedMixLogisticLoss):
33 | """
34 | :param l: predicted distribution, i.e., NKpHW, see DiscretizedMixLogisticLoss
35 | :param total_C:
36 | :param dmll:
37 | """
38 | self.l = l
39 | self.dmll = dmll
40 |
41 | # Lp = L+1
42 | self.targets = torch.linspace(dmll.x_min - dmll.bin_width / 2,
43 | dmll.x_max + dmll.bin_width / 2,
44 | dmll.L + 1, dtype=torch.float32, device=l.device)
45 | self.total_C = total_C
46 | self.c_cur = 0
47 |
48 | def get_next_C(self, decoded_x) -> CDFOut:
49 | """
50 | Get CDF to encode/decode next channel
51 | :param decoded_x: NCHW
52 | :return: C_cond_cur, NHWL'
53 | """
54 | C_Cur = self.dmll.cdf_step_non_shared(self.l, self.targets, self.c_cur, self.total_C, decoded_x)
55 | self.c_cur += 1
56 | return C_Cur
57 |
58 |
--------------------------------------------------------------------------------
/src/bitcoding/part_suffix_helper.py:
--------------------------------------------------------------------------------
1 | import re
2 | import glob
3 | import os
4 | import string
5 |
6 | _PART_SUFFIX_BASE = '.part'
7 | _PART_SUFFIX_REGEX = _PART_SUFFIX_BASE + r'(\d+)$'
8 |
9 |
10 | def make_part_suffix(i):
11 | """Return str that is suffix for index `i`."""
12 | assert i >= 0, i
13 | return _PART_SUFFIX_BASE + str(i)
14 |
15 |
16 | def contains_part_suffix(p):
17 | """Return true iff path `p` contains a part suffix."""
18 | return re.search(_PART_SUFFIX_REGEX, p) is not None
19 |
20 |
21 | def index_of_part_suffix(p):
22 | """Return index of suffix (final number in path)."""
23 | return int(re.search(_PART_SUFFIX_REGEX, p).group(1))
24 |
25 |
26 | def iter_part_suffixes(pin):
27 | """Return list of all paths that are the same base."""
28 | assert os.path.isfile(pin)
29 | assert contains_part_suffix(pin)
30 | base = pin.rstrip(string.digits)
31 | matches = glob.glob(base + '*')
32 | # Filter things of the form base* where * is not digits
33 | matches = [m for m in matches if contains_part_suffix(m)]
34 | matches = sorted(matches, key=index_of_part_suffix)
35 | return matches
36 |
37 |
38 | def test_part_suffix():
39 | assert make_part_suffix(1) == '.part1'
40 | assert make_part_suffix(10) == '.part10'
41 | assert contains_part_suffix('bla/bli/blupp.part1')
42 | assert index_of_part_suffix('bla/bli/blupp.part1') == 1
43 | assert contains_part_suffix('bla/bli/blupp.part3')
44 | assert not contains_part_suffix('bla/bli/blupp.part3/more')
45 |
46 |
47 | def test_iter(tmpdir):
48 | prefix = 'some.file'
49 | for i in range(16):
50 | tmpdir.join('some.file' + str(make_part_suffix(i))).write('test')
51 |
52 | some_file = str(tmpdir.join(prefix + str(make_part_suffix(3))))
53 | assert list(map(os.path.basename, iter_part_suffixes(some_file))) == \
54 | ['some.file.part{}'.format(i) for i in range(16)]
55 |
--------------------------------------------------------------------------------
/src/blueprints/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/src/blueprints/__init__.py
--------------------------------------------------------------------------------
/src/blueprints/multiscale_blueprint.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | from collections import namedtuple
20 |
21 | import numpy as np
22 | import torch
23 | import torch.nn.functional as F
24 | import torchvision
25 |
26 | import pytorch_ext as pe
27 | import vis.grid
28 | import vis.summarizable_module
29 | from helpers.pad import pad
30 | from modules.multiscale_network import MultiscaleNetwork, Out
31 | from vis import histogram_plotter, image_summaries
32 |
33 |
34 | MultiscaleLoss = namedtuple(
35 | 'MultiscaleLoss',
36 | ['loss_pc', # loss to minimize
37 | 'nonrecursive_bpsps', # bpsp corresponding to non-recursive scales
38 | 'recursive_bpsps']) # None if not recursive, else all bpsp including recursive
39 |
40 |
41 |
42 | class MultiscaleBlueprint(vis.summarizable_module.SummarizableModule):
43 | def __init__(self, config_ms):
44 | super(MultiscaleBlueprint, self).__init__()
45 | net = MultiscaleNetwork(config_ms)
46 | net.to(pe.DEVICE)
47 |
48 | self.net = net
49 | self.losses = net.get_losses()
50 |
51 | def set_eval(self):
52 | self.net.eval()
53 | self.losses.loss_dmol_rgb.eval()
54 | self.losses.loss_dmol_n.eval()
55 |
56 | def forward(self, in_batch, auto_recurse=0) -> Out:
57 | """
58 | :param in_batch: NCHW 0..255 float
59 | :param auto_recurse: int, how many times the last scales should be applied again. Used for RGB Shared.
60 | :return: layers.multiscale.Out
61 | """
62 | return self.net(in_batch, auto_recurse)
63 |
64 | def get_loss(self, out: Out, num_subpixels_before_pad=None) -> MultiscaleLoss:
65 | """
66 | :param num_subpixels_before_pad: If given, calculate bpsp with this, instead of num_pixels returned by self.losses.
67 | This is needed because while testing, we have to pad images. To calculate the correct bpsp, we need to
68 | calcualte it with respect to the actual (non-padded) number of pixels
69 | :returns instance of MultiscaleLoss, see above.
70 | """
71 | # `costs`: a list, containing the cost of each scale, in nats
72 | costs, final_cost_uniform, num_subpixels = self.losses.get(out)
73 | if num_subpixels_before_pad:
74 | assert num_subpixels_before_pad <= num_subpixels, num_subpixels_before_pad
75 | num_subpixels = num_subpixels_before_pad
76 | # conversion between nats and bits per subpixel
77 | conversion = np.log(2.) * num_subpixels
78 | costs_bpsp = [cost/conversion for cost in costs]
79 |
80 | self.summarizer.register_scalars(
81 | 'auto',
82 | {'costs/scale_{}_bpsp'.format(i): cost for i, cost in enumerate(costs_bpsp)})
83 |
84 | # all bpsps corresponding to non-recursive scales, including final (uniform-prior) cost
85 | nonrecursive_bpsps = costs_bpsp[:out.auto_recursive_from] + [final_cost_uniform / conversion]
86 | if out.auto_recursive_from is not None:
87 | # all bpsps corresponding to non-recursive AND recursive scales, including final cost
88 | recursive_bpsps = costs_bpsp + [out.get_nat_count(-1) / conversion]
89 | else:
90 | recursive_bpsps = None
91 |
92 | # loss is everything without final (uniform-prior) scale
93 | total_bpsp_without_final = sum(costs_bpsp)
94 | loss_pc = total_bpsp_without_final
95 | return MultiscaleLoss(loss_pc, nonrecursive_bpsps, recursive_bpsps)
96 |
97 | def sample_forward(self, in_batch, sample_scales, partial_final=None):
98 | return self.net.sample_forward(in_batch, self.losses, sample_scales, partial_final)
99 |
100 | @staticmethod
101 | def add_image_summaries(sw, out: Out, global_step, prefix):
102 | tag = lambda t: sw.pre(prefix, t)
103 | is_train = prefix == 'train'
104 | for scale, (S_i, _, P_i, L_i) in enumerate(out.iter_all_scales(), 1): # start from 1, as 0 is RGB
105 | sw.add_image(tag('bn/{}'.format(scale)), new_bottleneck_summary(S_i, L_i), global_step)
106 | # This will only trigger for the final scale, where P_i is the uniform distribution.
107 | # With this, we can check how accurate the uniform assumption is (hint: not very)
108 | is_logits = P_i.shape[1] == L_i
109 | if is_logits and is_train:
110 | with sw.add_figure_ctx(tag('histo_out/{}'.format(scale)), global_step) as plt:
111 | add_ps_summaries(S_i, get_p_y(P_i), L_i, plt)
112 |
113 | @staticmethod
114 | def bottleneck_images(s, L):
115 | assert s.dim() == 4, s.shape
116 | _assert_contains_symbol_indices(s, L)
117 | s = s.float().div(L)
118 | return [image_summaries.to_image(s[:, c, ...]) for c in range(s.shape[1])]
119 |
120 | @staticmethod
121 | def unpack_batch_pad(raw, fac):
122 | """
123 | :param raw: uint8, input image.
124 | :param fac: downscaling factor we will use, used to determine proper padding.
125 | """
126 | if len(raw.shape) == 3:
127 | raw.unsqueeze_(0) # add batch dim
128 | assert len(raw.shape) == 4
129 | raw = MultiscaleBlueprint.pad(raw, fac)
130 | raw = raw.to(pe.DEVICE)
131 | img_batch = raw.float()
132 | s = raw.long() # symbols
133 | return img_batch, s
134 |
135 | @staticmethod
136 | def pad(raw, fac):
137 | raw, _ = pad(raw, fac, mode=MultiscaleBlueprint.get_padding_mode())
138 | return raw
139 |
140 | @staticmethod
141 | def get_padding_mode():
142 | return 'constant'
143 |
144 | @staticmethod
145 | def unpack(img_batch):
146 | idxs = img_batch['idx'].squeeze().tolist()
147 | raw = img_batch['raw'].to(pe.DEVICE)
148 | img_batch = raw.float()
149 | s = raw.long() # symbols
150 | return idxs, img_batch, s
151 |
152 |
153 | def new_bottleneck_summary(s, L):
154 | """
155 | Grayscale bottleneck representation: Expects the actual bottleneck symbols.
156 | :param s: NCHW
157 | :return: [0, 1] image
158 | """
159 | assert s.dim() == 4, s.shape
160 | _assert_contains_symbol_indices(s, L)
161 | s = s.float().div(L)
162 | grid = vis.grid.prep_for_grid(s, channelwise=True)
163 | assert len(grid) == s.shape[1], (len(grid), s.shape)
164 | assert [g.max() <= 1 for g in grid], [g.max() for g in grid]
165 | assert grid[0].dtype == torch.float32, grid.dtype
166 | return torchvision.utils.make_grid(grid, nrow=5)
167 |
168 |
169 | def _assert_contains_symbol_indices(t, L):
170 | """ assert 0 <= t < L """
171 | assert 0 <= t.min() and t.max() < L, (t.min(), t.max())
172 |
173 |
174 | def add_ps_summaries(s, p_y, L, plt):
175 | histo_s = pe.histogram(s, L)
176 | p_x = histo_s / np.sum(histo_s)
177 |
178 | assert p_x.shape == p_y.shape, (p_x.shape, p_y.shape)
179 |
180 | histogram_plotter.plot_histogram([
181 | ('p_x', p_x),
182 | ('p_y', p_y),
183 | ], plt)
184 |
185 |
186 | def get_p_y(y):
187 | """
188 | :param y: NLCHW float, logits
189 | :return: L dimensional vector p
190 | """
191 | Ldim = 1
192 | L = y.shape[Ldim]
193 | y = y.detach()
194 | p = F.softmax(y, dim=Ldim)
195 | p = p.transpose(Ldim, -1)
196 | p = p.contiguous().view(-1, L) # nL
197 | p = torch.mean(p, dim=0) # L
198 | return pe.tensor_to_np(p)
199 |
--------------------------------------------------------------------------------
/src/configs/dl/in32.cf:
--------------------------------------------------------------------------------
1 | # Note: this is different from the text, batch size is actually 30. See Errata in README
2 | batchsize_train = 30
3 | batchsize_val = 120
4 | crop_size = 32
5 |
6 | max_epochs = None
7 |
8 | image_cache_pkl = None
9 | train_imgs_glob = 'path/to/train_32x32/ (check readme)'
10 | val_glob = 'path/to/val_32x32 (check readme)'
11 |
12 | val_glob_min_size = None
13 | fixed_first_image = None
14 | num_val_batches = 5
15 |
--------------------------------------------------------------------------------
/src/configs/dl/in64.cf:
--------------------------------------------------------------------------------
1 | use in32.cf
2 |
3 | crop_size = 64
4 |
5 | image_cache_pkl = None
6 | train_imgs_glob = 'path/to/train_64x64/ (check readme)'
7 | val_glob = 'path/to/val_64x64 (check readme)'
8 |
9 | num_val_batches = 5
10 |
--------------------------------------------------------------------------------
/src/configs/dl/oi.cf:
--------------------------------------------------------------------------------
1 | batchsize_train = 30
2 | batchsize_val = 30
3 | crop_size = 128
4 |
5 | max_epochs = None
6 |
7 | image_cache_pkl = None
8 | train_imgs_glob = 'path/to/train/ (check readme)'
9 | val_glob = 'path/to/val (check readme)'
10 |
11 | val_glob_min_size = None # for cache_p
12 | num_val_batches = 5
13 |
14 |
--------------------------------------------------------------------------------
/src/configs/ms/cr.cf:
--------------------------------------------------------------------------------
1 | optim = 'RMSprop'
2 |
3 | lr.initial = 0.0001
4 | lr.schedule = 'exp_0.75_e5'
5 | weight_decay = 0
6 |
7 | num_scales = 3
8 | shared_across_scales = False
9 |
10 | Cf = 64
11 | kernel_size = 3
12 |
13 | dmll_enable_grad = 0
14 |
15 | rgb_bicubic_baseline = False
16 |
17 | enc.cls = 'EDSRLikeEnc'
18 | enc.num_blocks = 8
19 | enc.feed_F = True
20 | enc.importance_map = False
21 |
22 | learned_L = False
23 |
24 | dec.cls = 'EDSRDec'
25 | dec.num_blocks = 8
26 | dec.skip = True
27 |
28 | q.cls = 'Quantizer'
29 | q.C = 5
30 | # We assume q.L levels, evenly distributed between q.levels_range[0] and q.levels_range[1], see net.py
31 | q.L = 25
32 | q.levels_range = (-1, 1)
33 | q.sigma = 2
34 |
35 | prob.K = 10
36 |
37 | after_q1x1 = True
38 | x4_down_in_scale0 = False
39 |
--------------------------------------------------------------------------------
/src/configs/ms/cr_rgb.cf:
--------------------------------------------------------------------------------
1 | use cr_rgb_shared.cf
2 |
3 | num_scales = 3
4 | # We have a skip in the decoder, see comment in _forward_with_scales in multiscale.py
5 | dec.skip = True
6 |
7 |
--------------------------------------------------------------------------------
/src/configs/ms/cr_rgb_shared.cf:
--------------------------------------------------------------------------------
1 | use cr.cf
2 |
3 | rgb_bicubic_baseline = True
4 | q.C = 3
5 | q.L = 5
6 |
7 | # shared for all
8 | num_scales = 1
9 |
10 | enc.feed_F = False
11 |
12 | dec.skip = False
13 |
14 | enc.cls = 'BicubicSubsampling'
15 |
16 |
--------------------------------------------------------------------------------
/src/criterion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/src/criterion/__init__.py
--------------------------------------------------------------------------------
/src/dataloaders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/src/dataloaders/__init__.py
--------------------------------------------------------------------------------
/src/dataloaders/images_loader.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | import argparse
20 | import glob
21 | import os
22 | import pickle
23 |
24 | import numpy as np
25 | import torch
26 | import torchvision.transforms as transforms
27 | from PIL import Image
28 | from torch.utils.data import Dataset
29 |
30 | from helpers.paths import has_image_ext
31 | from helpers.testset import Testset
32 |
33 |
34 | class NoImagesFoundException(Exception):
35 | def __init__(self, p):
36 | self.p = p
37 |
38 |
39 | class IndexImagesDataset(Dataset):
40 | """
41 | A Dataset class for images, that also returns the index of the image in the dataset.
42 | """
43 | @staticmethod
44 | def to_tensor_uint8_transform():
45 | """ Convert PIL to uint8 tensor, CHW. """
46 | return to_tensor_not_normalized
47 |
48 | @staticmethod
49 | def to_grb(t):
50 | assert t.shape[0] == 3
51 | with torch.no_grad():
52 | return torch.stack((t[1, ...], t[0, ...], t[2, ...]), dim=0)
53 |
54 | @staticmethod
55 | def to_float_tensor_transform():
56 | return transforms.Lambda(lambda img: img.float().div(255.))
57 |
58 | def copy(self):
59 | return IndexImagesDataset(
60 | self.images_spec, self.to_tensor_transform, self.cache_p, self.min_size)
61 |
62 | def __init__(self, images, to_tensor_transform, fixed_first=None):
63 | """
64 | :param images: Instance of Testset or ImagesCached
65 | :param to_tensor_transform: A function that takes a PIL RGB image and returns a torch.Tensor
66 | """
67 | self.to_tensor_transform = to_tensor_transform
68 |
69 | if isinstance(images, ImagesCached):
70 | self.files = images.get_images_sorted_cached()
71 | self.id = '{}_{}_{}'.format(images.images_spec, images.min_size, len(self.files))
72 | elif isinstance(images, Testset):
73 | self.files = images.ps
74 | self.id = images.id
75 | else:
76 | raise ValueError('Expected ImagesCached or Testset, got images={}'.format(images))
77 |
78 | if fixed_first:
79 | assert os.path.isfile(fixed_first)
80 | self.files = [fixed_first] + self.files
81 |
82 | if len(self.files) == 0:
83 | raise NoImagesFoundException(images.search_path())
84 |
85 | def __len__(self):
86 | return len(self.files)
87 |
88 | def __str__(self):
89 | return 'IndexImagesDataset({} images, id={})'.format(len(self.files), self.id)
90 |
91 | def __getitem__(self, idx):
92 | path = self.files[idx]
93 | with open(path, 'rb') as f:
94 | pil = Image.open(f).convert('RGB')
95 | raw = self.to_tensor_transform(pil) # 3HW uint8s
96 | return {'idx': idx,
97 | 'raw': raw}
98 |
99 |
100 | def to_tensor_not_normalized(pic):
101 | """ copied from PyTorch functional.to_tensor, removed final .float().div(255.) """
102 | if isinstance(pic, np.ndarray):
103 | # handle numpy array
104 | img = torch.from_numpy(pic.transpose((2, 0, 1)))
105 | return img
106 |
107 | # handle PIL Image
108 | if pic.mode == 'I':
109 | img = torch.from_numpy(np.array(pic, np.int32, copy=False))
110 | elif pic.mode == 'I;16':
111 | img = torch.from_numpy(np.array(pic, np.int16, copy=False))
112 | elif pic.mode == 'F':
113 | img = torch.from_numpy(np.array(pic, np.float32, copy=False))
114 | elif pic.mode == '1':
115 | img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
116 | else:
117 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
118 | # PIL image mode: L, P, I, F, RGB, YCbCr, RGBA, CMYK
119 | if pic.mode == 'YCbCr':
120 | nchannel = 3
121 | elif pic.mode == 'I;16':
122 | nchannel = 1
123 | else:
124 | nchannel = len(pic.mode)
125 | img = img.view(pic.size[1], pic.size[0], nchannel)
126 | # put it from HWC to CHW format
127 | # yikes, this transpose takes 80% of the loading time/CPU
128 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
129 | return img
130 |
131 |
132 | class ImagesCached(object):
133 | """ Caches contents of folders, for slow filesystems. """
134 | def __init__(self, images_spec, cache_p=None, min_size=None):
135 | """
136 | :param images_spec: str, interpreted as
137 | - a glob, if it contains a *
138 | - a single image, if it ends in one of IMG_EXTENSIONS
139 | - a directory otherwise
140 | :param cache_p: path to a cache or None. If given, check there when `get_images_sorted_cached` is called
141 | :param min_size: if given, make sure to only return/cache images of the given size
142 | :return:
143 | """
144 | self.images_spec = os.path.expanduser(images_spec)
145 | self.cache_p = os.path.expanduser(cache_p)
146 | self.min_size = min_size
147 | self.cache = ImagesCached._get_cache(self.cache_p)
148 |
149 | def __str__(self):
150 | return f'ImagesCached(images_spec={self.images_spec})'
151 |
152 | def search_path(self):
153 | return self.images_spec
154 |
155 | def get_images_sorted_cached(self):
156 | key = self.images_spec, self.min_size
157 | if key in self.cache:
158 | return self.cache[key]
159 | if self.cache_p:
160 | print(f'WARN: Given cache_p={self.cache_p}, but key not found:\n{key}')
161 | available_keys = sorted(self.cache.keys(), key=lambda img_size: img_size[0])
162 | print('Found:\n' + '\n'.join(map(str, available_keys)))
163 | return sorted(self._iter_imgs_unordered_filter_size())
164 |
165 | def update(self, force, verbose):
166 | """
167 | Writes/updates to cache_p
168 | :param force: overwrites
169 | """
170 | if not force and (self.images_spec, self.min_size) in self.cache:
171 | print('Cache already contains {}'.format(self.images_spec))
172 | return
173 | print('Updating cache for {}...'.format(self.images_spec))
174 | images = sorted(self._iter_imgs_unordered_filter_size(verbose))
175 | if len(images) == 0:
176 | print('No images found...')
177 | return
178 | print('Found {} images...'.format(len(images)))
179 | self.cache[(self.images_spec, self.min_size)] = images
180 | print('Writing cache...')
181 | with open(self.cache_p, 'wb') as f:
182 | pickle.dump(self.cache, f)
183 |
184 | @staticmethod
185 | def print_all(cache_p):
186 | """
187 | Print all cache entries to console
188 | :param cache_p: Path of the cache to print
189 | """
190 | cache = ImagesCached._get_cache(cache_p)
191 | for key in list(cache.keys()):
192 | if len(cache[key]) == 0:
193 | del cache[key]
194 | for (p, min_size), imgs in cache.items():
195 | min_size_str = ' (>={})'.format(min_size) if min_size else ''
196 | print('{}{}: {} images'.format(p, min_size_str, len(imgs)))
197 |
198 | def _iter_imgs_unordered_filter_size(self, verbose=False):
199 | for p in self._iter_imgs_unordered(self.images_spec):
200 | if self.min_size:
201 | img = Image.open(p)
202 | img_min_dim = min(img.size)
203 | if img_min_dim < self.min_size:
204 | print('Skipping {} ({})...'.format(p, img.size))
205 | continue
206 | if verbose:
207 | print(p)
208 | yield p
209 |
210 | @staticmethod
211 | def _get_cache(cache_p):
212 | if not cache_p:
213 | return {}
214 | if not os.path.isfile(cache_p):
215 | print(f'cache_p={cache_p} does not exist.')
216 | return {}
217 | with open(cache_p, 'rb') as f:
218 | return pickle.load(f)
219 |
220 | @staticmethod
221 | def _iter_imgs_unordered(images_spec):
222 | if '*' in images_spec:
223 | matches = glob.glob(images_spec)
224 | _, ext = os.path.splitext(images_spec)
225 | if not ext:
226 | matches = (p for p in matches if has_image_ext(p))
227 | elif not has_image_ext(images_spec):
228 | raise ValueError('Unrecognized extension {} in glob ({})'.format(ext, images_spec))
229 | if not matches:
230 | raise ValueError('No matches for glob {}'.format(images_spec))
231 | yield from matches
232 | return
233 |
234 | if has_image_ext(images_spec):
235 | yield from [images_spec]
236 | return
237 |
238 | # At this point, images_spec should be a path to a directory
239 | if not os.path.isdir(images_spec):
240 | raise NotADirectoryError(images_spec)
241 |
242 | print('Recursively traversing {}...'.format(images_spec))
243 | for root, _, fnames in os.walk(images_spec, followlinks=True):
244 | for fname in fnames:
245 | if has_image_ext(fname):
246 | p = os.path.join(root, fname)
247 | if os.path.getsize(p) == 0:
248 | print('WARN / 0 bytes /', p)
249 | continue
250 | yield p
251 |
252 | def main():
253 | # See README
254 | p = argparse.ArgumentParser()
255 | mode_parsers = p.add_subparsers(dest='mode')
256 | show_p = mode_parsers.add_parser('show')
257 | show_p.add_argument('cache_p')
258 | update_p = mode_parsers.add_parser('update')
259 | update_p.add_argument('images_spec')
260 | update_p.add_argument('cache_p')
261 | update_p.add_argument('--min_size', type=int)
262 | update_p.add_argument('--force', '-f', action='store_true')
263 | update_p.add_argument('--verbose', '-v', action='store_true', help='Print found paths. Might be slow!')
264 |
265 | flags = p.parse_args()
266 | if flags.mode == 'show':
267 | ImagesCached.print_all(flags.cache_p)
268 | elif flags.mode == 'update':
269 | ImagesCached(flags.images_spec, flags.cache_p, flags.min_size).update(flags.force, flags.verbose)
270 | else:
271 | p.print_usage()
272 |
273 |
274 | # Resize bicubic
275 |
276 |
277 | def resize_bicubic_batch(t, fac):
278 | assert len(t.shape) == 4
279 | N = t.shape[0]
280 | return torch.stack([resize_bicubic(t[n, ...], fac) for n in range(N)], dim=0)
281 |
282 |
283 | def resize_bicubic(t, fac):
284 | img = _tensor_to_image(t) # to PIL
285 | h, w = img.size
286 | img = img.resize((int(h * fac), int(w * fac)), Image.BICUBIC)
287 | t = to_tensor_not_normalized(img) # back to 3HW uint8 tensor
288 | return t
289 |
290 |
291 | def _tensor_to_image(t):
292 | assert t.shape[0] == 3, t.shape
293 | return Image.fromarray(t.permute(1, 2, 0).detach().cpu().numpy())
294 |
295 |
296 | if __name__ == '__main__':
297 | main()
298 |
299 |
--------------------------------------------------------------------------------
/src/helpers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/src/helpers/__init__.py
--------------------------------------------------------------------------------
/src/helpers/aligned_printer.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | import itertools
20 |
21 |
22 | class AlignedPrinter(object):
23 | """ Print Rows nicely as a table. """
24 | def __init__(self):
25 | self.rows = []
26 | self.maxs = []
27 |
28 | def append(self, *row):
29 | self.rows.append(row)
30 | self.maxs = [max(max_cur, len(row_entry))
31 | for max_cur, row_entry in
32 | itertools.zip_longest(self.maxs, row, fillvalue=0)]
33 |
34 | def print(self):
35 | for row in self.rows:
36 | for width, row_entry in zip(self.maxs, row):
37 | print('{row_entry:{width}}'.format(row_entry=row_entry, width=width), end=' ')
38 | print()
39 |
40 | def __enter__(self):
41 | return self
42 |
43 | def __exit__(self, exc_type, exc_val, exc_tb):
44 | self.print()
45 |
--------------------------------------------------------------------------------
/src/helpers/config_checker.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | import os
20 |
21 |
22 | DEFAULT_CONFIG_DIR = 'configs'
23 |
24 |
25 | class ConfigsRepo(object):
26 | def __init__(self, config_dir=DEFAULT_CONFIG_DIR):
27 | self.config_dir = config_dir
28 |
29 | def check_configs_available(self, *config_ps):
30 | for p in config_ps:
31 | assert self.config_dir in p, 'Expected {} to contain {}!'.format(p, self.config_dir)
32 | if not os.path.isfile(p):
33 | raise FileNotFoundError(p)
34 |
--------------------------------------------------------------------------------
/src/helpers/global_config.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 |
19 |
20 | --------------------------------------------------------------------------------
21 |
22 |
23 | Support for global config parameters shared over the whole program. The goal is to easily add parameters into some
24 | nested module without passing it all the way through, for fast prototyping.
25 | Also supports updating Configs returned by config_parser.py.
26 |
27 |
28 | Usage
29 |
30 | ```
31 | from global_config import global_config
32 |
33 | parser = argparse.ArgumentParser()
34 | parser.add_argument('-p', action='append', nargs=1)
35 | ...
36 | flags = parser.parse_args()
37 | ...
38 | global_config.add_from_flag(flags.p)
39 |
40 | # in some module
41 |
42 | from global_config import global_config
43 |
44 | # check if loss is set, if not, use the default of 'mse'
45 | loss_fct = global_config.get('loss', default_value='mse')
46 | if loss_fct == 'mse':
47 | loss = MSE
48 | elif loss_fct == 'ce':
49 | loss = CrossEntropy
50 |
51 | # start the training
52 | python train.py -p loss=ce
53 | ```
54 |
55 |
56 | """
57 | from contextlib import contextmanager
58 |
59 |
60 | class _GlobalConfig(object):
61 | def __init__(self, default_unset_value=False, type_classes=None):
62 | """
63 | :param default_unset_value: Value to use if global_config['key'] is called on a key that is not set
64 | :param type_classes: supported type classes that values are tried to convert to, see `_eval_value`
65 | """
66 | if type_classes is None:
67 | type_classes = [int, float]
68 | self.type_classes = type_classes
69 | self.default_unset_value = default_unset_value
70 | self._values = {}
71 | self._used_params = set()
72 |
73 | def add_from_flag(self, param_flag):
74 | """ Add from a list containing key=value, or key (eqivalent to key=True). """
75 | if param_flag is None:
76 | return
77 | for param_spec in param_flag:
78 | if isinstance(param_spec, list):
79 | assert len(param_spec) == 1
80 | param_spec = param_spec[0]
81 | self.add_param_from_spec(param_spec)
82 |
83 | def update_config(self, config):
84 | """ Update a fjcommon._Config returned by fjcommon.config_parser """
85 | for k, v in config.all_params_and_values():
86 | if k in self:
87 | print('Updating config.{} = {}'.format(k, self[k]))
88 | config.set_attr(k, self[k])
89 | self.declare_used(k)
90 |
91 | def add_param_from_spec(self, spec):
92 | if '=' not in spec:
93 | spec = '{}=True'.format(spec)
94 | key, value = spec.split('=')
95 | key = key.strip()
96 | value = self._eval_value(value.strip())
97 | self[key] = value
98 |
99 | def _eval_value(self, value):
100 | # try if `value` is True, False, or None, and return the actual instance
101 | if value in ('True', 'False', 'None'):
102 | return {'True': True,
103 | 'False': False,
104 | 'None': None}[value]
105 |
106 | # try casting to classes in type_classes
107 | for type_cls in self.type_classes:
108 | try:
109 | return type_cls(value)
110 | except ValueError:
111 | continue
112 |
113 | # finally, just interpret as string type
114 | if ' ' in value:
115 | raise ValueError('values are not allowed to contain spaces! {}'.format(value))
116 | if '/' in value or '~' in value:
117 | raise ValueError('values are not allowed to contain "/" or "~"! {}'.format(value))
118 | return value
119 |
120 | def __setitem__(self, key, value):
121 | self._values[key] = value
122 |
123 | def __getitem__(self, key):
124 | return self.get(key, self.default_unset_value)
125 |
126 | def __contains__(self, item):
127 | return item in self._values
128 |
129 | def get(self, key, default_value, incompatible=None):
130 | """
131 | Check if `key` is set
132 | :param default_value: value to return if `key` is not set
133 | :param incompatible: list of keys which are not allowed to bet set if `key` is set
134 | :return: values[key] if key is set, `default_value` otherwise
135 | """
136 | if incompatible and key in self._values:
137 | self._ensure_not_specified(key, incompatible)
138 | self._used_params.add(key)
139 | return self._values.get(key, default_value)
140 |
141 | def declare_used(self, *keys):
142 | """ Hack to mark all keys in `keys` as used, even if global_config[key] was never called """
143 | self._used_params.update(keys)
144 |
145 | def get_unused_params(self):
146 | return [k for k in self._values.keys() if k not in self._used_params]
147 |
148 | def values(self):
149 | return list(self._values_to_spec())
150 |
151 | def values_str(self, joiner=' '):
152 | return joiner.join(self.values())
153 |
154 | def reset(self): # ugly, needed because global...
155 | """ Reset global_config. """
156 | self._values = {}
157 | self._used_params = set()
158 |
159 | @contextmanager
160 | def reset_after(self):
161 | yield
162 | self.reset()
163 |
164 | def _values_to_spec(self):
165 | for k, v in sorted(self._values.items()):
166 | if v is True:
167 | yield k
168 | else:
169 | yield '{}={}'.format(k, v)
170 |
171 | def _ensure_not_specified(self, key, incompatible):
172 | """ Raises ValueError if any of the keys in `incompatible` were specified """
173 | assert isinstance(incompatible, list)
174 | errors = [k for k in incompatible if k in self]
175 | if errors:
176 | raise ValueError(f"Got {key}, incompatible with: {','.join(errors)}")
177 |
178 | def __str__(self):
179 | if len(self._values) == 0:
180 | return 'GlobalConfig()'
181 | return 'GlobalConfig(\n\t{})'.format('\n\t'.join(self.values()))
182 |
183 |
184 | global_config = _GlobalConfig()
185 |
186 |
--------------------------------------------------------------------------------
/src/helpers/logdir_helpers.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | import glob
20 | from collections import namedtuple
21 | from datetime import datetime, timedelta
22 |
23 | import fasteners
24 | import re
25 | import os
26 | from os import path
27 |
28 | _LOG_DATE_FORMAT = "%m%d_%H%M"
29 | _RESTORE_PREFIX = 'r@'
30 |
31 |
32 | def create_unique_log_dir(config_rel_paths, log_dir_root, line_breaking_chars_pat=r'[-]',
33 | postfix=None, restore_dir=None, strip_ext=None):
34 | """
35 | 0117_1704 repr@soa3_med_8e*5_deePer_b50_noHM_C16 repr@v2_res_shallow r@0115_1340
36 | :param config_rel_paths: paths to the configs, relative to the config root dir
37 | :param log_dir_root: In this directory, all log dirs are stored. Created if needed.
38 | :param line_breaking_chars_pat:
39 | :param postfix: appended to the returned log dir
40 | :param restore_dir: if given, expected to be a log dir. the JOB_ID of that will be appended
41 | :param strip_ext: if given, do not store extension `strip_ext` of config_rel_paths
42 | :return: path to a newly created directory
43 | """
44 | if any('@' in config_rel_path for config_rel_path in config_rel_paths):
45 | raise ValueError('"@" not allowed in paths, got {}'.format(config_rel_paths))
46 |
47 | if strip_ext:
48 | assert all(strip_ext in c for c in config_rel_paths)
49 | config_rel_paths = [c.replace(strip_ext, '') for c in config_rel_paths]
50 |
51 | def prep_path(p):
52 | p = p.replace(path.sep, '@')
53 | return re.sub(line_breaking_chars_pat, '*', p)
54 |
55 | postfix_dir_name = ' '.join(map(prep_path, config_rel_paths))
56 | if restore_dir:
57 | _, restore_job_component = _split_log_dir(restore_dir)
58 | restore_job_id = log_date_from_log_dir(restore_job_component)
59 | postfix_dir_name += ' {restore_prefix}{job_id}'.format(
60 | restore_prefix=_RESTORE_PREFIX, job_id=restore_job_id)
61 | if postfix:
62 | if isinstance(postfix, list):
63 | postfix = ' '.join(postfix)
64 | postfix_dir_name += ' ' + postfix
65 | return _mkdir_threadsafe_unique(log_dir_root, datetime.now(), postfix_dir_name)
66 |
67 |
68 | LogDirComps = namedtuple('LogDirComps', ['config_paths', 'postfix'])
69 |
70 |
71 | def parse_log_dir(log_dir, configs_dir, base_dirs, append_ext=''):
72 | """
73 | Given a log_dir produced by `create_unique_log_dir`, return the full paths of all configs used.
74 | The log dir has thus the following format
75 | {now} {netconfig} {probconfig} [r@XXXX_YYYY] [{postfix} {postfix}]
76 |
77 | :param log_dir: the log dir to parse
78 | :param configs_dir: the root config dir, where all the configs live
79 | :param base_dirs: Prefixed to the paths of the configs, e.g., ['ae', 'pc']
80 | :return: all config paths, as well as the postfix if one was given
81 | """
82 | base_dirs = [path.join(configs_dir, base_dir) for base_dir in base_dirs]
83 | log_dir = path.basename(log_dir.strip(path.sep))
84 |
85 | comps = log_dir.split(' ')
86 | assert is_log_date(comps[0]), 'Invalid log_dir: {}'.format(log_dir)
87 |
88 | assert len(comps) > len(base_dirs), 'Expected a base dir for every component, got {} and {}'.format(
89 | comps, base_dirs)
90 | config_components = comps[1:(1+len(base_dirs))]
91 | has_restore = any(_RESTORE_PREFIX in c for c in comps)
92 | postfix = comps[1+len(base_dirs)+has_restore:]
93 |
94 | def get_real_path(base, prepped_p):
95 | p_glob = prepped_p.replace('@', path.sep)
96 | p_glob = path.join(base, p_glob) + append_ext # e.g., ae_configs/p_glob.cf
97 | glob_matches = glob.glob(p_glob)
98 | # We always only replace one character with *, so filter for those.
99 | # I.e. lr1e-5 will become lr1e*5, which will match lr1e-5 but also lr1e-4.5
100 | glob_matches_of_same_len = [g for g in glob_matches if len(g) == len(p_glob)]
101 | if len(glob_matches_of_same_len) != 1:
102 | raise ValueError('Cannot find config on disk: {} (matches: {})'.format(p_glob, glob_matches_of_same_len))
103 | return glob_matches_of_same_len[0]
104 |
105 | return LogDirComps(
106 | config_paths=tuple(get_real_path(base_dir, comp)
107 | for base_dir, comp in zip(base_dirs, config_components)),
108 | postfix=tuple(postfix) if postfix else None)
109 |
110 |
111 | # ------------------------------------------------------------------------------
112 |
113 |
114 | def _split_log_dir(log_dir):
115 | """
116 | given
117 | some/path/to/job/dir/0101_1818 ae_config pc_config/ckpts
118 | or
119 | some/path/to/job/dir/0101_1818 ae_config pc_config
120 | returns
121 | tuple some/path/to/job/dir, 0101_1818 ae_config pc_config
122 | """
123 | log_dir_root = []
124 | job_component = None
125 |
126 | for comp in log_dir.split(path.sep):
127 | try:
128 | log_date_from_log_dir(comp)
129 | job_component = comp
130 | break # this component is an actual log dir. stop and return components
131 | except ValueError:
132 | log_dir_root.append(comp)
133 |
134 | assert job_component is not None, 'Invalid log_dir: {}'.format(log_dir)
135 | return path.sep.join(log_dir_root), job_component
136 |
137 |
138 | def _mkdir_threadsafe_unique(log_dir_root, log_date, postfix_dir_name):
139 | os.makedirs(log_dir_root, exist_ok=True)
140 | # Make sure only one process at a time writes into log_dir_root
141 | with fasteners.InterProcessLock(os.path.join(log_dir_root, 'lock')):
142 | return _mkdir_unique(log_dir_root, log_date, postfix_dir_name)
143 |
144 |
145 | def _mkdir_unique(log_dir_root, log_date, postfix_dir_name):
146 | log_date_str = log_date.strftime(_LOG_DATE_FORMAT)
147 | if _log_dir_with_log_date_exists(log_dir_root, log_date):
148 | print('Log dir starting with {} exists...'.format(log_date_str))
149 | return _mkdir_unique(log_dir_root, log_date + timedelta(minutes=1), postfix_dir_name)
150 |
151 | log_dir = path.join(log_dir_root, '{log_date_str} {postfix_dir_name}'.format(
152 | log_date_str=log_date_str,
153 | postfix_dir_name=postfix_dir_name).strip())
154 | os.makedirs(log_dir)
155 | return log_dir
156 |
157 |
158 | def _log_dir_with_log_date_exists(log_dir_root, log_date):
159 | log_date_str = log_date.strftime(_LOG_DATE_FORMAT)
160 | all_log_dates = set()
161 | for log_dir in os.listdir(log_dir_root):
162 | try:
163 | all_log_dates.add(log_date_from_log_dir(log_dir))
164 | except ValueError:
165 | continue
166 | return log_date_str in all_log_dates
167 |
168 |
169 | def log_date_from_log_dir(log_dir):
170 | # extract {log_date} from LOG_DIR/{log_date} {netconfig} {probconfig}
171 | possible_log_date = os.path.basename(log_dir).split(' ')[0]
172 | if not is_log_date(possible_log_date):
173 | raise ValueError('Invalid log dir: {}'.format(log_dir))
174 | return possible_log_date
175 |
176 |
177 | def is_log_dir(log_dir):
178 | try:
179 | log_date_from_log_dir(log_dir)
180 | return True
181 | except ValueError:
182 | return False
183 |
184 |
185 | def is_log_date(possible_log_date):
186 | try:
187 | datetime.strptime(possible_log_date, _LOG_DATE_FORMAT)
188 | return True
189 | except ValueError:
190 | return False
191 |
--------------------------------------------------------------------------------
/src/helpers/pad.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | from fjcommon import functools_ext as ft
20 | from torch.nn import functional as F
21 |
22 |
23 | def pad(img, fac, mode='replicate'):
24 | """
25 | pad img such that height and width are divisible by fac
26 | """
27 | _, _, h, w = img.shape
28 | padH = fac - (h % fac)
29 | padW = fac - (w % fac)
30 | if padH == fac and padW == fac:
31 | return img, ft.identity
32 | if padH == fac:
33 | padTop = 0
34 | padBottom = 0
35 | else:
36 | padTop = padH // 2
37 | padBottom = padH - padTop
38 | if padW == fac:
39 | padLeft = 0
40 | padRight = 0
41 | else:
42 | padLeft = padW // 2
43 | padRight = padW - padLeft
44 | assert (padTop + padBottom + h) % fac == 0
45 | assert (padLeft + padRight + w) % fac == 0
46 |
47 | padding_tuple = (padLeft, padRight, padTop, padBottom)
48 |
49 | return F.pad(img, padding_tuple, mode), padding_tuple
50 |
51 |
52 | def undo_pad(img, padLeft, padRight, padTop, padBottom, target_shape=None):
53 | # the 'or None' makes sure that we don't get 0:0
54 | img_out = img[..., padTop:(-padBottom or None), padLeft:(-padRight or None)]
55 | if target_shape:
56 | h, w = target_shape
57 | assert img_out.shape[-2:] == (h, w), (img_out.shape[-2:], (h, w), img_.shape,
58 | (padLeft, padRight, padTop, padBottom))
59 | return img_out
60 |
--------------------------------------------------------------------------------
/src/helpers/paths.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | import os
20 | import glob
21 |
22 | from PIL import Image
23 | from torchvision import transforms as transforms
24 |
25 | from helpers import logdir_helpers
26 |
27 | from fjcommon.assertions import assert_exc
28 |
29 | import pytorch_ext as pe
30 |
31 |
32 | CKPTS_DIR_NAME = 'ckpts'
33 | VAL_DIR_NAME = 'val'
34 |
35 |
36 | IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'}
37 |
38 |
39 | def get_ckpts_dir(experiment_dir, ensure_exists=True):
40 | ckpts_p = os.path.join(experiment_dir, CKPTS_DIR_NAME)
41 | if ensure_exists:
42 | assert_exc(os.path.isdir(ckpts_p), 'Not found: {}'.format(ckpts_p))
43 | return ckpts_p
44 |
45 |
46 | def get_experiment_dir(log_dir, experiment_spec):
47 | """
48 | experiment_spec: if is a logdate, find correct full path in log_dir, otherwise assume logdir/experiment_spec exists
49 | :return experiment dir, no slash at the end. containing /ckpts
50 | """
51 | if logdir_helpers.is_log_date(experiment_spec): # assume that log_dir/restore* matches
52 | assert_exc(log_dir is not None, 'Can only infer experiment_dir from log_date if log_dir is not None')
53 | restore_dir_glob = os.path.join(log_dir, experiment_spec + '*')
54 | restore_dir_possible = glob.glob(restore_dir_glob)
55 | assert_exc(len(restore_dir_possible) == 1, 'Expected one match for {}, got {}'.format(
56 | restore_dir_glob, restore_dir_possible))
57 | experiment_spec = restore_dir_possible[0]
58 | else:
59 | experiment_spec = os.path.join(log_dir, experiment_spec)
60 | experiment_dir = experiment_spec.rstrip(os.path.sep)
61 | assert_exc(os.path.isdir(experiment_dir), 'Invalid experiment_dir: {}'.format(experiment_dir))
62 | return experiment_dir
63 |
64 |
65 | def img_name(img_p):
66 | return os.path.splitext(os.path.basename(img_p))[0]
67 |
68 |
69 | def has_image_ext(p):
70 | return os.path.splitext(p)[1].lower() in IMG_EXTENSIONS
71 |
72 |
--------------------------------------------------------------------------------
/src/helpers/rolling_buffer.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | import numpy as np
20 | from fjcommon.assertions import assert_exc
21 |
22 | import pytorch_ext as pe
23 |
24 |
25 | class BufferSizeMismatch(Exception):
26 | pass
27 |
28 |
29 | class RollingBufferHistogram(object):
30 | """
31 | Buffer that sets shape to be number of entries of first v it receives.
32 | Has function `plot` to plot histogram of data.
33 | """
34 | def __init__(self, buffer_size, name=None):
35 | self._name = name or 'RollingBufferHistogram'
36 | self._buffer = None
37 | self._buffer_size = buffer_size
38 | self._idx = 0
39 | self._filled_idx = 0
40 |
41 | def add(self, v):
42 | v = pe.tensor_to_np(v)
43 | num_values = np.prod(v.shape)
44 | if self._buffer is None:
45 | print(f'Creating {v.dtype} buffer for {self._name}: {self._buffer_size}x{num_values}')
46 | self._buffer = np.zeros((self._buffer_size, num_values), dtype=v.dtype)
47 | assert_exc(self._buffer.shape[1] == num_values, (self._buffer.shape, v.shape, num_values), BufferSizeMismatch)
48 | self._buffer[self._idx, :] = v.flatten()
49 | self._idx = (self._idx + 1) % self._buffer_size
50 | self._filled_idx = min(self._filled_idx + 1, self._buffer_size)
51 |
52 | def get_buffer(self):
53 | return self._buffer[:self._filled_idx, :]
54 |
55 | def plot(self, bins='auto', most_mass=0):
56 | counts, bins = np.histogram(self._buffer[:self._filled_idx], bins)
57 | counts = counts / self._filled_idx # normalize it by number of entries
58 | idx_min, idx_max = _most_mass_indices(counts, most_mass)
59 | return bins[idx_min:idx_max], counts[idx_min:idx_max]
60 |
61 |
62 | def _most_mass_indices(a, mass=0):
63 | total_mass = np.sum(a)
64 | threshold = total_mass * mass
65 | indices, = (a > threshold).nonzero() # non zero returns a tuple of len a.ndim -> unpack!
66 | return indices[0], indices[-1]
67 |
68 |
69 |
--------------------------------------------------------------------------------
/src/helpers/saver.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | import os
20 | import time
21 | import re
22 | import shutil
23 |
24 | import pytorch_ext as pe
25 | from os.path import basename
26 | import torch
27 | from torch.optim import optimizer
28 | from fjcommon.no_op import NoOp
29 | from fjcommon import timer
30 | from fjcommon.assertions import assert_exc
31 |
32 |
33 | class _CheckpointTracker(object):
34 | """ out_dir is usally set via set_out_dir """
35 | def __init__(self, out_dir=None, ckpt_name_fmt='ckpt_{:010d}.pt', tmp_postfix='.tmp'):
36 | assert len(tmp_postfix)
37 | assert '.' in tmp_postfix
38 | m = re.search(r'{:0(\d+?)d}', ckpt_name_fmt)
39 | assert m, 'Expected ckpt_name_fmt to have an int specifier such as or {:09d} or {:010d}.'
40 | max_itr = 10 ** int(m.group(1)) - 1
41 | if max_itr < 10000000: # ten million, should be enough
42 | print(f'Maximum iteration supported: {max_itr}')
43 | assert os.sep not in ckpt_name_fmt
44 | self.ckpt_name_fmt = ckpt_name_fmt
45 | self.ckpt_prefix = ckpt_name_fmt.split('{')[0]
46 | assert len(self.ckpt_prefix), 'Expected ckpt_name_fmt to start with a prefix before the format part!'
47 | self.tmp_postfix = tmp_postfix
48 |
49 | self._out_dir = None
50 | if out_dir is not None:
51 | self.set_out_dir(out_dir)
52 |
53 | def set_out_dir(self, out_dir):
54 | assert self._out_dir is None
55 | os.makedirs(out_dir, exist_ok=True)
56 | self._out_dir = out_dir
57 |
58 | def get_all_ckpts(self):
59 | """
60 | :return: All checkpoints in `self._out_dir`, sorted ascendingly by global_step.
61 | """
62 | return [os.path.join(self._out_dir, f)
63 | for f in sorted(os.listdir(self._out_dir))
64 | if f.startswith(self.ckpt_prefix)]
65 |
66 | def itr_ckpt(self):
67 | for ckpt_p in self.get_all_ckpts():
68 | yield self.get_itr_from_ckpt_p(ckpt_p), ckpt_p
69 |
70 | def get_ckpt_for_itr(self, itr):
71 | """
72 | Gets ckpt_itrc where itrc <= itr, i.e., the latest ckpt before `itr`.
73 | Special values: itr == -1 -> newest ckpt
74 | """
75 | ckpts = list(self.itr_ckpt())
76 | assert_exc(len(ckpts) > 0, 'No ckpts found in {}'.format(self._out_dir))
77 | if itr == -1:
78 | return ckpts[-1]
79 | first_itrc, _ = ckpts[0]
80 | assert_exc(first_itrc <= itr, 'Earliest ckpt {} is after {}'.format(first_itrc, itr))
81 | for itrc, ckpt_p in reversed(ckpts):
82 | if itrc <= itr:
83 | return itrc, ckpt_p
84 | raise ValueError('Unexpected, {}, {}'.format(itr, ckpts))
85 |
86 | def get_latest_ckpt(self):
87 | """
88 | :return: Most recent checkpoint. May be a temporary checkpoint.
89 | """
90 | return self.get_all_ckpts()[-1]
91 |
92 | def get_lastest_persistent_ckpt(self):
93 | """
94 | :return: Most recent persistent checkpoint. May be a temporary checkpoint.
95 | """
96 | candidates = [p for p in self.get_all_ckpts() if not p.endswith(self.tmp_postfix)]
97 | if len(candidates) == 0:
98 | raise ValueError('No persistent checkpoints')
99 | return candidates[-1]
100 |
101 | def _get_out_p(self, global_step, is_tmp):
102 | postfix = self.tmp_postfix if is_tmp else ''
103 | return os.path.join(self._out_dir, self.ckpt_name_fmt.format(global_step) + postfix)
104 |
105 | def get_itr_from_ckpt_p(self, ckpt_p):
106 | file_name = os.path.splitext(os.path.basename(ckpt_p))[0]
107 | assert self.ckpt_prefix in file_name
108 | itr_part = file_name.replace(self.ckpt_prefix, '')
109 | itr_part_digits_only = int(''.join(c for c in itr_part if c.isdigit()))
110 | return itr_part_digits_only
111 |
112 |
113 |
114 | class Saver(_CheckpointTracker):
115 | """
116 | Saves ckpts:
117 | - ckpt_XXXXXXXX.pt.tmp
118 | If keep_tmp_last=None:
119 | Every `keep_every`-th ckpt is renamed to
120 | - ckpt_XXXXXXXX.pt
121 | and kept, the intermediate ones are removed. We call this a persistent checkpoint.
122 | else:
123 | Let C be the most recent persistent checkpoint.
124 | In addition to C being kept, the last `keep_tmp_last` temporary checkpoints before C are also kept.
125 | This means that always `keep_tmp_last` more checkpoints are kept than if keep_tmp_last=None
126 | """
127 | def __init__(self,
128 | keep_tmp_itr: int, keep_every=10, keep_tmp_last=None,
129 | out_dir=None, ckpt_name_fmt='ckpt_{:010d}.pt', tmp_postfix='.tmp',
130 | verbose=False):
131 | """
132 | :param keep_every: keep every `keep_every`-th checkpoint, making it a persistent checkpoint
133 | :param keep_tmp_itr: keep checkpoint every `keep_tmp_itr` iterations.
134 | :param keep_tmp_last: Also keep the last `keep_tmp_last` temporary checkpoints before a persistent checkpoint.
135 | :param ckpt_name_fmt: filename, must include a format spec and some prefix before the format
136 | :param tmp_postfix: non-empty string to append to temporary checkpoints
137 | :param verbose: if True, print rename and remove info.
138 | """
139 | self.keep_every = keep_every
140 | self.keep_tmp_last = keep_tmp_last
141 | self.keep_tmp_itr = keep_tmp_itr
142 | self.ckpts_since_last_permanent = 0
143 | self.print = print if verbose else NoOp
144 | self.save_time_acc = timer.TimeAccumulator()
145 | super(Saver, self).__init__(out_dir, ckpt_name_fmt, tmp_postfix)
146 |
147 | def save(self, modules, global_step, force=False):
148 | """
149 | Save iff (force given or global_step % keep_tmp_itr == 0)
150 | :param modules: dictionary name -> nn.Module
151 | :param global_step: current step
152 | :return: bool, Whether previous checkpoints were removed
153 | """
154 | if not (force or (global_step % self.keep_tmp_itr == 0)):
155 | return False
156 | assert self._out_dir is not None
157 | current_ckpt_p = self._save(modules, global_step)
158 | self.ckpts_since_last_permanent += 1
159 | if self.ckpts_since_last_permanent == self.keep_every:
160 | self._remove_previous(current_ckpt_p)
161 | self.ckpts_since_last_permanent = 0
162 | return True
163 | return False
164 |
165 | def _save(self, modules, global_step):
166 | out_p = self._get_out_p(global_step, is_tmp=True)
167 | with self.save_time_acc.execute():
168 | torch.save({key: m.state_dict() for key, m in modules.items()}, out_p)
169 | return out_p
170 |
171 | def _remove_previous(self, current_ckpt_p):
172 | assert self.tmp_postfix in current_ckpt_p
173 | current_ckpt_p_non_tmp = current_ckpt_p.replace(self.tmp_postfix, '')
174 | self.print('{} -> {}'.format(basename(current_ckpt_p), basename(current_ckpt_p_non_tmp)))
175 | os.rename(current_ckpt_p, current_ckpt_p_non_tmp)
176 | keep_tmp_last = self.get_all_ckpts()[-(self.keep_tmp_last+1):] if self.keep_tmp_last else []
177 | for p in self.get_all_ckpts():
178 | if self.tmp_postfix in p and p not in keep_tmp_last:
179 | self.print('Removing {}...'.format(basename(p)))
180 | os.remove(p)
181 | self.print('Average save time: {:.3f}s'.format(self.save_time_acc.mean_time_spent()))
182 |
183 |
184 | class Restorer(_CheckpointTracker):
185 | def restore_latest_persistent(self, net):
186 | return self.restore(net, self.get_lastest_persistent_ckpt())
187 |
188 | def restore(self, modules, ckpt_p, strict=True, restore_restart=False):
189 | print('Restoring {}... (strict={})'.format(ckpt_p, strict))
190 | map_location = None if pe.CUDA_AVAILABLE else 'cpu'
191 | state_dicts = torch.load(ckpt_p, map_location=map_location)
192 | # ---
193 | for key, m in modules.items():
194 | # optim implements its own load_state_dict which does not have the `strict` keyword...
195 | if isinstance(m, optimizer.Optimizer):
196 | if restore_restart:
197 | print('Not restoring optimizer, --restore_restart given...')
198 | else:
199 | try:
200 | m.load_state_dict(state_dicts[key])
201 | except ValueError as e:
202 | raise ValueError('Error while restoring Optimizer:', str(e))
203 | else:
204 | try:
205 | m.load_state_dict(state_dicts[key], strict=strict)
206 | except RuntimeError as e: # loading error
207 | for n, module in sorted(m.named_modules()):
208 | print(n, module)
209 | raise e
210 | return self.get_itr_from_ckpt_p(ckpt_p)
211 |
212 |
213 |
214 |
--------------------------------------------------------------------------------
/src/helpers/testset.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | import argparse
20 | import os
21 | import shutil
22 | from functools import total_ordering
23 |
24 | import numpy as np
25 | from fjcommon import os_ext
26 | from fjcommon.assertions import assert_exc
27 |
28 | from helpers.paths import has_image_ext, img_name
29 |
30 |
31 | @total_ordering
32 | class Testset(object):
33 | """
34 | Class the holds a reference to paths of images inside a folder
35 | """
36 | def __init__(self, root_dir_or_img, max_imgs=None, skip_hidden=False, append_id=None):
37 | """
38 | :param root_dir_or_img: Either a directory with images or the path of a single image.
39 | :param max_imgs: If given, subsample deterministically to only contain max_imgs
40 | :param skip_hidden: If given, skip images starting with '.'
41 | :param append_id: If given, append `append_id` to self.id
42 | :raises ValueError if root_dir is not a directory or does not contain images
43 | """
44 | self.root_dir_or_img = root_dir_or_img
45 |
46 | if os.path.isdir(root_dir_or_img):
47 | root_dir = root_dir_or_img
48 | self.name = os.path.basename(root_dir.rstrip('/'))
49 | self.ps = sorted(p for p in os_ext.listdir_paths(root_dir) if has_image_ext(p))
50 | if skip_hidden:
51 | self.ps = self._filter_hidden(self.ps)
52 | if max_imgs and max_imgs < len(self.ps):
53 | print('Subsampling to use {} imgs of {}...'.format(max_imgs, self.name))
54 | idxs = np.linspace(0, len(self.ps) - 1, max_imgs, dtype=np.int)
55 | self.ps = np.array(self.ps)[idxs].tolist()
56 | assert len(self.ps) == max_imgs
57 | assert_exc(len(self.ps) > 0, 'No images found in {}'.format(root_dir), ValueError)
58 | self.id = '{}_{}'.format(self.name, len(self.ps))
59 | self._str = 'Testset({}): in {}, {} images'.format(self.name, root_dir, len(self.ps))
60 | else:
61 | img = root_dir_or_img
62 | assert_exc(os.path.isfile(img), 'Does not exist: {}'.format(img), FileNotFoundError)
63 | self.name = os.path.basename(img)
64 | self.ps = [img]
65 | self.id = img
66 | self._str = 'Testset([{}]): 1 image'.format(self.name)
67 | if append_id:
68 | self.id += append_id
69 |
70 | def search_path(self):
71 | return self.root_dir_or_img
72 |
73 | def filter_filenames(self, filter_filenames):
74 | filename = lambda p: os.path.splitext(os.path.basename(p))[0]
75 | self.ps = [p for p in self.ps
76 | if filename(p) in filter_filenames]
77 | assert_exc(len(self.ps) > 0, 'No files after filtering for {}'.format(filter_filenames))
78 |
79 | def iter_img_names(self):
80 | return map(img_name, self.ps)
81 |
82 | def iter_orig_paths(self):
83 | return self.ps
84 |
85 | def __len__(self):
86 | return len(self.ps)
87 |
88 | def __str__(self):
89 | return self._str
90 |
91 | def __repr__(self):
92 | return 'Testset({}): {} paths'.format(self.name, len(self.ps))
93 |
94 | # to enable sorting
95 | def __lt__(self, other):
96 | return self.id < other.id
97 |
98 | @staticmethod
99 | def _filter_hidden(ps):
100 | count_a = len(ps)
101 | ps = [p for p in ps if not os.path.basename(p).startswith('.')]
102 | count_b = len(ps)
103 | if count_b < count_a:
104 | print(f'NOTE: Filtered {count_a - count_b} hidden file(s).')
105 | return ps
106 |
107 |
108 | def main():
109 | p = argparse.ArgumentParser('Copy deterministic subset of images to another directory.')
110 | p.add_argument('root_dir')
111 | p.add_argument('max_imgs', type=int)
112 | p.add_argument('out_dir')
113 | p.add_argument('--dry')
114 | p.add_argument('--verbose', '-v', action='store_true')
115 | flags = p.parse_args()
116 | os.makedirs(flags.out_dir, exist_ok=True)
117 |
118 | t = Testset(flags.root_dir, flags.max_imgs)
119 |
120 | def cp(p1, p2):
121 | if os.path.isfile(p2):
122 | print('Exists, skipping: {}'.format(p2))
123 | return
124 | if flags.verbose:
125 | print('cp {} -> {}'.format(p1, p2))
126 | if not flags.dry:
127 | shutil.copy(p1, p2)
128 |
129 | for p in t.iter_orig_paths():
130 | cp(p, os.path.join(flags.out_dir, os.path.basename(p)))
--------------------------------------------------------------------------------
/src/import_train_images.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | import argparse
20 | import multiprocessing
21 | import os
22 | import random
23 | import shutil
24 | import time
25 | import warnings
26 | from os.path import join
27 |
28 | import PIL
29 | import numpy as np
30 | import skimage.color
31 | from PIL import Image
32 |
33 | from helpers.paths import IMG_EXTENSIONS
34 |
35 | # TO SPEED THINGS UP: run on CPU cluster! We use task_array for this.
36 | # task_array is not released. It's used by us to batch process on our servers. Feel free to replace with whatever you
37 | # use. Make sure to set NUM_TASKS (number of concurrent processes) and set job_enumerate to a function that takes an
38 | # iterable and only yield elements to be processed by the current process.
39 | try:
40 | from task_array import NUM_TASKS, job_enumerate
41 | except ImportError:
42 | NUM_TASKS = 1
43 | job_enumerate = enumerate
44 |
45 | warnings.filterwarnings("ignore")
46 |
47 |
48 | random.seed(123)
49 |
50 |
51 | _NUM_PROCESSES = int(os.environ.get('NUM_PROCESS', 16))
52 | _DEFAULT_MAX_SCALE = 0.8
53 |
54 |
55 | get_fn = lambda p_: os.path.splitext(os.path.basename(p_))[0]
56 |
57 |
58 | def iter_images(root_dir, num_folder_levels=0):
59 | fns = sorted(os.listdir(root_dir))
60 | for fn in fns:
61 | if num_folder_levels > 0:
62 | dir_p = os.path.join(root_dir, fn)
63 | if os.path.isdir(dir_p):
64 | print('Recursing into', fn)
65 | yield from iter_images(dir_p, num_folder_levels - 1)
66 | continue
67 | _, ext = os.path.splitext(fn)
68 | if ext.lower() in IMG_EXTENSIONS:
69 | yield os.path.join(root_dir, fn)
70 |
71 |
72 | class Helper(object):
73 | def __init__(self, out_dir_clean, out_dir_discard, min_res: int):
74 | print(f'Creating {out_dir_clean}, {out_dir_discard}...')
75 | os.makedirs(out_dir_clean, exist_ok=True)
76 | os.makedirs(out_dir_discard, exist_ok=True)
77 | self.out_dir_clean = out_dir_clean
78 | self.out_dir_discard = out_dir_discard
79 |
80 | print('Getting images already processed...', end=" ", flush=True)
81 | self.images_cleaned = set(map(get_fn, os.listdir(out_dir_clean)))
82 | self.images_discarded = set(map(get_fn, os.listdir(out_dir_discard)))
83 | print(f'-> Found {len(self.images_cleaned) + len(self.images_discarded)} images.')
84 |
85 | self.min_res = min_res
86 |
87 | def process_all_in(self, input_dir):
88 | images_dl = iter_images(input_dir) # generator of paths
89 |
90 | # files this job should compress
91 | files_of_job = [p for _, p in job_enumerate(images_dl)]
92 | # files that were compressed already by somebody (i.e. this job earlier)
93 | processed_already = self.images_cleaned | self.images_discarded
94 | # resulting files to be compressed
95 | files_of_job = [p for p in files_of_job if get_fn(p) not in processed_already]
96 |
97 | N = len(files_of_job)
98 | if N == 0:
99 | print('Everything processed / nothing to process.')
100 | return
101 |
102 | num_process = 2 if NUM_TASKS > 1 else _NUM_PROCESSES
103 | print(f'Processing {N} images using {num_process} processes in {NUM_TASKS} tasks...')
104 |
105 | start = time.time()
106 | predicted_time = None
107 | with multiprocessing.Pool(processes=num_process) as pool:
108 | for i, clean in enumerate(pool.imap_unordered(self.process, files_of_job)):
109 | if i > 0 and i % 100 == 0:
110 | time_per_img = (time.time() - start) / (i + 1)
111 | time_remaining = time_per_img * (N - i)
112 | if not predicted_time:
113 | predicted_time = time_remaining
114 | print(f'\r{time_per_img:.2e} s/img | '
115 | f'{i / N * 100:.1f}% | '
116 | f'{time_remaining / 60:.1f} min remaining', end='', flush=True)
117 |
118 | def process(self, p_in):
119 | fn, ext = os.path.splitext(os.path.basename(p_in))
120 | if fn in self.images_cleaned:
121 | return 1
122 | if fn in self.images_discarded:
123 | return 0
124 | try:
125 | im = Image.open(p_in)
126 | except OSError as e:
127 | print(f'\n*** Error while opening {p_in}: {e}')
128 | return 0
129 | im_out = random_resize_or_discard(im, self.min_res)
130 | if im_out is not None:
131 | p_out = join(self.out_dir_clean, fn + '.png') # Make sure to use .png!
132 | im_out.save(p_out)
133 | return 1
134 | else:
135 | p_out = join(self.out_dir_discard, os.path.basename(p_in))
136 | shutil.copy(p_in, p_out)
137 | return 0
138 |
139 |
140 | def random_resize_or_discard(im, min_res: int):
141 | """Randomly resize image with `random_resize` and check if it should be discarded."""
142 | im_resized = random_resize(im, min_res)
143 | if im_resized is None:
144 | return None
145 | if should_discard(im_resized):
146 | return None
147 | return im_resized
148 |
149 |
150 | def random_resize(im, min_res: int, max_scale=_DEFAULT_MAX_SCALE):
151 | """Scale longer side to `min_res`, but only if that scales by <= max_scale."""
152 | W, H = im.size
153 | D = min(W, H)
154 | scale_min = min_res / D
155 | # Image is too small to downscale by a factor smaller MAX_SCALE.
156 | if scale_min > max_scale:
157 | return None
158 |
159 | # Get a random scale for new size.
160 | scale = random.uniform(scale_min, max_scale)
161 | new_size = round(W * scale), round(H * scale)
162 | try:
163 | # Using LANCZOS!
164 | return im.resize(new_size, resample=PIL.Image.LANCZOS)
165 | except OSError as e: # Happens for corrupted images
166 | print('*** Caught im.resize error', e)
167 | return None
168 |
169 |
170 | def should_discard(im):
171 | """Return true iff the image is high in saturation or value, or not RGB."""
172 | # Modes found in train_0:
173 | # Counter({'RGB': 152326, 'L': 4149, 'CMYK': 66})
174 | if im.mode != 'RGB':
175 | return True
176 | im_rgb = np.array(im)
177 | im_hsv = skimage.color.rgb2hsv(im_rgb)
178 | mean_hsv = np.mean(im_hsv, axis=(0, 1))
179 | _, s, v = mean_hsv
180 | if s > 0.9:
181 | return True
182 | if v > 0.8:
183 | return True
184 | return False
185 |
186 |
187 | def main():
188 | p = argparse.ArgumentParser()
189 | p.add_argument('base_dir',
190 | help='Directory of images, or directory of DIRS.')
191 | p.add_argument('dirs', nargs='*',
192 | help='If given, must be subdirectories in BASE_DIR. Will be processed. '
193 | 'If not given, assume BASE_DIR is already a directory of images.')
194 | p.add_argument('--out_dir_clean', required=True)
195 | p.add_argument('--out_dir_discard', required=True)
196 | p.add_argument('--resolution', type=int, default=512,
197 | help='Randomly rescale each image to be at least '
198 | 'RANDOM_SCALE long on the longer side.')
199 |
200 | flags = p.parse_args()
201 |
202 | # If --dirs not given, just assume `base_dir` is already the directory of images.
203 | if not flags.dirs:
204 | flags.dirs = [os.path.basename(flags.base_dir)]
205 | flags.base_dir = os.path.dirname(flags.base_dir)
206 |
207 | h = Helper(flags.out_dir_clean, flags.out_dir_discard, flags.resolution)
208 | for i, d in enumerate(flags.dirs):
209 | print(f'*** {d}: {i}/{len(flags.dirs)}')
210 | h.process_all_in(join(flags.base_dir, d))
211 |
212 | print('\n\nDONE') # For cluster logs.
213 |
214 |
215 | if __name__ == '__main__':
216 | main()
217 |
--------------------------------------------------------------------------------
/src/import_train_images_v1.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | import os
20 | import random
21 | from os.path import join
22 | import PIL
23 | from PIL import Image
24 | import skimage.color
25 | import numpy as np
26 | import argparse
27 | import warnings
28 | # task_array is not released. It's used by us to batch process on our servers. Feel free to replace with whatever you
29 | # use. Make sure to set NUM_TASKS (number of concurrent processes) and set job_enumerate to a function that takes an
30 | # iterable and only yield elements to be processed by the current process.
31 | try:
32 | from task_array import NUM_TASKS, job_enumerate
33 | except ImportError:
34 | NUM_TASKS = 1
35 | job_enumerate = enumerate
36 |
37 | warnings.filterwarnings("ignore")
38 |
39 |
40 | QUALITY = 95
41 | MAX_SCALE = 0.95
42 |
43 |
44 | def main():
45 | p = argparse.ArgumentParser()
46 | p.add_argument('base_dir')
47 | p.add_argument('dirs', nargs='+')
48 | p.add_argument('--out_dir_clean', required=True)
49 | p.add_argument('--out_dir_discard', required=True)
50 |
51 | p.add_argument('--resolution', '-r', type=str, default='768', help='can be randX_Y')
52 | p.add_argument('--seed')
53 | flags = p.parse_args()
54 |
55 | if flags.seed:
56 | print('Seeding {}'.format(flags.seed))
57 | random.seed(flags.seed)
58 |
59 | for d in flags.dirs:
60 | process(join(flags.base_dir, d), flags.out_dir_clean, flags.out_dir_discard, flags.resolution)
61 |
62 |
63 | def get_res(res) -> int:
64 | try:
65 | return int(res)
66 | except ValueError:
67 | pass
68 | if not res.startswith('rand'):
69 | raise ValueError('Expected res to be either int or `randX_Y`')
70 | X, Y = map(int, res.replace('rand', '').split('_'))
71 | return random.randint(X, Y)
72 |
73 |
74 | def process(input_dir, out_dir_clean, out_dir_discard, res: str):
75 | os.makedirs(out_dir_clean, exist_ok=True)
76 | os.makedirs(out_dir_discard, exist_ok=True)
77 |
78 | images_cleaned = set(os.listdir(out_dir_clean))
79 | images_discarded = set(os.listdir(out_dir_discard))
80 |
81 | images_dl = os.listdir(input_dir)
82 | N = len(images_dl) // NUM_TASKS
83 |
84 | clean = 0
85 | discarded = 0
86 | for i, imfile in job_enumerate(images_dl):
87 | if imfile in images_cleaned:
88 | clean += 1
89 | continue
90 | if imfile in images_discarded:
91 | discarded += 1
92 | continue
93 | im = Image.open(join(input_dir, imfile))
94 | res = get_res(res)
95 | im2 = resize_or_discard(im, res, should_clean=True)
96 | if im2 is not None:
97 | fn, ext = os.path.splitext(imfile)
98 | im2.save(join(out_dir_clean, fn + '.jpg'), quality=QUALITY)
99 | clean += 1
100 | else:
101 | im.save(join(out_dir_discard, imfile))
102 | discarded += 1
103 | print(f'\r{os.path.basename(input_dir)} -> {os.path.basename(out_dir_clean)} // '
104 | f'Resized: {clean}/{N}; Discarded: {discarded}/{N}', end='')
105 | # Done
106 | print(f'\n{os.path.basename(input_dir)} -> {os.path.basename(out_dir_clean)} // '
107 | f'Resized: {clean}/{N}; Discarded: {discarded}/{N}')
108 |
109 |
110 | def resize_or_discard(im, res: int, verbose=False, should_clean=True):
111 | im2 = resize(im, res, verbose)
112 | if im2 is None:
113 | return None
114 | if should_clean and should_discard(im2, verbose):
115 | return None
116 | return im2
117 |
118 |
119 | def resize(im, res, verbose=False, max_scale=MAX_SCALE):
120 | W, H = im.size
121 | D = max(W, H)
122 | s = float(res) / D
123 | if max_scale and s > max_scale:
124 | if verbose:
125 | print('Too big: {}'.format((W, H)))
126 | return None
127 | W2 = round(W * s)
128 | H2 = round(H * s)
129 | try:
130 | return im.resize((W2, H2), resample=PIL.Image.BICUBIC)
131 | except OSError as e:
132 | print(e)
133 | return None
134 |
135 |
136 | def should_discard(im, verbose=False):
137 | im_rgb = np.array(im)
138 | if im_rgb.ndim != 3 or im_rgb.shape[2] != 3:
139 | if verbose:
140 | print('Invalid shape: {}'.format(im_rgb.shape))
141 | return True
142 | im_hsv = skimage.color.rgb2hsv(im_rgb)
143 | mean_hsv = np.mean(im_hsv, axis=(0, 1))
144 | h, s, v = mean_hsv
145 | if s > 0.9:
146 | if verbose:
147 | print('Invalid s: {}'.format(s))
148 | return True
149 | if v > 0.8:
150 | if verbose:
151 | print('Invalid v: {}'.format(v))
152 | return True
153 | return False
154 |
155 |
156 | def get_hsv(im):
157 | im_rgb = np.array(im)
158 | im_hsv = skimage.color.rgb2hsv(im_rgb)
159 | mean_hsv = np.mean(im_hsv, axis=(0, 1))
160 | h, s, v = mean_hsv
161 | return h, s, v
162 |
163 |
164 | if __name__ == '__main__':
165 | main()
166 |
--------------------------------------------------------------------------------
/src/l3c.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | import torch.backends.cudnn
20 | torch.backends.cudnn.benchmark = False
21 | torch.backends.cudnn.deterministic = True
22 |
23 | import pytorch_ext as pe
24 |
25 | import argparse
26 |
27 | from test.multiscale_tester import MultiscaleTester, EncodeError, DecodeError
28 |
29 |
30 | # throws an exception if no backend available!
31 | from torchac import torchac
32 |
33 |
34 | class _FakeFlags(object):
35 | def __init__(self, flags):
36 | self.flags = flags
37 |
38 | def __getattr__(self, item):
39 | try:
40 | return self.flags.__dict__[item]
41 | except KeyError:
42 | return None
43 |
44 |
45 | def parse_device_flag(flag):
46 | print(f'Status: '
47 | f'torchac-backend-gpu available: {torchac.CUDA_SUPPORTED} // '
48 | f'torchac-backend-cpu available: {torchac.CPU_SUPPORTED} // '
49 | f'CUDA available: {pe.CUDA_AVAILABLE}')
50 | if flag == 'auto':
51 | if torchac.CUDA_SUPPORTED and pe.CUDA_AVAILABLE:
52 | flag = 'gpu'
53 | elif torchac.CPU_SUPPORTED:
54 | flag = 'cpu'
55 | else:
56 | raise ValueError('No suitable backend found!')
57 |
58 | if flag == 'cpu' and not torchac.CPU_SUPPORTED:
59 | raise ValueError('torchac-backend-cpu is not available. Please install.')
60 | if flag == 'gpu' and not torchac.CUDA_SUPPORTED:
61 | raise ValueError('torchac-backend-gpu is not available. Please install.')
62 | if flag == 'gpu' and not pe.CUDA_AVAILABLE:
63 | raise ValueError('Selected torchac-backend-gpu but CUDA not available!')
64 |
65 | assert flag in ('gpu', 'cpu')
66 | if flag == 'gpu' and not pe.CUDA_AVAILABLE:
67 | raise ValueError('Selected GPU backend but cuda is not available!')
68 |
69 | pe.CUDA_AVAILABLE = flag == 'gpu'
70 | pe.set_device(pe.CUDA_AVAILABLE)
71 | print(f'*** Using torchac-backend-{flag}; did set CUDA_AVAILABLE={pe.CUDA_AVAILABLE} and DEVICE={pe.DEVICE}')
72 |
73 |
74 | def main():
75 | p = argparse.ArgumentParser(description='Encoder/Decoder for L3C')
76 |
77 | p.add_argument('log_dir', help='Directory of experiments.')
78 | p.add_argument('log_date', help='A log_date, such as 0104_1345.')
79 |
80 | p.add_argument('--device', type=str, choices=['auto', 'gpu', 'cpu'], default='auto',
81 | help='Select the device to run this code on, as mentioned in the README section "Selecting torchac". '
82 | 'If DEVICE=auto, select torchac-backend depending on whether torchac-backend-gpu or -cpu is '
83 | 'available. See function parse_device_flag for details.'
84 | 'If DEVICE=gpu or =cpu, force usage of this backend.')
85 |
86 | p.add_argument('--restore_itr', '-i', default=-1, type=int,
87 | help='Which iteration to restore. -1 means latest iteration. Default: -1')
88 |
89 | mode = p.add_subparsers(title='mode', dest='mode')
90 |
91 | enc = mode.add_parser('enc', help='Encode image. enc IMG_P OUT_P [--overwrite | -f]\n'
92 | ' IMG_P: Path to an Image, readable by PIL.\n'
93 | ' OUT_P: Path to where to save the bitstream.\n'
94 | ' OVERWRITE: If given, overwrite OUT_P.'
95 | 'Example:\n'
96 | ' python l3c.py LOG_DIR LOG_DATE enc some/img.jpg out/img.l3c')
97 | dec = mode.add_parser('dec', help='Decode image. dec IMG_P OUT_P_PNG\n'
98 | ' IMG_P: Path to an L3C-encoded Image, readable by PIL.\n'
99 | ' OUT_P_PNG: Path to where to save the decode image to, as a PNG.\n'
100 | 'Example:\n'
101 | ' python l3c.py LOG_DIR LOG_DATE dec out/img.l3c decoded.png')
102 |
103 | enc.add_argument('img_p')
104 | enc.add_argument('out_p')
105 | enc.add_argument('--overwrite', '-f', action='store_true')
106 |
107 | dec.add_argument('img_p')
108 | dec.add_argument('out_p_png')
109 |
110 | flags = p.parse_args()
111 | parse_device_flag(flags.device)
112 |
113 | print('Testing {} at {} ---'.format(flags.log_date, flags.restore_itr))
114 | tester = MultiscaleTester(flags.log_date, _FakeFlags(flags), flags.restore_itr, l3c=True)
115 |
116 | if flags.mode == 'enc':
117 | try:
118 | tester.encode(flags.img_p, flags.out_p, flags.overwrite)
119 | except EncodeError as e:
120 | print('*** EncodeError:', e)
121 | else:
122 | try:
123 | tester.decode(flags.img_p, flags.out_p_png)
124 | except DecodeError as e:
125 | print('*** DecodeError:', e)
126 |
127 |
128 | if __name__ == '__main__':
129 | main()
130 |
--------------------------------------------------------------------------------
/src/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/src/modules/__init__.py
--------------------------------------------------------------------------------
/src/modules/edsr.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | # Code adapted from src/model/common.py of this repo:
20 | #
21 | # https://github.com/thstkdgus35/EDSR-PyTorch
22 | #
23 | # Original license:
24 | #
25 | # MIT License
26 | #
27 | # Copyright (c) 2018 Sanghyun Son
28 | #
29 | # Permission is hereby granted, free of charge, to any person obtaining a copy
30 | # of this software and associated documentation files (the "Software"), to deal
31 | # in the Software without restriction, including without limitation the rights
32 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
33 | # copies of the Software, and to permit persons to whom the Software is
34 | # furnished to do so, subject to the following conditions:
35 | #
36 | # The above copyright notice and this permission notice shall be included in all
37 | # copies or substantial portions of the Software.
38 | #
39 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
40 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
41 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
42 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
43 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
44 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
45 | # SOFTWARE.
46 |
47 | import math
48 |
49 | import torch
50 | import torch.nn as nn
51 |
52 | class MeanShift(nn.Conv2d):
53 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
54 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
55 | std = torch.Tensor(rgb_std)
56 | self.weight.data = torch.eye(3).view(3, 3, 1, 1)
57 | self.weight.data.div_(std.view(3, 1, 1, 1))
58 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
59 | self.bias.data.div_(std)
60 | self.requires_grad = False
61 |
62 |
63 | class ResBlock(nn.Module):
64 | def __init__(self, conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), atrous=False):
65 |
66 | super(ResBlock, self).__init__()
67 | m = []
68 | _repr = []
69 | for i in range(2):
70 | atrous_rate = 1 if (not atrous or i == 0) else atrous
71 | m.append(conv(n_feats, n_feats, kernel_size, rate=atrous_rate, bias=bias))
72 | _repr.append(f'Conv({n_feats}x{kernel_size}' + (f';A*{atrous_rate})' if atrous_rate != 1 else '') + ')')
73 | if bn:
74 | m.append(nn.BatchNorm2d(n_feats))
75 | _repr.append(f'BN({n_feats})')
76 | if i == 0:
77 | m.append(act)
78 | _repr.append(f'Act')
79 |
80 | self.body = nn.Sequential(*m)
81 | self._repr = '/'.join(_repr)
82 |
83 | def forward(self, x):
84 | res = self.body(x)
85 | res += x
86 | return res
87 |
88 | def __repr__(self):
89 | return f'ResBlock({self._repr})'
90 |
91 |
92 | class Upsampler(nn.Sequential):
93 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
94 |
95 | m = []
96 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
97 | for _ in range(int(math.log(scale, 2))):
98 | m.append(conv(n_feats, 4 * n_feats, 3, bias))
99 | m.append(nn.PixelShuffle(2))
100 | if bn: m.append(nn.BatchNorm2d(n_feats))
101 |
102 | if act == 'relu':
103 | m.append(nn.ReLU(True))
104 | elif act == 'prelu':
105 | m.append(nn.PReLU(n_feats))
106 |
107 | elif scale == 3:
108 | m.append(conv(n_feats, 9 * n_feats, 3, bias))
109 | m.append(nn.PixelShuffle(3))
110 | if bn: m.append(nn.BatchNorm2d(n_feats))
111 |
112 | if act == 'relu':
113 | m.append(nn.ReLU(True))
114 | elif act == 'prelu':
115 | m.append(nn.PReLU(n_feats))
116 | else:
117 | raise NotImplementedError
118 |
119 | super(Upsampler, self).__init__(*m)
120 |
121 |
--------------------------------------------------------------------------------
/src/modules/head.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | from torch import nn
20 |
21 | from modules import edsr
22 | from pytorch_ext import default_conv as conv
23 |
24 |
25 |
26 | class RGBHead(nn.Module):
27 | """ Go from 3 channels (RGB) to Cf channels, also normalize RGB """
28 | def __init__(self, config_ms):
29 | super(RGBHead, self).__init__()
30 | assert 'Subsampling' not in config_ms.enc.cls, 'For Subsampling encoders, head should be ID'
31 | self.head = nn.Sequential(
32 | edsr.MeanShift(0, (0., 0., 0.), (128., 128., 128.)),
33 | Head(config_ms, Cin=3))
34 | self._repr = 'MeanShift//Head(C=3)'
35 |
36 | def __repr__(self):
37 | return f'RGBHead({self._repr})'
38 |
39 | def forward(self, x):
40 | return self.head(x)
41 |
42 |
43 | class Head(nn.Module):
44 | """
45 | Go from Cin channels to Cf channels.
46 | For L3C, Cin=Cf, and this is the convolution yielding E^{s+1}_in in Fig. 2.
47 |
48 | """
49 | def __init__(self, config_ms, Cin):
50 | super(Head, self).__init__()
51 | assert 'Subsampling' not in config_ms.enc.cls, 'For Subsampling encoders, head should be ID'
52 | self.head = conv(Cin, config_ms.Cf, config_ms.kernel_size)
53 | self._repr = f'Conv({config_ms.Cf})'
54 |
55 | def __repr__(self):
56 | return f'Head({self._repr})'
57 |
58 | def forward(self, x):
59 | return self.head(x)
60 |
61 |
62 |
--------------------------------------------------------------------------------
/src/modules/net.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 |
19 | --------------------------------------------------------------------------------
20 |
21 | File for Net, which contains encoder/decoder of MultiscaleNetwork
22 |
23 | """
24 | from collections import namedtuple
25 |
26 | import torch
27 | from torch import nn
28 |
29 | import pytorch_ext as pe
30 | import vis.histogram_plot
31 | import vis.summarizable_module
32 | from dataloaders.images_loader import resize_bicubic_batch
33 | from modules import edsr
34 | from modules.quantizer import Quantizer
35 |
36 | EncOut = namedtuple('EncOut', ['bn', # NCH'W'
37 | 'bn_q', # quantized bn, NCH'W'
38 | 'S', # NCH'W', long
39 | 'L', # int
40 | 'F' # NCfH'W', float, before Q
41 | ])
42 | DecOut = namedtuple('DecOut', ['F', # NCfHW
43 | ])
44 |
45 |
46 | conv = pe.default_conv
47 |
48 |
49 | class Net(nn.Module):
50 | def __init__(self, config_ms, scale):
51 | super(Net, self).__init__()
52 | self.config_ms = config_ms
53 | self.enc = {
54 | 'EDSRLikeEnc': EDSRLikeEnc,
55 | 'BicubicSubsampling': BicubicDownsamplingEnc,
56 | }[config_ms.enc.cls](config_ms, scale)
57 | self.dec = {
58 | 'EDSRDec': EDSRDec
59 | }[config_ms.dec.cls](config_ms, scale)
60 |
61 | def forward(self, x):
62 | raise NotImplementedError() # Call .enc and .dec directly
63 |
64 |
65 | class BicubicDownsamplingEnc(vis.summarizable_module.SummarizableModule):
66 | def __init__(self, *_):
67 | super(BicubicDownsamplingEnc, self).__init__()
68 | # TODO: ugly
69 | self.rgb_mean = torch.tensor(
70 | [0.4488, 0.4371, 0.4040], dtype=torch.float32).reshape(3, 1, 1).mul(255.).to(pe.DEVICE)
71 |
72 | def forward(self, x):
73 | x = x + self.rgb_mean # back to 0...255
74 | x = x.clamp(0, 255.).round().type(torch.uint8)
75 | x = resize_bicubic_batch(x, 0.5).to(pe.DEVICE)
76 | sym = x.long()
77 | x = x.float() - self.rgb_mean
78 | x = x.detach() # make sure no gradients back to this point
79 | self.summarizer.register_images('train', {'input_subsampled': lambda: sym.type(torch.uint8)})
80 | return EncOut(x, x, sym, 256, None)
81 |
82 |
83 | def new_levels(L, initial_levels):
84 | lo, hi = initial_levels
85 | levels = torch.linspace(lo, hi, L)
86 | return torch.tensor(levels, requires_grad=False)
87 |
88 |
89 | class EDSRLikeEnc(vis.summarizable_module.SummarizableModule):
90 | def __init__(self, config_ms, scale):
91 | super(EDSRLikeEnc, self).__init__()
92 |
93 | self.scale = scale
94 | self.config_ms = config_ms
95 | Cf = config_ms.Cf
96 | kernel_size = config_ms.kernel_size
97 | C, self.L = config_ms.q.C, config_ms.q.L
98 |
99 | n_resblock = config_ms.enc.num_blocks
100 |
101 | # Downsampling
102 | self.down = conv(Cf, Cf, kernel_size=5, stride=2)
103 |
104 | # Body
105 | m_body = [
106 | edsr.ResBlock(conv, Cf, kernel_size, act=nn.ReLU(True))
107 | for _ in range(n_resblock)
108 | ]
109 | m_body.append(conv(Cf, Cf, kernel_size))
110 | self.body = nn.Sequential(*m_body)
111 |
112 | # to Quantizer
113 | to_q = [conv(Cf, C, 1)]
114 | if self.training:
115 | to_q.append(
116 | # start scale from 1, as 0 is RGB
117 | vis.histogram_plot.HistogramPlot('train', 'histo/enc_{}_after_1x1'.format(scale+1), buffer_size=10,
118 | num_inputs_to_buffer=1, per_channel=False))
119 | self.to_q = nn.Sequential(*to_q)
120 |
121 | # We assume q.L levels, evenly distributed between q.levels_range[0] and q.levels_range[1]
122 | # In theory, the levels could be learned. But in this code, they are assumed to be fixed.
123 | levels_first, levels_last = config_ms.q.levels_range
124 | # Wrapping this in a nn.Parameter ensures it is copied to gpu when .to('cuda') is called
125 | self.levels = nn.Parameter(torch.linspace(levels_first, levels_last, self.L), requires_grad=False)
126 | self._extra_repr = 'Levels={}'.format(','.join(map('{:.1f}'.format, list(self.levels))))
127 | self.q = Quantizer(self.levels, config_ms.q.sigma)
128 |
129 | def extra_repr(self):
130 | return self._extra_repr
131 |
132 | def quantize_x(self, x):
133 | _, x_hard, _ = self.q(x)
134 | return x_hard
135 |
136 | def forward(self, x):
137 | """
138 | :param x: NCHW
139 | :return:
140 | """
141 | x = self.down(x)
142 | x = self.body(x) + x
143 | F = x
144 | x = self.to_q(x)
145 | # assert self.summarizer is not None
146 | x_soft, x_hard, symbols_hard = self.q(x)
147 | # TODO(parallel): To support nn.DataParallel, this must be changed, as it not a tensor
148 | return EncOut(x_soft, x_hard, symbols_hard, self.L, F)
149 |
150 |
151 | class EDSRDec(nn.Module):
152 | def __init__(self, config_ms, scale):
153 | super(EDSRDec, self).__init__()
154 |
155 | self.scale = scale
156 | n_resblock = config_ms.dec.num_blocks
157 |
158 | Cf = config_ms.Cf
159 | kernel_size = config_ms.kernel_size
160 | C = config_ms.q.C
161 |
162 | after_q_kernel = 1
163 | self.head = conv(C, config_ms.Cf, after_q_kernel)
164 | m_body = [
165 | edsr.ResBlock(conv, Cf, kernel_size, act=nn.ReLU(True))
166 | for _ in range(n_resblock)
167 | ]
168 |
169 | m_body.append(conv(Cf, Cf, kernel_size))
170 | self.body = nn.Sequential(*m_body)
171 | self.tail = edsr.Upsampler(conv, 2, Cf, act=False)
172 |
173 | def forward(self, x, features_to_fuse=None):
174 | """
175 | :param x: NCHW
176 | :return:
177 | """
178 | x = self.head(x)
179 | if features_to_fuse is not None:
180 | x = x + features_to_fuse
181 | x = self.body(x) + x
182 | x = self.tail(x)
183 | # TODO(parallel): To support nn.DataParallel, this must be changed, as it not a tensor
184 | return DecOut(x)
185 |
186 |
187 |
--------------------------------------------------------------------------------
/src/modules/prob_clf.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | import torch
20 | from torch import nn, nn as nn
21 |
22 | import pytorch_ext as pe
23 | from criterion.logistic_mixture import non_shared_get_Kp
24 |
25 | conv = pe.default_conv
26 |
27 |
28 |
29 | class AtrousProbabilityClassifier(nn.Module):
30 | def __init__(self, config_ms, C=3, atrous_rates_str='1,2,4'):
31 | super(AtrousProbabilityClassifier, self).__init__()
32 |
33 | K = config_ms.prob.K
34 | Kp = non_shared_get_Kp(K, C)
35 |
36 | self.atrous = StackedAtrousConvs(atrous_rates_str, config_ms.Cf, Kp,
37 | kernel_size=config_ms.kernel_size)
38 | self._repr = f'C={C}; K={K}; Kp={Kp}; rates={atrous_rates_str}'
39 |
40 | def __repr__(self):
41 | return f'AtrousProbabilityClassifier({self._repr})'
42 |
43 | def forward(self, x):
44 | """
45 | :param x: NCfHW
46 | :return: NKpHW
47 | """
48 | return self.atrous(x)
49 |
50 |
51 | class StackedAtrousConvs(nn.Module):
52 | def __init__(self, atrous_rates_str, Cin, Cout, bias=True, kernel_size=3):
53 | super(StackedAtrousConvs, self).__init__()
54 | atrous_rates = self._parse_atrous_rates_str(atrous_rates_str)
55 | self.atrous = nn.ModuleList(
56 | [conv(Cin, Cin, kernel_size, rate=rate) for rate in atrous_rates])
57 | self.lin = conv(len(atrous_rates) * Cin, Cout, 1, bias=bias)
58 | self._extra_repr = 'rates={}'.format(atrous_rates)
59 |
60 | @staticmethod
61 | def _parse_atrous_rates_str(atrous_rates_str):
62 | # expected to either be an int or a comma-separated string 1,2,4
63 | if isinstance(atrous_rates_str, int):
64 | return [atrous_rates_str]
65 | else:
66 | return list(map(int, atrous_rates_str.split(',')))
67 |
68 | def extra_repr(self):
69 | return self._extra_repr
70 |
71 | def forward(self, x):
72 | x = torch.cat([atrous(x) for atrous in self.atrous], dim=1)
73 | x = self.lin(x)
74 | return x
--------------------------------------------------------------------------------
/src/modules/quantizer.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 |
19 | --------------------------------------------------------------------------------
20 |
21 | Based on our TensorFlow implementation for the CVPR 2018 paper
22 |
23 | "Conditional Probability Models for Deep Image Compression"
24 |
25 | This is a PyTorch implementation of that quantization layer.
26 |
27 | https://github.com/fab-jul/imgcomp-cvpr/blob/master/code/quantizer.py
28 |
29 | """
30 | import torch
31 | import torch.nn.functional as F
32 | from torch import nn
33 |
34 |
35 | SIGMA_HARD = 1e7
36 |
37 |
38 | def to_sym(x, x_min, x_max, L):
39 | sym_range = x_max - x_min
40 | bin_size = sym_range / (L-1)
41 | return x.clamp(x_min, x_max).sub(x_min).div(bin_size).round().long()
42 |
43 |
44 | def to_bn(S, x_min, x_max, L):
45 | sym_range = x_max - x_min
46 | bin_size = sym_range / (L-1)
47 | return S.float().mul(bin_size).add(x_min)
48 |
49 |
50 | class Quantizer(nn.Module):
51 | def __init__(self, levels, sigma=1.0):
52 | super(Quantizer, self).__init__()
53 | assert levels.dim() == 1, 'Expected 1D levels, got {}'.format(levels)
54 | self.levels = levels
55 | self.sigma = sigma
56 | self.L = self.levels.size()[0]
57 |
58 | def __repr__(self):
59 | return '{}(sigma={})'.format(
60 | self._get_name(), self.sigma)
61 |
62 | def forward(self, x):
63 | """
64 | :param x: NCHW
65 | :return:, x_soft, symbols
66 | """
67 | assert x.dim() == 4, 'Expected NCHW, got {}'.format(x.size())
68 | N, C, H, W = x.shape
69 | # make x into NCm1, where m=H*W
70 | x = x.view(N, C, H*W, 1)
71 | # NCmL, d[..., l] gives distance to l-th level
72 | d = torch.pow(x - self.levels, 2)
73 | # NCmL, \sum_l d[..., l] sums to 1
74 | phi_soft = F.softmax(-self.sigma * d, dim=-1)
75 | # - Calcualte soft assignements ---
76 | # NCm, soft assign x to levels
77 | x_soft = torch.sum(self.levels * phi_soft, dim=-1)
78 | # NCHW
79 | x_soft = x_soft.view(N, C, H, W)
80 |
81 | # - Calcualte hard assignements ---
82 | # NCm, symbols_hard[..., i] contains index of symbol to use
83 | _, symbols_hard = torch.min(d.detach(), dim=-1)
84 | # NCHW
85 | symbols_hard = symbols_hard.view(N, C, H, W)
86 | # NCHW, contains value of symbol to use
87 | x_hard = self.levels[symbols_hard]
88 |
89 | x_soft.data = x_hard # assign data, keep gradient
90 | return x_soft, x_hard, symbols_hard
91 |
--------------------------------------------------------------------------------
/src/prep_openimages.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 |
5 | if [[ -z $1 ]]; then
6 | echo "USAGE: $0 DATA_DIR [OUT_DIR]"
7 | exit 1
8 | fi
9 |
10 | DATA_DIR=$(realpath $1)
11 | SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
12 |
13 | if [[ -n $2 ]]; then
14 | OUT_DIR=$2
15 | else
16 | OUT_DIR=$DATA_DIR
17 | fi
18 |
19 | progress () {
20 | COUNTER=0
21 | while read LINE; do
22 | COUNTER=$((COUNTER+1))
23 | if [[ $((COUNTER % 10)) == 0 ]]; then
24 | echo -ne "\rExtracting $LINE; Unpacked $COUNTER files."
25 | fi
26 | done
27 | echo ""
28 | }
29 |
30 | echo "DATA_DIR=$DATA_DIR;\nSCRIPT_DIR=$SCRIPT_DIR;\nSaving to $OUT_DIR"
31 |
32 | mkdir -pv $DATA_DIR
33 |
34 | TRAIN_0=train_0
35 | TRAIN_1=train_1
36 | TRAIN_2=train_2
37 | VAL=validation
38 |
39 | # Download ----------
40 | DOWNLOAD_DIR=$DATA_DIR/download
41 | mkdir -p $DOWNLOAD_DIR
42 | pushd $DOWNLOAD_DIR
43 | for DIR in $TRAIN_0 $TRAIN_1 $TRAIN_2 $VAL; do
44 | TAR=${DIR}.tar.gz
45 | if [ ! -f "$TAR" ]; then
46 | echo "Downloading $TAR..."
47 | aws s3 --no-sign-request cp s3://open-images-dataset/tar/$TAR $TAR
48 | else
49 | echo "Found $TAR..."
50 | fi
51 | done
52 |
53 | for DIR in $TRAIN_0 $TRAIN_1 $TRAIN_2 $VAL; do
54 | TAR=${DIR}.tar.gz
55 | if [ -d $DIR ]; then
56 | echo "Found $DIR, not unpacking $TAR..."
57 | continue
58 | fi
59 | if [ ! -f $TAR ]; then
60 | echo "ERROR: Expected $TAR in $DOWNLOAD_DIR"
61 | exit 1
62 | fi
63 | echo "Unpacking $TAR..."
64 | ( tar xvf $TAR | progress ) &
65 | done
66 |
67 | # Wait for all unpacking background processes
68 | wait
69 | echo "Unpacked all!"
70 |
71 | popd
72 |
73 | # Convert ----------
74 | FINAL_TRAIN_DIR=$DATA_DIR/train_oi
75 | FINAL_VAL_DIR=$DATA_DIR/val_oi
76 |
77 | DISCARD=$OUT_DIR/discard
78 | DISCARD_VAL=$OUT_DIR/discard_val
79 | pushd $SCRIPT_DIR
80 |
81 | echo "Importing train..."
82 | # NOTE: this is were you want to employ parallelization on a cluster if it's available.
83 | # See import_train_images.py
84 | python import_train_images.py $DOWNLOAD_DIR $TRAIN_0 $TRAIN_1 $TRAIN_2 \
85 | --out_dir_clean=$FINAL_TRAIN_DIR \
86 | --out_dir_discard=$DISCARD \
87 | --resolution=512
88 |
89 | python import_train_images.py $DOWNLOAD_DIR $VAL \
90 | --out_dir_clean=$FINAL_VAL_DIR \
91 | --out_dir_discard=$DISCARD_VAL \
92 | --resolution=512
93 |
94 | # Update Cache ----------
95 | CACHE_P=$DATA_DIR/cache.pkl
96 | export PYTHONPATH=$(pwd)
97 |
98 | echo "Updating cache $CACHE_P..."
99 | python dataloaders/images_loader.py update $FINAL_TRAIN_DIR "$CACHE_P" --min_size 128
100 | python dataloaders/images_loader.py update $FINAL_VAL_DIR "$CACHE_P" --min_size 128
101 |
102 | echo "----------------------------------------"
103 | echo "Done"
104 | echo "To train, you MUST UPDATE configs/dl/oi.cf:"
105 | echo ""
106 | echo " image_cache_pkl = '$1/cache.pkl'"
107 | echo " train_imgs_glob = '$(realpath $1/train_oi)'"
108 | echo " val_glob = '$(realpath $1/val_oi)'"
109 | echo ""
110 | echo "----------------------------------------"
111 |
--------------------------------------------------------------------------------
/src/pytorch_ext.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 |
19 | --------------------------------------------------------------------------------
20 |
21 | General PyTorch related stuff
22 |
23 | """
24 |
25 | import os
26 | from collections import defaultdict
27 |
28 | import numpy as np
29 | import math
30 |
31 | import torch
32 |
33 | from torch import nn as nn
34 | from torch.utils.data import Dataset
35 |
36 | import itertools
37 |
38 |
39 | CUDA_AVAILABLE = torch.cuda.is_available()
40 |
41 | # IGNORE_CUDA = os.environ.get('IGNORE_CUDA', '0') == '1'
42 | # if IGNORE_CUDA:
43 | # print('*** IGNORE_CUDA=1')
44 |
45 | DEVICE = torch.device("cuda:0" if CUDA_AVAILABLE else "cpu")
46 |
47 |
48 | # This has an effect to all parts reading pytorch_ext.DEVICE *after* it has been set.
49 | def set_device(cuda_available):
50 | global DEVICE
51 | DEVICE = torch.device("cuda:0" if cuda_available else "cpu")
52 |
53 |
54 | # Conv -------------------------------------------------------------------------
55 |
56 |
57 | def default_conv(in_channels, out_channels, kernel_size, bias=True, rate=1, stride=1):
58 | padding = kernel_size // 2 if rate == 1 else rate
59 | return nn.Conv2d(
60 | in_channels, out_channels, kernel_size, stride=stride, dilation=rate,
61 | padding=padding, bias=bias)
62 |
63 |
64 | def initialize_with_filter(conv_or_deconv, f):
65 | assert conv_or_deconv.weight.size() == f.size(), 'Must match: {}, {}'.format(
66 | conv_or_deconv.weight.size(), f.size())
67 | conv_or_deconv.weight.data = f
68 | return conv_or_deconv
69 |
70 |
71 | def initialize_with_id(conv_or_deconv, with_noise=False):
72 | n, Cout, H, W = conv_or_deconv.weight.shape
73 | assert n == Cout and H == W == 1, 'Invalid shape: {}'.format(conv_or_deconv.weight.shape)
74 | eye = torch.eye(n).reshape(n, n, 1, 1)
75 | if with_noise: # nicked from torch.nn.modules.conv:reset_parameters
76 | stdv = 1. / math.sqrt(n)
77 | noise = torch.empty_like(eye)
78 | noise.uniform_(-stdv, stdv)
79 | eye.add_(noise)
80 |
81 | print(eye[:10, :10, 0, 0])
82 | initialize_with_filter(conv_or_deconv, eye)
83 |
84 |
85 | # Numpy -----------------------------------------------------------------------
86 |
87 |
88 | def tensor_to_np(t):
89 | return t.detach().cpu().numpy()
90 |
91 |
92 | def histogram(t, L):
93 | """
94 | A: If t is a list of tensors/np.ndarrays, B is executed for all, yielding len(ts) histograms, which are summed
95 | per bin
96 | B: convert t to numpy, count bins.
97 | :param t: tensor or list of tensor, each expected to be in [0, L)
98 | :param L: number of symbols
99 | :return: length-L array, containing at l the number of values mapping to to symbol l
100 | """
101 | if isinstance(t, list):
102 | ts = t
103 | histograms = np.stack((histogram(t, L) for t in ts), axis=0) # get array (len(ts) x L)
104 | return np.sum(histograms, 0)
105 | assert 0 <= t.min() and t.max() < L, (t.min(), t.max())
106 | a = tensor_to_np(t)
107 | counts, _ = np.histogram(a, np.arange(L+1)) # +1 because np.histogram takes bin edges, including rightmost edge
108 | return counts
109 |
110 |
111 | # Gradients --------------------------------------------------------------------
112 |
113 |
114 | def get_total_grad_norm(params, norm_type=2):
115 | # nicked from torch.nn.utils.clip_grad_norm
116 | with torch.no_grad():
117 | total_norm = 0
118 | for p in params:
119 | if p.grad is None:
120 | continue
121 | param_norm = p.grad.data.norm(norm_type)
122 | total_norm += param_norm.item() ** norm_type
123 | total_norm = total_norm ** (1. / norm_type)
124 | return total_norm
125 |
126 |
127 | def get_average_grad_norm(params, norm_type=2):
128 | """
129 | :param params: Assumed to be generator
130 | :param norm_type:
131 | """
132 | # nicked from torch.nn.utils.clip_grad_norm
133 | with torch.no_grad():
134 | average_norm = 0
135 | num_params = 0
136 | for p in params:
137 | if p.grad is None:
138 | continue
139 | average_norm += p.grad.data.norm(norm_type)
140 | num_params += 1
141 | if num_params == 0:
142 | return 0
143 | return average_norm / float(num_params)
144 |
145 |
146 | # Datasets --------------------------------------------------------------------
147 |
148 |
149 | class TruncatedDataset(Dataset):
150 | def __init__(self, dataset, num_elemens):
151 | assert len(dataset) >= num_elemens, 'Cannot truncate to {}: dataset has {} elements'.format(
152 | num_elemens, len(dataset))
153 | self.dataset = dataset
154 | self.num_elemens = num_elemens
155 |
156 | def __len__(self):
157 | return self.num_elemens
158 |
159 | def __getitem__(self, item):
160 | return self.dataset[item]
161 |
162 |
163 | # Helpful modules --------------------------------------------------------------
164 |
165 |
166 | class LambdaModule(nn.Module):
167 | def __init__(self, forward_lambda, name=''):
168 | super(LambdaModule, self).__init__()
169 | self.forward_lambda = forward_lambda
170 | self.description = 'LambdaModule({})'.format(name)
171 |
172 | def __repr__(self):
173 | return self.description
174 |
175 | def forward(self, x):
176 | return self.forward_lambda(x)
177 |
178 |
179 | class ChannelToLogitsTranspose(nn.Module):
180 | def __init__(self, Cout, Lout):
181 | super(ChannelToLogitsTranspose, self).__init__()
182 | self.Cout = Cout
183 | self.Lout = Lout
184 |
185 | def forward(self, x):
186 | N, C, H, W = x.shape
187 | # unfold channel dimension to (Cout, Lout)
188 | # this fills up the Cout dimension first!
189 | x = x.view(N, self.Lout, self.Cout, H, W)
190 | return x
191 |
192 | def __repr__(self):
193 | return 'ChannelToLogitsTranspose(Cout={}, Lout={})'.format(self.Cout, self.Lout)
194 |
195 |
196 | class LogitsToChannelTranspose(nn.Module):
197 | def __init__(self):
198 | super(LogitsToChannelTranspose, self).__init__()
199 |
200 | def forward(self, x):
201 | N, L, C, H, W = x.shape
202 | # fold channel, L dimension back to channel dim
203 | x = x.view(N, C * L, H, W)
204 | return x
205 |
206 | def __repr__(self):
207 | return 'LogitsToChannelTranspose()'
208 |
209 |
210 | def channel_to_logits(x, Cout, Lout):
211 | N, C, H, W = x.shape
212 | # unfold channel dimension to (Cout, Lout)
213 | # this fills up the Cout dimension first!
214 | x = x.view(N, Lout, Cout, H, W)
215 | return x
216 |
217 |
218 | def logits_to_channel(x):
219 | N, L, C, H, W = x.shape
220 | # fold channel, L dimension back to channel dim
221 | x = x.view(N, C * L, H, W)
222 | return x
223 |
224 |
225 | class OneHot(nn.Module):
226 | """
227 | Take long tensor x of some shape (N,d1,d2,...,dN) containing integers in [0, L),
228 | produces one hot encoding `out` of out_shape (N, d1, ..., L, ..., dN), where out_shape[Ldim] = L, containing
229 | out[n, i, ..., l, ..., j] == {1 if x[n, i, ..., j] == l
230 | 0 otherwise
231 | """
232 | def __init__(self, L, Ldim=1):
233 | super(OneHot, self).__init__()
234 | self.L = L
235 | self.Ldim = Ldim
236 |
237 | def forward(self, x):
238 | return one_hot(x, self.L, self.Ldim)
239 |
240 |
241 | def one_hot(x, L, Ldim):
242 | """ add dim L at Ldim """
243 | assert Ldim >= 0 or Ldim == -1, f'Only supporting Ldim >= 0 or Ldim == -1: {Ldim}'
244 | out_shape = list(x.shape)
245 | if Ldim == -1:
246 | out_shape.append(L)
247 | else:
248 | out_shape.insert(Ldim, L)
249 | x = x.unsqueeze(Ldim) # x must match # dims of outshape
250 | assert x.dim() == len(out_shape), (x.shape, out_shape)
251 | oh = torch.zeros(*out_shape, dtype=torch.float32, device=x.device)
252 | oh.scatter_(Ldim, x, 1)
253 | return oh
254 |
255 |
256 | # ------------------------------------------------------------------------------
257 |
258 |
259 | def assert_equal(t1, t2, show_num_wrong=3, names=None, msg=''):
260 | if t1.shape != t2.shape:
261 | raise AssertionError('Different shapes! {} != {}'.format(t1.shape, t2.shape))
262 | wrong = t1 != t2
263 | if not wrong.any():
264 | return
265 | if names is None:
266 | names = ('t1', 't2')
267 | wrong_idxs = wrong.nonzero()
268 | num_wrong = len(wrong_idxs)
269 | show_num_wrong = min(show_num_wrong, num_wrong)
270 | wrong_idxs = itertools.islice((tuple(i.tolist()) for i in wrong_idxs),
271 | show_num_wrong)
272 | err_msg = ' // '.join('{}: {}!={}'.format(idx, t1[idx], t2[idx])
273 | for idx in wrong_idxs)
274 | raise AssertionError(('{} != {}: {}, and {}/{} other(s) '.format(
275 | names[0], names[1], err_msg, num_wrong - show_num_wrong, np.prod(t1.shape)) + msg).strip())
276 |
277 |
278 | class BatchSummarizer(object):
279 | """
280 | Summarize values from multiple batches
281 | """
282 | def __init__(self, writer, global_step):
283 | self.writer = writer
284 | self.global_step = global_step
285 | self._values = defaultdict(list)
286 |
287 | def append(self, key, values):
288 | self._values[key].append(values)
289 |
290 | def output_summaries(self, strip_for_str='val'):
291 | strs = []
292 | for key, values in self._values.items():
293 | avg = sum(values) / len(values)
294 | self.writer.add_scalar(key, avg, self.global_step)
295 | key = key.replace(strip_for_str, '').strip('/')
296 | strs.append('{}={:.3f}'.format(key, avg))
297 | self.writer.file_writer.flush()
298 | return strs
299 |
--------------------------------------------------------------------------------
/src/test.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 |
19 | --------------------------------------------------------------------------------
20 |
21 | test.py log_dir log_dates images
22 |
23 | test.py log_dir log_dates images --sample=samples
24 |
25 | test.py log_dir log_dates images --write_to_files=l3c_out
26 |
27 | This code uses a cache: If some experiment has already been tested for some iteration and crop and dataset,
28 | we just print that (see TestID in multiscale_tester.py).
29 |
30 | """
31 |
32 | import torch.backends.cudnn
33 | torch.backends.cudnn.benchmark = True
34 |
35 | import argparse
36 | from operator import itemgetter
37 |
38 | from helpers.aligned_printer import AlignedPrinter
39 | from helpers.testset import Testset
40 | from test.multiscale_tester import MultiscaleTester
41 |
42 |
43 |
44 | def main():
45 | p = argparse.ArgumentParser()
46 |
47 | p.add_argument('log_dir', help='Directory of experiments. Will create a new folder, LOG_DIR_test, to save test '
48 | 'outputs.')
49 | p.add_argument('log_dates', help='A comma-separated list, where each entry is a log_date, such as 0104_1345. '
50 | 'These experiments will be tested.')
51 | p.add_argument('images', help='A comma-separated list, where each entry is either a directory with images or '
52 | 'the path of a single image. Will test on all these images.')
53 | p.add_argument('--match_filenames', '-fns', nargs='+', metavar='FILTER',
54 | help='If given, remove any images in the folders given by IMAGES that do not match any '
55 | 'of specified filter.')
56 | p.add_argument('--max_imgs_per_folder', '-m', type=int, metavar='MAX',
57 | help='If given, only use MAX images per folder given in IMAGES. Default: None')
58 | p.add_argument('--crop', type=int, help='Crop all images to CROP x CROP squares. Default: None')
59 |
60 | p.add_argument('--names', '-n', type=str,
61 | help='Comma separated list, if given, must be as long as LOG_DATES. Used for output. If not given, '
62 | 'will just print LOG_DATES as names.')
63 |
64 | p.add_argument('--overwrite_cache', '-f', action='store_true',
65 | help='Ignore cached test outputs, and re-create.')
66 | p.add_argument('--reset_entire_cache', action='store_true',
67 | help='Remove cache.')
68 |
69 | p.add_argument('--restore_itr', '-i', default='-1',
70 | help='Which iteration to restore. -1 means latest iteration. Will use closest smaller if exact '
71 | 'iteration is not found. Default: -1')
72 |
73 | p.add_argument('--recursive', default='0',
74 | help='Either an number or "auto". If given, the rgb configs with num_scales == 1 will '
75 | 'automatically be evaluated recursively (i.e., the RGB baseline). See _parse_recursive_flag '
76 | 'in multiscale_tester.py. Default: 0')
77 |
78 | p.add_argument('--sample', type=str, metavar='SAMPLE_OUT_DIR',
79 | help='Sample from model. Store results in SAMPLE_OUT_DIR.')
80 |
81 | p.add_argument('--write_to_files', type=str, metavar='WRITE_OUT_DIR',
82 | help='Write images to files in folder WRITE_OUT_DIR, with arithmetic coder. If given, the cache is '
83 | 'ignored and no test output is printed. Requires torchac to be installed, see README. Files '
84 | 'that already exist in WRITE_OUT_DIR are overwritten.')
85 | p.add_argument('--compare_theory', action='store_true',
86 | help='If given with --write_to_files, will compare actual bitrate on disk to theoretical bitrate '
87 | 'given by cross entropy.')
88 | p.add_argument('--time_report', type=str, metavar='TIME_REPORT_PATH',
89 | help='If given with --write_to_files, write a report of time needed for each component to '
90 | 'TIME_REPORT_PATH.')
91 |
92 | p.add_argument('--sort_output', '-s', choices=['testset', 'exp', 'itr', 'res'], default='testset',
93 | help='How to sort the final summary. Possible values: "testset" to sort by '
94 | 'name of the testset // "exp" to sort by experiment log_date // "itr" to sort by iteration // '
95 | '"res" to sort by result, i.e., show smaller first. Default: testset')
96 |
97 | flags = p.parse_args()
98 |
99 | if flags.compare_theory and not flags.write_to_files:
100 | raise ValueError('Cannot have --compare_theory without --write_to_files.')
101 | if flags.write_to_files and flags.sample:
102 | raise ValueError('Cannot have --write_to_files and --sample.')
103 | if flags.time_report and not flags.write_to_files:
104 | raise ValueError('--time_report only valid with --write_to_files.')
105 |
106 | testsets = [Testset(images_dir_or_image.rstrip('/'), flags.max_imgs_per_folder,
107 | # Append flags.crop to ID so that it creates unique entry in cache
108 | append_id=f'_crop{flags.crop}' if flags.crop else None)
109 | for images_dir_or_image in flags.images.split(',')]
110 | if flags.match_filenames:
111 | for ts in testsets:
112 | ts.filter_filenames(flags.match_filenames)
113 |
114 | splitter = ',' if ',' in flags.log_dates else '|' # support tensorboard strings, too
115 | results = []
116 | log_dates = flags.log_dates.split(splitter)
117 | for log_date in log_dates:
118 | for restore_itr in map(int, flags.restore_itr.split(',')):
119 | print('Testing {} at {} ---'.format(log_date, restore_itr))
120 | tester = MultiscaleTester(log_date, flags, restore_itr)
121 | results += tester.test_all(testsets)
122 |
123 | # if --names was passed: will print 'name (log_date)'. otherwise, will just print 'log_date'
124 | if flags.names:
125 | names = flags.names.split(splitter) if flags.names else log_dates
126 | names_to_log_date = {log_date: f'{name} ({log_date})'
127 | for log_date, name in zip(log_dates, names)}
128 | else:
129 | # set names to log_dates if --names is not given, i.e., we just output log_date
130 | names_to_log_date = {log_date: log_date for log_date in log_dates}
131 | if not flags.write_to_files:
132 | print('*** Summary:')
133 | with AlignedPrinter() as a:
134 | sortby = {'testset': 0, 'exp': 1, 'itr': 2, 'res': 3}[flags.sort_output]
135 | a.append('Testset', 'Experiment', 'Itr', 'Result')
136 | for testset, log_date, restore_itr, result in sorted(results, key=itemgetter(sortby)):
137 | a.append(testset.id, names_to_log_date[log_date], str(restore_itr), result)
138 |
139 |
140 | if __name__ == '__main__':
141 | main()
142 |
--------------------------------------------------------------------------------
/src/test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/src/test/__init__.py
--------------------------------------------------------------------------------
/src/test/cuda_timer.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | from contextlib import contextmanager
20 | import numpy as np
21 | import torch
22 | import os
23 | import time
24 | import pytorch_ext as pe
25 | from collections import defaultdict
26 | from collections import namedtuple
27 |
28 |
29 | _NO_CUDA_SYNC_OVERWRITE = int(os.environ.get('NO_CUDA_SYNC', 0)) == 1
30 |
31 |
32 | if _NO_CUDA_SYNC_OVERWRITE or not pe.CUDA_AVAILABLE:
33 | sync = lambda: None
34 | else:
35 | sync = torch.cuda.synchronize
36 |
37 |
38 |
39 | class StackLogger(object):
40 | Entry = namedtuple('Entry', ['fmt_str', 'logs'])
41 | CombineCtx = namedtuple('CombineCtx', ['prefixes_of_created_entries', 'fmt_str'])
42 |
43 | def __init__(self, default_fmt_str='{}'):
44 | self.default_fmt_str = default_fmt_str
45 |
46 | self.logs = defaultdict(list)
47 | self._order = []
48 | self._prefixes = []
49 |
50 | self._combine_ctx = None
51 | self._global_skip = False
52 |
53 | @contextmanager
54 | def skip(self, flag):
55 | self._global_skip = flag
56 | yield
57 | self._global_skip = False
58 |
59 | @contextmanager
60 | def prefix_scope(self, p):
61 | self._prefixes.append(p)
62 | yield
63 | del self._prefixes[-1]
64 |
65 | @contextmanager
66 | def combine(self, fmt_str):
67 | if self._combine_ctx is not None:
68 | raise ValueError('Already in combine!')
69 | self._combine_ctx = set(), fmt_str
70 | yield
71 | self._combine_ctx = None
72 |
73 | @contextmanager
74 | def prefix_scope_combine(self, prefix, fmt_str):
75 | with self.prefix_scope(prefix):
76 | with self.combine(fmt_str):
77 | yield
78 |
79 | def log(self, name, msg):
80 | if self._global_skip:
81 | return
82 | prefix = ' '.join(self._prefixes + [name])
83 | if prefix not in self.logs:
84 | self._order.append(prefix)
85 | logs = self._get_log_entry_list(prefix)
86 | logs.append(msg)
87 |
88 | def _get_log_entry_list(self, prefix):
89 | if self._combine_ctx is None:
90 | # Always create a new entry
91 | return self.logs[prefix]
92 |
93 | prefixes_of_created_entries, fmt_str = self._combine_ctx
94 | if prefix not in prefixes_of_created_entries:
95 | prefixes_of_created_entries.add(prefix)
96 | return self._append_entry(prefix, fmt_str).logs
97 |
98 | assert len(self.logs[prefix]) > 0
99 | return self.logs[prefix][-1].logs
100 |
101 | def _append_entry(self, prefix, fmt_str):
102 | log_entry = StackLogger.Entry(fmt_str, logs=[])
103 | self.logs[prefix].append(log_entry)
104 | return log_entry
105 |
106 |
107 | class StackTimeLogger(StackLogger):
108 | def __init__(self, default_fmt_str='{:.5f}'):
109 | super(StackTimeLogger, self).__init__(default_fmt_str)
110 |
111 |
112 | def get_mean_strs(self):
113 | for prefix in self._order:
114 | entries = self.logs[prefix]
115 | first_entry = entries[0]
116 | if isinstance(first_entry, StackLogger.Entry):
117 | num_values_per_entry = len(first_entry.logs)
118 | means = np.zeros(num_values_per_entry, dtype=np.float)
119 | for e in entries:
120 | means += np.array(e.logs)
121 | means = means / len(entries)
122 | yield self._to_str(prefix, first_entry.fmt_str, means)
123 | else: # entries is just a list
124 | mean = np.mean(entries)
125 | yield self._to_str(prefix, self.default_fmt_str, mean)
126 |
127 | def get_last_strs(self):
128 | for prefix in self._order:
129 | entries = self.logs[prefix]
130 | last_entry = entries[-1]
131 | if isinstance(last_entry, StackLogger.Entry):
132 | yield self._to_str(prefix, last_entry.fmt_str, last_entry.logs)
133 | else:
134 | yield self._to_str(prefix, self.default_fmt_str, last_entry)
135 |
136 | @contextmanager
137 | def run(self, name):
138 | sync()
139 | start = time.time()
140 | yield
141 | sync()
142 | duration = time.time() - start
143 | self.log(name, duration)
144 |
145 | @staticmethod
146 | def _to_str(prefix, fmt_str, values):
147 | should_iter = isinstance(values, (list, np.ndarray))
148 | values_str = (fmt_str.format(values) if not should_iter
149 | else '/'.join(fmt_str.format(i, v)
150 | for i, v in enumerate(values)))
151 | return prefix + ': ' + values_str
152 |
153 |
154 | def test_stack_time_logger():
155 | t = StackTimeLogger()
156 | for i in [1, 2]:
157 | with t.prefix_scope('foo'):
158 | with t.prefix_scope('bar'):
159 | with t.run('setup'):
160 | time.sleep(0.1)
161 | with t.combine('c{}: {:.5f}'):
162 | for c in range(5):
163 | with t.run('run'):
164 | time.sleep(c * 0.01)
165 | from pprint import pprint
166 | pprint(t.logs)
167 | pprint(t._order)
168 | print('\n'.join(t.get_mean_strs()))
169 | print('...')
170 | print('\n'.join(t.get_last_strs()))
171 |
172 |
--------------------------------------------------------------------------------
/src/test/image_saver.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | import glob
20 | import os
21 |
22 | import torch
23 | from PIL import Image
24 |
25 | from vis.image_summaries import to_image
26 |
27 |
28 | class ImageSaver(object):
29 | def __init__(self, out_dir):
30 | self.out_dir = out_dir
31 | os.makedirs(self.out_dir, exist_ok=True)
32 | self.saved_fs = []
33 |
34 | def __str__(self):
35 | return 'ImageSaver({})'.format(self.out_dir)
36 |
37 | def save_img(self, img, filename, convert_to_image=True):
38 | """
39 | :param img: image tensor, in {0, ..., 255}
40 | :param filename: output filename
41 | :param convert_to_image: if True, call to_image on img, otherwise assume this has already been done.
42 | :return:
43 | """
44 | if convert_to_image:
45 | img = to_image(img.type(torch.uint8))
46 | out_p = self.get_save_p(filename)
47 | Image.fromarray(img).save(out_p)
48 | return out_p
49 |
50 | def get_save_p(self, file_name):
51 | out_p = os.path.join(self.out_dir, file_name)
52 | self.saved_fs.append(file_name)
53 | return out_p
54 |
55 | def file_starting_with_exists(self, prefix):
56 | check_p = os.path.join(self.out_dir, prefix) + '*'
57 | print(check_p, glob.glob(check_p))
58 | return len(glob.glob(check_p)) > 0
--------------------------------------------------------------------------------
/src/torchac/setup.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 |
19 | --------------------------------------------------------------------------------
20 |
21 | NOTE: Needs PyTorch 1.0 or newer, as the C++ code relies on that API!
22 |
23 | Depending on the environment variable COMPILE_CUDA, compiles the torchac_backend with or
24 | without support for CUDA, into a module called torchac_backend_gpu or torchac_backend_cpu.
25 |
26 | COMPILE_CUDA = auto is equal to yes if one of the supported combinations of nvcc and gcc is found (see
27 | _supported_compilers_available).
28 | COMPILE_CUDA = force means compile with CUDA, even if it is not one of the supported combinations
29 | COMPILE_CUDA = no means no CUDA.
30 |
31 | The only difference between the CUDA and non-CUDA versions is: With CUDA, _get_uint16_cdf from torchac is done with a
32 | simple/non-optimized CUDA kernel (torchac_kernel.cu), which has one benefit: we can directly write into shared memory!
33 | This saves an expensive copying step from GPU to CPU.
34 |
35 | Flags read by this script:
36 | COMPILE_CUDA=[auto|force|no]
37 |
38 | """
39 |
40 | import sys
41 | import re
42 | import subprocess
43 | from setuptools import setup
44 | from distutils.version import LooseVersion
45 | from torch.utils.cpp_extension import CppExtension, BuildExtension, CUDAExtension
46 | import os
47 |
48 |
49 | MODULE_BASE_NAME = 'torchac_backend'
50 |
51 |
52 | def prefixed(prefix, l):
53 | ps = [os.path.join(prefix, el) for el in l]
54 | for p in ps:
55 | if not os.path.isfile(p):
56 | raise FileNotFoundError(p)
57 | return ps
58 |
59 |
60 | def compile_ext(cuda_support):
61 | print('Compiling, cuda_support={}'.format(cuda_support))
62 | ext_module = get_extension(cuda_support)
63 |
64 | setup(name=ext_module.name,
65 | version='1.0.0',
66 | ext_modules=[ext_module],
67 | extra_compile_args=['-mmacosx-version-min=10.9'],
68 | cmdclass={'build_ext': BuildExtension})
69 |
70 |
71 | def get_extension(cuda_support):
72 | # dir of this file
73 | setup_dir = os.path.dirname(os.path.realpath(__file__))
74 | # Where the cpp and cu files are
75 | prefix = os.path.join(setup_dir, MODULE_BASE_NAME)
76 | if not os.path.isdir(prefix):
77 | raise ValueError('Did not find backend foler: {}'.format(prefix))
78 | if cuda_support:
79 | nvcc_avaible, nvcc_version = supported_nvcc_available()
80 | if not nvcc_avaible:
81 | print(_bold_warn_str('***WARN') + ': Found untested nvcc {}'.format(nvcc_version))
82 |
83 | return CUDAExtension(
84 | MODULE_BASE_NAME + '_gpu',
85 | prefixed(prefix, ['torchac.cpp', 'torchac_kernel.cu']),
86 | define_macros=[('COMPILE_CUDA', '1')])
87 | else:
88 | return CppExtension(
89 | MODULE_BASE_NAME + '_cpu',
90 | prefixed(prefix, ['torchac.cpp']))
91 |
92 |
93 | # TODO:
94 | # Add further supported version as specified in readme
95 |
96 |
97 | def _supported_compilers_available():
98 | """
99 | To see an up-to-date list of tested combinations of GCC and NVCC, see the README
100 | """
101 | return _supported_gcc_available()[0] and supported_nvcc_available()[0]
102 |
103 |
104 | def _supported_gcc_available():
105 | v = _get_version(['gcc', '-v'], r'version (.*?)\s+')
106 | return LooseVersion('6.0') > LooseVersion(v) >= LooseVersion('5.0'), v
107 |
108 |
109 | def supported_nvcc_available():
110 | v = _get_version(['nvcc', '-V'], 'release (.*?),')
111 | if v is None:
112 | return False, 'nvcc unavailable!'
113 | return LooseVersion(v) >= LooseVersion('9.0'), v
114 |
115 |
116 | def _get_version(cmd, regex):
117 | try:
118 | otp = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode()
119 | if len(otp.strip()) == 0:
120 | raise ValueError('No output')
121 | m = re.search(regex, otp)
122 | if not m:
123 | raise ValueError('Regex does not match output:\n{}'.format(otp))
124 | return m.group(1)
125 | except FileNotFoundError:
126 | return None
127 |
128 |
129 | def _bold_warn_str(s):
130 | return '\x1b[91m\x1b[1m' + s + '\x1b[0m'
131 |
132 |
133 | def _assert_torch_version_sufficient():
134 | import torch
135 | if LooseVersion(torch.__version__) >= LooseVersion('1.0'):
136 | return
137 | print(_bold_warn_str('Error:'), 'Need PyTorch version >= 1.0, found {}'.format(torch.__version__))
138 | sys.exit(1)
139 |
140 |
141 | def main():
142 | _assert_torch_version_sufficient()
143 |
144 | cuda_flag = os.environ.get('COMPILE_CUDA', 'no')
145 |
146 | if cuda_flag == 'auto':
147 | cuda_support = _supported_compilers_available()
148 | print('Found CUDA supported:', cuda_support)
149 | elif cuda_flag == 'force':
150 | cuda_support = True
151 | elif cuda_flag == 'no':
152 | cuda_support = False
153 | else:
154 | raise ValueError('COMPILE_CUDA must be in (auto, force, no), got {}'.format(cuda_flag))
155 |
156 | compile_ext(cuda_support)
157 |
158 |
159 | if __name__ == '__main__':
160 | main()
161 |
--------------------------------------------------------------------------------
/src/torchac/torchac.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 |
20 | # TODO some comments needed about [..., -1] == 0
21 |
22 | import torch
23 |
24 |
25 | # torchac can be built with and without CUDA support.
26 | # Here, we try to import both torchac_backend_gpu and torchac_backend_cpu.
27 | # If both fail, an exception is thrown here already.
28 | #
29 | # The right version is then picked in the functions below.
30 | #
31 | # NOTE:
32 | # Without a clean build, multiple versions might be installed. You may use python seutp.py clean --all to prevent this.
33 | # But it should not be an issue.
34 |
35 |
36 | import_errors = []
37 |
38 |
39 | try:
40 | import torchac_backend_gpu
41 | CUDA_SUPPORTED = True
42 | except ImportError as e:
43 | CUDA_SUPPORTED = False
44 | import_errors.append(e)
45 |
46 | try:
47 | import torchac_backend_cpu
48 | CPU_SUPPORTED = True
49 | except ImportError as e:
50 | CPU_SUPPORTED = False
51 | import_errors.append(e)
52 |
53 |
54 | imported_at_least_one = CUDA_SUPPORTED or CPU_SUPPORTED
55 |
56 |
57 | # if import_errors:
58 | # import_errors_str = '\n'.join(map(str, import_errors))
59 | # print(f'*** Import errors:\n{import_errors_str}')
60 |
61 |
62 | if not imported_at_least_one:
63 | raise ImportError('*** Failed to import any torchac_backend! Make sure to install torchac with torchac/setup.py. '
64 | 'See the README for details.')
65 |
66 |
67 | any_backend = torchac_backend_cpu if CPU_SUPPORTED else torchac_backend_gpu
68 |
69 |
70 | # print(f'*** torchac: GPU support: {CUDA_SUPPORTED} // CPU support: {CPU_SUPPORTED}')
71 |
72 |
73 | def _get_gpu_backend():
74 | if not CUDA_SUPPORTED:
75 | raise ValueError('Got CUDA tensor, but torchac_backend_gpu is not available. '
76 | 'Compile torchac with CUDA support, or use CPU mode (see README).')
77 | return torchac_backend_gpu
78 |
79 |
80 | def _get_cpu_backend():
81 | if not CPU_SUPPORTED:
82 | raise ValueError('Got CPU tensor, but torchac_backend_cpu is not available. '
83 | 'Compile torchac without CUDA support, or use GPU mode (see README).')
84 | return torchac_backend_cpu
85 |
86 |
87 | def encode_cdf(cdf, sym):
88 | """
89 | :param cdf: CDF as 1HWLp, as int16, on CPU!
90 | :param sym: the symbols to encode, as int16, on CPU
91 | :return: byte-string, encoding `sym`
92 | """
93 | if cdf.is_cuda or sym.is_cuda:
94 | raise ValueError('CDF and symbols must be on CPU for `encode_cdf`')
95 | # encode_cdf is defined in both backends, so doesn't matter which one we use!
96 | return any_backend.encode_cdf(cdf, sym)
97 |
98 |
99 | def decode_cdf(cdf, input_string):
100 | """
101 | :param cdf: CDF as 1HWLp, as int16, on CPU
102 | :param input_string: byte-string, encoding some symbols `sym`.
103 | :return: decoded `sym`.
104 | """
105 | if cdf.is_cuda:
106 | raise ValueError('CDF must be on CPU for `decode_cdf`')
107 | # encode_cdf is defined in both backends, so doesn't matter which one we use!
108 | return any_backend.decode_cdf(cdf, input_string)
109 |
110 |
111 | def encode_logistic_mixture(
112 | targets, means, log_scales, logit_probs_softmax, # CDF
113 | sym):
114 | """
115 | NOTE: This function uses either the CUDA or CPU backend, depending on the device of the input tensors.
116 | NOTE: targets, means, log_scales, logit_probs_softmax must all be on the same device (CPU or GPU)
117 | In the following, we use
118 | Lp: Lp = L+1, where L = number of symbols.
119 | K: number of mixtures
120 | :param targets: values of symbols, tensor of length Lp, float32
121 | :param means: means of mixtures, tensor of shape 1KHW, float32
122 | :param log_scales: log(scales) of mixtures, tensor of shape 1KHW, float32
123 | :param logit_probs_softmax: weights of the mixtures (PI), tensorf of shape 1KHW, float32
124 | :param sym: the symbols to encode. MUST be on CPU!!
125 | :return: byte-string, encoding `sym`.
126 | """
127 | if not (targets.is_cuda == means.is_cuda == log_scales.is_cuda == logit_probs_softmax.is_cuda):
128 | raise ValueError('targets, means, log_scales, logit_probs_softmax must all be on the same device! Got '
129 | f'{targets.device}, {means.device}, {log_scales.device}, {logit_probs_softmax.device}.')
130 | if sym.is_cuda:
131 | raise ValueError('sym must be on CPU!')
132 |
133 | if targets.is_cuda:
134 | return _get_gpu_backend().encode_logistic_mixture(
135 | targets, means, log_scales, logit_probs_softmax, sym)
136 | else:
137 | cdf = _get_uint16_cdf(logit_probs_softmax, targets, means, log_scales)
138 | return encode_cdf(cdf, sym)
139 |
140 |
141 | def decode_logistic_mixture(
142 | targets, means, log_scales, logit_probs_softmax, # CDF
143 | input_string):
144 | """
145 | NOTE: This function uses either the CUDA or CPU backend, depending on the device of the input tensors.
146 | NOTE: targets, means, log_scales, logit_probs_softmax must all be on the same device (CPU or GPU)
147 | In the following, we use
148 | Lp: Lp = L+1, where L = number of symbols.
149 | K: number of mixtures
150 | :param targets: values of symbols, tensor of length Lp, float32
151 | :param means: means of mixtures, tensor of shape 1KHW, float32
152 | :param log_scales: log(scales) of mixtures, tensor of shape 1KHW, float32
153 | :param logit_probs_softmax: weights of the mixtures (PI), tensorf of shape 1KHW, float32
154 | :param input_string: byte-string, encoding some symbols `sym`.
155 | :return: decoded `sym`.
156 | """
157 | if not (targets.is_cuda == means.is_cuda == log_scales.is_cuda == logit_probs_softmax.is_cuda):
158 | raise ValueError('targets, means, log_scales, logit_probs_softmax must all be on the same device! Got '
159 | f'{targets.device}, {means.device}, {log_scales.device}, {logit_probs_softmax.device}.')
160 |
161 | if targets.is_cuda:
162 | return _get_gpu_backend().decode_logistic_mixture(
163 | targets, means, log_scales, logit_probs_softmax, input_string)
164 | else:
165 | cdf = _get_uint16_cdf(logit_probs_softmax, targets, means, log_scales)
166 | return decode_cdf(cdf, input_string)
167 |
168 |
169 | # ------------------------------------------------------------------------------
170 |
171 | # The following code is invoced for when the CDF is not on GPU, and we cannot use torchac/torchac_kernel.cu
172 | # This basically replicates that kernel in pure PyTorch.
173 |
174 | def _get_uint16_cdf(logit_probs_softmax, targets, means, log_scales):
175 | cdf_float = _get_C_cur_weighted(logit_probs_softmax, targets, means, log_scales)
176 | cdf = _renorm_cast_cdf_(cdf_float, precision=16)
177 | cdf = cdf.cpu()
178 | return cdf
179 |
180 |
181 | def _get_C_cur_weighted(logit_probs_softmax_c, targets, means_c, log_scales_c):
182 | C_cur = _get_C_cur(targets, means_c, log_scales_c) # NKHWL
183 | C_cur = C_cur.mul(logit_probs_softmax_c.unsqueeze(-1)).sum(1) # NHWL
184 | return C_cur
185 |
186 |
187 | def _get_C_cur(targets, means_c, log_scales_c): # NKHWL
188 | """
189 | :param targets: Lp floats
190 | :param means_c: NKHW
191 | :param log_scales_c: NKHW
192 | :return:
193 | """
194 | # NKHW1
195 | inv_stdv = torch.exp(-log_scales_c).unsqueeze(-1)
196 | # NKHWL'
197 | centered_targets = (targets - means_c.unsqueeze(-1))
198 | # NKHWL'
199 | cdf = centered_targets.mul(inv_stdv).sigmoid() # sigma' * (x - mu)
200 | return cdf
201 |
202 |
203 | def _renorm_cast_cdf_(cdf, precision):
204 | Lp = cdf.shape[-1]
205 | finals = 1 # NHW1
206 | # RENORMALIZATION_FACTOR in cuda
207 | f = torch.tensor(2, dtype=torch.float32, device=cdf.device).pow_(precision)
208 | cdf = cdf.mul((f - (Lp - 1)) / finals) # TODO
209 | cdf = cdf.round()
210 | cdf = cdf.to(dtype=torch.int16, non_blocking=True)
211 | r = torch.arange(Lp, dtype=torch.int16, device=cdf.device)
212 | cdf.add_(r)
213 | return cdf
214 |
--------------------------------------------------------------------------------
/src/torchac/torchac_backend/torchac_kernel.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * COPYRIGHT 2019 ETH Zurich
3 | */
4 | #include
5 |
6 | #include
7 | #include
8 | #include
9 |
10 |
11 | using cdf_t = uint16_t;
12 | const int PRECISION = 16;
13 | const int RENORMALIZATION_FACTOR = 2 << (PRECISION - 1);
14 |
15 | namespace {
16 | __device__ __forceinline__ float sigmoidf (float a) {
17 | return 1.0 / (1.0 + expf (-a));
18 | }
19 |
20 | __device__ __forceinline__ cdf_t renorm(float cdf, const int Lp, const int l) {
21 | cdf *= (RENORMALIZATION_FACTOR - (Lp - 1));
22 | cdf_t cdf_int = static_cast(lrintf(cdf) + l);
23 | return cdf_int;
24 | }
25 |
26 | __global__ void calculate_cdf_kernel(
27 | const int N, const int Lp, const int K,
28 | const float* __restrict__ targets, // Lp length vector
29 | const float* __restrict__ means,
30 | const float* __restrict__ log_scales,
31 | const float* __restrict__ logit_probs_softmax,
32 | cdf_t* __restrict__ cdf_mem /* out */) {
33 | /**
34 | * Expects to be launched on a N*Lp grid? TODO
35 | *
36 | * means, log_scales, logit_probs_softmax:
37 | * each is a 1KHW matrix reshaped to KN, where N = H*W
38 | * cdf_mem:
39 | * an array of length N * Lp, representing a NxLp matrix, where
40 | * cdf[n][l] = cdf_mem[n*Lp + l]
41 | *
42 | * Code:
43 | * for n, l in range(N) x range(Lp)
44 | * target = l
45 | * cdf_n_l = 0;
46 | * for k in range(K)
47 | * log_scale = log_scales[k][n]
48 | * mean = means[k][n]
49 | * logit_prob = logit_probs_softmax[k][n]
50 | * inv_stdv = exp(log_scale)
51 | * centered_target = target - mean
52 | * cdf_n_l += logit_prob * sigmoid(centered_target * inv_stdv)
53 | * cdf[n][l] = cdf_mem
54 | */
55 | int index = blockIdx.x * blockDim.x + threadIdx.x;
56 | int stride = blockDim.x * gridDim.x;
57 | for (int i_1d = index; i_1d < N * Lp; i_1d += stride) {
58 | const int n = i_1d / Lp;
59 | const int l = i_1d % Lp;
60 |
61 | const float target = targets[l];
62 | float cdf_n_l_float = 0; // initialize
63 |
64 | for (int k = 0; k < K; ++k) {
65 | const float log_scale = log_scales[k * N + n];
66 | const float mean = means[k * N + n];
67 | const float logit_prob = logit_probs_softmax[k * N + n];
68 | const float inv_stdv = expf(-log_scale);
69 | const float centered_target = target - mean;
70 | cdf_n_l_float += logit_prob * sigmoidf(centered_target * inv_stdv);
71 | }
72 |
73 | const int cdf_n_l_idx = i_1d;
74 | cdf_mem[cdf_n_l_idx] = renorm(cdf_n_l_float, Lp, l);
75 | }
76 | }
77 | }
78 |
79 |
80 | cdf_t* malloc_cdf(const int N, const int Lp) {
81 | cdf_t* cdf_mem;
82 | cudaMallocManaged(&cdf_mem, N*Lp*sizeof(cdf_t));
83 | return cdf_mem;
84 | }
85 |
86 |
87 | void free_cdf(cdf_t* cdf_mem) {
88 | cudaFree(cdf_mem);
89 | }
90 |
91 |
92 | template
93 | std::string to_string(const T& object) {
94 | std::ostringstream ss;
95 | ss << object;
96 | return ss.str();
97 | }
98 |
99 | #define CHECK_1KHW(K, x) AT_CHECK(x.sizes().size() == 4 && x.sizes()[0] == 1 && x.sizes()[1] == K, \
100 | "#x must be 4D, got %s", to_string(x.sizes()))
101 |
102 | #define CHECK_CONTIGUOUS_AND_CUDA(x) AT_CHECK(x.is_contiguous() && x.is_cuda(), \
103 | "#x must be contiguous and on GPU, got %d and %d", x.is_contiguous(), x.is_cuda())
104 |
105 | void calculate_cdf(
106 | const at::Tensor& targets,
107 | const at::Tensor& means,
108 | const at::Tensor& log_scales,
109 | const at::Tensor& logit_probs_softmax,
110 | cdf_t * cdf_mem,
111 | const int K, const int Lp, const int N_cdf) {
112 |
113 | CHECK_1KHW(K, means);
114 | CHECK_1KHW(K, log_scales);
115 | CHECK_1KHW(K, logit_probs_softmax);
116 |
117 | CHECK_CONTIGUOUS_AND_CUDA(targets);
118 | CHECK_CONTIGUOUS_AND_CUDA(means);
119 | CHECK_CONTIGUOUS_AND_CUDA(log_scales);
120 | CHECK_CONTIGUOUS_AND_CUDA(logit_probs_softmax);
121 |
122 | AT_CHECK(means.sizes() == log_scales.sizes() &&
123 | log_scales.sizes() == logit_probs_softmax.sizes())
124 |
125 | const auto param_sizes = means.sizes();
126 | const auto N = param_sizes[2] * param_sizes[3]; // H * W
127 | AT_CHECK(N == N_cdf, "%d != %d", N, N_cdf);
128 |
129 | const int blockSize = 1024;
130 | const int numBlocks = (N * Lp + blockSize - 1) / blockSize;
131 |
132 | calculate_cdf_kernel<<>>(
133 | N, Lp, K,
134 | targets.data(),
135 | means.data(),
136 | log_scales.data(),
137 | logit_probs_softmax.data(),
138 | cdf_mem);
139 |
140 | // Wait for GPU to finish before accessing on host
141 | cudaDeviceSynchronize();
142 | }
143 |
144 |
145 |
146 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | import torch
20 |
21 | # seed at least the random number generators.
22 | # doesn't guarantee full reproducability: https://pytorch.org/docs/stable/notes/randomness.html
23 | torch.manual_seed(0)
24 |
25 | # ---
26 |
27 | import argparse
28 | import sys
29 |
30 | import torch.backends.cudnn
31 | from fjcommon import no_op
32 |
33 | import pytorch_ext as pe
34 | from helpers.config_checker import DEFAULT_CONFIG_DIR, ConfigsRepo
35 | from helpers.global_config import global_config
36 | from helpers.saver import Saver
37 | from train.multiscale_trainer import MultiscaleTrainer
38 | from train.train_restorer import TrainRestorer
39 | from train.trainer import LogConfig
40 |
41 | torch.backends.cudnn.benchmark = True
42 |
43 | def _print_debug_info():
44 | print('*' * 80)
45 | print(f'DEVICE == {pe.DEVICE} // PyTorch v{torch.__version__}')
46 | print('*' * 80)
47 |
48 |
49 | def main(args, configs_dir=DEFAULT_CONFIG_DIR):
50 | p = argparse.ArgumentParser()
51 |
52 | p.add_argument('ms_config_p', help='Path to a multiscale config, see README')
53 | p.add_argument('dl_config_p', help='Path to a dataloader config, see README')
54 | p.add_argument('log_dir_root', default='logs', help='All outputs (checkpoints, tensorboard) will be saved here.')
55 | p.add_argument('--temporary', '-t', action='store_true',
56 | help='If given, outputs are actually saved in ${LOG_DIR_ROOT}_TMP.')
57 | p.add_argument('--log_train', '-ltrain', type=int, default=100,
58 | help='Interval of train output.')
59 | p.add_argument('--log_train_heavy', '-ltrainh', type=int, default=5, metavar='LOG_HEAVY_FAC',
60 | help='Every LOG_HEAVY_FAC-th time that i %% LOG_TRAIN is 0, also output heavy logs.')
61 | p.add_argument('--log_val', '-lval', type=int, default=500,
62 | help='Interval of validation output.')
63 |
64 | p.add_argument('-p', action='append', nargs=1,
65 | help='Specify global_config parameters, see README')
66 |
67 | p.add_argument('--restore', type=str, metavar='RESTORE_DIR',
68 | help='Path to the log_dir of the model to restore. If a log_date ('
69 | 'MMDD_HHmm) is given, the model is assumed to be in LOG_DIR_ROOT.')
70 | p.add_argument('--restore_continue', action='store_true',
71 | help='If given, continue in RESTORE_DIR instead of starting in a new folder.')
72 | p.add_argument('--restore_restart', action='store_true',
73 | help='If given, start from iteration 0, instead of the iteration of RESTORE_DIR. '
74 | 'Means that the model in RESTORE_DIR is used as pretrained model')
75 | p.add_argument('--restore_itr', '-i', type=int, default=-1,
76 | help='Which iteration to restore. -1 means latest iteration. Will use closest smaller if exact '
77 | 'iteration is not found. Only valid with --restore. Default: -1')
78 | p.add_argument('--restore_strict', type=str, help='y|n', choices=['y', 'n'], default='y')
79 |
80 | p.add_argument('--num_workers', '-W', type=int, default=8,
81 | help='Number of workers used for DataLoader')
82 |
83 | p.add_argument('--saver_keep_tmp_itr', '-si', type=int, default=250)
84 | p.add_argument('--saver_keep_every', '-sk', type=int, default=10)
85 | p.add_argument('--saver_keep_tmp_last', '-skt', type=int, default=3)
86 | p.add_argument('--no_saver', action='store_true',
87 | help='If given, no checkpoints are stored.')
88 |
89 | p.add_argument('--debug', action='store_true')
90 |
91 | flags = p.parse_args(args)
92 |
93 | _print_debug_info()
94 |
95 | if flags.debug:
96 | flags.temporary = True
97 |
98 | global_config.add_from_flag(flags.p)
99 | print(global_config)
100 |
101 | ConfigsRepo(configs_dir).check_configs_available(flags.ms_config_p, flags.dl_config_p)
102 |
103 | saver = (Saver(flags.saver_keep_tmp_itr, flags.saver_keep_every, flags.saver_keep_tmp_last,
104 | verbose=True)
105 | if not flags.no_saver
106 | else no_op.NoOp())
107 |
108 | restorer = TrainRestorer.from_flags(flags.restore, flags.log_dir_root, flags.restore_continue, flags.restore_itr,
109 | flags.restore_restart, flags.restore_strict)
110 |
111 | trainer = MultiscaleTrainer(flags.ms_config_p, flags.dl_config_p,
112 | flags.log_dir_root + ('_TMP' if flags.temporary else ''),
113 | LogConfig(flags.log_train, flags.log_val, flags.log_train_heavy),
114 | flags.num_workers,
115 | saver=saver, restorer=restorer)
116 | if not flags.debug:
117 | trainer.train()
118 | else:
119 | trainer.debug()
120 |
121 |
122 | if __name__ == '__main__':
123 | main(sys.argv[1:])
124 |
--------------------------------------------------------------------------------
/src/train/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/src/train/__init__.py
--------------------------------------------------------------------------------
/src/train/lr_schedule.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | import numpy as np
20 |
21 | from fjcommon.assertions import assert_exc
22 |
23 | from vis.figure_plotter import PlotToArray
24 |
25 |
26 | SPEC_SEP = '_'
27 |
28 |
29 | def from_spec(s, initial_lr, optims, epoch_len):
30 | """
31 | grammar: one of
32 | none
33 | exp FAC (iITR|eEPOCH)
34 | cos lrmax lrmin (iITR|eEPOCH) # time to finish
35 |
36 | Example:
37 | exp_0.5_e8_warm_20_0.75_e1
38 | """
39 | if s == 'none':
40 | return ConstantLRSchedule()
41 |
42 | schedule_kind, s = s.split(SPEC_SEP, 1)
43 |
44 | def _parse_cos_spec():
45 | lrmax, lrmin, T = s.split(SPEC_SEP)
46 | kind, T = T[0], T[1:]
47 | assert_exc(kind in ('i', 'e'), 'Invalid spec: {}'.format(s))
48 | T_itr = int(T) if kind == 'i' else None
49 | T_epoch = float(T) if kind == 'e' else None
50 | return CosineDecayLRSchedule(optims, float(lrmax), float(lrmin), T_itr, T_epoch, epoch_len)
51 |
52 | p = {'exp': lambda: _parse_exp_spec(s, optims, initial_lr, epoch_len),
53 | 'cos': _parse_cos_spec
54 | }[schedule_kind]
55 | return p()
56 |
57 |
58 | # ------------------------------------------------------------------------------
59 |
60 |
61 | class LRSchedule(object):
62 | def __init__(self, optims):
63 | self.optims = optims
64 | self._current_lr = None
65 |
66 | def _get_lr(self, i):
67 | raise NotImplementedError()
68 |
69 | def update(self, i):
70 | lr = self._get_lr(i)
71 | if lr == self._current_lr:
72 | return
73 | for optim in self.optims:
74 | for pg in optim.param_groups:
75 | pg['lr'] = lr
76 | self._current_lr = lr
77 |
78 |
79 | class ConstantLRSchedule(object):
80 | def update(self, i): # no-op
81 | pass
82 |
83 |
84 |
85 | class ExponentialDecayLRSchedule(LRSchedule):
86 | def __init__(self, optims, initial, decay_fac,
87 | decay_interval_itr=None, decay_interval_epoch=None, epoch_len=None,
88 | warm_restart=None,
89 | warm_restart_schedule=None):
90 | super(ExponentialDecayLRSchedule, self).__init__(optims)
91 | assert_exc((decay_interval_itr is not None) ^ (decay_interval_epoch is not None), 'Need either iter or epoch')
92 | if decay_interval_epoch:
93 | assert epoch_len is not None
94 | decay_interval_itr = int(decay_interval_epoch * epoch_len)
95 | if warm_restart:
96 | warm_restart = int(warm_restart * epoch_len)
97 | self.initial = initial
98 | self.decay_fac = decay_fac
99 | self.decay_every_itr = decay_interval_itr
100 |
101 | self.warm_restart_itr = warm_restart
102 | self.warm_restart_schedule = warm_restart_schedule
103 |
104 | self.last_warm_restart = 0
105 |
106 | def _get_lr(self, i):
107 | if i > 0 and self.warm_restart_itr and ((i - self.last_warm_restart) % self.warm_restart_itr) == 0:
108 | if i != self.last_warm_restart:
109 | self._warm_restart()
110 | self.last_warm_restart = i
111 | i -= self.last_warm_restart
112 | num_decays = i // self.decay_every_itr
113 | return self.initial * (self.decay_fac ** num_decays)
114 |
115 | def _warm_restart(self):
116 | print('WARM restart')
117 | if self.warm_restart_schedule:
118 | self.initial = self.warm_restart_schedule.initial
119 | self.decay_fac = self.warm_restart_schedule.decay_fac
120 | self.decay_every_itr = self.warm_restart_schedule.decay_every_itr
121 | self.warm_restart_itr = self.warm_restart_schedule.warm_restart_itr
122 | self.warm_restart_schedule = self.warm_restart_schedule.warm_restart_schedule
123 |
124 |
125 | class CosineDecayLRSchedule(LRSchedule):
126 | def __init__(self, optims, lrmax, lrmin, T_itr, T_epoch, epoch_len):
127 | super(CosineDecayLRSchedule, self).__init__(optims)
128 | self.lrmax = lrmax
129 | self.lrmin = lrmin
130 | if T_itr is None:
131 | assert epoch_len is not None
132 | T_itr = int(T_epoch * epoch_len)
133 | self.Ti = T_itr
134 | self.epoch_len = epoch_len
135 |
136 | def _get_lr(self, i):
137 | Tcur = (i % self.Ti) / (2 * self.Ti)
138 | return self.lrmin + (self.lrmax - self.lrmin)*(np.cos(np.pi * Tcur))
139 |
140 |
141 | def _parse_exp_spec(s, optims, initial_lr, epoch_len):
142 | if s.count(SPEC_SEP) > 2:
143 | fac, interval, warm, warm_start, warm_fac, warm_interval = s.split(SPEC_SEP)
144 | assert warm == 'warm'
145 | warm_start = int(warm_start)
146 | warm_schedule = _parse_exp_spec(SPEC_SEP.join([warm_fac, warm_interval]), optims, initial_lr, epoch_len)
147 | else:
148 | fac, interval = s.split(SPEC_SEP)
149 | warm_start, warm_schedule = None, None
150 | kind, interval = interval[0], interval[1:]
151 | assert_exc(kind in ('i', 'e'), 'Invalid spec: {}'.format(s))
152 | decay_interval_itr = int(interval) if kind == 'i' else None
153 | decay_interval_epoch = float(interval) if kind == 'e' else None
154 | return ExponentialDecayLRSchedule(
155 | optims, initial_lr, float(fac), decay_interval_itr, decay_interval_epoch, epoch_len,
156 | warm_restart=warm_start, warm_restart_schedule=warm_schedule)
157 |
158 |
--------------------------------------------------------------------------------
/src/train/train_restorer.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | import os
20 |
21 | from helpers import logdir_helpers
22 | from helpers.paths import get_ckpts_dir, get_experiment_dir, CKPTS_DIR_NAME
23 | from helpers.saver import Restorer
24 |
25 |
26 | class TrainRestorer(Restorer):
27 | @staticmethod
28 | def from_flags(restore_dir, log_dir, restore_continue, restore_itr,
29 | restart_at_zero=False, strict='y'):
30 | if restore_dir is None:
31 | return None
32 | strict = {'y': True, 'n': False}[strict]
33 | return TrainRestorer(get_ckpts_dir(get_experiment_dir(log_dir, restore_dir)),
34 | restore_continue, restore_itr, restart_at_zero, strict)
35 |
36 | def __init__(self, out_dir, restore_continue=False, restore_itr=-1, restart_at_zero=False, strict=True,
37 | ckpt_name_fmt='ckpt_{:010d}.pt', tmp_postfix='.tmp'):
38 | """
39 | :param out_dir: ends in ckpts/
40 | :param restore_continue:
41 | :param restart_at_zero:
42 | :param ckpt_name_fmt:
43 | :param tmp_postfix:
44 | """
45 | assert out_dir.rstrip(os.path.sep).endswith(CKPTS_DIR_NAME), out_dir
46 | super(TrainRestorer, self).__init__(out_dir, ckpt_name_fmt, tmp_postfix)
47 | self.restore_continue = restore_continue
48 | self.restore_itr = restore_itr
49 | self.restart_at_zero = restart_at_zero
50 | self.strict = strict
51 |
52 | def restore_desired_ckpt(self, modules):
53 | itrc, ckpt_p = self.get_ckpt_for_itr(self.restore_itr)
54 | print('Restoring {}...'.format(itrc))
55 | return self.restore(modules, ckpt_p, self.strict, restore_restart=self.restart_at_zero)
56 |
57 | def get_log_dir(self):
58 | log_dir = os.path.dirname(self._out_dir) # should be .../logs/MMDD_HHdd config config config
59 | assert logdir_helpers.is_log_dir(log_dir)
60 | return log_dir
61 |
--------------------------------------------------------------------------------
/src/train/trainer.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 |
19 | --------------------------------------------------------------------------------
20 |
21 | General Trainer class, subclassed by MultiscaleTrainer
22 |
23 | """
24 |
25 | from collections import namedtuple
26 |
27 | import torch
28 |
29 | import torchvision
30 | from fjcommon import timer
31 | from fjcommon.no_op import NoOp
32 |
33 | from helpers import logdir_helpers
34 | import vis.safe_summary_writer
35 | from helpers.saver import Saver
36 | from helpers.global_config import global_config
37 | import itertools
38 |
39 | from train.train_restorer import TrainRestorer
40 |
41 |
42 | LogConfig = namedtuple('LogConfig', ['log_train', 'log_val', 'log_train_heavy'])
43 |
44 |
45 | class TimedIterator(object):
46 | def __init__(self, it):
47 | self.t = timer.TimeAccumulator()
48 | self.it = iter(it)
49 |
50 | def __iter__(self):
51 | return self
52 |
53 | def __next__(self):
54 | with self.t.execute():
55 | return next(self.it)
56 |
57 |
58 | # TODO: allow "restart last epoch" or sth
59 | class TrainingSetIterator(object):
60 | """ Implements skipping to a certain iteration """
61 | def __init__(self, skip_to_itr, dl_train):
62 | self.skip_to_itr = skip_to_itr
63 | self.dl_train = dl_train
64 | self.epoch_len = len(self.dl_train)
65 |
66 | def epochs_to_skip(self):
67 | if self.skip_to_itr:
68 | skip_epochs, skip_batches = self.skip_to_itr // self.epoch_len, self.skip_to_itr % self.epoch_len
69 | return skip_epochs, skip_batches
70 | return 0, 0
71 |
72 | def iterator(self, epoch):
73 | """ :returns an iterator over tuples (itr, batch) """
74 | skip_epochs, skip_batches = self.epochs_to_skip()
75 | if epoch < skip_epochs:
76 | print('Skipping epoch {}'.format(epoch))
77 | return [] # nothing to iterate
78 | if epoch > skip_epochs or (epoch == skip_epochs and skip_batches == 0): # iterate like normal
79 | return enumerate(self.dl_train, epoch * len(self.dl_train))
80 | # if we get to here, we are in the first epoch which we should not skip, so skip `skip_batches` batches
81 | it = iter(self.dl_train)
82 | for i in range(skip_batches):
83 | print('\rDropping batch {: 10d}...'.format(i), end='')
84 | if not global_config.get('drop_batches', False):
85 | # would be nice to not load images but this is hard to do as DataLoader caches Dataset's respondes,
86 | # might even be immutable?
87 | next(it) # drop batch
88 | print(' -- dropped {} batches'.format(skip_batches))
89 | return enumerate(it, epoch * len(self.dl_train) + skip_batches)
90 |
91 |
92 | class AbortTrainingException(Exception):
93 | pass
94 |
95 |
96 | class Trainer(object):
97 | def __init__(self, dl_train, dl_val, optims, net, sw: vis.safe_summary_writer.SafeSummaryWriter,
98 | max_epochs, log_config: LogConfig, saver: Saver=None, skip_to_itr=None):
99 |
100 | assert isinstance(optims, list)
101 |
102 | self.dl_train = dl_train
103 | self.dl_val = dl_val
104 | self.optims = optims
105 | self.net = net
106 | self.sw = sw
107 | self.max_epochs = max_epochs
108 | self.log_config = log_config
109 | self.saver = saver if saver is not None else NoOp
110 |
111 | self.skip_to_itr = skip_to_itr
112 |
113 | def continue_from(self, ckpt_dir):
114 | pass
115 |
116 | def train(self):
117 | log_train, log_val, log_train_heavy = self.log_config
118 |
119 | dl_train_it = TrainingSetIterator(self.skip_to_itr, self.dl_train)
120 |
121 | _print_unused_global_config()
122 |
123 | try:
124 | for epoch in (range(self.max_epochs) if self.max_epochs else itertools.count()):
125 | self.print_epoch_sep(epoch)
126 | self.prepare_for_epoch(epoch)
127 | t = TimedIterator(dl_train_it.iterator(epoch))
128 | for i, img_batch in t:
129 | for o in self.optims:
130 | o.zero_grad()
131 | should_log = (i > 0 and i % log_train == 0)
132 | should_log_heavy = (i > 0 and (i / log_train_heavy) % log_train == 0)
133 | self.train_step(i, img_batch,
134 | log=should_log,
135 | log_heavy=should_log_heavy,
136 | load_time='[{:.2e} s/batch load]'.format(t.t.mean_time_spent()) if should_log else None)
137 | self.saver.save(self.modules_to_save(), i)
138 |
139 | if i > 0 and i % log_val == 0:
140 | self._eval(i)
141 | except AbortTrainingException as e:
142 | print('Caught {}'.format(e))
143 | return
144 |
145 | def _eval(self, i):
146 | self.net.eval()
147 | with torch.no_grad():
148 | self.validation_loop(i)
149 | self.net.train()
150 |
151 | def debug(self):
152 | print('Debug ---')
153 | _print_unused_global_config()
154 | self.prepare_for_epoch(0)
155 | self.train_step(0, next(iter(self.dl_train)),
156 | log=True, log_heavy=True, load_time=0)
157 | self._eval(101)
158 |
159 | def print_epoch_sep(self, epoch):
160 | print('-' * 80)
161 | print(' EPOCH {}'.format(epoch))
162 | print('-' * 80)
163 |
164 | def modules_to_save(self):
165 | """ used to save and restore. Should return a dictionary module_name -> nn.Module """
166 | raise NotImplementedError()
167 |
168 | def train_step(self, i, img_batch, log, log_heavy, load_time=None):
169 | raise NotImplementedError()
170 |
171 | def validation_loop(self, i):
172 | raise NotImplementedError()
173 |
174 | def prepare_for_epoch(self, epoch):
175 | pass
176 |
177 | def add_filter_summaray(self, tag, p, global_step):
178 | if len(p.shape) == 1: # bias
179 | C, = p.shape
180 | p = p.reshape(C, 1).expand(C, C).reshape(C, C, 1, 1)
181 |
182 | try:
183 | _, _, H, W = p.shape
184 | except ValueError:
185 | if global_step == 0:
186 | print('INFO: Cannot unpack {} ({})'.format(p.shape, tag))
187 | return
188 |
189 | if H == W == 1: # 1x1 conv
190 | p = p[:, :, 0, 0]
191 | filter_vis = torchvision.utils.make_grid(p, normalize=True)
192 | self.sw.add_image(tag, filter_vis, global_step)
193 |
194 | @staticmethod
195 | def update_lrs(epoch, optims_lrs_facs_interval):
196 | raise DeprecationWarning('use lr_schedule.py')
197 | for optimizer, init_lr, decay_fac, interval_epochs in optims_lrs_facs_interval:
198 | if decay_fac is None:
199 | continue
200 | Trainer.exp_lr_scheduler(optimizer, epoch, init_lr, decay_fac, interval_epochs)
201 |
202 | @staticmethod
203 | def exp_lr_scheduler(optimizer, epoch, init_lr, decay_fac=0.1, interval_epochs=7):
204 | raise DeprecationWarning('use lr_schedule.py')
205 | lr = init_lr * (decay_fac ** (epoch // interval_epochs))
206 | print('LR = {}'.format(lr))
207 | for param_group in optimizer.param_groups:
208 | param_group['lr'] = lr
209 |
210 | def get_lrs(self):
211 | for optim in self.optims:
212 | for param_group in optim.param_groups:
213 | yield param_group['lr']
214 |
215 | def maybe_restore(self, restorer: TrainRestorer):
216 | """
217 | :return: skip_to_itr
218 | """
219 | if restorer is None:
220 | return None # start from 0
221 | restore_itr = restorer.restore_desired_ckpt(self.modules_to_save()) # TODO: allow arbitrary ckpts
222 | if restorer.restart_at_zero:
223 | return 0
224 | return restore_itr
225 |
226 | @staticmethod
227 | def get_log_dir(log_dir_root, rel_paths, restorer, strip_ext='.cf'):
228 | if not restorer or not restorer.restore_continue:
229 | log_dir = logdir_helpers.create_unique_log_dir(
230 | rel_paths, log_dir_root, strip_ext=strip_ext, postfix=global_config.values())
231 | print('Created {}...'.format(log_dir))
232 | else:
233 | log_dir = restorer.get_log_dir()
234 | print('Using {}...'.format(log_dir))
235 | return log_dir
236 |
237 |
238 | def _print_unused_global_config(ignore=None):
239 | """ For safety, print parameters that were passed with -p but never used during construction of graph. """
240 | if not ignore:
241 | ignore = []
242 | unused = [u for u in global_config.get_unused_params() if u not in ignore]
243 | if unused:
244 | raise ValueError('Unused params:\n- ' + '\n- '.join(unused))
245 |
246 |
--------------------------------------------------------------------------------
/src/vis/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fab-jul/L3C-PyTorch/469d43b74583976895923138145e0bf4436e5dc9/src/vis/__init__.py
--------------------------------------------------------------------------------
/src/vis/figure_plotter.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | import torch
20 |
21 | from sys import platform
22 | import matplotlib as mpl
23 | if platform != 'darwin':
24 | mpl.use('Agg') # No display
25 | import matplotlib.pyplot as plt
26 | import matplotlib.backends.backend_agg as plt_backend_agg
27 | import numpy as np
28 |
29 |
30 | class PlotToArray(object):
31 | """
32 | p = PlotToArray()
33 | plt = p.prepare()
34 | # add plot to plt
35 | im = plt.get_numpy() # CHW
36 | """
37 | def __init__(self):
38 | self.fig = None
39 |
40 | def prepare(self):
41 | self.fig = plt.figure(dpi=100)
42 | return plt
43 |
44 | def get_numpy(self):
45 | assert self.fig is not None
46 | return figure_to_image(self.fig) # CHW
47 |
48 | def get_tensor(self):
49 | return torch.from_numpy(self.get_numpy()) # CHW
50 |
51 |
52 | def figure_to_image(figures, close=True):
53 | if not isinstance(figures, list):
54 | image = _render_to_rgb(figures, close)
55 | return image
56 | else:
57 | images = [_render_to_rgb(figure, close) for figure in figures]
58 | return np.stack(images)
59 |
60 | def _render_to_rgb(figure, close):
61 | canvas = plt_backend_agg.FigureCanvasAgg(figure)
62 | canvas.draw()
63 | data = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8)
64 | w, h = figure.canvas.get_width_height()
65 | image_hwc = data.reshape([h, w, 4])[..., :3]
66 | image_chw = np.moveaxis(image_hwc, source=2, destination=0)
67 | if close:
68 | plt.close(figure)
69 | return image_chw
70 |
--------------------------------------------------------------------------------
/src/vis/grid.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | import functools
20 |
21 | import torch
22 | from fjcommon import functools_ext as ft
23 | from torch.nn import functional as F
24 |
25 |
26 | # TODO: document thoroughly
27 | def prep_for_grid(x, pad_to=None, channelwise=False, insert_empty_indices=None):
28 | if insert_empty_indices is not None:
29 | assert isinstance(insert_empty_indices, list)
30 | assert isinstance(x, list), 'Need list for insert_empty_indices, got {}'.format(x)
31 | if isinstance(x, tuple) or isinstance(x, list):
32 | if insert_empty_indices:
33 | some_x = x[0]
34 | for idx in sorted(insert_empty_indices, reverse=True):
35 | x.insert(idx, torch.zeros_like(some_x))
36 | if pad_to is None: # we are given a list of tensors, they must be padded!
37 | pad_to = max(el.size()[-1] for el in x) # maximum width
38 | _prep_for_grid = functools.partial(prep_for_grid,
39 | pad_to=pad_to, channelwise=channelwise, insert_empty_indices=None)
40 | return ft.lconcat(map(_prep_for_grid, x))
41 |
42 | if x.dim() == 2: # HW
43 | x = x.unsqueeze(0) # NHW
44 | if x.dim() == 3: # NHW
45 | assert not channelwise
46 | x = x.unsqueeze(1) # NCHW
47 | assert x.dim() == 4, "Expected NCHW"
48 | x = x[0, ...] # now: CHW
49 | if pad_to:
50 | w = x.size()[-1]
51 | pad = (pad_to - w) // 2
52 | if pad:
53 | x = F.pad(x, (pad, pad, pad, pad))
54 | if channelwise:
55 | return [c.unsqueeze(0) for c in torch.unbind(x, 0)]
56 | else:
57 | return [x]
58 |
--------------------------------------------------------------------------------
/src/vis/histogram_plot.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 |
19 | --------------------------------------------------------------------------------
20 |
21 | Class to create histograms of tensors.
22 |
23 | """
24 |
25 | from helpers import rolling_buffer
26 | from vis.summarizable_module import SummarizableModule
27 |
28 |
29 | class HistogramPlot(SummarizableModule):
30 | def __init__(self, prefix, name, buffer_size, num_inputs_to_buffer=1, per_channel=False, most_mass=5e-5):
31 | """
32 | :param prefix:
33 | :param name: Name in TensorBoard
34 | :param buffer_size: buffer size
35 | :param num_inputs_to_buffer: x[:num_inputs_to_buffer, ...] will be stored only
36 | :param per_channel: if True, create a histo per channel
37 | """
38 | super(HistogramPlot, self).__init__()
39 | self.buffer_size = buffer_size
40 | self.num_inputs_to_buffer = num_inputs_to_buffer
41 | self.per_channel = per_channel
42 | self.buffers = None
43 | self.num_chan = None # non-None if self.per_channel and forward has been called at least once
44 | self.figure_creator = {name: self._plot}
45 | self.prefix = prefix
46 | self.name = name
47 | self.most_mass = most_mass
48 |
49 | def forward(self, x):
50 | """
51 | :param x: Tensor
52 | :returns: x
53 | """
54 | if not self.training: # only during training
55 | return x
56 | if self.buffers is None:
57 | self.buffers = self._new_buffers(x.detach())
58 | if self.per_channel:
59 | for c in range(self.num_chan):
60 | self.buffers[c].add(x[:self.num_inputs_to_buffer, c, ...].detach())
61 | else:
62 | self.buffers.add(x[:self.num_inputs_to_buffer, ...].detach())
63 | # register for plotting
64 | self.summarizer.register_figures(self.prefix, self.figure_creator)
65 | return x
66 |
67 | def _plot(self, plt):
68 | """ Called when summarizer decides to plot. """
69 | for b in self._iter_over_buffers():
70 | x, y = b.plot(bins=128, most_mass=self.most_mass)
71 | plt.plot(x, y)
72 |
73 | def _iter_over_buffers(self):
74 | if not self.per_channel: # self.buffers just a RollingBuffer instance
75 | yield self.buffers
76 | else:
77 | yield from self.buffers
78 |
79 | def _new_buffers(self, x):
80 | if not self.per_channel:
81 | return rolling_buffer.RollingBufferHistogram(self.buffer_size, self.name)
82 | self.num_chan = x.shape[1]
83 | return [rolling_buffer.RollingBufferHistogram(self.buffer_size, self.name)
84 | for _ in range(self.num_chan)]
--------------------------------------------------------------------------------
/src/vis/histogram_plotter.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | import numpy as np
20 |
21 |
22 | _WIDTH = 0.8
23 |
24 |
25 | # TODO: rename to something with bar plot
26 |
27 |
28 | def plot_histogram(datas, plt):
29 | for i, data in enumerate(datas):
30 | rel_i = i - len(datas) / 2
31 | w = _WIDTH/len(datas)
32 | _plot_histogram(data, plt, w, rel_i * w)
33 | plt.legend()
34 |
35 |
36 | def _plot_histogram(data, plt, width, offset):
37 | name, values = data
38 | plt.bar(np.arange(len(values)) + offset, values,
39 | width=width,
40 | label=name, align='edge')
41 |
42 |
43 | def _test():
44 | import matplotlib.pyplot as plt
45 |
46 | f = plt.figure()
47 | datas = [('gt', [1000, 10, 33, 500, 600, 700]),
48 | ('outs', [900, 20, 0, 0, 100, 1000]),
49 | ('ups', 0.5 * np.array([900, 20, 0, 0, 100, 1000])),
50 | ]
51 | plot_histogram(datas, plt)
52 | f.show()
53 |
54 |
55 | if __name__ == '__main__':
56 | _test()
57 |
--------------------------------------------------------------------------------
/src/vis/image_summaries.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | import numpy as np
20 | import pytorch_ext as pe
21 | from fjcommon.assertions import assert_exc
22 |
23 |
24 | def imshow(img):
25 | """ Only meant for local visualizations. Requires HW3 """
26 | import matplotlib.pyplot as plt
27 | assert isinstance(img, np.ndarray)
28 | assert img.dtype == np.uint8
29 | assert img.ndim == 3
30 | if img.shape[0] == 3:
31 | img = img.transpose(1, 2, 0)
32 | assert img.shape[2] == 3, img.shape
33 | plt.imshow(img)
34 | plt.show()
35 |
36 |
37 | def to_image(t):
38 | """
39 | :param t: tensor or np.ndarray, may be of shape NCHW / CHW with C=1 or 3 / HW, dtype float32 or uint8. If float32:
40 | must be in [0, 1]
41 | :return: HW3 uint8 np.ndarray
42 | """
43 | if not isinstance(t, np.ndarray):
44 | t = pe.tensor_to_np(t)
45 | # - t is numpy array
46 | if t.ndim == 4:
47 | # - t has batch dimension, only use first
48 | t = t[0, ...]
49 | elif t.ndim == 2:
50 | t = np.expand_dims(t, 0) # Now 1HW
51 | assert_exc(t.ndim == 3, 'Invalid shape: {}'.format(t.shape))
52 | # - t is 3 dimensional CHW numpy array
53 | if t.dtype != np.uint8:
54 | assert_exc(t.dtype == np.float32, 'Expected either uint8 or float32, got {}'.format(t.dtype))
55 | _check_range(t, 0, 1)
56 | t = (t * 255.).astype(np.uint8)
57 | # - t is uint8 numpy array
58 | num_channels = t.shape[0]
59 | if num_channels == 3:
60 | t = np.transpose(t, (1, 2, 0))
61 | elif num_channels == 1:
62 | t = np.stack([t[0, :, :] for _ in range(3)], -1)
63 | else:
64 | raise ValueError('Expected CHW, got {}'.format(t.shape))
65 | assert_exc(t.ndim == 3 and t.shape[2] == 3, str(t.shape))
66 | # - t is uint8 numpy array of shape HW3
67 | return t
68 |
69 |
70 | def _check_range(a, lo, hi):
71 | a_lo, a_hi = np.min(a), np.max(a)
72 | assert_exc(a_lo >= lo and a_hi <= hi, 'Invalid range: [{}, {}]. Expected: [{}, {}]'.format(a_lo, a_hi, lo, hi))
73 |
74 |
75 |
--------------------------------------------------------------------------------
/src/vis/safe_summary_writer.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 | """
19 | from contextlib import contextmanager
20 |
21 | from tensorboardX import SummaryWriter
22 | from tensorboardX.summary import Summary, _clean_tag, make_image as make_image_summary
23 |
24 | from vis.figure_plotter import PlotToArray
25 | from vis.image_summaries import to_image
26 |
27 |
28 | class SafeSummaryWriter(SummaryWriter):
29 | def get_events_file_path(self):
30 | try:
31 | return self.file_writer.event_writer._ev_writer._file_prefix
32 | except AttributeError:
33 | print('Cannot get events file name...')
34 | return None
35 |
36 | @staticmethod
37 | def pre(prefix, tag):
38 | assert prefix[0] != '/'
39 | return prefix.rstrip('/') + '/' + tag.lstrip('/')
40 |
41 | def add_image(self, tag, img_tensor, global_step=None, **kwargs):
42 | """
43 | Add image img_tensor to summary.
44 | img_tensor can be np.ndarray or torch.Tensor, 1HW or 3HW or HW
45 | If img_tensor is uint8:
46 | add it like it is
47 | If img_tensor is float32:
48 | check that it is in [0, 1] and add it
49 | :param **kwargs:
50 | """
51 | if len(img_tensor.shape) == 2:
52 | img_tensor = img_tensor.reshape(1, *img_tensor.shape)
53 | self.file_writer.add_summary(SafeSummaryWriter._to_image_summary_safe(tag, img_tensor), global_step)
54 |
55 | @contextmanager
56 | def add_figure_ctx(self, tag, global_step=None):
57 | """
58 | Context manager that yields a plt to draw on, converts it to image
59 | """
60 | p = PlotToArray()
61 | yield p.prepare() # user draws
62 | self.add_image(tag, p.get_numpy(), global_step, )
63 |
64 | @staticmethod
65 | def _to_image_summary_safe(tag, tensor):
66 | tag = _clean_tag(tag)
67 | img = make_image_summary(to_image(tensor))
68 | return Summary(value=[Summary.Value(tag=tag, image=img)])
69 |
70 |
--------------------------------------------------------------------------------
/src/vis/summarizable_module.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2019, ETH Zurich
3 |
4 | This file is part of L3C-PyTorch.
5 |
6 | L3C-PyTorch is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | any later version.
10 |
11 | L3C-PyTorch is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with L3C-PyTorch. If not, see .
18 |
19 | --------------------------------------------------------------------------------
20 |
21 | Contains the neat SummarizableModule class. It's a replacement for nn.Module. If in a tree of nn.Modules, the root
22 | is a SummarizableModule and some leaves are also, you can call `register_summarizer` on the root, and this will add
23 | an instance of `Summarizer` to every SummarizableModule in the tree.
24 | Then, the module can call stuff like:
25 |
26 | self.summarizer.register_scalars('train', {'lr': self.lr})
27 |
28 | def _plot(plt):
29 | plt.plot(x, y)
30 |
31 | self.summarizer.register_figures('val', {'niceplot': _plot})
32 |
33 | """
34 | from contextlib import contextmanager
35 |
36 | import torch
37 | from fjcommon.assertions import assert_exc
38 | from fjcommon.no_op import NoOp
39 | from torch import nn as nn
40 |
41 |
42 | class _GlobalStepDependable(object):
43 | def __init__(self):
44 | self.global_step = None
45 | self.enabled_prefix = None
46 |
47 | def enable(self, prefix, global_step):
48 | """ Enable logging of prefix """
49 | assert_exc(isinstance(prefix, str), 'prefix must be str, got {}'.format(prefix))
50 | assert_exc(prefix[-1] != '/')
51 | self.enabled_prefix = prefix
52 | self.global_step = global_step
53 |
54 | def disable(self):
55 | self.enabled_prefix = None
56 |
57 | @contextmanager
58 | def maybe_enable(self, prefix, flag, global_step):
59 | if flag:
60 | self.enable(prefix, global_step)
61 | yield
62 | self.disable()
63 |
64 |
65 | def normalize_to_0_1(t):
66 | return t.add(-t.min()).div(t.max() - t.min() + 1e-5)
67 |
68 |
69 | class Summarizer(_GlobalStepDependable):
70 | def __init__(self, sw):
71 | super(Summarizer, self).__init__()
72 | self.sw = sw
73 |
74 | def register_scalars(self, prefix, values):
75 | """
76 | :param prefix: Prefix to use in TensorBoard
77 | :param values: A dictionary of name -> value, where value can be callable (useful if it is expensive)
78 | """
79 | if self.enabled_prefix is None:
80 | return
81 | if prefix == 'auto':
82 | prefix = self.enabled_prefix
83 | if prefix == self.enabled_prefix:
84 | for name, value in values.items():
85 | self.sw.add_scalar(prefix + '/' + name, _convert_if_callable(value), self.global_step)
86 |
87 | def register_figures(self, prefix, creators):
88 | """
89 | :param prefix: Prefix to use in TensorBoard
90 | :param creators: plot_name -> (plt -> None)
91 | """
92 | if self.enabled_prefix is None:
93 | return
94 | if prefix == 'auto':
95 | prefix = self.enabled_prefix
96 | if prefix == self.enabled_prefix:
97 | for name, creator in creators.items():
98 | with self.sw.add_figure_ctx(prefix + '/' + name, self.global_step) as plt:
99 | creator(plt)
100 |
101 | def register_images(self, prefix, imgs, normalize=False): # , early_only=False):
102 | """
103 | :param prefix: Prefix to use in TensorBoard
104 | :param imgs: A dictionary of name -> img, where img can be callable (useful if it is expensive)
105 | :param normalize: If given, will normalize imgs to [0,1]
106 | """
107 | if self.enabled_prefix is None:
108 | return
109 | if prefix == 'auto':
110 | prefix = self.enabled_prefix
111 | if prefix == self.enabled_prefix:
112 | for name, img in imgs.items():
113 | img = _convert_if_callable(img)
114 | if normalize:
115 | img = normalize_to_0_1(img)
116 | self.sw.add_image(prefix + '/' + name, img, self.global_step)
117 |
118 |
119 | def _convert_if_callable(v):
120 | if hasattr(v, '__call__'): # python 3 only
121 | return v()
122 | return v
123 |
124 |
125 | class SummarizableModule(nn.Module):
126 | def __init__(self):
127 | super(SummarizableModule, self).__init__()
128 | self.summarizer = NoOp
129 |
130 | def forward(self, *input):
131 | raise NotImplementedError
132 |
133 | def register_summarizer(self, summarizer: Summarizer):
134 | for m in iter_modules_of_class(self, SummarizableModule):
135 | m.summarizer = summarizer
136 |
137 |
138 | def iter_modules_of_class(root_module: nn.Module, cls):
139 | """
140 | Helpful for extending nn.Module. How to use:
141 | 1. define new nn.Module subclass with some new instance methods, cls
142 | 2. make your root module inherit from cls
143 | 3. make some leaf module inherit from cls
144 | """
145 | for m in root_module.modules():
146 | if isinstance(m, cls):
147 | yield m
148 |
149 |
150 | # Tests ------------------------------------------------------------------------
151 |
152 |
153 | def test_submodules():
154 | class _T(nn.Module):
155 | def __init__(self):
156 | super(_T, self).__init__()
157 | self.foo = []
158 |
159 | def register_foo(self, f):
160 | self.foo.append(f)
161 |
162 | def get_all(self):
163 | for m_ in iter_modules_of_class(self, _T):
164 | yield from m_.foo
165 |
166 | class _SomethingWithTs(nn.Module):
167 | def __init__(self):
168 | super(_SomethingWithTs, self).__init__()
169 | self.a_t = _T()
170 |
171 | class _M(_T): # first T
172 | def __init__(self):
173 | super(_M, self).__init__()
174 | self.conv = nn.Conv2d(1, 2, 3)
175 | self.t = _T() # here
176 | self.t.register_foo(1)
177 | self.list = nn.ModuleList(
178 | [nn.Conv2d(1, 2, 3),
179 | _T()]) # here
180 | inner = _T()
181 | self.seq = nn.Sequential(
182 | nn.Conv2d(1, 2, 3),
183 | inner, # here
184 | _SomethingWithTs()) # here
185 | inner.register_foo(2)
186 |
187 | m = _M()
188 | all_ts = list(iter_modules_of_class(m, _T))
189 | assert len(all_ts) == 5, all_ts
190 | assert list(m.get_all()) == [1, 2]
191 |
--------------------------------------------------------------------------------