├── LICENSE ├── README.md ├── tcc ├── alignment.py ├── losses.py └── stochastic_alignment.py ├── tcc_tf ├── alignment.py ├── deterministic_alignment.py ├── losses.py └── stochastic_alignment.py └── test_tcc.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Temporal Cycle-Consistency Learning (https://sites.google.com/view/temporal-cycle-consistency/home) 2 | 3 | This is the pytorch version of tcc loss in [repo](https://github.com/google-research/google-research/tree/3002480d94c443da491f194dfdb6358dbc2a4500/tcc). It was used in the CVPR 2019 paper Temporal Cycle-Consistency 4 | Learning (https://arxiv.org/abs/1904.07846). 5 | 6 | ## Usage 7 | 8 | ```test_tcc.ipynb``` gives an example of usage both in tensorflow and pytorch version, please refer to it for more detail. Note that, to make a fair comparision, we put the original tensorflow version in folder ```tcc_tf```, which is credit to [repo](https://github.com/google-research/google-research/tree/3002480d94c443da491f194dfdb6358dbc2a4500/tcc). 9 | 10 | ## Reference 11 | 12 | - tensorflow repo: tcc: https://github.com/google-research/google-research.git 13 | 14 | - paper: 15 | 16 | 17 | ```@InProceedings{Dwibedi_2019_CVPR, 18 | author = {Dwibedi, Debidatta and Aytar, Yusuf and Tompson, Jonathan and Sermanet, Pierre and Zisserman, Andrew}, 19 | title = {Temporal Cycle-Consistency Learning}, 20 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 21 | month = {June}, 22 | year = {2019}, 23 | } 24 | ``` 25 | -------------------------------------------------------------------------------- /tcc/alignment.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | # from .deterministic_alignment import compute_deterministic_alignment_loss 9 | from .stochastic_alignment import compute_stochastic_alignment_loss 10 | 11 | 12 | def compute_alignment_loss(embs, 13 | batch_size, 14 | steps=None, 15 | seq_lens=None, 16 | stochastic_matching=False, 17 | normalize_embeddings=False, 18 | loss_type='classification', 19 | similarity_type='l2', 20 | num_cycles=20, 21 | cycle_length=2, 22 | temperature=0.1, 23 | label_smoothing=0.1, 24 | variance_lambda=0.001, 25 | huber_delta=0.1, 26 | normalize_indices=True): 27 | 28 | # Get the number of timestemps in the sequence embeddings. 29 | num_steps = embs.size(1) 30 | # print(num_steps) 31 | 32 | # If steps has not been provided assume sampling has been done uniformly. 33 | if steps is None: 34 | steps = torch.arange(0, num_steps).unsqueeze(0).repeat([batch_size, 1]) 35 | 36 | # print(steps.size()) 37 | 38 | # If seq_lens has not been provided assume is equal to the size of the 39 | # time axis in the emebeddings. 40 | if seq_lens is None: 41 | seq_lens = torch.tensor(num_steps).unsqueeze(0).repeat([batch_size]).int() 42 | 43 | # print(seq_lens) 44 | 45 | # check if batch_size if consistent with emb etc 46 | assert batch_size == embs.size(0) 47 | assert num_steps == steps.size(1) 48 | assert batch_size == steps.size(0) 49 | 50 | if normalize_embeddings: 51 | embs = F.normalize(embs, dim=-1, p=2) 52 | 53 | if stochastic_matching: 54 | loss = compute_stochastic_alignment_loss( 55 | embs=embs, 56 | steps=steps, 57 | seq_lens=seq_lens, 58 | num_steps=num_steps, 59 | batch_size=batch_size, 60 | loss_type=loss_type, 61 | similarity_type=similarity_type, 62 | num_cycles=num_cycles, 63 | cycle_length=cycle_length, 64 | temperature=temperature, 65 | label_smoothing=label_smoothing, 66 | variance_lambda=variance_lambda, 67 | huber_delta=huber_delta, 68 | normalize_indices=normalize_indices) 69 | else: 70 | raise NotImplementedError 71 | # loss = compute_deterministic_alignment_loss( 72 | # embs=embs, 73 | # steps=steps, 74 | # seq_lens=seq_lens, 75 | # num_steps=num_steps, 76 | # batch_size=batch_size, 77 | # loss_type=loss_type, 78 | # similarity_type=similarity_type, 79 | # temperature=temperature, 80 | # label_smoothing=label_smoothing, 81 | # variance_lambda=variance_lambda, 82 | # huber_delta=huber_delta, 83 | # normalize_indices=normalize_indices) 84 | 85 | return loss 86 | 87 | -------------------------------------------------------------------------------- /tcc/losses.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Loss functions imposing the cycle-consistency constraints.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import torch 23 | from torch.autograd import Variable 24 | import torch.nn.functional as F 25 | 26 | def classification_loss(logits, labels, label_smoothing): 27 | """Loss function based on classifying the correct indices. 28 | In the paper, this is called Cycle-back Classification. 29 | Args: 30 | logits: Tensor, Pre-softmax scores used for classification loss. These are 31 | similarity scores after cycling back to the starting sequence. 32 | labels: Tensor, One hot labels containing the ground truth. The index where 33 | the cycle started is 1. 34 | label_smoothing: Float, label smoothing factor which can be used to 35 | determine how hard the alignment should be. 36 | Returns: 37 | loss: Tensor, A scalar classification loss calculated using standard softmax 38 | cross-entropy loss. 39 | """ 40 | # Just to be safe, we stop gradients from labels as we are generating labels. 41 | return -torch.mean(torch.sum(Variable(labels) * 42 | F.log_softmax(logits, dim=1), dim=1), dim=0) 43 | 44 | 45 | def regression_loss(logits, labels, num_steps, steps, seq_lens, loss_type, 46 | normalize_indices, variance_lambda, huber_delta): 47 | """Loss function based on regressing to the correct indices. 48 | In the paper, this is called Cycle-back Regression. There are 3 variants 49 | of this loss: 50 | i) regression_mse: MSE of the predicted indices and ground truth indices. 51 | ii) regression_mse_var: MSE of the predicted indices that takes into account 52 | the variance of the similarities. This is important when the rate at which 53 | sequences go through different phases changes a lot. The variance scaling 54 | allows dynamic weighting of the MSE loss based on the similarities. 55 | iii) regression_huber: Huber loss between the predicted indices and ground 56 | truth indices. 57 | Args: 58 | logits: Tensor, Pre-softmax similarity scores after cycling back to the 59 | starting sequence. 60 | labels: Tensor, One hot labels containing the ground truth. The index where 61 | the cycle started is 1. 62 | num_steps: Integer, Number of steps in the sequence embeddings. 63 | steps: Tensor, step indices/frame indices of the embeddings of the shape 64 | [N, T] where N is the batch size, T is the number of the timesteps. 65 | seq_lens: Tensor, Lengths of the sequences from which the sampling was done. 66 | This can provide additional temporal information to the alignment loss. 67 | loss_type: String, This specifies the kind of regression loss function. 68 | Currently supported loss functions: regression_mse, regression_mse_var, 69 | regression_huber. 70 | normalize_indices: Boolean, If True, normalizes indices by sequence lengths. 71 | Useful for ensuring numerical instabilities don't arise as sequence 72 | indices can be large numbers. 73 | variance_lambda: Float, Weight of the variance of the similarity 74 | predictions while cycling back. If this is high then the low variance 75 | similarities are preferred by the loss while making this term low results 76 | in high variance of the similarities (more uniform/random matching). 77 | huber_delta: float, Huber delta described in tf.keras.losses.huber_loss. 78 | Returns: 79 | loss: Tensor, A scalar loss calculated using a variant of regression. 80 | """ 81 | raise ValueError('Unsupported regression loss %s. Supported losses are: ' 82 | 'regression_mse, regresstion_mse_var and regression_huber.' 83 | % loss_type) 84 | -------------------------------------------------------------------------------- /tcc/stochastic_alignment.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Stochastic alignment between sampled cycles in the sequences in a batch.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import torch 23 | import torch.nn.functional as F 24 | 25 | from .losses import classification_loss 26 | from .losses import regression_loss 27 | 28 | def _align_single_cycle(cycle, embs, cycle_length, num_steps, 29 | similarity_type, temperature): 30 | # choose from random frame 31 | n_idx = (torch.rand(1)*num_steps).long()[0] 32 | # n_idx = torch.tensor(8).long() 33 | 34 | # Create labels 35 | onehot_labels = torch.eye(num_steps)[n_idx] 36 | 37 | # Choose query feats for first frame. 38 | query_feats = embs[cycle[0], n_idx:n_idx + 1] 39 | num_channels = query_feats.size(-1) 40 | for c in range(1, cycle_length + 1): 41 | candidate_feats = embs[cycle[c]] 42 | if similarity_type == 'l2': 43 | mean_squared_distance = torch.sum((query_feats.repeat([num_steps, 1]) - 44 | candidate_feats) ** 2, dim=1) 45 | similarity = -mean_squared_distance 46 | elif similarity_type == 'cosine': 47 | similarity = torch.squeeze(torch.matmul(candidate_feats, query_feats.transpose(0, 1))) 48 | else: 49 | raise ValueError('similarity_type can either be l2 or cosine.') 50 | 51 | similarity /= float(num_channels) 52 | similarity /= temperature 53 | 54 | beta = F.softmax(similarity, dim=0).unsqueeze(1).repeat([1, num_channels]) 55 | query_feats = torch.sum(beta * candidate_feats, dim=0, keepdim=True) 56 | 57 | return similarity.unsqueeze(0), onehot_labels.unsqueeze(0) 58 | 59 | def _align(cycles, embs, num_steps, num_cycles, cycle_length, 60 | similarity_type, temperature): 61 | """Align by finding cycles in embs.""" 62 | logits_list = [] 63 | labels_list = [] 64 | for i in range(num_cycles): 65 | logits, labels = _align_single_cycle(cycles[i], 66 | embs, 67 | cycle_length, 68 | num_steps, 69 | similarity_type, 70 | temperature) 71 | logits_list.append(logits) 72 | labels_list.append(labels) 73 | 74 | logits = torch.cat(logits_list, dim=0) 75 | labels = torch.cat(labels_list, dim=0) 76 | 77 | return logits, labels 78 | 79 | def gen_cycles(num_cycles, batch_size, cycle_length=2): 80 | """Generates cycles for alignment. 81 | Generates a batch of indices to cycle over. For example setting num_cycles=2, 82 | batch_size=5, cycle_length=3 might return something like this: 83 | cycles = [[0, 3, 4, 0], [1, 2, 0, 3]]. This means we have 2 cycles for which 84 | the loss will be calculated. The first cycle starts at sequence 0 of the 85 | batch, then we find a matching step in sequence 3 of that batch, then we 86 | find matching step in sequence 4 and finally come back to sequence 0, 87 | completing a cycle. 88 | Args: 89 | num_cycles: Integer, Number of cycles that will be matched in one pass. 90 | batch_size: Integer, Number of sequences in one batch. 91 | cycle_length: Integer, Length of the cycles. If we are matching between 92 | 2 sequences (cycle_length=2), we get cycles that look like [0,1,0]. 93 | This means that we go from sequence 0 to sequence 1 then back to sequence 94 | 0. A cycle length of 3 might look like [0, 1, 2, 0]. 95 | Returns: 96 | cycles: Tensor, Batch indices denoting cycles that will be used for 97 | calculating the alignment loss. 98 | """ 99 | sorted_idxes = torch.arange(batch_size).unsqueeze(0).repeat([num_cycles, 1]) 100 | sorted_idxes = sorted_idxes.view([batch_size, num_cycles]) 101 | cycles = sorted_idxes[torch.randperm(len(sorted_idxes))].view([num_cycles, batch_size]) 102 | cycles = cycles[:, :cycle_length] 103 | cycles = torch.cat([cycles, cycles[:, 0:1]], dim=1) 104 | 105 | return cycles 106 | 107 | 108 | def compute_stochastic_alignment_loss(embs, 109 | steps, 110 | seq_lens, 111 | num_steps, 112 | batch_size, 113 | loss_type, 114 | similarity_type, 115 | num_cycles, 116 | cycle_length, 117 | temperature, 118 | label_smoothing, 119 | variance_lambda, 120 | huber_delta, 121 | normalize_indices): 122 | 123 | cycles = gen_cycles(num_cycles, batch_size, cycle_length) 124 | logits, labels = _align(cycles, embs, num_steps, num_cycles, cycle_length, 125 | similarity_type, temperature) 126 | 127 | if loss_type == 'classification': 128 | loss = classification_loss(logits, labels, label_smoothing) 129 | # elif 'regression' in loss_type: 130 | # steps = tf.gather(steps, cycles[:, 0]) 131 | # seq_lens = tf.gather(seq_lens, cycles[:, 0]) 132 | # loss = regression_loss(logits, labels, num_steps, steps, seq_lens, 133 | # loss_type, normalize_indices, variance_lambda, 134 | # huber_delta) 135 | else: 136 | raise ValueError('Unidentified loss type %s. Currently supported loss ' 137 | 'types are: regression_mse, regression_huber, ' 138 | 'classification .' 139 | % loss_type) 140 | return loss 141 | 142 | 143 | -------------------------------------------------------------------------------- /tcc_tf/alignment.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Variants of the cycle-consistency loss described in TCC paper. 17 | 18 | The Temporal Cycle-Consistency (TCC) Learning paper 19 | (https://arxiv.org/pdf/1904.07846.pdf) describes a loss that enables learning 20 | of self-supervised representations from sequences of embeddings that are good 21 | at temporally fine-grained tasks like phase classification, video alignment etc. 22 | 23 | These losses impose cycle-consistency constraints between sequences of 24 | embeddings. Another interpretation of the cycle-consistency constraints is 25 | that of mutual nearest-nieghbors. This means if state A in sequence 1 is the 26 | nearest neighbor of state B in sequence 2 then it must also follow that B is the 27 | nearest neighbor of A. We found that imposing this constraint on a dataset of 28 | related sequences (like videos of people pitching a baseball) allows us to learn 29 | generally useful visual representations. 30 | 31 | This code allows the user to apply the loss while giving them the freedom to 32 | choose the right encoder for their dataset/task. One advice for choosing an 33 | encoder is to ensure that the encoder does not solve the mutual neighbor finding 34 | task in a trivial fashion. For example, if one uses an LSTM or Transformer with 35 | positional encodings, the matching between sequences may be done trivially by 36 | counting the frame index with the encoder rather than learning good features. 37 | """ 38 | 39 | from __future__ import absolute_import 40 | from __future__ import division 41 | from __future__ import print_function 42 | 43 | import tensorflow.compat.v2 as tf 44 | 45 | from .deterministic_alignment import compute_deterministic_alignment_loss 46 | from .stochastic_alignment import compute_stochastic_alignment_loss 47 | 48 | 49 | def compute_alignment_loss(embs, 50 | batch_size, 51 | steps=None, 52 | seq_lens=None, 53 | stochastic_matching=False, 54 | normalize_embeddings=False, 55 | loss_type='classification', 56 | similarity_type='l2', 57 | num_cycles=20, 58 | cycle_length=2, 59 | temperature=0.1, 60 | label_smoothing=0.1, 61 | variance_lambda=0.001, 62 | huber_delta=0.1, 63 | normalize_indices=True): 64 | """Computes alignment loss between sequences of embeddings. 65 | 66 | This function is a wrapper around different variants of the alignment loss 67 | described deterministic_alignment.py and stochastic_alignment.py files. The 68 | structure of the library is as follows: 69 | i) loss_fns.py - Defines the different loss functions. 70 | ii) deterministic_alignment.py - Performs the alignment between sequences by 71 | deterministically sampling all steps of the sequences. 72 | iii) stochastic_alignment.py - Performs the alignment between sequences by 73 | stochasticallty sub-sampling a fixed number of steps from the sequences. 74 | 75 | There are four major hparams that need to be tuned while applying the loss: 76 | i) Should the loss be applied with L2 normalization on the embeddings or 77 | without it? 78 | ii) Should we perform stochastic alignment of sequences? This means should we 79 | use all the steps of the embedding or only choose a random subset for 80 | alignment? 81 | iii) Should we apply cycle-consistency constraints using a classification loss 82 | or a regression loss? (Section 3 in paper) 83 | iv) Should the similarity metric be based on an L2 distance or cosine 84 | similarity? 85 | 86 | Other hparams that can be used to control how hard/soft we want the alignment 87 | between different sequences to be: 88 | i) temperature (all losses) 89 | ii) label_smoothing (classification) 90 | iii) variance_lambda (regression_mse_var) 91 | iv) huber_delta (regression_huber) 92 | Each of these params are used in their respective loss types (in brackets) and 93 | allow the application of the cycle-consistency constraints in a controllable 94 | manner but they do so in very different ways. Please refer to paper for more 95 | details. 96 | 97 | The default hparams work well for frame embeddings of videos of humans 98 | performing actions. Other datasets might need different values of hparams. 99 | 100 | 101 | Args: 102 | embs: Tensor, sequential embeddings of the shape [N, T, D] where N is the 103 | batch size, T is the number of timesteps in the sequence, D is the size of 104 | the embeddings. 105 | batch_size: Integer, Size of the batch. 106 | steps: Tensor, step indices/frame indices of the embeddings of the shape 107 | [N, T] where N is the batch size, T is the number of the timesteps. 108 | If this is set to None, then we assume that the sampling was done in a 109 | uniform way and use tf.range(num_steps) as the steps. 110 | seq_lens: Tensor, Lengths of the sequences from which the sampling was done. 111 | This can provide additional information to the alignment loss. This is 112 | different from num_steps which is just the number of steps that have been 113 | sampled from the entire sequence. 114 | stochastic_matching: Boolean, Should the used for matching be sampled 115 | stochastically or deterministically? Deterministic is better for TPU. 116 | Stochastic is better for adding more randomness to the training process 117 | and handling long sequences. 118 | normalize_embeddings: Boolean, Should the embeddings be normalized or not? 119 | Default is to use raw embeddings. Be careful if you are normalizing the 120 | embeddings before calling this function. 121 | loss_type: String, This specifies the kind of loss function to use. 122 | Currently supported loss functions: classification, regression_mse, 123 | regression_mse_var, regression_huber. 124 | similarity_type: String, Currently supported similarity metrics: l2, cosine. 125 | num_cycles: Integer, number of cycles to match while aligning 126 | stochastically. Only used in the stochastic version. 127 | cycle_length: Integer, Lengths of the cycle to use for matching. Only used 128 | in the stochastic version. By default, this is set to 2. 129 | temperature: Float, temperature scaling used to scale the similarity 130 | distributions calculated using the softmax function. 131 | label_smoothing: Float, Label smoothing argument used in 132 | tf.keras.losses.categorical_crossentropy function and described in this 133 | paper https://arxiv.org/pdf/1701.06548.pdf. 134 | variance_lambda: Float, Weight of the variance of the similarity 135 | predictions while cycling back. If this is high then the low variance 136 | similarities are preferred by the loss while making this term low results 137 | in high variance of the similarities (more uniform/random matching). 138 | huber_delta: float, Huber delta described in tf.keras.losses.huber_loss. 139 | normalize_indices: Boolean, If True, normalizes indices by sequence lengths. 140 | Useful for ensuring numerical instabilities doesn't arise as sequence 141 | indices can be large numbers. 142 | 143 | Returns: 144 | loss: Tensor, Scalar loss tensor that imposes the chosen variant of the 145 | cycle-consistency loss. 146 | """ 147 | 148 | ############################################################################## 149 | # Checking inputs and setting defaults. 150 | ############################################################################## 151 | 152 | # Get the number of timestemps in the sequence embeddings. 153 | num_steps = tf.shape(embs)[1] 154 | 155 | # If steps has not been provided assume sampling has been done uniformly. 156 | if steps is None: 157 | steps = tf.tile(tf.expand_dims(tf.range(num_steps), axis=0), 158 | [batch_size, 1]) 159 | 160 | # If seq_lens has not been provided assume is equal to the size of the 161 | # time axis in the emebeddings. 162 | if seq_lens is None: 163 | seq_lens = tf.tile(tf.expand_dims(num_steps, 0), [batch_size]) 164 | 165 | if not tf.executing_eagerly(): 166 | # Check if batch size embs is consistent with provided batch size. 167 | with tf.control_dependencies([tf.assert_equal(batch_size, 168 | tf.shape(embs)[0])]): 169 | embs = tf.identity(embs) 170 | # Check if number of timesteps in embs is consistent with provided steps. 171 | with tf.control_dependencies([tf.assert_equal(num_steps, 172 | tf.shape(steps)[1]), 173 | tf.assert_equal(batch_size, 174 | tf.shape(steps)[0])]): 175 | steps = tf.identity(steps) 176 | else: 177 | tf.assert_equal(batch_size, tf.shape(steps)[0]) 178 | tf.assert_equal(num_steps, tf.shape(steps)[1]) 179 | tf.assert_equal(batch_size, tf.shape(embs)[0]) 180 | 181 | ############################################################################## 182 | # Perform alignment and return loss. 183 | ############################################################################## 184 | 185 | if normalize_embeddings: 186 | embs = tf.nn.l2_normalize(embs, axis=-1) 187 | 188 | if stochastic_matching: 189 | loss = compute_stochastic_alignment_loss( 190 | embs=embs, 191 | steps=steps, 192 | seq_lens=seq_lens, 193 | num_steps=num_steps, 194 | batch_size=batch_size, 195 | loss_type=loss_type, 196 | similarity_type=similarity_type, 197 | num_cycles=num_cycles, 198 | cycle_length=cycle_length, 199 | temperature=temperature, 200 | label_smoothing=label_smoothing, 201 | variance_lambda=variance_lambda, 202 | huber_delta=huber_delta, 203 | normalize_indices=normalize_indices) 204 | else: 205 | loss = compute_deterministic_alignment_loss( 206 | embs=embs, 207 | steps=steps, 208 | seq_lens=seq_lens, 209 | num_steps=num_steps, 210 | batch_size=batch_size, 211 | loss_type=loss_type, 212 | similarity_type=similarity_type, 213 | temperature=temperature, 214 | label_smoothing=label_smoothing, 215 | variance_lambda=variance_lambda, 216 | huber_delta=huber_delta, 217 | normalize_indices=normalize_indices) 218 | 219 | return loss -------------------------------------------------------------------------------- /tcc_tf/deterministic_alignment.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Deterministic alignment between all pairs of sequences in a batch.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow.compat.v2 as tf 23 | 24 | from .losses import classification_loss 25 | from .losses import regression_loss 26 | 27 | 28 | def pairwise_l2_distance(embs1, embs2): 29 | """Computes pairwise distances between all rows of embs1 and embs2.""" 30 | norm1 = tf.reduce_sum(tf.square(embs1), 1) 31 | norm1 = tf.reshape(norm1, [-1, 1]) 32 | norm2 = tf.reduce_sum(tf.square(embs2), 1) 33 | norm2 = tf.reshape(norm2, [1, -1]) 34 | 35 | # Max to ensure matmul doesn't produce anything negative due to floating 36 | # point approximations. 37 | dist = tf.maximum( 38 | norm1 + norm2 - 2.0 * tf.matmul(embs1, embs2, False, True), 0.0) 39 | 40 | return dist 41 | 42 | 43 | def get_scaled_similarity(embs1, embs2, similarity_type, temperature): 44 | """Returns similarity between each all rows of embs1 and all rows of embs2. 45 | The similarity is scaled by the number of channels/embedding size and 46 | temperature. 47 | Args: 48 | embs1: Tensor, Embeddings of the shape [M, D] where M is the number of 49 | embeddings and D is the embedding size. 50 | embs2: Tensor, Embeddings of the shape [N, D] where N is the number of 51 | embeddings and D is the embedding size. 52 | similarity_type: String, Either one of 'l2' or 'cosine'. 53 | temperature: Float, Temperature used in scaling logits before softmax. 54 | Returns: 55 | similarity: Tensor, [M, N] tensor denoting similarity between embs1 and 56 | embs2. 57 | """ 58 | channels = tf.cast(tf.shape(embs1)[1], tf.float32) 59 | # Go for embs1 to embs2. 60 | if similarity_type == 'cosine': 61 | similarity = tf.matmul(embs1, embs2, transpose_b=True) 62 | elif similarity_type == 'l2': 63 | similarity = -1.0 * pairwise_l2_distance(embs1, embs2) 64 | else: 65 | raise ValueError('similarity_type can either be l2 or cosine.') 66 | 67 | # Scale the distance by number of channels. This normalization helps with 68 | # optimization. 69 | similarity /= channels 70 | # Scale the distance by a temperature that helps with how soft/hard the 71 | # alignment should be. 72 | similarity /= temperature 73 | 74 | return similarity 75 | 76 | 77 | def align_pair_of_sequences(embs1, 78 | embs2, 79 | similarity_type, 80 | temperature): 81 | """Align a given pair embedding sequences. 82 | Args: 83 | embs1: Tensor, Embeddings of the shape [M, D] where M is the number of 84 | embeddings and D is the embedding size. 85 | embs2: Tensor, Embeddings of the shape [N, D] where N is the number of 86 | embeddings and D is the embedding size. 87 | similarity_type: String, Either one of 'l2' or 'cosine'. 88 | temperature: Float, Temperature used in scaling logits before softmax. 89 | Returns: 90 | logits: Tensor, Pre-softmax similarity scores after cycling back to the 91 | starting sequence. 92 | labels: Tensor, One hot labels containing the ground truth. The index where 93 | the cycle started is 1. 94 | """ 95 | max_num_steps = tf.shape(embs1)[0] 96 | 97 | # Find distances between embs1 and embs2. 98 | sim_12 = get_scaled_similarity(embs1, embs2, similarity_type, temperature) 99 | # Softmax the distance. 100 | softmaxed_sim_12 = tf.nn.softmax(sim_12, axis=1) 101 | 102 | # Calculate soft-nearest neighbors. 103 | nn_embs = tf.matmul(softmaxed_sim_12, embs2) 104 | 105 | # Find distances between nn_embs and embs1. 106 | sim_21 = get_scaled_similarity(nn_embs, embs1, similarity_type, temperature) 107 | 108 | logits = sim_21 109 | labels = tf.one_hot(tf.range(max_num_steps), max_num_steps) 110 | 111 | return logits, labels 112 | 113 | 114 | def compute_deterministic_alignment_loss(embs, 115 | steps, 116 | seq_lens, 117 | num_steps, 118 | batch_size, 119 | loss_type, 120 | similarity_type, 121 | temperature, 122 | label_smoothing, 123 | variance_lambda, 124 | huber_delta, 125 | normalize_indices): 126 | """Compute cycle-consistency loss for all steps in each sequence. 127 | This aligns each pair of videos in the batch except with itself. 128 | When aligning it also matters which video is the starting video. So for N 129 | videos in the batch, we have N * (N-1) alignments happening. 130 | For example, a batch of size 3 has 6 pairs of sequence alignments. 131 | Args: 132 | embs: Tensor, sequential embeddings of the shape [N, T, D] where N is the 133 | batch size, T is the number of timesteps in the sequence, D is the size 134 | of the embeddings. 135 | steps: Tensor, step indices/frame indices of the embeddings of the shape 136 | [N, T] where N is the batch size, T is the number of the timesteps. 137 | seq_lens: Tensor, Lengths of the sequences from which the sampling was 138 | done. This can provide additional information to the alignment loss. 139 | num_steps: Integer/Tensor, Number of timesteps in the embeddings. 140 | batch_size: Integer, Size of the batch. 141 | loss_type: String, This specifies the kind of loss function to use. 142 | Currently supported loss functions: 'classification', 'regression_mse', 143 | 'regression_mse_var', 'regression_huber'. 144 | similarity_type: String, Currently supported similarity metrics: 'l2' , 145 | 'cosine' . 146 | temperature: Float, temperature scaling used to scale the similarity 147 | distributions calculated using the softmax function. 148 | label_smoothing: Float, Label smoothing argument used in 149 | tf.keras.losses.categorical_crossentropy function and described in this 150 | paper https://arxiv.org/pdf/1701.06548.pdf. 151 | variance_lambda: Float, Weight of the variance of the similarity 152 | predictions while cycling back. If this is high then the low variance 153 | similarities are preferred by the loss while making this term low 154 | results in high variance of the similarities (more uniform/random 155 | matching). 156 | huber_delta: float, Huber delta described in tf.keras.losses.huber_loss. 157 | normalize_indices: Boolean, If True, normalizes indices by sequence 158 | lengths. Useful for ensuring numerical instabilities doesn't arise as 159 | sequence indices can be large numbers. 160 | Returns: 161 | loss: Tensor, Scalar loss tensor that imposes the chosen variant of the 162 | cycle-consistency loss. 163 | """ 164 | labels_list = [] 165 | logits_list = [] 166 | steps_list = [] 167 | seq_lens_list = [] 168 | 169 | for i in range(batch_size): 170 | for j in range(batch_size): 171 | # We do not align the sequence with itself. 172 | if i != j: 173 | logits, labels = align_pair_of_sequences(embs[i], 174 | embs[j], 175 | similarity_type, 176 | temperature) 177 | logits_list.append(logits) 178 | labels_list.append(labels) 179 | steps_list.append(tf.tile(steps[i:i+1], [num_steps, 1])) 180 | seq_lens_list.append(tf.tile(seq_lens[i:i+1], [num_steps])) 181 | 182 | logits = tf.concat(logits_list, axis=0) 183 | labels = tf.concat(labels_list, axis=0) 184 | steps = tf.concat(steps_list, axis=0) 185 | seq_lens = tf.concat(seq_lens_list, axis=0) 186 | 187 | if loss_type == 'classification': 188 | loss = classification_loss(logits, labels, label_smoothing) 189 | elif 'regression' in loss_type: 190 | 191 | loss = regression_loss(logits, labels, num_steps, steps, seq_lens, 192 | loss_type, normalize_indices, variance_lambda, 193 | huber_delta) 194 | else: 195 | raise ValueError('Unidentified loss_type %s. Currently supported loss ' 196 | 'types are: regression_mse, regression_huber, ' 197 | 'classification.' % loss_type) 198 | 199 | return loss 200 | -------------------------------------------------------------------------------- /tcc_tf/losses.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Loss functions imposing the cycle-consistency constraints.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow.compat.v2 as tf 23 | 24 | 25 | def classification_loss(logits, labels, label_smoothing): 26 | """Loss function based on classifying the correct indices. 27 | In the paper, this is called Cycle-back Classification. 28 | Args: 29 | logits: Tensor, Pre-softmax scores used for classification loss. These are 30 | similarity scores after cycling back to the starting sequence. 31 | labels: Tensor, One hot labels containing the ground truth. The index where 32 | the cycle started is 1. 33 | label_smoothing: Float, label smoothing factor which can be used to 34 | determine how hard the alignment should be. 35 | Returns: 36 | loss: Tensor, A scalar classification loss calculated using standard softmax 37 | cross-entropy loss. 38 | """ 39 | # Just to be safe, we stop gradients from labels as we are generating labels. 40 | labels = tf.stop_gradient(labels) 41 | return tf.reduce_mean(tf.keras.losses.categorical_crossentropy( 42 | y_true=labels, y_pred=logits, from_logits=True, 43 | label_smoothing=label_smoothing)) 44 | 45 | 46 | def regression_loss(logits, labels, num_steps, steps, seq_lens, loss_type, 47 | normalize_indices, variance_lambda, huber_delta): 48 | """Loss function based on regressing to the correct indices. 49 | In the paper, this is called Cycle-back Regression. There are 3 variants 50 | of this loss: 51 | i) regression_mse: MSE of the predicted indices and ground truth indices. 52 | ii) regression_mse_var: MSE of the predicted indices that takes into account 53 | the variance of the similarities. This is important when the rate at which 54 | sequences go through different phases changes a lot. The variance scaling 55 | allows dynamic weighting of the MSE loss based on the similarities. 56 | iii) regression_huber: Huber loss between the predicted indices and ground 57 | truth indices. 58 | Args: 59 | logits: Tensor, Pre-softmax similarity scores after cycling back to the 60 | starting sequence. 61 | labels: Tensor, One hot labels containing the ground truth. The index where 62 | the cycle started is 1. 63 | num_steps: Integer, Number of steps in the sequence embeddings. 64 | steps: Tensor, step indices/frame indices of the embeddings of the shape 65 | [N, T] where N is the batch size, T is the number of the timesteps. 66 | seq_lens: Tensor, Lengths of the sequences from which the sampling was done. 67 | This can provide additional temporal information to the alignment loss. 68 | loss_type: String, This specifies the kind of regression loss function. 69 | Currently supported loss functions: regression_mse, regression_mse_var, 70 | regression_huber. 71 | normalize_indices: Boolean, If True, normalizes indices by sequence lengths. 72 | Useful for ensuring numerical instabilities don't arise as sequence 73 | indices can be large numbers. 74 | variance_lambda: Float, Weight of the variance of the similarity 75 | predictions while cycling back. If this is high then the low variance 76 | similarities are preferred by the loss while making this term low results 77 | in high variance of the similarities (more uniform/random matching). 78 | huber_delta: float, Huber delta described in tf.keras.losses.huber_loss. 79 | Returns: 80 | loss: Tensor, A scalar loss calculated using a variant of regression. 81 | """ 82 | # Just to be safe, we stop gradients from labels as we are generating labels. 83 | labels = tf.stop_gradient(labels) 84 | steps = tf.stop_gradient(steps) 85 | 86 | if normalize_indices: 87 | float_seq_lens = tf.cast(seq_lens, tf.float32) 88 | tile_seq_lens = tf.tile( 89 | tf.expand_dims(float_seq_lens, axis=1), [1, num_steps]) 90 | steps = tf.cast(steps, tf.float32) / tile_seq_lens 91 | else: 92 | steps = tf.cast(steps, tf.float32) 93 | 94 | beta = tf.nn.softmax(logits) 95 | true_time = tf.reduce_sum(steps * labels, axis=1) 96 | pred_time = tf.reduce_sum(steps * beta, axis=1) 97 | 98 | if loss_type in ['regression_mse', 'regression_mse_var']: 99 | if 'var' in loss_type: 100 | # Variance aware regression. 101 | pred_time_tiled = tf.tile(tf.expand_dims(pred_time, axis=1), 102 | [1, num_steps]) 103 | 104 | pred_time_variance = tf.reduce_sum( 105 | tf.square(steps - pred_time_tiled) * beta, axis=1) 106 | 107 | # Using log of variance as it is numerically stabler. 108 | pred_time_log_var = tf.math.log(pred_time_variance) 109 | squared_error = tf.square(true_time - pred_time) 110 | return tf.reduce_mean(tf.math.exp(-pred_time_log_var) * squared_error 111 | + variance_lambda * pred_time_log_var) 112 | 113 | else: 114 | return tf.reduce_mean( 115 | tf.keras.losses.mean_squared_error(y_true=true_time, 116 | y_pred=pred_time)) 117 | elif loss_type == 'regression_huber': 118 | return tf.reduce_mean(tf.keras.losses.huber_loss( 119 | y_true=true_time, y_pred=pred_time, 120 | delta=huber_delta)) 121 | else: 122 | raise ValueError('Unsupported regression loss %s. Supported losses are: ' 123 | 'regression_mse, regresstion_mse_var and regression_huber.' 124 | % loss_type) 125 | -------------------------------------------------------------------------------- /tcc_tf/stochastic_alignment.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Stochastic alignment between sampled cycles in the sequences in a batch.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow.compat.v2 as tf 23 | 24 | from .losses import classification_loss 25 | from .losses import regression_loss 26 | 27 | 28 | def _align_single_cycle(cycle, embs, cycle_length, num_steps, 29 | similarity_type, temperature): 30 | """Takes a single cycle and returns logits (simialrity scores) and labels.""" 31 | # Choose random frame. 32 | n_idx = tf.random.uniform((), minval=0, maxval=num_steps, dtype=tf.int32) 33 | # n_idx = 8 34 | # Create labels 35 | onehot_labels = tf.one_hot(n_idx, num_steps) 36 | 37 | # Choose query feats for first frame. 38 | query_feats = embs[cycle[0], n_idx:n_idx+1] 39 | 40 | num_channels = tf.shape(query_feats)[-1] 41 | for c in range(1, cycle_length+1): 42 | candidate_feats = embs[cycle[c]] 43 | 44 | if similarity_type == 'l2': 45 | # Find L2 distance. 46 | mean_squared_distance = tf.reduce_sum( 47 | tf.math.squared_difference(tf.tile(query_feats, [num_steps, 1]), 48 | candidate_feats), axis=1) 49 | # Convert L2 distance to similarity. 50 | similarity = -mean_squared_distance 51 | 52 | elif similarity_type == 'cosine': 53 | # Dot product of embeddings. 54 | similarity = tf.squeeze(tf.matmul(candidate_feats, query_feats, 55 | transpose_b=True)) 56 | else: 57 | raise ValueError('similarity_type can either be l2 or cosine.') 58 | 59 | # Scale the distance by number of channels. This normalization helps with 60 | # optimization. 61 | similarity /= tf.cast(num_channels, tf.float32) 62 | # Scale the distance by a temperature that helps with how soft/hard the 63 | # alignment should be. 64 | similarity /= temperature 65 | 66 | beta = tf.nn.softmax(similarity) 67 | beta = tf.expand_dims(beta, axis=1) 68 | beta = tf.tile(beta, [1, num_channels]) 69 | 70 | # Find weighted nearest neighbour. 71 | query_feats = tf.reduce_sum(beta * candidate_feats, 72 | axis=0, keepdims=True) 73 | 74 | return similarity, onehot_labels 75 | 76 | 77 | def _align(cycles, embs, num_steps, num_cycles, cycle_length, 78 | similarity_type, temperature): 79 | """Align by finding cycles in embs.""" 80 | logits_list = [] 81 | labels_list = [] 82 | for i in range(num_cycles): 83 | logits, labels = _align_single_cycle(cycles[i], 84 | embs, 85 | cycle_length, 86 | num_steps, 87 | similarity_type, 88 | temperature) 89 | logits_list.append(logits) 90 | labels_list.append(labels) 91 | 92 | logits = tf.stack(logits_list) 93 | labels = tf.stack(labels_list) 94 | 95 | return logits, labels 96 | 97 | 98 | def gen_cycles(num_cycles, batch_size, cycle_length=2): 99 | """Generates cycles for alignment. 100 | Generates a batch of indices to cycle over. For example setting num_cycles=2, 101 | batch_size=5, cycle_length=3 might return something like this: 102 | cycles = [[0, 3, 4, 0], [1, 2, 0, 3]]. This means we have 2 cycles for which 103 | the loss will be calculated. The first cycle starts at sequence 0 of the 104 | batch, then we find a matching step in sequence 3 of that batch, then we 105 | find matching step in sequence 4 and finally come back to sequence 0, 106 | completing a cycle. 107 | Args: 108 | num_cycles: Integer, Number of cycles that will be matched in one pass. 109 | batch_size: Integer, Number of sequences in one batch. 110 | cycle_length: Integer, Length of the cycles. If we are matching between 111 | 2 sequences (cycle_length=2), we get cycles that look like [0,1,0]. 112 | This means that we go from sequence 0 to sequence 1 then back to sequence 113 | 0. A cycle length of 3 might look like [0, 1, 2, 0]. 114 | Returns: 115 | cycles: Tensor, Batch indices denoting cycles that will be used for 116 | calculating the alignment loss. 117 | """ 118 | sorted_idxes = tf.tile(tf.expand_dims(tf.range(batch_size), 0), 119 | [num_cycles, 1]) 120 | sorted_idxes = tf.reshape(sorted_idxes, [batch_size, num_cycles]) 121 | cycles = tf.reshape(tf.random.shuffle(sorted_idxes), 122 | [num_cycles, batch_size]) 123 | cycles = cycles[:, :cycle_length] 124 | # Append the first index at the end to create cycle. 125 | cycles = tf.concat([cycles, cycles[:, 0:1]], axis=1) 126 | return cycles 127 | 128 | 129 | def compute_stochastic_alignment_loss(embs, 130 | steps, 131 | seq_lens, 132 | num_steps, 133 | batch_size, 134 | loss_type, 135 | similarity_type, 136 | num_cycles, 137 | cycle_length, 138 | temperature, 139 | label_smoothing, 140 | variance_lambda, 141 | huber_delta, 142 | normalize_indices): 143 | """Compute cycle-consistency loss by stochastically sampling cycles. 144 | Args: 145 | embs: Tensor, sequential embeddings of the shape [N, T, D] where N is the 146 | batch size, T is the number of timesteps in the sequence, D is the size of 147 | the embeddings. 148 | steps: Tensor, step indices/frame indices of the embeddings of the shape 149 | [N, T] where N is the batch size, T is the number of the timesteps. 150 | seq_lens: Tensor, Lengths of the sequences from which the sampling was done. 151 | This can provide additional information to the alignment loss. 152 | num_steps: Integer/Tensor, Number of timesteps in the embeddings. 153 | batch_size: Integer/Tensor, Batch size. 154 | loss_type: String, This specifies the kind of loss function to use. 155 | Currently supported loss functions: 'classification', 'regression_mse', 156 | 'regression_mse_var', 'regression_huber'. 157 | similarity_type: String, Currently supported similarity metrics: 'l2', 158 | 'cosine'. 159 | num_cycles: Integer, number of cycles to match while aligning 160 | stochastically. Only used in the stochastic version. 161 | cycle_length: Integer, Lengths of the cycle to use for matching. Only used 162 | in the stochastic version. By default, this is set to 2. 163 | temperature: Float, temperature scaling used to scale the similarity 164 | distributions calculated using the softmax function. 165 | label_smoothing: Float, Label smoothing argument used in 166 | tf.keras.losses.categorical_crossentropy function and described in this 167 | paper https://arxiv.org/pdf/1701.06548.pdf. 168 | variance_lambda: Float, Weight of the variance of the similarity 169 | predictions while cycling back. If this is high then the low variance 170 | similarities are preferred by the loss while making this term low results 171 | in high variance of the similarities (more uniform/random matching). 172 | huber_delta: float, Huber delta described in tf.keras.losses.huber_loss. 173 | normalize_indices: Boolean, If True, normalizes indices by sequence lengths. 174 | Useful for ensuring numerical instabilities doesn't arise as sequence 175 | indices can be large numbers. 176 | Returns: 177 | loss: Tensor, Scalar loss tensor that imposes the chosen variant of the 178 | cycle-consistency loss. 179 | """ 180 | # Generate cycles. 181 | cycles = gen_cycles(num_cycles, batch_size, cycle_length) 182 | 183 | logits, labels = _align(cycles, embs, num_steps, num_cycles, cycle_length, 184 | similarity_type, temperature) 185 | 186 | if loss_type == 'classification': 187 | loss = classification_loss(logits, labels, label_smoothing) 188 | elif 'regression' in loss_type: 189 | steps = tf.gather(steps, cycles[:, 0]) 190 | seq_lens = tf.gather(seq_lens, cycles[:, 0]) 191 | loss = regression_loss(logits, labels, num_steps, steps, seq_lens, 192 | loss_type, normalize_indices, variance_lambda, 193 | huber_delta) 194 | else: 195 | raise ValueError('Unidentified loss type %s. Currently supported loss ' 196 | 'types are: regression_mse, regression_huber, ' 197 | 'classification .' 198 | % loss_type) 199 | return loss 200 | -------------------------------------------------------------------------------- /test_tcc.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## This script is used to test tcc loss both in tensorflow version and pytorch version" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "405722.8\n" 20 | ] 21 | } 22 | ], 23 | "source": [ 24 | "# First, prepare input to the loss\n", 25 | "import numpy as np\n", 26 | "np.random.seed(0)\n", 27 | "batch_size = 16\n", 28 | "emb_np = np.random.rand(16,198,256).astype(np.float32)\n", 29 | "print(np.sum(emb_np))" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "- Tensorflow" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "name": "stdout", 46 | "output_type": "stream", 47 | "text": [ 48 | "tf.Tensor(5.2371726, shape=(), dtype=float32)\n" 49 | ] 50 | } 51 | ], 52 | "source": [ 53 | "from tcc_tf.alignment import compute_alignment_loss\n", 54 | "import tensorflow.compat.v2 as tf\n", 55 | "\n", 56 | "embs = tf.convert_to_tensor(emb_np)\n", 57 | "loss = compute_alignment_loss(embs, batch_size, stochastic_matching=True)\n", 58 | "print(loss)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "- Pytorch" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 3, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | "tensor(5.2306)\n" 78 | ] 79 | } 80 | ], 81 | "source": [ 82 | "from tcc.alignment import compute_alignment_loss\n", 83 | "import torch\n", 84 | "\n", 85 | "embs = torch.from_numpy(emb_np)\n", 86 | "loss = compute_alignment_loss(embs, batch_size, stochastic_matching=True)\n", 87 | "print(loss)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "### Note: This repo only provide pytorch version of stochastic_alignment and classification_loss, due to time limit. The reason of results are slightly different is the randomness in the code, 1) gen_cycles(), the cycle generation are random; 2) _align_single_cycle(), the frame choosing is random. It has been verified, if these two parts are determined, the results are exactly the same." 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [] 103 | } 104 | ], 105 | "metadata": { 106 | "kernelspec": { 107 | "display_name": "Python 3", 108 | "language": "python", 109 | "name": "python3" 110 | }, 111 | "language_info": { 112 | "codemirror_mode": { 113 | "name": "ipython", 114 | "version": 3 115 | }, 116 | "file_extension": ".py", 117 | "mimetype": "text/x-python", 118 | "name": "python", 119 | "nbconvert_exporter": "python", 120 | "pygments_lexer": "ipython3", 121 | "version": "3.8.5" 122 | } 123 | }, 124 | "nbformat": 4, 125 | "nbformat_minor": 4 126 | } 127 | --------------------------------------------------------------------------------