├── .gitignore ├── LICENSE ├── MindSpore_version └── rotlayers.py ├── README.md ├── backbones ├── adgcl.py ├── attention_mil.py ├── ddi_gin.py └── resnet_imagenet.py └── pooling ├── __init__.py ├── baselines.py ├── rotlayers.py ├── rotpooling.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 S-Data Science Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MindSpore_version/rotlayers.py: -------------------------------------------------------------------------------- 1 | import mindspore 2 | import mindspore.nn as nn 3 | 4 | 5 | def uot_sinkhorn(x: mindspore.Tensor, p0: mindspore.Tensor, q0: mindspore.Tensor, 6 | a1: mindspore.Tensor, a2: mindspore.Tensor, a3: mindspore.Tensor, mask: mindspore.Tensor, 7 | num: int = 4, eps: float = 1e-8) -> mindspore.Tensor: 8 | """ 9 | Solving regularized optimal transport via Sinkhorn-scaling 10 | :param x: (B, N, D), a matrix with N samples and each sample is D-dimensional 11 | :param p0: (B, 1, D), the marginal prior of dimensions 12 | :param q0: (B, N, 1), the marginal prior of samples 13 | :param a1: (num, ), the weight of the entropic term 14 | :param a2: (num, ), the weight of the KL term of p0 15 | :param a3: (num, ), the weight of the KL term of q0 16 | :param mask: (B, N, 1) a masking tensor 17 | :param num: the number of outer iterations 18 | :param eps: the epsilon to avoid numerical instability 19 | :return: 20 | t: (B, N, D), the optimal transport matrix 21 | """ 22 | t = (q0 * p0) * mask # (B, N, D) 23 | log_p0 = mindspore.ops.log(p0) # (B, 1, D) 24 | log_q0 = mindspore.ops.log(q0 + eps) * mask # (B, N, 1) 25 | tau = 0.0 26 | cost = (-x - tau * mindspore.ops.log(t + eps)) * mask # (B, N, D) 27 | a = mindspore.ops.zeros_like(p0) # (B, 1, D) 28 | b = mindspore.ops.zeros_like(q0) # (B, N, 1) 29 | a11 = a1[0] + tau 30 | y = -cost / a11 # (B, N, D) 31 | for k in range(num): 32 | n = min([k, a1.shape[0] - 1]) 33 | a11 = a1[n] + tau 34 | ymin, _ = mindspore.ops.min(y, axis=1, keepdims=True) 35 | ymax, _ = mindspore.ops.max(ymin - mask * ymin + y, axis=1, keepdims=True) # (B, 1, D) 36 | log_p = mindspore.ops.log(mindspore.ops.sum(mindspore.ops.exp((y - ymax) * mask) * mask)) + ymax # (B, 1, D) 37 | log_q = mindspore.ops.logsumexp(y, axis=2, keep_dims=True) * mask # (B, N, 1) 38 | a = a2[n] / (a2[n] + a11) * (a / a11 + log_p0 - log_p) 39 | b = a3[n] / (a3[n] + a11) * (b / a11 + log_q0 - log_q) 40 | y = (-cost / a11 + a + b) * mask 41 | t = mindspore.ops.exp(y) * mask 42 | return t 43 | 44 | 45 | def rot_sinkhorn(x: mindspore.Tensor, c1: mindspore.Tensor, c2: mindspore.Tensor, p0: mindspore.Tensor, q0: mindspore.Tensor, 46 | a0: mindspore.Tensor, a1: mindspore.Tensor, a2: mindspore.Tensor, a3: mindspore.Tensor, mask: mindspore.Tensor, 47 | num: int = 4, inner: int = 5, eps: float = 1e-8) -> mindspore.Tensor: 48 | """ 49 | Solving regularized optimal transport via Sinkhorn-scaling 50 | :param x: (B, N, D), a matrix with N samples and each sample is D-dimensional 51 | :param c1: (B, D, D), a matrix with size D x D 52 | :param c2: (B, N, N), a matrix with size N x N 53 | :param p0: (B, 1, D), the marginal prior of dimensions 54 | :param q0: (B, N, 1), the marginal prior of samples 55 | :param a0: (num, ), the weight of the GW term 56 | :param a1: (num, ), the weight of the entropic term 57 | :param a2: (num, ), the weight of the KL term of p0 58 | :param a3: (num, ), the weight of the KL term of q0 59 | :param mask: (B, N, 1) a masking tensor 60 | :param num: the number of outer iterations 61 | :param inner: the number of inner Sinkhorn iterations 62 | :param eps: the epsilon to avoid numerical instability 63 | :return: 64 | t: (B, N, D), the optimal transport matrix 65 | """ 66 | t = (q0 * p0) * mask # (B, N, D) 67 | log_p0 = mindspore.ops.log(p0) # (B, 1, D) 68 | log_q0 = mindspore.ops.log(q0 + eps) * mask # (B, N, 1) 69 | tau = 1.0 70 | for m in range(num): 71 | n = min([m, a1.shape[0]-1]) 72 | a11 = a1[n] + tau 73 | tmp1 = mindspore.ops.matmul(c2, t) # (B, N, D) 74 | tmp2 = mindspore.ops.matmul(tmp1, c1) # (B, N, D) 75 | cost = (-x - a0[n] * tmp2 - tau * mindspore.ops.log(t + eps)) * mask # (B, N, D) 76 | a = mindspore.ops.zeros_like(p0) # (B, 1, D) 77 | b = mindspore.ops.zeros_like(q0) # (B, N, 1) 78 | y = -cost / a11 # (B, N, D) 79 | for k in range(inner): 80 | ymin, _ = mindspore.ops.min(y, axis=1, keepdims=True) 81 | ymax, _ = mindspore.ops.max(ymin - mask * ymin + y, axis=1, keepdims=True) # (B, 1, D) 82 | log_p = mindspore.ops.log(mindspore.ops.sum(mindspore.ops.exp((y - ymax) * mask) * mask)) + ymax # (B, 1, D) 83 | log_q = mindspore.ops.logsumexp(y, axis=2, keep_dims=True) * mask # (B, N, 1) 84 | a = a2[n] / (a2[n] + a11) * (a / a11 + log_p0 - log_p) 85 | b = a3[n] / (a3[n] + a11) * (b / a11 + log_q0 - log_q) 86 | y = (-cost / a11 + a + b) * mask 87 | t = mindspore.ops.exp(y) * mask 88 | return t 89 | 90 | 91 | def uot_badmm(x: mindspore.Tensor, p0: mindspore.Tensor, q0: mindspore.Tensor, 92 | a1: mindspore.Tensor, a2: mindspore.Tensor, a3: mindspore.Tensor, rho: mindspore.Tensor, 93 | mask: mindspore.Tensor, num: int = 4, eps: float = 1e-8) -> mindspore.Tensor: 94 | """ 95 | Solving regularized optimal transport via Bregman ADMM algorithm (entropic regularizer) 96 | :param x: (B, N, D), a matrix with N samples and each sample is D-dimensional 97 | :param p0: (B, 1, D), the marginal prior of dimensions 98 | :param q0: (B, N, 1), the marginal prior of samples 99 | :param a1: (num, ), the weight of the entropic term 100 | :param a2: (num, ), the weight of the KL term of p0 101 | :param a3: (num, ), the weight of the KL term of q0 102 | :param rho: (num, ), the learning rate of ADMM 103 | :param mask: (B, N, 1) a masking tensor 104 | :param num: the number of Bregman ADMM iterations 105 | :param eps: the epsilon to avoid numerical instability 106 | :return: 107 | t: (N, D), the optimal transport matrix 108 | """ 109 | log_p0 = mindspore.ops.log(p0) # (B, 1, D) 110 | log_q0 = mindspore.ops.log(q0 + eps) * mask # (B, N, 1) 111 | log_t = (log_q0 + log_p0) * mask # (B, N, D) 112 | log_s = (log_q0 + log_p0) * mask # (B, N, D) 113 | log_mu = mindspore.ops.log(p0) # (B, 1, D) 114 | log_eta = mindspore.ops.log(q0 + eps) * mask # (B, N, 1) 115 | z = mindspore.ops.zeros_like(log_t) # (B, N, D) 116 | z1 = mindspore.ops.zeros_like(p0) # (B, 1, D) 117 | z2 = mindspore.ops.zeros_like(q0) # (B, N, 1) 118 | for k in range(num): 119 | n = min([k, a1.shape[0] - 1]) 120 | # update logP 121 | y = ((x - z) / rho[n] + log_s) # (B, N, D) 122 | log_t = mask * (log_eta - mindspore.ops.logsumexp(y, axis=2, keep_dims=True)) + y # (B, N, D) 123 | # update logS 124 | y = (z + rho[n] * log_t) / (a1[n] + rho[n]) # (B, N, D) 125 | ymin, _ = mindspore.ops.min(y, axis=1, keepdims=True) 126 | ymax, _ = mindspore.ops.max(ymin- mask * ymin + y, axis=1, keepdims=True) # (B, 1, D) 127 | # (B, N, D) 128 | log_s = mask * ( 129 | log_mu - mindspore.ops.log(mindspore.ops.sum(mindspore.ops.exp((y - ymax) * mask) * mask)) - ymax) + y 130 | # update dual variables 131 | t = mindspore.ops.exp(log_t) * mask 132 | s = mindspore.ops.exp(log_s) * mask 133 | z = z + rho[n] * (t - s) 134 | y = (rho[n] * log_mu + a2[n] * log_p0 - z1) / (rho[n] + a2[n]) # (B, 1, D) 135 | log_mu = y - mindspore.ops.logsumexp(y, axis=2, keep_dims=True) # (B, 1, D) 136 | y = ((rho[n] * log_eta + a3[n] * log_q0 - z2) / (rho[n] + a3[n])) # (B, N, 1) 137 | ymin, _ = mindspore.ops.min(y, axis=1, keepdims=True) 138 | ymax, _ = mindspore.ops.max(ymin - mask * ymin + y, axis=1, keepdims=True) # (B, 1, D) 139 | log_eta = (y - mindspore.ops.log( 140 | mindspore.ops.sum(mindspore.ops.exp((y - ymax) * mask) * mask)) - ymax) * mask # (B, N, 1) 141 | # update dual variables 142 | z1 = z1 + rho[n] * (mindspore.ops.exp(log_mu) - mindspore.ops.sum(s, dim=1, keepdim=True)) # (B, 1, D) 143 | z2 = z2 + rho[n] * (mindspore.ops.exp(log_eta) * mask - mindspore.ops.sum(t, dim=2, keepdim=True)) * mask # (B, N, 1) 144 | return mindspore.ops.exp(log_t) * mask 145 | 146 | 147 | def rot_badmm(x: mindspore.Tensor, c1: mindspore.Tensor, c2: mindspore.Tensor, p0: mindspore.Tensor, q0: mindspore.Tensor, 148 | a0: mindspore.Tensor, a1: mindspore.Tensor, a2: mindspore.Tensor, a3: mindspore.Tensor, rho: mindspore.Tensor, 149 | mask: mindspore.Tensor, num: int = 4, eps: float = 1e-8) -> mindspore.Tensor: 150 | """ 151 | Solving regularized optimal transport via Bregman ADMM algorithm (entropic regularizer) 152 | :param x: (B, N, D), a matrix with N samples and each sample is D-dimensional 153 | :param c1: (B, D, D), a matrix with size D x D 154 | :param c2: (B, N, N), a matrix with size N x N 155 | :param p0: (B, 1, D), the marginal prior of dimensions 156 | :param q0: (B, N, 1), the marginal prior of samples 157 | :param a0: (num, ), the weight of the GW term 158 | :param a1: (num, ), the weight of the entropic term 159 | :param a2: (num, ), the weight of the KL term of p0 160 | :param a3: (num, ), the weight of the KL term of q0 161 | :param rho: (num, ), the learning rate of ADMM 162 | :param mask: (B, N, 1) a masking tensor 163 | :param num: the number of Bregman ADMM iterations 164 | :param eps: the epsilon to avoid numerical instability 165 | :return: 166 | t: (N, D), the optimal transport matrix 167 | """ 168 | log_p0 = mindspore.ops.log(p0) # (B, 1, D) 169 | log_q0 = mindspore.ops.log(q0 + eps) * mask # (B, N, 1) 170 | log_t = (log_q0 + log_p0) * mask # (B, N, D) 171 | log_s = (log_q0 + log_p0) * mask # (B, N, D) 172 | log_mu = mindspore.ops.log(p0) # (B, 1, D) 173 | log_eta = mindspore.ops.log(q0 + eps) * mask # (B, N, 1) 174 | z = mindspore.ops.zeros_like(log_t) # (B, N, D) 175 | z1 = mindspore.ops.zeros_like(p0) # (B, 1, D) 176 | z2 = mindspore.ops.zeros_like(q0) # (B, N, 1) 177 | for k in range(num): 178 | n = min([k, a1.shape[0] - 1]) 179 | # update logP 180 | tmp1 = mindspore.ops.matmul(c2, mindspore.ops.exp(log_s) * mask) 181 | tmp2 = mindspore.ops.matmul(tmp1, c1) 182 | y = (x + a0[n] * tmp2 - z) / rho[n] + log_s # (B, N, D) 183 | log_t = mask * (log_eta - mindspore.ops.logsumexp(y, axis=2, keep_dims=True)) + y # (B, N, D) 184 | # update logS 185 | tmp1 = mindspore.ops.matmul(c2, mindspore.ops.exp(log_t) * mask) 186 | tmp2 = mindspore.ops.matmul(tmp1, c1) 187 | y = (z + a0[n] * tmp2 + rho[n] * log_t) / (a1[n] + rho[n]) # (B, N, D) 188 | ymin, _ = mindspore.ops.min(y, axis=1, keepdims=True) 189 | ymax, _ = mindspore.ops.max(ymin - mask * ymin + y, axis=1, keepdims=True) # (B, 1, D) 190 | # (B, N, D) 191 | log_s = mask * ( 192 | log_mu - mindspore.ops.log(mindspore.ops.sum(mindspore.ops.exp((y - ymax) * mask) * mask)) - ymax) + y 193 | # update dual variables 194 | t = mindspore.ops.exp(log_t) * mask 195 | s = mindspore.ops.exp(log_s) * mask 196 | z = z + rho[n] * (t - s) 197 | # update log_mu 198 | y = (rho[n] * log_mu + a2[n] * log_p0 - z1) / (rho[n] + a2[n]) # (B, 1, D) 199 | log_mu = y - mindspore.ops.logsumexp(y, axis=2, keep_dims=True) # (B, 1, D) 200 | # update log_eta 201 | y = ((rho[n] * log_eta + a3[n] * log_q0 - z2) / (rho[n] + a3[n])) * mask # (B, N, 1) 202 | ymin, _ = mindspore.ops.min(y, axis=1, keepdims=True) 203 | ymax, _ = mindspore.ops.max(ymin - mask * ymin + y, axis=1, keepdims=True) # (B, 1, D) 204 | log_eta = (y - mindspore.ops.log( 205 | mindspore.ops.sum(mindspore.ops.exp((y - ymax) * mask) * mask)) - ymax) * mask # (B, N, 1) 206 | # update dual variables 207 | z1 = z1 + rho[n] * (mindspore.ops.exp(log_mu) - mindspore.ops.sum(s, dim=1, keepdim=True)) # (B, 1, D) 208 | z2 = z2 + rho[n] * (mindspore.ops.exp(log_eta) * mask - mindspore.ops.sum(t, dim=2, keepdim=True)) * mask # (B, N, 1) 209 | return mindspore.ops.exp(log_t) * mask 210 | 211 | 212 | def uot_badmm2(x: mindspore.Tensor, p0: mindspore.Tensor, q0: mindspore.Tensor, 213 | a1: mindspore.Tensor, a2: mindspore.Tensor, a3: mindspore.Tensor, rho: mindspore.Tensor, 214 | mask: mindspore.Tensor, num: int = 4, eps: float = 1e-8) -> mindspore.Tensor: 215 | """ 216 | Solving regularized optimal transport via Bregman ADMM algorithm (quadratic regularizer) 217 | :param x: (B, N, D), a matrix with N samples and each sample is D-dimensional 218 | :param p0: (B, 1, D), the marginal prior of dimensions 219 | :param q0: (B, N, 1), the marginal prior of samples 220 | :param a1: (num, ), the weight of the entropic term 221 | :param a2: (num, ), the weight of the KL term of p0 222 | :param a3: (num, ), the weight of the KL term of q0 223 | :param rho: (num, ), the learning rate of ADMM 224 | :param mask: (B, N, 1) a masking tensor 225 | :param num: the number of Bregman ADMM iterations 226 | :param eps: the epsilon to avoid numerical instability 227 | :param rho: the learning rate of ADMM 228 | :return: 229 | t: (N, D), the optimal transport matrix 230 | """ 231 | log_p0 = mindspore.ops.log(p0) # (B, 1, D) 232 | log_q0 = mindspore.ops.log(q0 + eps) * mask # (B, N, 1) 233 | log_t = (log_q0 + log_p0) * mask # (B, N, D) 234 | log_s = (log_q0 + log_p0) * mask # (B, N, D) 235 | log_mu = mindspore.ops.log(p0) # (B, 1, D) 236 | log_eta = mindspore.ops.log(q0 + eps) * mask # (B, N, 1) 237 | z = mindspore.ops.zeros_like(log_t) # (B, N, D) 238 | z1 = mindspore.ops.zeros_like(p0) # (B, 1, D) 239 | z2 = mindspore.ops.zeros_like(q0) # (B, N, 1) 240 | for k in range(num): 241 | n = min([k, a1.shape[0] - 1]) 242 | # update logP 243 | y = (x - a1[n] * mindspore.ops.exp(log_s) * mask - z) / rho[n] + log_s # (B, N, D) 244 | log_t = mask * (log_eta - mindspore.ops.logsumexp(y, axis=2, keep_dims=True)) + y # (B, N, D) 245 | # update logS 246 | y = (z - a1[n] * mindspore.ops.exp(log_t) * mask) / rho[n] + log_t # (B, N, D) 247 | ymin, _ = mindspore.ops.min(y, axis=1, keepdims=True) 248 | ymax, _ = mindspore.ops.max(ymin - mask * ymin + y, axis=1, keepdims=True) # (B, 1, D) 249 | # (B, N, D) 250 | log_s = mask * ( 251 | log_mu - mindspore.ops.log(mindspore.ops.sum(mindspore.ops.exp((y - ymax) * mask) * mask, dim=1, keepdim=True)) - ymax) + y 252 | # update dual variables 253 | t = mindspore.ops.exp(log_t) * mask 254 | s = mindspore.ops.exp(log_s) * mask 255 | z = z + rho[n] * (t - s) 256 | # update log_mu 257 | y = (rho[n] * log_mu + a2[n] * log_p0 - z1) / (rho[n] + a2[n]) # (B, 1, D) 258 | log_mu = y - mindspore.ops.logsumexp(y, axis=2, keep_dims=True) # (B, 1, D) 259 | # update log_eta 260 | y = ((rho[n] * log_eta + a3[n] * log_q0 - z2) / (rho[n] + a3[n])) * mask # (B, N, 1) 261 | ymin, _ = mindspore.ops.min(y, axis=1, keepdims=True) 262 | ymax, _ = mindspore.ops.max(ymin - mask * ymin + y, axis=1, keepdims=True) # (B, 1, D) 263 | log_eta = (y - mindspore.ops.log( 264 | mindspore.ops.sum(mindspore.ops.exp((y - ymax) * mask) * mask)) - ymax) * mask # (B, N, 1) 265 | # update dual variables 266 | z1 = z1 + rho[n] * (mindspore.ops.exp(log_mu) - mindspore.ops.sum(s, dim=1, keepdim=True)) # (B, 1, D) 267 | z2 = z2 + rho[n] * (mindspore.ops.exp(log_eta) * mask - mindspore.ops.sum(t, dim=2, keepdim=True)) * mask # (B, N, 1) 268 | return mindspore.ops.exp(log_t) * mask 269 | 270 | 271 | def rot_badmm2(x: mindspore.Tensor, c1: mindspore.Tensor, c2: mindspore.Tensor, p0: mindspore.Tensor, q0: mindspore.Tensor, 272 | a0: mindspore.Tensor, a1: mindspore.Tensor, a2: mindspore.Tensor, a3: mindspore.Tensor, rho: mindspore.Tensor, 273 | mask: mindspore.Tensor, num: int = 4, eps: float = 1e-8) -> mindspore.Tensor: 274 | """ 275 | Solving regularized optimal transport via Bregman ADMM algorithm (quadratic regularizer) 276 | :param x: (B, N, D), a matrix with N samples and each sample is D-dimensional 277 | :param c1: (B, D, D), a matrix with size D x D 278 | :param c2: (B, N, N), a matrix with size N x N 279 | :param p0: (B, 1, D), the marginal prior of dimensions 280 | :param q0: (B, N, 1), the marginal prior of samples 281 | :param a0: (num, ), the weight of the GW term 282 | :param a1: (num, ), the weight of the entropic term 283 | :param a2: (num, ), the weight of the KL term of p0 284 | :param a3: (num, ), the weight of the KL term of q0 285 | :param rho: (num, ), the weight of the ADMM term 286 | :param mask: (B, N, 1) a masking tensor 287 | :param num: the number of Bregman ADMM iterations 288 | :param eps: the epsilon to avoid numerical instability 289 | :param rho: the learning rate of ADMM 290 | :return: 291 | t: (N, D), the optimal transport matrix 292 | """ 293 | log_p0 = mindspore.ops.log(p0) # (B, 1, D) 294 | log_q0 = mindspore.ops.log(q0 + eps) * mask # (B, N, 1) 295 | log_t = (log_q0 + log_p0) * mask # (B, N, D) 296 | log_s = (log_q0 + log_p0) * mask # (B, N, D) 297 | log_mu = mindspore.ops.log(p0) # (B, 1, D) 298 | log_eta = mindspore.ops.log(q0 + eps) * mask # (B, N, 1) 299 | z = mindspore.ops.zeros_like(log_t) # (B, N, D) 300 | z1 = mindspore.ops.zeros_like(p0) # (B, 1, D) 301 | z2 = mindspore.ops.zeros_like(q0) # (B, N, 1) 302 | for k in range(num): 303 | n = min([k, a1.shape[0] - 1]) 304 | # update logP 305 | tmp1 = mindspore.ops.matmul(c2, mindspore.ops.exp(log_s) * mask) 306 | tmp2 = mindspore.ops.matmul(tmp1, c1) 307 | y = (x + a0[n] * tmp2 - a1[n] * mindspore.ops.exp(log_s) * mask - z) / rho[n] + log_s # (B, N, D) 308 | log_t = mask * (log_eta - mindspore.ops.logsumexp(y, axis=2, keep_dims=True)) + y 309 | # update logS 310 | tmp1 = mindspore.ops.matmul(c2, mindspore.ops.exp(log_t) * mask) 311 | tmp2 = mindspore.ops.matmul(tmp1, c1) 312 | y = (z + a0[n] * tmp2 - a1[n] * mindspore.ops.exp(log_t) * mask) / rho[n] + log_t # (B, N, D) 313 | ymin, _ = mindspore.ops.min(y, axis=1, keepdims=True) 314 | ymax, _ = mindspore.ops.max(ymin - mask * ymin + y, axis=1, keepdims=True) # (B, 1, D) 315 | # (B, N, D) 316 | log_s = mask * ( 317 | log_mu - mindspore.ops.log(mindspore.ops.sum(mindspore.ops.exp((y - ymax) * mask) * mask, dim=1, keepdim=True)) - ymax) + y 318 | # update dual variables 319 | t = mindspore.ops.exp(log_t) * mask 320 | s = mindspore.ops.exp(log_s) * mask 321 | z = z + rho[n] * (t - s) 322 | # update log_mu 323 | y = (rho[n] * log_mu + a2[n] * log_p0 - z1) / (rho[n] + a2[n]) # (B, 1, D) 324 | log_mu = y - mindspore.ops.logsumexp(y, axis=2, keep_dims=True) # (B, 1, D) 325 | # update log_eta 326 | y = ((rho[n] * log_eta + a3[n] * log_q0 - z2) / (rho[n] + a3[n])) * mask # (B, N, 1) 327 | ymin, _ = mindspore.ops.min(y, axis=1, keepdims=True) 328 | ymax, _ = mindspore.ops.max(ymin - mask * ymin + y, axis=1, keepdims=True) # (B, 1, D) 329 | log_eta = (y - mindspore.ops.log( 330 | mindspore.ops.sum(mindspore.ops.exp((y - ymax) * mask) * mask, dim=1, keepdim=True)) - ymax) * mask # (B, N, 1) 331 | # update dual variables 332 | z1 = z1 + rho[n] * (mindspore.ops.exp(log_mu) - mindspore.ops.sum(s, dim=1, keepdim=True)) # (B, 1, D) 333 | z2 = z2 + rho[n] * (mindspore.ops.exp(log_eta) * mask - mindspore.ops.sum(t, dim=2, keepdim=True)) * mask # (B, N, 1) 334 | return mindspore.ops.exp(log_t) * mask 335 | 336 | 337 | class ROT(nn.Cell): 338 | """ 339 | Neural network layer to implement regularized optimal transport. 340 | 341 | Parameters: 342 | ----------- 343 | :param num: int, the number of iterations 344 | :param eps: float, default: 1.0e-8 345 | The epsilon avoiding numerical instability 346 | :param f_method: str, default: 'badmm-e' 347 | The feed-forward method, badmm-e, badmm-q, or sinkhorn 348 | """ 349 | 350 | def __init__(self, num: int = 4, eps: float = 1e-8, f_method: str = 'badmm-e'): 351 | super(ROT, self).__init__() 352 | self.num = num 353 | self.eps = eps 354 | self.f_method = f_method 355 | 356 | def construct(self, x, c1, c2, p0, q0, a0, a1, a2, a3, rho, mask): 357 | """ 358 | Solving regularized OT problem 359 | """ 360 | if self.f_method == 'badmm-e': 361 | t = rot_badmm(x, c1, c2, p0, q0, a0, a1, a2, a3, rho, mask, self.num, self.eps) 362 | elif self.f_method == 'badmm-q': 363 | t = rot_badmm2(x, c1, c2, p0, q0, a0, a1, a2, a3, rho, mask, self.num, self.eps) 364 | else: 365 | t = rot_sinkhorn(x, c1, c2, p0, q0, a0, a1, a2, a3, mask, self.num, inner=0, eps=self.eps) 366 | return t 367 | 368 | 369 | class UOT(nn.Cell): 370 | """ 371 | Neural network layer to implement unbalanced optimal transport. 372 | 373 | Parameters: 374 | ----------- 375 | :param num: int, the number of iterations 376 | :param eps: float, default: 1.0e-8 377 | The epsilon avoiding numerical instability 378 | :param f_method: str, default: 'badmm-e' 379 | The feed-forward method, badmm-e, badmm-q or sinkhorn 380 | """ 381 | 382 | def __init__(self, num: int = 4, eps: float = 1e-8, f_method: str = 'badmm-e'): 383 | super(UOT, self).__init__() 384 | self.num = num 385 | self.eps = eps 386 | self.f_method = f_method 387 | 388 | def construct(self, x, p0, q0, a1, a2, a3, rho, mask): 389 | """ 390 | Solving regularized OT problem 391 | """ 392 | if self.f_method == 'badmm-e': 393 | t = uot_badmm(x, p0, q0, a1, a2, a3, rho, mask, self.num, self.eps) 394 | elif self.f_method == 'badmm-q': 395 | t = uot_badmm2(x, p0, q0, a1, a2, a3, rho, mask, self.num, self.eps) 396 | else: 397 | t = uot_sinkhorn(x, p0, q0, a1, a2, a3, mask, self.num, eps=self.eps) 398 | return t 399 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ROT-Pooling 2 | * Regularized Optimal Transport Layers for Generalized Global Pooling Operations [https://ieeexplore.ieee.org/document/10247589]. 3 | * The work is an extension of "Revisiting Global Pooling through the Lens of Optimal Transport" [https://arxiv.org/pdf/2201.09191.pdf]. 4 | 5 | ## Dependencies 6 | 7 | * [PyTorch] - Version: 1.10.0 8 | * [PyTorch Geometric] - Version: 2.0.3 9 | 10 | ## Training & Evaluation 11 | 12 | 13 | * attention_mil 14 | 15 | ``` 16 | python attention_mil.py --DS 'datasets/messidor' --pooling_layer 'uot_pooling' --f_method 'sinkhorn' --num 4 17 | ``` 18 | 19 | 20 | * adgcl 21 | 22 | ``` 23 | python adgcl.py --DS 'IMDB-BINARY' --pooling_layer 'rot_pooling' --f_method 'badmm-e' --num 4 24 | ``` 25 | 26 | * ddi 27 | 28 | ``` 29 | python ddi_gin.py --DS 'fears' --pooling_layer 'uot_pooling' --f_method 'badmm-e' --num 4 30 | ``` 31 | 32 | * resnet-imagenet 33 | 34 | 35 | The setting of parameters refer to the github link: https://github.com/pytorch/examples/tree/main/imagenet 36 | 37 | ``` 38 | python resnet_imagenet.py --f_method 'badmm-e' --num 4 39 | ``` 40 | 41 | ## Parameters 42 | 43 | 44 | ```DS``` is the dataset. 45 | 46 | ```pooling_layer``` is the pooling layer chosen for the backbone, including add_pooling, mean_pooling, max_pooling, deepset, 47 | mix_pooling, gated_pooling, set_set, attention_pooling, gated_attention_pooling, dynamic_pooling, GeneralizedNormPooling, 48 | SAGPooling, ASAPooling, OTK, SWE, WEGL, uot_pooling, rotpooling. Uot_pooling corresponds to "ROTP(a_0=0)" and rot_pooling corresponds to 49 | "ROTP(learned a_0)" in the paper. 50 | 51 | ```f_method``` could be ```badmm-e, badmm-q, sinkhorn``` 52 | 53 | ```num``` corresponds to K-step feed-forward computation. The default value is 4. 54 | 55 | ## Citation 56 | 57 | If our work can help you, please cite it 58 | ``` 59 | @ARTICLE{10247589, 60 | author={Xu, Hongteng and Cheng, Minjie}, 61 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 62 | title={Regularized Optimal Transport Layers for Generalized Global Pooling Operations}, 63 | year={2023}, 64 | volume={}, 65 | number={}, 66 | pages={1-18}, 67 | doi={10.1109/TPAMI.2023.3314661}} 68 | ``` 69 | 70 | -------------------------------------------------------------------------------- /backbones/adgcl.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | import torch 4 | import json 5 | from sklearn.svm import LinearSVC, SVC 6 | from torch_geometric.data import DataLoader 7 | from torch_geometric.transforms import Compose 8 | from torch_scatter import scatter 9 | import os 10 | import os.path as osp 11 | import shutil 12 | from sklearn.metrics import accuracy_score 13 | from torch_geometric.data import InMemoryDataset, download_url, extract_zip 14 | from torch_geometric.io import read_tu_data 15 | import argparse 16 | from sklearn.model_selection import GridSearchCV, KFold 17 | from sklearn.model_selection import train_test_split 18 | from sklearn.multioutput import MultiOutputClassifier 19 | from sklearn.pipeline import make_pipeline 20 | from sklearn.preprocessing import StandardScaler 21 | from torch_geometric.loader import DataLoader 22 | import numpy as np 23 | import torch.nn.functional as F 24 | from typing import Callable, Union 25 | from torch import Tensor 26 | from torch_geometric.nn.conv import MessagePassing 27 | from torch_geometric.typing import OptPairTensor, Adj, Size 28 | from torch.nn import Sequential, Linear, ReLU 29 | from torch_geometric.nn import global_add_pool, global_max_pool, global_mean_pool, SAGPooling, ASAPooling 30 | import sys 31 | sys.path.append("..") 32 | import pooling as Pooling 33 | import warnings 34 | warnings.filterwarnings("ignore") 35 | import time 36 | 37 | def arg_parse(): 38 | parser = argparse.ArgumentParser(description="AD-GCL TU") 39 | # MUTAG DD PROTEINS NCI1 40 | # COLLAB REDDIT-BINARY REDDIT-MULTI-5K IMDB-BINARY,IMDB-MULTI 41 | # NCI109 PTC_MR 42 | parser.add_argument("--DS", type=str, default="MUTAG", help="Dataset") 43 | parser.add_argument("--model_lr", type=float, default=0.001, help="Model Learning rate.") 44 | parser.add_argument("--view_lr", type=float, default=0.001, help="View Learning rate.") 45 | parser.add_argument("--num_gc_layers", type=int, default=5, help="Number of GNN layers before pooling",) 46 | parser.add_argument("--pooling_type", type=str, default="standard", help="GNN Pooling Type Standard/Layerwise",) 47 | parser.add_argument("--emb_dim", type=int, default=32, help="embedding dimension") 48 | parser.add_argument("--mlp_edge_model_dim", type=int, default=64, help="embedding dimension") 49 | parser.add_argument("--batch_size", type=int, default=32, help="batch size") 50 | parser.add_argument("--drop_ratio", type=float, default=0.5, help="Dropout Ratio / Probability") 51 | parser.add_argument("--epochs", type=int, default=20, help="Train Epochs") 52 | parser.add_argument("--reg_lambda", type=float, default=5.0, help="View Learner Edge Perturb Regularization Strength",) 53 | parser.add_argument("--eval_interval", type=int, default=1, help="eval epochs interval") 54 | parser.add_argument("--downstream_classifier",type=str,default="linear",help="Downstream classifier is linear or non-linear",) 55 | parser.add_argument("--seed", type=int, default=0) 56 | # for uot-pooling 57 | parser.add_argument("--a1", type=float, default=None) 58 | parser.add_argument("--a2", type=float, default=None) 59 | parser.add_argument("--a3", type=float, default=None) 60 | parser.add_argument("--rho", type=float, default=None) 61 | parser.add_argument("--same_para", type=bool, default=False) 62 | parser.add_argument("--h", type=int, default=32) 63 | parser.add_argument("--num", type=int, default=4) 64 | parser.add_argument('--p0', type=str, default='fixed') 65 | parser.add_argument('--q0', type=str, default='fixed') 66 | parser.add_argument("--f_method", type=str, default="badmm-e") 67 | parser.add_argument("--eps", type=float, default=1e-18) 68 | parser.add_argument("--pooling_layer", help="pooling_layer", default="rot_pooling", type=str) 69 | return parser.parse_args() 70 | 71 | class TUDataset(InMemoryDataset): 72 | 73 | url = "https://www.chrsmrrs.com/graphkerneldatasets" 74 | 75 | def __init__( 76 | self, 77 | root, 78 | name, 79 | transform=None, 80 | pre_transform=None, 81 | pre_filter=None, 82 | use_node_attr=False, 83 | use_edge_attr=False, 84 | cleaned=False, 85 | ): 86 | self.name = name 87 | self.cleaned = cleaned 88 | self.num_tasks = 1 89 | self.task_type = "classification" 90 | self.eval_metric = "accuracy" 91 | super(TUDataset, self).__init__(root, transform, pre_transform, pre_filter) 92 | self.data, self.slices = torch.load(self.processed_paths[0]) 93 | if self.data.x is not None and not use_node_attr: 94 | num_node_attributes = self.num_node_attributes 95 | self.data.x = self.data.x[:, num_node_attributes:] 96 | if self.data.edge_attr is not None and not use_edge_attr: 97 | num_edge_attributes = self.num_edge_attributes 98 | self.data.edge_attr = self.data.edge_attr[:, num_edge_attributes:] 99 | 100 | @property 101 | def raw_dir(self): 102 | name = "raw{}".format("_cleaned" if self.cleaned else "") 103 | return osp.join(self.root, self.name, name) 104 | 105 | @property 106 | def processed_dir(self): 107 | name = "processed{}".format("_cleaned" if self.cleaned else "") 108 | return osp.join(self.root, self.name, name) 109 | 110 | @property 111 | def num_node_labels(self): 112 | if self.data.x is None: 113 | return 0 114 | for i in range(self.data.x.size(1)): 115 | x = self.data.x[:, i:] 116 | if ((x == 0) | (x == 1)).all() and (x.sum(dim=1) == 1).all(): 117 | return self.data.x.size(1) - i 118 | return 0 119 | 120 | @property 121 | def num_node_attributes(self): 122 | if self.data.x is None: 123 | return 0 124 | return self.data.x.size(1) - self.num_node_labels 125 | 126 | @property 127 | def num_edge_labels(self): 128 | if self.data.edge_attr is None: 129 | return 0 130 | for i in range(self.data.edge_attr.size(1)): 131 | if self.data.edge_attr[:, i:].sum() == self.data.edge_attr.size(0): 132 | return self.data.edge_attr.size(1) - i 133 | return 0 134 | 135 | @property 136 | def num_edge_attributes(self): 137 | if self.data.edge_attr is None: 138 | return 0 139 | return self.data.edge_attr.size(1) - self.num_edge_labels 140 | 141 | @property 142 | def raw_file_names(self): 143 | names = ["A", "graph_indicator"] 144 | return ["{}_{}.txt".format(self.name, name) for name in names] 145 | 146 | @property 147 | def processed_file_names(self): 148 | return "data.pt" 149 | 150 | def download(self): 151 | url = self.cleaned_url if self.cleaned else self.url 152 | folder = osp.join(self.root, self.name) 153 | path = download_url("{}/{}.zip".format(url, self.name), folder) 154 | extract_zip(path, folder) 155 | os.unlink(path) 156 | shutil.rmtree(self.raw_dir) 157 | os.rename(osp.join(folder, self.name), self.raw_dir) 158 | 159 | def process(self): 160 | self.data, self.slices = read_tu_data(self.raw_dir, self.name) 161 | 162 | if self.pre_filter is not None: 163 | data_list = [self.get(idx) for idx in range(len(self))] 164 | data_list = [data for data in data_list if self.pre_filter(data)] 165 | self.data, self.slices = self.collate(data_list) 166 | 167 | if self.pre_transform is not None: 168 | data_list = [self.get(idx) for idx in range(len(self))] 169 | data_list = [self.pre_transform(data) for data in data_list] 170 | self.data, self.slices = self.collate(data_list) 171 | 172 | torch.save((self.data, self.slices), self.processed_paths[0]) 173 | 174 | def __repr__(self): 175 | return "{}({})".format(self.name, len(self)) 176 | 177 | 178 | class TUEvaluator: 179 | def __init__(self): 180 | self.num_tasks = 1 181 | self.eval_metric = "accuracy" 182 | 183 | def _parse_and_check_input(self, input_dict): 184 | if self.eval_metric == "accuracy": 185 | if not "y_true" in input_dict: 186 | raise RuntimeError("Missing key of y_true") 187 | if not "y_pred" in input_dict: 188 | raise RuntimeError("Missing key of y_pred") 189 | 190 | y_true, y_pred = input_dict["y_true"], input_dict["y_pred"] 191 | 192 | """ 193 | y_true: numpy ndarray or torch tensor of shape (num_graph, num_tasks) 194 | y_pred: numpy ndarray or torch tensor of shape (num_graph, num_tasks) 195 | """ 196 | 197 | # converting to torch.Tensor to numpy on cpu 198 | if torch is not None and isinstance(y_true, torch.Tensor): 199 | y_true = y_true.detach().cpu().numpy() 200 | 201 | if torch is not None and isinstance(y_pred, torch.Tensor): 202 | y_pred = y_pred.detach().cpu().numpy() 203 | 204 | 205 | if not (isinstance(y_true, np.ndarray) and isinstance(y_true, np.ndarray)): 206 | raise RuntimeError("Arguments to Evaluator need to be either numpy ndarray or torch tensor") 207 | 208 | if not y_true.shape == y_pred.shape: 209 | raise RuntimeError("Shape of y_true and y_pred must be the same") 210 | 211 | if not y_true.ndim == 2: 212 | raise RuntimeError( 213 | "y_true and y_pred mush to 2-dim arrray, {}-dim array given".format( y_true.ndim) 214 | ) 215 | 216 | if not y_true.shape[1] == self.num_tasks: 217 | raise RuntimeError( 218 | "Number of tasks should be {} but {} given".format(self.num_tasks, y_true.shape[1]) 219 | ) 220 | 221 | return y_true, y_pred 222 | else: 223 | raise ValueError("Undefined eval metric %s " % self.eval_metric) 224 | 225 | def _eval_accuracy(self, y_true, y_pred): 226 | """ 227 | compute Accuracy score averaged across tasks 228 | """ 229 | acc_list = [] 230 | 231 | for i in range(y_true.shape[1]): 232 | is_labeled = y_true[:, i] == y_true[:, i] 233 | acc = accuracy_score(y_true[is_labeled], y_pred[is_labeled]) 234 | acc_list.append(acc) 235 | 236 | return {"accuracy": sum(acc_list) / len(acc_list)} 237 | 238 | def eval(self, input_dict): 239 | y_true, y_pred = self._parse_and_check_input(input_dict) 240 | return self._eval_accuracy(y_true, y_pred) 241 | 242 | 243 | def initialize_edge_weight(data): 244 | data.edge_weight = torch.ones(data.edge_index.shape[1], dtype=torch.float) 245 | return data 246 | 247 | 248 | def initialize_node_features(data): 249 | num_nodes = int(data.edge_index.max()) + 1 250 | data.x = torch.ones((num_nodes, 1)) 251 | return data 252 | 253 | 254 | def set_tu_dataset_y_shape(data): 255 | num_tasks = 1 256 | data.y = data.y.unsqueeze(num_tasks) 257 | return data 258 | 259 | 260 | class ViewLearner(torch.nn.Module): 261 | def __init__(self, encoder, mlp_edge_model_dim=64): 262 | super(ViewLearner, self).__init__() 263 | 264 | self.encoder = encoder 265 | self.input_dim = self.encoder.out_node_dim 266 | 267 | self.mlp_edge_model = Sequential( 268 | Linear(self.input_dim * 2, mlp_edge_model_dim), 269 | ReLU(), 270 | Linear(mlp_edge_model_dim, 1), 271 | ) 272 | self.init_emb() 273 | 274 | def init_emb(self): 275 | for m in self.modules(): 276 | if isinstance(m, Linear): 277 | torch.nn.init.xavier_uniform_(m.weight.data) 278 | if m.bias is not None: 279 | m.bias.data.fill_(0.0) 280 | 281 | def forward(self, batch, x, edge_index, edge_attr): 282 | 283 | _, node_emb = self.encoder(batch, x, edge_index, edge_attr) 284 | 285 | src, dst = edge_index[0], edge_index[1] 286 | emb_src = node_emb[src] 287 | emb_dst = node_emb[dst] 288 | 289 | edge_emb = torch.cat([emb_src, emb_dst], 1) 290 | edge_logits = self.mlp_edge_model(edge_emb) 291 | 292 | return edge_logits 293 | 294 | 295 | def get_emb_y(loader, encoder, device, dtype="numpy", is_rand_label=False): 296 | x, y = encoder.get_embeddings(loader, device, is_rand_label) 297 | if dtype == "numpy": 298 | return x, y 299 | elif dtype == "torch": 300 | return torch.from_numpy(x).to(device), torch.from_numpy(y).to(device) 301 | else: 302 | raise NotImplementedError 303 | 304 | 305 | class EmbeddingEvaluation: 306 | def __init__( 307 | self, 308 | base_classifier, 309 | evaluator, 310 | task_type, 311 | num_tasks, 312 | device, 313 | params_dict=None, 314 | param_search=False, 315 | is_rand_label=False, 316 | ): 317 | self.is_rand_label = is_rand_label 318 | self.base_classifier = base_classifier 319 | self.evaluator = evaluator 320 | self.eval_metric = evaluator.eval_metric 321 | self.task_type = task_type 322 | self.num_tasks = num_tasks 323 | self.device = device 324 | self.param_search = param_search 325 | self.params_dict = params_dict 326 | if self.eval_metric == "rmse": 327 | self.gscv_scoring_name = "neg_root_mean_squared_error" 328 | elif self.eval_metric == "mae": 329 | self.gscv_scoring_name = "neg_mean_absolute_error" 330 | elif self.eval_metric == "rocauc": 331 | self.gscv_scoring_name = "roc_auc" 332 | elif self.eval_metric == "accuracy": 333 | self.gscv_scoring_name = "accuracy" 334 | else: 335 | raise ValueError( 336 | "Undefined grid search scoring for metric %s " % self.eval_metric 337 | ) 338 | 339 | self.classifier = None 340 | 341 | def scorer(self, y_true, y_raw): 342 | input_dict = {"y_true": y_true, "y_pred": y_raw} 343 | score = self.evaluator.eval(input_dict)[self.eval_metric] 344 | return score 345 | 346 | def ee_binary_classification( 347 | self, train_emb, train_y, val_emb, val_y, test_emb, test_y 348 | ): 349 | # param_search = False 350 | if self.param_search: 351 | params_dict = {"C": [0.001, 0.01, 0.1, 1, 10, 100, 1000]} 352 | self.classifier = make_pipeline( 353 | StandardScaler(), 354 | GridSearchCV( 355 | self.base_classifier, 356 | params_dict, 357 | cv=5, 358 | scoring=self.gscv_scoring_name, 359 | n_jobs=16, 360 | verbose=0, 361 | ), 362 | ) 363 | else: 364 | self.classifier = make_pipeline(StandardScaler(), self.base_classifier) 365 | 366 | self.classifier.fit(train_emb, np.squeeze(train_y)) 367 | 368 | if self.eval_metric == "accuracy": 369 | train_raw = self.classifier.predict(train_emb) 370 | val_raw = self.classifier.predict(val_emb) 371 | test_raw = self.classifier.predict(test_emb) 372 | else: 373 | train_raw = self.classifier.predict_proba(train_emb)[:, 1] 374 | val_raw = self.classifier.predict_proba(val_emb)[:, 1] 375 | test_raw = self.classifier.predict_proba(test_emb)[:, 1] 376 | 377 | return ( 378 | np.expand_dims(train_raw, axis=1), 379 | np.expand_dims(val_raw, axis=1), 380 | np.expand_dims(test_raw, axis=1), 381 | ) 382 | 383 | def ee_multioutput_binary_classification(self, train_emb, train_y, val_emb, val_y, test_emb, test_y): 384 | 385 | self.classifier = make_pipeline(StandardScaler(), MultiOutputClassifier(self.base_classifier, n_jobs=-1)) 386 | 387 | if np.isnan(train_y).any(): 388 | print("Has NaNs ... ignoring them") 389 | train_y = np.nan_to_num(train_y) 390 | self.classifier.fit(train_emb, train_y) 391 | 392 | train_raw = np.transpose([y_pred[:, 1] for y_pred in self.classifier.predict_proba(train_emb)]) 393 | val_raw = np.transpose([y_pred[:, 1] for y_pred in self.classifier.predict_proba(val_emb)]) 394 | test_raw = np.transpose([y_pred[:, 1] for y_pred in self.classifier.predict_proba(test_emb)]) 395 | 396 | return train_raw, val_raw, test_raw 397 | 398 | def ee_regression(self, train_emb, train_y, val_emb, val_y, test_emb, test_y): 399 | if self.param_search: 400 | params_dict = {"alpha": [1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4, 1e5]} 401 | self.classifier = GridSearchCV( 402 | self.base_classifier, 403 | params_dict, 404 | cv=5, 405 | scoring=self.gscv_scoring_name, 406 | n_jobs=16, 407 | verbose=0, 408 | ) 409 | else: 410 | self.classifier = self.base_classifier 411 | 412 | self.classifier.fit(train_emb, np.squeeze(train_y)) 413 | 414 | train_raw = self.classifier.predict(train_emb) 415 | val_raw = self.classifier.predict(val_emb) 416 | test_raw = self.classifier.predict(test_emb) 417 | 418 | return ( 419 | np.expand_dims(train_raw, axis=1), 420 | np.expand_dims(val_raw, axis=1), 421 | np.expand_dims(test_raw, axis=1), 422 | ) 423 | 424 | def embedding_evaluation(self, encoder, train_loader, valid_loader, test_loader): 425 | encoder.eval() 426 | train_emb, train_y = get_emb_y(train_loader, encoder, self.device, is_rand_label=self.is_rand_label) 427 | val_emb, val_y = get_emb_y(valid_loader, encoder, self.device, is_rand_label=self.is_rand_label) 428 | test_emb, test_y = get_emb_y(test_loader, encoder, self.device, is_rand_label=self.is_rand_label) 429 | if "classification" in self.task_type: 430 | if self.num_tasks == 1: 431 | train_raw, val_raw, test_raw = self.ee_binary_classification(train_emb, train_y, val_emb, val_y, test_emb, test_y) 432 | elif self.num_tasks > 1: 433 | (train_raw, val_raw, test_raw,) = self.ee_multioutput_binary_classification(train_emb, train_y, val_emb, val_y, test_emb, test_y) 434 | else: 435 | raise NotImplementedError 436 | else: 437 | if self.num_tasks == 1: 438 | train_raw, val_raw, test_raw = self.ee_regression(train_emb, train_y, val_emb, val_y, test_emb, test_y) 439 | else: 440 | raise NotImplementedError 441 | 442 | train_score = self.scorer(train_y, train_raw) 443 | 444 | val_score = self.scorer(val_y, val_raw) 445 | 446 | test_score = self.scorer(test_y, test_raw) 447 | 448 | return train_score, val_score, test_score 449 | #我改了这里的128 为4 450 | def kf_embedding_evaluation(self, encoder, dataset, folds=10, batch_size=128): 451 | kf_train = [] 452 | kf_val = [] 453 | kf_test = [] 454 | 455 | kf = KFold(n_splits=folds, shuffle=True, random_state=None) 456 | for k_id, (train_val_index, test_index) in enumerate(kf.split(dataset)): 457 | test_dataset = [dataset[int(i)] for i in list(test_index)] 458 | train_index, val_index = train_test_split(train_val_index, test_size=0.2, random_state=None) 459 | 460 | train_dataset = [dataset[int(i)] for i in list(train_index)] 461 | val_dataset = [dataset[int(i)] for i in list(val_index)] 462 | 463 | train_loader = DataLoader(train_dataset, batch_size=batch_size) 464 | valid_loader = DataLoader(val_dataset, batch_size=batch_size) 465 | test_loader = DataLoader(test_dataset, batch_size=batch_size) 466 | 467 | train_score, val_score, test_score = self.embedding_evaluation(encoder, train_loader, valid_loader, test_loader) 468 | 469 | kf_train.append(train_score) 470 | kf_val.append(val_score) 471 | kf_test.append(test_score) 472 | 473 | return ( 474 | np.array(kf_train).mean(), 475 | np.array(kf_val).mean(), 476 | np.array(kf_test).mean(), 477 | ) 478 | 479 | 480 | def reset(nn): 481 | def _reset(item): 482 | if hasattr(item, "reset_parameters"): 483 | item.reset_parameters() 484 | 485 | if nn is not None: 486 | if hasattr(nn, "children") and len(list(nn.children())) > 0: 487 | for item in nn.children(): 488 | _reset(item) 489 | else: 490 | _reset(nn) 491 | 492 | 493 | class WGINConv(MessagePassing): 494 | def __init__( 495 | self, nn: Callable, eps: float = 0.0, train_eps: bool = False, **kwargs 496 | ): 497 | kwargs.setdefault("aggr", "add") 498 | super(WGINConv, self).__init__(**kwargs) 499 | self.nn = nn 500 | self.initial_eps = eps 501 | if train_eps: 502 | self.eps = torch.nn.Parameter(torch.Tensor([eps])) 503 | else: 504 | self.register_buffer("eps", torch.Tensor([eps])) 505 | self.reset_parameters() 506 | 507 | def reset_parameters(self): 508 | reset(self.nn) 509 | self.eps.data.fill_(self.initial_eps) 510 | 511 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_weight=None, size: Size = None,) -> Tensor: 512 | """""" 513 | if isinstance(x, Tensor): 514 | x: OptPairTensor = (x, x) 515 | 516 | out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=size) 517 | 518 | x_r = x[1] 519 | if x_r is not None: 520 | out += (1 + self.eps) * x_r 521 | 522 | return self.nn(out) 523 | 524 | def message(self, x_j: Tensor, edge_weight) -> Tensor: 525 | return x_j if edge_weight is None else x_j * edge_weight.view(-1, 1) 526 | 527 | def __repr__(self): 528 | return "{}(nn={})".format(self.__class__.__name__, self.nn) 529 | 530 | 531 | 532 | class TUEncoder(torch.nn.Module): 533 | def __init__(self, num_dataset_features, pooling_layer, f_method, a1, a2, a3, rho, same_para, p0, q0, num, eps, h, emb_dim=300, num_gc_layers=5, drop_ratio=0.0, 534 | pooling_type="standard", 535 | is_infograph=False, 536 | ): 537 | super(TUEncoder, self).__init__() 538 | self.a1 = a1 539 | self.a2 = a2 540 | self.a3 = a3 541 | self.f_method = f_method 542 | self.rho = rho 543 | self.same_para = same_para 544 | self.p0 = p0 545 | self.q0 = q0 546 | self.num = num 547 | self.eps = eps 548 | self.h = h 549 | self.rho = rho 550 | self.pooling_type = pooling_type 551 | self.emb_dim = emb_dim 552 | self.num_gc_layers = num_gc_layers 553 | self.drop_ratio = drop_ratio 554 | self.is_infograph = is_infograph 555 | self.pooling_layer = (pooling_layer,) 556 | self.out_node_dim = self.emb_dim 557 | if self.pooling_type == "standard": 558 | self.out_graph_dim = self.emb_dim 559 | elif self.pooling_type == "layerwise": 560 | self.out_graph_dim = self.emb_dim * self.num_gc_layers 561 | else: 562 | raise NotImplementedError 563 | 564 | self.convs = torch.nn.ModuleList() 565 | self.bns = torch.nn.ModuleList() 566 | 567 | for i in range(num_gc_layers): 568 | 569 | if i: 570 | nn = Sequential(Linear(emb_dim, emb_dim), ReLU(), Linear(emb_dim, emb_dim)) 571 | else: 572 | nn = Sequential( 573 | Linear(num_dataset_features, emb_dim), 574 | ReLU(), 575 | Linear(emb_dim, emb_dim), 576 | ) 577 | conv = WGINConv(nn) 578 | bn = torch.nn.BatchNorm1d(emb_dim) 579 | 580 | self.convs.append(conv) 581 | self.bns.append(bn) 582 | self.pooling_layer = pooling_layer 583 | feature_pooling = emb_dim 584 | if self.pooling_layer == "mix_pooling": 585 | self.pooling = Pooling.MixedPooling() 586 | if self.pooling_layer == "gated_pooling": 587 | self.pooling = Pooling.GatedPooling(feature_pooling) 588 | if self.pooling_layer == "set_set": 589 | self.pooling = Pooling.Set2Set(feature_pooling, 2, 1) 590 | self.dense = torch.nn.Linear(feature_pooling * 2, 32) 591 | if self.pooling_layer == "attention_pooling": 592 | self.pooling = Pooling.AttentionPooling(feature_pooling, 32) 593 | if self.pooling_layer == "gated_attention_pooling": 594 | self.pooling = Pooling.GatedAttentionPooling(feature_pooling, 32) 595 | if self.pooling_layer == "dynamic_pooling": 596 | self.pooling = Pooling.DynamicPooling(feature_pooling, 3) 597 | if self.pooling_layer == "uot_pooling": 598 | self.pooling = Pooling.UOTPooling(dim=feature_pooling, num=num, rho = rho, same_para=same_para, p0=p0, q0=q0, eps=eps, a1=a1, a2=a2, a3=a3, f_method=args.f_method) 599 | if self.pooling_layer == "rot_pooling": 600 | self.pooling = Pooling.ROTPooling(dim=feature_pooling, num=num, rho = rho, same_para=same_para, p0=p0, q0=q0, eps=eps, a1=a1, a2=a2, a3=a3, f_method=args.f_method) 601 | if self.pooling_layer == "deepset": 602 | self.pooling = Pooling.DeepSet(feature_pooling, 32) 603 | if self.pooling_layer == 'GeneralizedNormPooling': 604 | self.pooling = Pooling.GeneralizedNormPooling(feature_pooling) 605 | if self.pooling_layer == 'SAGPooling': 606 | self.pooling = SAGPooling(feature_pooling) 607 | if self.pooling_layer == 'ASAPooling': 608 | self.pooling = ASAPooling(feature_pooling) 609 | 610 | def forward(self, batch, x, edge_index, edge_attr=None, edge_weight=None): 611 | xs = [] 612 | for i in range(self.num_gc_layers): 613 | x = self.convs[i](x, edge_index, edge_weight) 614 | x = self.bns[i](x) 615 | if i == self.num_gc_layers - 1: 616 | x = F.dropout(x, self.drop_ratio, training=self.training) 617 | else: 618 | x = F.dropout(F.relu(x), self.drop_ratio, training=self.training) 619 | xs.append(x) 620 | # compute graph embedding using pooling 621 | if self.pooling_type == "standard": 622 | if self.pooling_layer == "add_pooling": 623 | xpool = global_add_pool(x, batch) 624 | elif self.pooling_layer == "mean_pooling": 625 | xpool = global_mean_pool(x, batch) 626 | xpool = xpool 627 | elif self.pooling_layer == "max_pooling": 628 | xpool = global_max_pool(x, batch) 629 | elif self.pooling_layer == "set_set": 630 | torch.backends.cudnn.enabled = False 631 | xpool = self.pooling(x, batch) 632 | xpool = self.dense(xpool) 633 | elif self.pooling_layer == 'SAGPooling': 634 | xpool, _, _, batch, _, _ = self.pooling(x, edge_index, batch=batch) 635 | xpool = global_add_pool(xpool, batch) 636 | elif self.pooling_layer == 'ASAPooling': 637 | xpool, _, _, batch, _ = self.pooling(x, edge_index, batch=batch) 638 | xpool = global_add_pool(xpool, batch) 639 | else: 640 | xpool = self.pooling(x, batch) 641 | #logs_temp = np.array(xpool.cpu()) 642 | #print(xpool) 643 | return xpool, x 644 | 645 | elif self.pooling_type == "layerwise": 646 | if self.pooling_layer == "add_pooling": 647 | xpool = [global_add_pool(x, batch) for x in xs] 648 | elif self.pooling_layer == "mean_pooling": 649 | xpool = [global_mean_pool(x, batch) for x in xs] 650 | elif self.pooling_layer == "max_pooling": 651 | xpool = [global_max_pool(x, batch) for x in xs] 652 | # for set_set 653 | elif self.pooling_layer == "set_set": 654 | torch.backends.cudnn.enabled = False 655 | xpool = [self.pooling(x, batch) for x in xs] 656 | xpool = [self.dense(x2) for x2 in xpool] 657 | else: 658 | xpool = [self.pooling(x, batch) for x in xs] 659 | xpool = torch.cat(xpool, 1) 660 | if self.is_infograph: 661 | return xpool, torch.cat(xs, 1) 662 | else: 663 | return xpool, x 664 | else: 665 | raise NotImplementedError 666 | 667 | def get_embeddings(self, loader, device, is_rand_label=False): 668 | ret = [] 669 | y = [] 670 | with torch.no_grad(): 671 | for data in loader: 672 | if isinstance(data, list): 673 | data = data[0].to(device) 674 | data = data.to(device) 675 | batch, x, edge_index = data.batch, data.x, data.edge_index 676 | edge_weight = data.edge_weight if hasattr(data, "edge_weight") else None 677 | 678 | if x is None: 679 | x = torch.ones((batch.shape[0], 1)).to(device) 680 | x, _ = self.forward(batch, x, edge_index, edge_weight) 681 | 682 | ret.append(x.cpu().numpy()) 683 | if is_rand_label: 684 | y.append(data.rand_label.cpu().numpy()) 685 | else: 686 | y.append(data.y.cpu().numpy()) 687 | ret = np.concatenate(ret, 0) 688 | y = np.concatenate(y, 0) 689 | return ret, y 690 | 691 | 692 | class GInfoMinMax(torch.nn.Module): 693 | def __init__(self, encoder, proj_hidden_dim=300): 694 | super(GInfoMinMax, self).__init__() 695 | 696 | self.encoder = encoder 697 | self.input_proj_dim = self.encoder.out_graph_dim 698 | 699 | self.proj_head = Sequential( 700 | Linear(self.input_proj_dim, proj_hidden_dim), 701 | ReLU(inplace=True), 702 | Linear(proj_hidden_dim, proj_hidden_dim), 703 | ) 704 | 705 | self.init_emb() 706 | 707 | def init_emb(self): 708 | for m in self.modules(): 709 | if isinstance(m, Linear): 710 | torch.nn.init.xavier_uniform_(m.weight.data) 711 | if m.bias is not None: 712 | m.bias.data.fill_(0.0) 713 | 714 | def forward(self, batch, x, edge_index, edge_attr, edge_weight=None): 715 | 716 | z, node_emb = self.encoder(batch, x, edge_index, edge_attr, edge_weight) 717 | z = self.proj_head(z) 718 | return z, node_emb 719 | 720 | @staticmethod 721 | def calc_loss(x, x_aug, temperature=0.2, sym=True): 722 | 723 | batch_size, _ = x.size() 724 | x_abs = x.norm(dim=1) 725 | x_aug_abs = x_aug.norm(dim=1) 726 | 727 | sim_matrix = torch.einsum("ik,jk->ij", x, x_aug) / torch.einsum("i,j->ij", x_abs, x_aug_abs) 728 | sim_matrix = torch.exp(sim_matrix / temperature) 729 | pos_sim = sim_matrix[range(batch_size), range(batch_size)] 730 | if sym: 731 | 732 | loss_0 = pos_sim / (sim_matrix.sum(dim=0) - pos_sim) 733 | loss_1 = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) 734 | 735 | loss_0 = -torch.log(loss_0).mean() 736 | loss_1 = -torch.log(loss_1).mean() 737 | loss = (loss_0 + loss_1) / 2.0 738 | else: 739 | loss_1 = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) 740 | loss_1 = -torch.log(loss_1).mean() 741 | return loss_1 742 | 743 | return loss 744 | 745 | 746 | def setup_seed(seed): 747 | torch.manual_seed(seed) 748 | torch.cuda.manual_seed_all(seed) 749 | torch.backends.cudnn.deterministic = True 750 | torch.backends.cudnn.benchmark = False 751 | torch.backends.cudnn.enabled = False 752 | np.random.seed(seed) 753 | random.seed(seed) 754 | os.environ['PYTHONHASHSEED'] = str(seed) 755 | # torch.use_deterministic_algorithms(True) 756 | 757 | 758 | def run(args): 759 | logging.basicConfig( 760 | level=logging.INFO, 761 | format="%(asctime)s - %(message)s", 762 | datefmt="%d-%b-%y %H:%M:%S", 763 | ) 764 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 765 | logging.info("Using Device: %s" % device) 766 | logging.info("Seed: %d" % args.seed) 767 | logging.info(args) 768 | setup_seed(args.seed) 769 | 770 | evaluator = TUEvaluator() 771 | my_transforms = Compose([initialize_node_features, initialize_edge_weight, set_tu_dataset_y_shape]) 772 | dataset = TUDataset("./original_datasets/", args.DS, transform=my_transforms) 773 | dataloader = DataLoader( 774 | dataset, 775 | batch_size=args.batch_size, 776 | shuffle=True, 777 | num_workers=16, 778 | pin_memory=True, 779 | ) 780 | model = GInfoMinMax( 781 | TUEncoder( 782 | num_dataset_features=1, 783 | pooling_layer=args.pooling_layer, 784 | a1=args.a1, 785 | a2=args.a2, 786 | a3=args.a3, 787 | f_method=args.f_method, 788 | rho=args.rho, 789 | same_para=args.same_para, 790 | p0=args.p0, 791 | q0=args.q0, 792 | num=args.num, 793 | eps=args.eps, 794 | h=args.h, 795 | emb_dim=args.emb_dim, 796 | num_gc_layers=args.num_gc_layers, 797 | drop_ratio=args.drop_ratio, 798 | pooling_type=args.pooling_type, 799 | ), 800 | args.emb_dim, 801 | ).to(device) 802 | model_optimizer = torch.optim.Adam(model.parameters(), lr=args.model_lr) 803 | view_learner = ViewLearner( 804 | TUEncoder( 805 | num_dataset_features=1, 806 | pooling_layer=args.pooling_layer, 807 | a1=args.a1, 808 | a2=args.a2, 809 | a3=args.a3, 810 | f_method=args.f_method, 811 | rho=args.rho, 812 | same_para=args.same_para, 813 | p0=args.p0, 814 | q0=args.q0, 815 | num=args.num, 816 | eps=args.eps, 817 | h=args.h, 818 | emb_dim=args.emb_dim, 819 | num_gc_layers=args.num_gc_layers, 820 | drop_ratio=args.drop_ratio, 821 | pooling_type=args.pooling_type, 822 | ), 823 | mlp_edge_model_dim=args.mlp_edge_model_dim, 824 | ).to(device) 825 | view_optimizer = torch.optim.Adam(view_learner.parameters(), lr=args.view_lr) 826 | if args.downstream_classifier == "linear": 827 | ee = EmbeddingEvaluation( 828 | LinearSVC(dual=False, fit_intercept=True), 829 | evaluator, 830 | dataset.task_type, 831 | dataset.num_tasks, 832 | device, 833 | param_search=False, 834 | ) 835 | else: 836 | ee = EmbeddingEvaluation( 837 | SVC(), 838 | evaluator, 839 | dataset.task_type, 840 | dataset.num_tasks, 841 | device, 842 | param_search=False, 843 | ) 844 | model.eval() 845 | train_score, val_score, test_score = ee.kf_embedding_evaluation( 846 | model.encoder, dataset 847 | ) 848 | logging.info( 849 | "Before training Embedding Eval Scores: Train: {} Val: {} Test: {}".format(train_score, val_score, test_score) 850 | ) 851 | 852 | model_losses = [] 853 | view_losses = [] 854 | view_regs = [] 855 | valid_curve = [] 856 | test_curve = [] 857 | train_curve = [] 858 | 859 | accuracies = {"val": [], "test": []} 860 | for epoch in range(1, args.epochs + 1): 861 | begin_time = time.time() 862 | model_loss_all = 0 863 | view_loss_all = 0 864 | reg_all = 0 865 | for batch in dataloader: 866 | batch = batch.to(device) 867 | 868 | view_learner.train() 869 | view_learner.zero_grad() 870 | model.eval() 871 | 872 | x, _ = model(batch.batch, batch.x, batch.edge_index, None, None) 873 | 874 | edge_logits = view_learner(batch.batch, batch.x, batch.edge_index, None) 875 | 876 | temperature = 1.0 877 | bias = 0.0 + 0.0001 878 | eps = (bias - (1 - bias)) * torch.rand(edge_logits.size()) + (1 - bias) 879 | gate_inputs = torch.log(eps) - torch.log(1 - eps) 880 | gate_inputs = gate_inputs.to(device) 881 | gate_inputs = (gate_inputs + edge_logits) / temperature 882 | batch_aug_edge_weight = torch.sigmoid(gate_inputs).squeeze() 883 | 884 | x_aug, _ = model( 885 | batch.batch, batch.x, batch.edge_index, None, batch_aug_edge_weight 886 | ) 887 | 888 | # regularization 889 | row, col = batch.edge_index 890 | edge_batch = batch.batch[row] 891 | edge_drop_out_prob = 1 - batch_aug_edge_weight 892 | 893 | uni, edge_batch_num = edge_batch.unique(return_counts=True) 894 | sum_pe = scatter(edge_drop_out_prob, edge_batch, reduce="sum") 895 | 896 | reg = [] 897 | for b_id in range(args.batch_size): 898 | if b_id in uni: 899 | num_edges = edge_batch_num[uni.tolist().index(b_id)] 900 | reg.append(sum_pe[b_id] / num_edges) 901 | else: 902 | # means no edges in that graph. So don't include. 903 | pass 904 | reg = torch.stack(reg) 905 | reg = reg.mean() 906 | 907 | view_loss = model.calc_loss(x, x_aug) - (args.reg_lambda * reg) 908 | view_loss_all += view_loss.item() * batch.num_graphs 909 | reg_all += reg.item() 910 | # gradient ascent formulation 911 | (-view_loss).backward() 912 | view_optimizer.step() 913 | 914 | # train (model) to minimize contrastive loss 915 | model.train() 916 | view_learner.eval() 917 | model.zero_grad() 918 | 919 | x, _ = model(batch.batch, batch.x, batch.edge_index, None, None) 920 | edge_logits = view_learner(batch.batch, batch.x, batch.edge_index, None) 921 | 922 | temperature = 1.0 923 | bias = 0.0 + 0.0001 924 | eps = (bias - (1 - bias)) * torch.rand(edge_logits.size()) + (1 - bias) 925 | gate_inputs = torch.log(eps) - torch.log(1 - eps) 926 | gate_inputs = gate_inputs.to(device) 927 | gate_inputs = (gate_inputs + edge_logits) / temperature 928 | batch_aug_edge_weight = torch.sigmoid(gate_inputs).squeeze().detach() 929 | 930 | x_aug, _ = model( 931 | batch.batch, batch.x, batch.edge_index, None, batch_aug_edge_weight 932 | ) 933 | 934 | model_loss = model.calc_loss(x, x_aug) 935 | model_loss_all += model_loss.item() * batch.num_graphs 936 | # standard gradient descent formulation 937 | model_loss.backward() 938 | model_optimizer.step() 939 | 940 | end_time = time.time() 941 | print("-------------time------------") 942 | print(end_time - begin_time) 943 | fin_model_loss = model_loss_all / len(dataloader) 944 | fin_view_loss = view_loss_all / len(dataloader) 945 | fin_reg = reg_all / len(dataloader) 946 | 947 | logging.info( 948 | "Epoch {}, Model Loss {}, View Loss {}, Reg {}".format(epoch, fin_model_loss, fin_view_loss, fin_reg) 949 | ) 950 | model_losses.append(fin_model_loss) 951 | view_losses.append(fin_view_loss) 952 | view_regs.append(fin_reg) 953 | if epoch % args.eval_interval == 0: 954 | model.eval() 955 | 956 | train_score, val_score, test_score = ee.kf_embedding_evaluation(model.encoder, dataset) 957 | 958 | logging.info( 959 | "Metric: {} Train: {} Val: {} Test: {}".format(evaluator.eval_metric, train_score, val_score, test_score) 960 | ) 961 | 962 | train_curve.append(train_score) 963 | valid_curve.append(val_score) 964 | test_curve.append(test_score) 965 | accuracies["val"].append(val_score) 966 | accuracies["test"].append(test_score) 967 | if "classification" in dataset.task_type: 968 | best_val_epoch = np.argmax(np.array(valid_curve)) 969 | best_train = max(train_curve) 970 | else: 971 | best_val_epoch = np.argmin(np.array(valid_curve)) 972 | best_train = min(train_curve) 973 | 974 | logging.info("FinishedTraining!") 975 | logging.info("BestEpoch: {}".format(best_val_epoch)) 976 | logging.info("BestTrainScore: {}".format(best_train)) 977 | logging.info("BestValidationScore: {}".format(valid_curve[best_val_epoch])) 978 | logging.info("FinalTestScore: {}".format(test_curve[best_val_epoch])) 979 | return valid_curve[best_val_epoch], test_curve[best_val_epoch] 980 | 981 | 982 | if __name__ == "__main__": 983 | 984 | args = arg_parse() 985 | run(args) 986 | -------------------------------------------------------------------------------- /backbones/attention_mil.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import numpy as np 5 | import scipy.io as sio 6 | from sklearn.model_selection import KFold 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | from torch_geometric.data import Data 12 | from torch_geometric.nn import global_add_pool, global_max_pool, global_mean_pool 13 | from torch_geometric.loader import DataLoader 14 | 15 | import random 16 | import os 17 | import sys 18 | sys.path.append("..") 19 | import pooling as Pooling 20 | 21 | 22 | def arg_parse(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument( 25 | "--DS", 26 | help="dataset to train on, like musk1 or fox", 27 | default="datasets/Biocreative/process", 28 | type=str, 29 | ) 30 | parser.add_argument( 31 | "--epochs", 32 | type=int, 33 | default=50, 34 | metavar="N", 35 | help="number of epochs to train (default: 20)", 36 | ) 37 | parser.add_argument( 38 | "--lr", 39 | type=float, 40 | default=0.0005, 41 | metavar="LR", 42 | help="learning rate (default: 0.0005)", 43 | ) 44 | parser.add_argument( 45 | "--weight_decay", 46 | type=float, 47 | default=0.005, 48 | help="weight_decay (default: 0.005)", 49 | ) 50 | parser.add_argument( 51 | "--batch_size", 52 | type=int, 53 | default=128, 54 | help="batch size (default: 128)", 55 | ) 56 | parser.add_argument( 57 | "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" 58 | ) 59 | parser.add_argument( 60 | "--no-cuda", action="store_true", default=False, help="disables CUDA training" 61 | ) 62 | # for uot-pooling 63 | parser.add_argument("--a1", type=float, default=None) 64 | parser.add_argument("--a2", type=float, default=None) 65 | parser.add_argument("--a3", type=float, default=None) 66 | parser.add_argument("--rho", type=float, default=None) 67 | parser.add_argument("--h", type=int, default=64) 68 | parser.add_argument("--k", type=int, default=4) 69 | parser.add_argument("--u1", type=str, default="fixed") 70 | parser.add_argument("--u2", type=str, default="fixed") 71 | parser.add_argument("--pooling_layer", help="score_pooling", default="uot_pooling", type=str) 72 | parser.add_argument("--f_method", type=str, default="badmm-e") 73 | return parser.parse_args() 74 | 75 | 76 | def load_mil_data_mat( 77 | dataset, n_folds, batch_size: int, normalize: bool = True, split: float = 0.75, seed: int = 1): 78 | data = sio.loadmat(dataset + ".mat") 79 | instances = data["data"] 80 | bags = [] 81 | labels = [] 82 | for i in range(len(instances)): 83 | bag = torch.from_numpy(instances[i][0]).float()[:, 0:-1] 84 | label = instances[i][1][0, 0] 85 | bags.append(bag) 86 | if label < 0: 87 | labels.append(0) 88 | else: 89 | labels.append(label) 90 | labels = torch.Tensor(labels).long() 91 | 92 | if normalize: 93 | all_instances = torch.cat(bags, dim=0) 94 | avg_instance = torch.mean(all_instances, dim=0, keepdim=True) 95 | std_instance = torch.std(all_instances, dim=0, keepdim=True) 96 | for i in range(len(bags)): 97 | bags[i] = (bags[i] - avg_instance) / (std_instance + 1e-6) 98 | bags = bags 99 | bags_fea = [] 100 | for i in range(len(bags)): 101 | bags_fea.append(Data(x=bags[i], y=labels[i])) 102 | 103 | kf = KFold(n_splits=n_folds, shuffle=True, random_state=None) 104 | dataloaders = [] 105 | for train_idx, test_idx in kf.split(bags_fea): 106 | dataloader = {} 107 | dataloader["train"] = DataLoader([bags_fea[ibag] for ibag in train_idx], batch_size=batch_size, shuffle=True) 108 | dataloader["test"] = DataLoader([bags_fea[ibag] for ibag in test_idx], batch_size=batch_size, shuffle=False) 109 | dataloaders.append(dataloader) 110 | return dataloaders 111 | 112 | 113 | def get_dim(dataset): 114 | data = sio.loadmat(dataset + ".mat") 115 | ins_fea = data["data"][0, 0] 116 | length = len(ins_fea[0]) - 1 117 | return length 118 | 119 | 120 | class Net(nn.Module): 121 | def __init__(self, dim, pooling_layer, a1, a2, a3, rho, u1, u2, k, h): 122 | super(Net, self).__init__() 123 | self.dim = dim 124 | self.pooling_layer = pooling_layer 125 | self.L = 64 126 | self.D = 64 127 | self.K = 1 128 | self.a1 = a1 129 | self.a2 = a2 130 | self.a3 = a3 131 | self.rho = rho 132 | self.u1 = u1 133 | self.u2 = u2 134 | self.k = k 135 | self.h = h 136 | self.rho = rho 137 | self.feature_extractor_part1 = nn.Sequential( 138 | nn.Linear(self.dim, 256), 139 | nn.ReLU(), 140 | nn.Dropout(), 141 | nn.Linear(256, 128), 142 | nn.ReLU(), 143 | nn.Dropout(), 144 | nn.Linear(128, 64), 145 | nn.ReLU(), 146 | nn.Dropout(), 147 | ) 148 | if self.pooling_layer == "mix_pooling": 149 | self.pooling = Pooling.MixedPooling() 150 | if self.pooling_layer == "gated_pooling": 151 | self.pooling = Pooling.GatedPooling(64) 152 | if self.pooling_layer == "set_set": 153 | self.pooling = Pooling.Set2Set(64, 2, 1) 154 | self.dense4 = torch.nn.Linear(64 * 2, 64) 155 | if self.pooling_layer == "attention_pooling": 156 | self.pooling = Pooling.AttentionPooling(64, 64) 157 | if self.pooling_layer == "gated_attention_pooling": 158 | self.pooling = Pooling.GatedAttentionPooling(64, 64) 159 | if self.pooling_layer == "dynamic_pooling": 160 | self.pooling = Pooling.DynamicPooling(64, 3) 161 | if self.pooling_layer == 'GeneralizedNormPooling': 162 | self.pooling = Pooling.GeneralizedNormPooling(d=64) 163 | if self.pooling_layer == "uot_pooling": 164 | self.pooling = Pooling.UOTPooling(dim=64, num=k, a1=a1, a2=a2, a3=a3, p0=u1, q0=u2, f_method=args.f_method) 165 | if self.pooling_layer == "rot_pooling": 166 | self.pooling = Pooling.ROTPooling(dim=64, num=k, a1=a1, a2=a2, a3=a3, p0=u1, q0=u2, f_method=args.f_method) 167 | if self.pooling_layer == "deepset": 168 | self.pooling = Pooling.DeepSet(64, 64) 169 | self.classifier = nn.Sequential(nn.Linear(64, 1), nn.Sigmoid()) 170 | 171 | def forward(self, x, batch): 172 | H = self.feature_extractor_part1(x) 173 | if self.pooling_layer == "add_pooling": 174 | M = global_add_pool(H, batch) 175 | elif self.pooling_layer == "mean_pooling": 176 | M = global_mean_pool(H, batch) 177 | elif self.pooling_layer == "max_pooling": 178 | M = global_max_pool(H, batch) 179 | # for set_set 180 | elif self.pooling_layer == "set_set": 181 | torch.backends.cudnn.enabled = False 182 | M = self.pooling(H, batch) 183 | M = self.dense4(M) 184 | else: 185 | M = self.pooling(H, batch) 186 | Y_prob = self.classifier(M) 187 | Y_hat = torch.ge(Y_prob, 0.5).float() 188 | 189 | return Y_prob, Y_hat 190 | 191 | def calculate_classification_error(self, X, batch, Y): 192 | Y = Y.float().unsqueeze(1) 193 | _, Y_hat = self.forward(X, batch) 194 | error = 1.0 - Y_hat.eq(Y).cpu().float() 195 | acc_num = Y_hat.eq(Y).cpu().float().sum().item() 196 | error = torch.sum(error) 197 | return error, Y_hat, acc_num 198 | 199 | def calculate_objective(self, X, batch, Y): 200 | Y = Y.float().unsqueeze(1) 201 | Y_prob, _ = self.forward(X, batch) 202 | Y_prob = torch.clamp(Y_prob, min=1e-5, max=1.0 - 1e-5) 203 | neg_log_likelihood = -1.0 * ( 204 | Y * torch.log(Y_prob) + (1.0 - Y) * torch.log(1.0 - Y_prob) 205 | ) # negative log bernoulli 206 | neg_log_likelihood = torch.sum(neg_log_likelihood) 207 | return neg_log_likelihood 208 | 209 | 210 | def train(model, optimizer, train_bags): 211 | model.train() 212 | train_loss = 0.0 213 | train_error = 0.0 214 | train_len = 0 215 | for idx, data_a in enumerate(train_bags): 216 | data = data_a.x 217 | bag_label = data_a.y 218 | train_len += len(bag_label) 219 | batch = data_a.batch 220 | if args.cuda: 221 | data, bag_label, batch = data.cuda(), bag_label.cuda(), batch.cuda() 222 | # reset gradients 223 | optimizer.zero_grad() 224 | # calculate loss and metrics 225 | loss= model.calculate_objective(data, batch, bag_label) 226 | train_loss += loss.item() 227 | error, _, acc_num = model.calculate_classification_error(data, batch, bag_label) 228 | train_error += error 229 | # backward pass 230 | loss.backward() 231 | # step 232 | optimizer.step() 233 | 234 | # calculate loss and error for epoch 235 | train_loss /= train_len 236 | train_error /= train_len 237 | return train_error, train_loss 238 | 239 | 240 | def acc_test(model, test_bags): 241 | model.eval() 242 | test_loss = 0.0 243 | test_error = 0.0 244 | num_bags = 0 245 | num_corrects = 0 246 | test_len = 0 247 | for batch_idx, data_a in enumerate(test_bags): 248 | data = data_a.x 249 | bag_label = data_a.y 250 | test_len += len(bag_label) 251 | batch = data_a.batch 252 | num_bags += bag_label.shape[0] 253 | if args.cuda: 254 | data, bag_label, batch = data.cuda(), bag_label.cuda(), batch.cuda() 255 | with torch.no_grad(): 256 | loss = model.calculate_objective(data, batch, bag_label) 257 | test_loss += loss.item() 258 | error, predicted_label, acc_num = model.calculate_classification_error(data, batch, bag_label) 259 | test_error += error 260 | num_corrects += acc_num 261 | 262 | test_error /= test_len 263 | test_loss /= test_len 264 | test_acc = num_corrects / num_bags 265 | return test_error, test_loss, test_acc 266 | 267 | 268 | def mil_Net(dataloader): 269 | dim = get_dim(args.DS) 270 | pooling_layer = args.pooling_layer 271 | a1 = args.a1 272 | a2 = args.a2 273 | a3 = args.a3 274 | rho = args.rho 275 | u1 = args.u1 276 | u2 = args.u2 277 | k = args.k 278 | h = args.h 279 | rho = rho 280 | print("Init Model") 281 | model = Net(dim, pooling_layer, a1, a2, a3, rho, u1, u2, k, h) 282 | if args.cuda: 283 | model.cuda() 284 | optimizer = optim.Adam( 285 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay 286 | ) 287 | train_bags = dataloader["train"] 288 | test_bags = dataloader["test"] 289 | print(len(train_bags)) 290 | print(len(test_bags)) 291 | t1 = time.time() 292 | test_accs = [] 293 | test_loss_error = [] 294 | for epoch in range(1, args.epochs + 1): 295 | train_error, train_loss = train(model, optimizer, train_bags) 296 | test_error, test_loss, test_acc = acc_test(model, test_bags) 297 | test_accs.append(test_acc) 298 | test_loss_error.append(test_loss + test_error) 299 | print( 300 | "epoch=", 301 | epoch, 302 | " train_error= {:.3f}".format(train_error), 303 | " train_loss= {:.3f}".format(train_loss), 304 | " test_error= {:.3f}".format(test_error), 305 | " test_loss={:.3f}".format(test_loss), 306 | " test_acc= {:.3f}".format(test_acc), 307 | ) 308 | t2 = time.time() 309 | print("run time:", (t2 - t1) / 60.0, "min") 310 | index_epoch = np.argmin(test_loss_error) 311 | print("test_acc={:.3f}".format(test_accs[index_epoch])) 312 | return test_accs[index_epoch] 313 | 314 | 315 | if __name__ == "__main__": 316 | args = arg_parse() 317 | args.cuda = not args.no_cuda and torch.cuda.is_available() 318 | torch.manual_seed(args.seed) 319 | if args.cuda: 320 | torch.cuda.manual_seed(args.seed) 321 | print("\nGPU is ON!") 322 | print("Load Train and Test Set") 323 | loader_kwargs = {"num_workers": 1, "pin_memory": True} if args.cuda else {} 324 | log_dir = (args.DS).split("/")[-1] 325 | 326 | run = 1 327 | n_folds = 5 328 | seed = [0] 329 | acc = np.zeros((run, n_folds), dtype=float) 330 | for irun in range(run): 331 | os.environ['PYTHONHASHSEED'] = str(seed[irun]) 332 | torch.manual_seed(seed[irun]) 333 | torch.cuda.manual_seed(seed[irun]) 334 | torch.cuda.manual_seed_all(seed[irun]) 335 | torch.backends.cudnn.deterministic = True 336 | torch.backends.cudnn.benchmark = False 337 | torch.backends.cudnn.enabled = False 338 | np.random.seed(seed[irun]) 339 | random.seed(seed[irun]) 340 | 341 | dataloaders = load_mil_data_mat(dataset=args.DS, n_folds=n_folds, normalize=True, batch_size=args.batch_size) 342 | for ifold in range(n_folds): 343 | print("run=", irun, " fold=", ifold) 344 | acc[irun][ifold] = mil_Net(dataloaders[ifold]) 345 | print("mi-net mean accuracy = ", np.mean(acc)) 346 | print("std = ", np.std(acc)) -------------------------------------------------------------------------------- /backbones/ddi_gin.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import random 3 | import util 4 | import numpy as np 5 | from torch_geometric.data import DataLoader 6 | from torch.nn import Linear, Sequential, BatchNorm1d, ReLU 7 | from torch_geometric.nn import ( 8 | GCNConv, 9 | GINConv, 10 | global_add_pool, 11 | global_max_pool, 12 | global_mean_pool, 13 | SAGPooling, 14 | ASAPooling, 15 | ) 16 | from sklearn.metrics import roc_auc_score 17 | from dataprocess_fears import * 18 | import argparse 19 | import os 20 | import warnings 21 | import logging 22 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 23 | 24 | logging.getLogger().setLevel(logging.INFO) 25 | warnings.filterwarnings("ignore") 26 | import pooling as Pooling 27 | 28 | 29 | def arg_parse(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--batch_size", type=int, default=128, help="batch size") 32 | parser.add_argument("--seed", type=int, default=1, help="seed") 33 | parser.add_argument("--epoch", type=int, default=40, help="epoch") 34 | parser.add_argument("--dataset", type=str, default="fears", help="dataset") 35 | # for uot-pooling 36 | parser.add_argument("--a1", type=float, default=None) 37 | parser.add_argument("--a2", type=float, default=None) 38 | parser.add_argument("--a3", type=float, default=None) 39 | parser.add_argument("--rho", type=float, default=None) 40 | parser.add_argument("--same_para", type=bool, default=False) 41 | parser.add_argument("--h", type=int, default=32) 42 | parser.add_argument("--num", type=int, default=4) 43 | parser.add_argument("--p0", type=str, default="fixed") 44 | parser.add_argument("--q0", type=str, default="fixed") 45 | parser.add_argument("--f_method", type=str, default="sinkhorn") 46 | parser.add_argument("--eps", type=float, default=1e-18) 47 | parser.add_argument( 48 | "--pooling_layer", help="pooling_layer", default="uot_pooling", type=str 49 | ) 50 | return parser.parse_args() 51 | 52 | 53 | def setup_seed(seed): 54 | torch.manual_seed(seed) 55 | torch.cuda.manual_seed_all(seed) 56 | torch.backends.cudnn.deterministic = True 57 | torch.backends.cudnn.benchmark = False 58 | torch.backends.cudnn.enabled = False 59 | np.random.seed(seed) 60 | random.seed(seed) 61 | os.environ["PYTHONHASHSEED"] = str(seed) 62 | # torch.use_deterministic_algorithms(True) 63 | 64 | class Net(torch.nn.Module): 65 | def __init__( 66 | self, 67 | pooling_layer, 68 | a1, 69 | a2, 70 | a3, 71 | rho, 72 | p0, 73 | q0, 74 | num, 75 | eps, 76 | same_para, 77 | f_method, 78 | ): 79 | super(Net, self).__init__() 80 | self.a1 = a1 81 | self.a2 = a2 82 | self.a3 = a3 83 | self.f_method = f_method 84 | self.rho = rho 85 | self.same_para = same_para 86 | self.p0 = p0 87 | self.q0 = q0 88 | self.num = num 89 | self.eps = eps 90 | self.rho = rho 91 | self.pooling_layer = pooling_layer 92 | dim_h = 64 93 | #gin 94 | self.conv1 = GINConv( 95 | Sequential(Linear(30, dim_h), 96 | BatchNorm1d(dim_h), ReLU(), 97 | Linear(dim_h, dim_h), ReLU())) 98 | self.conv2 = GINConv( 99 | Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(), 100 | Linear(dim_h, dim_h), ReLU())) 101 | self.conv3 = GINConv( 102 | Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(), 103 | Linear(dim_h, dim_h), ReLU())) 104 | 105 | self.lin1 = torch.nn.Linear(64, 128) 106 | self.lin2 = torch.nn.Linear(128, 1) 107 | 108 | # for pooling1 109 | feature_pooling = 64 110 | if self.pooling_layer == "mix_pooling": 111 | self.pooling = Pooling.MixedPooling() 112 | if self.pooling_layer == "gated_pooling": 113 | self.pooling = Pooling.GatedPooling(feature_pooling) 114 | if self.pooling_layer == "set_set": 115 | self.pooling = Pooling.Set2Set(feature_pooling, 2, 1) 116 | self.dense = torch.nn.Linear(feature_pooling * 2, 64) 117 | if self.pooling_layer == "attention_pooling": 118 | self.pooling = Pooling.AttentionPooling(feature_pooling, 32) 119 | if self.pooling_layer == "gated_attention_pooling": 120 | self.pooling = Pooling.GatedAttentionPooling(feature_pooling, 32) 121 | if self.pooling_layer == "dynamic_pooling": 122 | self.pooling = Pooling.DynamicPooling(feature_pooling, 3) 123 | if self.pooling_layer == "uot_pooling": 124 | self.pooling = Pooling.UOTPooling( 125 | dim=feature_pooling, 126 | num=num, 127 | rho=rho, 128 | same_para=same_para, 129 | p0=p0, 130 | q0=q0, 131 | eps=eps, 132 | a1=a1, 133 | a2=a2, 134 | a3=a3, 135 | f_method=f_method, 136 | ) 137 | if self.pooling_layer == "rot_pooling": 138 | self.pooling = Pooling.ROTPooling( 139 | dim=feature_pooling, 140 | num=num, 141 | rho=rho, 142 | same_para=same_para, 143 | p0=p0, 144 | q0=q0, 145 | eps=eps, 146 | a1=a1, 147 | a2=a2, 148 | a3=a3, 149 | f_method=f_method, 150 | ) 151 | if self.pooling_layer == "deepset": 152 | self.pooling = Pooling.DeepSet(feature_pooling, 32) 153 | if self.pooling_layer == "GeneralizedNormPooling": 154 | self.pooling = Pooling.GeneralizedNormPooling(feature_pooling) 155 | if self.pooling_layer == "SAGPooling": 156 | self.pooling = SAGPooling(feature_pooling) 157 | if self.pooling_layer == "ASAPooling": 158 | self.pooling = ASAPooling(feature_pooling) 159 | 160 | # for pooling2 161 | feature_pooling2 = 128 162 | if self.pooling_layer == "mix_pooling": 163 | self.pooling2 = Pooling.MixedPooling() 164 | if self.pooling_layer == "gated_pooling": 165 | self.pooling2 = Pooling.GatedPooling(feature_pooling2) 166 | if self.pooling_layer == "set_set": 167 | self.pooling2 = Pooling.Set2Set(feature_pooling2, 2, 1) 168 | self.dense2 = torch.nn.Linear(feature_pooling2 * 2, 128) 169 | if self.pooling_layer == "attention_pooling": 170 | self.pooling2 = Pooling.AttentionPooling(feature_pooling2, 32) 171 | if self.pooling_layer == "gated_attention_pooling": 172 | self.pooling2 = Pooling.GatedAttentionPooling(feature_pooling2, 32) 173 | if self.pooling_layer == "dynamic_pooling": 174 | self.pooling2 = Pooling.DynamicPooling(feature_pooling2, 3) 175 | if self.pooling_layer == "uot_pooling": 176 | self.pooling2 = Pooling.UOTPooling( 177 | dim=feature_pooling, 178 | num=num, 179 | rho=rho, 180 | same_para=same_para, 181 | p0=p0, 182 | q0=q0, 183 | eps=eps, 184 | a1=a1, 185 | a2=a2, 186 | a3=a3, 187 | f_method=f_method, 188 | ) 189 | if self.pooling_layer == "rot_pooling": 190 | self.pooling2 = Pooling.ROTPooling( 191 | dim=feature_pooling, 192 | num=num, 193 | rho=rho, 194 | same_para=same_para, 195 | p0=p0, 196 | q0=q0, 197 | eps=eps, 198 | a1=a1, 199 | a2=a2, 200 | a3=a3, 201 | f_method=args.f_method, 202 | ) 203 | if self.pooling_layer == "deepset": 204 | self.pooling2 = Pooling.DeepSet(feature_pooling2, 32) 205 | if self.pooling_layer == "GeneralizedNormPooling": 206 | self.pooling2 = Pooling.GeneralizedNormPooling(feature_pooling2) 207 | if self.pooling_layer == "SAGPooling": 208 | self.pooling2 = SAGPooling(feature_pooling2) 209 | if self.pooling_layer == "ASAPooling": 210 | self.pooling2 = ASAPooling(feature_pooling2) 211 | 212 | def forward(self, data, device): 213 | data = data.to(device) 214 | x, edge_index, pos_raw, batch, num = ( 215 | data.x, 216 | data.edge_index, 217 | data.pos, 218 | data.batch, 219 | data.num, 220 | ) 221 | nodes_orders, batch_orders = org_batch(num) 222 | nodes_orders = torch.tensor(nodes_orders).to(device) 223 | batch_orders = torch.tensor(batch_orders).to(device) 224 | x = self.conv1(x, edge_index) 225 | x = self.conv2(x, edge_index) 226 | x = self.conv3(x, edge_index) 227 | # pooling1 228 | if self.pooling_layer == "add_pooling": 229 | x = global_add_pool(x, nodes_orders) 230 | elif self.pooling_layer == "mean_pooling": 231 | x = global_mean_pool(x, nodes_orders) 232 | elif self.pooling_layer == "max_pooling": 233 | x = global_max_pool(x, nodes_orders) 234 | # for set_set 235 | elif self.pooling_layer == "set_set": 236 | torch.backends.cudnn.enabled = False 237 | x = self.pooling(x, nodes_orders) 238 | x = self.dense(x) 239 | elif self.pooling_layer == 'SAGPooling': 240 | x, _, _, batch_graph, _, _ = self.pooling(x, edge_index, batch=nodes_orders) 241 | x = global_add_pool(x, batch_graph) 242 | elif self.pooling_layer == 'ASAPooling': 243 | x, _, _, batch_graph, _ = self.pooling(x, edge_index, batch=nodes_orders) 244 | x = global_add_pool(x, batch_graph) 245 | # 246 | else: 247 | x = self.pooling(x, nodes_orders) 248 | 249 | # pooling2 250 | x_len = 128 251 | x = self.lin1(x) 252 | x = F.relu(x) 253 | if self.pooling_layer == "add_pooling": 254 | x = global_add_pool(x, batch_orders) 255 | elif self.pooling_layer == "mean_pooling": 256 | x = global_mean_pool(x, batch_orders) 257 | elif self.pooling_layer == "max_pooling": 258 | x = global_max_pool(x, batch_orders) 259 | # for set_set 260 | elif self.pooling_layer == "set_set": 261 | torch.backends.cudnn.enabled = False 262 | x = self.pooling2(x, batch_orders) 263 | x = self.dense2(x) 264 | elif self.pooling_layer == 'SAGPooling': 265 | # pooling2 266 | x = global_add_pool(x, batch_orders) 267 | elif self.pooling_layer == 'ASAPooling': 268 | x = global_add_pool(x, batch_orders) 269 | # 270 | else: 271 | x = self.pooling2(x, batch_orders) 272 | x = torch.sigmoid(self.lin2(x)).squeeze(1) 273 | return x 274 | 275 | 276 | def org_batch(num): 277 | """ 278 | # generate nodes_orders and 279 | :param num: [[2,1][3][4,2,3]],[2,1] means a combination which has two graphs, the nodes of those graphs are 2 and 1. 280 | :return: 281 | nodes_orders: [0,0,1,2,2,2,3,3,3,3,4,4,5,5,5] responses to the nodes in the graphs of drug combinations. 282 | batch_orders: [0,0,1,2,2,2,3,3] responses to the graphs in drug combinations. 283 | """ 284 | 285 | nodes_order = [] 286 | batch_order = [] 287 | for i in range(len(num)): 288 | batch_order.append(len(num[i])) 289 | for j in range(len(num[i])): 290 | nodes_order.append(num[i][j]) 291 | nodes_orders = [] 292 | batch_orders = [] 293 | num = 0 294 | for i in range(len(batch_order)): 295 | for j in range(batch_order[i]): 296 | batch_orders.append(num) 297 | num += 1 298 | num_2 = 0 299 | for i in range(len(nodes_order)): 300 | for j in range(nodes_order[i]): 301 | nodes_orders.append(num_2) 302 | num_2 += 1 303 | return nodes_orders, batch_orders 304 | 305 | 306 | def train(model, optimizer, crit, train_loader, train_dataset, device): 307 | model.train() 308 | loss_all = 0 309 | for data in train_loader: 310 | optimizer.zero_grad() 311 | data_model = data 312 | output = model(data_model, device).squeeze() 313 | label = data.y.float().squeeze().to(device) 314 | loss = crit(output, label) 315 | loss.backward() 316 | loss_all += data_model.num_graphs * loss.item() 317 | optimizer.step() 318 | return loss_all / len(train_dataset) 319 | 320 | 321 | def evaluate(model, crit, loader, dataset, device): 322 | model.eval() 323 | predictions = [] 324 | labels = [] 325 | loss_all = 0 326 | with torch.no_grad(): 327 | for data in loader: 328 | data_model = data 329 | pred = model(data_model, device).squeeze() 330 | label = data.y.float().squeeze().to(device) 331 | loss = crit(pred, label) 332 | pred = pred.detach().cpu().numpy() 333 | label = label.detach().cpu().numpy() 334 | # print(pred) 335 | predictions.append(pred) 336 | labels.append(label) 337 | loss_all += data_model.num_graphs * loss.item() 338 | predictions = np.hstack(predictions) 339 | labels = np.hstack(labels) 340 | val_loss = loss_all / len(dataset) 341 | predictions_tensor = torch.from_numpy(predictions) 342 | predictions_tensor = (predictions_tensor > 0.5).float() 343 | labels_tensor = torch.from_numpy(labels) 344 | acc, pre, rec, F1, auc = util.evaluation(predictions_tensor, labels_tensor) 345 | return val_loss, roc_auc_score(labels, predictions), acc, pre, rec, F1, auc 346 | 347 | 348 | def run(args,seed): 349 | print("begin") 350 | setup_seed(seed) 351 | dataset = FearsGraphDataset(root="fears", name="fears") 352 | dataset = dataset.data 353 | random.shuffle(dataset) 354 | split_len = len(dataset) 355 | train_set_len = int(split_len * 0.6) 356 | valid_set_len = int(split_len * 0.2) 357 | train_dataset = dataset[:train_set_len] 358 | val_dataset = dataset[train_set_len : valid_set_len + train_set_len] 359 | test_dataset = dataset[valid_set_len + train_set_len :] 360 | print(len(dataset)) 361 | print(len(train_dataset)) 362 | print(len(val_dataset)) 363 | print(len(test_dataset)) 364 | batch_size = args.batch_size 365 | train_loader = DataLoader(train_dataset, batch_size=batch_size) 366 | val_loader = DataLoader(val_dataset, batch_size=batch_size) 367 | test_loader = DataLoader(test_dataset, batch_size=batch_size) 368 | print("Data processing completed") 369 | 370 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 371 | 372 | model = Net( 373 | pooling_layer=args.pooling_layer, 374 | a1=args.a1, 375 | a2=args.a2, 376 | a3=args.a3, 377 | rho=args.rho, 378 | p0=args.p0, 379 | q0=args.q0, 380 | num=args.num, 381 | eps=args.eps, 382 | same_para=args.same_para, 383 | f_method=args.f_method, 384 | ).to(device) 385 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 386 | crit = torch.nn.BCELoss() 387 | val_acc_all = [] 388 | test_metric_all = [] 389 | 390 | for epoch in range(args.epoch): 391 | val_val_loss, val_acc, val_acc2, val_pre, val_rec, val_F1, val_auc = evaluate( 392 | model, crit, val_loader, val_dataset, device 393 | ) 394 | (test_val_loss, 395 | test_acc, 396 | test_acc2, 397 | test_pre, 398 | test_rec, 399 | test_F1, 400 | test_auc, 401 | ) = evaluate(model, crit, test_loader, test_dataset, device) 402 | val_acc_all.append(val_acc2) 403 | test_metric_all.append( 404 | [test_acc, test_acc2, test_pre, test_rec, test_F1, test_auc] 405 | ) 406 | 407 | best_val_epoch = np.argmax(np.array(val_acc_all)) 408 | best_test_result = test_metric_all[best_val_epoch] 409 | return best_test_result 410 | 411 | 412 | if __name__ == "__main__": 413 | 414 | args = arg_parse() 415 | results=[] 416 | for i in [0,1,2,3,4]: 417 | results.append(run(args,i)) 418 | result = np.mean(np.array(results), axis=0) 419 | std = np.std(np.array(results), axis=0) 420 | print(result) 421 | print(std) 422 | -------------------------------------------------------------------------------- /backbones/resnet_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import shutil 5 | import random 6 | import warnings 7 | import argparse 8 | import numpy as np 9 | import torch.optim 10 | import torch.nn as nn 11 | import torch.utils.data 12 | import torch.nn.parallel 13 | import pooling as Pooling 14 | import torch.distributed as dist 15 | import torch.multiprocessing as mp 16 | import torchvision.models as models 17 | import torch.utils.data.distributed 18 | import torch.backends.cudnn as cudnn 19 | import torchvision.datasets as datasets 20 | import torchvision.transforms as transforms 21 | 22 | from enum import Enum 23 | from torch.optim.lr_scheduler import StepLR 24 | 25 | model_names = sorted(name for name in models.__dict__ 26 | if name.islower() and not name.startswith("__") 27 | and callable(models.__dict__[name])) 28 | 29 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 30 | parser.add_argument('-data', metavar='DIR', default='imagenet_last/imagenet', help='path to dataset (default: imagenet)') 31 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=model_names, help='model architecture: ' + 32 | ' | '.join(model_names) + ' (default: resnet18)') 33 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') 34 | parser.add_argument('--epochs', default=100, type=int, metavar='N', help='number of total epochs to run') 35 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') 36 | parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256), this is the total ' 37 | 'batch size of all GPUs on the current node when ' 'using Data Parallel or Distributed Data Parallel') 38 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate', dest='lr') 39 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 40 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', 41 | dest='weight_decay') 42 | parser.add_argument('-p', '--print-freq', default=10, type=int, metavar='N', help='print frequency (default: 10)') 43 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 44 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') 45 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') 46 | parser.add_argument('--world-size', default=1, type=int, help='number of nodes for distributed training') 47 | parser.add_argument('--rank', default=0, type=int, help='node rank for distributed training') 48 | parser.add_argument('--dist-url', default='tcp:// ', type=str, help='url used to set up distributed training') 49 | parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend') 50 | parser.add_argument('--seed', default=1, type=int, help='seed for initializing training. ') 51 | parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') 52 | parser.add_argument('--multiprocessing-distributed', action='store_true', help='Use multi-processing distributed training to launch ' 53 | 'N processes per node, which has N GPUs. This is the ' 54 | 'fastest way to use PyTorch for either single node or ' 55 | 'multi node data parallel training') 56 | parser.add_argument('--flag', default='k2_0922', type=str, help='flag') 57 | parser.add_argument('--k', default='2', type=int, help='flag') 58 | parser.add_argument('--f_method', default='badmm-e', type=str, help='badmm-e') 59 | 60 | best_acc1 = 0 61 | 62 | 63 | def main(): 64 | args = parser.parse_args() 65 | 66 | if args.seed is not None: 67 | random.seed(args.seed) 68 | torch.manual_seed(args.seed) 69 | cudnn.deterministic = True 70 | warnings.warn('You have chosen to seed training. ' 71 | 'This will turn on the CUDNN deterministic setting, ' 72 | 'which can slow down your training considerably! ' 73 | 'You may see unexpected behavior when restarting ' 74 | 'from checkpoints.') 75 | 76 | if args.gpu is not None: 77 | warnings.warn('You have chosen a specific GPU. This will completely ' 78 | 'disable data parallelism.') 79 | 80 | if args.dist_url == "env://" and args.world_size == -1: 81 | args.world_size = int(os.environ["WORLD_SIZE"]) 82 | 83 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 84 | 85 | ngpus_per_node = torch.cuda.device_count() 86 | if args.multiprocessing_distributed: 87 | # Since we have ngpus_per_node processes per node, the total world_size 88 | # needs to be adjusted accordingly 89 | args.world_size = ngpus_per_node * args.world_size 90 | # Use torch.multiprocessing.spawn to launch distributed processes: the 91 | # main_worker process function 92 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 93 | else: 94 | # Simply call main_worker function 95 | main_worker(args.gpu, ngpus_per_node, args) 96 | 97 | 98 | def main_worker(gpu, ngpus_per_node, args): 99 | global best_acc1 100 | args.gpu = gpu 101 | 102 | if args.gpu is not None: 103 | print("Use GPU: {} for training".format(args.gpu)) 104 | 105 | if args.distributed: 106 | if args.dist_url == "env://" and args.rank == -1: 107 | args.rank = int(os.environ["RANK"]) 108 | if args.multiprocessing_distributed: 109 | # For multiprocessing distributed training, rank needs to be the 110 | # global rank among all the processes 111 | args.rank = args.rank * ngpus_per_node + gpu 112 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 113 | world_size=args.world_size, rank=args.rank) 114 | # create model 115 | args.pretrained = False 116 | if args.pretrained: 117 | print("=> using pre-trained model '{}'".format(args.arch)) 118 | model = models.__dict__[args.arch](pretrained=True) 119 | else: 120 | print("model") 121 | print("=> creating model '{}'".format(args.arch)) 122 | model = models.__dict__[args.arch]() 123 | dim = 512 124 | model = replace_pooling(model, k=args.k, dim=dim, f_method=args.f_method) 125 | 126 | if not torch.cuda.is_available(): 127 | print('using CPU, this will be slow') 128 | elif args.distributed: 129 | # For multiprocessing distributed, DistributedDataParallel constructor 130 | # should always set the single device scope, otherwise, 131 | # DistributedDataParallel will use all available devices. 132 | if args.gpu is not None: 133 | torch.cuda.set_device(args.gpu) 134 | model.cuda(args.gpu) 135 | # When using a single GPU per process and per 136 | # DistributedDataParallel, we need to divide the batch size 137 | # ourselves based on the total number of GPUs of the current node. 138 | args.batch_size = int(args.batch_size / ngpus_per_node) 139 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 140 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 141 | else: 142 | model.cuda() 143 | # DistributedDataParallel will divide and allocate batch_size to all 144 | # available GPUs if device_ids are not set 145 | model = torch.nn.parallel.DistributedDataParallel(model) 146 | elif args.gpu is not None: 147 | torch.cuda.set_device(args.gpu) 148 | model = model.cuda(args.gpu) 149 | else: 150 | # DataParallel will divide and allocate batch_size to all available GPUs 151 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 152 | model.features = torch.nn.DataParallel(model.features) 153 | model.cuda() 154 | else: 155 | model = torch.nn.DataParallel(model).cuda() 156 | 157 | # define loss function (criterion), optimizer, and learning rate scheduler 158 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 159 | 160 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 161 | momentum=args.momentum, 162 | weight_decay=args.weight_decay) 163 | 164 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 165 | scheduler = StepLR(optimizer, step_size=30, gamma=0.1) 166 | 167 | # optionally resume from a checkpoint 168 | if args.resume: 169 | if os.path.isfile(args.resume): 170 | print("=> loading checkpoint '{}'".format(args.resume)) 171 | if args.gpu is None: 172 | checkpoint = torch.load(args.resume) 173 | else: 174 | # Map model to be loaded to specified single gpu. 175 | loc = 'cuda:{}'.format(args.gpu) 176 | checkpoint = torch.load(args.resume, map_location=loc) 177 | args.start_epoch = checkpoint['epoch'] 178 | best_acc1 = checkpoint['best_acc1'] 179 | if args.gpu is not None: 180 | # best_acc1 may be from a checkpoint from a different GPU 181 | best_acc1 = best_acc1.to(args.gpu) 182 | model.load_state_dict(checkpoint['state_dict']) 183 | optimizer.load_state_dict(checkpoint['optimizer']) 184 | scheduler.load_state_dict(checkpoint['scheduler']) 185 | print("=> loaded checkpoint '{}' (epoch {})" 186 | .format(args.resume, checkpoint['epoch'])) 187 | else: 188 | print("=> no checkpoint found at '{}'".format(args.resume)) 189 | 190 | cudnn.benchmark = True 191 | 192 | # Data loading code 193 | traindir = os.path.join(args.data, 'train') 194 | valdir = os.path.join(args.data, 'val') 195 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 196 | std=[0.229, 0.224, 0.225]) 197 | 198 | train_dataset = datasets.ImageFolder( 199 | traindir, 200 | transforms.Compose([ 201 | transforms.RandomResizedCrop(224), 202 | transforms.RandomHorizontalFlip(), 203 | transforms.ToTensor(), 204 | normalize, 205 | ])) 206 | if args.distributed: 207 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 208 | else: 209 | train_sampler = None 210 | 211 | train_loader = torch.utils.data.DataLoader( 212 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 213 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 214 | 215 | val_loader = torch.utils.data.DataLoader( 216 | datasets.ImageFolder(valdir, transforms.Compose([ 217 | transforms.Resize(256), 218 | transforms.CenterCrop(224), 219 | transforms.ToTensor(), 220 | normalize, 221 | ])), 222 | batch_size=args.batch_size, shuffle=False, 223 | num_workers=args.workers, pin_memory=True) 224 | if args.evaluate: 225 | validate(val_loader, model, criterion, args) 226 | return 227 | 228 | for epoch in range(args.start_epoch, args.epochs): 229 | if args.distributed: 230 | train_sampler.set_epoch(epoch) 231 | 232 | # train for one epoch 233 | train(train_loader, model, criterion, optimizer, epoch, args) 234 | 235 | # evaluate on validation set 236 | acc1 = validate(val_loader, model, criterion, args) 237 | 238 | scheduler.step() 239 | 240 | # remember best acc@1 and save checkpoint 241 | is_best = acc1 > best_acc1 242 | best_acc1 = max(acc1, best_acc1) 243 | 244 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 245 | and args.rank % ngpus_per_node == 0): 246 | save_checkpoint({ 247 | 'epoch': epoch + 1, 248 | 'arch': args.arch, 249 | 'state_dict': model.state_dict(), 250 | 'best_acc1': best_acc1, 251 | 'optimizer': optimizer.state_dict(), 252 | 'scheduler': scheduler.state_dict() 253 | }, is_best, args.flag) 254 | 255 | 256 | def train(train_loader, model, criterion, optimizer, epoch, args): 257 | batch_time = AverageMeter('Time', ':6.3f') 258 | data_time = AverageMeter('Data', ':6.3f') 259 | losses = AverageMeter('Loss', ':.4e') 260 | top1 = AverageMeter('Acc@1', ':6.2f') 261 | top5 = AverageMeter('Acc@5', ':6.2f') 262 | progress = ProgressMeter( 263 | len(train_loader), 264 | [batch_time, data_time, losses, top1, top5], 265 | prefix="Epoch: [{}]".format(epoch)) 266 | 267 | # switch to train mode 268 | model.train() 269 | 270 | end = time.time() 271 | for i, (images, target) in enumerate(train_loader): 272 | # measure data loading time 273 | data_time.update(time.time() - end) 274 | 275 | if args.gpu is not None: 276 | images = images.cuda(args.gpu, non_blocking=True) 277 | if torch.cuda.is_available(): 278 | target = target.cuda(args.gpu, non_blocking=True) 279 | 280 | # compute output 281 | output = model(images) 282 | loss = criterion(output, target) 283 | 284 | # measure accuracy and record loss 285 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 286 | losses.update(loss.item(), images.size(0)) 287 | top1.update(acc1[0], images.size(0)) 288 | top5.update(acc5[0], images.size(0)) 289 | 290 | # compute gradient and do SGD step 291 | optimizer.zero_grad() 292 | loss.backward() 293 | optimizer.step() 294 | 295 | # measure elapsed time 296 | batch_time.update(time.time() - end) 297 | end = time.time() 298 | 299 | if i % args.print_freq == 0: 300 | progress.display(i) 301 | 302 | 303 | def validate(val_loader, model, criterion, args): 304 | batch_time = AverageMeter('Time', ':6.3f', Summary.NONE) 305 | losses = AverageMeter('Loss', ':.4e', Summary.NONE) 306 | top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE) 307 | top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE) 308 | progress = ProgressMeter( 309 | len(val_loader), 310 | [batch_time, losses, top1, top5], 311 | prefix='Test: ') 312 | 313 | # switch to evaluate mode 314 | model.eval() 315 | 316 | with torch.no_grad(): 317 | end = time.time() 318 | for i, (images, target) in enumerate(val_loader): 319 | if args.gpu is not None: 320 | images = images.cuda(args.gpu, non_blocking=True) 321 | if torch.cuda.is_available(): 322 | target = target.cuda(args.gpu, non_blocking=True) 323 | 324 | # compute output 325 | output = model(images) 326 | loss = criterion(output, target) 327 | 328 | # measure accuracy and record loss 329 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 330 | losses.update(loss.item(), images.size(0)) 331 | top1.update(acc1[0], images.size(0)) 332 | top5.update(acc5[0], images.size(0)) 333 | 334 | # measure elapsed time 335 | batch_time.update(time.time() - end) 336 | end = time.time() 337 | 338 | if i % args.print_freq == 0: 339 | progress.display(i) 340 | 341 | progress.display_summary() 342 | 343 | return top1.avg 344 | 345 | 346 | def save_checkpoint(state, is_best, flag, filename='checkpoint.pth.tar'): 347 | filename = flag+"-"+filename 348 | torch.save(state, filename) 349 | if is_best: 350 | shutil.copyfile(filename, flag+"-"+'model_best.pth.tar') 351 | 352 | 353 | class Summary(Enum): 354 | NONE = 0 355 | AVERAGE = 1 356 | SUM = 2 357 | COUNT = 3 358 | 359 | 360 | class AverageMeter(object): 361 | """Computes and stores the average and current value""" 362 | 363 | def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE): 364 | self.name = name 365 | self.fmt = fmt 366 | self.summary_type = summary_type 367 | self.reset() 368 | 369 | def reset(self): 370 | self.val = 0 371 | self.avg = 0 372 | self.sum = 0 373 | self.count = 0 374 | 375 | def update(self, val, n=1): 376 | self.val = val 377 | self.sum += val * n 378 | self.count += n 379 | self.avg = self.sum / self.count 380 | 381 | def __str__(self): 382 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 383 | return fmtstr.format(**self.__dict__) 384 | 385 | def summary(self): 386 | fmtstr = '' 387 | if self.summary_type is Summary.NONE: 388 | fmtstr = '' 389 | elif self.summary_type is Summary.AVERAGE: 390 | fmtstr = '{name} {avg:.3f}' 391 | elif self.summary_type is Summary.SUM: 392 | fmtstr = '{name} {sum:.3f}' 393 | elif self.summary_type is Summary.COUNT: 394 | fmtstr = '{name} {count:.3f}' 395 | else: 396 | raise ValueError('invalid summary type %r' % self.summary_type) 397 | 398 | return fmtstr.format(**self.__dict__) 399 | 400 | 401 | class ProgressMeter(object): 402 | def __init__(self, num_batches, meters, prefix=""): 403 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 404 | self.meters = meters 405 | self.prefix = prefix 406 | 407 | def display(self, batch): 408 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 409 | entries += [str(meter) for meter in self.meters] 410 | print('\t'.join(entries)) 411 | 412 | def display_summary(self): 413 | entries = [" *"] 414 | entries += [meter.summary() for meter in self.meters] 415 | print(' '.join(entries)) 416 | 417 | def _get_batch_fmtstr(self, num_batches): 418 | num_digits = len(str(num_batches // 1)) 419 | fmt = '{:' + str(num_digits) + 'd}' 420 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 421 | 422 | 423 | def accuracy(output, target, topk=(1,)): 424 | """Computes the accuracy over the k top predictions for the specified values of k""" 425 | with torch.no_grad(): 426 | maxk = max(topk) 427 | batch_size = target.size(0) 428 | 429 | _, pred = output.topk(maxk, 1, True, True) 430 | pred = pred.t() 431 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 432 | 433 | res = [] 434 | for k in topk: 435 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 436 | res.append(correct_k.mul_(100.0 / batch_size)) 437 | return res 438 | 439 | 440 | class UOTPoolingImg(nn.Module): 441 | def __init__(self, d: int = 512, k: int = 4, rho: float = None, 442 | a1: float = None, a2: float = None, a3: float = None, f_method: str='sinkhorn'): 443 | super(UOTPoolingImg, self).__init__() 444 | self.pooling2d = Pooling.UOTPooling(dim=d, num=k, rho=rho, a1=a1, a2=a2, a3=a3, f_method=f_method) 445 | 446 | def forward(self, x: torch.Tensor): 447 | """ 448 | :param x: (B, C, W, H) 449 | :return: 450 | y: (B, C, 1, 1) 451 | """ 452 | b, c, w, h = x.shape 453 | #print(b, c, w, h) 454 | x = torch.reshape(x, (b, c, w * h)) # b x c x (wh) 455 | x = x.permute(0, 2, 1) # b x (wh) x c 456 | x = torch.reshape(x, (b * w * h, c)) # (bwh) * c 457 | batch = torch.LongTensor(np.zeros((b * w * h, ))).to(x.device) 458 | for i in range(b): 459 | batch[i*(w*h):(i+1)*(w*h)] = i 460 | tmp = self.pooling2d(x, batch) 461 | #print('Done') 462 | return tmp.view(b, c, 1, 1) 463 | 464 | def replace_pooling(net, k, dim: int,f_method): 465 | for name, layer in net.named_modules(): 466 | if name == 'avgpool': 467 | net._modules[name] = UOTPoolingImg(k=k, d=dim, f_method=f_method) 468 | return net 469 | 470 | if __name__ == '__main__': 471 | main() 472 | -------------------------------------------------------------------------------- /pooling/__init__.py: -------------------------------------------------------------------------------- 1 | from .baselines import * 2 | from .rotlayers import * 3 | from .rotpooling import * 4 | from .utils import * -------------------------------------------------------------------------------- /pooling/baselines.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_geometric.nn import global_add_pool, global_max_pool, global_mean_pool 4 | from torch_geometric.utils import softmax 5 | from torch_scatter import scatter_add 6 | 7 | """ 8 | AveragePooling (UOT, with some alpha's) 9 | MaxPooling (UOT, with some alpha's) 10 | GeneralizedNormPooling 11 | DeepSet (UOT, with mixed add- and max-pooling passing through learnable parameters) 12 | MixedPooling (UOT-barycenter, with nonparametric weights) 13 | GatedPooling (UOT-barycenter, with parametric weights) 14 | Set2Set 15 | AttentionPooling (UOT, with one side constraint + parametric transport) 16 | GatedAttentionPooling (UOT, with one side constraint + parametric transport) 17 | DynamicPooling (UOT, with one side constraint + parametric transport) 18 | """ 19 | 20 | 21 | class GeneralizedNormPooling(nn.Module): 22 | def __init__(self, d: int): 23 | super(GeneralizedNormPooling, self).__init__() 24 | self.ps = nn.Parameter(torch.randn(2)) 25 | self.qs = nn.Parameter(torch.randn(2)) 26 | self.softplus = nn.Softplus() 27 | self.thres = nn.Threshold(threshold=-50, value=50, inplace=False) 28 | self.tanh = nn.Tanh() 29 | self.linear = nn.Linear(in_features=d, out_features=d, bias=True) 30 | self.epsilon = 1e-6 31 | 32 | def forward(self, x, batch): 33 | """ 34 | :param x: (N, D) matrix 35 | :param batch: (N,) each element in {0, B-1} 36 | :return: 37 | z: (B, d) matrix 38 | """ 39 | nums = torch.ones_like(x[:, 0]) 40 | nums = global_add_pool(nums, batch=batch).unsqueeze(1) # (B, 1) 41 | 42 | ps = -self.thres(-self.softplus(self.ps)) 43 | qs = self.tanh(self.qs) 44 | 45 | x = torch.abs(x) + self.epsilon 46 | d = x.shape[0] 47 | d1 = int(d/2) 48 | 49 | x_pos = torch.exp(ps[0] * torch.log(x[:, :d1])) 50 | gnp_pos = torch.exp(torch.log(global_add_pool(x_pos, batch=batch)) / ps[0]) / (nums ** qs[0]) 51 | 52 | x_neg = x[:, d1:] 53 | gnp_neg = (global_add_pool(x_neg ** ps[1], batch=batch) ** (1 / ps[1])) / (nums ** qs[0]) 54 | gnp = self.linear(torch.cat((gnp_pos, gnp_neg), dim=1)) 55 | return gnp 56 | 57 | 58 | class DeepSet(nn.Module): 59 | """ 60 | The mixed permutation-invariant structure 61 | in Zaheer, Manzil, et al. "Deep Sets." 62 | Proceedings of the 31st International Conference on Neural Information Processing Systems. 2017. 63 | """ 64 | 65 | def __init__(self, d: int, h: int): 66 | """ 67 | :param d: the dimension of input samples 68 | :param h: the dimension of hidden representations 69 | """ 70 | super(DeepSet, self).__init__() 71 | self.alpha = nn.Parameter(torch.randn(1)) 72 | self.sigmoid = nn.Sigmoid() 73 | self.regressor = nn.Sequential( 74 | nn.Linear(d, h), 75 | nn.ELU(inplace=True), 76 | nn.Linear(h, h), 77 | nn.ELU(inplace=True), 78 | nn.Linear(h, h), 79 | nn.ELU(inplace=True), 80 | nn.Linear(h, d), 81 | ) 82 | 83 | def forward(self, x, batch): 84 | """ 85 | :param x: (N, D) matrix 86 | :param batch: (N,) each element in {0, B-1} 87 | :return: 88 | """ 89 | alpha = self.sigmoid(self.alpha) 90 | return self.regressor(alpha * global_add_pool(x, batch) + (1 - alpha) * global_max_pool(x, batch)) 91 | 92 | 93 | def global_softmax_pooling(x: torch.Tensor, alpha: torch.Tensor, batch: torch.Tensor) -> torch.Tensor: 94 | """ 95 | Apply softmax to a batch of bags 96 | :param x: (n, d) matrix 97 | :param alpha: (n, 1) matrix 98 | :param batch: (n,) vector, each element in {0, B-1} 99 | :return: 100 | z: (B, d) matrix 101 | """ 102 | alpha = softmax(alpha, batch) # (n, 1) 103 | return global_add_pool(x * alpha, batch) 104 | 105 | 106 | class MixedPooling(nn.Module): 107 | """ 108 | The mixed pooling structure 109 | in Lee, Chen-Yu, Patrick W. Gallagher, and Zhuowen Tu. 110 | "Generalizing pooling functions in convolutional neural networks: Mixed, gated, and tree." 111 | Artificial intelligence and statistics. PMLR, 2016. 112 | """ 113 | 114 | def __init__(self): 115 | super(MixedPooling, self).__init__() 116 | self.alpha = nn.Parameter(torch.randn(1)) 117 | self.sigmoid = nn.Sigmoid() 118 | 119 | def forward(self, x, batch): 120 | """ 121 | :param x: (N, D) matrix 122 | :param batch: (N,) each element in {0, B-1} 123 | :return: 124 | """ 125 | alpha = self.sigmoid(self.alpha) 126 | return alpha * global_mean_pool(x, batch) + (1 - alpha) * global_max_pool(x, batch) 127 | 128 | 129 | class GatedPooling(nn.Module): 130 | """ 131 | The gated pooling structure 132 | in Lee, Chen-Yu, Patrick W. Gallagher, and Zhuowen Tu. 133 | "Generalizing pooling functions in convolutional neural networks: Mixed, gated, and tree." 134 | Artificial intelligence and statistics. PMLR, 2016. 135 | """ 136 | 137 | def __init__(self, dim: int): 138 | super(GatedPooling, self).__init__() 139 | self.linear = nn.Linear(in_features=dim, out_features=1, bias=False) 140 | self.sigmoid = nn.Sigmoid() 141 | 142 | def forward(self, x, batch): 143 | """ 144 | :param x: (N, D) matrix 145 | :param batch: (N,) each element in {0, B-1} 146 | :return: 147 | """ 148 | alpha = self.linear(x) # (N, 1) 149 | alpha = self.sigmoid(global_add_pool(alpha, batch)) 150 | return alpha * global_mean_pool(x, batch) + (1 - alpha) * global_max_pool(x, batch) 151 | 152 | 153 | class Set2Set(nn.Module): 154 | r"""The global pooling operator based on iterative content-based attention 155 | from the `"Order Matters: Sequence to sequence for sets" 156 | `_ paper 157 | .. math:: 158 | \mathbf{q}_t &= \mathrm{LSTM}(\mathbf{q}^{*}_{t-1}) 159 | \alpha_{i,t} &= \mathrm{softmax}(\mathbf{x}_i \cdot \mathbf{q}_t) 160 | \mathbf{r}_t &= \sum_{i=1}^N \alpha_{i,t} \mathbf{x}_i 161 | \mathbf{q}^{*}_t &= \mathbf{q}_t \, \Vert \, \mathbf{r}_t, 162 | where :math:`\mathbf{q}^{*}_T` defines the output of the layer with twice 163 | the dimensionality as the input. 164 | Args: 165 | in_channels (int): Size of each input sample. 166 | processing_steps (int): Number of iterations :math:`T`. 167 | num_layers (int, optional): Number of recurrent layers, *.e.g*, setting 168 | :obj:`num_layers=2` would mean stacking two LSTMs together to form 169 | a stacked LSTM, with the second LSTM taking in outputs of the first 170 | LSTM and computing the final results. (default: :obj:`1`) 171 | """ 172 | 173 | def __init__(self, in_channels, processing_steps, num_layers=1): 174 | super(Set2Set, self).__init__() 175 | 176 | self.in_channels = in_channels 177 | self.out_channels = 2 * in_channels 178 | self.processing_steps = processing_steps 179 | self.num_layers = num_layers 180 | 181 | self.lstm = torch.nn.LSTM(self.out_channels, self.in_channels, 182 | num_layers) 183 | 184 | self.reset_parameters() 185 | 186 | def reset_parameters(self): 187 | self.lstm.reset_parameters() 188 | 189 | def forward(self, x, batch): 190 | """""" 191 | batch_size = batch.max().item() + 1 192 | 193 | h = (x.new_zeros((self.num_layers, batch_size, self.in_channels)), 194 | x.new_zeros((self.num_layers, batch_size, self.in_channels))) 195 | q_star = x.new_zeros(batch_size, self.out_channels) 196 | 197 | for i in range(self.processing_steps): 198 | q, h = self.lstm(q_star.unsqueeze(0), h) 199 | q = q.view(batch_size, self.in_channels) 200 | e = (x * q[batch]).sum(dim=-1, keepdim=True) 201 | a = softmax(e, batch, num_nodes=batch_size) 202 | r = scatter_add(a * x, batch, dim=0, dim_size=batch_size) 203 | q_star = torch.cat([q, r], dim=-1) 204 | 205 | return q_star 206 | 207 | def __repr__(self): 208 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 209 | self.out_channels) 210 | 211 | 212 | class AttentionPooling(nn.Module): 213 | def __init__(self, d: int, h: int): 214 | super(AttentionPooling, self).__init__() 215 | self.d = d 216 | self.h = h 217 | self.linear_r1 = nn.Linear(in_features=self.d, out_features=self.h, bias=False) 218 | self.linear_r2 = nn.Linear(in_features=self.h, out_features=1, bias=False) 219 | self.tanh = nn.Tanh() 220 | self.softmax1 = nn.Softmax(dim=1) 221 | 222 | def forward(self, x, batch): 223 | """ 224 | Implement attention pooling in 225 | Ilse, Maximilian, Jakub Tomczak, and Max Welling. 226 | "Attention-based deep multiple instance learning." 227 | International conference on machine learning. PMLR, 2018. 228 | :param x: (n, d) features 229 | :param batch: (n,) in {0, B-1} 230 | :return: 231 | z: (B, d) features 232 | """ 233 | alpha = self.linear_r2(self.tanh(self.linear_r1(x))) # (n, 1) 234 | return global_softmax_pooling(x, alpha, batch) 235 | 236 | 237 | class GatedAttentionPooling(nn.Module): 238 | def __init__(self, d: int, h: int): 239 | super(GatedAttentionPooling, self).__init__() 240 | self.d = d 241 | self.h = h 242 | self.linear_r1 = nn.Linear(in_features=self.d, out_features=self.h, bias=False) 243 | self.linear_r2 = nn.Linear(in_features=self.d, out_features=self.h, bias=False) 244 | self.linear_r3 = nn.Linear(in_features=self.h, out_features=1, bias=False) 245 | self.tanh = nn.Tanh() 246 | self.softmax1 = nn.Softmax(dim=1) 247 | 248 | def forward(self, x, batch): 249 | """ 250 | Implement gated attention pooling in 251 | Ilse, Maximilian, Jakub Tomczak, and Max Welling. 252 | "Attention-based deep multiple instance learning." 253 | International conference on machine learning. PMLR, 2018. 254 | :param x: (n, d) features 255 | :param batch: (n,) in {0, B-1} 256 | :return: 257 | z: (B, d) features 258 | """ 259 | ux = self.tanh(self.linear_r1(x)) # (n, h) 260 | vx = self.softmax1(self.linear_r2(x)) # (n, h) 261 | alpha = self.linear_r3(ux * vx) 262 | return global_softmax_pooling(x, alpha, batch) 263 | 264 | 265 | class DynamicPooling(nn.Module): 266 | def __init__(self, d: int, k: int = 3): 267 | super(DynamicPooling, self).__init__() 268 | self.d = d 269 | self.k = k 270 | self.softmax1 = nn.Softmax(dim=1) 271 | 272 | def forward(self, x, batch): 273 | """ 274 | Implement the dynamic pooling layer in 275 | Yan, Yongluan, Xinggang Wang, Xiaojie Guo, Jiemin Fang, Wenyu Liu, and Junzhou Huang. 276 | "Deep multi-instance learning with dynamic pooling." 277 | In Asian Conference on Machine Learning, pp. 662-677. PMLR, 2018. 278 | :param x: (n, d) features 279 | :param batch: (n,) in {0, B-1} 280 | :return: 281 | z: (B, d) features 282 | """ 283 | batch_size = int(batch.max().item() + 1) 284 | alpha = torch.zeros_like(x[:, 0]).unsqueeze(1) # (n, 1) 285 | for _ in range(self.k): 286 | z = global_softmax_pooling(x, alpha, batch) # (B, d) 287 | energy = torch.sum(z ** 2, dim=1, keepdim=True) # (B, 1) 288 | z_squashed = torch.sqrt(energy) / (1 + energy) * z # (B, d) 289 | for i in range(batch_size): 290 | alpha[batch == i, :] = alpha[batch == i, :] + torch.mm(x[batch == i, :], z_squashed[i, :].unsqueeze(1)) 291 | return global_softmax_pooling(x, alpha, batch) 292 | -------------------------------------------------------------------------------- /pooling/rotlayers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def uot_sinkhorn(x: torch.Tensor, p0: torch.Tensor, q0: torch.Tensor, 6 | a1: torch.Tensor, a2: torch.Tensor, a3: torch.Tensor, mask: torch.Tensor, 7 | num: int = 4, eps: float = 1e-8) -> torch.Tensor: 8 | """ 9 | Solving regularized optimal transport via Sinkhorn-scaling 10 | :param x: (B, N, D), a matrix with N samples and each sample is D-dimensional 11 | :param p0: (B, 1, D), the marginal prior of dimensions 12 | :param q0: (B, N, 1), the marginal prior of samples 13 | :param a1: (num, ), the weight of the entropic term 14 | :param a2: (num, ), the weight of the KL term of p0 15 | :param a3: (num, ), the weight of the KL term of q0 16 | :param mask: (B, N, 1) a masking tensor 17 | :param num: the number of outer iterations 18 | :param eps: the epsilon to avoid numerical instability 19 | :return: 20 | t: (B, N, D), the optimal transport matrix 21 | """ 22 | t = (q0 * p0) * mask # (B, N, D) 23 | log_p0 = torch.log(p0) # (B, 1, D) 24 | log_q0 = torch.log(q0 + eps) * mask # (B, N, 1) 25 | tau = 0.0 26 | cost = (-x - tau * torch.log(t + eps)) * mask # (B, N, D) 27 | a = torch.zeros_like(p0) # (B, 1, D) 28 | b = torch.zeros_like(q0) # (B, N, 1) 29 | a11 = a1[0] + tau 30 | y = -cost / a11 # (B, N, D) 31 | for k in range(num): 32 | n = min([k, a1.shape[0] - 1]) 33 | a11 = a1[n] + tau 34 | ymin, _ = torch.min(y, dim=1, keepdim=True) 35 | ymax, _ = torch.max(ymin - mask * ymin + y, dim=1, keepdim=True) # (B, 1, D) 36 | log_p = torch.log(torch.sum(torch.exp((y - ymax) * mask) * mask, dim=1, keepdim=True)) + ymax # (B, 1, D) 37 | log_q = torch.logsumexp(y, dim=2, keepdim=True) * mask # (B, N, 1) 38 | a = a2[n] / (a2[n] + a11) * (a / a11 + log_p0 - log_p) 39 | b = a3[n] / (a3[n] + a11) * (b / a11 + log_q0 - log_q) 40 | y = (-cost / a11 + a + b) * mask 41 | t = torch.exp(y) * mask 42 | return t 43 | 44 | 45 | def rot_sinkhorn(x: torch.Tensor, c1: torch.Tensor, c2: torch.Tensor, p0: torch.Tensor, q0: torch.Tensor, 46 | a0: torch.Tensor, a1: torch.Tensor, a2: torch.Tensor, a3: torch.Tensor, mask: torch.Tensor, 47 | num: int = 4, inner: int = 5, eps: float = 1e-8) -> torch.Tensor: 48 | """ 49 | Solving regularized optimal transport via Sinkhorn-scaling 50 | :param x: (B, N, D), a matrix with N samples and each sample is D-dimensional 51 | :param c1: (B, D, D), a matrix with size D x D 52 | :param c2: (B, N, N), a matrix with size N x N 53 | :param p0: (B, 1, D), the marginal prior of dimensions 54 | :param q0: (B, N, 1), the marginal prior of samples 55 | :param a0: (num, ), the weight of the GW term 56 | :param a1: (num, ), the weight of the entropic term 57 | :param a2: (num, ), the weight of the KL term of p0 58 | :param a3: (num, ), the weight of the KL term of q0 59 | :param mask: (B, N, 1) a masking tensor 60 | :param num: the number of outer iterations 61 | :param inner: the number of inner Sinkhorn iterations 62 | :param eps: the epsilon to avoid numerical instability 63 | :return: 64 | t: (B, N, D), the optimal transport matrix 65 | """ 66 | t = (q0 * p0) * mask # (B, N, D) 67 | log_p0 = torch.log(p0) # (B, 1, D) 68 | log_q0 = torch.log(q0 + eps) * mask # (B, N, 1) 69 | tau = 1.0 70 | for m in range(num): 71 | n = min([m, a1.shape[0]-1]) 72 | a11 = a1[n] + tau 73 | tmp1 = torch.matmul(c2, t) # (B, N, D) 74 | tmp2 = torch.matmul(tmp1, c1) # (B, N, D) 75 | cost = (-x - a0[n] * tmp2 - tau * torch.log(t + eps)) * mask # (B, N, D) 76 | a = torch.zeros_like(p0) # (B, 1, D) 77 | b = torch.zeros_like(q0) # (B, N, 1) 78 | y = -cost / a11 # (B, N, D) 79 | for k in range(inner): 80 | ymin, _ = torch.min(y, dim=1, keepdim=True) 81 | ymax, _ = torch.max(ymin - mask * ymin + y, dim=1, keepdim=True) # (B, 1, D) 82 | log_p = torch.log(torch.sum(torch.exp((y - ymax) * mask) * mask, dim=1, keepdim=True)) + ymax # (B, 1, D) 83 | log_q = torch.logsumexp(y, dim=2, keepdim=True) * mask # (B, N, 1) 84 | a = a2[n] / (a2[n] + a11) * (a / a11 + log_p0 - log_p) 85 | b = a3[n] / (a3[n] + a11) * (b / a11 + log_q0 - log_q) 86 | y = (-cost / a11 + a + b) * mask 87 | t = torch.exp(y) * mask 88 | return t 89 | 90 | 91 | def uot_badmm(x: torch.Tensor, p0: torch.Tensor, q0: torch.Tensor, 92 | a1: torch.Tensor, a2: torch.Tensor, a3: torch.Tensor, rho: torch.Tensor, 93 | mask: torch.Tensor, num: int = 4, eps: float = 1e-8) -> torch.Tensor: 94 | """ 95 | Solving regularized optimal transport via Bregman ADMM algorithm (entropic regularizer) 96 | :param x: (B, N, D), a matrix with N samples and each sample is D-dimensional 97 | :param p0: (B, 1, D), the marginal prior of dimensions 98 | :param q0: (B, N, 1), the marginal prior of samples 99 | :param a1: (num, ), the weight of the entropic term 100 | :param a2: (num, ), the weight of the KL term of p0 101 | :param a3: (num, ), the weight of the KL term of q0 102 | :param rho: (num, ), the learning rate of ADMM 103 | :param mask: (B, N, 1) a masking tensor 104 | :param num: the number of Bregman ADMM iterations 105 | :param eps: the epsilon to avoid numerical instability 106 | :return: 107 | t: (N, D), the optimal transport matrix 108 | """ 109 | log_p0 = torch.log(p0) # (B, 1, D) 110 | log_q0 = torch.log(q0 + eps) * mask # (B, N, 1) 111 | log_t = (log_q0 + log_p0) * mask # (B, N, D) 112 | log_s = (log_q0 + log_p0) * mask # (B, N, D) 113 | log_mu = torch.log(p0) # (B, 1, D) 114 | log_eta = torch.log(q0 + eps) * mask # (B, N, 1) 115 | z = torch.zeros_like(log_t) # (B, N, D) 116 | z1 = torch.zeros_like(p0) # (B, 1, D) 117 | z2 = torch.zeros_like(q0) # (B, N, 1) 118 | for k in range(num): 119 | n = min([k, a1.shape[0] - 1]) 120 | # update logP 121 | y = ((x - z) / rho[n] + log_s) # (B, N, D) 122 | log_t = mask * (log_eta - torch.logsumexp(y, dim=2, keepdim=True)) + y # (B, N, D) 123 | # update logS 124 | y = (z + rho[n] * log_t) / (a1[n] + rho[n]) # (B, N, D) 125 | ymin, _ = torch.min(y, dim=1, keepdim=True) 126 | ymax, _ = torch.max(ymin- mask * ymin + y, dim=1, keepdim=True) # (B, 1, D) 127 | # (B, N, D) 128 | log_s = mask * ( 129 | log_mu - torch.log(torch.sum(torch.exp((y - ymax) * mask) * mask, dim=1, keepdim=True)) - ymax) + y 130 | # update dual variables 131 | t = torch.exp(log_t) * mask 132 | s = torch.exp(log_s) * mask 133 | z = z + rho[n] * (t - s) 134 | y = (rho[n] * log_mu + a2[n] * log_p0 - z1) / (rho[n] + a2[n]) # (B, 1, D) 135 | log_mu = y - torch.logsumexp(y, dim=2, keepdim=True) # (B, 1, D) 136 | y = ((rho[n] * log_eta + a3[n] * log_q0 - z2) / (rho[n] + a3[n])) # (B, N, 1) 137 | ymin, _ = torch.min(y, dim=1, keepdim=True) 138 | ymax, _ = torch.max(ymin - mask * ymin + y, dim=1, keepdim=True) # (B, 1, D) 139 | log_eta = (y - torch.log( 140 | torch.sum(torch.exp((y - ymax) * mask) * mask, dim=1, keepdim=True)) - ymax) * mask # (B, N, 1) 141 | # update dual variables 142 | z1 = z1 + rho[n] * (torch.exp(log_mu) - torch.sum(s, dim=1, keepdim=True)) # (B, 1, D) 143 | z2 = z2 + rho[n] * (torch.exp(log_eta) * mask - torch.sum(t, dim=2, keepdim=True)) * mask # (B, N, 1) 144 | return torch.exp(log_t) * mask 145 | 146 | 147 | def rot_badmm(x: torch.Tensor, c1: torch.Tensor, c2: torch.Tensor, p0: torch.Tensor, q0: torch.Tensor, 148 | a0: torch.Tensor, a1: torch.Tensor, a2: torch.Tensor, a3: torch.Tensor, rho: torch.Tensor, 149 | mask: torch.Tensor, num: int = 4, eps: float = 1e-8) -> torch.Tensor: 150 | """ 151 | Solving regularized optimal transport via Bregman ADMM algorithm (entropic regularizer) 152 | :param x: (B, N, D), a matrix with N samples and each sample is D-dimensional 153 | :param c1: (B, D, D), a matrix with size D x D 154 | :param c2: (B, N, N), a matrix with size N x N 155 | :param p0: (B, 1, D), the marginal prior of dimensions 156 | :param q0: (B, N, 1), the marginal prior of samples 157 | :param a0: (num, ), the weight of the GW term 158 | :param a1: (num, ), the weight of the entropic term 159 | :param a2: (num, ), the weight of the KL term of p0 160 | :param a3: (num, ), the weight of the KL term of q0 161 | :param rho: (num, ), the learning rate of ADMM 162 | :param mask: (B, N, 1) a masking tensor 163 | :param num: the number of Bregman ADMM iterations 164 | :param eps: the epsilon to avoid numerical instability 165 | :return: 166 | t: (N, D), the optimal transport matrix 167 | """ 168 | log_p0 = torch.log(p0) # (B, 1, D) 169 | log_q0 = torch.log(q0 + eps) * mask # (B, N, 1) 170 | log_t = (log_q0 + log_p0) * mask # (B, N, D) 171 | log_s = (log_q0 + log_p0) * mask # (B, N, D) 172 | log_mu = torch.log(p0) # (B, 1, D) 173 | log_eta = torch.log(q0 + eps) * mask # (B, N, 1) 174 | z = torch.zeros_like(log_t) # (B, N, D) 175 | z1 = torch.zeros_like(p0) # (B, 1, D) 176 | z2 = torch.zeros_like(q0) # (B, N, 1) 177 | for k in range(num): 178 | n = min([k, a1.shape[0] - 1]) 179 | # update logP 180 | tmp1 = torch.matmul(c2, torch.exp(log_s) * mask) 181 | tmp2 = torch.matmul(tmp1, c1) 182 | y = (x + a0[n] * tmp2 - z) / rho[n] + log_s # (B, N, D) 183 | log_t = mask * (log_eta - torch.logsumexp(y, dim=2, keepdim=True)) + y # (B, N, D) 184 | # update logS 185 | tmp1 = torch.matmul(c2, torch.exp(log_t) * mask) 186 | tmp2 = torch.matmul(tmp1, c1) 187 | y = (z + a0[n] * tmp2 + rho[n] * log_t) / (a1[n] + rho[n]) # (B, N, D) 188 | ymin, _ = torch.min(y, dim=1, keepdim=True) 189 | ymax, _ = torch.max(ymin - mask * ymin + y, dim=1, keepdim=True) # (B, 1, D) 190 | # (B, N, D) 191 | log_s = mask * ( 192 | log_mu - torch.log(torch.sum(torch.exp((y - ymax) * mask) * mask, dim=1, keepdim=True)) - ymax) + y 193 | # update dual variables 194 | t = torch.exp(log_t) * mask 195 | s = torch.exp(log_s) * mask 196 | z = z + rho[n] * (t - s) 197 | # update log_mu 198 | y = (rho[n] * log_mu + a2[n] * log_p0 - z1) / (rho[n] + a2[n]) # (B, 1, D) 199 | log_mu = y - torch.logsumexp(y, dim=2, keepdim=True) # (B, 1, D) 200 | # update log_eta 201 | y = ((rho[n] * log_eta + a3[n] * log_q0 - z2) / (rho[n] + a3[n])) * mask # (B, N, 1) 202 | ymin, _ = torch.min(y, dim=1, keepdim=True) 203 | ymax, _ = torch.max(ymin - mask * ymin + y, dim=1, keepdim=True) # (B, 1, D) 204 | log_eta = (y - torch.log( 205 | torch.sum(torch.exp((y - ymax) * mask) * mask, dim=1, keepdim=True)) - ymax) * mask # (B, N, 1) 206 | # update dual variables 207 | z1 = z1 + rho[n] * (torch.exp(log_mu) - torch.sum(s, dim=1, keepdim=True)) # (B, 1, D) 208 | z2 = z2 + rho[n] * (torch.exp(log_eta) * mask - torch.sum(t, dim=2, keepdim=True)) * mask # (B, N, 1) 209 | return torch.exp(log_t) * mask 210 | 211 | 212 | def uot_badmm2(x: torch.Tensor, p0: torch.Tensor, q0: torch.Tensor, 213 | a1: torch.Tensor, a2: torch.Tensor, a3: torch.Tensor, rho: torch.Tensor, 214 | mask: torch.Tensor, num: int = 4, eps: float = 1e-8) -> torch.Tensor: 215 | """ 216 | Solving regularized optimal transport via Bregman ADMM algorithm (quadratic regularizer) 217 | :param x: (B, N, D), a matrix with N samples and each sample is D-dimensional 218 | :param p0: (B, 1, D), the marginal prior of dimensions 219 | :param q0: (B, N, 1), the marginal prior of samples 220 | :param a1: (num, ), the weight of the entropic term 221 | :param a2: (num, ), the weight of the KL term of p0 222 | :param a3: (num, ), the weight of the KL term of q0 223 | :param rho: (num, ), the learning rate of ADMM 224 | :param mask: (B, N, 1) a masking tensor 225 | :param num: the number of Bregman ADMM iterations 226 | :param eps: the epsilon to avoid numerical instability 227 | :param rho: the learning rate of ADMM 228 | :return: 229 | t: (N, D), the optimal transport matrix 230 | """ 231 | log_p0 = torch.log(p0) # (B, 1, D) 232 | log_q0 = torch.log(q0 + eps) * mask # (B, N, 1) 233 | log_t = (log_q0 + log_p0) * mask # (B, N, D) 234 | log_s = (log_q0 + log_p0) * mask # (B, N, D) 235 | log_mu = torch.log(p0) # (B, 1, D) 236 | log_eta = torch.log(q0 + eps) * mask # (B, N, 1) 237 | z = torch.zeros_like(log_t) # (B, N, D) 238 | z1 = torch.zeros_like(p0) # (B, 1, D) 239 | z2 = torch.zeros_like(q0) # (B, N, 1) 240 | for k in range(num): 241 | n = min([k, a1.shape[0] - 1]) 242 | # update logP 243 | y = (x - a1[n] * torch.exp(log_s) * mask - z) / rho[n] + log_s # (B, N, D) 244 | log_t = mask * (log_eta - torch.logsumexp(y, dim=2, keepdim=True)) + y # (B, N, D) 245 | # update logS 246 | y = (z - a1[n] * torch.exp(log_t) * mask) / rho[n] + log_t # (B, N, D) 247 | ymin, _ = torch.min(y, dim=1, keepdim=True) 248 | ymax, _ = torch.max(ymin - mask * ymin + y, dim=1, keepdim=True) # (B, 1, D) 249 | # (B, N, D) 250 | log_s = mask * ( 251 | log_mu - torch.log(torch.sum(torch.exp((y - ymax) * mask) * mask, dim=1, keepdim=True)) - ymax) + y 252 | # update dual variables 253 | t = torch.exp(log_t) * mask 254 | s = torch.exp(log_s) * mask 255 | z = z + rho[n] * (t - s) 256 | # update log_mu 257 | y = (rho[n] * log_mu + a2[n] * log_p0 - z1) / (rho[n] + a2[n]) # (B, 1, D) 258 | log_mu = y - torch.logsumexp(y, dim=2, keepdim=True) # (B, 1, D) 259 | # update log_eta 260 | y = ((rho[n] * log_eta + a3[n] * log_q0 - z2) / (rho[n] + a3[n])) * mask # (B, N, 1) 261 | ymin, _ = torch.min(y, dim=1, keepdim=True) 262 | ymax, _ = torch.max(ymin - mask * ymin + y, dim=1, keepdim=True) # (B, 1, D) 263 | log_eta = (y - torch.log( 264 | torch.sum(torch.exp((y - ymax) * mask) * mask, dim=1, keepdim=True)) - ymax) * mask # (B, N, 1) 265 | # update dual variables 266 | z1 = z1 + rho[n] * (torch.exp(log_mu) - torch.sum(s, dim=1, keepdim=True)) # (B, 1, D) 267 | z2 = z2 + rho[n] * (torch.exp(log_eta) * mask - torch.sum(t, dim=2, keepdim=True)) * mask # (B, N, 1) 268 | return torch.exp(log_t) * mask 269 | 270 | 271 | def rot_badmm2(x: torch.Tensor, c1: torch.Tensor, c2: torch.Tensor, p0: torch.Tensor, q0: torch.Tensor, 272 | a0: torch.Tensor, a1: torch.Tensor, a2: torch.Tensor, a3: torch.Tensor, rho: torch.Tensor, 273 | mask: torch.Tensor, num: int = 4, eps: float = 1e-8) -> torch.Tensor: 274 | """ 275 | Solving regularized optimal transport via Bregman ADMM algorithm (quadratic regularizer) 276 | :param x: (B, N, D), a matrix with N samples and each sample is D-dimensional 277 | :param c1: (B, D, D), a matrix with size D x D 278 | :param c2: (B, N, N), a matrix with size N x N 279 | :param p0: (B, 1, D), the marginal prior of dimensions 280 | :param q0: (B, N, 1), the marginal prior of samples 281 | :param a0: (num, ), the weight of the GW term 282 | :param a1: (num, ), the weight of the entropic term 283 | :param a2: (num, ), the weight of the KL term of p0 284 | :param a3: (num, ), the weight of the KL term of q0 285 | :param rho: (num, ), the weight of the ADMM term 286 | :param mask: (B, N, 1) a masking tensor 287 | :param num: the number of Bregman ADMM iterations 288 | :param eps: the epsilon to avoid numerical instability 289 | :param rho: the learning rate of ADMM 290 | :return: 291 | t: (N, D), the optimal transport matrix 292 | """ 293 | log_p0 = torch.log(p0) # (B, 1, D) 294 | log_q0 = torch.log(q0 + eps) * mask # (B, N, 1) 295 | log_t = (log_q0 + log_p0) * mask # (B, N, D) 296 | log_s = (log_q0 + log_p0) * mask # (B, N, D) 297 | log_mu = torch.log(p0) # (B, 1, D) 298 | log_eta = torch.log(q0 + eps) * mask # (B, N, 1) 299 | z = torch.zeros_like(log_t) # (B, N, D) 300 | z1 = torch.zeros_like(p0) # (B, 1, D) 301 | z2 = torch.zeros_like(q0) # (B, N, 1) 302 | for k in range(num): 303 | n = min([k, a1.shape[0] - 1]) 304 | # update logP 305 | tmp1 = torch.matmul(c2, torch.exp(log_s) * mask) 306 | tmp2 = torch.matmul(tmp1, c1) 307 | y = (x + a0[n] * tmp2 - a1[n] * torch.exp(log_s) * mask - z) / rho[n] + log_s # (B, N, D) 308 | log_t = mask * (log_eta - torch.logsumexp(y, dim=2, keepdim=True)) + y 309 | # update logS 310 | tmp1 = torch.matmul(c2, torch.exp(log_t) * mask) 311 | tmp2 = torch.matmul(tmp1, c1) 312 | y = (z + a0[n] * tmp2 - a1[n] * torch.exp(log_t) * mask) / rho[n] + log_t # (B, N, D) 313 | ymin, _ = torch.min(y, dim=1, keepdim=True) 314 | ymax, _ = torch.max(ymin - mask * ymin + y, dim=1, keepdim=True) # (B, 1, D) 315 | # (B, N, D) 316 | log_s = mask * ( 317 | log_mu - torch.log(torch.sum(torch.exp((y - ymax) * mask) * mask, dim=1, keepdim=True)) - ymax) + y 318 | # update dual variables 319 | t = torch.exp(log_t) * mask 320 | s = torch.exp(log_s) * mask 321 | z = z + rho[n] * (t - s) 322 | # update log_mu 323 | y = (rho[n] * log_mu + a2[n] * log_p0 - z1) / (rho[n] + a2[n]) # (B, 1, D) 324 | log_mu = y - torch.logsumexp(y, dim=2, keepdim=True) # (B, 1, D) 325 | # update log_eta 326 | y = ((rho[n] * log_eta + a3[n] * log_q0 - z2) / (rho[n] + a3[n])) * mask # (B, N, 1) 327 | ymin, _ = torch.min(y, dim=1, keepdim=True) 328 | ymax, _ = torch.max(ymin - mask * ymin + y, dim=1, keepdim=True) # (B, 1, D) 329 | log_eta = (y - torch.log( 330 | torch.sum(torch.exp((y - ymax) * mask) * mask, dim=1, keepdim=True)) - ymax) * mask # (B, N, 1) 331 | # update dual variables 332 | z1 = z1 + rho[n] * (torch.exp(log_mu) - torch.sum(s, dim=1, keepdim=True)) # (B, 1, D) 333 | z2 = z2 + rho[n] * (torch.exp(log_eta) * mask - torch.sum(t, dim=2, keepdim=True)) * mask # (B, N, 1) 334 | return torch.exp(log_t) * mask 335 | 336 | 337 | class ROT(nn.Module): 338 | """ 339 | Neural network layer to implement regularized optimal transport. 340 | 341 | Parameters: 342 | ----------- 343 | :param num: int, the number of iterations 344 | :param eps: float, default: 1.0e-8 345 | The epsilon avoiding numerical instability 346 | :param f_method: str, default: 'badmm-e' 347 | The feed-forward method, badmm-e, badmm-q, or sinkhorn 348 | """ 349 | 350 | def __init__(self, num: int = 4, eps: float = 1e-8, f_method: str = 'badmm-e'): 351 | super(ROT, self).__init__() 352 | self.num = num 353 | self.eps = eps 354 | self.f_method = f_method 355 | 356 | def forward(self, x, c1, c2, p0, q0, a0, a1, a2, a3, rho, mask): 357 | """ 358 | Solving regularized OT problem 359 | """ 360 | if self.f_method == 'badmm-e': 361 | t = rot_badmm(x, c1, c2, p0, q0, a0, a1, a2, a3, rho, mask, self.num, self.eps) 362 | elif self.f_method == 'badmm-q': 363 | t = rot_badmm2(x, c1, c2, p0, q0, a0, a1, a2, a3, rho, mask, self.num, self.eps) 364 | else: 365 | t = rot_sinkhorn(x, c1, c2, p0, q0, a0, a1, a2, a3, mask, self.num, inner=0, eps=self.eps) 366 | return t 367 | 368 | 369 | class UOT(nn.Module): 370 | """ 371 | Neural network layer to implement unbalanced optimal transport. 372 | 373 | Parameters: 374 | ----------- 375 | :param num: int, the number of iterations 376 | :param eps: float, default: 1.0e-8 377 | The epsilon avoiding numerical instability 378 | :param f_method: str, default: 'badmm-e' 379 | The feed-forward method, badmm-e, badmm-q or sinkhorn 380 | """ 381 | 382 | def __init__(self, num: int = 4, eps: float = 1e-8, f_method: str = 'badmm-e'): 383 | super(UOT, self).__init__() 384 | self.num = num 385 | self.eps = eps 386 | self.f_method = f_method 387 | 388 | def forward(self, x, p0, q0, a1, a2, a3, rho, mask): 389 | """ 390 | Solving regularized OT problem 391 | """ 392 | if self.f_method == 'badmm-e': 393 | t = uot_badmm(x, p0, q0, a1, a2, a3, rho, mask, self.num, self.eps) 394 | elif self.f_method == 'badmm-q': 395 | t = uot_badmm2(x, p0, q0, a1, a2, a3, rho, mask, self.num, self.eps) 396 | else: 397 | t = uot_sinkhorn(x, p0, q0, a1, a2, a3, mask, self.num, eps=self.eps) 398 | return t 399 | -------------------------------------------------------------------------------- /pooling/rotpooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_geometric.utils import softmax, to_dense_batch 4 | from pooling.rotlayers import ROT, UOT 5 | 6 | 7 | def to_sparse_batch(x: torch.Tensor, mask: torch.Tensor = None): 8 | """ 9 | Transform data with size (B, Nmax, D) to (sum_b N_b, D) 10 | :param x: a tensor with size (B, Nmax, D) 11 | :param mask: a tensor with size (B, Nmax) 12 | :return: 13 | x: with size (sum_b N_b, D) 14 | batch: with size (sum_b N_b,) 15 | """ 16 | bs, n_max, d = x.shape 17 | # get num of nodes and reshape x 18 | num_nodes_graphs = torch.zeros_like(x[:, 0, 0], dtype=torch.int64).fill_(n_max) 19 | x = x.reshape(-1, d) # total_nodes * d 20 | # apply mask 21 | if mask is not None: 22 | # get number nodes per graph 23 | num_nodes_graphs = mask.sum(dim=1) # bs 24 | # mask x 25 | x = x[mask.reshape(-1)] # total_nodes * d 26 | # set up batch 27 | batch = torch.repeat_interleave(input=torch.arange(bs, device=x.device), repeats=num_nodes_graphs) 28 | return x, batch 29 | 30 | 31 | class ROTPooling(nn.Module): 32 | def __init__(self, dim: int, a0: float = None, a1: float = None, a2: float = None, a3: float = None, 33 | rho: float = None, num: int = 4, eps: float = 1e-8, f_method: str = 'badmm-e', 34 | p0: str = 'fixed', q0: str = 'fixed', same_para: bool = False): 35 | super(ROTPooling, self).__init__() 36 | self.dim = dim 37 | self.eps = eps 38 | self.num = num 39 | self.f_method = f_method 40 | self.p0 = p0 41 | self.q0 = q0 42 | self.same_para = same_para 43 | 44 | if rho is None: 45 | if self.same_para: 46 | self.rho = nn.Parameter(0.1 * torch.randn(1), requires_grad=True) 47 | else: 48 | self.rho = nn.Parameter(0.1 * torch.randn(self.num), requires_grad=True) 49 | else: 50 | if self.same_para: 51 | self.rho = rho * torch.ones(1) 52 | else: 53 | self.rho = rho * torch.ones(self.num) 54 | 55 | if a0 is None: 56 | if self.same_para: 57 | self.a0 = nn.Parameter(0.1 * torch.randn(1), requires_grad=True) 58 | else: 59 | print("----------------------------------------") 60 | self.a0 = nn.Parameter(0.1 * torch.randn(self.num), requires_grad=True) 61 | else: 62 | if self.same_para: 63 | self.a0 = a0 * torch.ones(1) 64 | else: 65 | self.a0 = a0 * torch.ones(self.num) 66 | 67 | if a1 is None: 68 | if self.same_para: 69 | self.a1 = nn.Parameter(0.1 * torch.randn(1), requires_grad=True) 70 | else: 71 | self.a1 = nn.Parameter(0.1 * torch.randn(self.num), requires_grad=True) 72 | else: 73 | if self.same_para: 74 | self.a1 = a1 * torch.ones(1) 75 | else: 76 | print("------better with gradient--") 77 | #self.a1 = a1 * torch.ones(self.num) 78 | self.a1 = nn.Parameter(a1 * torch.ones(self.num), requires_grad=True) 79 | 80 | if a2 is None: 81 | if self.same_para: 82 | self.a2 = nn.Parameter(0.1 * torch.randn(1), requires_grad=True) 83 | else: 84 | self.a2 = nn.Parameter(0.1 * torch.randn(self.num), requires_grad=True) 85 | else: 86 | if self.same_para: 87 | self.a2 = a2 * torch.ones(1) 88 | else: 89 | self.a2 = a2 * torch.ones(self.num) 90 | 91 | if a3 is None: 92 | if self.same_para: 93 | self.a3 = nn.Parameter(0.1 * torch.randn(1), requires_grad=True) 94 | else: 95 | self.a3 = nn.Parameter(0.1 * torch.randn(self.num), requires_grad=True) 96 | else: 97 | if self.same_para: 98 | self.a3 = a3 * torch.ones(1) 99 | else: 100 | self.a3 = a3 * torch.ones(self.num) 101 | 102 | self.linear_r1 = nn.Linear(in_features=self.dim, out_features=2 * self.dim, bias=False) 103 | self.linear_r2 = nn.Linear(in_features=2 * self.dim, out_features=1, bias=False) 104 | self.tanh = nn.Tanh() 105 | 106 | self.softplus = nn.Softplus() 107 | self.sigmoid = nn.Sigmoid() 108 | self.softmax1 = nn.Softmax(dim=1) 109 | self.rot = ROT(num=self.num, eps=self.eps, f_method=self.f_method) 110 | 111 | def forward(self, x, batch): 112 | """ 113 | The feed-forward function of ROT-Pooling 114 | :param x: (sum_b N_b, D) samples 115 | :param batch: (sum_b N_b, ) the index of sets in the batch 116 | :return: 117 | a pooling result with size (K, D) 118 | """ 119 | if self.q0 != 'fixed': 120 | q0 = softmax(self.linear_r2(self.tanh(self.linear_r1(x))), batch) # (N, 1) 121 | else: 122 | q0 = softmax(torch.zeros_like(x[:, 0].unsqueeze(1)), batch) # (N, 1) 123 | x, mask = to_dense_batch(x, batch) # (B, Nmax, D) 124 | mask = mask.unsqueeze(2) # (B, Nmax, 1) 125 | q0, _ = to_dense_batch(q0, batch) # (B, Nmax, 1) 126 | p0 = (torch.ones_like(x[:, 0, :]) / self.dim).unsqueeze(1) # (B, 1, D) 127 | c1 = torch.matmul(x.permute(0, 2, 1), x) / x.shape[1] # (B, D, D) 128 | c2 = torch.matmul(x, x.permute(0, 2, 1)) / x.shape[2] # (B, Nmax, Nmax) 129 | # 20221004 130 | c1 = c1 / torch.max(c1) 131 | c2 = c2 / torch.max(c2) 132 | rho = self.softplus(self.rho) 133 | if rho.shape[0] == 1: 134 | rho = rho.repeat(self.num) 135 | a0 = self.softplus(self.a0) 136 | if a0.shape[0] == 1: 137 | a0 = a0.repeat(self.num) 138 | 139 | a1 = self.softplus(self.a1) 140 | if a1.shape[0] == 1: 141 | a1 = a1.repeat(self.num) 142 | a2 = self.softplus(self.a2) 143 | if a2.shape[0] == 1: 144 | a2 = a2.repeat(self.num) 145 | a3 = self.softplus(self.a3) 146 | if a3.shape[0] == 1: 147 | a3 = a3.repeat(self.num) 148 | trans = self.rot(x, c1, c2, p0, q0, a0, a1, a2, a3, rho, mask) # (B, Nmax, D) 149 | frot = self.dim * x * trans * mask # (B, Nmax, D) 150 | return torch.sum(frot, dim=1, keepdim=False) # (B, D) 151 | 152 | 153 | class UOTPooling(nn.Module): 154 | def __init__(self, dim: int, a1: float = None, a2: float = None, a3: float = None, rho: float = None, 155 | num: int = 4, eps: float = 1e-8, f_method: str = 'badmm-e', 156 | p0: str = 'fixed', q0: str = 'fixed', same_para: bool = False): 157 | super(UOTPooling, self).__init__() 158 | self.dim = dim 159 | self.eps = eps 160 | self.num = num 161 | self.f_method = f_method 162 | self.p0 = p0 163 | self.q0 = q0 164 | self.same_para = same_para 165 | 166 | if rho is None: 167 | if self.same_para: 168 | self.rho = nn.Parameter(0.1 * torch.randn(1), requires_grad=True) 169 | else: 170 | self.rho = nn.Parameter(0.1 * torch.randn(self.num), requires_grad=True) 171 | else: 172 | if self.same_para: 173 | self.rho = rho * torch.ones(1) 174 | else: 175 | self.rho = rho * torch.ones(self.num) 176 | 177 | if a1 is None: 178 | if self.same_para: 179 | self.a1 = nn.Parameter(0.1 * torch.randn(1), requires_grad=True) 180 | else: 181 | self.a1 = nn.Parameter(0.1 * torch.randn(self.num), requires_grad=True) 182 | else: 183 | if self.same_para: 184 | self.a1 = a1 * torch.ones(1) 185 | else: 186 | self.a1 = a1 * torch.ones(self.num) 187 | 188 | if a2 is None: 189 | if self.same_para: 190 | self.a2 = nn.Parameter(0.1 * torch.randn(1), requires_grad=True) 191 | else: 192 | self.a2 = nn.Parameter(0.1 * torch.randn(self.num), requires_grad=True) 193 | else: 194 | if self.same_para: 195 | self.a2 = a2 * torch.ones(1) 196 | else: 197 | self.a2 = a2 * torch.ones(self.num) 198 | 199 | if a3 is None: 200 | if self.same_para: 201 | self.a3 = nn.Parameter(0.1 * torch.randn(1), requires_grad=True) 202 | else: 203 | self.a3 = nn.Parameter(0.1 * torch.randn(self.num), requires_grad=True) 204 | else: 205 | if self.same_para: 206 | self.a3 = a3 * torch.ones(1) 207 | else: 208 | self.a3 = a3 * torch.ones(self.num) 209 | 210 | self.linear_r1 = nn.Linear(in_features=self.dim, out_features=2 * self.dim, bias=False) 211 | self.linear_r2 = nn.Linear(in_features=2 * self.dim, out_features=1, bias=False) 212 | self.tanh = nn.Tanh() 213 | 214 | self.softplus = nn.Softplus() 215 | self.sigmoid = nn.Sigmoid() 216 | self.softmax1 = nn.Softmax(dim=1) 217 | self.uot = UOT(num=self.num, eps=self.eps, f_method=self.f_method) 218 | 219 | def forward(self, x, batch): 220 | """ 221 | The feed-forward function of ROT-Pooling 222 | :param x: (sum_b N_b, D) samples 223 | :param batch: (sum_b N_b, ) the index of sets in the batch 224 | :return: 225 | a pooling result with size (K, D) 226 | """ 227 | if self.q0 != 'fixed': 228 | q0 = softmax(self.linear_r2(self.tanh(self.linear_r1(x))), batch) # (N, 1) 229 | else: 230 | q0 = softmax(torch.zeros_like(x[:, 0].unsqueeze(1)), batch) # (N, 1) 231 | x, mask = to_dense_batch(x, batch) # (B, Nmax, D) 232 | mask = mask.unsqueeze(2) # (B, Nmax, 1) 233 | q0, _ = to_dense_batch(q0, batch) # (B, Nmax, 1) 234 | p0 = (torch.ones_like(x[:, 0, :]) / self.dim).unsqueeze(1) # (B, 1, D) 235 | rho = self.softplus(self.rho) 236 | if rho.shape[0] == 1: 237 | rho = rho.repeat(self.num) 238 | a1 = self.softplus(self.a1) 239 | if a1.shape[0] == 1: 240 | a1 = a1.repeat(self.num) 241 | a2 = self.softplus(self.a2) 242 | if a2.shape[0] == 1: 243 | a2 = a2.repeat(self.num) 244 | a3 = self.softplus(self.a3) 245 | if a3.shape[0] == 1: 246 | a3 = a3.repeat(self.num) 247 | trans = self.uot(x, p0, q0, a1, a2, a3, rho, mask) # (B, Nmax, D) 248 | frot = self.dim * x * trans * mask # (B, Nmax, D) 249 | return torch.sum(frot, dim=1, keepdim=False) # (B, D) 250 | -------------------------------------------------------------------------------- /pooling/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | import matplotlib 5 | from typing import List, Union 6 | matplotlib.rcParams['pdf.fonttype'] = 42 7 | matplotlib.rcParams['ps.fonttype'] = 42 8 | plt.rcParams.update({'font.size': 12}) 9 | 10 | 11 | def visualize_map(data: np.ndarray, dst: str, ticklabels: List[str], xlabel: str = None, ylabel: str = None): 12 | plt.figure(figsize=(6, 5)) 13 | ax = sns.heatmap(data, 14 | linewidth=1, 15 | vmin=0, 16 | vmax=100, 17 | square=True, 18 | annot=True, 19 | cmap="summer", 20 | xticklabels=ticklabels, 21 | yticklabels=ticklabels) 22 | ax.tick_params(top=True, bottom=False, 23 | labeltop=True, labelbottom=False) 24 | if ylabel is not None: 25 | plt.ylabel(ylabel, fontdict={'size': 28}) 26 | if xlabel is not None: 27 | plt.xlabel(xlabel, fontdict={'size': 28}) 28 | plt.tight_layout() 29 | plt.savefig(dst) 30 | plt.close() 31 | 32 | 33 | def visualize_data(x: np.ndarray, dst: str): 34 | plt.figure(figsize=(20, 5)) 35 | sns.heatmap(x, linewidth=1, square=True, cmap="YlGnBu") 36 | plt.axis('off') 37 | plt.tight_layout() 38 | plt.savefig(dst) 39 | plt.close() 40 | 41 | 42 | def visualize_pooling(data_list: List[np.ndarray], dst: str, vmin: float = None, vmax: float = None): 43 | dim, num = data_list[0].shape 44 | tmp = None 45 | for i in range(len(data_list)): 46 | if i == 0: 47 | tmp = data_list[i] 48 | else: 49 | tmp = np.concatenate((tmp, np.zeros((1, num)), data_list[i]), axis=0) 50 | mask = np.zeros_like(tmp) 51 | if len(data_list) > 1: 52 | for i in range(len(data_list) - 1): 53 | mask[dim + i * (dim + 1), :] = True 54 | 55 | plt.figure(figsize=(5, 6)) 56 | if vmin is not None and vmax is not None: 57 | sns.heatmap(tmp, mask=mask, linewidth=1, square=True, cmap="YlGnBu", vmin=vmin, vmax=vmax) 58 | else: 59 | sns.heatmap(tmp, mask=mask, linewidth=1, square=True, cmap="YlGnBu") 60 | plt.axis('off') 61 | plt.tight_layout() 62 | plt.savefig(dst) 63 | plt.close() 64 | 65 | 66 | def visualize_errorbar_curve(xs: Union[np.ndarray, List], ms: np.ndarray, vs: np.ndarray, colors: List[str], 67 | labels: List[str], xlabel: str, ylabel: str, dst: str): 68 | plt.figure(figsize=(5, 5)) 69 | for i in range(len(colors)): 70 | plt.plot(xs, ms[:, i], 'o-', label=labels[i], color=colors[i]) 71 | plt.fill_between(xs, ms[:, i] - vs[:, i], ms[:, i] + vs[:, i], color=colors[i], alpha=0.2) 72 | plt.legend(loc='upper left') 73 | plt.xlabel(xlabel) 74 | plt.ylabel(ylabel) 75 | plt.grid() 76 | plt.tight_layout() 77 | plt.savefig(dst) 78 | plt.close() 79 | --------------------------------------------------------------------------------