├── .gitignore
├── 1_max_margin_contrastive_loss_vs_baseline.ipynb
├── 2_comparing_contrastive_losses.ipynb
├── Main_contrast_loss-regression.ipynb
├── Plot_learning_curves.ipynb
├── README.ipynb
├── README.md
├── lars_optimizer.py
├── losses.py
├── main.py
├── main_ce_baseline.py
├── model.py
├── requirements.txt
├── supcontrast.py
└── test_supcontrast_loss.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | .ipynb_checkpoints
3 | refs/
4 | venv/
5 | logs/
6 | runs/
7 | figs/
8 | img/
9 | .DS_Store
10 |
--------------------------------------------------------------------------------
/Plot_learning_curves.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 11,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "\n",
11 | "import matplotlib.pyplot as plt\n",
12 | "%matplotlib inline\n",
13 | "import seaborn as sns\n",
14 | "sns.set_context('talk')"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 4,
20 | "metadata": {},
21 | "outputs": [
22 | {
23 | "name": "stdout",
24 | "output_type": "stream",
25 | "text": [
26 | "run-baseline_fashion_mnist_20200427-210140_test-tag-accuracy.csv\r\n",
27 | "run-baseline_mnist_20200427-205525_test-tag-accuracy.csv\r\n",
28 | "run-contrast_loss_model_fashion_mnist_20200427-210420_test-tag-accuracy.csv\r\n",
29 | "run-contrast_loss_model_mnist_20200427-205809_test-tag-accuracy.csv\r\n"
30 | ]
31 | }
32 | ],
33 | "source": [
34 | "!ls runs/"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 7,
40 | "metadata": {},
41 | "outputs": [
42 | {
43 | "data": {
44 | "text/html": [
45 | "
\n",
46 | "\n",
59 | "
\n",
60 | " \n",
61 | " \n",
62 | " | \n",
63 | " Wall time | \n",
64 | " Step | \n",
65 | " Value | \n",
66 | " model | \n",
67 | "
\n",
68 | " \n",
69 | " \n",
70 | " \n",
71 | " 0 | \n",
72 | " 1.588035e+09 | \n",
73 | " 0 | \n",
74 | " 0.9456 | \n",
75 | " MLP | \n",
76 | "
\n",
77 | " \n",
78 | " 1 | \n",
79 | " 1.588035e+09 | \n",
80 | " 1 | \n",
81 | " 0.9506 | \n",
82 | " MLP | \n",
83 | "
\n",
84 | " \n",
85 | " 2 | \n",
86 | " 1.588035e+09 | \n",
87 | " 2 | \n",
88 | " 0.9616 | \n",
89 | " MLP | \n",
90 | "
\n",
91 | " \n",
92 | " 3 | \n",
93 | " 1.588035e+09 | \n",
94 | " 3 | \n",
95 | " 0.9685 | \n",
96 | " MLP | \n",
97 | "
\n",
98 | " \n",
99 | " 4 | \n",
100 | " 1.588035e+09 | \n",
101 | " 4 | \n",
102 | " 0.9634 | \n",
103 | " MLP | \n",
104 | "
\n",
105 | " \n",
106 | "
\n",
107 | "
"
108 | ],
109 | "text/plain": [
110 | " Wall time Step Value model\n",
111 | "0 1.588035e+09 0 0.9456 MLP\n",
112 | "1 1.588035e+09 1 0.9506 MLP\n",
113 | "2 1.588035e+09 2 0.9616 MLP\n",
114 | "3 1.588035e+09 3 0.9685 MLP\n",
115 | "4 1.588035e+09 4 0.9634 MLP"
116 | ]
117 | },
118 | "execution_count": 7,
119 | "metadata": {},
120 | "output_type": "execute_result"
121 | }
122 | ],
123 | "source": [
124 | "df1 = pd.read_csv('runs/run-baseline_mnist_20200427-205525_test-tag-accuracy.csv')\n",
125 | "df1.head()"
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": 8,
131 | "metadata": {},
132 | "outputs": [],
133 | "source": [
134 | "df2 = pd.read_csv('runs/run-contrast_loss_model_mnist_20200427-205809_test-tag-accuracy.csv')"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": 14,
140 | "metadata": {},
141 | "outputs": [
142 | {
143 | "data": {
144 | "image/png": "\n",
145 | "text/plain": [
146 | ""
147 | ]
148 | },
149 | "metadata": {
150 | "needs_background": "light"
151 | },
152 | "output_type": "display_data"
153 | }
154 | ],
155 | "source": [
156 | "fig, ax = plt.subplots()\n",
157 | "ax.plot(df1['Step'], df1['Value'], label='MLP baseline')\n",
158 | "ax.plot(df2['Step'], df2['Value'], label='Contrastive')\n",
159 | "\n",
160 | "ax.set(xlabel='Epoch', ylabel='Accuracy (Test set)', title='MNIST dataset');\n",
161 | "ax.legend();\n",
162 | "fig.savefig('figs/mnist_test_acc_curves.png')"
163 | ]
164 | },
165 | {
166 | "cell_type": "code",
167 | "execution_count": 15,
168 | "metadata": {},
169 | "outputs": [],
170 | "source": [
171 | "df1 = pd.read_csv('runs/run-baseline_fashion_mnist_20200427-210140_test-tag-accuracy.csv')\n",
172 | "df2 = pd.read_csv('runs/run-contrast_loss_model_fashion_mnist_20200427-210420_test-tag-accuracy.csv')"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": 16,
178 | "metadata": {},
179 | "outputs": [
180 | {
181 | "data": {
182 | "image/png": "\n",
183 | "text/plain": [
184 | ""
185 | ]
186 | },
187 | "metadata": {
188 | "needs_background": "light"
189 | },
190 | "output_type": "display_data"
191 | }
192 | ],
193 | "source": [
194 | "fig, ax = plt.subplots()\n",
195 | "ax.plot(df1['Step'], df1['Value'], label='MLP baseline')\n",
196 | "ax.plot(df2['Step'], df2['Value'], label='Contrastive')\n",
197 | "\n",
198 | "ax.set(xlabel='Epoch', ylabel='Accuracy (Test set)', title='Fashion MNIST dataset');\n",
199 | "ax.legend();\n",
200 | "fig.savefig('figs/fashion_mnist_test_acc_curves.png')"
201 | ]
202 | },
203 | {
204 | "cell_type": "code",
205 | "execution_count": null,
206 | "metadata": {},
207 | "outputs": [],
208 | "source": []
209 | }
210 | ],
211 | "metadata": {
212 | "kernelspec": {
213 | "display_name": "venv",
214 | "language": "python",
215 | "name": "venv"
216 | },
217 | "language_info": {
218 | "codemirror_mode": {
219 | "name": "ipython",
220 | "version": 3
221 | },
222 | "file_extension": ".py",
223 | "mimetype": "text/x-python",
224 | "name": "python",
225 | "nbconvert_exporter": "python",
226 | "pygments_lexer": "ipython3",
227 | "version": "3.7.1"
228 | }
229 | },
230 | "nbformat": 4,
231 | "nbformat_minor": 4
232 | }
233 |
--------------------------------------------------------------------------------
/README.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Preliminary\n",
8 | "\n",
9 | "Let $\\mathbf{x}$ be the input feature vector and $y$ be its label. Let $f(\\cdot)$ be a encoder network mapping the input space to the latent space and $\\mathbf{z} = f(\\mathbf{x})$ be the latent vector. \n"
10 | ]
11 | },
12 | {
13 | "attachments": {},
14 | "cell_type": "markdown",
15 | "metadata": {},
16 | "source": [
17 | "## Types of contrastive loss functions\n",
18 | "\n",
19 | "### 1. Max margin contrastive loss (Hadsell et al. 2006)\n",
20 | "\n",
21 | "$$ \\mathcal{L}(\\mathbf{z_i}, \\mathbf{z_j}) = \n",
22 | "\\mathbb{1}_{y_i=y_j} \\left\\lVert \\mathbf{z_i} - \\mathbf{z_j} \\right\\rVert^2_2 + \n",
23 | "\\mathbb{1}_{y_i \\neq y_j} \\max(0, m - \\left\\lVert \\mathbf{z_i} - \\mathbf{z_j} \\right\\rVert_2)^2$$\n",
24 | "\n",
25 | ", where $m > 0$ is a margin. The margin imposes a lower bound on the distance between a pair of samples with different labels. \n",
26 | "\n",
27 | "### 2. Triplet loss (Weinberger et al. 2006)\n",
28 | "\n",
29 | "Triplet loss operates on a triplet of vectors whose labels follow $y_i = y_j$ and $y_i \\neq y_k$. That is to say two of the three ($\\mathbf{z_i}$ and $\\mathbf{z_j}$) shared the same label and a third vector $\\mathbf{z_k}$ has a different label. In triplet learning literatures, they are termed anchor, positive, and negative, respectively. Triplet loss is defined as:\n",
30 | "\n",
31 | "$$ \\mathcal{L}(\\mathbf{z_i}, \\mathbf{z_j}, \\mathbf{z_k}) = \n",
32 | "\\max(0, \\left\\lVert \\mathbf{z_i} - \\mathbf{z_j} \\right\\rVert^2_2 - \n",
33 | " \\left\\lVert \\mathbf{z_i} - \\mathbf{z_k} \\right\\rVert^2_2 + m)\n",
34 | "$$\n",
35 | ", where $m$ again is the margin parameter that requires the delta distances between anchor-positive and anchor-negative has to be larger than $m$. The intuition for this loss function is to push negative samples outside of the neighborhood by a margin while keeping positive samples within the neighborhood. Graphically:\n",
36 | "\n",
37 | "\n",
38 | "\n",
39 | "#### Triplet mining\n",
40 | "\n",
41 | "Based on the definition of the triplet loss, a triplet may have the following three scenarios before any training: \n",
42 | "- **easy**: triplets with a loss of 0 because the negative is already more than a margin away from the anchor than the positive, i.e. $ \\left\\lVert \\mathbf{z_i} - \\mathbf{z_j} \\right\\rVert^2_2 + m < \n",
43 | " \\left\\lVert \\mathbf{z_i} - \\mathbf{z_k} \\right\\rVert^2_2 $\n",
44 | "- **hard**: triplets where the negative is closer to the anchor than the positive, i.e. $ \\left\\lVert \\mathbf{z_i} - \\mathbf{z_j} \\right\\rVert^2_2 >\n",
45 | " \\left\\lVert \\mathbf{z_i} - \\mathbf{z_k} \\right\\rVert^2_2$ \n",
46 | "- **semi-hard**: triplets where the negative lies in the margin, i.e. $ \\left\\lVert \\mathbf{z_i} - \\mathbf{z_j} \\right\\rVert^2_2 <\n",
47 | " \\left\\lVert \\mathbf{z_i} - \\mathbf{z_k} \\right\\rVert^2_2 < \\left\\lVert \\mathbf{z_i} - \\mathbf{z_j} \\right\\rVert^2_2 + m$\n",
48 | "\n",
49 | "In the FaceNet (Schroff et al. 2015) paper, which uses triplet loss to learn embeddings for faces, the authors argued that triplet mining is crucial for model performance and convergence. They also found that hardest triplets led to local minima early on in training, specifically resulted in a collapsed model, whereas semi-hard triplets yields more stable results and faster convergence.\n",
50 | "\n",
51 | "\n",
52 | "\n",
53 | "### 3. Multi-class N-pair loss (Sohn 2016)\n",
54 | "\n",
55 | "Multi-class N-pair loss is a generalization of triplet loss allowing joint comparison among more than one negative samples. When applied on a pair of positive samples $\\mathbf{z_i}$ and $\\mathbf{z_j}$ sharing the same label ($y_i = y_j$) from a mini-batch with $2N$ samples, it is computed as:\n",
56 | "\n",
57 | "$$ \\mathcal{L}(\\mathbf{z_i}, \\mathbf{z_j}) = \n",
58 | "\\log(1+\\sum_{k=1}^{2N}{\\mathbb{1}_{k \\neq i} \\exp(\\mathbf{z_i} \\mathbf{z_k} - \\mathbf{z_i} \\mathbf{z_j})})\n",
59 | "$$\n",
60 | ", where $z_i z_j$ is the cosine similarity between the two vectors. \n",
61 | "\n"
62 | ]
63 | },
64 | {
65 | "cell_type": "markdown",
66 | "metadata": {},
67 | "source": [
68 | "With some algebraic manipulation, multi-class N-pair loss can be written as the following:\n",
69 | "\n",
70 | "\\begin{equation}\n",
71 | "\\begin{split}\n",
72 | "\\mathcal{L}(\\mathbf{z_i}, \\mathbf{z_j}) & = \\log(1+\\sum_{k=1}^{2N}{\\mathbb{1}_{k \\neq i} \\exp(\\mathbf{z_i} \\mathbf{z_k} - \\mathbf{z_i} \\mathbf{z_j})}) \\\\\n",
73 | " & = -\\log \\frac{1}{1+\\sum_{k=1}^{2N}{\\mathbb{1}_{k \\neq i} \\exp(\\mathbf{z_i} \\mathbf{z_k} - \\mathbf{z_i} \\mathbf{z_j})}} \\\\\n",
74 | " & = -\\log \\frac{1}{1+\\sum_{k=1}^{2N}{\\mathbb{1}_{k \\neq i} \\frac{\\exp(\\mathbf{z_i} \\mathbf{z_k})}{\\exp(\\mathbf{z_i} \\mathbf{z_j})}}} \\\\\n",
75 | " & = -\\log \\frac{\\exp(\\mathbf{z_i} \\mathbf{z_j})}{\\exp(\\mathbf{z_i} \\mathbf{z_j}) + \\sum_{k=1}^{2N}\\mathbb{1}_{k \\neq i} \\exp(\\mathbf{z_i} \\mathbf{z_k})}\n",
76 | "\\end{split}\n",
77 | "\\end{equation}\n"
78 | ]
79 | },
80 | {
81 | "cell_type": "markdown",
82 | "metadata": {},
83 | "source": [
84 | "### 4. Supervised NT-Xent loss (Khosla et al. 2020)\n",
85 | "\n",
86 | "- Self-supervised NT-xent loss (Chen et al. 2020 in SimCLR paper) \n",
87 | "\n",
88 | "NT-Xent is coined by Chen et al. 2020 and is short for normalized temperature-scaled cross entropy loss. It is a modification of Multi-class N-pair loss with addition of the temperature parameter ($\\tau$).\n",
89 | "\n",
90 | "$$\n",
91 | "\\mathcal{L}(\\mathbf{z_i}, \\mathbf{z_j}) = \n",
92 | "-\\log \\frac{\\exp(\\mathbf{z_i} \\mathbf{z_j} / \\tau)}{\\sum_{k=1}^{2N}{\\mathbb{1}_{k \\neq i} \\exp(\\mathbf{z_i} \\mathbf{z_k} / \\tau)}}\n",
93 | "$$\n",
94 | "\n",
95 | "- Supervised NT-xent loss\n",
96 | "\n",
97 | "$$\n",
98 | "\\mathcal{L}(\\mathbf{z_i}, \\mathbf{z_j}) = \n",
99 | "\\frac{-1}{2N_{y_i}-1} \\sum_{j=1}^{2N} \\log \\frac{\\exp(\\mathbf{z_i} \\mathbf{z_j} / \\tau)}{\\sum_{k=1}^{2N}{\\mathbb{1}_{k \\neq i} \\exp(\\mathbf{z_i} \\mathbf{z_k} / \\tau)}}\n",
100 | "$$\n",
101 | "\n"
102 | ]
103 | },
104 | {
105 | "attachments": {},
106 | "cell_type": "markdown",
107 | "metadata": {},
108 | "source": []
109 | },
110 | {
111 | "cell_type": "markdown",
112 | "metadata": {},
113 | "source": [
114 | "# References\n",
115 | "- [Hadsell, R., Chopra, S., & LeCun, Y. (2006, June). Dimensionality reduction by learning an invariant mapping.](http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf) In 2006 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR'06) (Vol. 2, pp. 1735-1742). IEEE.\n",
116 | "- [Weinberger, K. Q., Blitzer, J., & Saul, L. K. (2006). Distance metric learning for large margin nearest neighbor classification.](https://papers.nips.cc/paper/2795-distance-metric-learning-for-large-margin-nearest-neighbor-classification.pdf) In Advances in neural information processing systems (pp. 1473-1480).\n",
117 | "- [Schroff, F., Kalenichenko, D., & Philbin, J. (2015). Facenet: A unified embedding for face recognition and clustering.](https://arxiv.org/abs/1503.03832) In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 815-823).\n",
118 | "- [Sohn, K. (2016). Improved deep metric learning with multi-class n-pair loss objective.](https://papers.nips.cc/paper/6200-improved-deep-metric-learning-with-multi-class-n-pair-loss-objective) In Advances in neural information processing systems (pp. 1857-1865).\n",
119 | "- [Chen, T., Kornblith, S., Norouzi, M., & Hinton, G. (2020). A simple framework for contrastive learning of visual representations.](https://arxiv.org/pdf/2002.05709.pdf) arXiv preprint arXiv:2002.05709.\n",
120 | "- [Khosla, P., Teterwak, P., Wang, C., Sarna, A., Tian, Y., Isola, P., ... & Krishnan, D. (2020). Supervised Contrastive Learning.](https://arxiv.org/pdf/2004.11362.pdf) arXiv preprint arXiv:2004.11362."
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": null,
126 | "metadata": {},
127 | "outputs": [],
128 | "source": []
129 | }
130 | ],
131 | "metadata": {
132 | "kernelspec": {
133 | "display_name": "venv",
134 | "language": "python",
135 | "name": "venv"
136 | },
137 | "language_info": {
138 | "codemirror_mode": {
139 | "name": "ipython",
140 | "version": 3
141 | },
142 | "file_extension": ".py",
143 | "mimetype": "text/x-python",
144 | "name": "python",
145 | "nbconvert_exporter": "python",
146 | "pygments_lexer": "ipython3",
147 | "version": "3.7.1"
148 | }
149 | },
150 | "nbformat": 4,
151 | "nbformat_minor": 4
152 | }
153 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Contrastive loss functions
2 |
3 | Experiments with different contrastive loss functions to see if they help supervised learning.
4 |
5 | For detailed reviews and intuitions, please check out those posts:
6 | - [Contrastive loss for supervised classification](https://towardsdatascience.com/contrastive-loss-for-supervised-classification-224ae35692e7)
7 | - [Contrasting contrastive loss functions](https://medium.com/@wangzc921/contrasting-contrastive-loss-functions-3c13ca5f055e)
8 |
--------------------------------------------------------------------------------
/lars_optimizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Functions and classes related to optimization (weight updates)."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import re
23 |
24 | import tensorflow.compat.v1 as tf
25 |
26 | EETA_DEFAULT = 0.001
27 |
28 |
29 | class LARSOptimizer(tf.train.Optimizer):
30 | """Layer-wise Adaptive Rate Scaling for large batch training.
31 |
32 | Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
33 | I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
34 | """
35 |
36 | def __init__(self,
37 | learning_rate,
38 | momentum=0.9,
39 | use_nesterov=False,
40 | weight_decay=0.0,
41 | exclude_from_weight_decay=None,
42 | exclude_from_layer_adaptation=None,
43 | classic_momentum=True,
44 | eeta=EETA_DEFAULT,
45 | name="LARSOptimizer"):
46 | """Constructs a LARSOptimizer.
47 |
48 | Args:
49 | learning_rate: A `float` for learning rate.
50 | momentum: A `float` for momentum.
51 | use_nesterov: A 'Boolean' for whether to use nesterov momentum.
52 | weight_decay: A `float` for weight decay.
53 | exclude_from_weight_decay: A list of `string` for variable screening, if
54 | any of the string appears in a variable's name, the variable will be
55 | excluded for computing weight decay. For example, one could specify
56 | the list like ['batch_normalization', 'bias'] to exclude BN and bias
57 | from weight decay.
58 | exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but
59 | for layer adaptation. If it is None, it will be defaulted the same as
60 | exclude_from_weight_decay.
61 | classic_momentum: A `boolean` for whether to use classic (or popular)
62 | momentum. The learning rate is applied during momeuntum update in
63 | classic momentum, but after momentum for popular momentum.
64 | eeta: A `float` for scaling of learning rate when computing trust ratio.
65 | name: The name for the scope.
66 | """
67 | super(LARSOptimizer, self).__init__(False, name)
68 |
69 | self.learning_rate = learning_rate
70 | self.momentum = momentum
71 | self.weight_decay = weight_decay
72 | self.use_nesterov = use_nesterov
73 | self.classic_momentum = classic_momentum
74 | self.eeta = eeta
75 | self.exclude_from_weight_decay = exclude_from_weight_decay
76 | # exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
77 | # arg is None.
78 | if exclude_from_layer_adaptation:
79 | self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
80 | else:
81 | self.exclude_from_layer_adaptation = exclude_from_weight_decay
82 |
83 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
84 | if global_step is None:
85 | global_step = tf.train.get_or_create_global_step()
86 | new_global_step = global_step + 1
87 |
88 | assignments = []
89 | for (grad, param) in grads_and_vars:
90 | if grad is None or param is None:
91 | continue
92 |
93 | param_name = param.op.name
94 |
95 | v = tf.get_variable(
96 | name=param_name + "/Momentum",
97 | shape=param.shape.as_list(),
98 | dtype=tf.float32,
99 | trainable=False,
100 | initializer=tf.zeros_initializer())
101 |
102 | if self._use_weight_decay(param_name):
103 | grad += self.weight_decay * param
104 |
105 | if self.classic_momentum:
106 | trust_ratio = 1.0
107 | if self._do_layer_adaptation(param_name):
108 | w_norm = tf.norm(param, ord=2)
109 | g_norm = tf.norm(grad, ord=2)
110 | trust_ratio = tf.where(
111 | tf.greater(w_norm, 0), tf.where(
112 | tf.greater(g_norm, 0), (self.eeta *
113 | w_norm / g_norm),
114 | 1.0),
115 | 1.0)
116 | scaled_lr = self.learning_rate * trust_ratio
117 |
118 | next_v = tf.multiply(self.momentum, v) + scaled_lr * grad
119 | if self.use_nesterov:
120 | update = tf.multiply(
121 | self.momentum, next_v) + scaled_lr * grad
122 | else:
123 | update = next_v
124 | next_param = param - update
125 | else:
126 | next_v = tf.multiply(self.momentum, v) + grad
127 | if self.use_nesterov:
128 | update = tf.multiply(self.momentum, next_v) + grad
129 | else:
130 | update = next_v
131 |
132 | trust_ratio = 1.0
133 | if self._do_layer_adaptation(param_name):
134 | w_norm = tf.norm(param, ord=2)
135 | v_norm = tf.norm(update, ord=2)
136 | trust_ratio = tf.where(
137 | tf.greater(w_norm, 0), tf.where(
138 | tf.greater(v_norm, 0), (self.eeta *
139 | w_norm / v_norm),
140 | 1.0),
141 | 1.0)
142 | scaled_lr = trust_ratio * self.learning_rate
143 | next_param = param - scaled_lr * update
144 |
145 | assignments.extend(
146 | [param.assign(next_param),
147 | v.assign(next_v),
148 | global_step.assign(new_global_step)])
149 | return tf.group(*assignments, name=name)
150 |
151 | def _use_weight_decay(self, param_name):
152 | """Whether to use L2 weight decay for `param_name`."""
153 | if not self.weight_decay:
154 | return False
155 | if self.exclude_from_weight_decay:
156 | for r in self.exclude_from_weight_decay:
157 | if re.search(r, param_name) is not None:
158 | return False
159 | return True
160 |
161 | def _do_layer_adaptation(self, param_name):
162 | """Whether to do layer-wise learning rate adaptation for `param_name`."""
163 | if self.exclude_from_layer_adaptation:
164 | for r in self.exclude_from_layer_adaptation:
165 | if re.search(r, param_name) is not None:
166 | return False
167 | return True
168 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import tensorflow_addons as tfa
4 |
5 |
6 | def pdist_euclidean(A):
7 | # Euclidean pdist
8 | # https://stackoverflow.com/questions/37009647/compute-pairwise-distance-in-a-batch-without-replicating-tensor-in-tensorflow
9 | r = tf.reduce_sum(A*A, 1)
10 |
11 | # turn r into column vector
12 | r = tf.reshape(r, [-1, 1])
13 | D = r - 2*tf.matmul(A, tf.transpose(A)) + tf.transpose(r)
14 | return tf.sqrt(D)
15 |
16 |
17 | def square_to_vec(D):
18 | '''Convert a squared form pdist matrix to vector form.
19 | '''
20 | n = D.shape[0]
21 | triu_idx = np.triu_indices(n, k=1)
22 | d_vec = tf.gather_nd(D, list(zip(triu_idx[0], triu_idx[1])))
23 | return d_vec
24 |
25 |
26 | def get_contrast_batch_labels(y):
27 | '''
28 | Make contrast labels by taking all the pairwise in y
29 | y: tensor with shape: (batch_size, )
30 | returns:
31 | tensor with shape: (batch_size * (batch_size-1) // 2, )
32 | '''
33 | y_col_vec = tf.reshape(tf.cast(y, tf.float32), [-1, 1])
34 | D_y = pdist_euclidean(y_col_vec)
35 | d_y = square_to_vec(D_y)
36 | y_contrasts = tf.cast(d_y == 0, tf.int32)
37 | return y_contrasts
38 |
39 |
40 | def get_contrast_batch_labels_regression(y):
41 | '''
42 | Make contrast labels for regression by taking all the pairwise in y
43 | y: tensor with shape: (batch_size, )
44 | returns:
45 | tensor with shape: (batch_size * (batch_size-1) // 2, )
46 | '''
47 | raise NotImplementedError
48 |
49 |
50 | def max_margin_contrastive_loss(z, y, margin=1.0, metric='euclidean'):
51 | '''
52 | Wrapper for the maximum margin contrastive loss (Hadsell et al. 2006)
53 | `tfa.losses.contrastive_loss`
54 | Args:
55 | z: hidden vector of shape [bsz, n_features].
56 | y: ground truth of shape [bsz].
57 | metric: one of ('euclidean', 'cosine')
58 | '''
59 | # compute pair-wise distance matrix
60 | if metric == 'euclidean':
61 | D = pdist_euclidean(z)
62 | elif metric == 'cosine':
63 | D = 1 - tf.matmul(z, z, transpose_a=False, transpose_b=True)
64 | # convert squareform matrix to vector form
65 | d_vec = square_to_vec(D)
66 | # make contrastive labels
67 | y_contrasts = get_contrast_batch_labels(y)
68 | loss = tfa.losses.contrastive_loss(y_contrasts, d_vec, margin=margin)
69 | # exploding/varnishing gradients on large batch?
70 | return tf.reduce_mean(loss)
71 |
72 |
73 | def multiclass_npairs_loss(z, y):
74 | '''
75 | Wrapper for the multiclass N-pair loss (Sohn 2016)
76 | `tfa.losses.npairs_loss`
77 | Args:
78 | z: hidden vector of shape [bsz, n_features].
79 | y: ground truth of shape [bsz].
80 | '''
81 | # cosine similarity matrix
82 | S = tf.matmul(z, z, transpose_a=False, transpose_b=True)
83 | loss = tfa.losses.npairs_loss(y, S)
84 | return loss
85 |
86 |
87 | def triplet_loss(z, y, margin=1.0, kind='hard'):
88 | '''
89 | Wrapper for the triplet losses
90 | `tfa.losses.triplet_hard_loss` and `tfa.losses.triplet_semihard_loss`
91 | Args:
92 | z: hidden vector of shape [bsz, n_features], assumes it is l2-normalized.
93 | y: ground truth of shape [bsz].
94 | '''
95 | if kind == 'hard':
96 | loss = tfa.losses.triplet_hard_loss(y, z, margin=margin, soft=False)
97 | elif kind == 'soft':
98 | loss = tfa.losses.triplet_hard_loss(y, z, margin=margin, soft=True)
99 | elif kind == 'semihard':
100 | loss = tfa.losses.triplet_semihard_loss(y, z, margin=margin)
101 | return loss
102 |
103 |
104 | def supervised_nt_xent_loss(z, y, temperature=0.5, base_temperature=0.07):
105 | '''
106 | Supervised normalized temperature-scaled cross entropy loss.
107 | A variant of Multi-class N-pair Loss from (Sohn 2016)
108 | Later used in SimCLR (Chen et al. 2020, Khosla et al. 2020).
109 | Implementation modified from:
110 | - https://github.com/google-research/simclr/blob/master/objective.py
111 | - https://github.com/HobbitLong/SupContrast/blob/master/losses.py
112 | Args:
113 | z: hidden vector of shape [bsz, n_features].
114 | y: ground truth of shape [bsz].
115 | '''
116 | batch_size = tf.shape(z)[0]
117 | contrast_count = 1
118 | anchor_count = contrast_count
119 | y = tf.expand_dims(y, -1)
120 |
121 | # mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
122 | # has the same class as sample i. Can be asymmetric.
123 | mask = tf.cast(tf.equal(y, tf.transpose(y)), tf.float32)
124 | anchor_dot_contrast = tf.divide(
125 | tf.matmul(z, tf.transpose(z)),
126 | temperature
127 | )
128 | # # for numerical stability
129 | logits_max = tf.reduce_max(anchor_dot_contrast, axis=1, keepdims=True)
130 | logits = anchor_dot_contrast - logits_max
131 | # # tile mask
132 | logits_mask = tf.ones_like(mask) - tf.eye(batch_size)
133 | mask = mask * logits_mask
134 | # compute log_prob
135 | exp_logits = tf.exp(logits) * logits_mask
136 | log_prob = logits - \
137 | tf.math.log(tf.reduce_sum(exp_logits, axis=1, keepdims=True))
138 |
139 | # compute mean of log-likelihood over positive
140 | # this may introduce NaNs due to zero division,
141 | # when a class only has one example in the batch
142 | mask_sum = tf.reduce_sum(mask, axis=1)
143 | mean_log_prob_pos = tf.reduce_sum(
144 | mask * log_prob, axis=1)[mask_sum > 0] / mask_sum[mask_sum > 0]
145 |
146 | # loss
147 | loss = -(temperature / base_temperature) * mean_log_prob_pos
148 | # loss = tf.reduce_mean(tf.reshape(loss, [anchor_count, batch_size]))
149 | loss = tf.reduce_mean(loss)
150 | return loss
151 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | '''
2 | Script to run various two-stage supervised contrastive loss functions on
3 | MNIST or Fashion MNIST data.
4 |
5 | Author: Zichen Wang (wangzc921@gmail.com)
6 | '''
7 | import argparse
8 | import datetime
9 | import numpy as np
10 | import tensorflow as tf
11 | import tensorflow_addons as tfa
12 | import pandas as pd
13 | from sklearn.decomposition import PCA
14 | import matplotlib.pyplot as plt
15 | import seaborn as sns
16 |
17 | from model import *
18 | import losses
19 |
20 | SEED = 42
21 | np.random.seed(SEED)
22 | tf.random.set_seed(SEED)
23 |
24 | LOSS_NAMES = {
25 | 'max_margin': 'Max margin contrastive',
26 | 'npairs': 'Multiclass N-pairs',
27 | 'sup_nt_xent': 'Supervised NT-Xent',
28 | 'triplet-hard': 'Triplet hard',
29 | 'triplet-semihard': 'Triplet semihard',
30 | 'triplet-soft': 'Triplet soft'
31 | }
32 |
33 |
34 | def parse_option():
35 | parser = argparse.ArgumentParser('arguments for two-stage training ')
36 | # training params
37 | parser.add_argument('--batch_size_1', type=int, default=512,
38 | help='batch size for stage 1 pretraining'
39 | )
40 | parser.add_argument('--batch_size_2', type=int, default=32,
41 | help='batch size for stage 2 training'
42 | )
43 | parser.add_argument('--lr_1', type=float, default=0.5,
44 | help='learning rate for stage 1 pretraining'
45 | )
46 | parser.add_argument('--lr_2', type=float, default=0.001,
47 | help='learning rate for stage 2 training'
48 | )
49 | parser.add_argument('--epoch', type=int, default=20,
50 | help='Number of epochs for training in stage1, the same number of epochs will be applied on stage2')
51 | parser.add_argument('--optimizer', type=str, default='adam',
52 | help='Optimizer to use, choose from ("adam", "lars", "sgd")'
53 | )
54 | # loss functions
55 | parser.add_argument('--loss', type=str, default='max_margin',
56 | help='Loss function used for stage 1, choose from ("max_margin", "npairs", "sup_nt_xent", "triplet-hard", "triplet-semihard", "triplet-soft")')
57 | parser.add_argument('--margin', type=float, default=1.0,
58 | help='margin for tfa.losses.contrastive_loss. will only be used when --loss=max_margin')
59 | parser.add_argument('--metric', type=str, default='euclidean',
60 | help='distance metrics for tfa.losses.contrastive_loss, choose from ("euclidean", "cosine"). will only be used when --loss=max_margin')
61 | parser.add_argument('--temperature', type=float, default=0.5,
62 | help='temperature for sup_nt_xent loss. will only be used when --loss=sup_nt_xent')
63 | parser.add_argument('--base_temperature', type=float, default=0.07,
64 | help='base_temperature for sup_nt_xent loss. will only be used when --loss=sup_nt_xent')
65 | # dataset params
66 | parser.add_argument('--data', type=str, default='mnist',
67 | help='Dataset to choose from ("mnist", "fashion_mnist")'
68 | )
69 | parser.add_argument('--n_data_train', type=int, default=60000,
70 | help='number of data points used for training both stage 1 and 2'
71 | )
72 |
73 | # model architecture
74 | parser.add_argument('--projection_dim', type=int, default=128,
75 | help='output tensor dimension from projector'
76 | )
77 | parser.add_argument('--activation', type=str, default='leaky_relu',
78 | help='activation function between hidden layers'
79 | )
80 |
81 | # output options
82 | parser.add_argument('--write_summary', action='store_true',
83 | help='write summary for tensorboard'
84 | )
85 | parser.add_argument('--draw_figures', action='store_true',
86 | help='produce figures for the projections'
87 | )
88 |
89 | args = parser.parse_args()
90 | return args
91 |
92 |
93 | def main():
94 | args = parse_option()
95 | print(args)
96 |
97 | # check args
98 | if args.loss not in LOSS_NAMES:
99 | raise ValueError('Unsupported loss function type {}'.format(args.loss))
100 |
101 | if args.optimizer == 'adam':
102 | optimizer1 = tf.keras.optimizers.Adam(lr=args.lr_1)
103 | elif args.optimizer == 'lars':
104 | from lars_optimizer import LARSOptimizer
105 | # not compatible with tf2
106 | optimizer1 = LARSOptimizer(args.lr_1,
107 | exclude_from_weight_decay=['batch_normalization', 'bias'])
108 | elif args.optimizer == 'sgd':
109 | optimizer1 = tfa.optimizers.SGDW(learning_rate=args.lr_1,
110 | momentum=0.9,
111 | weight_decay=1e-4
112 | )
113 | optimizer2 = tf.keras.optimizers.Adam(lr=args.lr_2)
114 |
115 | model_name = '{}_model-bs_{}-lr_{}'.format(
116 | args.loss, args.batch_size_1, args.lr_1)
117 |
118 | # 0. Load data
119 | if args.data == 'mnist':
120 | mnist = tf.keras.datasets.mnist
121 | elif args.data == 'fashion_mnist':
122 | mnist = tf.keras.datasets.fashion_mnist
123 | print('Loading {} data...'.format(args.data))
124 | (x_train, y_train), (x_test, y_test) = mnist.load_data()
125 | x_train, x_test = x_train / 255.0, x_test / 255.0
126 | x_train = x_train.reshape(-1, 28*28).astype(np.float32)
127 | x_test = x_test.reshape(-1, 28*28).astype(np.float32)
128 | print(x_train.shape, x_test.shape)
129 |
130 | # simulate low data regime for training
131 | n_train = x_train.shape[0]
132 | shuffle_idx = np.arange(n_train)
133 | np.random.shuffle(shuffle_idx)
134 |
135 | x_train = x_train[shuffle_idx][:args.n_data_train]
136 | y_train = y_train[shuffle_idx][:args.n_data_train]
137 | print('Training dataset shapes after slicing:')
138 | print(x_train.shape, y_train.shape)
139 |
140 | train_ds = tf.data.Dataset.from_tensor_slices(
141 | (x_train, y_train)).shuffle(5000).batch(args.batch_size_1)
142 |
143 | train_ds2 = tf.data.Dataset.from_tensor_slices(
144 | (x_train, y_train)).shuffle(5000).batch(args.batch_size_2)
145 |
146 | test_ds = tf.data.Dataset.from_tensor_slices(
147 | (x_test, y_test)).batch(args.batch_size_1)
148 |
149 | # 1. Stage 1: train encoder with multiclass N-pair loss
150 | encoder = Encoder(normalize=True, activation=args.activation)
151 | projector = Projector(args.projection_dim,
152 | normalize=True, activation=args.activation)
153 |
154 | if args.loss == 'max_margin':
155 | def loss_func(z, y): return losses.max_margin_contrastive_loss(
156 | z, y, margin=args.margin, metric=args.metric)
157 | elif args.loss == 'npairs':
158 | loss_func = losses.multiclass_npairs_loss
159 | elif args.loss == 'sup_nt_xent':
160 | def loss_func(z, y): return losses.supervised_nt_xent_loss(
161 | z, y, temperature=args.temperature, base_temperature=args.base_temperature)
162 | elif args.loss.startswith('triplet'):
163 | triplet_kind = args.loss.split('-')[1]
164 | def loss_func(z, y): return losses.triplet_loss(
165 | z, y, kind=triplet_kind, margin=args.margin)
166 |
167 | train_loss = tf.keras.metrics.Mean(name='train_loss')
168 | test_loss = tf.keras.metrics.Mean(name='test_loss')
169 |
170 | # tf.config.experimental_run_functions_eagerly(True)
171 | @tf.function
172 | # train step for the contrastive loss
173 | def train_step_stage1(x, y):
174 | '''
175 | x: data tensor, shape: (batch_size, data_dim)
176 | y: data labels, shape: (batch_size, )
177 | '''
178 | with tf.GradientTape() as tape:
179 | r = encoder(x, training=True)
180 | z = projector(r, training=True)
181 | loss = loss_func(z, y)
182 |
183 | gradients = tape.gradient(loss,
184 | encoder.trainable_variables + projector.trainable_variables)
185 | optimizer1.apply_gradients(zip(gradients,
186 | encoder.trainable_variables + projector.trainable_variables))
187 | train_loss(loss)
188 |
189 | @tf.function
190 | def test_step_stage1(x, y):
191 | r = encoder(x, training=False)
192 | z = projector(r, training=False)
193 | t_loss = loss_func(z, y)
194 | test_loss(t_loss)
195 |
196 | print('Stage 1 training ...')
197 | for epoch in range(args.epoch):
198 | # Reset the metrics at the start of the next epoch
199 | train_loss.reset_states()
200 | test_loss.reset_states()
201 |
202 | for x, y in train_ds:
203 | train_step_stage1(x, y)
204 |
205 | for x_te, y_te in test_ds:
206 | test_step_stage1(x_te, y_te)
207 |
208 | template = 'Epoch {}, Loss: {}, Test Loss: {}'
209 | print(template.format(epoch + 1,
210 | train_loss.result(),
211 | test_loss.result()))
212 |
213 | if args.draw_figures:
214 | # projecting data with the trained encoder, projector
215 | x_tr_proj = projector(encoder(x_train))
216 | x_te_proj = projector(encoder(x_test))
217 | # convert tensor to np.array
218 | x_tr_proj = x_tr_proj.numpy()
219 | x_te_proj = x_te_proj.numpy()
220 | print(x_tr_proj.shape, x_te_proj.shape)
221 |
222 | # check learned embedding using PCA
223 | pca = PCA(n_components=2)
224 | pca.fit(x_tr_proj)
225 | x_te_proj_pca = pca.transform(x_te_proj)
226 |
227 | x_te_proj_pca_df = pd.DataFrame(x_te_proj_pca, columns=['PC1', 'PC2'])
228 | x_te_proj_pca_df['label'] = y_test
229 | # PCA scatter plot
230 | fig, ax = plt.subplots()
231 | ax = sns.scatterplot('PC1', 'PC2',
232 | data=x_te_proj_pca_df,
233 | palette='tab10',
234 | hue='label',
235 | linewidth=0,
236 | alpha=0.6,
237 | ax=ax
238 | )
239 |
240 | box = ax.get_position()
241 | ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
242 | ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
243 | title = 'Data: {}\nEmbedding: {}\nbatch size: {}; LR: {}'.format(
244 | args.data, LOSS_NAMES[args.loss], args.batch_size_1, args.lr_1)
245 | ax.set_title(title)
246 | fig.savefig(
247 | 'figs/PCA_plot_{}_{}_embed.png'.format(args.data, model_name))
248 |
249 | # density plot for PCA
250 | g = sns.jointplot('PC1', 'PC2', data=x_te_proj_pca_df,
251 | kind="hex"
252 | )
253 | plt.subplots_adjust(top=0.95)
254 | g.fig.suptitle(title)
255 |
256 | g.savefig(
257 | 'figs/Joint_PCA_plot_{}_{}_embed.png'.format(args.data, model_name))
258 |
259 | # Stage 2: freeze the learned representations and then learn a classifier
260 | # on a linear layer using a softmax loss
261 | softmax = SoftmaxPred()
262 |
263 | train_loss = tf.keras.metrics.Mean(name='train_loss')
264 | train_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='train_ACC')
265 | test_loss = tf.keras.metrics.Mean(name='test_loss')
266 | test_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='test_ACC')
267 |
268 | cce_loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(
269 | from_logits=True)
270 |
271 | @tf.function
272 | # train step for the 2nd stage
273 | def train_step(x, y):
274 | '''
275 | x: data tensor, shape: (batch_size, data_dim)
276 | y: data labels, shape: (batch_size, )
277 | '''
278 | with tf.GradientTape() as tape:
279 | r = encoder(x, training=False)
280 | y_preds = softmax(r, training=True)
281 | loss = cce_loss_obj(y, y_preds)
282 |
283 | # freeze the encoder, only train the softmax layer
284 | gradients = tape.gradient(loss,
285 | softmax.trainable_variables)
286 | optimizer2.apply_gradients(zip(gradients,
287 | softmax.trainable_variables))
288 | train_loss(loss)
289 | train_acc(y, y_preds)
290 |
291 | @tf.function
292 | def test_step(x, y):
293 | r = encoder(x, training=False)
294 | y_preds = softmax(r, training=False)
295 | t_loss = cce_loss_obj(y, y_preds)
296 | test_loss(t_loss)
297 | test_acc(y, y_preds)
298 |
299 | if args.write_summary:
300 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
301 | train_log_dir = 'logs/{}/{}/{}/train'.format(
302 | model_name, args.data, current_time)
303 | test_log_dir = 'logs/{}/{}/{}/test'.format(
304 | model_name, args.data, current_time)
305 | train_summary_writer = tf.summary.create_file_writer(train_log_dir)
306 | test_summary_writer = tf.summary.create_file_writer(test_log_dir)
307 |
308 | print('Stage 2 training ...')
309 | for epoch in range(args.epoch):
310 | # Reset the metrics at the start of the next epoch
311 | train_loss.reset_states()
312 | train_acc.reset_states()
313 | test_loss.reset_states()
314 | test_acc.reset_states()
315 |
316 | for x, y in train_ds2:
317 | train_step(x, y)
318 |
319 | if args.write_summary:
320 | with train_summary_writer.as_default():
321 | tf.summary.scalar('loss', train_loss.result(), step=epoch)
322 | tf.summary.scalar('accuracy', train_acc.result(), step=epoch)
323 |
324 | for x_te, y_te in test_ds:
325 | test_step(x_te, y_te)
326 |
327 | if args.write_summary:
328 | with test_summary_writer.as_default():
329 | tf.summary.scalar('loss', test_loss.result(), step=epoch)
330 | tf.summary.scalar('accuracy', test_acc.result(), step=epoch)
331 |
332 | template = 'Epoch {}, Loss: {}, Acc: {}, Test Loss: {}, Test Acc: {}'
333 | print(template.format(epoch + 1,
334 | train_loss.result(),
335 | train_acc.result() * 100,
336 | test_loss.result(),
337 | test_acc.result() * 100))
338 |
339 |
340 | if __name__ == '__main__':
341 | main()
342 |
--------------------------------------------------------------------------------
/main_ce_baseline.py:
--------------------------------------------------------------------------------
1 | '''
2 | Script to run baseline MLP with cross-entropy loss on
3 | MNIST or Fashion MNIST data.
4 |
5 | Author: Zichen Wang (wangzc921@gmail.com)
6 | '''
7 | import argparse
8 | import datetime
9 | import numpy as np
10 | import tensorflow as tf
11 | import tensorflow_addons as tfa
12 | import pandas as pd
13 | from sklearn.decomposition import PCA
14 | import matplotlib.pyplot as plt
15 | import seaborn as sns
16 |
17 | from model import *
18 |
19 | SEED = 42
20 | np.random.seed(SEED)
21 | tf.random.set_seed(SEED)
22 |
23 |
24 | def parse_option():
25 | parser = argparse.ArgumentParser('arguments for training baseline MLP')
26 | # training params
27 | parser.add_argument('--batch_size', type=int, default=32,
28 | help='batch size training'
29 | )
30 | parser.add_argument('--lr', type=float, default=0.001,
31 | help='learning rate training'
32 | )
33 | parser.add_argument('--epoch', type=int, default=20,
34 | help='Number of epochs for training')
35 |
36 | # dataset params
37 | parser.add_argument('--data', type=str, default='mnist',
38 | help='Dataset to choose from ("mnist", "fashion_mnist")'
39 | )
40 | parser.add_argument('--n_data_train', type=int, default=60000,
41 | help='number of data points used for training both stage 1 and 2'
42 | )
43 | # model architecture
44 | parser.add_argument('--projection_dim', type=int, default=128,
45 | help='output tensor dimension from projector'
46 | )
47 | parser.add_argument('--activation', type=str, default='leaky_relu',
48 | help='activation function between hidden layers'
49 | )
50 | # output options
51 | parser.add_argument('--write_summary', action='store_true',
52 | help='write summary for tensorboard'
53 | )
54 | parser.add_argument('--draw_figures', action='store_true',
55 | help='produce figures for the projections'
56 | )
57 |
58 | args = parser.parse_args()
59 | return args
60 |
61 |
62 | def main():
63 | args = parse_option()
64 | print(args)
65 |
66 | optimizer = tf.keras.optimizers.Adam(lr=args.lr)
67 | # 0. Load data
68 | if args.data == 'mnist':
69 | mnist = tf.keras.datasets.mnist
70 | elif args.data == 'fashion_mnist':
71 | mnist = tf.keras.datasets.fashion_mnist
72 | print('Loading {} data...'.format(args.data))
73 | (x_train, y_train), (x_test, y_test) = mnist.load_data()
74 | x_train, x_test = x_train / 255.0, x_test / 255.0
75 | x_train = x_train.reshape(-1, 28*28).astype(np.float32)
76 | x_test = x_test.reshape(-1, 28*28).astype(np.float32)
77 | print(x_train.shape, x_test.shape)
78 |
79 | # simulate low data regime for training
80 | n_train = x_train.shape[0]
81 | shuffle_idx = np.arange(n_train)
82 | np.random.shuffle(shuffle_idx)
83 |
84 | x_train = x_train[shuffle_idx][:args.n_data_train]
85 | y_train = y_train[shuffle_idx][:args.n_data_train]
86 | print('Training dataset shapes after slicing:')
87 | print(x_train.shape, y_train.shape)
88 |
89 | train_ds = tf.data.Dataset.from_tensor_slices(
90 | (x_train, y_train)).shuffle(5000).batch(args.batch_size)
91 |
92 | test_ds = tf.data.Dataset.from_tensor_slices(
93 | (x_test, y_test)).batch(args.batch_size)
94 |
95 | # 1. the baseline MLP model
96 | mlp = MLP(normalize=True, activation=args.activation)
97 | cce_loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(
98 | from_logits=True)
99 |
100 | train_loss = tf.keras.metrics.Mean(name='train_loss')
101 | train_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='train_ACC')
102 |
103 | test_loss = tf.keras.metrics.Mean(name='test_loss')
104 | test_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='test_ACC')
105 |
106 | @tf.function
107 | def train_step_baseline(x, y):
108 | with tf.GradientTape() as tape:
109 | y_preds = mlp(x, training=True)
110 | loss = cce_loss_obj(y, y_preds)
111 |
112 | gradients = tape.gradient(loss,
113 | mlp.trainable_variables)
114 | optimizer.apply_gradients(zip(gradients,
115 | mlp.trainable_variables))
116 |
117 | train_loss(loss)
118 | train_acc(y, y_preds)
119 |
120 | @tf.function
121 | def test_step_baseline(x, y):
122 | y_preds = mlp(x, training=False)
123 | t_loss = cce_loss_obj(y, y_preds)
124 | test_loss(t_loss)
125 | test_acc(y, y_preds)
126 |
127 | model_name = 'baseline'
128 | if args.write_summary:
129 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
130 | train_log_dir = 'logs/%s/%s/%s/train' % (
131 | model_name, args.data, current_time)
132 | test_log_dir = 'logs/%s/%s/%s/test' % (
133 | model_name, args.data, current_time)
134 | train_summary_writer = tf.summary.create_file_writer(train_log_dir)
135 | test_summary_writer = tf.summary.create_file_writer(test_log_dir)
136 |
137 | for epoch in range(args.epoch):
138 | # Reset the metrics at the start of the next epoch
139 | train_loss.reset_states()
140 | train_acc.reset_states()
141 | test_loss.reset_states()
142 | test_acc.reset_states()
143 |
144 | for x, y in train_ds:
145 | train_step_baseline(x, y)
146 |
147 | if args.write_summary:
148 | with train_summary_writer.as_default():
149 | tf.summary.scalar('loss', train_loss.result(), step=epoch)
150 | tf.summary.scalar('accuracy', train_acc.result(), step=epoch)
151 |
152 | for x_te, y_te in test_ds:
153 | test_step_baseline(x_te, y_te)
154 |
155 | if args.write_summary:
156 | with test_summary_writer.as_default():
157 | tf.summary.scalar('loss', test_loss.result(), step=epoch)
158 | tf.summary.scalar('accuracy', test_acc.result(), step=epoch)
159 |
160 | template = 'Epoch {}, Loss: {}, Acc: {}, Test Loss: {}, Test Acc: {}'
161 | print(template.format(epoch + 1,
162 | train_loss.result(),
163 | train_acc.result() * 100,
164 | test_loss.result(),
165 | test_acc.result() * 100))
166 |
167 | # get the projections from the last hidden layer before output
168 | x_tr_proj = mlp.get_last_hidden(x_train)
169 | x_te_proj = mlp.get_last_hidden(x_test)
170 | # convert tensor to np.array
171 | x_tr_proj = x_tr_proj.numpy()
172 | x_te_proj = x_te_proj.numpy()
173 | print(x_tr_proj.shape, x_te_proj.shape)
174 | # 2. Check learned embedding
175 | if args.draw_figures:
176 | # do PCA for the projected data
177 | pca = PCA(n_components=2)
178 | pca.fit(x_tr_proj)
179 | x_te_proj_pca = pca.transform(x_te_proj)
180 |
181 | x_te_proj_pca_df = pd.DataFrame(x_te_proj_pca, columns=['PC1', 'PC2'])
182 | x_te_proj_pca_df['label'] = y_test
183 | # PCA scatter plot
184 | fig, ax = plt.subplots()
185 | ax = sns.scatterplot('PC1', 'PC2',
186 | data=x_te_proj_pca_df,
187 | palette='tab10',
188 | hue='label',
189 | linewidth=0,
190 | alpha=0.6,
191 | ax=ax
192 | )
193 |
194 | box = ax.get_position()
195 | ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
196 | ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
197 | title = 'Data: %s; Embedding: MLP' % args.data
198 | ax.set_title(title)
199 | fig.savefig('figs/PCA_plot_%s_MLP_last_layer.png' % args.data)
200 | # density plot for PCA
201 | g = sns.jointplot('PC1', 'PC2', data=x_te_proj_pca_df,
202 | kind="hex"
203 | )
204 | plt.subplots_adjust(top=0.95)
205 | g.fig.suptitle(title)
206 | g.savefig('figs/Joint_PCA_plot_%s_MLP_last_layer.png' % args.data)
207 |
208 |
209 | if __name__ == '__main__':
210 | main()
211 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | class UnitNormLayer(tf.keras.layers.Layer):
5 | '''Normalize vectors (euclidean norm) in batch to unit hypersphere.
6 | '''
7 |
8 | def __init__(self):
9 | super(UnitNormLayer, self).__init__()
10 |
11 | def call(self, input_tensor):
12 | norm = tf.norm(input_tensor, axis=1)
13 | return input_tensor / tf.reshape(norm, [-1, 1])
14 |
15 |
16 | class DenseLeakyReluLayer(tf.keras.layers.Layer):
17 | '''A dense layer followed by a LeakyRelu layer
18 | '''
19 |
20 | def __init__(self, n, alpha=0.3):
21 | super(DenseLeakyReluLayer, self).__init__()
22 | self.dense = tf.keras.layers.Dense(n, activation=None)
23 | self.lrelu = tf.keras.layers.LeakyReLU(alpha=alpha)
24 |
25 | def call(self, input_tensor):
26 | x = self.dense(input_tensor)
27 | return self.lrelu(x)
28 |
29 |
30 | class Encoder(tf.keras.Model):
31 | '''An encoder network, E(·), which maps an augmented image x to a representation vector, r = E(x) ∈ R^{DE}
32 | '''
33 |
34 | def __init__(self, normalize=True, activation='relu'):
35 | super(Encoder, self).__init__(name='')
36 | if activation == 'leaky_relu':
37 | self.hidden1 = DenseLeakyReluLayer(256)
38 | self.hidden2 = DenseLeakyReluLayer(256)
39 | else:
40 | self.hidden1 = tf.keras.layers.Dense(256, activation=activation)
41 | self.hidden2 = tf.keras.layers.Dense(256, activation=activation)
42 |
43 | self.normalize = normalize
44 | if self.normalize:
45 | self.norm = UnitNormLayer()
46 |
47 | def call(self, input_tensor, training=False):
48 | x = self.hidden1(input_tensor, training=training)
49 | x = self.hidden2(x, training=training)
50 | if self.normalize:
51 | x = self.norm(x)
52 | return x
53 |
54 |
55 | class Projector(tf.keras.Model):
56 | '''
57 | A projection network, P(·), which maps the normalized representation vector r into a vector z = P(r) ∈ R^{DP}
58 | suitable for computation of the contrastive loss.
59 | '''
60 |
61 | def __init__(self, n, normalize=True, activation='relu'):
62 | super(Projector, self).__init__(name='')
63 | if activation == 'leaky_relu':
64 | self.dense = DenseLeakyReluLayer(256)
65 | self.dense2 = DenseLeakyReluLayer(256)
66 | else:
67 | self.dense = tf.keras.layers.Dense(256, activation=activation)
68 | self.dense2 = tf.keras.layers.Dense(256, activation=activation)
69 |
70 | self.normalize = normalize
71 | if self.normalize:
72 | self.norm = UnitNormLayer()
73 |
74 | def call(self, input_tensor, training=False):
75 | x = self.dense(input_tensor, training=training)
76 | x = self.dense2(x, training=training)
77 | if self.normalize:
78 | x = self.norm(x)
79 | return x
80 |
81 |
82 | class SoftmaxPred(tf.keras.Model):
83 | '''For stage 2, simply a softmax on top of the Encoder.
84 | '''
85 |
86 | def __init__(self, num_classes=10):
87 | super(SoftmaxPred, self).__init__(name='')
88 | self.dense = tf.keras.layers.Dense(num_classes, activation='softmax')
89 |
90 | def call(self, input_tensor, training=False):
91 | return self.dense(input_tensor, training=training)
92 |
93 |
94 | class MLP(tf.keras.Model):
95 | '''A simple baseline MLP with the same architecture to Encoder + Softmax/Regression output.
96 | '''
97 |
98 | def __init__(self, num_classes=10, normalize=True, regress=False, activation='relu'):
99 | super(MLP, self).__init__(name='')
100 | if activation == 'leaky_relu':
101 | self.hidden1 = DenseLeakyReluLayer(256)
102 | self.hidden2 = DenseLeakyReluLayer(256)
103 | else:
104 | self.hidden1 = tf.keras.layers.Dense(256, activation=activation)
105 | self.hidden2 = tf.keras.layers.Dense(256, activation=activation)
106 | self.normalize = normalize
107 | if self.normalize:
108 | self.norm = UnitNormLayer()
109 | if not regress:
110 | self.output_layer = tf.keras.layers.Dense(
111 | num_classes, activation='softmax')
112 | else:
113 | self.output_layer = tf.keras.layers.Dense(1)
114 |
115 | def call(self, input_tensor, training=False):
116 | x = self.hidden1(input_tensor, training=training)
117 | x = self.hidden2(x, training=training)
118 | if self.normalize:
119 | x = self.norm(x)
120 | preds = self.output_layer(x, training=training)
121 | return preds
122 |
123 | def get_last_hidden(self, input_tensor):
124 | '''Get the last hidden layer before prediction.
125 | '''
126 | x = self.hidden1(input_tensor, training=False)
127 | x = self.hidden2(x, training=False)
128 | if self.normalize:
129 | x = self.norm(x)
130 | return x
131 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.9.0
2 | appnope==0.1.0
3 | astor==0.8.1
4 | attrs==19.3.0
5 | backcall==0.1.0
6 | bleach==3.1.4
7 | cachetools==4.1.0
8 | certifi==2020.4.5.1
9 | chardet==3.0.4
10 | cycler==0.10.0
11 | decorator==4.4.2
12 | defusedxml==0.6.0
13 | entrypoints==0.3
14 | gast==0.2.2
15 | google-auth==1.14.1
16 | google-auth-oauthlib==0.4.1
17 | google-pasta==0.2.0
18 | grpcio==1.28.1
19 | h5py==2.10.0
20 | idna==2.9
21 | importlib-metadata==1.6.0
22 | ipykernel==5.2.1
23 | ipython==7.13.0
24 | ipython-genutils==0.2.0
25 | ipywidgets==7.5.1
26 | jedi==0.17.0
27 | Jinja2==2.11.2
28 | joblib==0.14.1
29 | jsonschema==3.2.0
30 | jupyter==1.0.0
31 | jupyter-client==6.1.3
32 | jupyter-console==6.1.0
33 | jupyter-core==4.6.3
34 | Keras-Applications==1.0.8
35 | Keras-Preprocessing==1.1.0
36 | kiwisolver==1.2.0
37 | Markdown==3.2.1
38 | MarkupSafe==1.1.1
39 | matplotlib==3.2.1
40 | mistune==0.8.4
41 | nbconvert==5.6.1
42 | nbformat==5.0.6
43 | notebook==6.1.5
44 | numpy==1.18.3
45 | oauthlib==3.1.0
46 | opt-einsum==3.2.1
47 | pandas==1.0.3
48 | pandocfilters==1.4.2
49 | parso==0.7.0
50 | pexpect==4.8.0
51 | pickleshare==0.7.5
52 | prometheus-client==0.7.1
53 | prompt-toolkit==3.0.5
54 | protobuf==3.11.3
55 | ptyprocess==0.6.0
56 | pyasn1==0.4.8
57 | pyasn1-modules==0.2.8
58 | Pygments==2.6.1
59 | pyparsing==2.4.7
60 | pyrsistent==0.16.0
61 | python-dateutil==2.8.1
62 | pytz==2019.3
63 | pyzmq==19.0.0
64 | qtconsole==4.7.3
65 | QtPy==1.9.0
66 | requests==2.23.0
67 | requests-oauthlib==1.3.0
68 | rsa==4.0
69 | scikit-learn==0.22.2.post1
70 | scipy==1.4.1
71 | seaborn==0.10.0
72 | Send2Trash==1.5.0
73 | six==1.14.0
74 | tensorboard==2.1.1
75 | tensorflow==2.3.1
76 | tensorflow-addons==0.9.1
77 | tensorflow-estimator==2.1.0
78 | termcolor==1.1.0
79 | terminado==0.8.3
80 | testpath==0.4.4
81 | tornado==6.0.4
82 | traitlets==4.3.3
83 | typeguard==2.7.1
84 | urllib3==1.25.9
85 | wcwidth==0.1.9
86 | webencodings==0.5.1
87 | Werkzeug==1.0.1
88 | widgetsnbextension==3.5.1
89 | wrapt==1.12.1
90 | zipp==3.1.0
91 |
--------------------------------------------------------------------------------
/supcontrast.py:
--------------------------------------------------------------------------------
1 | # From https://github.com/HobbitLong/SupContrast/blob/master/losses.py
2 | """
3 | Author: Yonglong Tian (yonglong@mit.edu)
4 | Date: May 07, 2020
5 | """
6 | from __future__ import print_function
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 |
12 | class SupConLoss(nn.Module):
13 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
14 | It also supports the unsupervised contrastive loss in SimCLR"""
15 |
16 | def __init__(self, temperature=0.07, contrast_mode='all',
17 | base_temperature=0.07):
18 | super(SupConLoss, self).__init__()
19 | self.temperature = temperature
20 | self.contrast_mode = contrast_mode
21 | self.base_temperature = base_temperature
22 |
23 | def forward(self, features, labels=None, mask=None):
24 | """Compute loss for model. If both `labels` and `mask` are None,
25 | it degenerates to SimCLR unsupervised loss:
26 | https://arxiv.org/pdf/2002.05709.pdf
27 | Args:
28 | features: hidden vector of shape [bsz, n_views, ...].
29 | labels: ground truth of shape [bsz].
30 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
31 | has the same class as sample i. Can be asymmetric.
32 | Returns:
33 | A loss scalar.
34 | """
35 | device = (torch.device('cuda')
36 | if features.is_cuda
37 | else torch.device('cpu'))
38 |
39 | if len(features.shape) < 3:
40 | raise ValueError('`features` needs to be [bsz, n_views, ...],'
41 | 'at least 3 dimensions are required')
42 | if len(features.shape) > 3:
43 | features = features.view(features.shape[0], features.shape[1], -1)
44 |
45 | batch_size = features.shape[0]
46 | if labels is not None and mask is not None:
47 | raise ValueError('Cannot define both `labels` and `mask`')
48 | elif labels is None and mask is None:
49 | mask = torch.eye(batch_size, dtype=torch.float32).to(device)
50 | elif labels is not None:
51 | labels = labels.contiguous().view(-1, 1)
52 | if labels.shape[0] != batch_size:
53 | raise ValueError(
54 | 'Num of labels does not match num of features')
55 | mask = torch.eq(labels, labels.T).float().to(device)
56 | else:
57 | mask = mask.float().to(device)
58 |
59 | contrast_count = features.shape[1]
60 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
61 | if self.contrast_mode == 'one':
62 | anchor_feature = features[:, 0]
63 | anchor_count = 1
64 | elif self.contrast_mode == 'all':
65 | anchor_feature = contrast_feature
66 | anchor_count = contrast_count
67 | else:
68 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
69 |
70 | # compute logits
71 | anchor_dot_contrast = torch.div(
72 | torch.matmul(anchor_feature, contrast_feature.T),
73 | self.temperature)
74 | # for numerical stability
75 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
76 | logits = anchor_dot_contrast - logits_max.detach()
77 |
78 | # tile mask
79 | mask = mask.repeat(anchor_count, contrast_count)
80 | # mask-out self-contrast cases
81 | logits_mask = torch.scatter(
82 | torch.ones_like(mask),
83 | 1,
84 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
85 | 0
86 | )
87 | mask = mask * logits_mask
88 |
89 | # compute log_prob
90 | exp_logits = torch.exp(logits) * logits_mask
91 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
92 |
93 | # compute mean of log-likelihood over positive
94 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
95 |
96 | # loss
97 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
98 | loss = loss.view(anchor_count, batch_size).mean()
99 |
100 | return loss
101 |
--------------------------------------------------------------------------------
/test_supcontrast_loss.py:
--------------------------------------------------------------------------------
1 | from losses import supervised_nt_xent_loss
2 | from supcontrast import SupConLoss
3 | import torch
4 | import tensorflow as tf
5 | import unittest
6 | import numpy as np
7 | np.random.seed(42)
8 |
9 |
10 | class TestSupContrastLoss(unittest.TestCase):
11 | '''To test my tensorflow implementation of Supervised Contrastive loss yields the same
12 | values with the Torch implementation.
13 | '''
14 |
15 | def setUp(self):
16 | self.batch_size = 128
17 | X = np.random.randn(self.batch_size, 128)
18 | X /= np.linalg.norm(X, axis=1).reshape(-1, 1)
19 | self.X = X.astype(np.float32)
20 | self.y = np.random.choice(np.arange(10), self.batch_size, replace=True)
21 |
22 | # very small batch where there could be only class with only one example
23 | self.batch_size_s = 8
24 | X_s = np.random.randn(self.batch_size_s, 128)
25 | X_s /= np.linalg.norm(X_s, axis=1).reshape(-1, 1)
26 | self.X_s = X_s.astype(np.float32)
27 | self.y_s = np.random.choice(
28 | np.arange(10), self.batch_size_s, replace=True)
29 |
30 | self.temperature = 0.5
31 | self.base_temperature = 0.07
32 |
33 | def test_nt_xent_loss_equals_sup_con_loss(self):
34 | l1 = supervised_nt_xent_loss(tf.constant(self.X),
35 | tf.constant(self.y),
36 | temperature=self.temperature,
37 | base_temperature=self.base_temperature
38 | )
39 |
40 | scl = SupConLoss(temperature=self.temperature,
41 | base_temperature=self.base_temperature
42 | )
43 | l2 = scl.forward(features=torch.Tensor(self.X.reshape(self.batch_size, 1, 128)),
44 | labels=torch.Tensor(self.y)
45 | )
46 | print('\nLosses from normal batch size={}:'.format(self.batch_size))
47 | print('l1 = {}'.format(l1.numpy()))
48 | print('l2 = {}'.format(l2.numpy()))
49 | self.assertTrue(np.allclose(l1.numpy(), l2.numpy()))
50 |
51 | def test_nt_xent_loss_and_sup_con_loss_small_batch(self):
52 | # on very small batch, the SupConLoss would return NaN
53 | # whereas supervised_nt_xent_loss will ignore those classes
54 | l1 = supervised_nt_xent_loss(tf.constant(self.X_s),
55 | tf.constant(self.y_s),
56 | temperature=self.temperature,
57 | base_temperature=self.base_temperature
58 | )
59 |
60 | scl = SupConLoss(temperature=self.temperature,
61 | base_temperature=self.base_temperature
62 | )
63 | l2 = scl.forward(features=torch.Tensor(self.X_s.reshape(self.batch_size_s, 1, 128)),
64 | labels=torch.Tensor(self.y_s)
65 | )
66 | print('\nLosses from small batch size={}:'.format(self.batch_size_s))
67 | print('l1 = {}'.format(l1.numpy()))
68 | print('l2 = {}'.format(l2.numpy()))
69 | self.assertTrue(np.isfinite(l1.numpy()))
70 | self.assertTrue(np.isnan(l2.numpy()))
71 |
72 |
73 | if __name__ == "__main__":
74 | unittest.main()
75 |
--------------------------------------------------------------------------------