├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── colabs
└── OpticalFlow-Inference.ipynb
├── examples
├── davis
│ ├── frame_0000.png
│ ├── frame_0001.png
│ ├── output_pwc_it_ft.png
│ └── output_pwc_it_pre.png
├── examples_pwc_it_ft.png
├── middlebury_dogdance
│ ├── frame_0007.png
│ ├── frame_0008.png
│ ├── output_pwc_it_ft.png
│ └── output_pwc_it_pre.png
└── teaser.png
├── requirements.txt
└── src
└── dataset_lib
└── augmentations
├── __init__.py
├── aug_params.py
├── augmentations.py
├── color_aug.py
├── crop_aug.py
├── image_aug.py
├── pwc_augmentation.py
├── raft_augmentation.py
├── simple_augmentation.py
└── spatial_aug.py
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement (CLA). You (or your employer) retain the copyright to your
10 | contribution; this simply gives us permission to use and redistribute your
11 | contributions as part of the project. Head over to
12 |
7 |
8 | [Disentangling Architecture and Training for Optical Flow](https://arxiv.org/pdf/2203.10712v1.pdf)
9 | [Deqing Sun](https://deqings.github.io/)\*,T, [Charles Herrmann](https://scholar.google.com/citations?user=LQvi5XAAAAAJ&hl=en)\*, [Fitsum Reda](https://fitsumreda.github.io/), [Michael Rubinstein](http://people.csail.mit.edu/mrub/), [David Fleet](https://www.cs.toronto.edu/~fleet/), [William T. Freeman](https://billf.mit.edu/)
10 | Google Research
11 | In ECCV 2022. * denotes equal technical contribution, T denotes project lead.
12 |
13 | 
Left: Large 14 | improvements with newly trained PWC-Net, IRR-PWC and RAFT (left: originally 15 | published results in blue; results of our newly trained models in red). The 16 | newly trained RAFT is more accurate than all published methods on KITTI 2015 at 17 | the time of writing. 18 | 19 | Right: Visual comparison on a Davis sequence between the original [43] 20 | and our newly trained PWC-Net and RAFT, shows improved flow details, e.g. the 21 | hole between the cart and the person at the back. The newly trained PWC-Net 22 | recovers the hole between the cart and the front person better than RAFT.
23 | 24 | [AutoFlow: Learning a Better Training Set for Optical Flow](https://arxiv.org/pdf/2104.14544.pdf)Example 76 | of retrained PWC-Net from our most recent submission (run in this inference 77 | colab)
78 | 79 | ## Training 80 | 81 | The src/ directory currently contains the augmentation module (the Improved 82 | Training result uses the "pwc" augmentation from the augmentations module). The 83 | full train loop will be posted soon. 84 | 85 | ## Citation 86 | 87 | If you find this useful in your works, please acknowledge it appropriately by 88 | citing: 89 | 90 | ``` 91 | @article{sun2022disentangling, 92 | title={Disentangling Architecture and Training for Optical Flow}, 93 | author={Sun, Deqing and Herrmann, Charles and Reda, Fitsum and Rubinstein, Michael and Fleet, David and Freeman, William T}, 94 | journal={arXiv preprint arXiv:2203.10712}, 95 | year={2022} 96 | } 97 | ``` 98 | 99 | ``` 100 | @inproceedings{sun2021autoflow, 101 | title={Autoflow: Learning a better training set for optical flow}, 102 | author={Sun, Deqing and Vlasic, Daniel and Herrmann, Charles and Jampani, Varun and Krainin, Michael and Chang, Huiwen and Zabih, Ramin and Freeman, William T and Liu, Ce}, 103 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 104 | pages={10093--10102}, 105 | year={2021} 106 | } 107 | ``` 108 | 109 | Contact: Deqing Sun (deqingsun @ google.com) and Charles Herrmann (irwinherrmann 110 | @ google.com) 111 | 112 | ## Coding style 113 | 114 | * 2 spaces for indentation 115 | * 80 character line length 116 | * PEP8 formatting 117 | 118 | ## Disclaimer 119 | 120 | This is not an officially supported Google product. 121 | -------------------------------------------------------------------------------- /examples/davis/frame_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/opticalflow-autoflow/037086cb49ba869fd7ff42e3f1b388eae92d4e34/examples/davis/frame_0000.png -------------------------------------------------------------------------------- /examples/davis/frame_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/opticalflow-autoflow/037086cb49ba869fd7ff42e3f1b388eae92d4e34/examples/davis/frame_0001.png -------------------------------------------------------------------------------- /examples/davis/output_pwc_it_ft.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/opticalflow-autoflow/037086cb49ba869fd7ff42e3f1b388eae92d4e34/examples/davis/output_pwc_it_ft.png -------------------------------------------------------------------------------- /examples/davis/output_pwc_it_pre.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/opticalflow-autoflow/037086cb49ba869fd7ff42e3f1b388eae92d4e34/examples/davis/output_pwc_it_pre.png -------------------------------------------------------------------------------- /examples/examples_pwc_it_ft.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/opticalflow-autoflow/037086cb49ba869fd7ff42e3f1b388eae92d4e34/examples/examples_pwc_it_ft.png -------------------------------------------------------------------------------- /examples/middlebury_dogdance/frame_0007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/opticalflow-autoflow/037086cb49ba869fd7ff42e3f1b388eae92d4e34/examples/middlebury_dogdance/frame_0007.png -------------------------------------------------------------------------------- /examples/middlebury_dogdance/frame_0008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/opticalflow-autoflow/037086cb49ba869fd7ff42e3f1b388eae92d4e34/examples/middlebury_dogdance/frame_0008.png -------------------------------------------------------------------------------- /examples/middlebury_dogdance/output_pwc_it_ft.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/opticalflow-autoflow/037086cb49ba869fd7ff42e3f1b388eae92d4e34/examples/middlebury_dogdance/output_pwc_it_ft.png -------------------------------------------------------------------------------- /examples/middlebury_dogdance/output_pwc_it_pre.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/opticalflow-autoflow/037086cb49ba869fd7ff42e3f1b388eae92d4e34/examples/middlebury_dogdance/output_pwc_it_pre.png -------------------------------------------------------------------------------- /examples/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/opticalflow-autoflow/037086cb49ba869fd7ff42e3f1b388eae92d4e34/examples/teaser.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Docker base image: `gcr.io/deeplearning-platform-release/tf2-gpu.2-6:latest` 2 | tensorflow==2.8.2 # The latest should include tensorflow-gpu 3 | tensorflow-addons==0.17.0 4 | mediapy==1.0.3 5 | scikit-image==0.19.1 6 | jupyter 7 | opencv-python 8 | -------------------------------------------------------------------------------- /src/dataset_lib/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /src/dataset_lib/augmentations/aug_params.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import ml_collections 17 | 18 | 19 | def get_params(name): 20 | aug_params = ml_collections.ConfigDict() 21 | 22 | aug_params.name = name 23 | 24 | """Parameters controlling data augmentation.""" 25 | aug_params.crop_height = 320 26 | aug_params.crop_width = 448 27 | aug_params.eval_crop_height = 320 28 | aug_params.eval_crop_width = 768 29 | 30 | aug_params.noise_std_range = 0.06 # range for sampling std of additive noise 31 | aug_params.crop_range_delta = 0.03 # range of relative translation of image 2 32 | aug_params.flow_interpolation = "BILINEAR" # "NEAREST" 33 | 34 | # control params 35 | aug_params.is_schedule_coeff = True # schedule aug coeff for image 2 36 | aug_params.schedule_coeff = 1.0 37 | aug_params.is_channel_swapping = False # True: random swapping color channels 38 | aug_params.is_augment_colors = True 39 | aug_params.is_augment_spatial = True 40 | aug_params.disable_ground_truth = False # True: set ground truth to invalid for semi-supervised training 41 | aug_params.black = False # True: allow out-of-boundary cropping (Chairs) 42 | aug_params.prob_hard_sample = 1.0 # probability that we use the hard sample technique, see line 87 in https://github.com/gengshan-y/VCN/blob/master/dataloader/robloader.py 43 | aug_params.is_random_erasing = False 44 | 45 | # spatial params 46 | aug_params.min_scale = 0.2 47 | aug_params.max_scale = 1.0 48 | aug_params.vflip_prob = 0.0 49 | aug_params.rot1 = 0.4 50 | aug_params.squeeze1 = 0.3 51 | aug_params.scale1 = 0.3 52 | aug_params.tran1 = 0.4 53 | aug_params.scale2 = 0.1 54 | aug_params.lmult_factor = 1. 55 | aug_params.sat_factor = 1. 56 | aug_params.col_factor = 1. 57 | aug_params.ladd_factor = 1. 58 | aug_params.col_rot_factor = 1. 59 | return aug_params 60 | -------------------------------------------------------------------------------- /src/dataset_lib/augmentations/augmentations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import functools 17 | 18 | from augmentations import pwc_augmentation 19 | from augmentations import raft_augmentation 20 | from augmentations import simple_augmentation 21 | 22 | 23 | def no_aug(element, aug_params): 24 | del aug_params 25 | return element 26 | 27 | 28 | ALL_AUGMENTATIONS = { 29 | 'raft': raft_augmentation.apply, 30 | 'pwc': pwc_augmentation.apply, 31 | 'crop': simple_augmentation.apply_crop, 32 | 'resize': simple_augmentation.apply_resize, 33 | 'none': no_aug, 34 | } 35 | 36 | 37 | def get_augmentation_fn(aug_params): 38 | aug_name = aug_params.name 39 | if aug_name not in ALL_AUGMENTATIONS.keys(): 40 | raise NotImplementedError( 41 | 'Unrecognized augmentation: {}'.format(aug_name)) 42 | aug_fn = ALL_AUGMENTATIONS[aug_name] 43 | return functools.partial(aug_fn, aug_params=aug_params) 44 | -------------------------------------------------------------------------------- /src/dataset_lib/augmentations/color_aug.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | 18 | import tensorflow as tf 19 | 20 | 21 | class PCAAug(object): 22 | """Chromatic Eigen Augmentation, translated from VCN 23 | 24 | https://github.com/gengshan-y/VCN, which translates 25 | https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu 26 | """ 27 | 28 | def __init__(self, 29 | lmult_pow=[0.4, 0, -0.2], 30 | lmult_mult=[0.4, 0, 0], 31 | lmult_add=[0.03, 0, 0], 32 | sat_pow=[0.4, 0, 0], 33 | sat_mult=[0.5, 0, -0.3], 34 | sat_add=[0.03, 0, 0], 35 | col_pow=[0.4, 0, 0], 36 | col_mult=[0.2, 0, 0], 37 | col_add=[0.02, 0, 0], 38 | ladd_pow=[0.4, 0, 0], 39 | ladd_mult=[0.4, 0, 0], 40 | ladd_add=[0.04, 0, 0], 41 | col_rotate=[1., 0, 0], 42 | schedule_coeff=1): 43 | # no mean 44 | self.pow_nomean = [1, 1, 1] 45 | self.add_nomean = [0, 0, 0] 46 | self.mult_nomean = [1, 1, 1] 47 | self.pow_withmean = [1, 1, 1] 48 | self.add_withmean = [0, 0, 0] 49 | self.mult_withmean = [1, 1, 1] 50 | self.lmult_pow = 1 51 | self.lmult_mult = 1 52 | self.lmult_add = 0 53 | self.col_angle = 0 54 | if ladd_pow is not None: 55 | self.pow_nomean[0] = tf.exp( 56 | tf.random.normal([], ladd_pow[2], ladd_pow[0])) 57 | if col_pow is not None: 58 | self.pow_nomean[1] = tf.exp(tf.random.normal([], col_pow[2], col_pow[0])) 59 | self.pow_nomean[2] = tf.exp(tf.random.normal([], col_pow[2], col_pow[0])) 60 | 61 | if ladd_add is not None: 62 | self.add_nomean[0] = tf.random.normal([], ladd_add[2], ladd_add[0]) 63 | if col_add is not None: 64 | self.add_nomean[1] = tf.random.normal([], col_add[2], col_add[0]) 65 | self.add_nomean[2] = tf.random.normal([], col_add[2], col_add[0]) 66 | 67 | if ladd_mult is not None: 68 | self.mult_nomean[0] = tf.exp( 69 | tf.random.normal([], ladd_mult[2], ladd_mult[0])) 70 | if col_mult is not None: 71 | self.mult_nomean[1] = tf.exp( 72 | tf.random.normal([], col_mult[2], col_mult[0])) 73 | self.mult_nomean[2] = tf.exp( 74 | tf.random.normal([], col_mult[2], col_mult[0])) 75 | 76 | # with mean 77 | if sat_pow is not None: 78 | self.pow_withmean[1] = tf.exp( 79 | tf.random.uniform([], sat_pow[2] - sat_pow[0], 80 | sat_pow[2] + sat_pow[0])) 81 | self.pow_withmean[2] = self.pow_withmean[1] 82 | if sat_add is not None: 83 | self.add_withmean[1] = tf.random.uniform([], sat_add[2] - sat_add[0], 84 | sat_add[2] + sat_add[0]) 85 | self.add_withmean[2] = self.add_withmean[1] 86 | if sat_mult is not None: 87 | self.mult_withmean[1] = tf.exp( 88 | tf.random.uniform([], sat_mult[2] - sat_mult[0], 89 | sat_mult[2] + sat_mult[0])) 90 | self.mult_withmean[2] = self.mult_withmean[1] 91 | 92 | if lmult_pow is not None: 93 | self.lmult_pow = tf.exp( 94 | tf.random.uniform([], lmult_pow[2] - lmult_pow[0], 95 | lmult_pow[2] + lmult_pow[0])) 96 | if lmult_mult is not None: 97 | self.lmult_mult = tf.exp( 98 | tf.random.uniform([], lmult_mult[2] - lmult_mult[0], 99 | lmult_mult[2] + lmult_mult[0])) 100 | if lmult_add is not None: 101 | self.lmult_add = tf.random.uniform([], lmult_add[2] - lmult_add[0], 102 | lmult_add[2] + lmult_add[0]) 103 | if col_rotate is not None: 104 | self.col_angle = tf.random.uniform([], col_rotate[2] - col_rotate[0], 105 | col_rotate[2] + col_rotate[0]) 106 | 107 | # eigen vectors 108 | self.eigvec = tf.transpose( 109 | tf.reshape([0.51, 0.56, 0.65, 0.79, 0.01, -0.62, 0.35, -0.83, 0.44], 110 | [3, 3])) 111 | 112 | def __call__(self, inputs, target): 113 | inputs = tf.stack([ 114 | self.pca_image(inputs[0, :, :, :]), 115 | self.pca_image(inputs[1, :, :, :]) 116 | ]) 117 | return inputs, target 118 | 119 | def _apply_eig_nomean(self, eig, c, max_abs_eig): 120 | """tf.cond true_fn for eig_nomean.""" 121 | eig_result = eig[:, :, c] / max_abs_eig[c] 122 | a = tf.pow(tf.abs(eig_result), self.pow_nomean[c]) 123 | b = (tf.to_float(eig_result > 0) - 0.5) * 2 124 | eig_result = tf.multiply(a, b) 125 | eig_result = eig_result + self.add_nomean[c] 126 | eig_result = eig_result * self.mult_nomean[c] 127 | return eig_result 128 | 129 | def _apply_eig_withmean(self, eig, max_abs_eig): 130 | """tf.cond true_fn.""" 131 | a = tf.pow(tf.abs(eig[:, :, 0]), self.pow_withmean[0]) 132 | b = (tf.to_float(eig[:, :, 0] > 0) - 0.5) * 2 133 | eig0 = tf.multiply(a, b) 134 | eig0 += self.add_withmean[0] 135 | eig0 *= self.mult_withmean[0] 136 | return eig0 137 | 138 | def _apply_color_angle(self, eig): 139 | """tf.cond true_fn.""" 140 | temp1 = tf.math.cos(self.col_angle) * eig[:, :, 1] - tf.math.sin( 141 | self.col_angle) * eig[:, :, 2] 142 | temp2 = tf.math.sin(self.col_angle) * eig[:, :, 1] + tf.math.cos( 143 | self.col_angle) * eig[:, :, 2] 144 | return tf.stack([eig[:, :, 0], temp1, temp2], -1) 145 | 146 | def _apply_final_step(self, eig, l1, max_abs_eig, max_l): 147 | """tf.cond true_fn.""" 148 | l = tf.sqrt(eig[:, :, 0] * eig[:, :, 0] + eig[:, :, 1] * eig[:, :, 1] + 149 | eig[:, :, 2] * eig[:, :, 2] + 1e-9) 150 | l1 = tf.pow(l1, self.lmult_pow) 151 | l1 = tf.clip_by_value(l1 + self.lmult_add, 0, np.inf) 152 | l1 = l1 * self.lmult_mult 153 | l1 = l1 * max_l 154 | lmask = tf.to_float(l > 1e-2) 155 | eig = eig * tf.expand_dims((1 - lmask), -1) + tf.multiply( 156 | tf.divide(eig, tf.expand_dims(l, -1)), tf.expand_dims( 157 | l1, -1)) * tf.expand_dims(lmask, -1) 158 | eig_list = [] 159 | for c in range(3): 160 | tmp = eig[:, :, c] * (1 - lmask) + tf.clip_by_value( 161 | eig[:, :, c], -np.inf, max_abs_eig[c]) * lmask 162 | eig_list.append(tmp) 163 | eig = tf.stack(eig_list, -1) 164 | return eig 165 | 166 | def pca_image(self, rgb): 167 | eig = tf.matmul(rgb, self.eigvec) 168 | 169 | eig = tf.matmul(rgb, self.eigvec) 170 | mean_rgb = tf.reduce_mean(rgb, [0, 1]) 171 | 172 | max_abs_eig = tf.reduce_max(tf.abs(eig), [0, 1]) 173 | max_l = tf.norm(max_abs_eig) 174 | mean_eig = tf.linalg.matvec(self.eigvec, mean_rgb, transpose_a=True) 175 | # no-mean stuff 176 | eig -= tf.expand_dims(tf.expand_dims(mean_eig, 0), 0) 177 | 178 | mean_eig_list = [] 179 | eig_list = [] 180 | for c in range(3): 181 | is_apply = tf.greater(max_abs_eig[c], 1e-2) 182 | mean_eig0 = tf.cond(is_apply, lambda: mean_eig[c] / max_abs_eig[c], 183 | lambda: mean_eig[c]) 184 | eig0 = tf.cond(is_apply, 185 | lambda: self._apply_eig_nomean(eig, c, max_abs_eig), 186 | lambda: eig[:, :, c]) 187 | mean_eig_list.append(mean_eig0) 188 | eig_list.append(eig0) 189 | mean_eig = tf.stack(mean_eig_list) 190 | eig = tf.stack(eig_list, -1) 191 | 192 | eig += tf.expand_dims(tf.expand_dims(mean_eig, 0), 0) # match:-) 193 | 194 | # withmean stuff 195 | is_apply = tf.greater(max_abs_eig[0], 1e-2) 196 | eig0 = tf.cond(is_apply, lambda: self._apply_eig_withmean(eig, max_abs_eig), 197 | lambda: eig[:, :, 0]) 198 | eig = tf.stack([eig0, eig[:, :, 1], eig[:, :, 2]], -1) 199 | 200 | s = tf.sqrt(eig[:, :, 1] * eig[:, :, 1] + eig[:, :, 2] * eig[:, :, 2] + 201 | 1e-9) 202 | smask = tf.to_float(s > 1e-2) 203 | s1 = tf.pow(s, self.pow_withmean[1]) 204 | s1 = tf.clip_by_value(s1 + self.add_withmean[1], 0, np.inf) 205 | s1 = s1 * self.mult_withmean[1] 206 | s1 = s1 * smask + s * (1 - smask) 207 | 208 | # color angle 209 | is_apply = tf.math.not_equal(self.col_angle, 0) 210 | eig = tf.cond(is_apply, lambda: self._apply_color_angle(eig), lambda: eig) 211 | 212 | # to origin magnitude 213 | eig_list = [] 214 | for c in range(3): 215 | is_apply = tf.greater(max_abs_eig[c], 1e-2) 216 | tmp = tf.cond(is_apply, lambda: eig[:, :, c] * max_abs_eig[c], 217 | lambda: eig[:, :, c]) 218 | eig_list.append(tmp) 219 | 220 | is_apply = tf.greater(max_l, 1e-2) 221 | tmp = tf.to_float( 222 | tf.sqrt( 223 | tf.multiply(eig_list[0], eig_list[0]) + 224 | tf.multiply(eig_list[1], eig_list[1]) + 225 | tf.multiply(eig_list[2], eig_list[2])) / max_l + 1e-9) 226 | l1 = tf.cond(is_apply, lambda: tmp, lambda: tmp * 0) 227 | 228 | eig_list[1] = eig_list[1] * (1 - smask) + eig_list[1] / s * s1 * smask 229 | eig_list[2] = eig_list[2] * (1 - smask) + eig_list[2] / s * s1 * smask 230 | eig = tf.stack(eig_list, -1) 231 | 232 | is_apply = tf.greater(max_l, 1e-2) 233 | eig = tf.cond(is_apply, 234 | lambda: self._apply_final_step(eig, l1, max_abs_eig, max_l), 235 | lambda: eig) 236 | eig = tf.clip_by_value(tf.matmul(eig, tf.transpose(self.eigvec)), 0, 1) 237 | return eig 238 | 239 | 240 | def apply(images, aug_params): 241 | """Augments the input images by applying random color transformations. 242 | 243 | Args: 244 | images: A tensor of size [2, height, width, 3] representing the two RGB 245 | input images with range [-1, 1]. 246 | aug_params: An instance of AugmentationParams to control the color 247 | transformations. 248 | 249 | Returns: 250 | output: A tensor of size [2, height, width, 3] holding the augmented image. 251 | """ 252 | # PCA agumentation 253 | # Convert to [0,1] from [-1,1] 254 | images = (images + 1.)/2. 255 | pcaaug = PCAAug( 256 | lmult_pow=[0.4 * aug_params.lmult_factor, 0, -0.2], 257 | lmult_mult=[0.4 * aug_params.lmult_factor, 0, 0], 258 | lmult_add=[0.03 * aug_params.lmult_factor, 0, 0], 259 | sat_pow=[0.4 * aug_params.sat_factor, 0, 0], 260 | sat_mult=[0.5 * aug_params.sat_factor, 0, -0.3], 261 | sat_add=[0.03 * aug_params.sat_factor, 0, 0], 262 | col_pow=[0.4 * aug_params.col_factor, 0, 0], 263 | col_mult=[0.2 * aug_params.col_factor, 0, 0], 264 | col_add=[0.02 * aug_params.col_factor, 0, 0], 265 | ladd_pow=[ 266 | 0.4 * aug_params.ladd_factor, 267 | 0, 268 | 0, 269 | ], 270 | ladd_mult=[0.4 * aug_params.ladd_factor, 0, 0], 271 | ladd_add=[0.04 * aug_params.ladd_factor, 0, 0], 272 | col_rotate=[1. * aug_params.col_rot_factor, 0, 0], 273 | schedule_coeff=1) 274 | images, _ = pcaaug(images, []) 275 | 276 | # Chromatic augmentation applied to image 2 277 | image1 = images[1, :, :, :] 278 | mean_in = tf.reduce_sum(image1, -1) 279 | color = tf.math.exp( 280 | tf.random.normal([3], 0., 0.02 * aug_params.schedule_coeff)) 281 | image1 = image1 * color 282 | brightness_coeff = tf.divide(mean_in, tf.reduce_sum(image1, -1) + 0.01) 283 | image1 = tf.math.multiply(image1, tf.expand_dims(brightness_coeff, -1)) 284 | image1 = tf.clip_by_value(image1, 0., 1.) 285 | # Gamma 286 | gamma = tf.exp( 287 | tf.random.normal([], 0., 0.02 * aug_params.schedule_coeff)) 288 | image1 = tf.pow(image1, gamma) 289 | # Brightness 290 | image1 += tf.random.normal([], 0, 0.02 * aug_params.schedule_coeff) 291 | # Contrast 292 | image1 = 0.5 + (image1 - 0.5) * tf.exp( 293 | tf.random.normal([], 0, 0.02 * aug_params.schedule_coeff)) 294 | image1 = tf.clip_by_value(image1, 0., 1.) 295 | images = tf.stack([images[0, :, :, :], image1]) 296 | 297 | # Add noise 298 | noise_std = tf.random.uniform([], 0., aug_params.noise_std_range) 299 | noise = tf.random.normal(tf.shape(images), 0, noise_std) 300 | images += noise 301 | 302 | # Clip to [0,1] & scale to [-1,1] 303 | images = tf.clip_by_value(images, 0., 1.) * 2. - 1 304 | 305 | if aug_params.is_channel_swapping: 306 | channel_permutation = tf.constant([[0, 1, 2], [0, 2, 1], [1, 0, 2], 307 | [1, 2, 0], [2, 0, 1], [2, 1, 0]]) 308 | rand_i = tf.random_uniform([], minval=0, maxval=6, dtype=tf.int32) 309 | perm = channel_permutation[rand_i] 310 | images = tf.stack([ 311 | images[:, :, :, perm[0]], images[:, :, :, perm[1]], images[:, :, :, 312 | perm[2]] 313 | ], axis=-1) 314 | 315 | if aug_params.is_random_erasing: 316 | image0, image1 = eraser_transform(images[0, :, :, :], images[1, :, :, :], 317 | bounds=[50, 100], eraser_aug_prob=0.5) 318 | images = tf.stack([image0, image1]) 319 | # images = tf.stack([images[0, :, :, :], random_erasing(images[1, :, :, :])]) 320 | 321 | return images 322 | 323 | def eraser_transform(img1, img2, bounds, eraser_aug_prob=0.5): 324 | ht, wd, _ = tf.unstack(tf.shape(img1)) 325 | pred = tf.random.uniform([]) < eraser_aug_prob 326 | def true_fn(img1, img2): 327 | mean_color = tf.reduce_mean(tf.reshape(img2, (-1, 3)), axis=0) 328 | mean_color = tf.expand_dims(tf.expand_dims(mean_color, axis=0), axis=0) 329 | def body(var_img, mean_color): 330 | x0 = tf.random.uniform([], 0, wd, dtype=tf.int32) 331 | y0 = tf.random.uniform([], 0, ht, dtype=tf.int32) 332 | dx = tf.random.uniform([], bounds[0], bounds[1], dtype=tf.int32) 333 | dy = tf.random.uniform([], bounds[0], bounds[1], dtype=tf.int32) 334 | x = tf.range(wd) 335 | x_mask = (x0 <= x) & (x < x0+dx) 336 | y = tf.range(ht) 337 | y_mask = (y0 <= y) & (y < y0+dy) 338 | mask = x_mask & y_mask[:, tf.newaxis] 339 | mask = tf.cast(mask[:, :, tf.newaxis], img1.dtype) 340 | mean_slice = tf.tile(mean_color, multiples=[ht, wd, 1]) 341 | result = var_img * (1 - mask) + mean_slice * mask 342 | return result 343 | max_num = tf.random.uniform([], 1, 3, dtype=tf.int32) 344 | img2 = body(img2, mean_color) 345 | img2 = tf.cond(2 <= max_num, lambda: body(img2, mean_color), lambda: img2) 346 | return img1, img2 347 | def false_fn(img1, img2): 348 | return img1, img2 349 | 350 | return tf.cond(pred, lambda: true_fn(img1, img2), 351 | lambda: false_fn(img1, img2)) 352 | -------------------------------------------------------------------------------- /src/dataset_lib/augmentations/crop_aug.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import math 17 | import tensorflow as tf 18 | 19 | from augmentations import image_aug 20 | 21 | def crop_to_box(images, forward_flow, y, x, height, width): 22 | """Applies a crop to images and forward_flow.""" 23 | images = tf.image.crop_to_bounding_box(images, y, x, height, width) 24 | forward_flow = tf.image.crop_to_bounding_box(forward_flow, y, x, height, 25 | width) 26 | return images, forward_flow 27 | 28 | 29 | def top_left_crop(element, crop_height, crop_width): 30 | """Crop validation data to be multiples of 64.""" 31 | # crop_height = tf.math.floordiv(element.images.shape[1], 64) * 64 32 | # crop_width = tf.math.floordiv(element.images.shape[2], 64) * 64 33 | 34 | inputs = element['inputs'][:, 0:crop_height, 0:crop_width, :] 35 | label = element['label'][0:crop_height, 0:crop_width, :] 36 | return {'inputs': inputs, 'label': label} 37 | 38 | 39 | def sample_cropping_centers(image_height, image_width, image_stretch_y, 40 | image_stretch_x, rotation_degrees, 41 | image_stretch_y2, image_stretch_x2, 42 | rotation_degrees2, augmentation_params): 43 | """Sample crop_centers for two images.""" 44 | rotated_box_height, rotated_box_width = image_aug.rotated_box_size( 45 | rotation_degrees, augmentation_params) 46 | rotated_box_height2, rotated_box_width2 = image_aug.rotated_box_size( 47 | rotation_degrees2, augmentation_params) 48 | 49 | stretched_image_height = image_stretch_y * tf.cast(image_height, tf.float32) 50 | stretched_image_width = image_stretch_x * tf.cast(image_width, tf.float32) 51 | stretched_image_height2 = image_stretch_y2 * tf.cast(image_height, tf.float32) 52 | stretched_image_width2 = image_stretch_x2 * tf.cast(image_width, tf.float32) 53 | 54 | y_min = tf.maximum(rotated_box_height / 2, rotated_box_height2 / 2) 55 | y_max = tf.minimum(stretched_image_height - rotated_box_height / 2, 56 | stretched_image_height2 - rotated_box_height2 / 2) 57 | 58 | x_min = tf.maximum(rotated_box_width / 2, rotated_box_width2 / 2) 59 | x_max = tf.minimum(stretched_image_width - rotated_box_width / 2, 60 | stretched_image_width2 - rotated_box_width2 / 2) 61 | 62 | center_y = tf.random_uniform([], y_min, y_max) 63 | center_x = tf.random_uniform([], x_min, x_max) 64 | 65 | # Sample crop center of second image conditioned that of first image. 66 | delta_y = augmentation_params.crop_range_delta * augmentation_params.crop_height * augmentation_params.schedule_coeff 67 | y2_min = tf.maximum(rotated_box_height2 / 2, center_y - delta_y) 68 | y2_max = tf.minimum(stretched_image_height2 - rotated_box_height2 / 2, 69 | center_y + delta_y) 70 | 71 | delta_x = augmentation_params.crop_range_delta * augmentation_params.crop_width * augmentation_params.schedule_coeff 72 | x2_min = tf.maximum(rotated_box_width2 / 2, center_x - delta_x) 73 | x2_max = tf.minimum(stretched_image_width2 - rotated_box_width2 / 2, 74 | center_x + delta_x) 75 | center_y2 = tf.random_uniform([], y2_min, y2_max) 76 | center_x2 = tf.random_uniform([], x2_min, x2_max) 77 | 78 | return center_y / image_stretch_y, center_x / image_stretch_x, center_y2 / image_stretch_y2, center_x2 / image_stretch_x2 79 | 80 | 81 | def compose_cropping_transformation(stretch_factor_y, stretch_factor_x, 82 | crop_center_y, crop_center_x, 83 | rotation_degrees, crop_height, crop_width): 84 | """Composes stretching, rotation, and cropping into one 3x3 transformation.""" 85 | # Transforms coordinates from the output space to a "centered" output space 86 | # (i.e., relative to the center of the output). 87 | centering_matrix = tf.stack([(1, 0, -0.5 * (crop_width - 1)), 88 | (0, 1, -0.5 * (crop_height - 1)), (0, 0, 1)], 89 | axis=0) 90 | 91 | # Performs a rotation by |rotation_degrees|. 92 | cos_value = tf.math.cos(rotation_degrees * math.pi / 180) 93 | sin_value = tf.math.sin(rotation_degrees * math.pi / 180) 94 | rotation_matrix = tf.stack([(cos_value, -sin_value, 0), 95 | (sin_value, cos_value, 0), (0, 0, 1)], 96 | axis=0) 97 | 98 | # Performs translation to account for the requested cropping location. Note 99 | # that this translates to the location in the stretched image. 100 | translation_matrix = tf.stack([(1, 0, stretch_factor_x * crop_center_x), 101 | (0, 1, stretch_factor_y * crop_center_y), 102 | (0, 0, 1)], 103 | axis=0) 104 | 105 | # Scales from stretched coordinates to un-stretched coordinates. 106 | scaling_matrix = tf.stack([(1 / stretch_factor_x, 0, 0), 107 | (0, 1 / stretch_factor_y, 0), (0, 0, 1)], 108 | axis=0) 109 | 110 | # Compose our various transformations into one overall transformation. The 111 | # transformation T should be such that T * output_coord = input_coord. The 112 | # action of our composed transformation on output_coord is equivalent to: 113 | # 1) Centering (e.g., mapping x=0 to roughly -width/2). 114 | # 2) Applying the rotation. Due to (1), this rotates about the center rather 115 | # than the top-left. 116 | # 3) Translate to the location of (crop_center_x, crop_center_y) in the 117 | # stretched version of the image. 118 | # 4) Apply inverse stretching to get to original coordinates from |image|. 119 | transform = centering_matrix 120 | transform = tf.matmul(rotation_matrix, transform) 121 | transform = tf.matmul(translation_matrix, transform) 122 | transform = tf.matmul(scaling_matrix, transform) 123 | return transform 124 | -------------------------------------------------------------------------------- /src/dataset_lib/augmentations/image_aug.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Contains various data augmentation utilities for PWC-Net.""" 17 | import math 18 | 19 | import tensorflow as tf 20 | 21 | 22 | def apply_horizontal_flip(images, forward_flow): 23 | """Applies a horizontal flip to images and forward_flow.""" 24 | images = tf.image.flip_left_right(images) 25 | forward_flow = tf.image.flip_left_right(forward_flow) 26 | 27 | # Invert the horizontal component of flow. This is because an object moving to 28 | # the right in the input images (positive horizontal flow) will be moving to 29 | # the left in the flipped images (negative horizontal flow) and vice-versa. 30 | flow_scale_factors = tf.constant([-1, 1], dtype=tf.float32) 31 | forward_flow = forward_flow * flow_scale_factors 32 | return images, forward_flow 33 | 34 | 35 | def random_vertical_flip(img1, img2, flow, prob=0.1): 36 | pred = tf.random.uniform([]) < prob 37 | def true_fn(img1, img2, flow): 38 | img1 = tf.image.flip_up_down(img1) 39 | img2 = tf.image.flip_up_down(img2) 40 | flow = tf.image.flip_up_down(flow) * [1.0, -1.0] 41 | return img1, img2, flow 42 | def false_fn(img1, img2, flow): 43 | return img1, img2, flow 44 | return tf.cond(pred, lambda: true_fn(img1, img2, flow), 45 | lambda: false_fn(img1, img2, flow)) 46 | 47 | 48 | def rotated_box_size(rotation_degrees, augmentation_params): 49 | """Returns the bounding box size after accounting for its rotation.""" 50 | box_diagonal = math.hypot(augmentation_params.crop_width, 51 | augmentation_params.crop_height) 52 | diagonal_angle = math.atan2(augmentation_params.crop_height, 53 | augmentation_params.crop_width) 54 | absolute_rotation_radians = tf.math.abs(rotation_degrees * math.pi / 180) 55 | rotated_height = box_diagonal * tf.sin(diagonal_angle + 56 | absolute_rotation_radians) 57 | rotated_width = box_diagonal * tf.cos(diagonal_angle - 58 | absolute_rotation_radians) 59 | return rotated_height, rotated_width 60 | 61 | -------------------------------------------------------------------------------- /src/dataset_lib/augmentations/pwc_augmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import tensorflow as tf 17 | 18 | from augmentations import color_aug 19 | from augmentations import crop_aug 20 | from augmentations import spatial_aug 21 | 22 | 23 | FLOW_SCALE_FACTOR = 20.0 24 | 25 | INVALID_FLOW_VALUE = 1e9 26 | 27 | 28 | def apply(element, aug_params): 29 | images, forward_flow = element['inputs'], element['label'] 30 | 31 | is_hard_augment = tf.greater_equal(aug_params.prob_hard_sample, 32 | tf.random.uniform([], 0, 1)) 33 | is_hard_augment = tf.cast(is_hard_augment, dtype=tf.bool) 34 | 35 | images, forward_flow = tf.cond(tf.cast(aug_params.is_augment_spatial, dtype=tf.bool), 36 | lambda: spatial_aug.apply(images, forward_flow, aug_params, is_hard_augment), 37 | lambda: no_spatial_op(images, forward_flow, aug_params, is_hard_augment)) 38 | 39 | return tf.cond( 40 | tf.logical_and(is_hard_augment, aug_params.is_augment_colors), 41 | lambda: chromatic_aug(images, forward_flow, aug_params), 42 | lambda: no_op(images, forward_flow)) 43 | 44 | 45 | def no_spatial_op(images, forward_flow, augmentation_params, is_hard_augment): 46 | _, height, width, _ = tf.unstack(tf.shape(images), num=4) 47 | 48 | crop_start_y = tf.random.uniform([], 49 | minval=0, 50 | maxval=height - 51 | augmentation_params.crop_height + 1, 52 | dtype=tf.dtypes.int32) 53 | crop_start_x = tf.random.uniform([], 54 | minval=0, 55 | maxval=width - 56 | augmentation_params.crop_width + 1, 57 | dtype=tf.dtypes.int32) 58 | images, forward_flow = crop_aug.crop_to_box(images, 59 | forward_flow, 60 | crop_start_y, 61 | crop_start_x, 62 | augmentation_params.crop_height, 63 | augmentation_params.crop_width) 64 | 65 | if augmentation_params.disable_ground_truth: 66 | # Set GT to invalid values for semi-supervised training 67 | forward_flow = tf.ones(forward_flow.get_shape())*INVALID_FLOW_VALUE 68 | 69 | return images, forward_flow 70 | 71 | 72 | def chromatic_aug(images, forward_flow, augmentation_params): 73 | # Should be faster than color->spatial and have same effect 74 | images = color_aug.apply(images, augmentation_params) 75 | return {'inputs': images, 'label': forward_flow} 76 | 77 | 78 | def no_op(images, forward_flow): 79 | return {'inputs': images, 'label': forward_flow} 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /src/dataset_lib/augmentations/raft_augmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import tensorflow as tf 17 | 18 | CLIP_MAX = 1e3 19 | DEFAULT_ERASER_BOUNDS = (50, 100) 20 | 21 | 22 | class Augment(tf.keras.layers.Layer): 23 | """Augment object for RAFT""" 24 | 25 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5): 26 | super(Augment, self).__init__() 27 | self.crop_size = crop_size 28 | 29 | self.brightness = (0.6, 1.4) 30 | self.contrast = (0.6, 1.4) 31 | self.saturation = (0.6, 1.4) 32 | self.hue = 0.5 / 3.14 33 | 34 | self.asymmetric_color_aug_prob = 0.2 35 | self.spatial_aug_prob = 0.8 36 | self.eraser_aug_prob = 0.5 37 | 38 | self.min_scale = min_scale 39 | self.max_scale = max_scale 40 | self.max_stretch = 0.2 41 | self.stretch_prob = 0.8 42 | self.margin = 20 43 | 44 | def augment_color(self, images): 45 | brightness_scale = tf.random.uniform([], 46 | self.brightness[0], 47 | self.brightness[1], 48 | dtype=tf.float32) 49 | images = images * brightness_scale 50 | # images = tf.clip_by_value(images, 0, 1) # float limits 51 | images = tf.image.random_contrast(images, self.contrast[0], 52 | self.contrast[1]) 53 | # images = tf.clip_by_value(images, 0, 1) # float limits 54 | images = tf.image.random_saturation(images, self.saturation[0], 55 | self.saturation[1]) 56 | # images = tf.clip_by_value(images, 0, 1) # float limits 57 | images = tf.image.random_hue(images, self.hue) 58 | images = tf.clip_by_value(images, 0, 1) # float limits 59 | return images 60 | 61 | def color_transform(self, img1, img2): 62 | pred = tf.random.uniform([]) < self.asymmetric_color_aug_prob 63 | def true_fn(img1, img2): 64 | img1 = self.augment_color(img1) 65 | img2 = self.augment_color(img2) 66 | return [img1, img2] 67 | def false_fn(img1, img2): 68 | imgs = tf.concat((img1, img2), axis=0) 69 | imgs = self.augment_color(imgs) 70 | return tf.split(imgs, num_or_size_splits=2) 71 | 72 | return tf.cond(pred, lambda: true_fn(img1, img2), 73 | lambda: false_fn(img1, img2)) 74 | 75 | def eraser_transform(self, img1, img2, bounds=DEFAULT_ERASER_BOUNDS): 76 | ht, wd, _ = tf.unstack(tf.shape(img1), num=3) 77 | pred = tf.random.uniform([]) < self.eraser_aug_prob 78 | def true_fn(img1, img2): 79 | mean_color = tf.reduce_mean(tf.reshape(img2, (-1, 3)), axis=0) 80 | mean_color = tf.expand_dims(tf.expand_dims(mean_color, axis=0), axis=0) 81 | def body(var_img, mean_color): 82 | x0 = tf.random.uniform([], 0, wd, dtype=tf.int32) 83 | y0 = tf.random.uniform([], 0, ht, dtype=tf.int32) 84 | dx = tf.random.uniform([], bounds[0], bounds[1], dtype=tf.int32) 85 | dy = tf.random.uniform([], bounds[0], bounds[1], dtype=tf.int32) 86 | x = tf.range(wd) 87 | x_mask = (x0 <= x) & (x < x0+dx) 88 | y = tf.range(ht) 89 | y_mask = (y0 <= y) & (y < y0+dy) 90 | mask = x_mask & y_mask[:, tf.newaxis] 91 | mask = tf.cast(mask[:, :, tf.newaxis], img1.dtype) 92 | mean_slice = tf.tile(mean_color, multiples=[ht, wd, 1]) 93 | result = var_img * (1 - mask) + mean_slice * mask 94 | return result 95 | max_num = tf.random.uniform([], 1, 3, dtype=tf.int32) 96 | img2 = body(img2, mean_color) 97 | img2 = tf.cond(2 <= max_num, lambda: body(img2, mean_color), lambda: img2) 98 | return img1, img2 99 | def false_fn(img1, img2): 100 | return img1, img2 101 | 102 | return tf.cond(pred, lambda: true_fn(img1, img2), 103 | lambda: false_fn(img1, img2)) 104 | 105 | def random_vertical_flip(self, img1, img2, flow, prob=0.1): 106 | pred = tf.random.uniform([]) < prob 107 | def true_fn(img1, img2, flow): 108 | img1 = tf.image.flip_up_down(img1) 109 | img2 = tf.image.flip_up_down(img2) 110 | flow = tf.image.flip_up_down(flow) * [1.0, -1.0] 111 | return img1, img2, flow 112 | def false_fn(img1, img2, flow): 113 | return img1, img2, flow 114 | return tf.cond(pred, 115 | lambda: true_fn(img1, img2, flow), 116 | lambda: false_fn(img1, img2, flow)) 117 | 118 | def random_horizontal_flip(self, img1, img2, flow, prob=0.5): 119 | pred = tf.random.uniform([]) < prob 120 | def true_fn(img1, img2, flow): 121 | img1 = tf.image.flip_left_right(img1) 122 | img2 = tf.image.flip_left_right(img2) 123 | flow = tf.image.flip_left_right(flow) * [-1.0, 1.0] 124 | return img1, img2, flow 125 | def false_fn(img1, img2, flow): 126 | return img1, img2, flow 127 | return tf.cond(pred, 128 | lambda: true_fn(img1, img2, flow), 129 | lambda: false_fn(img1, img2, flow)) 130 | 131 | def random_scale(self, img1, img2, flow, scale_x, scale_y): 132 | pred = tf.random.uniform([]) < self.spatial_aug_prob 133 | ht, wd, _ = tf.unstack(tf.shape(img1), num=3) 134 | def true_fn(img1, img2, flow, scale_x, scale_y): 135 | # rescale the images 136 | new_ht = scale_x * tf.cast(ht, dtype=tf.float32) 137 | new_wd = scale_y * tf.cast(wd, dtype=tf.float32) 138 | new_shape = tf.cast(tf.concat([new_ht, new_wd], axis=0), dtype=tf.int32) 139 | img1 = tf.compat.v1.image.resize( 140 | img1, 141 | new_shape, 142 | tf.compat.v1.image.ResizeMethod.BILINEAR, 143 | align_corners=True) 144 | img2 = tf.compat.v1.image.resize( 145 | img2, 146 | new_shape, 147 | tf.compat.v1.image.ResizeMethod.BILINEAR, 148 | align_corners=True) 149 | flow = tf.compat.v1.image.resize( 150 | flow, 151 | new_shape, 152 | tf.compat.v1.image.ResizeMethod.BILINEAR, 153 | align_corners=True) 154 | 155 | flow = flow * tf.expand_dims( 156 | tf.expand_dims(tf.concat([scale_x, scale_y], axis=0), axis=0), axis=0) 157 | return img1, img2, flow 158 | 159 | def false_fn(img1, img2, flow): 160 | return img1, img2, flow 161 | return tf.cond(pred, 162 | lambda: true_fn(img1, img2, flow, scale_x, scale_y), 163 | lambda: false_fn(img1, img2, flow)) 164 | 165 | def spatial_transform(self, img1, img2, flow): 166 | # randomly sample scale 167 | ht, wd, _ = tf.unstack(tf.shape(img1), num=3) 168 | min_scale = tf.math.maximum( 169 | (self.crop_size[0] + 1) / ht, 170 | (self.crop_size[1] + 1) / wd) 171 | 172 | max_scale = self.max_scale 173 | min_scale = tf.math.maximum(min_scale, self.min_scale) 174 | 175 | scale = 2 ** tf.random.uniform([], self.min_scale, self.max_scale) 176 | scale_x = scale 177 | scale_y = scale 178 | pred = tf.random.uniform([]) < self.stretch_prob 179 | def true_fn(scale_x, scale_y): 180 | scale_x *= 2 ** tf.random.uniform([], -self.max_stretch, self.max_stretch) 181 | scale_y *= 2 ** tf.random.uniform([], -self.max_stretch, self.max_stretch) 182 | return tf.stack((scale_x, scale_y), axis=0) 183 | def false_fn(scale_x, scale_y): 184 | return tf.stack((scale_x, scale_y), axis=0) 185 | scales = tf.cond(pred, 186 | lambda: true_fn(scale_x, scale_y), 187 | lambda: false_fn(scale_x, scale_y)) 188 | scale_x, scale_y = tf.split(scales, num_or_size_splits=2) 189 | 190 | clip_max = tf.cast(CLIP_MAX, dtype=tf.float32) 191 | min_scale = tf.cast(min_scale, dtype=tf.float32) 192 | scale_x = tf.clip_by_value(scale_x, min_scale, clip_max) 193 | scale_y = tf.clip_by_value(scale_y, min_scale, clip_max) 194 | 195 | img1, img2, flow = self.random_scale(img1, img2, flow, scale_x, scale_y) 196 | 197 | # random flips 198 | img1, img2, flow = self.random_horizontal_flip(img1, img2, flow, prob=0.5) 199 | img1, img2, flow = self.random_vertical_flip(img1, img2, flow, prob=0.1) 200 | 201 | # clip_by_value 202 | ht, wd, _ = tf.unstack(tf.shape(img1), num=3) 203 | y0 = tf.random.uniform([], 204 | -self.margin, 205 | ht - self.crop_size[0] + self.margin, 206 | dtype=tf.int32) 207 | x0 = tf.random.uniform([], 208 | -self.margin, 209 | wd - self.crop_size[1] + self.margin, 210 | dtype=tf.int32) 211 | 212 | y0 = tf.clip_by_value(y0, 0, ht - self.crop_size[0]) 213 | x0 = tf.clip_by_value(x0, 0, wd - self.crop_size[1]) 214 | 215 | # crop 216 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]:, :] 217 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]:, :] 218 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]:, :] 219 | 220 | return img1, img2, flow 221 | 222 | def __call__(self, images, flow): 223 | images = (images + 1) / 2.0 # switch from [-1,1] to [0,1] 224 | 225 | img1, img2 = tf.unstack(images, num=2) 226 | img1, img2 = self.color_transform(img1, img2) 227 | img1, img2 = self.eraser_transform(img1, img2) 228 | img1, img2, flow = self.spatial_transform(img1, img2, flow) 229 | images = tf.stack((img1, img2), axis=0) 230 | images = tf.ensure_shape(images, 231 | (2, self.crop_size[0], self.crop_size[1], 3)) 232 | flow = tf.ensure_shape(flow, (self.crop_size[0], self.crop_size[1], 2)) 233 | 234 | images = 2 * images - 1 # switch from [0,1] to [-1,1] 235 | 236 | return images, flow 237 | 238 | 239 | def apply(element, aug_params): 240 | crop_size = (aug_params.crop_height, aug_params.crop_width) 241 | min_scale = aug_params.min_scale 242 | max_scale = aug_params.max_scale 243 | aug = Augment(crop_size=crop_size, min_scale=min_scale, max_scale=max_scale) 244 | images, flow = aug(element['inputs'], element['label']) 245 | return {'inputs': images, 'label':flow} 246 | -------------------------------------------------------------------------------- /src/dataset_lib/augmentations/simple_augmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import tensorflow as tf 17 | 18 | from augmentations import crop_aug 19 | 20 | 21 | def apply_crop(element, aug_params): 22 | crop_height = aug_params.crop_height 23 | crop_width = aug_params.crop_width 24 | return crop_aug.top_left_crop(element, crop_height, crop_width) 25 | 26 | 27 | def compute_upsample_flow(flow, size): 28 | upsampled_flow = tf.compat.v1.image.resize( 29 | flow, size, tf.compat.v1.image.ResizeMethod.BILINEAR, align_corners=True) 30 | upsampled_x = upsampled_flow[:, :, 0] * tf.cast( 31 | size[1], dtype=tf.float32) / tf.cast( 32 | tf.shape(flow)[1], dtype=tf.float32) 33 | upsampled_y = upsampled_flow[:, :, 1] * tf.cast( 34 | size[0], dtype=tf.float32) / tf.cast( 35 | tf.shape(flow)[0], dtype=tf.float32) 36 | return tf.stack((upsampled_x, upsampled_y), axis=-1) 37 | 38 | 39 | def apply_resize(element, aug_params): 40 | _, height, width, _ = tf.unstack(tf.shape(element['inputs'])) 41 | divisor = 64 42 | adapt_height = tf.to_int32( 43 | tf.math.ceil(height / divisor) * divisor) 44 | adapt_width = tf.to_int32(tf.math.ceil(width / divisor) * divisor) 45 | 46 | images = tf.compat.v1.image.resize( 47 | element['inputs'], 48 | [adapt_height, adapt_width], 49 | tf.compat.v1.image.ResizeMethod.BILINEAR, 50 | align_corners=True) 51 | flows = compute_upsample_flow(element['label'], (adapt_height, adapt_width)) 52 | return {'inputs': images, 'label': flows} 53 | -------------------------------------------------------------------------------- /src/dataset_lib/augmentations/spatial_aug.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | 18 | import tensorflow as tf 19 | 20 | from tensorflow.contrib import image as contrib_image 21 | 22 | from augmentations import crop_aug 23 | from augmentations import image_aug 24 | 25 | 26 | MAX_SAMPLES = 50 # max number of rejection samples 27 | INVALID_THRESHOLD = 1e3 28 | INVALID_FLOW_VALUE = 1e9 29 | FLOW_SCALE_FACTOR = 20 30 | 31 | 32 | class SpatialAug(object): 33 | """Sample spatial agumentation parameters, borrowed from VCN's pytorch implementation of FlowNet/PWC-Net.""" 34 | 35 | def __init__(self, 36 | crop, 37 | scale=None, 38 | rot=None, 39 | trans=None, 40 | squeeze=None, 41 | schedule_coeff=1, 42 | order=1, 43 | black=False): 44 | self.crop = crop 45 | self.scale = scale 46 | self.rot = rot 47 | self.trans = trans 48 | self.squeeze = squeeze 49 | self.t_0 = tf.zeros(1) 50 | self.t_1 = tf.zeros(1) 51 | self.t_2 = tf.zeros(1) 52 | self.t_3 = tf.zeros(1) 53 | self.t_4 = tf.zeros(1) 54 | self.t_5 = tf.zeros(1) 55 | self.schedule_coeff = schedule_coeff 56 | self.order = order 57 | self.black = black 58 | 59 | def to_identity(self): 60 | """Identity transformation.""" 61 | self.t_0 = tf.constant([1], dtype=tf.float32) 62 | self.t_1 = tf.constant([0], dtype=tf.float32) 63 | self.t_2 = tf.constant([0], dtype=tf.float32) 64 | self.t_3 = tf.constant([1], dtype=tf.float32) 65 | self.t_4 = tf.constant([0], dtype=tf.float32) 66 | self.t_5 = tf.constant([0], dtype=tf.float32) 67 | 68 | def left_multiply(self, u0, u1, u2, u3, u4, u5): 69 | """Composite transformations.""" 70 | result_0 = self.t_0 * u0 + self.t_1 * u2 71 | result_1 = self.t_0 * u1 + self.t_1 * u3 72 | 73 | result_2 = self.t_2 * u0 + self.t_3 * u2 74 | result_3 = self.t_2 * u1 + self.t_3 * u3 75 | 76 | result_4 = self.t_4 * u0 + self.t_5 * u2 + u4 77 | result_5 = self.t_4 * u1 + self.t_5 * u3 + u5 78 | 79 | self.t_0 = result_0 80 | self.t_1 = result_1 81 | self.t_2 = result_2 82 | self.t_3 = result_3 83 | self.t_4 = result_4 84 | self.t_5 = result_5 85 | 86 | def inverse(self): 87 | """Compute inverse transformation.""" 88 | a = self.t_0 89 | c = self.t_2 90 | e = self.t_4 91 | b = self.t_1 92 | d = self.t_3 93 | f = self.t_5 94 | 95 | denom = a * d - b * c 96 | 97 | result_0 = d / denom 98 | result_1 = -b / denom 99 | result_2 = -c / denom 100 | result_3 = a / denom 101 | result_4 = (c * f - d * e) / denom 102 | result_5 = (b * e - a * f) / denom 103 | 104 | return tf.stack( 105 | [result_0, result_1, result_2, result_3, result_4, result_5]) 106 | 107 | def grid_transform(self, meshgrid, t, normalize=True, gridsize=None): 108 | """Transform grid according to transformation.""" 109 | if gridsize is None: 110 | h, w = meshgrid[0].shape 111 | else: 112 | h, w = gridsize 113 | 114 | vgrid = tf.stack([(meshgrid[0] * t[0] + meshgrid[1] * t[2] + t[4]), 115 | (meshgrid[0] * t[1] + meshgrid[1] * t[3] + t[5])], 2) 116 | if normalize: 117 | # normlize for pytorch style [-1, 1] 118 | vgridx = 2.0 * vgrid[:, :, 0] / tf.math.maximum(w - 1, 1) - 1.0 119 | vgridy = 2.0 * vgrid[:, :, 1] / tf.math.maximum(h - 1, 1) - 1.0 120 | return tf.stack([vgridx, vgridy], 2) 121 | else: 122 | return vgrid 123 | 124 | def __call__(self, h, w): 125 | th, tw = self.crop 126 | # meshgrid = np.meshgrid(range(th), range(tw))[::-1] 127 | cornergrid = np.meshgrid([0, th - 1], [0, tw - 1])[::-1] 128 | 129 | def cond(out, not_found, i, max_iters): 130 | del out 131 | 132 | return tf.math.logical_and(not_found, tf.less(i, max_iters)) 133 | 134 | def body(out, not_found, i, max_iters): 135 | del not_found 136 | 137 | # Compute transformation for first image. 138 | self.to_identity() 139 | # Center. 140 | self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th) 141 | scale0 = 1 142 | scale1 = 1 143 | squeeze0 = 1 144 | squeeze1 = 1 145 | # Sample rotation. 146 | if self.rot is None: 147 | rot0 = 0.0 148 | rot1 = 0.0 149 | else: 150 | rot0 = tf.random.uniform([], minval=-self.rot[0], maxval=self.rot[0]) 151 | rot1 = tf.random.uniform( 152 | [], 153 | minval=-self.rot[1] * self.schedule_coeff, 154 | maxval=self.rot[1] * self.schedule_coeff) + rot0 155 | self.left_multiply( 156 | tf.math.cos(rot0), tf.math.sin(rot0), -tf.math.sin(rot0), 157 | tf.math.cos(rot0), 0, 0) 158 | 159 | # Sample scale & squeeze. 160 | if self.squeeze is None: 161 | squeeze0 = 1.0 162 | squeeze1 = 1.0 163 | else: 164 | squeeze0 = tf.math.exp( 165 | tf.random.uniform([], 166 | minval=-self.squeeze[0], 167 | maxval=self.squeeze[0])) 168 | squeeze1 = tf.math.exp( 169 | tf.random.uniform( 170 | [], 171 | minval=-self.squeeze[1] * self.schedule_coeff, 172 | maxval=self.squeeze[1] * self.schedule_coeff)) * squeeze0 173 | 174 | if self.scale is None: 175 | scale0 = 1.0 176 | scale1 = 1.0 177 | else: 178 | scale0 = tf.math.exp( 179 | tf.random.uniform([], 180 | minval=self.scale[2] - self.scale[0], 181 | maxval=self.scale[2] + self.scale[0])) 182 | scale1 = tf.math.exp( 183 | tf.random.uniform( 184 | [], 185 | minval=-self.scale[1] * self.schedule_coeff, 186 | maxval=self.scale[1] * self.schedule_coeff)) * scale0 187 | 188 | self.left_multiply(1.0 / (scale0 * squeeze0), 0, 0, 189 | 1.0 / (scale0 / squeeze0), 0, 0) 190 | 191 | # Sample translation. 192 | if self.trans is None: 193 | trans0 = [0.0, 0.0] 194 | trans1 = [0.0, 0.0] 195 | else: 196 | trans0 = tf.random.uniform([2], 197 | minval=-self.trans[0], 198 | maxval=self.trans[0]) 199 | trans1 = tf.random.uniform( 200 | [2], 201 | minval=-self.trans[1] * self.schedule_coeff, 202 | maxval=self.trans[1] * self.schedule_coeff) + trans0 203 | 204 | self.left_multiply(1, 0, 0, 1, trans0[0] * tw, trans0[1] * th) 205 | 206 | self.left_multiply(1, 0, 0, 1, .5 * float(w), .5 * float(h)) 207 | transmat0_0 = tf.identity(self.t_0) 208 | transmat0_1 = tf.identity(self.t_1) 209 | transmat0_2 = tf.identity(self.t_2) 210 | transmat0_3 = tf.identity(self.t_3) 211 | transmat0_4 = tf.identity(self.t_4) 212 | transmat0_5 = tf.identity(self.t_5) 213 | transmat0 = [ 214 | transmat0_0, transmat0_1, transmat0_2, transmat0_3, transmat0_4, 215 | transmat0_5 216 | ] 217 | 218 | # Compute transformation for second image. 219 | self.to_identity() 220 | self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th) 221 | if self.rot is not None: 222 | self.left_multiply( 223 | tf.math.cos(rot1), tf.math.sin(rot1), -tf.math.sin(rot1), 224 | tf.math.cos(rot1), 0, 0) 225 | if self.trans is not None: 226 | self.left_multiply(1, 0, 0, 1, trans1[0] * tw, trans1[1] * th) 227 | self.left_multiply(1.0 / (scale1 * squeeze1), 0, 0, 228 | 1.0 / (scale1 / squeeze1), 0, 0) 229 | self.left_multiply(1, 0, 0, 1, .5 * float(w), .5 * float(h)) 230 | transmat1_0 = tf.identity(self.t_0) 231 | transmat1_1 = tf.identity(self.t_1) 232 | transmat1_2 = tf.identity(self.t_2) 233 | transmat1_3 = tf.identity(self.t_3) 234 | transmat1_4 = tf.identity(self.t_4) 235 | transmat1_5 = tf.identity(self.t_5) 236 | transmat1 = [ 237 | transmat1_0, transmat1_1, transmat1_2, transmat1_3, transmat1_4, 238 | transmat1_5 239 | ] 240 | 241 | sum_val0 = tf.math.reduce_sum( 242 | tf.to_float( 243 | tf.math.abs( 244 | self.grid_transform( 245 | cornergrid, transmat0, gridsize=[float(h), 246 | float(w)])) > 1)) 247 | sum_val1 = tf.math.reduce_sum( 248 | tf.to_float( 249 | tf.math.abs( 250 | self.grid_transform( 251 | cornergrid, transmat1, gridsize=[float(h), 252 | float(w)])) > 1)) 253 | bool_val = tf.logical_or( 254 | tf.math.equal((sum_val0 + sum_val1), 0), self.black) 255 | 256 | out = ( 257 | (rot0 * 180 / 3.14), 258 | (scale0 * squeeze0), 259 | (scale0 / squeeze0), 260 | (rot1 * 180 / 3.14), 261 | (scale1 * squeeze1), 262 | (scale1 / squeeze1), 263 | ) 264 | 265 | return [out, tf.math.logical_not(bool_val), tf.add(i, 1), max_iters] 266 | 267 | identity_val = tf.constant([0.], shape=()), tf.constant( 268 | [1.], shape=()), tf.constant([1.], shape=()), tf.constant( 269 | [0.], shape=()), tf.constant([1.], shape=()), tf.constant([1.], 270 | shape=()) 271 | not_found = tf.ones([], dtype=tf.bool) 272 | ret_val, not_found, _, _ = tf.while_loop( 273 | cond, body, [identity_val, not_found, 0, MAX_SAMPLES]) 274 | 275 | return tf.cond(not_found, lambda: identity_val, lambda: ret_val) 276 | 277 | def apply(images, forward_flow, augmentation_params, is_hard_augment): 278 | """Augments the inputs by applying random spatial transformations. 279 | 280 | Args: 281 | images: A tensor of size [2, height, width, 3] representing the two images. 282 | forward_flow: A tensor of size [height, width, 2] representing the flow from 283 | the first image to the second. 284 | augmentation_params: An AugmentationParams controlling the augmentations to 285 | be performed. 286 | 287 | Returns: 288 | (images, forward_flow) after any spatial transformations. 289 | """ 290 | images = tf.convert_to_tensor(images) 291 | input_image_height = images.shape[1] 292 | input_image_width = images.shape[2] 293 | crop_height = augmentation_params.crop_height 294 | crop_width = augmentation_params.crop_width 295 | 296 | # For quick experiment: sample a valid one, re-sample crop-center 297 | _, input_image_height, input_image_width, _ = tf.unstack( 298 | tf.shape(images), num=4) 299 | 300 | # Sample rotations and stretches for each images. 301 | rotation_degrees, stretch_factor_y, stretch_factor_x, rotation_degrees2, stretch_factor_y2, stretch_factor_x2 = _rejection_sample_spatial_aug_parameters( 302 | input_image_height, input_image_width, augmentation_params, 303 | is_hard_augment) 304 | 305 | # Sample crop_centers for each images. 306 | crop_center_y, crop_center_x, crop_center_y2, crop_center_x2 = crop_aug.sample_cropping_centers( 307 | input_image_height, input_image_width, stretch_factor_y, stretch_factor_x, 308 | rotation_degrees, stretch_factor_y2, stretch_factor_x2, rotation_degrees2, 309 | augmentation_params) 310 | 311 | # Transform first image. 312 | crop_center_y += 1 313 | crop_center_x += 1 314 | transform = crop_aug.compose_cropping_transformation( 315 | stretch_factor_y, stretch_factor_x, crop_center_y, crop_center_x, 316 | rotation_degrees, crop_height, crop_width) 317 | 318 | transform = tf.reshape(transform, [-1])[:8] # tf is row-based 319 | output_shape = tf.stack([crop_height, crop_width]) 320 | aug_image = contrib_image.transform( 321 | images[0, :, :, :], 322 | transform, 323 | interpolation="BILINEAR", 324 | output_shape=output_shape) 325 | 326 | # Transform flow 327 | aug_flow = contrib_image.transform( 328 | forward_flow, 329 | transform, 330 | interpolation=augmentation_params.flow_interpolation, 331 | output_shape=output_shape) 332 | # print('forward_flow316', np.max(forward_flow)) 333 | 334 | all_ones = tf.ones(tf.shape(forward_flow)) 335 | aug_all_ones = contrib_image.transform( 336 | all_ones, 337 | transform, 338 | interpolation=augmentation_params.flow_interpolation, 339 | output_shape=output_shape) 340 | # Mark invalid pixels (extreme value or out-of-boundary) 341 | invalid_mask = tf.logical_or( 342 | tf.abs(aug_flow) > tf.to_float(INVALID_THRESHOLD), aug_all_ones < 1.0) 343 | # invalid_mask = tf.abs(aug_flow) > tf.to_float(INVALID_THRESHOLD) 344 | invalid_flow = tf.ones(aug_flow.get_shape()) * INVALID_FLOW_VALUE 345 | 346 | # Transform second image 347 | crop_center_y2 += 1 348 | crop_center_x2 += 1 349 | transform2 = crop_aug.compose_cropping_transformation( 350 | stretch_factor_y2, stretch_factor_x2, crop_center_y2, crop_center_x2, 351 | rotation_degrees2, crop_height, crop_width) 352 | 353 | # Compute reverse transform for augmenting flow 354 | transform2_inv = tf.linalg.inv(transform2) 355 | transform2_inv = tf.reshape(transform2_inv, [-1])[:8] 356 | transform2 = tf.reshape(transform2, [-1])[:8] 357 | aug_image2 = contrib_image.transform( 358 | images[1, :, :, :], 359 | transform2, 360 | interpolation="BILINEAR", 361 | output_shape=output_shape) 362 | 363 | # Composite augmented image pairs. 364 | images = tf.stack([aug_image, aug_image2]) 365 | 366 | # Compute augmented optical flow. 367 | # Compute position in transformed first image. 368 | x, y = tf.meshgrid(tf.range(crop_width), tf.range(crop_height)) 369 | x = tf.to_float(x) 370 | y = tf.to_float(y) 371 | # Map to coordinates of first image. 372 | x0 = x * transform[0] + y * transform[1] + transform[2] 373 | y0 = x * transform[3] + y * transform[4] + transform[5] 374 | # Map to coordinates of second image. 375 | x1 = x0 + aug_flow[:, :, 0] * FLOW_SCALE_FACTOR 376 | y1 = y0 + aug_flow[:, :, 1] * FLOW_SCALE_FACTOR 377 | # Map to coordinates of augmented second image. 378 | x11 = x1 * transform2_inv[0] + y1 * transform2_inv[1] + transform2_inv[2] 379 | y11 = x1 * transform2_inv[3] + y1 * transform2_inv[4] + transform2_inv[5] 380 | # Compute flow for augmented image paris & scale for training. 381 | forward_flow = tf.stack([x11 - x, y11 - y], -1) / FLOW_SCALE_FACTOR 382 | # print('forward_flow368', np.max(forward_flow.numpy())) 383 | 384 | # Remark invalid flow. 385 | forward_flow = tf.where(invalid_mask, invalid_flow, forward_flow) 386 | # print('forward_flow370', np.max(forward_flow.numpy())) 387 | 388 | # Apply a horizontal flip with 50% probability. 389 | should_flip = tf.less(tf.random_uniform([]), 0.5) 390 | # pyformat: disable 391 | images, forward_flow = tf.cond( 392 | should_flip, 393 | lambda: image_aug.apply_horizontal_flip(images, forward_flow), 394 | lambda: (images, forward_flow)) 395 | # pyformat: enable 396 | # print('380', np.max(forward_flow.numpy())) 397 | 398 | image0, image1, forward_flow = image_aug.random_vertical_flip( 399 | images[0, :, :, :], images[1, :, :, :], forward_flow, 400 | augmentation_params.vflip_prob) 401 | images = tf.stack([image0, image1]) 402 | # print('forward_flow385', np.max(forward_flow.numpy())) 403 | 404 | # Augmentation make flip signs of flow. Flip to the same extreme value 405 | invalid_mask = tf.abs(forward_flow) > tf.to_float(INVALID_THRESHOLD) 406 | invalid_flow = tf.ones(forward_flow.get_shape())*INVALID_FLOW_VALUE 407 | forward_flow = tf.where(invalid_mask, invalid_flow, forward_flow) 408 | # print('forward_flow390', np.max(forward_flow.numpy())) 409 | 410 | if augmentation_params.disable_ground_truth: 411 | # Set GT to invalid values for semi-supervised training 412 | forward_flow = tf.ones(forward_flow.get_shape())*INVALID_FLOW_VALUE 413 | 414 | return images, forward_flow 415 | 416 | 417 | def _rejection_sample_spatial_aug_parameters(input_image_height, 418 | input_image_width, 419 | augmentation_params, 420 | is_hard_augment): 421 | """Rejection sample rotation and scaling factors.""" 422 | th = augmentation_params.crop_height 423 | tw = augmentation_params.crop_width 424 | 425 | def hard_augment(): 426 | spa = SpatialAug( 427 | [th, tw], 428 | scale=[augmentation_params.scale1, 0.03, augmentation_params.scale2], 429 | rot=[augmentation_params.rot1, 0.03], 430 | trans=[augmentation_params.tran1, 0.03], 431 | squeeze=[augmentation_params.squeeze1, 0.], 432 | black=augmentation_params.black) 433 | return spa(input_image_height, input_image_width) 434 | 435 | def easy_augment(): 436 | spa = SpatialAug([th, tw], 437 | trans=[0.4, 0.03], 438 | black=augmentation_params.black) 439 | return spa(input_image_height, input_image_width) 440 | 441 | rotation_degrees, stretch_factor_x, stretch_factor_y, rotation_degrees2, stretch_factor_x2, stretch_factor_y2 = tf.cond( 442 | is_hard_augment, hard_augment, easy_augment) 443 | 444 | # Schedule parameters for second image. 445 | s1 = tf.math.log(tf.sqrt(stretch_factor_x * stretch_factor_y + 1e-9) + 446 | 1e-9) # scale 447 | z1 = tf.math.log(tf.sqrt(stretch_factor_x / stretch_factor_y + 1e-9) + 448 | 1e-9) # squeeze 449 | 450 | s2 = tf.math.log( 451 | tf.sqrt(stretch_factor_x2 * stretch_factor_y2 + 1e-9) + 1e-9) # scale 452 | z2 = tf.math.log( 453 | tf.sqrt(stretch_factor_x2 / stretch_factor_y2 + 1e-9) + 1e-9) # squeeze 454 | 455 | s2 = tf.to_float(s1) + tf.to_float(s2 - 456 | s1) * augmentation_params.schedule_coeff 457 | z2 = tf.to_float(z1) + tf.to_float(z2 - 458 | z1) * augmentation_params.schedule_coeff 459 | 460 | stretch_factor_x2 = tf.exp(s2 + z2) 461 | stretch_factor_y2 = tf.exp(s2 - z2) 462 | 463 | return rotation_degrees, stretch_factor_y, stretch_factor_x, rotation_degrees2, stretch_factor_y2, stretch_factor_x2 464 | 465 | --------------------------------------------------------------------------------