├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── entropy_reconstruction.py └── requirements.txt /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2023 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # The Role of Entropy and Reconstruction in Multi-View Self-Supervised Learning 2 | 3 | This software project accompanies the ICML 2023 research paper, [The Role of Entropy and Reconstruction in Multi-View Self-Supervised Learning](https://openreview.net/forum?id=YJ3ytyemn1). It contains the code to compute the entropy and reconstruction used to train the `ER` models. 4 | 5 | ```bibtex 6 | @InProceedings{ 7 | rodriguez2023er, 8 | title={The Role of Entropy and Reconstruction for Multi-View Self-Supervised Learning}, 9 | author={Borja Rodriguez-Galvez and Arno Blaas and Pau Rodriguez and Adam Golinski and Xavier Suau and Jason Ramapuram and Dan Busbridge and Luca Zappella}, 10 | year={2023}, 11 | booktitle={ICML}, 12 | } 13 | ``` 14 | 15 | ## Abstract 16 | 17 | The mechanisms behind the success of multi-view self-supervised learning (MVSSL) are not yet fully understood Contrastive MVSSL methods have been studied through the lens of InfoNCE, a lower bound of the Mutual Information (MI). However, the relation between other MVSSL methods and MI remains unclear. We consider a different lower bound on the MI consisting of an entropy and a reconstruction term (ER), and analyze the main MVSSL families through its lens. Through this ER bound, we show that clustering-based methods such as DeepCluster and SwAV maximize the MI. We also re-interpret the mechanisms of distillation-based approaches such as BYOL and DINO, showing that they explicitly maximize the reconstruction term and implicitly encourage a stable entropy, and we confirm this empirically. We show that replacing the objectives of common MVSSL methods with this ER bound achieves competitive performance, while making them stable when training with smaller batch sizes or smaller exponential moving average (EMA) coefficients. 18 | 19 | 20 | ## Documentation 21 | ### Install dependencies 22 | `pip -r requirements.txt` 23 | 24 | ### Getting Started 25 | To verify that the code is working as expected, run: 26 | 27 | `python entropy_reconstruction.py` 28 | 29 | It should return the following: 30 | 31 | ``` 32 | Continuous entropy 33 | Entropy sphere plug-in estimator: -249.30848693847656 34 | Entropy sphere Joe's estimator: -249.30833435058594 35 | Discrete entropy 36 | Max entropy 6.907755278982137 37 | Entropy on uniform sample: 6.907612323760986 38 | Entropy on uniform sample (high temp): 6.695544242858887 39 | Entropy on one_hot vector: -1.1920928244535389e-07 40 | Continuous reconstruction 41 | Reconstruction error: -248.65525817871094 42 | Discrete reconstruction 43 | Reconstruction error: 6.100684642791748 44 | ``` 45 | 46 | ### Computing entropy 47 | ```python 48 | def entropy( 49 | embeddings: torch.Tensor, 50 | kappa: float = 10, 51 | support: str = "sphere", 52 | reduction: str = "expectation", 53 | ) -> torch.Tensor: 54 | """Computes the entropy from a tensor of embeddings 55 | 56 | :param embeddings: tensor containing a batch of embeddings 57 | :type embeddings: torch.Tensor 58 | :param kappa: von Misses-Fisher Kappa (https://en.wikipedia.org/wiki/Von_Mises-Fisher_distribution), defaults to 10 59 | :type kappa: float, optional 60 | :param support: support of the random variables. Sphere or discrete, defaults to "sphere" 61 | :type support: str, optional 62 | :param reduction: "average" for Joe's estimator and "expectation" for the plug-in estimator (see Section 4.1), defaults to "expectation" 63 | :type reduction: str, optional 64 | :return: entropy value 65 | :rtype: torch.Tensor 66 | """ 67 | ``` 68 | 69 | ### Computing the reconstruction loss 70 | ```python 71 | def reconstruction( 72 | projection1: torch.Tensor, 73 | projection2: torch.Tensor, 74 | kappa: float = 10, 75 | support: str = "sphere", 76 | ) -> torch.Tensor: 77 | """Reconstruction error from ER 78 | 79 | :param projection1: projection of augmentation1 80 | :type projection1: torch.Tensor 81 | :param projection2: projection of augmentation2 82 | :type projection2: torch.Tensor 83 | :param kappa: von Misses-Fisher kappa, defaults to 10 84 | :type kappa: float, optional 85 | :param support: support of the random variables, defaults to "sphere" 86 | :type support: str, optional 87 | :return: reconstruction error 88 | :rtype: torch.Tensor 89 | """ 90 | ``` 91 | -------------------------------------------------------------------------------- /entropy_reconstruction.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2020 Apple Inc. All Rights Reserved. 4 | # 5 | import math 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from scipy.special import iv 10 | 11 | eps = 1e-7 12 | 13 | 14 | def entropy( 15 | embeddings: torch.Tensor, 16 | kappa: float = 10, 17 | support: str = "sphere", 18 | reduction: str = "expectation", 19 | ) -> torch.Tensor: 20 | """Computes the entropy from a tensor of embeddings 21 | 22 | :param embeddings: tensor containing a batch of embeddings 23 | :type embeddings: torch.Tensor 24 | :param kappa: von Misses-Fisher Kappa (https://en.wikipedia.org/wiki/Von_Mises-Fisher_distribution), defaults to 10 25 | :type kappa: float, optional 26 | :param support: support of the random variables. Sphere or discrete, defaults to "sphere" 27 | :type support: str, optional 28 | :param reduction: "average" for Joe's estimator and "expectation" for the plug-in estimator (see Section 4.1), defaults to "expectation" 29 | :type reduction: str, optional 30 | :return: entropy value 31 | :rtype: torch.Tensor 32 | """ 33 | k = embeddings.shape[0] 34 | d = embeddings.shape[1] 35 | 36 | if support == "sphere": 37 | # If the support is in the sphere, the received random variable is Z 38 | # and it belongs to S^{d-1} 39 | csim = kappa * torch.matmul(embeddings, embeddings.T) 40 | const = ( 41 | -math.log(kappa) * (d * 0.5 - 1) 42 | + 0.5 * d * math.log(2 * math.pi) 43 | + math.log(iv(0.5 * d - 1, kappa) + 1e-7) 44 | + math.log(k) 45 | ) 46 | if reduction == "average": 47 | entropy = -torch.logsumexp(csim, dim=-1) + const # -> log(p).sum 48 | entropy = entropy.mean() 49 | elif reduction == "expectation": 50 | logp = -torch.logsumexp(csim, dim=-1) + const 51 | entropy = F.softmax(-logp, dim=-1) * logp 52 | entropy = entropy.sum() 53 | else: 54 | raise NotImplementedError(f"Reduction type {reduction} not implemented") 55 | elif support == "discrete": 56 | # If the support is discrete, the received random variable is W and 57 | # it belongs to [d] 58 | embeddings_mean = embeddings.mean(0) 59 | if reduction == "expectation": 60 | entropy = -(embeddings_mean * torch.log(embeddings_mean + eps)).sum() 61 | elif reduction == "average": 62 | entropy = -torch.log(embeddings_mean + eps).mean() 63 | else: 64 | raise NotImplementedError(f"Reduction type {reduction} not implemented") 65 | else: 66 | raise NotImplementedError(f"Support type {support} not implemented") 67 | 68 | return entropy 69 | 70 | 71 | def reconstruction( 72 | projection1: torch.Tensor, 73 | projection2: torch.Tensor, 74 | kappa: float = 10, 75 | support: str = "sphere", 76 | ) -> torch.Tensor: 77 | """Reconstruction error from ER 78 | 79 | :param projection1: projection of augmentation1 80 | :type projection1: torch.Tensor 81 | :param projection2: projection of augmentation2 82 | :type projection2: torch.Tensor 83 | :param kappa: von Misses-Fisher kappa (https://en.wikipedia.org/wiki/Von_Mises-Fisher_distribution), defaults to 10 84 | :type kappa: float, optional 85 | :param support: support of the random variables, defaults to "sphere" 86 | :type support: str, optional 87 | :return: reconstruction error 88 | :rtype: torch.Tensor 89 | """ 90 | d = projection1.shape[1] 91 | 92 | if support == "sphere": 93 | # If the support is in the sphere, the reconstruction is done with a 94 | # von Mises--Fisher distribution 95 | 96 | const = ( 97 | -(0.5 * d - 1) * math.log(kappa) 98 | + 0.5 * d * math.log(2 * math.pi) 99 | + math.log(iv(0.5 * d - 1, kappa) + 1e-7) 100 | ) 101 | csim = kappa * torch.sum(projection1 * projection2, dim=-1) 102 | rec = -csim + const 103 | 104 | elif support == "discrete": 105 | # If the support is discrete, the received random variable is W and 106 | # it belongs to [d] 107 | rec1 = -torch.sum(projection1 * torch.log(projection2 + eps), dim=-1) 108 | rec2 = -torch.sum(projection2 * torch.log(projection1 + eps), dim=-1) 109 | rec = 0.5 * (rec1 + rec2) 110 | 111 | rec_mean = rec.mean() 112 | return rec_mean 113 | 114 | 115 | if __name__ == "__main__": 116 | # Some checks and usage examples 117 | embeddings = torch.randn(1000, 1000) 118 | embeddings = F.normalize(embeddings, 2, 1) 119 | print("Continuous entropy") 120 | print( 121 | "Entropy sphere plug-in estimator:", 122 | float(entropy(embeddings, support="sphere", reduction="expectation")), 123 | ) 124 | print( 125 | "Entropy sphere Joe's estimator:", 126 | float(entropy(embeddings, support="sphere", reduction="average")), 127 | ) 128 | print("Discrete entropy") 129 | print("Max entropy", math.log(1000)) 130 | embeddings = F.softmax(torch.rand(1000, 1000), -1) 131 | entropy_uniform = entropy( 132 | embeddings, kappa=1, support="discrete", reduction="expectation" 133 | ) 134 | print( 135 | "Entropy on uniform sample:", 136 | float(entropy_uniform), 137 | ) 138 | embeddings = F.softmax(1000 * torch.rand(1000, 1000), -1) 139 | entropy_high_temp = entropy( 140 | embeddings, kappa=1, support="discrete", reduction="expectation" 141 | ) 142 | print( 143 | "Entropy on uniform sample (high temp):", 144 | float(entropy_high_temp), 145 | ) 146 | assert entropy_high_temp < entropy_uniform 147 | embeddings = torch.zeros_like(embeddings) 148 | embeddings[:, 0] = 1000 149 | embeddings = F.softmax(embeddings, -1) 150 | entropy_one_hot = entropy( 151 | embeddings, kappa=1, support="discrete", reduction="expectation" 152 | ) 153 | assert entropy_one_hot < entropy_uniform 154 | print("Entropy on one_hot vector:", float(entropy_one_hot)) 155 | print("Continuous reconstruction") 156 | embeddings = torch.rand(1000, 1000) 157 | projection1 = embeddings + 1 * torch.randn(1000, 1000) 158 | projection1 = F.normalize(projection1, 2, 1) 159 | projection2 = embeddings + 1 * torch.randn(1000, 1000) 160 | projection2 = F.normalize(projection2, 2, 1) 161 | assert reconstruction(projection1, projection1) < reconstruction( 162 | projection1, projection2 163 | ) 164 | print("Reconstruction error:", float(reconstruction(projection1, projection2))) 165 | print("Discrete reconstruction") 166 | embeddings = torch.rand(1000, 1000) 167 | projection1 = 10 * embeddings + torch.randn(1000, 1000) 168 | projection1 = F.softmax(projection1, -1) 169 | projection2 = 10 * embeddings + torch.randn(1000, 1000) 170 | projection2 = F.softmax(projection2, -1) 171 | print( 172 | "Reconstruction error:", 173 | float(reconstruction(projection1, projection2, support="discrete")), 174 | ) 175 | assert reconstruction(projection1, projection1) < reconstruction( 176 | projection1, projection2 177 | ) 178 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | scipy --------------------------------------------------------------------------------