├── .DS_Store ├── ICLR24_LieBN_PPT.pdf ├── ICLR24_LieBN_Poster.pdf ├── LieBN ├── Geometry │ ├── Base │ │ ├── LieGroups.py │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ ├── LieGroups.cpython-38.pyc │ │ │ └── __init__.cpython-38.pyc │ ├── Correlation │ │ ├── CorMatrices.py │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── CorMatrices.cpython-38.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── cor_functions.cpython-38.pyc │ │ │ └── ppbcm_functionals.cpython-38.pyc │ │ ├── cor_functions.py │ │ └── ppbcm_functionals.py │ ├── Rotations │ │ ├── RotMatrices.py │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ ├── RotMatrices.cpython-38.pyc │ │ │ └── __init__.cpython-38.pyc │ ├── SPD │ │ ├── SPDMatrices.py │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── SPDMatrices.cpython-38.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── sym_functional.cpython-38.pyc │ │ └── sym_functional.py │ ├── __init__.py │ └── __pycache__ │ │ └── __init__.cpython-38.pyc ├── LieBNBase.py ├── LieBNCor.py ├── LieBNRot.py ├── LieBNSPD.py ├── LieBN_illustration.png ├── README.md ├── __init__.py ├── __pycache__ │ ├── LieBNBase.cpython-38.pyc │ ├── LieBNCor.cpython-38.pyc │ ├── LieBNRot.cpython-38.pyc │ ├── LieBNSPD.cpython-38.pyc │ └── __init__.cpython-38.pyc ├── demos │ ├── liebncor.py │ ├── liebnrot.py │ └── liebnspd.py └── sum_metrics.png ├── LieBN_SPDNet ├── Network │ ├── Get_Model.py │ ├── SPDNetRBN.py │ └── __init__.py ├── SPDNetLieBN.py ├── conf │ ├── LieBN.yaml │ ├── dataset │ │ ├── HDM05.yaml │ │ └── RADAR.yaml │ └── nnet │ │ ├── SPDNet.yaml │ │ ├── SPDNetBN.yaml │ │ └── SPDNetLieBN.yaml ├── cplx │ ├── functional.py │ └── nn.py ├── experiments.sh └── spd │ ├── DataLoader │ ├── FPHA_Loader.py │ ├── HDM05_Loader.py │ └── Radar_Loader.py │ ├── LieBN.py │ ├── __init__.py │ ├── functional.py │ ├── nn.py │ ├── sym_functional.py │ ├── training_script.py │ └── utils.py ├── LieBN_TSMNet ├── LieBN_utilities │ ├── Training.py │ ├── __init__.py │ └── utils.py ├── TSMNet-LieBN.py ├── conf │ ├── LieBN.yaml │ ├── dataset │ │ └── hinss2021.yaml │ ├── evaluation │ │ ├── inter-session+uda.yaml │ │ └── inter-subject+uda.yaml │ ├── nnet │ │ ├── tsmnet.yaml │ │ ├── tsmnet_LieBN.yaml │ │ └── tsmnet_spddsmbn.yaml │ └── preprocessing │ │ └── bb4-36Hz.yaml ├── datasetio │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-310.pyc │ └── eeg │ │ ├── __init__.py │ │ ├── __pycache__ │ │ └── __init__.cpython-310.pyc │ │ └── moabb │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── base.cpython-310.pyc │ │ ├── hinss2021.cpython-310.pyc │ │ └── stieger2021.cpython-310.pyc │ │ ├── base.py │ │ ├── hinss2021.py │ │ └── stieger2021.py ├── experiments_Hinss21.sh ├── library │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-310.pyc │ └── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ └── __init__.cpython-310.pyc │ │ ├── hydra │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ └── __init__.cpython-310.pyc │ │ ├── moabb │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ └── __init__.cpython-310.pyc │ │ ├── pyriemann │ │ └── __init__.py │ │ └── torch │ │ ├── __init__.py │ │ └── __pycache__ │ │ └── __init__.cpython-310.pyc └── spdnets │ ├── BaseBatchNorm.py │ ├── LieBNImpl.py │ ├── Liebatchnorm.py │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── batchnorm.cpython-310.pyc │ ├── functionals.cpython-310.pyc │ ├── manifolds.cpython-310.pyc │ └── modules.cpython-310.pyc │ ├── batchnorm.py │ ├── functionals.py │ ├── manifolds.py │ ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── base.cpython-310.pyc │ │ ├── dann.cpython-310.pyc │ │ ├── eegnet.cpython-310.pyc │ │ ├── shconvnet.cpython-310.pyc │ │ └── tsmnet.cpython-310.pyc │ ├── base.py │ ├── dann.py │ ├── eegnet.py │ ├── shconvnet.py │ └── tsmnet.py │ ├── modules.py │ └── utils │ └── skorch │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── logging.cpython-310.pyc │ └── network.cpython-310.pyc │ ├── logging.py │ └── network.py ├── Readme.md └── environment.yaml /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/.DS_Store -------------------------------------------------------------------------------- /ICLR24_LieBN_PPT.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/ICLR24_LieBN_PPT.pdf -------------------------------------------------------------------------------- /ICLR24_LieBN_Poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/ICLR24_LieBN_Poster.pdf -------------------------------------------------------------------------------- /LieBN/Geometry/Base/LieGroups.py: -------------------------------------------------------------------------------- 1 | """ 2 | SPD computations under LieBN: 3 | @inproceedings{chen2024liebn, 4 | title={A Lie Group Approach to Riemannian Batch Normalization}, 5 | author={Ziheng Chen and Yue Song and Yunmei Liu and Nicu Sebe}, 6 | booktitle={The Twelfth International Conference on Learning Representations}, 7 | year={2024}, 8 | url={https://openreview.net/forum?id=okYdj8Ysru} 9 | } 10 | """ 11 | import torch as th 12 | import torch.nn as nn 13 | 14 | from abc import ABC 15 | 16 | 17 | class LieGroup(nn.Module): 18 | """LieBN for [...,n,n]: Following SPDNetBN, we use .detach() for mean and variance calculation 19 | """ 20 | def __init__(self,is_detach=True): 21 | super().__init__() 22 | self.is_detach=is_detach 23 | 24 | def geodesic(self,P,Q,t): 25 | '''The geodesic from P to Q at t''' 26 | raise NotImplemented 27 | 28 | #--- Methods required in LieBN --- 29 | def deformation(self,X): 30 | raise NotImplemented 31 | 32 | def inv_deformation(self,X): 33 | raise NotImplemented 34 | def dist2Isquare(self,X): 35 | """geodesic square distance to the identity element. 36 | it should keep all the dim, such that cal_geom_var only eliminate the batch dim. 37 | This will support [bs,...,n,n] data 38 | """ 39 | raise NotImplemented 40 | 41 | def cal_geom_mean_(self,X,batchdim=[0]): 42 | raise NotImplemented 43 | 44 | def cal_geom_mean(self,X,batchdim=[0]): 45 | batch = X.detach() if self.is_detach else X 46 | return self.cal_geom_mean_(batch,batchdim) 47 | 48 | def cal_geom_var(self, X,batchdim=[0]): 49 | """Frechet variance""" 50 | batch = X.detach() if self.is_detach else X 51 | return self.dist2Isquare(batch).mean(dim=batchdim) 52 | 53 | def translation(self, X, P, is_inverse): 54 | raise NotImplemented 55 | 56 | def scaling(self, X, factor): 57 | raise NotImplemented 58 | 59 | def __repr__(self): 60 | attributes = [] 61 | for key, value in self.__dict__.items(): 62 | if key.startswith("_"): 63 | continue 64 | if isinstance(value, th.Tensor): 65 | attributes.append(f"{key}.shape={tuple(value.shape)}") # 只显示 Tensor 形状 66 | else: 67 | attributes.append(f"{key}={value}") 68 | 69 | # 处理 register_buffer 注册的变量 70 | for name, buffer in self.named_buffers(recurse=False): # 只打印当前类的 buffers 71 | attributes.append(f"{name}={buffer.item() if buffer.numel() == 1 else tuple(buffer.shape)}") 72 | 73 | return f"{self.__class__.__name__}({', '.join(attributes)})" 74 | 75 | class PullbackEuclideanMetric(ABC): 76 | """The Euclidean computation in the co-domain of PullBack Euclidean Metric: 77 | Ziheng Chen, etal, Adaptive Log-Euclidean Metrics for SPD Matrix Learning 78 | For the subclass of this class, we only need to implement 79 | dist2Isquare, deformation, and inv_deformation 80 | """ 81 | # def __init__(self, is_detach=True): 82 | # super().__init__(is_detach=is_detach) 83 | 84 | def cal_geom_mean_(self, batch,batchdim=[0]): 85 | return batch.mean(dim=batchdim) 86 | 87 | def translation(self, X, P, is_inverse): 88 | X_new = X - P if is_inverse else X + P 89 | return X_new 90 | 91 | def scaling(self, X, factor): 92 | return X * factor 93 | 94 | def geodesic(self,P,Q,t): 95 | return (1 - t) * P + t * Q -------------------------------------------------------------------------------- /LieBN/Geometry/Base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/Base/__init__.py -------------------------------------------------------------------------------- /LieBN/Geometry/Base/__pycache__/LieGroups.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/Base/__pycache__/LieGroups.cpython-38.pyc -------------------------------------------------------------------------------- /LieBN/Geometry/Base/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/Base/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /LieBN/Geometry/Correlation/__init__.py: -------------------------------------------------------------------------------- 1 | from .CorMatrices import Correlation,\ 2 | CorEuclideanCholeskyMetric,CorLogEuclideanCholeskyMetric,\ 3 | CorOffLogMetric,CorLogScaledMetric -------------------------------------------------------------------------------- /LieBN/Geometry/Correlation/__pycache__/CorMatrices.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/Correlation/__pycache__/CorMatrices.cpython-38.pyc -------------------------------------------------------------------------------- /LieBN/Geometry/Correlation/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/Correlation/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /LieBN/Geometry/Correlation/__pycache__/cor_functions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/Correlation/__pycache__/cor_functions.cpython-38.pyc -------------------------------------------------------------------------------- /LieBN/Geometry/Correlation/__pycache__/ppbcm_functionals.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/Correlation/__pycache__/ppbcm_functionals.cpython-38.pyc -------------------------------------------------------------------------------- /LieBN/Geometry/Correlation/ppbcm_functionals.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | @th.jit.script 4 | def hemisphere_to_poincare(x): 5 | """ 6 | Transform points from the open hemisphere HS^n to the Poincaré ball P^n. 7 | 8 | Parameters: 9 | - x: torch.Tensor of shape [..., n+1], where the last dimension 10 | represents the coordinates (x^T, x_{n+1}). 11 | 12 | Returns: 13 | - y: torch.Tensor of shape [..., n], transformed coordinates in the Poincaré ball. 14 | """ 15 | x_T, x_n1 = x[..., :-1], x[..., -1] # Split x into (x_T, x_n+1) 16 | y = x_T / (1 + x_n1.unsqueeze(-1)) # Apply the transformation: y = x_T / (1 + x_n1) 17 | return y 18 | 19 | @th.jit.script 20 | def poincare_to_hemisphere(y): 21 | """Map from Poincaré ball to open hemisphere.""" 22 | norm_y = y.norm(p=2,dim=-1, keepdim=True) ** 2 23 | factor = 1 / (1 + norm_y) 24 | mapped = th.cat((2 * y * factor, (1 - norm_y) * factor), dim=-1) 25 | return mapped 26 | 27 | 28 | -------------------------------------------------------------------------------- /LieBN/Geometry/Rotations/RotMatrices.py: -------------------------------------------------------------------------------- 1 | """ 2 | SPD computations under LieBN: 3 | @inproceedings{chen2024liebn, 4 | title={A Lie Group Approach to Riemannian Batch Normalization}, 5 | author={Ziheng Chen and Yue Song and Yunmei Liu and Nicu Sebe}, 6 | booktitle={The Twelfth International Conference on Learning Representations}, 7 | year={2024}, 8 | url={https://openreview.net/forum?id=okYdj8Ysru} 9 | } 10 | """ 11 | import torch as th 12 | from ..Base.LieGroups import LieGroup 13 | 14 | from geoopt.manifolds import Manifold 15 | from pytorch3d.transforms import matrix_to_axis_angle 16 | 17 | class RotMatrices(LieGroup,Manifold): 18 | """ 19 | Computation for SO3 data with size of [...,n,n] 20 | Following manopt, we use the Lie algebra representation for the tangent spaces 21 | """ 22 | # __scaling__ = Manifold.__scaling__.copy() 23 | name = "Rotation" 24 | ndim = 2 25 | reversible = False 26 | def __init__(self,eps=1e-5,is_detach=True,karcher_steps=1): 27 | super().__init__(is_detach=is_detach) 28 | self.eps=eps; 29 | self.register_buffer('I', th.eye(3)) 30 | self.karcher_steps=karcher_steps 31 | 32 | # === Generating random matrices === 33 | 34 | def random(self, *shape): 35 | """Generate random 3D rotation matrices of shape [..., 3, 3]""" 36 | A = th.rand(shape) # Generate a random matrix 37 | u, _, v = th.linalg.svd(A) # Perform SVD 38 | 39 | # Ensure determinant is +1 40 | det_u = u.det()[..., None] # Reshape to match the last two dimensions 41 | u[..., :, -1] *= th.sign(det_u) # Flip last column where needed 42 | # u_new= u[..., :, -1] * th.sign(det_u) # Flip last column where needed 43 | 44 | # Verify that the output is a valid SO(3) matrix 45 | result, _ = self._check_point_on_manifold(u) 46 | if not result: 47 | raise ValueError("SO(3) init value error") 48 | 49 | return u 50 | 51 | def rand_skrew_sym(self,n,f): 52 | A=th.rand(n,f,3,3) 53 | return A-A.transpose(-1,-2) 54 | 55 | # === For geoopt === 56 | 57 | def quat_axis2log_axis(self,axis_quat): 58 | """"convert quaternion axis into matrix log axis 59 | https://github.com/facebookresearch/pytorch3d/issues/188 60 | """ 61 | log_axis_result = axis_quat.clone() 62 | angle_in_2pi = log_axis_result.norm(p=2, dim=-1, keepdim=True) 63 | mask = angle_in_2pi > th.pi 64 | tmp = 1 - 2 * th.pi / angle_in_2pi 65 | new_values = tmp * log_axis_result 66 | results = th.where(mask, new_values, log_axis_result) 67 | return results 68 | 69 | def matrix2euler_axis(self,R): 70 | quat_vec = matrix_to_axis_angle(R) 71 | log_axis = self.quat_axis2log_axis(quat_vec) 72 | euler_axis = log_axis.div(log_axis.norm(dim=-1,keepdim=True)) 73 | return euler_axis 74 | 75 | def mLog(self, R): 76 | """ 77 | Note that Exp(\alpha A)=Exp(A), with \alpha = \frac{\theta-2\pi}{\theta} 78 | So, for a single rotation matrices, quat_axis or log_axis does not affect self.mExp. 79 | """ 80 | vec = matrix_to_axis_angle(R) 81 | log_vec = self.quat_axis2log_axis(vec) 82 | skew_symmetric = self.vec2skrew(log_vec) 83 | return skew_symmetric 84 | 85 | def vec2skrew(self,vec): 86 | # skew_symmetric = th.zeros_like(vec).unsqueeze(-1).expand(*vec.shape, 3).contiguous() 87 | skew_symmetric = th.zeros(*vec.shape, 3,dtype=vec.dtype,device=vec.device) 88 | skew_symmetric[..., 0, 1] = -vec[..., 2] 89 | skew_symmetric[..., 1, 0] = vec[..., 2] 90 | skew_symmetric[..., 0, 2] = vec[..., 1] 91 | skew_symmetric[..., 2, 0] = -vec[..., 1] 92 | skew_symmetric[..., 1, 2] = -vec[..., 0] 93 | skew_symmetric[..., 2, 1] = vec[..., 0] 94 | return skew_symmetric 95 | def mExp(self, S): 96 | """Computing matrix exponential for skrew symmetric matrices""" 97 | a, b, c = S[..., 0, 1], S[..., 0, 2], S[..., 1, 2] 98 | theta = th.sqrt(a ** 2 + b ** 2 + c ** 2).unsqueeze(-1).unsqueeze(-1) 99 | 100 | S_normalized = S / theta 101 | S_norm_squared = S_normalized.matmul(S_normalized) 102 | sin_theta = th.sin(theta) 103 | cos_theta = th.cos(theta) 104 | tmp_S = self.I + sin_theta * S_normalized + (1 - cos_theta) * S_norm_squared 105 | 106 | S_new = th.where(theta < self.eps, S-S.detach()+self.I, tmp_S) # S+I to ensure autograd 107 | 108 | return S_new 109 | 110 | def transp(self, x, y, v): 111 | return v 112 | 113 | def inner(self, x, u, v, keepdim=False): 114 | if v is None: 115 | v = u 116 | return th.sum(u * v, dim=[-2, -1], keepdim=keepdim) 117 | 118 | def projx(self, x): 119 | u, s, vt = th.linalg.svd(x) 120 | ones = th.ones_like(s)[..., :-1] 121 | signs = th.sign(th.det(th.matmul(u, vt))).unsqueeze(-1) 122 | flip = th.cat([ones, signs], dim=-1) 123 | result = u.matmul(th.diag_embed(flip)).matmul(vt) 124 | return result 125 | 126 | def proju(self,X, H): 127 | k = self.multiskew(H) 128 | return k 129 | 130 | def egrad2rgrad(self, x, u): 131 | """Map the Euclidean gradient :math:`u` in the ambient space on the tangent 132 | space at :math:`x`. 133 | """ 134 | k = self.multiskew(x.transpose(-1, -2).matmul(u)) 135 | return k 136 | 137 | def retr(self, X,U): 138 | Y = X + X.matmul(U) 139 | Q, R = th.linalg.qr(Y) 140 | New = th.matmul(Q, th.diag_embed(th.sign(th.sign(th.diagonal(R, dim1=-2, dim2=-1)) + 0.5))) 141 | return New 142 | 143 | def multiskew(self,A): 144 | return 0.5 * (A - A.transpose(-1,-2)) 145 | 146 | def logmap(self,R,S): 147 | """ return skrew symmetric matrices """ 148 | return self.mLog(R.transpose(-1,-2).matmul(S)) 149 | 150 | def expmap(self,R,V): 151 | """ V is the skrew symmetric matrices """ 152 | return R.matmul(self.mExp(V)) 153 | 154 | def geodesic(self,R,S,t): 155 | """ the geodesic connecting R and s """ 156 | vector = self.logmap(R,S) 157 | X_new = R.matmul(self.mExp(t*vector)) 158 | return X_new 159 | 160 | def trace(self,m): 161 | """Computation for trace of m of [...,n,n]""" 162 | return th.einsum("...ii->...", m) 163 | 164 | def _check_point_on_manifold( 165 | self, x: th.Tensor, *, atol=1e-5, rtol=1e-8 166 | ): 167 | 168 | if x.shape[-1] != 3 or x.shape[-2] != 3: 169 | raise ValueError("Input matrices must be 3x3.") 170 | 171 | # Check orthogonality 172 | is_orthogonal = th.allclose(x @ x.transpose(-1, -2),self.I, atol=atol,rtol=rtol) 173 | 174 | # Check determinant 175 | det = th.det(x) 176 | is_det_one = th.allclose(det, th.tensor(1.0, device=x.device,dtype=x.dtype), atol=atol,rtol=rtol) 177 | 178 | # Combine both conditions 179 | is_SO3 = is_orthogonal & is_det_one 180 | 181 | return is_SO3, None 182 | 183 | def _check_vector_on_tangent( 184 | self, x: th.Tensor, u: th.Tensor, *, atol=1e-5, rtol=1e-8 185 | ): 186 | """Check whether u is a skrew symmetric matrices""" 187 | diff = u + u.transpose(-1,-2) 188 | ok = th.allclose(diff, th.zeros_like(diff), atol=atol, rtol=rtol) 189 | return ok, None 190 | 191 | def is_not_equal(self, a, b, eps=0.01): 192 | """ Return true if not eaqual""" 193 | return th.nonzero(th.abs(a - b) > eps) 194 | 195 | # === Computing angles === 196 | 197 | def cal_roc_angel_batch(self,r,epsilon=1e-4): 198 | """ 199 | Following the matlab implemetation, we set derivative=0 for tr near -1 or near 3. 200 | Besides, there could be cases where tr \in [-1-eps, 3+eps] due to numerical error. 201 | We view the cases beyond the [-1,3] as -1 or 3, and the derivative is 0. 202 | return: 203 | tr <= -1-epsilon, theta=pi (derivative is 0) 204 | tr >= 3-epsilon, theta=0 (derivative is 0) 205 | """ 206 | assert epsilon >= 0, "Epsilon must be positive" 207 | 208 | mtrc = self.trace(r) 209 | 210 | maskpi = (mtrc + 1) <= epsilon # tr <= -1 + epsilon 211 | # mtrcpi = -mtrc * maskpi * np.pi # this is different from the matlab implemetation, as its direvative is -pi 212 | mtrcpi = maskpi * th.pi 213 | maskacos = ((mtrc + 1) > epsilon) * ((3- mtrc) > epsilon) # -1+epsilon < tr < 3 - epsilon 214 | 215 | mtrcacos = th.acos((mtrc * maskacos - 1) / 2) * maskacos # -1+epsilon < tr < 3 - epsilon, use the acos 216 | results = mtrcpi + mtrcacos # for tr -tr <= -1 + epsilon and tr >= 3-epsilon, the derivative = 0 217 | return results 218 | 219 | # === For LieBN === 220 | 221 | def cal_geom_mean_(self, X, batchdim=[0,1]): 222 | "Karcher flow" 223 | init_point = X[tuple(0 for _ in batchdim)] # Dynamically selects first elements along batchdim 224 | for ith in range(self.karcher_steps): 225 | tan_data = self.mLog(init_point.transpose(-1,-2).matmul(X)) 226 | tan_mean = tan_data.mean(dim=batchdim) 227 | condition = tan_mean.norm(dim=(-1, -2)) 228 | # print(f'{ith+1}: {condition}') 229 | if th.all(condition<=1e-4): 230 | # th.sum(condition>=0.1) 231 | # print('early stop') 232 | break 233 | init_point = init_point.matmul(self.mExp(tan_mean)) 234 | return init_point 235 | 236 | def scaling(self, X,shift,running_var=None,batchdim=[0,1]): 237 | """Frechet variance""" 238 | Log_X = self.mLog(X) 239 | if running_var is None: 240 | Log_X_norm_square = Log_X.norm(p='fro', dim=(-2, -1), keepdim=True).square() 241 | var = th.mean(Log_X_norm_square, dim=batchdim) 242 | var = var.detach() if self.is_detach else var 243 | else: 244 | var = running_var 245 | factor = shift / (var + self.eps).sqrt() 246 | scale_Log = factor * Log_X 247 | X_scaled = self.mExp(scale_Log) 248 | return X_scaled,var 249 | 250 | def translation(self, X, P, is_inverse, is_left): 251 | """translation by P""" 252 | if is_left: 253 | X_new = P.transpose(-1, -2).matmul(X) if is_inverse else P.matmul(X) 254 | else: 255 | X_new = X.matmul(P.transpose(-1, -2)) if is_inverse else X.matmul(P) 256 | return X_new 257 | 258 | -------------------------------------------------------------------------------- /LieBN/Geometry/Rotations/__init__.py: -------------------------------------------------------------------------------- 1 | from .RotMatrices import RotMatrices -------------------------------------------------------------------------------- /LieBN/Geometry/Rotations/__pycache__/RotMatrices.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/Rotations/__pycache__/RotMatrices.cpython-38.pyc -------------------------------------------------------------------------------- /LieBN/Geometry/Rotations/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/Rotations/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /LieBN/Geometry/SPD/__init__.py: -------------------------------------------------------------------------------- 1 | from .SPDMatrices import SPDMatrices,\ 2 | SPDLogEuclideanMetric,SPDAdaptiveLogEuclideanMetric,\ 3 | SPDLogCholeskyMetric,SPDAffineInvariantMetric,\ 4 | SPDCholeskyRightInvariantMetric -------------------------------------------------------------------------------- /LieBN/Geometry/SPD/__pycache__/SPDMatrices.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/SPD/__pycache__/SPDMatrices.cpython-38.pyc -------------------------------------------------------------------------------- /LieBN/Geometry/SPD/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/SPD/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /LieBN/Geometry/SPD/__pycache__/sym_functional.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/SPD/__pycache__/sym_functional.cpython-38.pyc -------------------------------------------------------------------------------- /LieBN/Geometry/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/__init__.py -------------------------------------------------------------------------------- /LieBN/Geometry/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /LieBN/LieBNBase.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Ziheng Chen 3 | Please cite the paper below if you use the code: 4 | 5 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024. 6 | Ziheng Chen, Yue Song, Tianyang Xu, Zhiwu Huang, Xiao-Jun Wu, and Nicu Sebe. Adaptive Log-Euclidean metrics for SPD matrix learning TIP 2024. 7 | 8 | Copyright (C) 2024 Ziheng Chen 9 | All rights reserved. 10 | """ 11 | 12 | import torch as th 13 | import torch.nn as nn 14 | 15 | class LieBNBase(nn.Module): 16 | """ 17 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024 18 | Input X: (batchdim,...,n,n) matrix matrices 19 | Output X_new: (batchdim,...,n,n) batch-normalized matrices 20 | arguments: 21 | shape: excluding the batch dim 22 | batchdim: the fist k dims are batch dim, such as [0],[0,1]... 23 | is_detach: whether use detach when calculating Frecheat statistics. 24 | parameters: 25 | self.weight: [...,n,n] manifold parameter 26 | self.shift: [...,1,1] scalar dispersion 27 | """ 28 | def __init__(self,shape, batchdim=[0], momentum=0.1,is_detach=True): 29 | super().__init__() 30 | self.shape=shape; self.momentum = momentum; self.batchdim=batchdim 31 | self.eps = 1e-5; 32 | self.manifold = None; 33 | self.is_detach=is_detach 34 | 35 | # Handle channel vs non-channel case 36 | if len(self.shape) > 2: 37 | # --- running statistics --- 38 | self.register_buffer("running_mean", th.eye(self.shape[-1]).repeat(*self.shape[:-2], 1, 1)) 39 | self.register_buffer("running_var", th.ones(*shape[:-2], 1, 1)) 40 | # --- parameters --- 41 | self.shift = nn.Parameter(th.ones(*shape[:-2], 1, 1)) 42 | else: 43 | # --- running statistics --- 44 | self.register_buffer("running_mean", th.eye(self.shape[-1])) 45 | self.register_buffer("running_var", th.ones(1)) 46 | # --- parameters --- 47 | self.shift = nn.Parameter(th.ones(1)) 48 | 49 | # biasing weight should be set specifically for each manifold 50 | # self.set_weight() 51 | 52 | def updating_running_statistics(self,batch_mean,batch_var): 53 | self.running_mean.data = self.manifold.geodesic(self.running_mean, batch_mean,self.momentum) 54 | self.running_var.data = (1 - self.momentum) * self.running_var + batch_var * self.momentum 55 | 56 | def set_weight(self): 57 | raise NotImplementedError 58 | 59 | def get_manifold(self): 60 | raise NotImplementedError 61 | 62 | def __repr__(self): 63 | attributes = [] 64 | for key, value in self.__dict__.items(): 65 | if key.startswith("_"): 66 | continue 67 | if isinstance(value, th.Tensor): 68 | attributes.append(f"{key}.shape={tuple(value.shape)}") # only show Tensor shape 69 | else: 70 | attributes.append(f"{key}={value}") 71 | 72 | # register_buffer 73 | for name, buffer in self.named_buffers(recurse=False): 74 | attributes.append(f"{name}={buffer.item() if buffer.numel() == 1 else tuple(buffer.shape)}") 75 | 76 | return f"{self.__class__.__name__}({', '.join(attributes)}) \n {self.manifold}" 77 | -------------------------------------------------------------------------------- /LieBN/LieBNCor.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Ziheng Chen 3 | Please cite the paper below if you use the code: 4 | 5 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024. 6 | Ziheng Chen, Yue Song, Tianyang Xu, Zhiwu Huang, Xiao-Jun Wu, and Nicu Sebe. Adaptive Log-Euclidean metrics for SPD matrix learning TIP 2024. 7 | 8 | Copyright (C) 2024 Ziheng Chen 9 | All rights reserved. 10 | """ 11 | 12 | import torch as th 13 | import geoopt 14 | from geoopt.manifolds import PoincareBall 15 | 16 | from .Geometry.Correlation.CorMatrices import CorEuclideanCholeskyMetric,CorLogEuclideanCholeskyMetric,CorOffLogMetric,CorLogScaledMetric 17 | from .Geometry.Correlation.ppbcm_functionals import poincare_to_hemisphere 18 | 19 | from . import LieBNBase 20 | 21 | class LieBNCor(LieBNBase): 22 | """ Implemented metrics: ECM,LECM,OLM,LSM; all are bi-invariant metric with Abelian groups. 23 | param_mode: trivial,Riem 24 | 25 | Note: 26 | - Our experiments indicates that the Riemmannian is better than trivialization for correlation biasing parameters; 27 | - We do not use detach when calculating Fréchet statistics, as our experiments show it is better than detach. 28 | - is_inv_deformation: when it is False, this might help to simplify computations (for certain following layers). 29 | """ 30 | 31 | def __init__(self,shape ,batchdim=[0],momentum=0.1,is_detach=False, 32 | metric: str ='ECM',alpha=1.,param_1=0.,param_2=0.,max_iter=100, 33 | is_inv_deformation=True,param_mode='Riem'): 34 | super().__init__(shape, batchdim=batchdim, momentum=momentum,is_detach=is_detach) 35 | self.metric = metric; self.alpha = alpha;self.param_1 = param_1;self.param_2 = param_2; 36 | self.max_iter = max_iter;self.is_inv_deformation=is_inv_deformation 37 | self.param_mode=param_mode 38 | self.get_manifold() 39 | self.set_weight() 40 | 41 | def set_weight(self): 42 | """Initializes (n-1) separate Poincare vectors of increasing dimensions (1 to n-1) using the Poincare Ball manifold.""" 43 | # use poincare updates 44 | if self.param_mode=='Riem': 45 | self.weight = th.nn.ParameterList([ 46 | geoopt.ManifoldParameter( 47 | th.zeros(*self.shape[:-2], i) 48 | if len(self.shape) > 2 else th.zeros(i), # Handle channel vs non-channel case 49 | manifold=PoincareBall(c=1.0, learnable=False) 50 | ) 51 | for i in range(1, self.shape[-1]) # Each vector has i dimensions 52 | ]) 53 | elif self.param_mode=='trivial': 54 | # use identical metric updates 55 | self.weight = th.nn.Parameter(th.randn(self.shape)) 56 | 57 | def poincare_to_correlation(self,weight): 58 | """Converts stored Poincare vectors into a correlation matrix without extra memory storage.""" 59 | # Initialize lower-triangular Cholesky factor L 60 | L = th.zeros(self.shape, dtype=weight[0].dtype, device=weight[0].device) 61 | L[...,0, 0] = 1 # Set (0,0) explicitly to 1 62 | 63 | # Directly compute Hemisphere representations and assign to L 64 | for i in range(1, self.shape[-1]): 65 | L[...,i, :i+1] = poincare_to_hemisphere(weight[i - 1]) # No extra storage, direct assignment 66 | return L @ L.transpose(-1,-2) 67 | 68 | def Euc2Codomain(self,weight): 69 | if self.metric in ['ECM','LECM']: 70 | weight_codomain = weight.tril(-1) 71 | elif self.metric == 'OLM': 72 | weight_codomain = weight.tril(-1) + weight.tril(-1).transpose(-1,-2) 73 | elif self.metric == 'LSM': 74 | n = self.shape[-1] 75 | weight_codomain = th.zeros_like(weight) 76 | weight_codomain[..., :n - 1, :n - 1] = weight.tril(-1)[..., 1:n , :n-1] 77 | # **Symmetrization**: Ensure the upper-left part of Row0 is a symmetric matrix 78 | weight_codomain = weight_codomain.tril() + weight_codomain.tril(-1).transpose(-1, -2) 79 | # **Step 2**: Compute the last column so that the sum of each row is 0 80 | weight_codomain[..., :-1, -1] = -th.sum(weight_codomain[..., :-1, :-1], dim=-1) 81 | # **Step 3**: Compute the `[n, n]` element so that the last row can be recovered symmetrically 82 | weight_codomain[..., -1, -1] = -th.sum(weight_codomain[..., :-1, -1], dim=-1) 83 | # **Step 4**: Restore the last row to be equal to the transpose of the last column 84 | weight_codomain[..., -1, :-1] = weight_codomain[..., :-1, -1] 85 | return weight_codomain 86 | 87 | def forward(self,X): 88 | #deformation 89 | X_deformed = self.manifold.deformation(X) 90 | if self.param_mode == 'Riem': 91 | weight = self.manifold.deformation(self.poincare_to_correlation(self.weight)) 92 | elif self.param_mode == 'trivial': 93 | weight = self.Euc2Codomain(self.weight) 94 | 95 | if(self.training): 96 | # centering 97 | batch_mean = self.manifold.cal_geom_mean(X_deformed,batchdim=self.batchdim) 98 | X_centered = self.manifold.translation(X_deformed,batch_mean,is_inverse=True) 99 | 100 | # scaling and shifting 101 | # Note that as the mean is calculated in closed-form, cal_geom_var (by dist(X_i,I)) is accurate 102 | batch_var = self.manifold.cal_geom_var(X_centered,batchdim=self.batchdim) 103 | factor = self.shift / (batch_var + self.eps).sqrt() 104 | X_scaled = self.manifold.scaling(X_centered, factor) 105 | self.updating_running_statistics(batch_mean, batch_var) 106 | else: 107 | # centering, scaling and shifting 108 | X_centered = self.manifold.translation(X_deformed, self.running_mean, is_inverse=True) 109 | factor = self.shift / (self.running_var + self.eps).sqrt() 110 | X_scaled = self.manifold.scaling(X_centered, factor) 111 | #biasing 112 | X_normalized = self.manifold.translation(X_scaled, weight, is_inverse=False) 113 | # inv_deformation 114 | if self.is_inv_deformation: 115 | X_new = self.manifold.inv_deformation(X_normalized) 116 | else: 117 | X_new = X_normalized 118 | return X_new 119 | 120 | def get_manifold(self): 121 | # ECM,LECM,OLM,LSM 122 | classes = { 123 | "ECM": CorEuclideanCholeskyMetric, 124 | "LECM": CorLogEuclideanCholeskyMetric, 125 | "OLM": CorOffLogMetric, 126 | "LSM": CorLogScaledMetric, 127 | } 128 | 129 | if self.metric == 'OLM': 130 | self.manifold = classes[self.metric](n=self.shape[-1], 131 | alpha=self.alpha,beta=self.param_1, gamma=self.param_2, 132 | max_iter=self.max_iter) 133 | elif self.metric == 'LSM': 134 | self.manifold = classes[self.metric](n=self.shape[-1], 135 | alpha=self.alpha, delta=self.param_1,zeta=self.param_2, 136 | max_iter=self.max_iter) 137 | elif self.metric in ['ECM','LECM']: 138 | self.manifold = classes[self.metric](n=self.shape[-1]) 139 | else: 140 | raise NotImplementedError 141 | 142 | self.manifold.is_detach = self.is_detach 143 | 144 | -------------------------------------------------------------------------------- /LieBN/LieBNRot.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Ziheng Chen 3 | Please cite the paper below if you use the code: 4 | 5 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024. 6 | Ziheng Chen, Yue Song, Tianyang Xu, Zhiwu Huang, Xiao-Jun Wu, and Nicu Sebe. Adaptive Log-Euclidean metrics for SPD matrix learning TIP 2024. 7 | 8 | Copyright (C) 2024 Ziheng Chen 9 | All rights reserved. 10 | """ 11 | 12 | import torch as th 13 | import geoopt 14 | from .Geometry.Rotations import RotMatrices 15 | from . import LieBNBase 16 | 17 | class LieBNRot(LieBNBase): 18 | """ LieBN on the 3 x 3 rotations under the bi-invariant metric. 19 | is_left: True for left translation for centering and biasing. 20 | is_detach in RotMatrices is set to True by default. 21 | """ 22 | def __init__(self,shape ,batchdim=[0,1],momentum=0.1,is_detach=True, 23 | karcher_steps=1,is_left=True): 24 | super().__init__(shape,batchdim,momentum,is_detach) 25 | self.karcher_steps = karcher_steps;self.is_left=is_left 26 | self.get_manifold() 27 | self.set_weight() 28 | 29 | def set_weight(self): 30 | if len(self.shape) > 2: 31 | self.weight = geoopt.ManifoldParameter(th.eye(self.shape[-1]).repeat(*self.shape[:-2], 1, 1), 32 | manifold=RotMatrices()) 33 | else: 34 | self.weight = geoopt.ManifoldParameter(th.eye(self.shape[-1]),manifold=RotMatrices()) 35 | 36 | def forward(self,S): 37 | if(self.training): 38 | batch_mean = self.manifold.cal_geom_mean(S,batchdim=self.batchdim) 39 | X_centered = self.manifold.translation(S,batch_mean,is_inverse=True,is_left=self.is_left) 40 | X_scaled,var = self.manifold.scaling(X_centered,shift=self.shift,batchdim=self.batchdim) 41 | self.updating_running_statistics(batch_mean, var) 42 | else: 43 | X_centered = self.manifold.translation(S, self.running_mean, is_inverse=True,is_left=self.is_left) 44 | X_scaled,_ = self.manifold.scaling(X_centered,shift=self.shift,running_var=self.running_var) 45 | X_normalized = self.manifold.translation(X_scaled, self.weight, is_inverse=False,is_left=self.is_left) 46 | 47 | return X_normalized 48 | 49 | def get_manifold(self): 50 | self.manifold = RotMatrices(karcher_steps=self.karcher_steps) 51 | self.manifold.is_detach = self.is_detach 52 | 53 | -------------------------------------------------------------------------------- /LieBN/LieBNSPD.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Ziheng Chen 3 | Please cite the paper below if you use the code: 4 | 5 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024. 6 | Ziheng Chen, Yue Song, Tianyang Xu, Zhiwu Huang, Xiao-Jun Wu, and Nicu Sebe. Adaptive Log-Euclidean metrics for SPD matrix learning TIP 2024. 7 | 8 | Copyright (C) 2024 Ziheng Chen 9 | All rights reserved. 10 | """ 11 | 12 | import torch as th 13 | import geoopt 14 | from geoopt.manifolds import SymmetricPositiveDefinite 15 | 16 | from .Geometry.SPD.SPDMatrices import SPDLogEuclideanMetric,SPDAdaptiveLogEuclideanMetric,SPDLogCholeskyMetric,\ 17 | SPDAffineInvariantMetric,SPDCholeskyRightInvariantMetric,tril_param_metric,bi_param_metric,single_param_metric 18 | 19 | from . import LieBNBase 20 | 21 | class LieBNSPD(LieBNBase): 22 | """ Arguments: 23 | metric: LEM, ALEM, LCM, AIM, CRIM, with the last one is right-invariant 24 | karcher_step=1: for AIM and CRIM 25 | power=1: power deformation for AIM, CRIM, and LCM 26 | (alpha, beta)=(1,0): O(n)-invariance inner product in AIM, CRIM, LEM, and ALEM. 27 | 28 | Note: 29 | - The SPD biasing paramter is optimized by AIM-based geoopt. 30 | - For ALEM, the paramter A is in the SPDAdaptiveLogEuclideanMetric. 31 | - There are two ways to calculate the varriance: 1 via d(P_i, M); 2 via d(\bar{P}_i, I). 32 | If the Fréchet mean are accurate, such as LEM,LCM, and ALEM, these two ways are euqivalent. 33 | For AIM and CRIM, as we set karcher_steps=1, the 2nd is an approximation. We adpot the 2nd one for efficiency. 34 | If one set karcher flow to converge, 1 is euqivalent to 2. 35 | - Follwoing SPDNetBN, we use detach when calculating Fréchet statistics 36 | """ 37 | 38 | def __init__(self,shape ,batchdim=[0],momentum=0.1,is_detach=True, 39 | metric: str ='AIM',power=1.,alpha=1.0,beta=0.,karcher_steps=1,): 40 | super().__init__(shape,batchdim,momentum,is_detach) 41 | self.metric = metric; self.power = power;self.alpha = alpha;self.beta = beta; 42 | self.karcher_steps = karcher_steps 43 | self.get_manifold() 44 | self.set_weight() 45 | 46 | def set_weight(self): 47 | if len(self.shape) > 2: 48 | self.weight = geoopt.ManifoldParameter(th.eye(self.shape[-1]).repeat(*self.shape[:-2], 1, 1), 49 | manifold=SymmetricPositiveDefinite()) 50 | else: 51 | self.weight = geoopt.ManifoldParameter(th.eye(self.shape[-1]), 52 | manifold=SymmetricPositiveDefinite()) 53 | 54 | def forward(self,X): 55 | #deformation 56 | X_deformed = self.manifold.deformation(X) 57 | weight = self.manifold.deformation(self.weight) 58 | 59 | if(self.training): 60 | # centering 61 | batch_mean = self.manifold.cal_geom_mean(X_deformed,batchdim=self.batchdim) 62 | X_centered = self.manifold.translation(X_deformed,batch_mean,is_inverse=True) 63 | 64 | # scaling and shifting 65 | # As centering is an isometry, batch variance is equal to the one of the centered data (to the identity element). 66 | # This is more efficient than calculating the original bach variance. 67 | # Note that if the mean is calculated approximatedly, such as karcher flow, cal_geom_var is also approximated 68 | # One can also try original bach variance. 69 | batch_var = self.manifold.cal_geom_var(X_centered,batchdim=self.batchdim) 70 | factor = self.shift / (batch_var + self.eps).sqrt() 71 | X_scaled = self.manifold.scaling(X_centered, factor) 72 | self.updating_running_statistics(batch_mean, batch_var) 73 | 74 | else: 75 | # centering, scaling and shifting 76 | X_centered = self.manifold.translation(X_deformed, self.running_mean, is_inverse=True) 77 | factor = self.shift / (self.running_var + self.eps).sqrt() 78 | X_scaled = self.manifold.scaling(X_centered, factor) 79 | #biasing 80 | X_normalized = self.manifold.translation(X_scaled, weight, is_inverse=False) 81 | # inv_deformation 82 | X_new = self.manifold.inv_deformation(X_normalized) 83 | 84 | return X_new 85 | 86 | def get_manifold(self): 87 | classes = { 88 | "LEM": SPDLogEuclideanMetric, 89 | "ALEM": SPDAdaptiveLogEuclideanMetric, 90 | "LCM": SPDLogCholeskyMetric, 91 | "AIM": SPDAffineInvariantMetric, 92 | "CRIM": SPDCholeskyRightInvariantMetric, 93 | } 94 | n=self.shape[-1] 95 | if self.metric in tril_param_metric: 96 | self.manifold = classes[self.metric](n=n, power=self.power,alpha=self.alpha, beta=self.beta, 97 | karcher_steps=self.karcher_steps) 98 | elif self.metric in bi_param_metric: 99 | self.manifold = classes[self.metric](n=n, alpha=self.alpha, beta=self.beta) 100 | elif self.metric in single_param_metric: 101 | self.manifold = classes[self.metric](n=n, power=self.power) 102 | else: 103 | raise NotImplementedError 104 | 105 | self.manifold.is_detach=self.is_detach 106 | 107 | -------------------------------------------------------------------------------- /LieBN/LieBN_illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/LieBN_illustration.png -------------------------------------------------------------------------------- /LieBN/README.md: -------------------------------------------------------------------------------- 1 | [](https://arxiv.org/abs/2403.11261) 2 | [](https://openreview.net/forum?id=okYdj8Ysru) 3 | [](https://openreview.net/pdf?id=okYdj8Ysru) 4 | 5 | # Lie Group Batch Normalization 6 | 7 |
8 | 9 |

Figure 1: Visualization of LieBN on the SPD, rotation, and correlation manifolds.

10 |
11 | 12 | ## Introduction 13 | This is the official implementation for our ICLR 2024 publications: 14 | - *A Lie Group Approach to Riemannian Batch Normalization* 15 | [[OpenReview](https://openreview.net/forum?id=okYdj8Ysru)] 16 | 17 | - *LieBN: Batch Normalization over Lie Groups* (Conference extension, in submission) 18 | 19 | If you find this project helpful, please consider citing us as follows: 20 | 21 | ```bib 22 | @inproceedings{chen2024liebn, 23 | title={A Lie Group Approach to Riemannian Batch Normalization}, 24 | author={Ziheng Chen and Yue Song and Yunmei Liu and Nicu Sebe}, 25 | booktitle={ICLR}, 26 | year={2024}, 27 | } 28 | ``` 29 | 30 | If you have any problem, please contact me via ziheng_ch@163.com. 31 | 32 | ## Implementations 33 | This source code contains LieBN on the following manifolds: 34 | - SPD manifolds: Log-Euclidean Metric (LEM), Affine-Invariant Metric (AIM), Log-Cholesky Metric (LCM), and our proposed Cholesky Right Invariant Metric (CRIM); 35 | - Rotation groups: the canonical bi-invariant metric; 36 | - Full-rank correlation manifolds: Euclidean-Cholesky Metric (ECM), Log-Euclidean-Cholesky Metric (LECM), Off-Log Metric (OLM), and Log-Scaled Metric (LSM). 37 | 38 |
39 |
40 | 41 |
Figure 2: Summary of Invariant Metrics.
42 |
43 |
44 | 45 | **Notes:** 46 | - By default, LieBN uses `.detach()` (`is_detach=True`) when computing Fréchet statistics on the SPD and rotation matrices, following prior work (SPDNetBN). 47 | - For the correlation matrix, however, LieBN sets `is_detach=False` by default, as our experiments suggest this yields better results. 48 | 49 | ## Requirements 50 | Requierments: `torch`, `geoopt`, and `pytorch3d`. 51 | 52 | ## Demos 53 | Demos of LieBN on different geometries can be found in `./demos`: 54 | 55 | -------------------------------------------------------------------------------- /LieBN/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Ziheng Chen 3 | Please cite the paper below if you use the code: 4 | 5 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024. 6 | Ziheng Chen, Yue Song, Tianyang Xu, Zhiwu Huang, Xiao-Jun Wu, and Nicu Sebe. Adaptive Log-Euclidean metrics for SPD matrix learning TIP 2024. 7 | 8 | Copyright (C) 2024 Ziheng Chen 9 | All rights reserved. 10 | """ 11 | 12 | from .LieBNBase import LieBNBase 13 | from .LieBNCor import LieBNCor 14 | from .LieBNSPD import LieBNSPD 15 | from .LieBNRot import LieBNRot -------------------------------------------------------------------------------- /LieBN/__pycache__/LieBNBase.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/__pycache__/LieBNBase.cpython-38.pyc -------------------------------------------------------------------------------- /LieBN/__pycache__/LieBNCor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/__pycache__/LieBNCor.cpython-38.pyc -------------------------------------------------------------------------------- /LieBN/__pycache__/LieBNRot.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/__pycache__/LieBNRot.cpython-38.pyc -------------------------------------------------------------------------------- /LieBN/__pycache__/LieBNSPD.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/__pycache__/LieBNSPD.cpython-38.pyc -------------------------------------------------------------------------------- /LieBN/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /LieBN/demos/liebncor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Demo for LieBN on Correlation Matrices 3 | 4 | @author: Ziheng Chen 5 | Please cite: 6 | 7 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. 8 | A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024 9 | 10 | Copyright (C) 2024 Ziheng Chen 11 | All rights reserved. 12 | """ 13 | 14 | import torch as th 15 | import geoopt 16 | from LieBN import LieBNCor 17 | from LieBN.Geometry.Correlation import Correlation 18 | 19 | # Set the random seed for reproducibility 20 | SEED = 42 21 | th.manual_seed(SEED) 22 | 23 | 24 | # Helper function to generate random correlation matrices 25 | def random_correlation_matrix(bs, c, n): 26 | corr_module = Correlation(n) 27 | return corr_module.random_cor([bs, c, n, n]).to(th.double) 28 | 29 | 30 | # Set dimensions 31 | bs, c, n = 4, 2, 5 # Batch size, channels, matrix size 32 | shape = [c, n, n] 33 | manifold = Correlation(n=n) 34 | 35 | # Instantiate correlation metric module and generate random input 36 | P = manifold.random(bs, c, n,n).requires_grad_(True).to(th.double) # Input correlation matrices 37 | target = manifold.random(bs, c, n,n).to(th.double) # Target correlation matrices for loss computation 38 | 39 | # Instantiate LieBN for Correlation Manifold (ECM, LECM, OLM, LSM) 40 | liebn = LieBNCor(shape=shape, metric='OLM',batchdim=[0]).to(th.double) 41 | 42 | print("\nLieBNCor Layer Initialized:", liebn) 43 | print("Manifold:", liebn.manifold) 44 | 45 | print("\n=== LieBNCor Parameters ===") 46 | for name, param in liebn.named_parameters(): 47 | print(f"Parameter Name: {name}, Shape: {param.shape}, Requires Grad: {param.requires_grad}") 48 | 49 | # Define optimizer and loss function 50 | optimizer = geoopt.optim.RiemannianAdam(liebn.parameters(), lr=1e-3) 51 | criterion = th.nn.MSELoss() 52 | 53 | # Print initial loss 54 | with th.no_grad(): 55 | initial_loss = criterion(liebn(P), target).item() 56 | print(f"\nInitial Loss: {initial_loss:.6f}") 57 | 58 | # Training loop 59 | num_epochs = 5 # Set higher if needed 60 | for i in range(num_epochs): 61 | liebn.train() # Set to training mode 62 | optimizer.zero_grad() # Clear previous gradients 63 | output = liebn(P) 64 | # Compute loss 65 | loss = criterion(output, target) 66 | 67 | # Backpropagation 68 | loss.backward() 69 | 70 | # Gradient Norm Check (Optional) 71 | grad_norm = th.nn.utils.clip_grad_norm_(liebn.parameters(), max_norm=1.0) 72 | 73 | optimizer.step() 74 | 75 | # Print loss every iteration 76 | print(f"Epoch {i + 1} | Loss: {loss.item():.6f} | Grad Norm: {grad_norm:.6f}") 77 | 78 | print("\nProcessed Correlation Matrices in Training Mode:", output.shape) 79 | print(f"Final Training Loss: {loss.item():.6f}") 80 | 81 | # Evaluation mode 82 | liebn.eval() 83 | with th.no_grad(): 84 | test_output = liebn(P) 85 | 86 | print("\nProcessed Correlation Matrices in Testing Mode:", test_output.shape) 87 | 88 | print("\nTraining and Evaluation completed successfully! 🚀") 89 | -------------------------------------------------------------------------------- /LieBN/demos/liebnrot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Demo for LieBN on SO(3) matrices 3 | 4 | @author: Ziheng Chen 5 | Please cite: 6 | 7 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. 8 | A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024 9 | 10 | Copyright (C) 2024 Ziheng Chen 11 | All rights reserved. 12 | """ 13 | 14 | import torch as th 15 | import geoopt 16 | from LieBN import LieBNRot 17 | from LieBN.Geometry.Rotations import RotMatrices 18 | 19 | # Set the random seed for reproducibility 20 | SEED = 42 21 | th.manual_seed(SEED) 22 | 23 | # Batch size, frame, space (SO(3) are 3x3 matrices) 24 | # In LieNet, the shape is [bs,f,s,3,3] 25 | bs, f, s = 5, 2, 3 26 | manifold = RotMatrices() 27 | 28 | # Generate random SO(3) matrices as inputs and targets 29 | input_rot = manifold.random(bs,f, s,3,3).to(th.double).requires_grad_(True).to(th.double) # Input SO(3) matrices 30 | target_rot = manifold.random(bs,f, s,3,3).to(th.double).to(th.double) # Target matrices for loss computation 31 | 32 | # Instantiate LieBN for rotation matrices 33 | # is_left=False/True 34 | # liebn = LieBNRot(shape=[s,3,3], batchdim=[0, 1], is_left=False,karcher_steps=100).to(th.double) 35 | liebn = LieBNRot(shape=[s,3,3], batchdim=[0,1], is_left=False).to(th.double) 36 | 37 | print("\nLieBNRot Layer Initialized:", liebn) 38 | print("Manifold:", liebn.manifold) 39 | 40 | print("\n=== LieBNRot Parameters ===") 41 | for name, param in liebn.named_parameters(): 42 | print(f"Parameter Name: {name}, Shape: {param.shape}, Requires Grad: {param.requires_grad}") 43 | 44 | # Define optimizer and loss function 45 | optimizer = geoopt.optim.RiemannianAdam(liebn.parameters(), lr=1e-3) 46 | criterion = th.nn.MSELoss() 47 | 48 | # Training loop 49 | num_epochs = 2 # Set higher if needed 50 | for i in range(num_epochs): 51 | liebn.train() # Set to training mode 52 | optimizer.zero_grad() # Clear previous gradients 53 | 54 | # Forward pass 55 | output_rot = liebn(input_rot) 56 | 57 | # Compute loss 58 | loss = criterion(output_rot, target_rot) 59 | 60 | # Backpropagation 61 | loss.backward() 62 | optimizer.step() 63 | 64 | # Print loss every iteration 65 | print(f"Epoch {i + 1} | Loss: {loss.item():.6f}") 66 | 67 | print("\nProcessed SO(3) Matrices in Training Mode:", output_rot.shape) 68 | print(f"Final Training Loss: {loss.item():.6f}") 69 | 70 | # Evaluation mode 71 | liebn.eval() 72 | with th.no_grad(): 73 | test_output = liebn(input_rot) 74 | 75 | print("\nProcessed SO(3) Matrices in Testing Mode:", test_output.shape) 76 | 77 | print("\nTraining and Evaluation completed successfully! 🚀") 78 | -------------------------------------------------------------------------------- /LieBN/demos/liebnspd.py: -------------------------------------------------------------------------------- 1 | """ 2 | Demo for LieBN on SPD Matrices 3 | 4 | @author: Ziheng Chen 5 | Please cite: 6 | 7 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. 8 | A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024 9 | 10 | Copyright (C) 2024 Ziheng Chen 11 | All rights reserved. 12 | """ 13 | 14 | import torch as th 15 | import geoopt 16 | from LieBN import LieBNSPD 17 | from LieBN.Geometry.SPD import SPDMatrices 18 | 19 | # Set the random seed for reproducibility 20 | SEED = 42 21 | th.manual_seed(SEED) 22 | 23 | 24 | # Batch size, channels, SPD matrix size 25 | # In SPDNet and SPDNetBN, the shape is [bs,1,n,n] 26 | bs, c, n = 4, 2, 5 27 | 28 | # Generate input SPD matrices and target SPD matrices 29 | manifold=SPDMatrices(n=n) 30 | P = manifold.random(bs, c, n,n).requires_grad_(True).to(th.double) # Input SPD matrices 31 | target = manifold.random(bs, c, n,n).to(th.double) # Target SPD matrices for loss computation 32 | 33 | # LEM,ALEM,LCM,AIM,CRIM 34 | liebn = LieBNSPD(shape=[c, n, n], metric="LCM", batchdim=[0]).to(th.double) 35 | 36 | print("\nLieBNSPD Layer Initialized:", liebn) 37 | 38 | print("\n=== LieBNSPD Parameters ===") 39 | for name, param in liebn.named_parameters(): 40 | print(f"Parameter Name: {name}, Shape: {param.shape}, Requires Grad: {param.requires_grad}") 41 | 42 | # Define optimizer (Adam) and loss function (MSE) 43 | optimizer = geoopt.optim.RiemannianAdam(liebn.parameters(), lr=1e-3) 44 | criterion = th.nn.MSELoss() 45 | 46 | # Print initial loss 47 | with th.no_grad(): 48 | initial_loss = criterion(liebn(P), target).item() 49 | print(f"\nInitial Loss: {initial_loss:.6f}") 50 | 51 | # Training loop 52 | num_epochs = 2 # Increase if needed 53 | for i in range(num_epochs): 54 | liebn.train() # Set to training mode 55 | optimizer.zero_grad() # Clear previous gradients 56 | 57 | # Forward pass 58 | output = liebn(P) 59 | 60 | # Compute loss 61 | loss = criterion(output, target) 62 | 63 | # Backpropagation 64 | loss.backward() 65 | 66 | # Gradient Norm Check (Optional) 67 | grad_norm = th.nn.utils.clip_grad_norm_(liebn.parameters(), max_norm=1.0) 68 | 69 | optimizer.step() 70 | 71 | # Print loss every iteration 72 | print(f"Epoch {i + 1} | Loss: {loss.item():.6f} | Grad Norm: {grad_norm:.6f}") 73 | 74 | print("\nProcessed SPD Matrices in Training Mode:", output.shape) 75 | print(f"Final Training Loss: {loss.item():.6f}") 76 | 77 | # Evaluation mode 78 | liebn.eval() 79 | with th.no_grad(): 80 | test_output = liebn(P) 81 | 82 | print("\nProcessed SPD Matrices in Testing Mode:", test_output.shape) 83 | 84 | print("\nTraining and Evaluation completed successfully! 🚀") 85 | -------------------------------------------------------------------------------- /LieBN/sum_metrics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/sum_metrics.png -------------------------------------------------------------------------------- /LieBN_SPDNet/Network/Get_Model.py: -------------------------------------------------------------------------------- 1 | from Network.SPDNetRBN import SPDNetLieBN,SPDNet 2 | 3 | 4 | def get_model(args): 5 | if args.model_type in args.total_BN_model_types: 6 | model = SPDNetLieBN(args) 7 | elif args.model_type=='SPDNet': 8 | model = SPDNet(args) 9 | else: 10 | raise Exception('unknown model {} or metric {}'.format(args.model_type,args.metric)) 11 | return model -------------------------------------------------------------------------------- /LieBN_SPDNet/Network/SPDNetRBN.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import cplx.nn as nn_cplx 3 | import spd.nn as nn_spd 4 | from spd.LieBN import LieBatchNormSPD 5 | 6 | class SPDNetLieBN(nn.Module): 7 | def __init__(self,args): 8 | super(__class__, self).__init__() 9 | dims = [int(dim) for dim in args.architecture] 10 | self.feature = [] 11 | if args.dataset == 'RADAR': 12 | self.feature.append(nn_cplx.SplitSignal_cplx(2, 20, 10)) 13 | self.feature.append(nn_cplx.CovPool_cplx()) 14 | self.feature.append(nn_spd.ReEig()) 15 | 16 | for i in range(len(dims) - 2): 17 | self.feature.append(nn_spd.BiMap(1,1,dims[i], dims[i + 1])) 18 | self.feature.append(SPDBN(dims[i + 1],args)) 19 | self.feature.append(nn_spd.ReEig()) 20 | 21 | self.feature.append(nn_spd.BiMap(1,1,dims[-2], dims[-1])) 22 | self.feature.append(SPDBN(dims[-1],args)) 23 | self.feature = nn.Sequential(*self.feature) 24 | 25 | self.classifier = LogEigMLR(dims[-1]**2,args.class_num) 26 | 27 | def forward(self, x): 28 | x_spd = self.feature(x) 29 | y = self.classifier(x_spd) 30 | return y 31 | 32 | # LogEig MLR 33 | class LogEigMLR(nn.Module): 34 | def __init__(self, input_dim, classnum): 35 | super(__class__, self).__init__() 36 | self.logeig = nn_spd.LogEig() 37 | self.linear = nn.Linear(input_dim, classnum).double() 38 | 39 | def forward(self, x): 40 | x_vec = self.logeig(x).view(x.shape[0], -1) 41 | y = self.linear(x_vec) 42 | return y 43 | 44 | class SPDBN(nn.Module): 45 | def __init__(self, n,args,ddevice='cpu'): 46 | super(__class__, self).__init__() 47 | if args.BN_type == 'brooks': 48 | self.BN = nn_spd.BatchNormSPD(n,args.momentum) 49 | elif args.BN_type == 'LieBN': 50 | self.BN = LieBatchNormSPD(n, 51 | metric=args.metric, 52 | theta=args.theta, alpha=args.alpha, beta=args.beta,momentum=args.momentum) 53 | else: 54 | raise Exception('unknown BN {}'.format(args.BN_type)) 55 | 56 | def forward(self, x): 57 | x_spd = self.BN(x) 58 | return x_spd 59 | 60 | class SPDNet(nn.Module): 61 | def __init__(self,args): 62 | super(__class__, self).__init__() 63 | dims = [int(dim) for dim in args.architecture] 64 | self.feature = [] 65 | if args.dataset == 'RADAR': 66 | self.feature.append(nn_cplx.SplitSignal_cplx(2, 20, 10)) 67 | self.feature.append(nn_cplx.CovPool_cplx()) 68 | self.feature.append(nn_spd.ReEig()) 69 | 70 | for i in range(len(dims) - 2): 71 | self.feature.append(nn_spd.BiMap(1,1,dims[i], dims[i + 1])) 72 | self.feature.append(nn_spd.ReEig()) 73 | 74 | self.feature.append(nn_spd.BiMap(1,1,dims[-2], dims[-1])) 75 | self.feature = nn.Sequential(*self.feature) 76 | 77 | self.classifier = LogEigMLR(dims[-1]**2,args.class_num) 78 | 79 | def forward(self, x): 80 | x_spd = self.feature(x) 81 | y = self.classifier(x_spd) 82 | return y 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /LieBN_SPDNet/Network/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_SPDNet/Network/__init__.py -------------------------------------------------------------------------------- /LieBN_SPDNet/SPDNetLieBN.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Ziheng Chen 3 | Please cite the paper below if you use the code: 4 | 5 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024 6 | 7 | Copyright (C) 2024 Ziheng Chen 8 | All rights reserved. 9 | """ 10 | 11 | import hydra 12 | from omegaconf import DictConfig 13 | 14 | from spd.training_script import training 15 | 16 | class Args: 17 | """ a Struct Class """ 18 | pass 19 | args=Args() 20 | args.config_name='LieBN.yaml' 21 | 22 | 23 | args.total_BN_model_types=['SPDNetLieBN', 'SPDNetBN'] 24 | args.total_LieBN_model_types=['SPDNetLieBN'] 25 | 26 | @hydra.main(config_path='./conf/', config_name=args.config_name, version_base='1.1') 27 | def main(cfg: DictConfig): 28 | training(cfg,args) 29 | 30 | if __name__ == "__main__": 31 | main() -------------------------------------------------------------------------------- /LieBN_SPDNet/conf/LieBN.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - nnet: SPDNetLieBN # SPDNet SPDNetLieBN SPDNetBN 4 | - dataset: RADAR # HDM05, RADAR, FPHA 5 | - override hydra/launcher: joblib 6 | fit: 7 | epochs: 200 8 | batch_size: 30 9 | threadnum: 1 10 | is_writer: True 11 | cycle: 1 12 | seed: 1024 13 | is_save: True 14 | 15 | hydra: 16 | run: 17 | dir: ./outputs/${dataset.name} 18 | sweep: 19 | dir: ./outputs/${dataset.name} 20 | subdir: '.' 21 | launcher: 22 | n_jobs: -1 23 | job_logging: 24 | handlers: 25 | file: 26 | class: logging.FileHandler 27 | filename: default.log 28 | -------------------------------------------------------------------------------- /LieBN_SPDNet/conf/dataset/HDM05.yaml: -------------------------------------------------------------------------------- 1 | name: HDM05 2 | architecture: [93,30] 3 | path: /data #change this to your data folder 4 | -------------------------------------------------------------------------------- /LieBN_SPDNet/conf/dataset/RADAR.yaml: -------------------------------------------------------------------------------- 1 | name: RADAR 2 | architecture: [20,16,12] 3 | path: /data #change this to your data folder -------------------------------------------------------------------------------- /LieBN_SPDNet/conf/nnet/SPDNet.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | model_type: SPDNet 3 | optimizer: 4 | mode: AMSGRAD 5 | lr: 5e-3 6 | weight_decay: 0. -------------------------------------------------------------------------------- /LieBN_SPDNet/conf/nnet/SPDNetBN.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - SPDNet 3 | - _self_ 4 | model: 5 | model_type: SPDNetBN 6 | BN_type: brooks 7 | momentum: 0.1 -------------------------------------------------------------------------------- /LieBN_SPDNet/conf/nnet/SPDNetLieBN.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - SPDNetBN 3 | - _self_ 4 | model: 5 | model_type: SPDNetLieBN 6 | BN_type: LieBN 7 | metric: AIM # AIM,LEM,LCM 8 | theta: 1.0 9 | alpha: 1.0 10 | beta: 0.0 -------------------------------------------------------------------------------- /LieBN_SPDNet/cplx/functional.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | def conv_cplx(X,conv_re,conv_im): 4 | if(isinstance(X,list)): 5 | XX=[x.split(x.shape[0]//2,0) for x in X] 6 | tmp=X[0].split(X[0].shape[0]//2,0) 7 | tmpp=conv_re(tmp[0][None,...])-conv_im(tmp[1][None,...]) 8 | P_Re=[conv_re(xx[0][None,...])-conv_im(xx[1][None,...]) for xx in XX] 9 | P_Im=[conv_im(xx[0][None,...])+conv_re(xx[1][None,...]) for xx in XX] 10 | P=[th.cat((p_Re,p_Im),0) for p_Re,p_Im in zip(P_Re,P_Im)] 11 | else: 12 | XX=X.split(X.shape[1]//2,1) 13 | P_Re=conv_re(XX[0])-conv_im(XX[1]) 14 | P_Im=conv_im(XX[0])+conv_re(XX[1]) 15 | P=th.cat((P_Re,P_Im),1) 16 | return P 17 | 18 | def split_signal_cplx(X,conv_re,conv_im): 19 | ''' 20 | 1D to 2D complex conv layer, where the weights are adequately placed zeroes and ones to split a signal 21 | ''' 22 | XX=X.split(X.shape[1]//2,1) 23 | P_Re=conv_re(XX[0]) 24 | P_Im=conv_re(XX[1]) 25 | P=th.cat((P_Re[:,None,:,:],P_Im[:,None,:,:]),1) 26 | return P 27 | 28 | def roll(X): 29 | return th.cat((X[:,:,X.shape[-2]//2:,:],X[:,:,:X.shape[-2]//2,:]),2) 30 | 31 | def decibel(X): 32 | # n_fft=X.shape[-2]//2 33 | # X_Re=X[:,:n_fft,:] 34 | # X_Im=X[:,n_fft:,:] 35 | if(isinstance(X,list)): 36 | absX=AbsolutSquared()(X) 37 | X_db=[10*th.log(x) for x in absX] 38 | else: 39 | X_db=10*th.log(absolut_squared(X)) 40 | return X_db 41 | 42 | def absolut_squared(X): 43 | ''' 44 | Inputs a (N,2*C,H,W) complex tensor 45 | Outputs a (N,C,H,W) real tensor containing the input's squared module 46 | ''' 47 | XX=X.split(X.shape[1]//2,1) 48 | return (XX[0]**2+XX[1]**2) 49 | 50 | def oneD2twoD_cplx(x): 51 | ''' 52 | Inputs a 3D complex tensor (N,2*n,T) 53 | Outputs a 4D complex tensor (N,2,n,T) 54 | ''' 55 | return x.view(x.shape[0],2,-1,x.shape[-1]) 56 | 57 | def batch_norm2d_cplx(X,running_mean,running_var,gamma11,gamma12,gamma22,bias,momentum,training): 58 | N,C,H,W=X.shape 59 | XX=list(X.split(C//2,1)) 60 | X_re=XX[0].transpose(0,1).contiguous().view(C//2,N*H*W) #(C//2,NHW) 61 | X_im=XX[1].transpose(0,1).contiguous().view(C//2,N*H*W) #(C//2,NHW) 62 | X_cplx=th.cat((X_re[:,None,:],X_im[:,None,:]),1) #(C//2,2,NHW) 63 | if(training): 64 | mu_re=X_re.mean(1); mu_im=X_im.mean(1) #(C//2,) 65 | mu_cplx=th.cat((mu_re[:,None],mu_im[:,None]),1) #(C//2,2) 66 | var_re=X_re.var(1); var_im=X_im.var(1) #(C//2,) 67 | cov_imre=((X_re-mu_re[:,None])*(X_im-mu_im[:,None])).sum(1)/(N*H*W-1) #(C//2,) 68 | cov_cplx=utils.twobytwo_covmat_from_coeffs(var_re,var_im,cov_imre) 69 | with th.no_grad(): 70 | running_mean=(1-momentum)*running_mean+momentum*mu_cplx #(C//2,2) 71 | running_var=(1-momentum)*running_var+momentum*cov_cplx #(C//2,2,2) 72 | # cov_cplx_sqinv=(functional_spd.SqminvEig()(cov_cplx[:,None,:,:]))[:,0,:,:] 73 | cov_cplx_sqinv=utils.twobytwo_sqinv(var_re,var_im,cov_imre) 74 | Y=cov_cplx_sqinv.matmul((X_cplx-mu_cplx[:,:,None])) 75 | else: 76 | # running_var_sqinv=(functional_spd.SqminvEig()(running_var[:,None,:,:].double())).float()[:,0,:,:] 77 | running_var_sqinv=utils.twobytwo_sqinv(running_var[:,0,0],running_var[:,1,1],running_var[:,1,0]) 78 | Y=running_var_sqinv.matmul((X_cplx-running_mean[:,:,None])) 79 | weight=utils.twobytwo_covmat_from_coeffs(gamma11,gamma22,gamma12) 80 | Z=weight.matmul(Y)+bias[:,:,None] 81 | return Z.view(C,-1).view(C,N,H,W).transpose(0,1) 82 | 83 | def batch_norm2d_cplx_spd(X,running_mean,running_var,weight,bias,momentum,training): 84 | N,C,H,W=X.shape 85 | XX=list(X.split(C//2,1)) 86 | X_re=XX[0].transpose(0,1).contiguous().view(C//2,N*H*W) #(C//2,NHW) 87 | X_im=XX[1].transpose(0,1).contiguous().view(C//2,N*H*W) #(C//2,NHW) 88 | X_cplx=th.cat((X_re[:,None,:],X_im[:,None,:]),1) #(C//2,2,NHW) 89 | if(training): 90 | mu_re=X_re.mean(1); mu_im=X_im.mean(1) #(C//2,) 91 | mu_cplx=th.cat((mu_re[:,None],mu_im[:,None]),1) #(C//2,2) 92 | var_re=X_re.var(1); var_im=X_im.var(1) #(C//2,) 93 | cov_imre=((X_re-mu_re[:,None])*(X_im-mu_im[:,None])).sum(1)/(N*H*W-1) #(C//2,) 94 | cov_cplx=utils.twobytwo_covmat_from_coeffs(var_re,var_im,cov_imre) 95 | with th.no_grad(): 96 | running_mean=(1-momentum)*running_mean+momentum*mu_cplx #(C//2,2) 97 | running_var=(1-momentum)*running_var+momentum*cov_cplx #(C//2,2,2) 98 | cov_cplx_sqinv=(functional_spd.SqminvEig()(cov_cplx[:,None,:,:]))[:,0,:,:] 99 | # cov_cplx_sqinv=utils.twobytwo_sqinv(var_re,var_im,cov_imre) 100 | Y=cov_cplx_sqinv.matmul((X_cplx-mu_cplx[:,:,None])) 101 | else: 102 | running_var_sqinv=(functional_spd.SqminvEig()(running_var[:,None,:,:].double())).float()[:,0,:,:] 103 | # running_var_sqinv=utils.twobytwo_sqinv(running_var[:,0,0],running_var[:,1,1],running_var[:,1,0]) 104 | Y=running_var_sqinv.matmul((X_cplx-running_mean[:,:,None])) 105 | weight=utils.twobytwo_covmat_from_coeffs(_gamma11,_gamma22,_gamma12) ##################### CHANGE 106 | Z=weight.matmul(Y)+bias[:,:,None] 107 | return Z.view(C,-1).view(C,N,H,W).transpose(0,1) 108 | 109 | 110 | def cov_pool_cplx(f,reg_mode='mle',N_estimates=None): 111 | """ 112 | Input f: Temporal n-dimensionnal complex feature map of length T (T=1 for a unitary signal) (batch_size,2,n,T) 113 | Output X: Complex covariance matrix of size (batch_size,2,n,n) 114 | """ 115 | N,_,n,T=f.shape 116 | ff=f.split(f.shape[1]//2,1) 117 | f_re=ff[0]; f_im=ff[1] 118 | if(N_estimates is not None): 119 | f_re=f_re.split(self._Nestimates,-1) 120 | if(f_re[-1].shape[-1]!=self._Nestimates): 121 | f_re=th.cat(f_re[:-1]+(th.cat((f_re[-1],th.zeros(N,1,n,self._Nestimates-f_re[-1].shape[-1])),-1),),1) 122 | else: 123 | f_re=th.cat(f_re,1) 124 | f_im=f_im.split(self._Nestimates,-1) 125 | if(f_im[-1].shape[-1]!=self._Nestimates): 126 | f_im=th.cat(f_im[:-1]+(th.cat((f_im[-1],th.zeros(N,1,n,self._Nestimates-f_im[-1].shape[-1])),-1),),1) 127 | else: 128 | f_im=th.cat(f_im,1) 129 | f_re-=f_re.mean(-1,True); f_im-=f_im.mean(-1,True) 130 | f_re=f_re.double(); f_im=f_im.double() 131 | X_Re=((f_re.matmul(f_re.transpose(-1,-2))+f_im.matmul(f_im.transpose(-1,-2)))/(f.shape[-1]-1)) 132 | X_Im=((f_im.matmul(f_re.transpose(-1,-2))-f_re.matmul(f_im.transpose(-1,-2)))/(f.shape[-1]-1)) 133 | if(reg_mode=='mle'): 134 | pass 135 | elif(self._reg_mode=='add_id'): 136 | X_Re=RegulEig(1e-6)(X_Re) 137 | X_Im=RegulEig(1e-6)(X_Im) 138 | elif(self._reg_mode=='adjust_eig'): 139 | X_Re=AdjustEig(0.75)(X_Re) 140 | X_Im=AdjustEig(0.75)(X_Im) 141 | X=(X_Re+X_Im)/2 ############## later, do cat for HPD 142 | # X=th.cat((X_Re,X_Im),1) #for real complex matrices 143 | return X 144 | -------------------------------------------------------------------------------- /LieBN_SPDNet/cplx/nn.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import numpy.random 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from . import functional 7 | 8 | class Conv_cplx(nn.Module): 9 | ''' 10 | Interface complex conv layer 11 | ''' 12 | def forward(self,X): 13 | return functional.conv_cplx(X,self._conv_Re,self._conv_Im) 14 | 15 | class Conv1d_cplx(Conv_cplx): 16 | ''' 17 | 1D complex conv layer 18 | Inputs a 3D Tensor (batch_size,2*C,T) 19 | C is the number of channels, 2*C=in_channels the effective number of channels for handling complex data 20 | Contains two real-valued conv layers 21 | Output is of shape (batch_size,out_channels,T) (complex channels is out_channels//2) 22 | ''' 23 | def __init__(self,in_channels,out_channels,kernel_size,stride,bias=False): 24 | super(__class__,self).__init__() 25 | self._conv_Re=nn.Conv1d(in_channels//2,out_channels,kernel_size,stride,bias=bias) 26 | self._conv_Im=nn.Conv1d(in_channels//2,out_channels,kernel_size,stride,bias=bias) 27 | 28 | class FFT(Conv1d_cplx): 29 | ''' 30 | 1D complex conv layer, where the weights are the Fourier atoms 31 | ''' 32 | def __init__(self,in_channels,out_channels,kernel_size,stride,bias=False): 33 | super(__class__, self).__init__(in_channels,out_channels,kernel_size,stride,bias) 34 | atoms=signal_utils.fourier_atoms(out_channels,kernel_size).conj() 35 | atoms_Re=th.from_numpy(utils.cplx2bichan(atoms)[:,0,:][:,None,:]).float() 36 | atoms_Im=th.from_numpy(utils.cplx2bichan(atoms)[:,1,:][:,None,:]).float() 37 | self._conv_Re.weight.data=nn.Parameter(atoms_Re) 38 | self._conv_Im.weight.data=nn.Parameter(atoms_Im) 39 | for param in list(self.parameters()): 40 | param.requires_grad=False 41 | 42 | class SplitSignal_cplx(Conv1d_cplx): 43 | ''' 44 | 1D to 2D complex conv layer, where the weights are adequately placed zeroes and ones to split a signal 45 | Input is still (batch_size,2,T) 46 | Output is 4-D complex signal (batch_size,2,window_size,T') instead of (batch_size,2*window_size,T') 47 | ''' 48 | def __init__(self,in_channels,window_size,hop_length): 49 | super(__class__, self).__init__(in_channels,window_size,window_size,hop_length,False) 50 | self._conv_Re.weight.data=th.eye(window_size)[:,None,:] 51 | self._conv_Im.weight.data=th.eye(window_size)[:,None,:] 52 | for param in list(self._conv_Re.parameters()): 53 | param.requires_grad=False 54 | for param in list(self._conv_Im.parameters()): 55 | param.requires_grad=False 56 | def forward(self,X): 57 | return functional.split_signal_cplx(X,self._conv_Re,self._conv_Im) 58 | 59 | class SplitSignalBlock_cplx(Conv1d_cplx): 60 | ''' 61 | 1D to 2D complex conv layer, where the weights are adequately placed zeroes and ones to split a signal 62 | Input is now L blocks of 3-D complex signals (batch_size,L,2*C_in,T), which are acted on independently by the conv layer 63 | Output is L blocks of 3-D complex signals (batch_size,L,2*C_out,window_size,T') 64 | ''' 65 | def __init__(self,in_channels,window_size,hop_length): 66 | super(__class__, self).__init__(in_channels,window_size,window_size,hop_length,False) 67 | # self._conv_Re.weight.data=th.ones_like(self._conv_Re.weight.data) 68 | self._conv_Re.weight.data=th.eye(window_size)[:,None,:] 69 | # self._conv_Im.weight.data=th.ones_like(self._conv_Im.weight.data) 70 | self._conv_Im.weight.data=th.eye(window_size)[:,None,:] 71 | for param in list(self._conv_Re.parameters()): 72 | param.requires_grad=False 73 | for param in list(self._conv_Im.parameters()): 74 | param.requires_grad=False 75 | def forward(self,X): 76 | XX=[functional.split_signal_cplx(X[:,i,:,:],self._conv_Re,self._conv_Im)[:,None,:,:,:] 77 | for i in range(X.shape[1])] 78 | return th.cat(XX,1) 79 | 80 | class Conv2d_cplx(Conv_cplx): 81 | ''' 82 | 2D complex conv layer 83 | Inputs a 4D Tensor (N,2*C,H,W) 84 | C is the number of channels, 2*C the effective number of channels for handling complex data 85 | Contains two real-valued conv layers 86 | ''' 87 | def __init__(self,in_channels,out_channels,kernel_size,stride=(1,1),bias=True): 88 | super(Conv2d_cplx,self).__init__() 89 | self._conv_Re=nn.Conv2d(in_channels//2,out_channels//2,kernel_size,stride,bias=bias) 90 | self._conv_Im=nn.Conv2d(in_channels//2,out_channels//2,kernel_size,stride,bias=bias) 91 | sigma=2./(in_channels//2+out_channels//2) 92 | ampli=numpy.random.rayleigh(sigma,(out_channels//2,in_channels//2,kernel_size[0],kernel_size[1])) 93 | phase=numpy.random.uniform(-np.pi,np.pi,(out_channels//2,in_channels//2,kernel_size[0],kernel_size[1])) 94 | self._conv_Re.weight.data=nn.Parameter(th.Tensor(ampli*np.cos(phase))) 95 | self._conv_Im.weight.data=nn.Parameter(th.Tensor(ampli*np.sin(phase))) 96 | 97 | class ReLU_cplx(nn.ReLU): 98 | pass 99 | # def forward(self,X): 100 | # return F.relu(X) 101 | 102 | class Roll(nn.Module): 103 | def forward(self,X): 104 | return functional.roll(X) 105 | 106 | class Decibel(nn.Module): 107 | ''' 108 | Inputs a (N,2*C,H,W) complex tensor 109 | Outputs a (N,C,H,W) real tensor containing the input's decibel amplitude 110 | ''' 111 | def forward(self,X): 112 | return functional.decibel(X) 113 | 114 | class AbsolutSquared(nn.Module): 115 | ''' 116 | Inputs a (N,2*C,H,W) complex tensor 117 | Outputs a (N,C,H,W) real tensor containing the input's squared module 118 | ''' 119 | def forward(self,X): 120 | return functional.absolut_squared(X) 121 | 122 | class oneD2twoD_cplx(nn.Module): 123 | ''' 124 | Inputs a 3D complex tensor (N,2*n,T) 125 | Outputs a 4D complex tensor (N,2,n,T) 126 | ''' 127 | def forward(self,x): 128 | return functional.oneD2twoD_cplx(x) 129 | 130 | class MaxPool2d_cplx(nn.MaxPool2d): 131 | pass 132 | # def __init__(self,kernel_size,stride=None,padding=0,dilation=1,return_indices=False,ceil_mode=False): 133 | # super(MaxPool2d_cplx, self).__init__(kernel_size,stride,padding,dilation,return_indices,ceil_mode) 134 | # self._abs=AbsolutSquared() 135 | # def forward(self,X): 136 | # N,C,H,W=X.shape 137 | # X_abs=self._abs(X) 138 | # _,idx=F.max_pool2d(X_abs,self.kernel_size,self.stride,self.padding,self.dilation,self.ceil_mode,return_indices=True) 139 | # XX=X.split(C//2,1) 140 | # Y_re=th.gather(XX[0].view(N,C//2,-1),-1,idx.view(N,C//2,-1)).view(idx.shape) 141 | # Y_im=th.gather(XX[1].view(N,C//2,-1),-1,idx.view(N,C//2,-1)).view(idx.shape) 142 | # Y=th.cat((Y_re,Y_im),1) 143 | # return Y 144 | 145 | class BatchNorm2d_cplx(nn.BatchNorm2d): 146 | ''' 147 | Inputs a (N,2*C,H,W) complex tensor 148 | Outputs a whitened and parametrically rescaled (N,2*C,H,W) complex tensor 149 | ''' 150 | # def __init__(self,in_channels): 151 | # super(BatchNorm2d_cplx,self).__init__(in_channels) 152 | # self.momentum=0.1 153 | # self.running_mean=th.zeros(in_channels//2,2) 154 | # self.running_var=th.eye(2)[None,:,:].repeat(in_channels//2,1,1)/(2**.5) 155 | # self._gamma11=nn.Parameter(th.ones(in_channels//2)/2**.5) 156 | # self._gamma22=nn.Parameter(th.ones(in_channels//2)/2**.5) 157 | # self._gamma12=nn.Parameter(th.zeros(in_channels//2)) 158 | # self.bias=nn.Parameter(th.zeros(in_channels//2,2)) 159 | # def forward(self,X): 160 | # return functional.batch_norm2d_cplx(X,self.running_mean,self.running_var, 161 | # self._gamma11,self._gamma12,self._gamma22,self.bias,self.momentum,self.training) 162 | pass 163 | 164 | class BatchNorm2d_cplxSPD(nn.BatchNorm2d): 165 | ''' 166 | Inputs a (N,2*C,H,W) complex tensor 167 | Outputs a whitened and parametrically rescaled (N,2*C,H,W) complex tensor 168 | ''' 169 | def __init__(self,in_channels): 170 | super(BatchNorm2d_cplxSPD,self).__init__(in_channels) 171 | self.momentum=0.1 172 | self.running_mean=th.zeros(in_channels//2,2) 173 | self.running_var=th.eye(2)[None,:,:].repeat(in_channels//2,1,1)/(2**.5) 174 | self.weight_=nn.ParameterList([functional_spd.SPDParameter(th.eye(2)/(2**.5)) for _ in range(in_channels//2)]) 175 | self.bias=nn.Parameter(th.zeros(in_channels//2,2)) 176 | def forward(self,X): 177 | return functional.batch_norm2d_cplx_spd(X,self.running_mean,self.running_var, 178 | self.weight_,self.bias,self.momentum,self.training) 179 | 180 | class BatchNorm2d(nn.BatchNorm2d): 181 | pass 182 | # def __init__(self,in_channels): 183 | # super(BatchNorm2d,self).__init__(in_channels) 184 | # self.momentum=0.1 185 | # self.running_mean=th.zeros(in_channels) 186 | # self.running_var=th.zeros(in_channels) 187 | # self.weight=nn.Parameter(th.ones(in_channels)) 188 | # self.bias=nn.Parameter(th.zeros(in_channels)) 189 | # def forward(self,X): 190 | # N,C,H,W=X.shape 191 | # X_vec=X.transpose(0,1).contiguous().view(C,N*H*W) 192 | # if(self.training): 193 | # mu=X_vec.mean(1); var=X_vec.var(1) 194 | # with th.no_grad(): 195 | # self.running_mean=(1-self.momentum)*self.running_mean+self.momentum*mu 196 | # self.running_var=(1-self.momentum)*self.running_var+self.momentum*var 197 | # Y=(X_vec-mu.view(-1,1))/(var.view(-1,1)**.5+self.eps) 198 | # else: 199 | # Y=(X_vec-self.running_mean.view(-1,1))/(self.running_var.view(-1,1)**.5+self.eps) 200 | # Z=self.weight.view(-1,1)*Y+self.bias.view(-1,1) 201 | # return Z.view(C,N,H,W).transpose(0,1) 202 | # 203 | # def forward(self, x): 204 | # self._check_input_dim(x) 205 | # y = x.transpose(0,1) 206 | # return_shape = y.shape 207 | # y = y.contiguous().view(x.size(1), -1) 208 | # mu = y.mean(dim=1) 209 | # sigma2 = y.var(dim=1) 210 | # if self.training is not True: 211 | # y = y - self.running_mean.view(-1, 1) 212 | # y = y / (self.running_var.view(-1, 1)**.5 + self.eps) 213 | # else: 214 | # if self.track_running_stats is True: 215 | # with torch.no_grad(): 216 | # self.running_mean = (1-self.momentum)*self.running_mean + self.momentum*mu 217 | # self.running_var = (1-self.momentum)*self.running_var + self.momentum*sigma2 218 | # y = y - mu.view(-1,1) 219 | # y = y / (sigma2.view(-1,1)**.5 + self.eps) 220 | # 221 | # y = self.weight.view(-1, 1) * y + self.bias.view(-1, 1) 222 | # return y.view(return_shape).transpose(0,1) 223 | 224 | class CovPool_cplx(nn.Module): 225 | """ 226 | Input f: Temporal n-dimensionnal complex feature map of length T (T=1 for a unitary signal) (batch_size,2,n,T) 227 | Output X: Complex covariance matrix of size (batch_size,2,n,n) 228 | """ 229 | def __init__(self,reg_mode='mle',N_estimates=None): 230 | super(__class__,self).__init__() 231 | self._reg_mode=reg_mode; self.N_estimates=N_estimates 232 | def forward(self,f): 233 | return functional.cov_pool_cplx(f,self._reg_mode,self.N_estimates) 234 | 235 | class CovPoolBlock_cplx(nn.Module): 236 | """ 237 | Input f: L blocks of temporal n-dimensionnal complex feature map of length T of shape (batch_size,L,2,n,T) 238 | Output X: L blocks of complex covariance matrix of size (batch_size,L,2,n,n) 239 | """ 240 | def __init__(self,reg_mode='mle',N_estimates=None): 241 | super(__class__,self).__init__() 242 | self._reg_mode=reg_mode; self.N_estimates=N_estimates 243 | def forward(self,f): 244 | XX=[functional.cov_pool_cplx(f[:,i,:,:,:],self._reg_mode,self.N_estimates)[:,None,:,:,:] 245 | for i in range(f.shape[1])] 246 | return th.cat(XX,1) -------------------------------------------------------------------------------- /LieBN_SPDNet/experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | epochs=200 4 | datasets=RADAR,HDM05 5 | 6 | ### Experiments on SPDNet,SPDNetBN 7 | [ $? -eq 0 ] && python SPDNetLieBN.py -m dataset=$datasets nnet=SPDNetBN,SPDNet fit.epochs=$epochs 8 | 9 | ### Experiments on SPDNetLieBN-LEM 10 | [ $? -eq 0 ] && python SPDNetLieBN.py -m dataset=$datasets nnet=SPDNetLieBN fit.epochs=$epochs nnet.model.metric=LEM 11 | 12 | ### Experiments on SPDNetLieBN-AIM 13 | #RADAR 14 | [ $? -eq 0 ] && python SPDNetLieBN.py -m dataset=RADAR nnet=SPDNetLieBN fit.epochs=$epochs nnet.model.metric=AIM nnet.model.theta=1. 15 | #HDM05 16 | [ $? -eq 0 ] && python SPDNetLieBN.py -m dataset=HDM05 nnet=SPDNetLieBN fit.epochs=$epochs nnet.model.metric=AIM nnet.model.theta=1,1.5 17 | 18 | ### Experiments on SPDNetLieBN-LCM 19 | #RADAR 20 | [ $? -eq 0 ] && python SPDNetLieBN.py -m dataset=RADAR nnet=SPDNetLieBN fit.epochs=$epochs nnet.model.metric=LCM nnet.model.theta=1.,-0.5 21 | ##HDM05 22 | [ $? -eq 0 ] && python SPDNetLieBN.py -m dataset=HDM05 nnet=SPDNetLieBN fit.epochs=$epochs nnet.model.metric=LCM nnet.model.theta=1,0.5 -------------------------------------------------------------------------------- /LieBN_SPDNet/spd/DataLoader/FPHA_Loader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | import numpy as np 5 | import torch as th 6 | import random 7 | 8 | from torch.utils import data 9 | 10 | class DatasetSPD(data.Dataset): 11 | def __init__(self, path, names): 12 | self._path = path 13 | self._names = names 14 | 15 | def __len__(self): 16 | return len(self._names) 17 | 18 | def __getitem__(self, item): 19 | x = np.load(self._path + self._names[item])[None, :, :].real 20 | x = th.from_numpy(x).double() 21 | y = int(self._names[item].split('.')[0].split('_')[-1]) 22 | y = th.from_numpy(np.array(y)).long() 23 | # return x.to(device),y.to(device) 24 | return x, y 25 | 26 | 27 | class DataLoaderFPHA: 28 | def __init__(self, data_path, batch_size): 29 | path_train, path_test = data_path + 'train/', data_path + 'val/' 30 | for filenames in os.walk(path_train): 31 | names_train = sorted(filenames[2]) 32 | for filenames in os.walk(path_test): 33 | names_test = sorted(filenames[2]) 34 | self._train_generator = data.DataLoader(DatasetSPD(path_train, names_train), batch_size=batch_size, 35 | shuffle='True') 36 | self._test_generator = data.DataLoader(DatasetSPD(path_test, names_test), batch_size=batch_size, 37 | shuffle='False') -------------------------------------------------------------------------------- /LieBN_SPDNet/spd/DataLoader/HDM05_Loader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | import numpy as np 5 | import torch as th 6 | import random 7 | 8 | from torch.utils import data 9 | 10 | device = 'cpu' 11 | class DatasetHDM05(data.Dataset): 12 | def __init__(self, path, names): 13 | self._path = path 14 | self._names = names 15 | 16 | def __len__(self): 17 | return len(self._names) 18 | 19 | def __getitem__(self, item): 20 | x = np.load(self._path + self._names[item])[None, :, :].real 21 | x = th.from_numpy(x).double() 22 | y = int(self._names[item].split('.')[0].split('_')[-1]) 23 | y = th.from_numpy(np.array(y)).long() 24 | return x.to(device), y.to(device) 25 | 26 | 27 | class DataLoaderHDM05: 28 | def __init__(self, data_path, pval, batch_size): 29 | for filenames in os.walk(data_path): 30 | names = sorted(filenames[2]) 31 | random.Random(1024).shuffle(names) 32 | N_test = int(pval * len(names)) 33 | train_set = DatasetHDM05(data_path, names[N_test:]) 34 | test_set = DatasetHDM05(data_path, names[:N_test]) 35 | self._train_generator = data.DataLoader(train_set, batch_size=batch_size, shuffle='True') 36 | self._test_generator = data.DataLoader(test_set, batch_size=batch_size, shuffle='False') -------------------------------------------------------------------------------- /LieBN_SPDNet/spd/DataLoader/Radar_Loader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | import numpy as np 5 | import torch as th 6 | import random 7 | 8 | from torch.utils import data 9 | 10 | pval=0.25 #validation percentage 11 | ptest=0.25 #test percentage 12 | # th.cuda.device('cpu') 13 | 14 | class DatasetRadar(data.Dataset): 15 | def __init__(self, path, names): 16 | self._path = path 17 | self._names = names 18 | def __len__(self): 19 | return len(self._names) 20 | def __getitem__(self, item): 21 | x=np.load(self._path+self._names[item]) 22 | x=np.concatenate((x.real[:,None],x.imag[:,None]),axis=1).T 23 | x=th.from_numpy(x) 24 | y=int(self._names[item].split('.')[0].split('_')[-1]) 25 | y=th.from_numpy(np.array(y)) 26 | return x.float(),y.long() 27 | class DataLoaderRadar: 28 | def __init__(self,data_path,pval,batch_size): 29 | for filenames in os.walk(data_path): 30 | names=sorted(filenames[2]) 31 | random.Random().shuffle(names) 32 | N_val=int(pval*len(names)) 33 | N_test=int(ptest*len(names)) 34 | N_train=(len(names)-N_test-N_val) 35 | train_set=DatasetRadar(data_path,names[N_val+N_test:int(N_train)+N_test+N_val]) 36 | test_set=DatasetRadar(data_path,names[:N_test]) 37 | val_set=DatasetRadar(data_path,names[N_test:N_test+N_val]) 38 | self._train_generator=data.DataLoader(train_set,batch_size=batch_size,shuffle='True') 39 | self._test_generator=data.DataLoader(test_set,batch_size=batch_size,shuffle='False') 40 | self._val_generator=data.DataLoader(val_set,batch_size=batch_size,shuffle='False') -------------------------------------------------------------------------------- /LieBN_SPDNet/spd/LieBN.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Ziheng Chen 3 | Please cite the paper below if you use the code: 4 | 5 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024 6 | 7 | Copyright (C) 2024 Ziheng Chen 8 | All rights reserved. 9 | """ 10 | 11 | import torch as th 12 | import torch.nn as nn 13 | from . import functional 14 | import geoopt 15 | from geoopt.manifolds import SymmetricPositiveDefinite 16 | 17 | from spd import sym_functional 18 | 19 | dtype=th.double 20 | device=th.device('cpu') 21 | 22 | class LieBatchNormSPD(nn.Module): 23 | """ 24 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024 25 | Input X: (...,n,n) SPD matrices 26 | Output X_new: (...,n,n) batch-normalized matrices 27 | Parameters: 28 | self.weight: (n,n) SPD parameter 29 | self.shift: scalar dispersion 30 | """ 31 | def __init__(self,n: int ,metric: str ='AIM',theta: float =1.,alpha: float=1.0,beta: float=0.,ddevice='cpu',momentum=0.1): 32 | super(__class__,self).__init__() 33 | self.momentum=momentum;self.n=n; self.ddevice=ddevice; 34 | self.metric = metric;self.theta=th.tensor(theta,dtype=dtype);self.alpha=th.tensor(alpha,dtype=dtype);self.beta=th.tensor(beta,dtype=dtype) 35 | self.running_mean=th.eye(n,dtype=dtype) 36 | self.weight = geoopt.ManifoldParameter(th.eye(n, n, dtype=dtype, device=ddevice), 37 | manifold=SymmetricPositiveDefinite()) 38 | self.eps=1e-05; 39 | self.running_var=th.ones(1,dtype=dtype); 40 | self.shift = nn.Parameter(th.ones(1, dtype=dtype)) 41 | 42 | if metric != "AIM" and metric != "LCM" and metric != "LEM" : 43 | raise Exception('unknown metric {}'.format(metric)) 44 | 45 | def forward(self,X): 46 | #deformation 47 | X_deformed = self.deformation(X) 48 | weight = self.deformation(self.weight) 49 | 50 | if(self.training): 51 | # centering 52 | batch_mean = self.cal_geom_mean(X_deformed) 53 | X_centered = self.Left_translation(X_deformed,batch_mean,is_inverse=True) 54 | # scaling and shifting 55 | batch_var = self.cal_geom_var(X_centered) 56 | X_scaled = self.scaling(X_centered, batch_var, self.shift) 57 | self.updating_running_statistics(batch_mean, batch_var) 58 | 59 | else: 60 | # centering, scaling and shifting 61 | X_centered = self.Left_translation(X_deformed, self.running_mean, is_inverse=True) 62 | X_scaled = self.scaling(X_centered, self.running_var, self.shift) 63 | #biasing 64 | X_normalized = self.Left_translation(X_scaled, weight, is_inverse=False) 65 | # inv_deformation 66 | X_new = self.inv_deformation(X_normalized) 67 | 68 | return X_new 69 | def alpha_beta_dist(self, X): 70 | """"computing the O(n)-invariant Euclidean distance to the identity (element)""" 71 | if self.beta==0.: 72 | dist = self.alpha * th.linalg.matrix_norm(X).square() 73 | else: 74 | item1 = th.linalg.matrix_norm(X) 75 | item2 = functional.trace(X) 76 | dist = self.alpha * item1.square() + self.beta * item2.square() 77 | return dist 78 | 79 | def spd_power(self,X): 80 | if self.theta == 1.: 81 | X_power = X 82 | else: 83 | X_power = sym_functional.sym_powm.apply(X, self.theta) 84 | return X_power 85 | 86 | def inv_power(self,X): 87 | if self.theta == 1.: 88 | X_power = X 89 | else: 90 | X_power = sym_functional.sym_powm.apply(X, 1/self.theta) 91 | return X_power 92 | 93 | def deformation(self,X): 94 | 95 | if self.metric=='AIM': 96 | X_deformed = self.spd_power(X) 97 | elif self.metric == 'LEM': 98 | X_deformed = sym_functional.sym_logm.apply(X) 99 | elif self.metric == 'LCM': 100 | X_power = self.spd_power(X) 101 | L = th.linalg.cholesky(X_power) 102 | diag_part = th.diag_embed(th.log(th.diagonal(L, dim1=-2, dim2=-1))) 103 | X_deformed = L.tril(-1) + diag_part 104 | 105 | return X_deformed 106 | 107 | def inv_deformation(self,X): 108 | if self.metric=='AIM': 109 | X_inv_deformed = self.inv_power(X) 110 | elif self.metric == 'LEM': 111 | X_inv_deformed = sym_functional.sym_expm.apply(X) 112 | elif self.metric == 'LCM': 113 | Cho = X.tril(-1) + th.diag_embed(th.exp(th.diagonal(X, dim1=-2, dim2=-1))) 114 | spd = Cho.matmul(Cho.transpose(-1,-2)) 115 | X_inv_deformed = self.inv_power(spd) 116 | return X_inv_deformed 117 | def cal_geom_mean(self, X): 118 | """Frechet mean""" 119 | if self.metric == 'AIM': 120 | mean = self.BaryGeom(X.detach()) 121 | elif self.metric == 'LEM' or self.metric == 'LCM': 122 | mean = X.detach().mean(dim=0, keepdim=True) 123 | 124 | return mean 125 | def cal_geom_var(self, X): 126 | """Frechet variance""" 127 | spd = X.detach() 128 | if self.metric == 'AIM': 129 | dists = self.alpha * th.linalg.matrix_norm(sym_functional.sym_logm.apply(spd)).square() + self.beta * th.logdet(spd).square() 130 | var = dists.mean() 131 | 132 | elif self.metric == 'LEM' or self.metric == 'LCM': 133 | dists = self.alpha_beta_dist(spd) 134 | var = dists.mean() 135 | 136 | if self.metric == 'AIM' or self.metric == 'LCM': 137 | var_final = var * (1 / (self.theta ** 2)) 138 | else: 139 | var_final=var 140 | return var_final.unsqueeze(dim=0) 141 | def Left_translation(self, X, P, is_inverse): 142 | """Left translation by P""" 143 | if self.metric == 'AIM': 144 | L = th.linalg.cholesky(P); 145 | if is_inverse: 146 | tmp = th.linalg.solve(L, X) 147 | X_new = th.linalg.solve(L.transpose(-1, -2), tmp, left=False) 148 | else: 149 | X_new = L.matmul(X).matmul(L.transpose(-1, -2)) 150 | 151 | elif self.metric == 'LEM' or self.metric == 'LCM': 152 | X_new = X - P if is_inverse else X + P 153 | 154 | return X_new 155 | def scaling(self, X, var,scale): 156 | """Scaling by variance""" 157 | factor = scale / (var+self.eps).sqrt() 158 | if self.metric == 'AIM': 159 | X_new = sym_functional.sym_powm.apply(X, factor) 160 | elif self.metric == 'LEM' or self.metric == 'LCM': 161 | X_new = X * factor 162 | 163 | return X_new 164 | def updating_running_statistics(self,batch_mean,batch_var=None): 165 | """updating running mean""" 166 | with th.no_grad(): 167 | # updating mean 168 | if self.metric == 'AIM': 169 | self.running_mean.data = functional.geodesic(self.running_mean, batch_mean,self.momentum) 170 | elif self.metric == 'LEM' or self.metric == 'LCM': 171 | self.running_mean.data = (1-self.momentum) * self.running_mean+ batch_mean * self.momentum 172 | # updating var 173 | self.running_var.data = (1 - self.momentum) * self.running_var + batch_var * self.momentum 174 | 175 | def BaryGeom(self,X,karcher_steps=1,batchdim=0): 176 | ''' 177 | Function which computes the Riemannian barycenter for a batch of data using the Karcher flow 178 | Input x is a batch of SPD matrices (...,n,n) to average 179 | Output is (n,n) Riemannian mean 180 | ''' 181 | batch_mean = X.mean(dim=batchdim,keepdim=True) 182 | for _ in range(karcher_steps): 183 | bm_sq, bm_invsq = sym_functional.sym_invsqrtm2.apply(batch_mean) 184 | XT = sym_functional.sym_logm.apply(bm_invsq @ X @ bm_invsq) 185 | GT = XT.mean(dim=batchdim,keepdim=True) 186 | batch_mean = bm_sq @ sym_functional.sym_expm.apply(GT) @ bm_sq 187 | return batch_mean.squeeze() -------------------------------------------------------------------------------- /LieBN_SPDNet/spd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_SPDNet/spd/__init__.py -------------------------------------------------------------------------------- /LieBN_SPDNet/spd/nn.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | from . import functional 4 | import geoopt 5 | from geoopt.manifolds import SymmetricPositiveDefinite 6 | 7 | 8 | dtype=th.double 9 | # dtype=th.float32 10 | device=th.device('cpu') 11 | 12 | class BiMap(nn.Module): 13 | """ 14 | Input X: (batch_size,hi) SPDDR matrices of size (ni,ni) 15 | Output P: (batch_size,ho) of bilinearly mapped matrices of size (no,no) 16 | Stiefel parameter of size (ho,hi,ni,no) 17 | """ 18 | def __init__(self,ho,hi,ni,no,device = th.device('cpu'),dtype=dtype,optimizer='geoopt'): 19 | super(BiMap, self).__init__() 20 | if optimizer == 'geoopt': 21 | self._W = geoopt.ManifoldParameter(th.empty(ho, hi, ni, no, dtype=dtype, device=device), manifold=geoopt.manifolds.Stiefel()) 22 | elif optimizer == 'homemade': 23 | self._W=functional.StiefelParameter(th.empty(ho,hi,ni,no,dtype=dtype,device=device)) 24 | else: 25 | raise Exception('unknown param_mode {}'.format(optimizer)) 26 | self._ho=ho; self._hi=hi; self._ni=ni; self._no=no 27 | functional.init_bimap_parameter(self._W) 28 | def forward(self,X): 29 | return functional.bimap_channels(X,self._W) 30 | 31 | class LogEig(nn.Module): 32 | """ 33 | Input P: (batch_size,h) SPDDR matrices of size (n,n) 34 | Output X: (batch_size,h) of log eigenvalues matrices of size (n,n) 35 | """ 36 | def forward(self,P): 37 | return functional.LogEig.apply(P) 38 | 39 | class SqmEig(nn.Module): 40 | """ 41 | Input P: (batch_size,h) SPDDR matrices of size (n,n) 42 | Output X: (batch_size,h) of sqrt eigenvalues matrices of size (n,n) 43 | """ 44 | def forward(self,P): 45 | return functional.SqmEig.apply(P) 46 | 47 | class ReEig(nn.Module): 48 | """ 49 | Input P: (batch_size,h) SPDDR matrices of size (n,n) 50 | Output X: (batch_size,h) of rectified eigenvalues matrices of size (n,n) 51 | """ 52 | def forward(self,P): 53 | return functional.ReEig.apply(P) 54 | 55 | class BaryGeom(nn.Module): 56 | ''' 57 | Function which computes the Riemannian barycenter for a batch of data using the Karcher flow 58 | Input x is a batch of SPDDR matrices (batch_size,1,n,n) to average 59 | Output is (n,n) Riemannian mean 60 | ''' 61 | def forward(self,x): 62 | return functional.BaryGeom(x) 63 | 64 | class BatchNormSPD(nn.Module): 65 | """ 66 | Input X: (N,h) SPD matrices of size (n,n) with h channels and batch size N 67 | Output P: (N,h) batch-normalized matrices 68 | SPD parameter of size (n,n) 69 | """ 70 | def __init__(self,n,momentum=0.1): 71 | super(__class__,self).__init__() 72 | self.momentum=momentum; 73 | self.running_mean=th.eye(n,dtype=dtype) ################################ 74 | self.weight = geoopt.ManifoldParameter(th.eye(n, n, dtype=dtype), 75 | manifold=SymmetricPositiveDefinite()) 76 | def forward(self,X): 77 | N,h,n,n=X.shape 78 | X_batched=X.permute(2,3,0,1).contiguous().view(n,n,N*h,1).permute(2,3,0,1).contiguous() 79 | if(self.training): 80 | mean=functional.BaryGeom(X_batched) 81 | with th.no_grad(): 82 | self.running_mean.data=functional.geodesic(self.running_mean,mean,self.momentum) 83 | X_centered=functional.CongrG(X_batched,mean,'neg') 84 | else: 85 | X_centered=functional.CongrG(X_batched,self.running_mean,'neg') 86 | X_normalized=functional.CongrG(X_centered,self.weight,'pos') 87 | # num_x = X.data.numpy() 88 | # num_X_centered = X_centered.data.numpy() 89 | # num_mean = mean.data.numpy() 90 | X_new = X_normalized.permute(2,3,0,1).contiguous().view(n,n,N,h).permute(2,3,0,1).contiguous() 91 | return X_new 92 | 93 | class CovPool(nn.Module): 94 | """ 95 | Input f: Temporal n-dimensionnal feature map of length T (T=1 for a unitary signal) (batch_size,n,T) 96 | Output X: Covariance matrix of size (batch_size,1,n,n) 97 | """ 98 | def __init__(self,reg_mode='mle'): 99 | super(__class__,self).__init__() 100 | self._reg_mode=reg_mode 101 | def forward(self,f): 102 | return functional.cov_pool(f,self._reg_mode) 103 | 104 | class CovPoolBlock(nn.Module): 105 | """ 106 | Input f: L blocks of temporal n-dimensionnal feature map of length T (T=1 for a unitary signal) (batch_size,L,n,T) 107 | Output X: L covariance matrices, shape (batch_size,L,1,n,n) 108 | """ 109 | def __init__(self,reg_mode='mle'): 110 | super(__class__,self).__init__() 111 | self._reg_mode=reg_mode 112 | def forward(self,f): 113 | ff=[functional.cov_pool(f[:,i,:,:],self._reg_mode)[:,None,:,:,:] for i in range(f.shape[1])] 114 | return th.cat(ff,1) 115 | 116 | class CovPoolMean(nn.Module): 117 | """ 118 | Input f: Temporal n-dimensionnal feature map of length T (T=1 for a unitary signal) (batch_size,n,T) 119 | Output X: Covariance matrix of size (batch_size,1,n,n) 120 | """ 121 | def __init__(self,reg_mode='mle'): 122 | super(__class__,self).__init__() 123 | self._reg_mode=reg_mode 124 | def forward(self,f): 125 | return functional.cov_pool_mu(f,self._reg_mode) -------------------------------------------------------------------------------- /LieBN_SPDNet/spd/training_script.py: -------------------------------------------------------------------------------- 1 | from spd.utils import get_dataset_settings 2 | from Network import Get_Model 3 | import spd.utils as nn_spd_utils 4 | 5 | from torch.utils.tensorboard import SummaryWriter 6 | 7 | import os 8 | import datetime 9 | import time 10 | import logging 11 | 12 | import torch as th 13 | import torch.nn as nn 14 | import numpy as np 15 | import fcntl 16 | 17 | def training(cfg,args): 18 | args=nn_spd_utils.parse_cfg(args,cfg) 19 | 20 | #set logger 21 | logger = logging.getLogger(args.modelname) 22 | logger.setLevel(logging.INFO) 23 | args.logger = logger 24 | logger.info('begin model {} on dataset: {}'.format(args.modelname,args.dataset)) 25 | 26 | #set seed and threadnum 27 | nn_spd_utils.set_seed_thread(args.seed,args.threadnum) 28 | 29 | # set dataset, model and optimizer 30 | args.class_num, args.DataLoader = get_dataset_settings(args) 31 | model = Get_Model.get_model(args) 32 | loss_fn = nn.CrossEntropyLoss() 33 | args.loss_fn = loss_fn.cuda() 34 | args.opti = nn_spd_utils.optimzer(model.parameters(), lr=args.lr, mode=args.optimizer,weight_decay=args.weight_decay) 35 | # begin training 36 | val_acc = training_loop(model,args) 37 | 38 | return val_acc 39 | 40 | def training_loop(model, args): 41 | #setting tensorboard 42 | if args.is_writer: 43 | args.writer_path = os.path.join('./tensorboard_logs/',f"{args.modelname}") 44 | args.logger.info('writer path {}'.format(args.writer_path)) 45 | args.writer = SummaryWriter(args.writer_path) 46 | 47 | acc_val = [];loss_val = [];acc_train = [];loss_train = [];training_time=[] 48 | logger = args.logger 49 | # training loop 50 | for epoch in range(0, args.epochs): 51 | # train one epoch 52 | start = time.time() 53 | temp_loss_train, temp_acc_train = [], [] 54 | model.train() 55 | for local_batch, local_labels in args.DataLoader._train_generator: 56 | args.opti.zero_grad() 57 | out = model(local_batch) 58 | l = args.loss_fn(out, local_labels) 59 | acc, loss = (out.argmax(1) == local_labels).cpu().numpy().sum() / out.shape[0], l.cpu().data.numpy() 60 | temp_loss_train.append(loss) 61 | temp_acc_train.append(acc) 62 | l.backward() 63 | args.opti.step() 64 | end = time.time() 65 | training_time.append(end-start) 66 | acc_train.append(np.asarray(temp_acc_train).mean() * 100) 67 | loss_train.append(np.asarray(temp_loss_train).mean()) 68 | 69 | # validation 70 | acc_val_list = [];loss_val_list = [];y_true, y_pred = [], [] 71 | model.eval() 72 | with th.no_grad(): 73 | for local_batch, local_labels in args.DataLoader._test_generator: 74 | out = model(local_batch) 75 | l = args.loss_fn(out, local_labels) 76 | predicted_labels = out.argmax(1) 77 | y_true.extend(list(local_labels.cpu().numpy())); 78 | y_pred.extend(list(predicted_labels.cpu().numpy())) 79 | acc, loss = (predicted_labels == local_labels).cpu().numpy().sum() / out.shape[0], l.cpu().data.numpy() 80 | acc_val_list.append(acc) 81 | loss_val_list.append(loss) 82 | loss_val.append(np.asarray(loss_val_list).mean()) 83 | acc_val.append(np.asarray(acc_val_list).mean() * 100) 84 | 85 | if args.is_writer: 86 | args.writer.add_scalar('Loss/val', loss_val[epoch], epoch) 87 | args.writer.add_scalar('Accuracy/val', acc_val[epoch], epoch) 88 | args.writer.add_scalar('Loss/train', loss_train[epoch], epoch) 89 | args.writer.add_scalar('Accuracy/train', acc_train[epoch], epoch) 90 | 91 | if epoch % args.cycle == 0: 92 | logger.info( 93 | 'Time: {:.2f}, Val acc: {:.2f}, loss: {:.2f} at epoch {:d}/{:d} '.format( 94 | training_time[epoch], acc_val[epoch], loss_val[epoch], epoch + 1, args.epochs)) 95 | 96 | if args.is_save: 97 | average_time = np.asarray(training_time[-10:]).mean() 98 | final_val_acc = acc_val[-1] 99 | final_results = f'Final validation accuracy : {final_val_acc:.2f}% with average time: {average_time:.2f}' 100 | final_results_path = os.path.join(os.getcwd(), 'final_results_' + args.dataset) 101 | logger.info(f"results file path: {final_results_path}, and saving the results") 102 | write_final_results(final_results_path, args.modelname + '- ' + final_results) 103 | torch_results_dir = './torch_resutls' 104 | if not os.path.exists(torch_results_dir): 105 | os.makedirs(torch_results_dir) 106 | th.save({ 107 | 'acc_val': acc_val, 108 | }, os.path.join(torch_results_dir,args.modelname.rsplit('-',1)[0])) 109 | 110 | if args.is_writer: 111 | args.writer.close() 112 | return acc_val 113 | 114 | def write_final_results(file_path,message): 115 | # Create a file lock 116 | with open(file_path, "a") as file: 117 | fcntl.flock(file.fileno(), fcntl.LOCK_EX) # Acquire an exclusive lock 118 | 119 | # Write the message to the file 120 | file.write(message + "\n") 121 | 122 | fcntl.flock(file.fileno(), fcntl.LOCK_UN) -------------------------------------------------------------------------------- /LieBN_SPDNet/spd/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import time 4 | import datetime 5 | 6 | import random 7 | import geoopt 8 | import numpy as np 9 | import torch as th 10 | 11 | from spd.DataLoader.FPHA_Loader import DataLoaderFPHA 12 | from spd.DataLoader.HDM05_Loader import DataLoaderHDM05 13 | from spd.DataLoader.Radar_Loader import DataLoaderRadar 14 | 15 | def get_model_name(args): 16 | if args.model_type == 'SPDNet': 17 | name = f'{args.seed}-{args.lr}-wd_{args.weight_decay}-{args.model_type}-{args.optimizer}-{args.architecture}-{datetime.datetime.now().strftime("%H_%M")}' 18 | elif args.model_type == 'SPDNetBN': 19 | name = f'{args.seed}-{args.lr}-wd_{args.weight_decay}-m_{args.momentum}-{args.model_type}-{args.optimizer}-{args.architecture}-{datetime.datetime.now().strftime("%H_%M")}' 20 | elif args.model_type in args.total_LieBN_model_types: 21 | if args.model_type=='SPDNetLieBN_RS': 22 | model_type = args.model_type + '_init_RS'if args.init_by_RS else args.model_type 23 | else: 24 | model_type = args.model_type 25 | if args.metric == 'AIM' or args.metric == 'LEM': 26 | name = f'{args.seed}-{args.lr}-wd_{args.weight_decay}-m_{args.momentum}-{model_type}-{args.optimizer}-{args.architecture}-{args.metric}-({args.theta},{args.alpha},{args.beta:.4f})-{datetime.datetime.now().strftime("%H_%M")}' 27 | elif args.metric== 'LCM': 28 | name = f'{args.seed}-{args.lr}-wd_{args.weight_decay}-m_{args.momentum}-{model_type}-{args.optimizer}-{args.architecture}-{args.metric}-({args.theta})-{datetime.datetime.now().strftime("%H_%M")}' 29 | else: 30 | raise Exception('unknown metric {} or model'.format(args.metric,args.model_type)) 31 | return name 32 | 33 | def get_dataset_settings(args): 34 | if args.dataset=='FPHA': 35 | class_num = 45 36 | DataLoader = DataLoaderFPHA(args.path,args.batchsize) 37 | elif args.dataset=='HDM05': 38 | class_num = 117 39 | pval = 0.5 40 | DataLoader = DataLoaderHDM05(args.path, pval, args.batchsize) 41 | elif args.dataset== 'RADAR' : 42 | class_num = 3 43 | pval = 0.25 44 | DataLoader = DataLoaderRadar(args.path,pval,args.batchsize) 45 | else: 46 | raise Exception('unknown dataset {}'.format(args.dataset)) 47 | return class_num,DataLoader 48 | 49 | def set_seed_thread(seed,threadnum): 50 | th.set_num_threads(threadnum) 51 | seed = seed 52 | random.seed(seed) 53 | # th.cuda.set_device(args.gpu) 54 | np.random.seed(seed) 55 | th.manual_seed(seed) 56 | th.cuda.manual_seed(seed) 57 | 58 | def optimzer(parameters,lr,mode='AMSGRAD',weight_decay=0.): 59 | if mode=='ADAM': 60 | optim = geoopt.optim.RiemannianAdam(parameters, lr=lr,weight_decay=weight_decay) 61 | elif mode=='SGD': 62 | optim = geoopt.optim.RiemannianSGD(parameters, lr=lr,weight_decay=weight_decay) 63 | elif mode=='AMSGRAD': 64 | optim = geoopt.optim.RiemannianAdam(parameters, lr=lr,amsgrad=True,weight_decay=weight_decay) 65 | else: 66 | raise Exception('unknown optimizer {}'.format(mode)) 67 | return optim 68 | 69 | def training_loop(model, data_loader, opti, loss_fn,writer, args,model_path, begin_epoch): 70 | acc_val = [];loss_val = [];acc_train = [];loss_train = [] 71 | # training loop 72 | for epoch in range(begin_epoch, args.epochs): 73 | # train one epoch 74 | start = time.time() 75 | temp_loss_train, temp_acc_train = [], [] 76 | model.train() 77 | for local_batch, local_labels in data_loader._train_generator: 78 | opti.zero_grad() 79 | out = model(local_batch) 80 | l = loss_fn(out, local_labels) 81 | acc, loss = (out.argmax(1) == local_labels).cpu().numpy().sum() / out.shape[0], l.cpu().data.numpy() 82 | temp_loss_train.append(loss) 83 | temp_acc_train.append(acc) 84 | l.backward() 85 | opti.step() 86 | if args.is_gpu: 87 | th.cuda.synchronize() 88 | end = time.time() 89 | acc_train.append(np.asarray(temp_acc_train).mean() * 100) 90 | loss_train.append(np.asarray(temp_loss_train).mean()) 91 | 92 | # validation 93 | acc_val_list = [];loss_val_list = [];y_true, y_pred = [], [] 94 | model.eval() 95 | with th.no_grad(): 96 | for local_batch, local_labels in data_loader._test_generator: 97 | out = model(local_batch) 98 | l = loss_fn(out, local_labels) 99 | predicted_labels = out.argmax(1) 100 | y_true.extend(list(local_labels.cpu().numpy())); 101 | y_pred.extend(list(predicted_labels.cpu().numpy())) 102 | acc, loss = (predicted_labels == local_labels).cpu().numpy().sum() / out.shape[0], l.cpu().data.numpy() 103 | acc_val_list.append(acc) 104 | loss_val_list.append(loss) 105 | loss_val.append(np.asarray(loss_val_list).mean()) 106 | acc_val.append(np.asarray(acc_val_list).mean() * 100) 107 | if args.is_writer: 108 | writer.add_scalar('Loss/val', loss_val[epoch], epoch) 109 | writer.add_scalar('Accuracy/val', acc_val[epoch], epoch) 110 | writer.add_scalar('Loss/train', loss_train[epoch], epoch) 111 | writer.add_scalar('Accuracy/train', acc_train[epoch], epoch) 112 | print( 113 | '{}: time: {:.2f}, Val acc: {:.2f}, loss: {:.2f}, at epoch {:d}/{:d} '.format( 114 | args.modelname,end - start, acc_val[epoch], loss_val[epoch], epoch + 1, args.epochs)) 115 | if epoch + 1 == args.epochs and args.is_save: 116 | th.save({ 117 | 'epoch': epoch, 118 | 'model_state_dict': model.state_dict(), 119 | 'lr': args.lr, 120 | 'acc_val': acc_val, 121 | 'acc_train': acc_train, 122 | 'loss_val': loss_val, 123 | 'loss_train': loss_train 124 | }, model_path + '-' + str(epoch)) 125 | print('{}: Final validation accuracy: {}%'.format(args.modelname,acc_val[-1])) 126 | if args.is_writer: 127 | writer.close() 128 | return acc_val 129 | 130 | def del_file(path): 131 | ls = os.listdir(path) 132 | for i in ls: 133 | c_path = os.path.join(path, i) 134 | if os.path.isdir(c_path): 135 | del_file(c_path) 136 | else: 137 | os.remove(c_path) 138 | 139 | def resuming_writer(begin_epoch, writer,loss_val,loss_train,acc_val,acc_train): 140 | for epoch in range(begin_epoch): 141 | writer.add_scalar('Loss/val', loss_val[epoch], epoch) 142 | writer.add_scalar('Accuracy/val', acc_val[epoch], epoch) 143 | writer.add_scalar('Loss/train', loss_train[epoch], epoch) 144 | writer.add_scalar('Accuracy/train', acc_train[epoch], epoch) 145 | 146 | def parse_cfg(args,cfg): 147 | # setting args from cfg 148 | 149 | args.seed = cfg.fit.seed 150 | args.model_type = cfg.nnet.model.model_type 151 | args.is_save = cfg.fit.is_save 152 | 153 | if args.model_type in args.total_BN_model_types: 154 | args.BN_type = cfg.nnet.model.BN_type 155 | args.momentum = cfg.nnet.model.momentum 156 | if args.model_type in args.total_LieBN_model_types: 157 | args.metric = cfg.nnet.model.metric 158 | args.theta = cfg.nnet.model.theta 159 | args.alpha = cfg.nnet.model.alpha 160 | args.beta = eval(cfg.nnet.model.beta) if isinstance(cfg.nnet.model.beta, str) else cfg.nnet.model.beta 161 | 162 | args.dataset = cfg.dataset.name 163 | args.architecture = cfg.dataset.architecture 164 | args.path = cfg.dataset.path 165 | 166 | args.optimizer = cfg.nnet.optimizer.mode 167 | args.lr = cfg.nnet.optimizer.lr 168 | args.weight_decay = cfg.nnet.optimizer.weight_decay 169 | 170 | args.epochs = cfg.fit.epochs 171 | args.batchsize = cfg.fit.batch_size 172 | 173 | args.threadnum = cfg.fit.threadnum 174 | args.is_writer = cfg.fit.is_writer 175 | args.cycle = cfg.fit.cycle 176 | 177 | # get model name 178 | args.modelname = get_model_name(args) 179 | 180 | return args -------------------------------------------------------------------------------- /LieBN_TSMNet/LieBN_utilities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/LieBN_utilities/__init__.py -------------------------------------------------------------------------------- /LieBN_TSMNet/LieBN_utilities/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import fcntl 3 | import random 4 | 5 | import numpy as np 6 | import torch as th 7 | 8 | def get_model_name(args): 9 | bias = 'bias' if args.learn_mean else 'non_bias' 10 | 11 | if args.model_type=='TSMNet+SPDDSMBN' or args.model_type=='TSMNet': 12 | name = f'{args.seed}-{args.lr}-wd_{args.weight_decay}-{args.model_type}-{args.optimizer}-{args.architecture}-{bias}-{datetime.datetime.now().strftime("%H_%M")}' 13 | elif args.model_type == 'TSMNet+LieBN': 14 | if args.metric == 'AIM' or args.metric == 'LEM': 15 | name = f'{args.seed}-{args.lr}-wd_{args.weight_decay}-{args.model_type}-{args.optimizer}-{args.architecture}-{bias}-{args.metric}-({args.theta},{args.alpha},{args.beta})-{datetime.datetime.now().strftime("%H_%M")}' 16 | elif args.metric== 'LCM': 17 | name = f'{args.seed}-{args.lr}-wd_{args.weight_decay}-{args.model_type}-{args.optimizer}-{args.architecture}-{bias}-{args.metric}-({args.theta})-{datetime.datetime.now().strftime("%H_%M")}' 18 | else: 19 | raise Exception('unknown metric {} or model'.format(args.metric,args.model_type)) 20 | return name 21 | 22 | def write_final_results(file_path,message): 23 | # Create a file lock 24 | with open(file_path, "a") as file: 25 | fcntl.flock(file.fileno(), fcntl.LOCK_EX) # Acquire an exclusive lock 26 | # Write the message to the file 27 | file.write(message + "\n") 28 | fcntl.flock(file.fileno(), fcntl.LOCK_UN) # Release the lock 29 | 30 | def set_seed_thread(seed,threadnum): 31 | th.set_num_threads(threadnum) 32 | seed = seed 33 | random.seed(seed) 34 | # th.cuda.set_device(args.gpu) 35 | np.random.seed(seed) 36 | th.manual_seed(seed) 37 | th.cuda.manual_seed(seed) -------------------------------------------------------------------------------- /LieBN_TSMNet/TSMNet-LieBN.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Ziheng Chen 3 | Please cite the paper below if you use the code: 4 | 5 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024 6 | 7 | Copyright (C) 2024 Ziheng Chen 8 | All rights reserved. 9 | """ 10 | 11 | import hydra 12 | from omegaconf import DictConfig 13 | 14 | import warnings 15 | from sklearn.exceptions import FitFailedWarning, ConvergenceWarning 16 | 17 | from library.utils.hydra import hydra_helpers 18 | from LieBN_utilities.Training import training 19 | 20 | warnings.filterwarnings("ignore", category=FitFailedWarning) 21 | warnings.filterwarnings("ignore", category=ConvergenceWarning) 22 | warnings.filterwarnings("ignore", category=FutureWarning) 23 | warnings.filterwarnings("ignore", category=UserWarning) 24 | warnings.filterwarnings("ignore", category=RuntimeWarning) 25 | 26 | class Args: 27 | """ a Struct Class """ 28 | pass 29 | args=Args() 30 | args.config_name='LieBN.yaml' 31 | 32 | @hydra_helpers 33 | @hydra.main(config_path='./conf/', config_name=args.config_name, version_base='1.1') 34 | # @hydra.main(config_path='./conf/', config_name=args.config_name) 35 | def main(cfg: DictConfig): 36 | training(cfg,args) 37 | 38 | if __name__ == '__main__': 39 | 40 | main() 41 | -------------------------------------------------------------------------------- /LieBN_TSMNet/conf/LieBN.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - dataset: hinss2021 4 | - evaluation: inter-session+uda 5 | - preprocessing: bb4-36Hz 6 | - nnet: tsmnet_spddsmbn 7 | - override hydra/launcher: joblib 8 | fit: 9 | stratified: True 10 | epochs: 50 11 | batch_size_train: 50 12 | domains_per_batch: 5 13 | batch_size_test: -1 14 | validation_size: 0.2 #0.1 # float <1 for fraction; int for specific number 15 | test_size: 0.05 # percent of groups/domains used for testing 16 | 17 | saving_model: 18 | is_save_model: False 19 | is_save_results_each_fold: True 20 | 21 | score: balanced_accuracy # sklearn scores 22 | device: GPU 23 | threadnum: 1 24 | data_dir: /data 25 | is_debug: False 26 | seed: 42 27 | is_timing: False 28 | 29 | hydra: 30 | run: 31 | dir: outputs/${dataset.name}/${evaluation.strategy} 32 | sweep: 33 | dir: outputs/${dataset.name}/${evaluation.strategy} 34 | subdir: '.' 35 | launcher: 36 | n_jobs: -1 37 | job_logging: 38 | handlers: 39 | file: 40 | class: logging.FileHandler 41 | filename: default.log 42 | # job: 43 | # num_threads: 2 44 | -------------------------------------------------------------------------------- /LieBN_TSMNet/conf/dataset/hinss2021.yaml: -------------------------------------------------------------------------------- 1 | # name of the dataset 2 | name: Hinss2021 3 | # python type (and parameters) 4 | type: 5 | _target_: datasetio.eeg.moabb.Hinss2021 6 | 7 | classes: ["easy", "medium", "difficult"] 8 | # channel selection 9 | channels: ['FP1', 'FP2', 'FPz', 'AF7', 'AF3', 'AFz', 'AF4', 'AF8', 'F7', 'F5', 'F3', 'F1', 'F2', 'F4', 'F6', 'F8', 'FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'C3', 'C4', 'CPz', 'PO3', 'PO4', 'POz', 'Oz' ] 10 | # if 'null' or not defined, all available channels will be used 11 | # resampling 12 | # if 'null' or not defined, the datasets sampling frequency will be used 13 | # resample: 250 # Hz 14 | ## epoching (relative to TASK CUE onset, as defined in the dataset) 15 | tmin: 0.0 16 | tmax: 1.996 17 | -------------------------------------------------------------------------------- /LieBN_TSMNet/conf/evaluation/inter-session+uda.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - inter-session 3 | - _self_ 4 | adapt: 5 | name: uda 6 | nadapt_domain: 1. # int -> absolute number of observations per CLASS -------------------------------------------------------------------------------- /LieBN_TSMNet/conf/evaluation/inter-subject+uda.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - inter-subject 3 | - _self_ 4 | adapt: 5 | name: uda 6 | nadapt_domain: 1. # int -> absolute number of observations per CLASS -------------------------------------------------------------------------------- /LieBN_TSMNet/conf/nnet/tsmnet.yaml: -------------------------------------------------------------------------------- 1 | name: TSMNet 2 | inputtype: ${torchdtype:float32} 3 | model: 4 | _target_: spdnets.models.TSMNet 5 | temporal_filters: 4 6 | spatial_filters: 40 7 | subspacedims: 20 8 | learn_mean: False # Following SPDDSMBN, we set bias=I 9 | bnorm: null 10 | bnorm_dispersion: null 11 | optimizer: 12 | _target_: geoopt.optim.RiemannianAdam 13 | amsgrad: True 14 | weight_decay: 1e-4 15 | lr: 1e-3 16 | param_groups: 17 | - 18 | - 'spdnet.*.W' 19 | - weight_decay: 0 20 | scheduler: 21 | _target_: spdnets.batchnorm.DummyScheduler -------------------------------------------------------------------------------- /LieBN_TSMNet/conf/nnet/tsmnet_LieBN.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - tsmnet 3 | - _self_ 4 | name: TSMNet+LieBN 5 | model: 6 | bnorm: LieBN 7 | bnorm_dispersion: SCALAR 8 | metric: AIM 9 | theta: 1.0 10 | alpha: 1.0 11 | beta: 0.0 12 | scheduler: 13 | _target_: spdnets.Liebatchnorm.MomentumBatchNormScheduler 14 | epochs: ${sub:${fit.epochs},10} 15 | bs: ${rdiv:${fit.batch_size_train},${fit.domains_per_batch}} 16 | bs0: ${fit.batch_size_train} 17 | tau0: 0.85 18 | 19 | optimizer: 20 | param_groups: 21 | - 22 | - 'spd*.mean' 23 | - weight_decay: 0 24 | - 25 | - 'spdnet.*.W' 26 | - weight_decay: 0 -------------------------------------------------------------------------------- /LieBN_TSMNet/conf/nnet/tsmnet_spddsmbn.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - tsmnet 3 | - _self_ 4 | name: TSMNet+SPDDSMBN 5 | model: 6 | bnorm: spddsbn 7 | bnorm_dispersion: SCALAR 8 | scheduler: 9 | _target_: spdnets.batchnorm.MomentumBatchNormScheduler 10 | epochs: ${sub:${fit.epochs},10} 11 | bs: ${rdiv:${fit.batch_size_train},${fit.domains_per_batch}} 12 | bs0: ${fit.batch_size_train} 13 | tau0: 0.85 14 | 15 | optimizer: 16 | param_groups: 17 | - 18 | - 'spd*.mean' 19 | - weight_decay: 0 20 | - 21 | - 'spdnet.*.W' 22 | - weight_decay: 0 -------------------------------------------------------------------------------- /LieBN_TSMNet/conf/preprocessing/bb4-36Hz.yaml: -------------------------------------------------------------------------------- 1 | bb4-36Hz: 2 | _target_: library.utils.moabb.CachedMotorImagery 3 | fmin: 4 # Hz 4 | fmax: 36 # Hz 5 | events: ${dataset.classes} 6 | channels: ${oc.select:dataset.channels} 7 | resample: ${oc.select:dataset.resample} 8 | tmin: ${oc.select:dataset.tmin, 0.0} 9 | tmax: ${oc.select:dataset.tmax} 10 | -------------------------------------------------------------------------------- /LieBN_TSMNet/datasetio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/datasetio/__init__.py -------------------------------------------------------------------------------- /LieBN_TSMNet/datasetio/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/datasetio/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/datasetio/eeg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/datasetio/eeg/__init__.py -------------------------------------------------------------------------------- /LieBN_TSMNet/datasetio/eeg/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/datasetio/eeg/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/datasetio/eeg/moabb/__init__.py: -------------------------------------------------------------------------------- 1 | from moabb.datasets.bnci import BNCI2014001 as moabbBNCI2014001 2 | from moabb.datasets.bnci import BNCI2015001 as moabbBNCI2015001 3 | from moabb.datasets import Lee2019_MI as moabbLee2019 4 | 5 | from .base import PreprocessedDataset, CachableDatase 6 | from .stieger2021 import Stieger2021 7 | from .hinss2021 import Hinss2021 8 | 9 | class BNCI2014001(moabbBNCI2014001, CachableDatase): 10 | 11 | def _get_single_subject_data(self, subject): 12 | """return data for a single subject""" 13 | sessions = super()._get_single_subject_data(subject=subject) 14 | map = dict(session_T=1,session_E=2) 15 | sessions = dict([(map[k],v) for k, v in sessions.items()]) 16 | return sessions 17 | 18 | class BNCI2015001(moabbBNCI2015001, CachableDatase): 19 | 20 | def _get_single_subject_data(self, subject): 21 | """return data for a single subject""" 22 | sessions = super()._get_single_subject_data(subject=subject) 23 | map = dict(session_A=1,session_B=2,session_C=3) 24 | sessions = dict([(map[k],v) for k, v in sessions.items()]) 25 | return sessions 26 | 27 | class Lee2019(moabbLee2019, CachableDatase): 28 | 29 | def _get_single_subject_data(self, subject): 30 | """return data for a single subject""" 31 | sessions = super()._get_single_subject_data(subject=subject) 32 | map = dict(session_1=1,session_2=2) 33 | sessions = dict([(map[k],v) for k, v in sessions.items()]) 34 | return sessions 35 | -------------------------------------------------------------------------------- /LieBN_TSMNet/datasetio/eeg/moabb/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/datasetio/eeg/moabb/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/datasetio/eeg/moabb/__pycache__/base.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/datasetio/eeg/moabb/__pycache__/base.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/datasetio/eeg/moabb/__pycache__/hinss2021.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/datasetio/eeg/moabb/__pycache__/hinss2021.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/datasetio/eeg/moabb/__pycache__/stieger2021.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/datasetio/eeg/moabb/__pycache__/stieger2021.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/datasetio/eeg/moabb/base.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import mne 3 | import json 4 | 5 | from moabb.datasets.base import BaseDataset 6 | 7 | class CachableDatase(BaseDataset): 8 | 9 | def __init__(self, *args, **kwargs): 10 | super().__init__(*args, **kwargs) 11 | 12 | def __repr__(self) -> str: 13 | return json.dumps({self.__class__.__name__: self.__dict__}) 14 | 15 | class PreprocessedDataset(CachableDatase): 16 | 17 | def __init__(self, *args, channels : Optional[list] = None, srate : Optional[int] = None, **kwargs): 18 | super().__init__(*args, **kwargs) 19 | self.channels = channels 20 | self.srate = srate 21 | 22 | def preprocess(self, raw): 23 | 24 | # find the events, first check stim_channels 25 | if len(mne.pick_types(raw.info, stim=True)) > 0: 26 | events = mne.find_events(raw, shortest_event=0, verbose=False) 27 | else: 28 | events = None # the dataset already uses annotations 29 | 30 | # optional resampling 31 | if self.srate is not None: 32 | ret = raw.resample(self.srate, events=events) 33 | raw, events = (ret, events) if events is None else (ret[0], ret[1]) 34 | 35 | # convert optional events to annotations (before we discard the stim channels) 36 | if events is not None: 37 | rev_event_it = dict(zip(self.event_id.values(), self.event_id.keys())) 38 | annot = mne.annotations_from_events(events, raw.info['sfreq'], event_desc=rev_event_it) 39 | raw.set_annotations(annot) 40 | 41 | # pick subset of all channels 42 | if self.channels is not None: 43 | raw.pick_channels(self.channels) 44 | else: 45 | raw.pick_types(eeg=True) 46 | 47 | return raw 48 | 49 | def __repr__(self) -> str: 50 | return json.dumps({self.__class__.__name__: self.__dict__}) 51 | -------------------------------------------------------------------------------- /LieBN_TSMNet/datasetio/eeg/moabb/hinss2021.py: -------------------------------------------------------------------------------- 1 | import os 2 | import mne 3 | import numpy as np 4 | import pooch 5 | import logging 6 | import requests 7 | import json 8 | import subprocess 9 | import re 10 | import glob 11 | 12 | from scipy.io import loadmat 13 | import moabb.datasets.download as dl 14 | 15 | from .base import PreprocessedDataset 16 | 17 | log = logging.getLogger(__name__) 18 | 19 | 20 | 21 | def doi_to_url(doi, api_url = lambda x : f"https://doi.org/api/handles/{x}?type=URL"): 22 | 23 | url = None 24 | headers = {"Content-Type": "application/json"} 25 | response_data = dl.fs_issue_request("GET", api_url(doi), headers=headers) 26 | 27 | if 'values' in response_data: 28 | candidates = [ val['data']['value'] for val in response_data['values'] if 'data' in val and isinstance(val['data'], dict) and 'value' in val['data']] 29 | url = candidates[0] if len(candidates)> 0 else None 30 | 31 | return url 32 | 33 | 34 | 35 | def url_get_json(url : str): 36 | 37 | headers = {"Content-Type": "application/json"} 38 | response = dl.fs_issue_request("GET", url, headers=headers) 39 | return response 40 | 41 | 42 | 43 | class Hinss2021(PreprocessedDataset): 44 | 45 | ZENODO_JSON_API_URL = lambda x : f"https://zenodo.org/api/{x}" 46 | 47 | TASK_TO_EVENTID = dict(RS='rest', MATBeasy='easy', MATBmed='medium', MATBdiff='difficult') 48 | 49 | def __init__(self, interval = [0, 2], channels = None, srate = None): 50 | super().__init__( 51 | subjects=list(range(1, 15+1)), 52 | sessions_per_subject=2, 53 | events=dict(easy=1, medium=2, difficult=3, rest=4), 54 | code="Hinss2021", 55 | interval=interval, 56 | paradigm="imagery", 57 | doi="10.5281/zenodo.4917217", 58 | channels=channels, 59 | srate=srate 60 | ) 61 | 62 | 63 | 64 | def preprocess(self, raw): 65 | # interpolate channels marked as bad 66 | if len(raw.info['bads']) > 0: 67 | raw.interpolate_bads() 68 | return super().preprocess(raw) 69 | 70 | def data_path( 71 | self, subject, path=None, force_update=False, update_path=None, verbose=None 72 | ): 73 | if subject not in self.subject_list: 74 | raise (ValueError("Invalid subject number")) 75 | 76 | key_dest = f"MNE-{self.code:s}-data" 77 | path = os.path.join(dl.get_dataset_path(self.code, path), key_dest) 78 | 79 | url = doi_to_url(self.doi) 80 | if url is None: 81 | raise ValueError("Could not find zenodo id based on dataset DOI!") 82 | 83 | zenodoid = url.split('/')[-1] 84 | 85 | metadata = url_get_json(Hinss2021.ZENODO_JSON_API_URL(f"records/{zenodoid}")) 86 | 87 | fnames = [] 88 | for record in metadata['files']: 89 | 90 | fname = record['key'] 91 | fpath = os.path.join(path, fname) 92 | 93 | 94 | # metadata 95 | # if record['type'] != 'zip' and not os.path.exists(fpath): # subject data 96 | # pooch.retrieve(record['links']['self'], record['checksum'], fname, path, downloader=pooch.HTTPDownloader(progressbar=True)) 97 | # subject specific data 98 | if record['type'] == 'zip' and fname == f"P{subject:02d}.zip": 99 | if not os.path.exists(fpath): 100 | files = pooch.retrieve(record['links']['self'], record['checksum'], fname, path, 101 | processor=pooch.Unzip(), 102 | downloader=pooch.HTTPDownloader(progressbar=True)) 103 | 104 | # load the data 105 | tasks = list(Hinss2021.TASK_TO_EVENTID.keys()) 106 | taskpattern = '('+ '|'.join(tasks)+')' 107 | pattern = f'{fpath}.unzip/P{subject:02d}/S?/eeg/alldata_*.set' 108 | candidates = glob.glob(pattern, recursive=True) 109 | fnames += [c for c in candidates if re.search(f'.*{taskpattern}.set', c)] 110 | 111 | return fnames 112 | 113 | 114 | def _get_single_subject_data(self, subject): 115 | fnames = self.data_path(subject) 116 | 117 | subject_data = {} 118 | for fn in fnames: 119 | meta = re.search('alldata_sbj(?P\d\d)_sess(?P\d)_((?P\w+))', 120 | os.path.basename(fn)) 121 | sid = int(meta['session']) 122 | 123 | if sid not in range(1,self.n_sessions+1): 124 | continue 125 | 126 | epochs = mne.io.read_epochs_eeglab(fn, verbose=False) 127 | assert(len(epochs.event_id) == 1) 128 | event_id = Hinss2021.TASK_TO_EVENTID[list(epochs.event_id.keys())[0]] 129 | epochs.event_id = {event_id : self.event_id[event_id]} 130 | epochs.events[:,2] = epochs.event_id[event_id] 131 | 132 | # covnert to continuous raw object with correct annotations 133 | continuous_data = np.swapaxes(epochs.get_data(),0,1).reshape((len(epochs.info['chs']),-1)) 134 | raw = mne.io.RawArray(data=continuous_data, info=epochs.info, verbose=False, first_samp=1) 135 | # XXX use standard electrode layout rather than invidividual positions 136 | # raw.set_montage(epochs.get_montage()) 137 | raw.set_montage('standard_1005') 138 | events = epochs.events.copy() 139 | evt_desc = dict(zip(epochs.event_id.values(),epochs.event_id.keys())) 140 | 141 | annot = mne.annotations_from_events(events, raw.info['sfreq'], event_desc=evt_desc, first_samp=1) 142 | 143 | raw.set_annotations(annot) 144 | 145 | if sid in subject_data: 146 | subject_data[sid][0].append(raw) 147 | else: 148 | subject_data[sid] = {0 : raw} 149 | 150 | # discard boundary annotations 151 | keep = [i for i, desc in enumerate(subject_data[sid][0].annotations.description) if desc in self.event_id] 152 | subject_data[sid][0].set_annotations(subject_data[sid][0].annotations[keep]) 153 | 154 | return subject_data 155 | -------------------------------------------------------------------------------- /LieBN_TSMNet/datasetio/eeg/moabb/stieger2021.py: -------------------------------------------------------------------------------- 1 | import os 2 | import mne 3 | import numpy as np 4 | import pooch 5 | import logging 6 | 7 | from scipy.io import loadmat 8 | import moabb.datasets.download as dl 9 | 10 | from .base import PreprocessedDataset 11 | 12 | log = logging.getLogger(__name__) 13 | 14 | class Stieger2021(PreprocessedDataset): 15 | 16 | BASE_URL = "https://ndownloader.figshare.com/files/" 17 | 18 | def __init__(self, interval = [0, 3], channels = None, srate = None, sessions = None): 19 | super().__init__( 20 | subjects=list(range(1, 63)), 21 | sessions_per_subject=11, 22 | events=dict(right_hand=1, left_hand=2, both_hand=3, rest=4), 23 | code="Stieger2021", 24 | interval=interval, 25 | paradigm="imagery", 26 | doi="10.1038/s41597-021-00883-1", 27 | channels=channels, 28 | srate=srate 29 | ) 30 | 31 | self.sessions = sessions 32 | self.figshare_id = 13123148 # id on figshare 33 | 34 | assert(interval[0] >= 0.00) # the interval has to start after the cue onset 35 | 36 | 37 | def preprocess(self, raw): 38 | # interpolate channels marked as bad 39 | if len(raw.info['bads']) > 0: 40 | raw.interpolate_bads() 41 | return super().preprocess(raw) 42 | 43 | def data_path( 44 | self, subject, path=None, force_update=False, update_path=None, verbose=None 45 | ): 46 | if subject not in self.subject_list: 47 | raise (ValueError("Invalid subject number")) 48 | 49 | key_dest = f"MNE-{self.code:s}-data" 50 | path = os.path.join(dl.get_dataset_path(self.code, path), key_dest) 51 | 52 | filelist = dl.fs_get_file_list(self.figshare_id) 53 | reg = dl.fs_get_file_hash(filelist) 54 | fsn = dl.fs_get_file_id(filelist) 55 | 56 | spath = [] 57 | for f in fsn.keys(): 58 | if ".mat" not in f: 59 | continue 60 | sbj = int(f.split('_')[0][1:]) 61 | ses = int(f.split('_')[-1].split('.')[0]) 62 | if sbj == subject and ses in self.sessions: 63 | fpath = os.path.join(path, f) 64 | if not os.path.exists(fpath): 65 | pooch.retrieve(Stieger2021.BASE_URL + fsn[f], reg[fsn[f]], f, path, downloader=pooch.HTTPDownloader(progressbar=True)) 66 | spath.append(fpath) 67 | return spath 68 | 69 | def _get_single_subject_data(self, subject): 70 | fname = self.data_path(subject) 71 | 72 | subject_data = {} 73 | 74 | for fn in fname: 75 | 76 | session = int(os.path.basename(fn).split('_')[2].split('.')[0]) 77 | 78 | if self.sessions is not None: 79 | if session not in set(self.sessions): 80 | continue 81 | 82 | container = loadmat( 83 | fn, 84 | squeeze_me=True, 85 | struct_as_record=False, 86 | verify_compressed_data_integrity=False, 87 | )['BCI'] 88 | 89 | srate = container.SRATE 90 | 91 | eeg_ch_names = container.chaninfo.label.tolist() 92 | # adjust naming convention 93 | eeg_ch_names = [ch.replace('Z', 'z').replace('FP','Fp') for ch in eeg_ch_names] 94 | # extract all standard EEG channels 95 | montage = mne.channels.make_standard_montage("standard_1005") 96 | channel_mask = np.isin(eeg_ch_names, montage.ch_names) 97 | ch_names = [ch for ch, found in zip(eeg_ch_names, channel_mask) if found] + ["Stim"] 98 | ch_types = ["eeg"] * channel_mask.sum() + ["stim"] 99 | 100 | X_flat = [] 101 | stim_flat = [] 102 | for i in range(container.data.shape[0]): 103 | x = container.data[i][channel_mask,:] 104 | y = container.TrialData[i].targetnumber 105 | stim = np.zeros_like(container.time[i]) 106 | if container.TrialData[i].artifact == 0 and (container.TrialData[i].triallength + 2) > self.interval[1]: 107 | assert(container.time[i][2*srate] == 0) # this should be the cue time-point 108 | stim[2*srate] = y 109 | X_flat.append(x) 110 | stim_flat.append(stim[None,:]) 111 | 112 | X_flat = np.concatenate(X_flat, axis=1) 113 | stim_flat = np.concatenate(stim_flat, axis=1) 114 | 115 | p_keep = np.flatnonzero(stim_flat).shape[0]/container.data.shape[0] 116 | 117 | message = f'Record {subject}/{session} (subject/session): rejecting {(1 - p_keep)*100:.0f}% of the trials.' 118 | if p_keep < 0.5: 119 | log.warning(message) 120 | else: 121 | log.info(message) 122 | 123 | eeg_data = np.concatenate([X_flat * 1e-6, stim_flat], axis=0) 124 | 125 | info = mne.create_info(ch_names=ch_names, ch_types=ch_types, sfreq=srate) 126 | raw = mne.io.RawArray(data=eeg_data, info=info, verbose=False) 127 | raw.set_montage(montage) 128 | if isinstance(container.chaninfo.noisechan, int): 129 | badchanidxs = [container.chaninfo.noisechan] 130 | elif isinstance(container.chaninfo.noisechan, np.ndarray): 131 | badchanidxs = container.chaninfo.noisechan 132 | else: 133 | badchanidxs = [] 134 | 135 | for idx in badchanidxs: 136 | # badchan = [eeg_ch_names[idx-1] for idx in badchanidxs] 137 | used_channels = ch_names if self.channels is None else self.channels 138 | if eeg_ch_names[idx-1] in used_channels: 139 | raw.info['bads'].append(eeg_ch_names[idx-1]) 140 | 141 | if len(raw.info['bads']) > 0: 142 | log.info(f'Record {subject}/{session} (subject/session): bad channels that will be interpolated: {raw.info["bads"]}') 143 | 144 | # subject_data[session] = {"run_0": self._common_prep(raw)}\ 145 | subject_data[session] = {"run_0": self.preprocess(raw)} 146 | return subject_data 147 | -------------------------------------------------------------------------------- /LieBN_TSMNet/experiments_Hinss21.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | data_dir=/data #change this to your data folder 3 | 4 | ### Experiments on SPDDSMBN 5 | [ $? -eq 0 ] && python TSMNet-LieBN.py -m data_dir=$data_dir dataset=hinss2021 evaluation=inter-subject+uda,inter-session+uda nnet=tsmnet_spddsmbn 6 | ### Experiments on DSMLieBN under standard metrics 7 | [ $? -eq 0 ] && python TSMNet-LieBN.py -m data_dir=$data_dir dataset=hinss2021 evaluation=inter-subject+uda,inter-session+uda nnet=tsmnet_LieBN nnet.model.metric=LEM,LCM,AIM 8 | 9 | ### Experiments on DSMLieBN under 0.5-LCM for inter-session 10 | [ $? -eq 0 ] && python TSMNet-LieBN.py -m data_dir=$data_dir dataset=hinss2021 evaluation=inter-session+uda nnet=tsmnet_LieBN nnet.model.metric=LCM nnet.model.theta=0.5 11 | ### Experiments on DSMLieBN under -0.5-AIM for inter-subject 12 | [ $? -eq 0 ] && python TSMNet-LieBN.py -m data_dir=$data_dir dataset=hinss2021 evaluation=inter-subject+uda nnet=tsmnet_LieBN nnet.model.metric=AIM nnet.model.theta=-0.5 -------------------------------------------------------------------------------- /LieBN_TSMNet/library/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/library/__init__.py -------------------------------------------------------------------------------- /LieBN_TSMNet/library/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/library/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/library/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/library/utils/__init__.py -------------------------------------------------------------------------------- /LieBN_TSMNet/library/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/library/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/library/utils/hydra/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hydra 3 | from hydra.core.hydra_config import HydraConfig 4 | from omegaconf import OmegaConf, DictConfig 5 | import torch 6 | 7 | from sklearn.pipeline import Pipeline 8 | 9 | def hydra_helpers(func): 10 | def inner(*args, **kwargs): 11 | # setup helpers 12 | 13 | # omega conf helpers 14 | OmegaConf.register_new_resolver("len", lambda x:len(x), replace=True) 15 | OmegaConf.register_new_resolver("add", lambda x,y:x+y, replace=True) 16 | OmegaConf.register_new_resolver("sub", lambda x,y:x-y, replace=True) 17 | OmegaConf.register_new_resolver("mul", lambda x,y:x*y, replace=True) 18 | OmegaConf.register_new_resolver("rdiv", lambda x,y:x/y, replace=True) 19 | 20 | STR2TORCHDTYPE = { 21 | 'float32': torch.float32, 22 | 'float64': torch.float64, 23 | 'double': torch.double, 24 | } 25 | OmegaConf.register_new_resolver("torchdtype", lambda x:STR2TORCHDTYPE[x], replace=True) 26 | # if func is not None and list(kwargs.keys())[0] !='args': 27 | if func is not None: 28 | # func(args = kwargs['args']) 29 | func(*args, **kwargs) 30 | return inner 31 | 32 | 33 | def make_sklearn_pipeline(steps_config) -> Pipeline: 34 | 35 | steps = [] 36 | for step_config in steps_config: 37 | 38 | # retrieve the name and parameter dictionary of the current steps 39 | step_name, step_transform = next(iter(step_config.items())) 40 | # instantiate the pipeline step, and append to the list of steps 41 | if isinstance(step_transform, DictConfig): 42 | pipeline_step = (step_name, hydra.utils.instantiate(step_transform, _convert_='partial')) 43 | else: 44 | pipeline_step = (step_name, step_transform) 45 | steps.append(pipeline_step) 46 | 47 | return Pipeline(steps) 48 | -------------------------------------------------------------------------------- /LieBN_TSMNet/library/utils/hydra/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/library/utils/hydra/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/library/utils/moabb/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | import json 4 | import os 5 | import hashlib 6 | import numpy as np 7 | import pandas as pd 8 | import mne 9 | 10 | from sklearn.base import BaseEstimator 11 | from moabb.paradigms.base import BaseParadigm 12 | from moabb.paradigms.motor_imagery import FilterBankMotorImagery, MotorImagery 13 | from mne import get_config, set_config 14 | from mne.datasets.utils import _get_path 15 | from mne.io import read_info, write_info 16 | from skorch import dataset 17 | 18 | 19 | log = logging.getLogger(__name__) 20 | 21 | class CachedParadigm(BaseParadigm): 22 | 23 | def _get_string_rep(self, obj): 24 | if issubclass(type(obj), BaseEstimator): 25 | str_repr = repr(obj.get_params()) 26 | else: 27 | str_repr = repr(obj) 28 | str_no_addresses = re.sub("0x[a-z0-9]*", "0x__", str_repr) 29 | return str_no_addresses.replace("\n", "") 30 | 31 | def _get_rep(self, dataset): 32 | return self._get_string_rep(dataset) + '\n' + self._get_string_rep(self) 33 | 34 | def _get_cache_dir(self, rep): 35 | if get_config("MNEDATASET_TMP_DIR") is None: 36 | set_config("MNEDATASET_TMP_DIR", os.path.join(os.path.expanduser("~"), "mne_data")) 37 | base_dir = _get_path(None, "MNEDATASET_TMP_DIR", "preprocessed") 38 | 39 | digest = hashlib.sha1(rep.encode("utf8")).hexdigest() 40 | 41 | cache_dir = os.path.join( 42 | base_dir, 43 | "preprocessed", 44 | digest 45 | ) 46 | return cache_dir 47 | 48 | 49 | def process_raw(self, raw, dataset, return_epochs=False,return_raws=False): 50 | # get events id 51 | event_id = self.used_events(dataset) 52 | 53 | # find the events, first check stim_channels then annotations 54 | stim_channels = mne.utils._get_stim_channel(None, raw.info, 55 | raise_error=False) 56 | if len(stim_channels) > 0: 57 | events = mne.find_events(raw, shortest_event=0, verbose=False) 58 | else: 59 | events, _ = mne.events_from_annotations(raw, event_id=event_id, verbose=False) 60 | 61 | # picks channels 62 | if self.channels is None: 63 | picks = mne.pick_types(raw.info, eeg=True, stim=False) 64 | else: 65 | picks = mne.pick_types(raw.info, stim=False, include=self.channels) 66 | 67 | # pick events, based on event_id 68 | try: 69 | events = mne.pick_events(events, include=list(event_id.values())) 70 | except RuntimeError: 71 | # skip raw if no event found 72 | return 73 | 74 | # get interval 75 | tmin = self.tmin + dataset.interval[0] 76 | if self.tmax is None: 77 | tmax = dataset.interval[1] 78 | else: 79 | tmax = self.tmax + dataset.interval[0] 80 | 81 | X = [] 82 | for bandpass in self.filters: 83 | fmin, fmax = bandpass 84 | # filter data 85 | if fmin is None and fmax is None: 86 | raw_f = raw 87 | else: 88 | raw_f = raw.copy().filter(fmin, fmax, method='iir', 89 | picks=picks, verbose=False) 90 | # epoch data 91 | epochs = mne.Epochs(raw_f, events, event_id=event_id, 92 | tmin=tmin, tmax=tmax, proj=False, 93 | baseline=None, preload=True, 94 | verbose=False, picks=picks, 95 | event_repeated='drop', 96 | on_missing='ignore') 97 | if self.resample is not None: 98 | epochs = epochs.resample(self.resample) 99 | # rescale to work with uV 100 | if return_epochs: 101 | X.append(epochs) 102 | else: 103 | X.append(dataset.unit_factor * epochs.get_data()) 104 | 105 | inv_events = {k: v for v, k in event_id.items()} 106 | labels = np.array([inv_events[e] for e in epochs.events[:, -1]]) 107 | 108 | # if only one band, return a 3D array, otherwise return a 4D 109 | if len(self.filters) == 1: 110 | X = X[0] 111 | else: 112 | X = np.array(X).transpose((1, 2, 3, 0)) 113 | 114 | metadata = pd.DataFrame(index=range(len(labels))) 115 | return X, labels, metadata 116 | 117 | 118 | def get_data(self, dataset, subjects=None, return_epochs=False): 119 | 120 | if return_epochs: 121 | raise ValueError("Only return_epochs=False is supported.") 122 | 123 | rep = self._get_rep(dataset) 124 | cache_dir = self._get_cache_dir(rep) 125 | os.makedirs(cache_dir, exist_ok=True) 126 | 127 | X = [] if return_epochs else np.array([]) 128 | labels = [] 129 | metadata = pd.Series([]) 130 | 131 | if subjects is None: 132 | subjects = dataset.subject_list 133 | 134 | if not os.path.isfile(os.path.join(cache_dir, 'repr.json')): 135 | with open(os.path.join(cache_dir, 'repr.json'), 'w+') as f: 136 | f.write(self._get_rep(dataset)) 137 | 138 | for subject in subjects: 139 | if not os.path.isfile(os.path.join(cache_dir, f'{subject}.npy')): 140 | # compute 141 | x, lbs, meta = super().get_data(dataset, [subject], return_epochs) 142 | np.save(os.path.join(cache_dir, f'{subject}.npy'), x) 143 | meta['label'] = lbs 144 | meta.to_csv(os.path.join(cache_dir, f'{subject}.csv'), index=False) 145 | log.info(f'saved cached data in directory {cache_dir}') 146 | 147 | # load from cache 148 | log.info(f'loading cached data from directory {cache_dir}') 149 | x = np.load(os.path.join(cache_dir, f'{subject}.npy'), mmap_mode ='r') 150 | meta = pd.read_csv(os.path.join(cache_dir, f'{subject}.csv')) 151 | lbs = meta['label'].tolist() 152 | 153 | if return_epochs: 154 | X.append(x) 155 | else: 156 | X = np.append(X, x, axis=0) if len(X) else x 157 | labels = np.append(labels, lbs, axis=0) 158 | metadata = pd.concat([metadata, meta], ignore_index=True) 159 | 160 | return X, labels, metadata 161 | 162 | def get_info(self, dataset): 163 | # check if the info has been saved 164 | rep = self._get_rep(dataset) 165 | cache_dir = self._get_cache_dir(rep) 166 | os.makedirs(cache_dir, exist_ok=True) 167 | info_file = os.path.join(cache_dir, f'raw-info.fif') 168 | if not os.path.isfile(info_file): 169 | x, _, _ = super().get_data(dataset, [dataset.subject_list[0]], True) 170 | info = x.info 171 | write_info(info_file, info) 172 | log.info(f'saved cached info in directory {cache_dir}') 173 | else: 174 | log.info(f'loading cached info from directory {cache_dir}') 175 | info = read_info(info_file) 176 | return info 177 | 178 | def __repr__(self) -> str: 179 | return json.dumps({self.__class__.__name__: self.__dict__}) 180 | 181 | 182 | class CachedMotorImagery(CachedParadigm, MotorImagery): 183 | 184 | def __init__(self, **kwargs): 185 | n_classes = len(kwargs['events']) 186 | super().__init__(n_classes=n_classes, **kwargs) 187 | 188 | 189 | class CachedFilterBankMotorImagery(CachedParadigm, FilterBankMotorImagery): 190 | 191 | def __init__(self, **kwargs): 192 | n_classes = len(kwargs['events']) 193 | super().__init__(n_classes=n_classes, **kwargs) 194 | 195 | -------------------------------------------------------------------------------- /LieBN_TSMNet/library/utils/moabb/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/library/utils/moabb/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/library/utils/pyriemann/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.linalg import eigvalsh 3 | from pyriemann.tangentspace import TangentSpace 4 | 5 | def squared_airm(A, B): 6 | return np.square(np.log(eigvalsh(A, B))).sum() 7 | 8 | def airm(A,B): 9 | return np.sqrt(squared_airm(A,B)) 10 | 11 | def geom_mean(As): 12 | ts = TangentSpace() 13 | ts.fit(As) 14 | return ts.reference_ 15 | 16 | def tsm(As): 17 | ts = TangentSpace() 18 | ts.fit(As) 19 | return ts.transform(As) -------------------------------------------------------------------------------- /LieBN_TSMNet/library/utils/torch/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/library/utils/torch/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/BaseBatchNorm.py: -------------------------------------------------------------------------------- 1 | 2 | from enum import Enum 3 | import torch.nn as nn 4 | 5 | class BatchNormDispersion(Enum): 6 | NONE = 'mean' 7 | SCALAR = 'scalar' 8 | VECTOR = 'vector' 9 | class BatchNormTestStatsMode(Enum): 10 | BUFFER = 'buffer' 11 | REFIT = 'refit' 12 | ADAPT = 'adapt' 13 | class BatchNormTestStatsInterface: 14 | def set_test_stats_mode(self, mode: BatchNormTestStatsMode): 15 | pass 16 | 17 | class BaseBatchNorm(nn.Module, BatchNormTestStatsInterface): 18 | def __init__(self, eta=1.0, eta_test=0.1, test_stats_mode: BatchNormTestStatsMode = BatchNormTestStatsMode.BUFFER): 19 | super().__init__() 20 | self.eta = eta 21 | self.eta_test = eta_test 22 | self.test_stats_mode = test_stats_mode 23 | 24 | def set_test_stats_mode(self, mode: BatchNormTestStatsMode): 25 | self.test_stats_mode = mode 26 | 27 | -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/LieBNImpl.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Ziheng Chen 3 | Please cite the paper below if you use the code: 4 | 5 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024 6 | 7 | Copyright (C) 2024 Ziheng Chen 8 | All rights reserved. 9 | """ 10 | 11 | from builtins import NotImplementedError 12 | from typing import Tuple 13 | import torch as th 14 | import torch.nn as nn 15 | 16 | from geoopt.tensor import ManifoldParameter, ManifoldTensor 17 | from .manifolds import SymmetricPositiveDefinite 18 | from . import functionals 19 | from spdnets.BaseBatchNorm import BaseBatchNorm,BatchNormDispersion,BatchNormTestStatsMode 20 | 21 | 22 | class SPDLieBatchNormImpl(BaseBatchNorm): 23 | def __init__(self, shape: Tuple[int, ...] or th.Size, batchdim: int, 24 | eta=1., eta_test=0.1, 25 | karcher_steps: int = 1, learn_mean=False, learn_std=True, 26 | dispersion: BatchNormDispersion = BatchNormDispersion.SCALAR, 27 | eps=1e-5, mean=None, std=None, 28 | metric='AIM',theta: float =1.,alpha: float=1.0,beta: float=0., 29 | **kwargs): 30 | super().__init__(eta, eta_test) 31 | if metric != "AIM" and metric != "LCM" and metric != "LEM" : 32 | raise Exception('unknown metric {}'.format(metric)) 33 | self.metric = metric;self.theta = th.tensor(theta);self.alpha = th.tensor(alpha);self.beta = th.tensor(beta) 34 | self.identity = th.eye(shape[-1]) 35 | 36 | # the last two dimensions are used for SPD manifold 37 | assert (shape[-1] == shape[-2]) 38 | 39 | if dispersion == BatchNormDispersion.VECTOR: 40 | raise NotImplementedError() 41 | 42 | self.dispersion = dispersion 43 | self.learn_mean = learn_mean 44 | self.learn_std = learn_std 45 | self.batchdim = batchdim 46 | self.karcher_steps = karcher_steps 47 | self.eps = eps 48 | 49 | init_mean = th.diag_embed(th.ones(shape[:-1], **kwargs)) 50 | init_var = th.ones((*shape[:-2], 1), **kwargs) 51 | 52 | self.register_buffer('running_mean', ManifoldTensor(init_mean, 53 | manifold=SymmetricPositiveDefinite())) 54 | self.register_buffer('running_var', init_var) 55 | self.register_buffer('running_mean_test', ManifoldTensor(init_mean, 56 | manifold=SymmetricPositiveDefinite())) 57 | self.register_buffer('running_var_test', init_var) 58 | 59 | if mean is not None: 60 | self.mean = mean 61 | else: 62 | if self.learn_mean: 63 | self.mean = ManifoldParameter(init_mean.clone(), manifold=SymmetricPositiveDefinite()) 64 | else: 65 | self.mean = ManifoldTensor(init_mean.clone(), manifold=SymmetricPositiveDefinite()) 66 | 67 | if self.dispersion is not BatchNormDispersion.NONE: 68 | if std is not None: 69 | self.std = std 70 | else: 71 | if self.learn_std: 72 | self.std = nn.parameter.Parameter(init_var.clone()) 73 | else: 74 | self.std = init_var.clone() 75 | 76 | @th.no_grad() 77 | def initrunningstats(self, X): 78 | self.running_mean.data = self.cal_geom_mean(X) 79 | self.running_mean_test.data = self.running_mean.data.clone() 80 | 81 | if self.dispersion is BatchNormDispersion.SCALAR: 82 | self.running_var.data = self.cal_geom_var(X, self.running_mean) 83 | self.running_var_test.data = self.running_var.data.clone() 84 | 85 | def forward(self, X): 86 | X_deformed = self.deformation(X); 87 | if self.learn_mean: 88 | weight = self.deformation(self.mean) 89 | 90 | if self.training: 91 | # compute batch mean 92 | batch_mean = self.cal_geom_mean(X_deformed) 93 | # update the running_mean, running_mean_test 94 | self.updating_running_means(batch_mean) 95 | # compute batch vars w.r.t. running_mean and running_mean_test, update the running_var and running_var_test 96 | if self.dispersion is not BatchNormDispersion.NONE: 97 | batch_var = self.cal_geom_var(X_deformed,self.running_mean) 98 | batch_var_test = self.cal_geom_var(X_deformed, self.running_mean_test) 99 | self.updating_running_var(batch_var,batch_var_test) 100 | else: 101 | if self.test_stats_mode == BatchNormTestStatsMode.BUFFER: 102 | pass # nothing to do: use the ones in the buffer 103 | elif self.test_stats_mode == BatchNormTestStatsMode.REFIT: 104 | self.initrunningstats(X_deformed) 105 | elif self.test_stats_mode == BatchNormTestStatsMode.ADAPT: 106 | raise NotImplementedError() 107 | 108 | rm = self.running_mean if self.training else self.running_mean_test 109 | if self.dispersion is BatchNormDispersion.SCALAR: 110 | rv = self.running_var if self.training else self.running_var_test 111 | 112 | # subtracting mean 113 | X_centered = self.Left_translation(X_deformed,rm,True) 114 | # scaling and shifting 115 | if self.dispersion is BatchNormDispersion.SCALAR: 116 | X_scaled = self.scaling(X_centered, rv,self.std) 117 | # biasing 118 | X_normalized = self.Left_translation(X_scaled, weight,False) if self.learn_mean else X_scaled 119 | # inv of deformation 120 | X_new = self.inv_deformation(X_normalized) 121 | 122 | return X_new 123 | 124 | def spd_power(self,X): 125 | if self.theta == 1.: 126 | X_power = X 127 | else: 128 | X_power = functionals.sym_powm.apply(X, self.theta) 129 | return X_power 130 | 131 | def inv_power(self,X): 132 | if self.theta == 1.: 133 | X_power = X 134 | else: 135 | X_power = functionals.sym_powm.apply(X, 1/self.theta) 136 | return X_power 137 | def deformation(self,X): 138 | 139 | if self.metric=='AIM': 140 | X_deformed = self.spd_power(X) 141 | elif self.metric == 'LEM': 142 | X_deformed = functionals.sym_logm.apply(X) 143 | elif self.metric == 'LCM': 144 | X_power = self.spd_power(X) 145 | L = th.linalg.cholesky(X_power) 146 | diag_part = th.diag_embed(th.log(th.diagonal(L, dim1=-2, dim2=-1))) 147 | X_deformed = L.tril(-1) + diag_part 148 | 149 | return X_deformed 150 | 151 | def inv_deformation(self,X): 152 | if self.metric=='AIM': 153 | X_inv_deformed = self.inv_power(X) 154 | elif self.metric == 'LEM': 155 | X_inv_deformed = functionals.sym_expm.apply(X) 156 | elif self.metric == 'LCM': 157 | Cho = X.tril(-1) + th.diag_embed(th.exp(th.diagonal(X, dim1=-2, dim2=-1))) 158 | spd = Cho.matmul(Cho.transpose(-1,-2)) 159 | X_inv_deformed = self.inv_power(spd) 160 | return X_inv_deformed 161 | 162 | def cal_geom_mean(self, X): 163 | """Frechet mean""" 164 | if self.metric == 'AIM': 165 | mean = self.KF_AIM(X.detach()) 166 | elif self.metric == 'LEM' or self.metric == 'LCM': 167 | mean = X.detach().mean(dim=self.batchdim, keepdim=True) 168 | 169 | return mean 170 | def cal_geom_var(self, X, rm): 171 | """Frechet variance w.r.t. rm""" 172 | spd = X.detach() 173 | if self.metric == 'AIM': 174 | rm_invsq = functionals.sym_invsqrtm.apply(rm) 175 | if self.beta == 0.: 176 | dists = self.alpha * th.linalg.matrix_norm( 177 | functionals.sym_logm.apply(rm_invsq @ spd @ rm_invsq)).square() 178 | else: 179 | dists = self.alpha * th.linalg.matrix_norm(functionals.sym_logm.apply(rm_invsq @ spd @ rm_invsq)).square()\ 180 | + self.beta * th.logdet(th.linalg.solve(rm,spd)).square() 181 | 182 | elif self.metric == 'LEM' or self.metric == 'LCM': 183 | tmp = spd - rm 184 | if self.beta == 0.: 185 | dists = self.alpha * th.linalg.matrix_norm(tmp).square() 186 | else: 187 | item1 = th.linalg.matrix_norm(tmp) 188 | item2 = functionals.trace(tmp) 189 | dists = self.alpha * item1.square() + self.beta * item2.square() 190 | 191 | var = dists.mean(dim=self.batchdim, keepdim=True) 192 | if self.metric == 'AIM' or self.metric == 'LCM': 193 | var_final = var* (1/(self.theta**2)) 194 | else: 195 | var_final = var 196 | return var_final.unsqueeze(dim=0) 197 | def Left_translation(self, X, P, is_inverse): 198 | """Left translation by P""" 199 | if self.metric == 'AIM': 200 | L = th.linalg.cholesky(P); 201 | if is_inverse: 202 | tmp = th.linalg.solve(L, X) 203 | X_new = th.linalg.solve(L.transpose(-1, -2), tmp, left=False) 204 | else: 205 | X_new = L.matmul(X).matmul(L.transpose(-1, -2)) 206 | 207 | elif self.metric == 'LEM' or self.metric == 'LCM': 208 | X_new = X - P if is_inverse else X + P 209 | 210 | return X_new 211 | def scaling(self, X, var,scale): 212 | """Scaling by variance""" 213 | factor = scale / (var+self.eps).sqrt() 214 | if self.metric == 'AIM': 215 | X_new = functionals.sym_powm.apply(X, factor) 216 | elif self.metric == 'LEM' or self.metric == 'LCM': 217 | X_new = X * factor 218 | 219 | return X_new 220 | def updating_running_means(self,batch_mean): 221 | """updating running means""" 222 | with th.no_grad(): 223 | if self.metric == 'AIM': 224 | self.running_mean.data = functionals.spd_2point_interpolation(self.running_mean, batch_mean,self.eta) 225 | self.running_mean_test.data = functionals.spd_2point_interpolation(self.running_mean_test, batch_mean,self.eta_test) 226 | elif self.metric == 'LEM' or self.metric == 'LCM': 227 | self.running_mean.data = (1-self.eta) * self.running_mean+ batch_mean * self.eta 228 | self.running_mean_test.data = (1 - self.eta_test) * self.running_mean_test + batch_mean * self.eta_test 229 | 230 | def updating_running_var(self, batch_var, batch_var_test): 231 | """updating running vars""" 232 | with th.no_grad(): 233 | if self.dispersion is BatchNormDispersion.SCALAR: 234 | self.running_var = (1. - self.eta) * self.running_var + self.eta * batch_var 235 | self.running_var_test = (1. - self.eta_test) * self.running_var_test + self.eta_test * batch_var_test 236 | 237 | def KF_AIM(self,X,karcher_steps=1): 238 | ''' 239 | Function which computes the Riemannian barycenter for a batch of data using the Karcher flow 240 | Input x is a batch of SPD matrices (batch_size,1,n,n) to average 241 | Output is (n,n) Riemannian mean 242 | ''' 243 | batch_mean = X.mean(dim=self.batchdim, keepdim=True) 244 | for _ in range(karcher_steps): 245 | bm_sq, bm_invsq = functionals.sym_invsqrtm2.apply(batch_mean) 246 | XT = functionals.sym_logm.apply(bm_invsq @ X @ bm_invsq) 247 | GT = XT.mean(dim=self.batchdim, keepdim=True) 248 | batch_mean = bm_sq @ functionals.sym_expm.apply(GT) @ bm_sq 249 | return batch_mean 250 | -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/__init__.py -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/__pycache__/batchnorm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/__pycache__/batchnorm.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/__pycache__/functionals.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/__pycache__/functionals.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/__pycache__/manifolds.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/__pycache__/manifolds.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/__pycache__/modules.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/__pycache__/modules.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/manifolds.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Union, Optional, Tuple 3 | from geoopt.manifolds import Manifold 4 | from . import functionals 5 | 6 | __all__ = ["SymmetricPositiveDefinite"] 7 | 8 | class SymmetricPositiveDefinite(Manifold): 9 | """ 10 | Subclass of the SymmetricPositiveDefinite manifold using the 11 | affine invariant Riemannian metric (AIRM) as default metric 12 | """ 13 | 14 | __scaling__ = Manifold.__scaling__.copy() 15 | name = "SymmetricPositiveDefinite" 16 | ndim = 2 17 | reversible = False 18 | 19 | def __init__(self): 20 | super().__init__() 21 | 22 | def dist(self, x: torch.Tensor, y: torch.Tensor, keepdim) -> torch.Tensor: 23 | """ 24 | Computes the affine invariant Riemannian metric (AIM) 25 | """ 26 | inv_sqrt_x = functionals.sym_invsqrtm.apply(x) 27 | return torch.norm( 28 | functionals.sym_logm.apply(inv_sqrt_x @ y @ inv_sqrt_x), 29 | dim=[-1, -2], 30 | keepdim=keepdim, 31 | ) 32 | 33 | def _check_point_on_manifold( 34 | self, x: torch.Tensor, *, atol=1e-5, rtol=1e-5 35 | ) -> Union[Tuple[bool, Optional[str]], bool]: 36 | ok = torch.allclose(x, x.transpose(-1, -2), atol=atol, rtol=rtol) 37 | if not ok: 38 | return False, "`x != x.transpose` with atol={}, rtol={}".format(atol, rtol) 39 | e = torch.linalg.eigvalsh(x) 40 | ok = (e > -atol).min() 41 | if not ok: 42 | return False, "eigenvalues of x are not all greater than 0." 43 | return True, None 44 | 45 | def _check_vector_on_tangent( 46 | self, x: torch.Tensor, u: torch.Tensor, *, atol=1e-5, rtol=1e-5 47 | ) -> Union[Tuple[bool, Optional[str]], bool]: 48 | ok = torch.allclose(u, u.transpose(-1, -2), atol=atol, rtol=rtol) 49 | if not ok: 50 | return False, "`u != u.transpose` with atol={}, rtol={}".format(atol, rtol) 51 | return True, None 52 | 53 | def projx(self, x: torch.Tensor) -> torch.Tensor: 54 | symx = functionals.ensure_sym(x) 55 | return functionals.sym_abseig.apply(symx) 56 | 57 | def proju(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor: 58 | return functionals.ensure_sym(u) 59 | 60 | def egrad2rgrad(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor: 61 | return x @ self.proju(x, u) @ x 62 | 63 | def inner(self, x: torch.Tensor, u: torch.Tensor, v: Optional[torch.Tensor], keepdim) -> torch.Tensor: 64 | if v is None: 65 | v = u 66 | inv_x = functionals.sym_invm.apply(x) 67 | ret = torch.diagonal(inv_x @ u @ inv_x @ v, dim1=-2, dim2=-1).sum(-1) 68 | if keepdim: 69 | return torch.unsqueeze(torch.unsqueeze(ret, -1), -1) 70 | return ret 71 | 72 | def retr(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor: 73 | inv_x = functionals.sym_invm.apply(x) 74 | return functionals.ensure_sym(x + u + 0.5 * u @ inv_x @ u) 75 | # return self.expmap(x, u) 76 | 77 | def expmap(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor: 78 | sqrt_x, inv_sqrt_x = functionals.sym_invsqrtm2.apply(x) 79 | return sqrt_x @ functionals.sym_expm.apply(inv_sqrt_x @ u @ inv_sqrt_x) @ sqrt_x 80 | 81 | def logmap(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor: 82 | sqrt_x, inv_sqrt_x = functionals.sym_invsqrtm2.apply(x) 83 | return sqrt_x @ functionals.sym_logm.apply(inv_sqrt_x @ u @ inv_sqrt_x) @ sqrt_x 84 | 85 | def extra_repr(self) -> str: 86 | return "default_metric=AIM" 87 | 88 | def transp(self, x: torch.Tensor, y: torch.Tensor, v: torch.Tensor) -> torch.Tensor: 89 | 90 | xinvy = torch.linalg.solve(x.double(),y.double()) 91 | s, U = torch.linalg.eig(xinvy.transpose(-2,-1)) 92 | s = s.real 93 | U = U.real 94 | 95 | Ut = U.transpose(-2,-1) 96 | Esqm = torch.linalg.solve(Ut, torch.diag_embed(s.sqrt()) @ Ut).transpose(-2,-1).to(y.dtype) 97 | 98 | return Esqm @ v @ Esqm.transpose(-1,-2) 99 | 100 | def random(self, *size, dtype=None, device=None, **kwargs) -> torch.Tensor: 101 | tens = torch.randn(*size, dtype=dtype, device=device, **kwargs) 102 | tens = functionals.ensure_sym(tens) 103 | tens = functionals.sym_expm.apply(tens) 104 | return tens 105 | 106 | def barycenter(self, X : torch.Tensor, steps : int = 1, dim = 0) -> torch.Tensor: 107 | """ 108 | Compute several steps of the Kracher flow algorithm to estimate the 109 | Barycenter on the manifold. 110 | """ 111 | return functionals.spd_mean_kracher_flow(X, None, maxiter=steps, dim=dim, return_dist=False) 112 | 113 | def geodesic(self, A : torch.Tensor, B : torch.Tensor, t : torch.Tensor) -> torch.Tensor: 114 | """ 115 | Compute geodesic between two SPD tensors A and B and return 116 | point on the geodesic at length t \in [0,1] 117 | if t = 0, then A is returned 118 | if t = 1, then B is returned 119 | """ 120 | Asq, Ainvsq = functionals.sym_invsqrtm2.apply(A) 121 | return Asq @ functionals.sym_powm.apply(Ainvsq @ B @ Ainvsq, t) @ Asq 122 | 123 | def transp_via_identity(self, X : torch.Tensor, A : torch.Tensor, B : torch.Tensor) -> torch.Tensor: 124 | """ 125 | Parallel transport of the tensors in X around A to the identity matrix I 126 | Parallel transport from around the identity matrix to the new center (tensor B) 127 | """ 128 | Ainvsq = functionals.sym_invsqrtm.apply(A) 129 | Bsq = functionals.sym_sqrtm.apply(B) 130 | return Bsq @ (Ainvsq @ X @ Ainvsq) @ Bsq 131 | 132 | def transp_identity_rescale_transp(self, X : torch.Tensor, A : torch.Tensor, s : torch.Tensor, B : torch.Tensor) -> torch.Tensor: 133 | """ 134 | Parallel transport of the tensors in X around A to the identity matrix I 135 | Rescales the dispersion by the factor s 136 | Parallel transport from the identity to the new center (tensor B) 137 | """ 138 | Ainvsq = functionals.sym_invsqrtm.apply(A) 139 | Bsq = functionals.sym_sqrtm.apply(B) 140 | return Bsq @ functionals.sym_powm.apply(Ainvsq @ X @ Ainvsq, s) @ Bsq 141 | 142 | def transp_identity_rescale_rotate_transp(self, X : torch.Tensor, A : torch.Tensor, s : torch.Tensor, B : torch.Tensor, W : torch.Tensor) -> torch.Tensor: 143 | """ 144 | Parallel transport of the tensors in X around A to the identity matrix I 145 | Rescales the dispersion by the factor s 146 | Parallel transport from the identity to the new center (tensor B) 147 | """ 148 | Ainvsq = functionals.sym_invsqrtm.apply(A) 149 | Bsq = functionals.sym_sqrtm.apply(B) 150 | WBsq = W @ Bsq 151 | return WBsq.transpose(-2,-1) @ functionals.sym_powm.apply(Ainvsq @ X @ Ainvsq, s) @ WBsq 152 | -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseModel, FineTuneableModel, CPUModel, PatternInterpretableModel 2 | from .base import DomainAdaptBaseModel 3 | from .base import DomainAdaptFineTuneableModel, DomainAdaptJointTrainableModel 4 | 5 | from .eegnet import EEGNetv4, DANNEEGNet 6 | from .shconvnet import ShallowConvNet,DANNShallowConvNet,ShConvNetDSBN 7 | from .tsmnet import TSMNet, CNNNet 8 | from .tsmnetMLR import TSMNetMLR -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/models/__pycache__/base.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/models/__pycache__/base.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/models/__pycache__/dann.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/models/__pycache__/dann.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/models/__pycache__/eegnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/models/__pycache__/eegnet.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/models/__pycache__/shconvnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/models/__pycache__/shconvnet.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/models/__pycache__/tsmnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/models/__pycache__/tsmnet.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/models/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BaseModel(nn.Module): 6 | def __init__(self, nclasses=None, nchannels=None, nsamples=None, nbands=None, device=None, input_shape=None): 7 | super().__init__() 8 | self.device_ = device 9 | self.lossfn = torch.nn.CrossEntropyLoss() 10 | self.nclasses_ = nclasses 11 | self.nchannels_ = nchannels 12 | self.nsamples_ = nsamples 13 | self.nbands_ = nbands 14 | self.input_shape_ = input_shape 15 | 16 | # AUXILIARY METHODS 17 | def calculate_classification_accuracy(self, Y, Y_lat): 18 | Y_hat = Y_lat.argmax(1) 19 | acc = Y_hat.eq(Y).float().mean().item() 20 | P_hat = torch.softmax(Y_lat, dim=1) 21 | return acc, P_hat 22 | 23 | def calculate_objective(self, model_pred, y_true, model_inp=None): 24 | # Y_lat, l = self(X.to(self.device_), B) 25 | if isinstance(model_pred, (list, tuple)): 26 | y_class_hat = model_pred[0] 27 | else: 28 | y_class_hat = model_pred 29 | loss = self.lossfn(y_class_hat, y_true.to(y_class_hat.device)) 30 | return loss 31 | 32 | def get_hyperparameters(self): 33 | return dict(nchannels = self.nchannels_, 34 | nclasses=self.nclasses_, 35 | nsamples=self.nsamples_, 36 | nbands=self.nbands_) 37 | 38 | 39 | class CPUModel: 40 | pass 41 | 42 | 43 | class FineTuneableModel: 44 | def finetune(self, x, y, d): 45 | raise NotImplementedError() 46 | 47 | 48 | class DomainAdaptBaseModel(BaseModel): 49 | def __init__(self, domains = [], **kwargs): 50 | super().__init__(**kwargs) 51 | self.domains_ = domains 52 | 53 | 54 | class DomainAdaptFineTuneableModel(DomainAdaptBaseModel): 55 | def domainadapt_finetune(self, x, y, d, target_domains): 56 | raise NotImplementedError() 57 | 58 | 59 | class DomainAdaptJointTrainableModel(DomainAdaptBaseModel): 60 | def calculate_objective(self, model_pred, y_true, model_inp=None): 61 | # filter out masked observations 62 | keep = y_true != -1 # special label 63 | 64 | if isinstance(model_pred, (list, tuple)): 65 | y_class_hat = model_pred[0] 66 | else: 67 | y_class_hat = model_pred 68 | 69 | return super().calculate_objective(y_class_hat[keep], y_true[keep], None) 70 | 71 | 72 | class PatternInterpretableModel: 73 | def compute_patterns(self, x, y, d): 74 | raise NotImplementedError() 75 | -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/models/dann.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base import DomainAdaptJointTrainableModel 3 | import spdnets.modules as modules 4 | 5 | class DANNBase(DomainAdaptJointTrainableModel): 6 | """ 7 | Domain adeversarial neural network (DANN) proposed 8 | by Ganin et al. 2016, JMLR 9 | """ 10 | def __init__(self, daloss_scaling = 1., dann_mode = 'ganin2016', **kwargs): 11 | domains = kwargs['domains'] 12 | assert (domains.dtype == torch.long) 13 | kwargs['domains'] = domains.sort()[0] 14 | super().__init__(**kwargs) 15 | self.dann_mode_ = dann_mode 16 | 17 | if self.dann_mode_ == 'ganin2015': 18 | grad_reversal_scaling = daloss_scaling 19 | self.daloss_scaling_ = 1. 20 | elif self.dann_mode_ == 'ganin2016': 21 | grad_reversal_scaling = 1. 22 | self.daloss_scaling_ = daloss_scaling 23 | else: 24 | raise NotImplementedError() 25 | 26 | ndim_latent = self._ndim_latent() 27 | self.adversary_loss = torch.nn.CrossEntropyLoss() 28 | 29 | self.adversary = torch.nn.Sequential( 30 | torch.nn.Flatten(start_dim=1), 31 | modules.ReverseGradient(scaling=grad_reversal_scaling), 32 | torch.nn.Linear(ndim_latent, len(self.domains_)) 33 | ).to(self.device_) 34 | 35 | def _ndim_latent(self): 36 | raise NotImplementedError() 37 | 38 | def forward(self, l, d): 39 | # super().forward() 40 | # h = self.cnn(x[:,None,...]).flatten(start_dim=1) 41 | # y = self.classifier(h) 42 | y_domain = self.adversary(l) 43 | return y_domain 44 | 45 | def domainadapt(self, x, y, d, target_domain): 46 | pass # domain adaptation is done during the training process 47 | 48 | def calculate_objective(self, model_pred, y_true, model_inp): 49 | loss = super().calculate_objective(model_pred, y_true, model_inp) 50 | domain = model_inp['d'] 51 | y_dom_hat = model_pred[1] 52 | # check if all requested domains were declared 53 | assert ((self.domains_[..., None] == domain[None,...]).any(dim=0).all()) 54 | # assign to the class indices (buckets) 55 | y_dom = torch.bucketize(domain, self.domains_).to(y_dom_hat.device) 56 | 57 | adversarial_loss = self.adversary_loss(y_dom_hat, y_dom) 58 | loss = loss + self.daloss_scaling_ * adversarial_loss 59 | 60 | return loss -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/models/eegnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base import BaseModel 3 | from .dann import DANNBase 4 | import spdnets.modules as modules 5 | 6 | 7 | class EEGNetv4(BaseModel): 8 | def __init__(self, is_within = False, srate = 128, f1 = 8, d = 2, **kwargs): 9 | super().__init__(**kwargs) 10 | self.is_within_ = is_within 11 | self.srate_ = srate 12 | self.f1_ = f1 13 | self.d_ = d 14 | self.f2_ = self.f1_ * self.d_ 15 | momentum = 0.01 16 | 17 | kernel_length = int(self.srate_ // 2) 18 | nlatsamples_time = self.nsamples_ // 32 19 | 20 | temp2_kernel_length = int(self.srate_ // 2 // 4) 21 | 22 | if self.is_within_: 23 | drop_prob = 0.5 24 | else: 25 | drop_prob = 0.25 26 | 27 | bntemp = torch.nn.BatchNorm2d(self.f1_, momentum=momentum, affine=True, eps=1e-3) 28 | bnspat = torch.nn.BatchNorm2d(self.f1_ * self.d_, momentum=momentum, affine=True, eps=1e-3) 29 | 30 | self.cnn = torch.nn.Sequential( 31 | torch.nn.Conv2d(1,self.f1_,(1, kernel_length), bias=False, padding='same'), 32 | bntemp, 33 | modules.Conv2dWithNormConstraint(self.f1_, self.f1_ * self.d_, (self.nchannels_, 1), max_norm=1, 34 | stride=1, bias=False, groups=self.f1_, padding=(0, 0)), 35 | bnspat, 36 | torch.nn.ELU(), 37 | torch.nn.AvgPool2d(kernel_size=(1, 4), stride=(1, 4)), 38 | torch.nn.Dropout(p=drop_prob), 39 | torch.nn.Conv2d(self.f1_ * self.d_, self.f1_ * self.d_, (1, temp2_kernel_length), 40 | stride=1, bias=False, groups=self.f1_ * self.d_, padding='same'), 41 | torch.nn.Conv2d(self.f1_ * self.d_, self.f2_, (1, 1), 42 | stride=1, bias=False, padding=(0, 0)), 43 | torch.nn.BatchNorm2d(self.f2_, momentum=momentum, affine=True, eps=1e-3), 44 | torch.nn.ELU(), 45 | torch.nn.AvgPool2d(kernel_size=(1, 8), stride=(1, 8)), 46 | torch.nn.Dropout(p=drop_prob), 47 | ).to(self.device_) 48 | 49 | self.classifier = torch.nn.Sequential( 50 | torch.nn.Flatten(start_dim=1), 51 | modules.LinearWithNormConstraint(self.f2_ * nlatsamples_time, self.nclasses_, max_norm=0.25) 52 | ).to(self.device_) 53 | 54 | def get_hyperparameters(self): 55 | kwargs = super().get_hyperparameters() 56 | kwargs['nsamples'] = self.nsamples_ 57 | kwargs['is_within_subject'] = self.is_within_subject_ 58 | kwargs['srate'] = self.srate_ 59 | kwargs['f1'] = self.f1_ 60 | kwargs['d'] = self.d_ 61 | return kwargs 62 | 63 | def forward(self, x, d): 64 | l = self.cnn(x[:,None,...]) 65 | y = self.classifier(l) 66 | return y, l 67 | 68 | 69 | class DANNEEGNet(DANNBase, EEGNetv4): 70 | """ 71 | Domain adeversarial neural network (DANN) proposed for EEG MI classification 72 | by Ozdenizci et al. 2020, IEEE Access 73 | """ 74 | def __init__(self, daloss_scaling = 0.03, dann_mode = 'ganin2016', **kwargs): 75 | kwargs['daloss_scaling'] = daloss_scaling 76 | kwargs['dann_mode'] = dann_mode 77 | super().__init__(**kwargs) 78 | 79 | def _ndim_latent(self): 80 | return self.classifier[-1].weight.shape[-1] 81 | 82 | def forward(self, x, d): 83 | y, l = EEGNetv4.forward(self, x, d) 84 | y_domain = DANNBase.forward(self, l, d) 85 | return y, y_domain -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/models/shconvnet.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import torch 3 | from .base import BaseModel, DomainAdaptFineTuneableModel 4 | from .dann import DANNBase 5 | import spdnets.batchnorm as bn 6 | import spdnets.modules as modules 7 | 8 | 9 | class ShallowConvNet(BaseModel): 10 | def __init__(self, spatial_filters = 40, temporal_filters = 40, pdrop = 0.5, **kwargs): 11 | super().__init__(**kwargs) 12 | self.spatial_filters_ = spatial_filters 13 | self.temporal_filters_ = temporal_filters 14 | 15 | temp_cnn_kernel = 25 16 | temp_pool_kernel = 75 17 | temp_pool_stride = 15 18 | ntempconvout = int((self.nsamples_ - 1*(temp_cnn_kernel-1) - 1)/1 + 1) 19 | navgpoolout = int((ntempconvout - temp_pool_kernel)/temp_pool_stride + 1) 20 | 21 | self.bn = torch.nn.BatchNorm2d(self.spatial_filters_) 22 | drop = torch.nn.Dropout(p=pdrop) 23 | 24 | self.cnn = torch.nn.Sequential( 25 | torch.nn.Conv2d(1,self.temporal_filters_,(1, temp_cnn_kernel)), 26 | torch.nn.Conv2d(self.temporal_filters_, self.spatial_filters_,(self.nchannels_, 1)), 27 | ).to(self.device_) 28 | self.pool = torch.nn.Sequential( 29 | modules.MySquare(), 30 | torch.nn.AvgPool2d(kernel_size=(1, temp_pool_kernel), stride=(1, temp_pool_stride)), 31 | modules.MyLog(), 32 | drop, 33 | torch.nn.Flatten(start_dim=1), 34 | ).to(self.device_) 35 | self.classifier = torch.nn.Sequential( 36 | torch.nn.Linear(self.spatial_filters_ * navgpoolout, self.nclasses_), 37 | ).to(self.device_) 38 | 39 | def forward(self,x, d): 40 | l = self.cnn(x.to(self.device_)[:,None,...]) 41 | l = self.bn(l) 42 | l = self.pool(l) 43 | y = self.classifier(l) 44 | return y, l 45 | 46 | 47 | class DANNShallowConvNet(DANNBase, ShallowConvNet): 48 | """ 49 | Domain adeversarial neural network (DANN) proposed for EEG MI classification 50 | by Ozdenizci et al. 2020, IEEE Access 51 | """ 52 | def __init__(self, daloss_scaling = 0.05, dann_mode = 'ganin2016', **kwargs): 53 | kwargs['daloss_scaling'] = daloss_scaling 54 | kwargs['dann_mode'] = dann_mode 55 | super().__init__(**kwargs) 56 | 57 | def _ndim_latent(self): 58 | return self.classifier[-1].weight.shape[-1] 59 | 60 | def forward(self, x, d): 61 | y, l = ShallowConvNet.forward(self, x, d) 62 | y_domain = DANNBase.forward(self, l, d) 63 | return y, y_domain 64 | 65 | 66 | class ShConvNetDSBN(ShallowConvNet, DomainAdaptFineTuneableModel): 67 | def __init__(self, 68 | bnorm_dispersion : Union[str, bn.BatchNormDispersion] = bn.BatchNormDispersion.VECTOR, 69 | **kwargs): 70 | super().__init__(**kwargs) 71 | 72 | if isinstance(bnorm_dispersion, str): 73 | bnorm_dispersion = bn.BatchNormDispersion[bnorm_dispersion] 74 | 75 | self.bn = bn.AdaMomDomainBatchNorm((1, self.spatial_filters_, 1, 1), 76 | batchdim=[0,2,3], # same as batch norm 2D 77 | domains=self.domains_, 78 | dispersion=bnorm_dispersion, 79 | eta=1., eta_test=.1).to(self.device_) 80 | 81 | def forward(self,x, d): 82 | l = self.cnn(x.to(self.device_)[:,None,...]) 83 | l = self.bn(l,d.to(device=self.device_)) 84 | l = self.pool(l) 85 | y = self.classifier(l) 86 | return y, l 87 | 88 | def domainadapt_finetune(self, x, y, d, target_domains): 89 | self.bn.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT) 90 | for du in d.unique(): 91 | self.forward(x[d==du], d[d==du]) 92 | self.bn.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER) 93 | -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/models/tsmnet.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | import torch 3 | 4 | import spdnets.modules as modules 5 | import spdnets.batchnorm as bn 6 | import spdnets.Liebatchnorm as LieBN 7 | from spdnets.BaseBatchNorm import BatchNormDispersion 8 | from .base import DomainAdaptFineTuneableModel, FineTuneableModel, PatternInterpretableModel 9 | 10 | 11 | class TSMNet(DomainAdaptFineTuneableModel, FineTuneableModel, PatternInterpretableModel): 12 | def __init__(self, temporal_filters, spatial_filters = 40, 13 | subspacedims = 20, 14 | temp_cnn_kernel = 25, 15 | bnorm : Optional[str] = 'spdbn', 16 | bnorm_dispersion : Union[str, BatchNormDispersion] = BatchNormDispersion.SCALAR, 17 | metric: str ='AIM',theta: float =1., alpha: float=1.0, beta: float=0.,learn_mean=False, 18 | **kwargs): 19 | super().__init__(**kwargs) 20 | self.learn_mean = learn_mean;self.metric = metric;self.theta = torch.tensor(theta);self.alpha = torch.tensor(alpha);self.beta = torch.tensor(beta) 21 | 22 | self.temporal_filters_ = temporal_filters 23 | self.spatial_filters_ = spatial_filters 24 | self.subspacedimes = subspacedims 25 | self.bnorm_ = bnorm 26 | self.spd_device_ = torch.device('cpu') 27 | if isinstance(bnorm_dispersion, str): 28 | self.bnorm_dispersion_ = BatchNormDispersion[bnorm_dispersion] 29 | else: 30 | self.bnorm_dispersion_ = bnorm_dispersion 31 | 32 | tsdim = int(subspacedims*(subspacedims+1)/2) 33 | 34 | self.cnn = torch.nn.Sequential( 35 | torch.nn.Conv2d(1, self.temporal_filters_, kernel_size=(1,temp_cnn_kernel), 36 | padding='same', padding_mode='reflect'), 37 | torch.nn.Conv2d(self.temporal_filters_, self.spatial_filters_,(self.nchannels_, 1)), 38 | torch.nn.Flatten(start_dim=2), 39 | ).to(self.device_) 40 | 41 | self.cov_pooling = torch.nn.Sequential( 42 | modules.CovariancePool(), 43 | ) 44 | 45 | if self.bnorm_ == 'spddsbn': 46 | self.spddsbnorm = bn.AdaMomDomainSPDBatchNorm((1,subspacedims,subspacedims), batchdim=0, 47 | domains=self.domains_, 48 | learn_mean=self.learn_mean,learn_std=True, 49 | dispersion=self.bnorm_dispersion_, 50 | eta=1., eta_test=.1, dtype=torch.double, device=self.spd_device_) 51 | elif self.bnorm_ == 'LieBN': 52 | self.spddsbnorm = LieBN.AdaMomDomainSPDLieBatchNorm((1,subspacedims,subspacedims), batchdim=0, 53 | domains=self.domains_, 54 | learn_mean=self.learn_mean,learn_std=True, 55 | dispersion=self.bnorm_dispersion_, 56 | eta=1., eta_test=.1, dtype=torch.double, device=self.spd_device_, 57 | metric=self.metric,theta=self.theta,beta=self.beta,alpha=self.alpha) 58 | 59 | elif self.bnorm_ is not None: 60 | raise NotImplementedError('requested undefined batch normalization method.') 61 | 62 | self.spdnet = torch.nn.Sequential( 63 | modules.BiMap((1,self.spatial_filters_,subspacedims), dtype=torch.double, device=self.spd_device_), 64 | modules.ReEig(threshold=1e-4), 65 | ) 66 | self.logeig = torch.nn.Sequential( 67 | modules.LogEig(subspacedims), 68 | torch.nn.Flatten(start_dim=1), 69 | ) 70 | self.classifier = torch.nn.Sequential( 71 | torch.nn.Linear(tsdim,self.nclasses_).double(), 72 | ).to(self.spd_device_) 73 | 74 | def to(self, device: Optional[Union[int, torch.device]] = None, dtype: Optional[Union[int, torch.dtype]] = None, non_blocking: bool = False): 75 | if device is not None: 76 | self.device_ = device 77 | self.cnn.to(self.device_) 78 | return super().to(device=None, dtype=dtype, non_blocking=non_blocking) 79 | 80 | def forward(self, x, d, return_latent=True, return_prebn=False, return_postbn=False): 81 | out = () 82 | h = self.cnn(x.to(device=self.device_)[:,None,...]) 83 | C = self.cov_pooling(h).to(device=self.spd_device_, dtype=torch.double) 84 | l = self.spdnet(C) 85 | out += (l,) if return_prebn else () 86 | l = self.spdbnorm(l) if hasattr(self, 'spdbnorm') else l 87 | l = self.spddsbnorm(l,d.to(device=self.spd_device_)) if hasattr(self, 'spddsbnorm') else l 88 | out += (l,) if return_postbn else () 89 | l = self.logeig(l) 90 | l = self.tsbnorm(l) if hasattr(self, 'tsbnorm') else l 91 | l = self.tsdsbnorm(l,d) if hasattr(self, 'tsdsbnorm') else l 92 | out += (l,) if return_latent else () 93 | y = self.classifier(l) 94 | out = y if len(out) == 0 else (y, *out[::-1]) 95 | return out 96 | 97 | def domainadapt_finetune(self, x, y, d, target_domains): 98 | if hasattr(self, 'spddsbnorm'): 99 | self.spddsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT) 100 | if hasattr(self, 'tsdsbnorm'): 101 | self.tsdsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT) 102 | 103 | with torch.no_grad(): 104 | for du in d.unique(): 105 | self.forward(x[d==du], d[d==du]) 106 | 107 | if hasattr(self, 'spddsbnorm'): 108 | self.spddsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER) 109 | if hasattr(self, 'tsdsbnorm'): 110 | self.tsdsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER) 111 | 112 | def finetune(self, x, y, d): 113 | if hasattr(self, 'spdbnorm'): 114 | self.spdbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT) 115 | if hasattr(self, 'tsbnorm'): 116 | self.tsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT) 117 | 118 | with torch.no_grad(): 119 | self.forward(x, d) 120 | 121 | if hasattr(self, 'spdbnorm'): 122 | self.spdbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER) 123 | if hasattr(self, 'tsbnorm'): 124 | self.tsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER) 125 | 126 | def compute_patterns(self, x, y, d): 127 | pass 128 | 129 | 130 | 131 | class CNNNet(DomainAdaptFineTuneableModel, FineTuneableModel): 132 | def __init__(self, temporal_filters, spatial_filters = 40, 133 | temp_cnn_kernel = 25, 134 | bnorm : Optional[str] = 'bn', 135 | bnorm_dispersion : Union[str, BatchNormDispersion] = BatchNormDispersion.SCALAR, 136 | **kwargs): 137 | super().__init__(**kwargs) 138 | 139 | self.temporal_filters_ = temporal_filters 140 | self.spatial_filters_ = spatial_filters 141 | self.bnorm_ = bnorm 142 | 143 | if isinstance(bnorm_dispersion, str): 144 | self.bnorm_dispersion_ = BatchNormDispersion[bnorm_dispersion] 145 | else: 146 | self.bnorm_dispersion_ = bnorm_dispersion 147 | 148 | self.cnn = torch.nn.Sequential( 149 | torch.nn.Conv2d(1, self.temporal_filters_, kernel_size=(1,temp_cnn_kernel), 150 | padding='same', padding_mode='reflect'), 151 | torch.nn.Conv2d(self.temporal_filters_, self.spatial_filters_,(self.nchannels_, 1)), 152 | torch.nn.Flatten(start_dim=2), 153 | ).to(self.device_) 154 | 155 | self.cov_pooling = torch.nn.Sequential( 156 | modules.CovariancePool(), 157 | ) 158 | 159 | if self.bnorm_ == 'bn': 160 | self.bnorm = bn.AdaMomBatchNorm((1, self.spatial_filters_), batchdim=0, dispersion=self.bnorm_dispersion_, 161 | eta=1., eta_test=.1).to(self.device_) 162 | elif self.bnorm_ == 'dsbn': 163 | self.dsbnorm = bn.AdaMomDomainBatchNorm((1, self.spatial_filters_), batchdim=0, 164 | domains=self.domains_, 165 | dispersion=self.bnorm_dispersion_, 166 | eta=1., eta_test=.1).to(self.device_) 167 | elif self.bnorm_ is not None: 168 | raise NotImplementedError('requested undefined batch normalization method.') 169 | 170 | self.logarithm = torch.nn.Sequential( 171 | modules.MyLog(), 172 | torch.nn.Flatten(start_dim=1), 173 | ) 174 | self.classifier = torch.nn.Sequential( 175 | torch.nn.Linear(self.spatial_filters_,self.nclasses_), 176 | ).to(self.device_) 177 | 178 | def forward(self, x, d, return_latent=True): 179 | out = () 180 | h = self.cnn(x.to(device=self.device_)[:,None,...]) 181 | C = self.cov_pooling(h) 182 | l = torch.diagonal(C, dim1=-2, dim2=-1) 183 | l = self.logarithm(l) 184 | l = self.bnorm(l) if hasattr(self, 'bnorm') else l 185 | l = self.dsbnorm(l,d) if hasattr(self, 'dsbnorm') else l 186 | out += (l,) if return_latent else () 187 | y = self.classifier(l) 188 | out = y if len(out) == 0 else (y, *out[::-1]) 189 | return out 190 | 191 | def domainadapt_finetune(self, x, y, d, target_domains): 192 | if hasattr(self, 'dsbnorm'): 193 | self.dsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT) 194 | 195 | with torch.no_grad(): 196 | for du in d.unique(): 197 | self.forward(x[d==du], d[d==du]) 198 | 199 | if hasattr(self, 'dsbnorm'): 200 | self.dsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER) 201 | 202 | def finetune(self, x, y, d): 203 | if hasattr(self, 'bnorm'): 204 | self.bnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT) 205 | 206 | with torch.no_grad(): 207 | self.forward(x, d) 208 | 209 | if hasattr(self, 'bnorm'): 210 | self.bnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER) 211 | -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple 3 | import torch 4 | from torch import Tensor 5 | from torch.nn.parameter import Parameter 6 | from torch.types import Number 7 | import torch.nn as nn 8 | from geoopt.tensor import ManifoldParameter 9 | from geoopt.manifolds import Stiefel, Sphere 10 | from . import functionals 11 | 12 | class Conv2dWithNormConstraint(torch.nn.Conv2d): 13 | def __init__(self, *args, max_norm=1, **kwargs): 14 | self.max_norm = max_norm 15 | super(Conv2dWithNormConstraint, self).__init__(*args, **kwargs) 16 | 17 | def forward(self, x): 18 | self.weight.data = torch.renorm( 19 | self.weight.data, p=2, dim=0, maxnorm=self.max_norm 20 | ) 21 | return super(Conv2dWithNormConstraint, self).forward(x) 22 | 23 | 24 | class LinearWithNormConstraint(torch.nn.Linear): 25 | def __init__(self, *args, max_norm=1, **kwargs): 26 | self.max_norm = max_norm 27 | super(LinearWithNormConstraint, self).__init__(*args, **kwargs) 28 | 29 | def forward(self, x): 30 | self.weight.data = torch.renorm( 31 | self.weight.data, p=2, dim=0, maxnorm=self.max_norm 32 | ) 33 | return super(LinearWithNormConstraint, self).forward(x) 34 | 35 | 36 | class MySquare(torch.nn.Module): 37 | def __init__(self): 38 | super().__init__() 39 | def forward(self, x): 40 | return x.square() 41 | 42 | class MyLog(torch.nn.Module): 43 | def __init__(self, eps = 1e-4): 44 | super().__init__() 45 | self.eps = eps 46 | def forward(self, x): 47 | return torch.log(x + self.eps) 48 | 49 | 50 | class MyConv2d(torch.nn.Conv2d): 51 | def __init__(self, *args, **kwargs): 52 | super(MyConv2d, self).__init__(*args, **kwargs) 53 | 54 | self.convshape = self.weight.shape 55 | w0 = self.weight.data.flatten(start_dim=1) 56 | self.weight = ManifoldParameter(w0 / w0.norm(dim=-1, keepdim=True), manifold=Sphere()) 57 | 58 | def forward(self, x): 59 | return self._conv_forward(x, self.weight.view(self.convshape), self.bias) 60 | 61 | 62 | class UnitNormLinear(torch.nn.Linear): 63 | def __init__(self, *args, **kwargs): 64 | super(UnitNormLinear, self).__init__(*args, **kwargs) 65 | 66 | w0 = self.weight.data.flatten(start_dim=1) 67 | self.weight = ManifoldParameter(w0 / w0.norm(dim=-1, keepdim=True), manifold=Sphere()) 68 | 69 | def forward(self, x): 70 | return super().forward(x) 71 | 72 | 73 | class MyLinear(nn.Module): 74 | def __init__(self, shape : Tuple[int, ...] or torch.Size, bias: bool = True, **kwargs): 75 | super().__init__() 76 | 77 | self.W = Parameter(torch.empty(shape, **kwargs)) 78 | 79 | if bias: 80 | self.bias = Parameter(torch.empty((*shape[:-2], shape[-1]), **kwargs)) 81 | else: 82 | self.register_parameter('bias', None) 83 | 84 | self.reset_parameters() 85 | 86 | def forward(self, X : Tensor) -> Tensor: 87 | A = (X.unsqueeze(-2) @ self.W).squeeze(-2) 88 | if self.bias is not None: 89 | A += self.bias 90 | return A 91 | 92 | @torch.no_grad() 93 | def reset_parameters(self): 94 | # kaiming initialization std2uniformbound * gain * fan_in 95 | bound = math.sqrt(3) * 1. / math.sqrt(self.W.shape[-2]) 96 | self.W.data.uniform_(-bound,bound) 97 | if self.bias is not None: 98 | bound = 1 / math.sqrt(self.W.shape[-2]) 99 | self.bias.data.uniform_(-bound, bound) 100 | 101 | 102 | class Encode2DPosition(nn.Module): 103 | """ 104 | Encodes the 2D position of a 2D CNN or 2D image 105 | as additional channels. 106 | Input: (batch, chans, height, width) 107 | Output: (batch, chans+2, height, width) 108 | """ 109 | def __init__(self, flatten = True): 110 | super().__init__() 111 | self.flatten = flatten 112 | 113 | def forward(self, X : Tensor) -> Tensor: 114 | pos1 = torch.arange(X.shape[-2])[None,None,:,None].tile((X.shape[0],1, 1, X.shape[-1])) / X.shape[-2] 115 | pos2 = torch.arange(X.shape[-1])[None,None,None,:].tile((X.shape[0],1, X.shape[-2], 1)) / X.shape[-1] 116 | 117 | Z = torch.cat((X, pos1, pos2),dim=1) 118 | if self.flatten: 119 | Z = Z.flatten(start_dim=-2) 120 | 121 | return Z 122 | 123 | 124 | class CovariancePool(nn.Module): 125 | def __init__(self, alpha = None, unitvar = False): 126 | super().__init__() 127 | self.pooldim = -1 128 | self.chandim = -2 129 | self.alpha = alpha 130 | self.unitvar = unitvar 131 | 132 | def forward(self, X : Tensor) -> Tensor: 133 | X0 = X - X.mean(dim=self.pooldim, keepdim=True) 134 | if self.unitvar: 135 | X0 = X0 / X0.std(dim=self.pooldim, keepdim=True) 136 | X0.nan_to_num_(0) 137 | 138 | C = (X0 @ X0.transpose(-2, -1)) / X0.shape[self.pooldim] 139 | if self.alpha is not None: 140 | Cd = C.diagonal(dim1=self.pooldim, dim2=self.pooldim-1) 141 | Cd += self.alpha 142 | return C 143 | 144 | 145 | class ReverseGradient(nn.Module): 146 | def __init__(self, scaling = 1.): 147 | super().__init__() 148 | self.scaling_ = scaling 149 | 150 | def forward(self, X : Tensor) -> Tensor: 151 | return functionals.reverse_gradient.apply(X, self.scaling_) 152 | 153 | 154 | class BiMap(nn.Module): 155 | def __init__(self, shape : Tuple[int, ...] or torch.Size, W0 : Tensor = None, manifold='stiefel', **kwargs): 156 | super().__init__() 157 | 158 | if manifold == 'stiefel': 159 | assert(shape[-2] >= shape[-1]) 160 | mf = Stiefel() 161 | elif manifold == 'sphere': 162 | mf = Sphere() 163 | shape = list(shape) 164 | shape[-1], shape[-2] = shape[-2], shape[-1] 165 | else: 166 | raise NotImplementedError() 167 | 168 | # add constraint (also initializes the parameter to fulfill the constraint) 169 | self.W = ManifoldParameter(torch.empty(shape, **kwargs), manifold=mf) 170 | 171 | # optionally initialize the weights (initialization has to fulfill the constraint!) 172 | if W0 is not None: 173 | self.W.data = W0 # e.g., self.W = torch.nn.init.orthogonal_(self.W) 174 | else: 175 | self.reset_parameters() 176 | 177 | def forward(self, X : Tensor) -> Tensor: 178 | if isinstance(self.W.manifold, Sphere): 179 | return self.W @ X @ self.W.transpose(-2,-1) 180 | else: 181 | return self.W.transpose(-2,-1) @ X @ self.W 182 | 183 | @torch.no_grad() 184 | def reset_parameters(self): 185 | if isinstance(self.W.manifold, Stiefel): 186 | # uniform initialization on stiefel manifold after theorem 2.2.1 in Chikuse (2003): statistics on special manifolds 187 | W = torch.rand(self.W.shape, dtype=self.W.dtype, device=self.W.device) 188 | self.W.data = W @ functionals.sym_invsqrtm.apply(W.transpose(-1,-2) @ W) 189 | elif isinstance(self.W.manifold, Sphere): 190 | W = torch.empty(self.W.shape, dtype=self.W.dtype, device=self.W.device) 191 | # kaiming initialization std2uniformbound * gain * fan_in 192 | bound = math.sqrt(3) * 1. / W.shape[-1] 193 | W.uniform_(-bound, bound) 194 | # constraint has to be satisfied 195 | self.W.data = W / W.norm(dim=-1, keepdim=True) 196 | else: 197 | raise NotImplementedError() 198 | 199 | 200 | class ReEig(nn.Module): 201 | def __init__(self, threshold : Number = 1e-4): 202 | super().__init__() 203 | self.threshold = Tensor([threshold]) 204 | 205 | def forward(self, X : Tensor) -> Tensor: 206 | return functionals.sym_reeig.apply(X, self.threshold) 207 | 208 | 209 | class LogEig(nn.Module): 210 | def __init__(self, ndim, tril=True): 211 | super().__init__() 212 | 213 | self.tril = tril 214 | if self.tril: 215 | ixs_lower = torch.tril_indices(ndim,ndim, offset=-1) 216 | ixs_diag = torch.arange(start=0, end=ndim, dtype=torch.long) 217 | self.ixs = torch.cat((ixs_diag[None,:].tile((2,1)), ixs_lower), dim=1) 218 | self.ndim = ndim 219 | 220 | def forward(self, X : Tensor) -> Tensor: 221 | return self.embed(functionals.sym_logm.apply(X)) 222 | 223 | def embed(self, X : Tensor) -> Tensor: 224 | if self.tril: 225 | x_vec = X[...,self.ixs[0],self.ixs[1]] 226 | x_vec[...,self.ndim:] *= math.sqrt(2) 227 | else: 228 | x_vec = X.flatten(start_dim=-2) 229 | return x_vec 230 | 231 | 232 | class TrilEmbedder(nn.Module): 233 | 234 | def forward(self, X : Tensor) -> Tensor: 235 | 236 | ndim = X.shape[-1] 237 | ixs_lower = torch.tril_indices(ndim,ndim, offset=-1) 238 | ixs_diag = torch.arange(start=0, end=ndim, dtype=torch.long) 239 | ixs = torch.cat((ixs_diag[None,:].tile((2,1)), ixs_lower), dim=1) 240 | 241 | x_vec = X[...,ixs[0],ixs[1]] 242 | x_vec[...,ndim:] *= math.sqrt(2) 243 | return x_vec 244 | 245 | def inverse_transform(self, x_vec: Tensor) -> Tensor: 246 | 247 | ndim = int(-.5 + math.sqrt(.25 + 2*x_vec.shape[-1])) # c*(c+1)/2 = nts 248 | ixs_lower = torch.tril_indices(ndim,ndim, offset=-1) 249 | ixs_diag = torch.arange(start=0, end=ndim, dtype=torch.long) 250 | 251 | X = torch.zeros(x_vec.shape[:-1] + (ndim, ndim), device=x_vec.device, dtype=x_vec.dtype) 252 | 253 | # off diagonal elements 254 | X[...,ixs_lower[0],ixs_lower[1]] = x_vec[...,ndim:] / math.sqrt(2) 255 | X[...,ixs_lower[1],ixs_lower[0]] = x_vec[...,ndim:] / math.sqrt(2) 256 | X[...,ixs_diag,ixs_diag] = x_vec[...,:ndim] 257 | 258 | return X -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/utils/skorch/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .logging import TrainLog 3 | from .network import DomainAdaptNeuralNetClassifier -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/utils/skorch/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/utils/skorch/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/utils/skorch/__pycache__/logging.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/utils/skorch/__pycache__/logging.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/utils/skorch/__pycache__/network.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/utils/skorch/__pycache__/network.cpython-310.pyc -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/utils/skorch/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from skorch.callbacks.logging import EpochTimer, PrintLog 3 | 4 | log = logging.getLogger(__name__) 5 | 6 | class TrainLog(PrintLog): 7 | 8 | def __init__(self, prefix='') -> None: 9 | super().__init__() 10 | self.prefix = prefix 11 | 12 | def initialize(self): 13 | return self 14 | 15 | def on_epoch_end(self, net, **kwargs): 16 | r = net.history[-1] 17 | 18 | if r['epoch'] == 1 or r['epoch'] % 10 == 0: 19 | log.info(f"{self.prefix} {r['epoch']:3d} : trn={r['train_loss']:.3f}/{r['score_trn']:.2f} val={r['valid_loss']:.3f}/{r['score_val']:.2f} time: {r['dur']:.2f}") 20 | 21 | 22 | -------------------------------------------------------------------------------- /LieBN_TSMNet/spdnets/utils/skorch/network.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from skorch.classifier import NeuralNetClassifier 4 | from skorch.callbacks.logging import EpochTimer, PrintLog 5 | from skorch.callbacks.scoring import EpochScoring, PassthroughScoring 6 | 7 | from spdnets.models import BaseModel, DomainAdaptFineTuneableModel, FineTuneableModel 8 | 9 | from .logging import TrainLog 10 | 11 | log = logging.getLogger(__name__) 12 | 13 | class DomainAdaptNeuralNetClassifier(NeuralNetClassifier): 14 | def __init__(self, module, *args, criterion=torch.nn.CrossEntropyLoss, **kwargs): 15 | super().__init__(module, *args, criterion=criterion, **kwargs) 16 | 17 | @property 18 | def _default_callbacks(self): 19 | return [ 20 | ('epoch_timer', EpochTimer()), 21 | ('train_loss', PassthroughScoring( 22 | name='train_loss', 23 | on_train=True, 24 | )), 25 | ('valid_loss', PassthroughScoring( 26 | name='valid_loss', 27 | )), 28 | ('print_log', TrainLog()), 29 | ] 30 | 31 | def get_loss(self, mdl_pred, y_true, X=None, **kwargs): 32 | if isinstance(self.module_, BaseModel): 33 | return self.module_.calculate_objective(mdl_pred, y_true, X) 34 | elif isinstance(mdl_pred, (list, tuple)): 35 | y_hat = mdl_pred[0] 36 | else: 37 | y_hat = mdl_pred 38 | return self.criterion_(y_hat, y_true.to(y_hat.device)) 39 | 40 | def domainadapt_finetune(self, x: torch.Tensor, y: torch.Tensor, d : torch.Tensor, target_domains=None): 41 | if isinstance(self.module_, DomainAdaptFineTuneableModel): 42 | self.module_.domainadapt_finetune(x=x.to(self.device), y=y, d=d, target_domains=target_domains) 43 | else: 44 | log.info("Model does not support domain adapt fine tuning.") 45 | 46 | def finetune(self, x: torch.Tensor, y: torch.Tensor, d : torch.Tensor): 47 | if isinstance(self.module_, FineTuneableModel): 48 | self.module_.finetune(x=x.to(self.device), y=y, d=d) 49 | else: 50 | log.info("Model does not support fine-tuning.") 51 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | [](https://arxiv.org/abs/2403.11261) 2 | [](https://openreview.net/forum?id=okYdj8Ysru) 3 | [](https://openreview.net/pdf?id=okYdj8Ysru) 4 | 5 | 6 | # A Lie Group Approach to Riemannian Batch Normalization 7 | 8 | **Updates (02/2025):** We have integrated the LieBN implementations into a toolbox, now supporting nine invariant metrics across different matrix manifolds: 9 | 10 | - Symmetric Positive Definite (SPD) Manifold: Four distinct metrics, including a newly introduced right-invariant metric. 11 | - Rotation Group: One bi-invariant metric. 12 | - Full-Rank Correlation Manifold: Four recently developed correlation geometries. 13 | 14 | The complete implementations can be found in the `LieBN` folder. This toolbox is designed to be plug-and-play, making it easy to apply LieBN as a drop-in normalization module across different neural architectures. 15 | 16 | ## Introduction 17 | 18 | This is the official code for our ICLR 2024 publication: *A Lie Group Approach to Riemannian Batch Normalization*. [[OpenReview](https://openreview.net/forum?id=okYdj8Ysru)]. 19 | 20 | If you find this project helpful, please consider citing us as follows: 21 | 22 | ```bib 23 | @inproceedings{chen2024liebn, 24 | title={A Lie Group Approach to Riemannian Batch Normalization}, 25 | author={Ziheng Chen and Yue Song and Yunmei Liu and Nicu Sebe}, 26 | booktitle={The Twelfth International Conference on Learning Representations}, 27 | year={2024}, 28 | url={https://openreview.net/forum?id=okYdj8Ysru} 29 | } 30 | ``` 31 | 32 | In case you have any problem, do not hesitate to contact me ziheng_ch@163.com. 33 | 34 | ## Requirements 35 | 36 | Install necessary dependencies by `conda`: 37 | 38 | ```setup 39 | conda env create --file environment.yaml 40 | ``` 41 | **Note** that the [hydra](https://hydra.cc/) package is used to manage configuration files. 42 | 43 | ## Experiments on the SPDNet 44 | 45 | The code of experiments on SPDNet, SPDNetBN, and SPDNetLieBN is enclosed in the folder `./LieBN_SPDNet` 46 | 47 | The implementation is based on the official code of *Riemannian batch normalization for SPD neural networks* [[Neurips 2019](https://papers.nips.cc/paper_files/paper/2019/hash/6e69ebbfad976d4637bb4b39de261bf7-Abstract.html)] [[code](https://papers.nips.cc/paper_files/paper/2019/file/6e69ebbfad976d4637bb4b39de261bf7-Supplemental.zip)]. 48 | 49 | ### Dataset 50 | 51 | The synthetic [Radar](https://www.dropbox.com/s/dfnlx2bnyh3kjwy/data.zip?e=1&dl=0) dataset is released by SPDNetBN. We further release our preprocessed [HDM05](https://www.dropbox.com/scl/fi/x2ouxjwqj3zrb1idgkg2g/HDM05.zip?rlkey=4f90ktgzfz28x3i2i4ylu6dvu&dl=0) dataset. 52 | 53 | Please download the datasets and put them in your personal folder and change the `path` accordingly in `./LieBN_SPDNet/conf/dataset/RADAR.yaml` and `./LieBN_SPDNet/conf/dataset/HDM05.yaml` 54 | 55 | ### Running experiments 56 | 57 | To run all the experiments on the Radar and HDM05 datasets, go to the folder `LieBN_SPDNet` and run this command: 58 | 59 | ```train 60 | bash run_experiments.sh 61 | ``` 62 | This script contains the experiments on the Radar and HDM05 datasets shown in Tab. 4 63 | 64 | ## Experiments on the TSMNet 65 | 66 | The code of experiments on TSMNet, TSMNet + SPDDSMBN, and TSMNet + DSMLieBN is enclosed in the folder `./LieBN_TSMNet` 67 | 68 | The implementation is based on the official code of *SPD domain-specific batch normalization to crack interpretable unsupervised domain adaptation in EEG* [[Neurips 2022](https://openreview.net/forum?id=pp7onaiM4VB)] [[code](https://github.com/rkobler/TSMNet.git)]. 69 | 70 | ### Dataset 71 | 72 | The [Hinss2021](https://doi.org/10.5281/zenodo.5055046) dataset is publicly available. The [moabb](https://neurotechx.github.io/moabb/) and [mne](https://mne.tools) packages are used to download and preprocess these datasets. There is no need to manually download and preprocess the datasets. This is done automatically. If necessary, change the `data_dir` in `./LieBN_TSMNet/conf/LieBN.yaml` to your personal folder. 73 | 74 | ### Running experiments 75 | 76 | To run all the experiments on the Radar and HDM05 datasets, go to the folder `LieBN_TSMNet` and run this command: 77 | 78 | ```train 79 | bash run_experiments.sh 80 | ``` 81 | This script contains the experiments on the Hinss2021 datasets shown in Tab. 5 82 | 83 | **Note:** You also can change the `data_dir` in `run_experiments.sh`, which will override the hydra config. 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: LieBN 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | # basic programs 7 | - python==3.8.* 8 | - pip 9 | # scientific python base packages 10 | - numpy==1.20.* 11 | - pandas==1.2.* 12 | - scipy==1.6.* 13 | # jupyter notebooks 14 | - ipykernel 15 | - notebook 16 | - jupyterlab 17 | - nb_conda_kernels 18 | # python visualization 19 | - matplotlib==3.4.* 20 | - seaborn==0.11.* 21 | # machine learning 22 | - scikit-learn==1.0.* 23 | - pytorch==2.2.* 24 | - torchvision==0.17.* 25 | - skorch==0.11.* 26 | - pip: 27 | # tensorboard 28 | - tensorboard==2.14.* 29 | # m/eeg analysis 30 | - mne==0.22.* 31 | - moabb==0.4.* 32 | # command line interfacing 33 | - hydra-core==1.3.* 34 | - hydra-joblib-launcher==1.2.* 35 | # machine learning 36 | - pyriemann==0.2.* 37 | - git+https://github.com/geoopt/geoopt.git@524330b11c0f9f6046bda59fe334803b4b74e13e 38 | # this package 39 | - -e . 40 | --------------------------------------------------------------------------------