├── .DS_Store
├── ICLR24_LieBN_PPT.pdf
├── ICLR24_LieBN_Poster.pdf
├── LieBN
├── Geometry
│ ├── Base
│ │ ├── LieGroups.py
│ │ ├── __init__.py
│ │ └── __pycache__
│ │ │ ├── LieGroups.cpython-38.pyc
│ │ │ └── __init__.cpython-38.pyc
│ ├── Correlation
│ │ ├── CorMatrices.py
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── CorMatrices.cpython-38.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── cor_functions.cpython-38.pyc
│ │ │ └── ppbcm_functionals.cpython-38.pyc
│ │ ├── cor_functions.py
│ │ └── ppbcm_functionals.py
│ ├── Rotations
│ │ ├── RotMatrices.py
│ │ ├── __init__.py
│ │ └── __pycache__
│ │ │ ├── RotMatrices.cpython-38.pyc
│ │ │ └── __init__.cpython-38.pyc
│ ├── SPD
│ │ ├── SPDMatrices.py
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── SPDMatrices.cpython-38.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ └── sym_functional.cpython-38.pyc
│ │ └── sym_functional.py
│ ├── __init__.py
│ └── __pycache__
│ │ └── __init__.cpython-38.pyc
├── LieBNBase.py
├── LieBNCor.py
├── LieBNRot.py
├── LieBNSPD.py
├── LieBN_illustration.png
├── README.md
├── __init__.py
├── __pycache__
│ ├── LieBNBase.cpython-38.pyc
│ ├── LieBNCor.cpython-38.pyc
│ ├── LieBNRot.cpython-38.pyc
│ ├── LieBNSPD.cpython-38.pyc
│ └── __init__.cpython-38.pyc
├── demos
│ ├── liebncor.py
│ ├── liebnrot.py
│ └── liebnspd.py
└── sum_metrics.png
├── LieBN_SPDNet
├── Network
│ ├── Get_Model.py
│ ├── SPDNetRBN.py
│ └── __init__.py
├── SPDNetLieBN.py
├── conf
│ ├── LieBN.yaml
│ ├── dataset
│ │ ├── HDM05.yaml
│ │ └── RADAR.yaml
│ └── nnet
│ │ ├── SPDNet.yaml
│ │ ├── SPDNetBN.yaml
│ │ └── SPDNetLieBN.yaml
├── cplx
│ ├── functional.py
│ └── nn.py
├── experiments.sh
└── spd
│ ├── DataLoader
│ ├── FPHA_Loader.py
│ ├── HDM05_Loader.py
│ └── Radar_Loader.py
│ ├── LieBN.py
│ ├── __init__.py
│ ├── functional.py
│ ├── nn.py
│ ├── sym_functional.py
│ ├── training_script.py
│ └── utils.py
├── LieBN_TSMNet
├── LieBN_utilities
│ ├── Training.py
│ ├── __init__.py
│ └── utils.py
├── TSMNet-LieBN.py
├── conf
│ ├── LieBN.yaml
│ ├── dataset
│ │ └── hinss2021.yaml
│ ├── evaluation
│ │ ├── inter-session+uda.yaml
│ │ └── inter-subject+uda.yaml
│ ├── nnet
│ │ ├── tsmnet.yaml
│ │ ├── tsmnet_LieBN.yaml
│ │ └── tsmnet_spddsmbn.yaml
│ └── preprocessing
│ │ └── bb4-36Hz.yaml
├── datasetio
│ ├── __init__.py
│ ├── __pycache__
│ │ └── __init__.cpython-310.pyc
│ └── eeg
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ └── __init__.cpython-310.pyc
│ │ └── moabb
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── base.cpython-310.pyc
│ │ ├── hinss2021.cpython-310.pyc
│ │ └── stieger2021.cpython-310.pyc
│ │ ├── base.py
│ │ ├── hinss2021.py
│ │ └── stieger2021.py
├── experiments_Hinss21.sh
├── library
│ ├── __init__.py
│ ├── __pycache__
│ │ └── __init__.cpython-310.pyc
│ └── utils
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ └── __init__.cpython-310.pyc
│ │ ├── hydra
│ │ ├── __init__.py
│ │ └── __pycache__
│ │ │ └── __init__.cpython-310.pyc
│ │ ├── moabb
│ │ ├── __init__.py
│ │ └── __pycache__
│ │ │ └── __init__.cpython-310.pyc
│ │ ├── pyriemann
│ │ └── __init__.py
│ │ └── torch
│ │ ├── __init__.py
│ │ └── __pycache__
│ │ └── __init__.cpython-310.pyc
└── spdnets
│ ├── BaseBatchNorm.py
│ ├── LieBNImpl.py
│ ├── Liebatchnorm.py
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-310.pyc
│ ├── batchnorm.cpython-310.pyc
│ ├── functionals.cpython-310.pyc
│ ├── manifolds.cpython-310.pyc
│ └── modules.cpython-310.pyc
│ ├── batchnorm.py
│ ├── functionals.py
│ ├── manifolds.py
│ ├── models
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── base.cpython-310.pyc
│ │ ├── dann.cpython-310.pyc
│ │ ├── eegnet.cpython-310.pyc
│ │ ├── shconvnet.cpython-310.pyc
│ │ └── tsmnet.cpython-310.pyc
│ ├── base.py
│ ├── dann.py
│ ├── eegnet.py
│ ├── shconvnet.py
│ └── tsmnet.py
│ ├── modules.py
│ └── utils
│ └── skorch
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-310.pyc
│ ├── logging.cpython-310.pyc
│ └── network.cpython-310.pyc
│ ├── logging.py
│ └── network.py
├── Readme.md
└── environment.yaml
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/.DS_Store
--------------------------------------------------------------------------------
/ICLR24_LieBN_PPT.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/ICLR24_LieBN_PPT.pdf
--------------------------------------------------------------------------------
/ICLR24_LieBN_Poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/ICLR24_LieBN_Poster.pdf
--------------------------------------------------------------------------------
/LieBN/Geometry/Base/LieGroups.py:
--------------------------------------------------------------------------------
1 | """
2 | SPD computations under LieBN:
3 | @inproceedings{chen2024liebn,
4 | title={A Lie Group Approach to Riemannian Batch Normalization},
5 | author={Ziheng Chen and Yue Song and Yunmei Liu and Nicu Sebe},
6 | booktitle={The Twelfth International Conference on Learning Representations},
7 | year={2024},
8 | url={https://openreview.net/forum?id=okYdj8Ysru}
9 | }
10 | """
11 | import torch as th
12 | import torch.nn as nn
13 |
14 | from abc import ABC
15 |
16 |
17 | class LieGroup(nn.Module):
18 | """LieBN for [...,n,n]: Following SPDNetBN, we use .detach() for mean and variance calculation
19 | """
20 | def __init__(self,is_detach=True):
21 | super().__init__()
22 | self.is_detach=is_detach
23 |
24 | def geodesic(self,P,Q,t):
25 | '''The geodesic from P to Q at t'''
26 | raise NotImplemented
27 |
28 | #--- Methods required in LieBN ---
29 | def deformation(self,X):
30 | raise NotImplemented
31 |
32 | def inv_deformation(self,X):
33 | raise NotImplemented
34 | def dist2Isquare(self,X):
35 | """geodesic square distance to the identity element.
36 | it should keep all the dim, such that cal_geom_var only eliminate the batch dim.
37 | This will support [bs,...,n,n] data
38 | """
39 | raise NotImplemented
40 |
41 | def cal_geom_mean_(self,X,batchdim=[0]):
42 | raise NotImplemented
43 |
44 | def cal_geom_mean(self,X,batchdim=[0]):
45 | batch = X.detach() if self.is_detach else X
46 | return self.cal_geom_mean_(batch,batchdim)
47 |
48 | def cal_geom_var(self, X,batchdim=[0]):
49 | """Frechet variance"""
50 | batch = X.detach() if self.is_detach else X
51 | return self.dist2Isquare(batch).mean(dim=batchdim)
52 |
53 | def translation(self, X, P, is_inverse):
54 | raise NotImplemented
55 |
56 | def scaling(self, X, factor):
57 | raise NotImplemented
58 |
59 | def __repr__(self):
60 | attributes = []
61 | for key, value in self.__dict__.items():
62 | if key.startswith("_"):
63 | continue
64 | if isinstance(value, th.Tensor):
65 | attributes.append(f"{key}.shape={tuple(value.shape)}") # 只显示 Tensor 形状
66 | else:
67 | attributes.append(f"{key}={value}")
68 |
69 | # 处理 register_buffer 注册的变量
70 | for name, buffer in self.named_buffers(recurse=False): # 只打印当前类的 buffers
71 | attributes.append(f"{name}={buffer.item() if buffer.numel() == 1 else tuple(buffer.shape)}")
72 |
73 | return f"{self.__class__.__name__}({', '.join(attributes)})"
74 |
75 | class PullbackEuclideanMetric(ABC):
76 | """The Euclidean computation in the co-domain of PullBack Euclidean Metric:
77 | Ziheng Chen, etal, Adaptive Log-Euclidean Metrics for SPD Matrix Learning
78 | For the subclass of this class, we only need to implement
79 | dist2Isquare, deformation, and inv_deformation
80 | """
81 | # def __init__(self, is_detach=True):
82 | # super().__init__(is_detach=is_detach)
83 |
84 | def cal_geom_mean_(self, batch,batchdim=[0]):
85 | return batch.mean(dim=batchdim)
86 |
87 | def translation(self, X, P, is_inverse):
88 | X_new = X - P if is_inverse else X + P
89 | return X_new
90 |
91 | def scaling(self, X, factor):
92 | return X * factor
93 |
94 | def geodesic(self,P,Q,t):
95 | return (1 - t) * P + t * Q
--------------------------------------------------------------------------------
/LieBN/Geometry/Base/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/Base/__init__.py
--------------------------------------------------------------------------------
/LieBN/Geometry/Base/__pycache__/LieGroups.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/Base/__pycache__/LieGroups.cpython-38.pyc
--------------------------------------------------------------------------------
/LieBN/Geometry/Base/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/Base/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/LieBN/Geometry/Correlation/__init__.py:
--------------------------------------------------------------------------------
1 | from .CorMatrices import Correlation,\
2 | CorEuclideanCholeskyMetric,CorLogEuclideanCholeskyMetric,\
3 | CorOffLogMetric,CorLogScaledMetric
--------------------------------------------------------------------------------
/LieBN/Geometry/Correlation/__pycache__/CorMatrices.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/Correlation/__pycache__/CorMatrices.cpython-38.pyc
--------------------------------------------------------------------------------
/LieBN/Geometry/Correlation/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/Correlation/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/LieBN/Geometry/Correlation/__pycache__/cor_functions.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/Correlation/__pycache__/cor_functions.cpython-38.pyc
--------------------------------------------------------------------------------
/LieBN/Geometry/Correlation/__pycache__/ppbcm_functionals.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/Correlation/__pycache__/ppbcm_functionals.cpython-38.pyc
--------------------------------------------------------------------------------
/LieBN/Geometry/Correlation/ppbcm_functionals.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 |
3 | @th.jit.script
4 | def hemisphere_to_poincare(x):
5 | """
6 | Transform points from the open hemisphere HS^n to the Poincaré ball P^n.
7 |
8 | Parameters:
9 | - x: torch.Tensor of shape [..., n+1], where the last dimension
10 | represents the coordinates (x^T, x_{n+1}).
11 |
12 | Returns:
13 | - y: torch.Tensor of shape [..., n], transformed coordinates in the Poincaré ball.
14 | """
15 | x_T, x_n1 = x[..., :-1], x[..., -1] # Split x into (x_T, x_n+1)
16 | y = x_T / (1 + x_n1.unsqueeze(-1)) # Apply the transformation: y = x_T / (1 + x_n1)
17 | return y
18 |
19 | @th.jit.script
20 | def poincare_to_hemisphere(y):
21 | """Map from Poincaré ball to open hemisphere."""
22 | norm_y = y.norm(p=2,dim=-1, keepdim=True) ** 2
23 | factor = 1 / (1 + norm_y)
24 | mapped = th.cat((2 * y * factor, (1 - norm_y) * factor), dim=-1)
25 | return mapped
26 |
27 |
28 |
--------------------------------------------------------------------------------
/LieBN/Geometry/Rotations/RotMatrices.py:
--------------------------------------------------------------------------------
1 | """
2 | SPD computations under LieBN:
3 | @inproceedings{chen2024liebn,
4 | title={A Lie Group Approach to Riemannian Batch Normalization},
5 | author={Ziheng Chen and Yue Song and Yunmei Liu and Nicu Sebe},
6 | booktitle={The Twelfth International Conference on Learning Representations},
7 | year={2024},
8 | url={https://openreview.net/forum?id=okYdj8Ysru}
9 | }
10 | """
11 | import torch as th
12 | from ..Base.LieGroups import LieGroup
13 |
14 | from geoopt.manifolds import Manifold
15 | from pytorch3d.transforms import matrix_to_axis_angle
16 |
17 | class RotMatrices(LieGroup,Manifold):
18 | """
19 | Computation for SO3 data with size of [...,n,n]
20 | Following manopt, we use the Lie algebra representation for the tangent spaces
21 | """
22 | # __scaling__ = Manifold.__scaling__.copy()
23 | name = "Rotation"
24 | ndim = 2
25 | reversible = False
26 | def __init__(self,eps=1e-5,is_detach=True,karcher_steps=1):
27 | super().__init__(is_detach=is_detach)
28 | self.eps=eps;
29 | self.register_buffer('I', th.eye(3))
30 | self.karcher_steps=karcher_steps
31 |
32 | # === Generating random matrices ===
33 |
34 | def random(self, *shape):
35 | """Generate random 3D rotation matrices of shape [..., 3, 3]"""
36 | A = th.rand(shape) # Generate a random matrix
37 | u, _, v = th.linalg.svd(A) # Perform SVD
38 |
39 | # Ensure determinant is +1
40 | det_u = u.det()[..., None] # Reshape to match the last two dimensions
41 | u[..., :, -1] *= th.sign(det_u) # Flip last column where needed
42 | # u_new= u[..., :, -1] * th.sign(det_u) # Flip last column where needed
43 |
44 | # Verify that the output is a valid SO(3) matrix
45 | result, _ = self._check_point_on_manifold(u)
46 | if not result:
47 | raise ValueError("SO(3) init value error")
48 |
49 | return u
50 |
51 | def rand_skrew_sym(self,n,f):
52 | A=th.rand(n,f,3,3)
53 | return A-A.transpose(-1,-2)
54 |
55 | # === For geoopt ===
56 |
57 | def quat_axis2log_axis(self,axis_quat):
58 | """"convert quaternion axis into matrix log axis
59 | https://github.com/facebookresearch/pytorch3d/issues/188
60 | """
61 | log_axis_result = axis_quat.clone()
62 | angle_in_2pi = log_axis_result.norm(p=2, dim=-1, keepdim=True)
63 | mask = angle_in_2pi > th.pi
64 | tmp = 1 - 2 * th.pi / angle_in_2pi
65 | new_values = tmp * log_axis_result
66 | results = th.where(mask, new_values, log_axis_result)
67 | return results
68 |
69 | def matrix2euler_axis(self,R):
70 | quat_vec = matrix_to_axis_angle(R)
71 | log_axis = self.quat_axis2log_axis(quat_vec)
72 | euler_axis = log_axis.div(log_axis.norm(dim=-1,keepdim=True))
73 | return euler_axis
74 |
75 | def mLog(self, R):
76 | """
77 | Note that Exp(\alpha A)=Exp(A), with \alpha = \frac{\theta-2\pi}{\theta}
78 | So, for a single rotation matrices, quat_axis or log_axis does not affect self.mExp.
79 | """
80 | vec = matrix_to_axis_angle(R)
81 | log_vec = self.quat_axis2log_axis(vec)
82 | skew_symmetric = self.vec2skrew(log_vec)
83 | return skew_symmetric
84 |
85 | def vec2skrew(self,vec):
86 | # skew_symmetric = th.zeros_like(vec).unsqueeze(-1).expand(*vec.shape, 3).contiguous()
87 | skew_symmetric = th.zeros(*vec.shape, 3,dtype=vec.dtype,device=vec.device)
88 | skew_symmetric[..., 0, 1] = -vec[..., 2]
89 | skew_symmetric[..., 1, 0] = vec[..., 2]
90 | skew_symmetric[..., 0, 2] = vec[..., 1]
91 | skew_symmetric[..., 2, 0] = -vec[..., 1]
92 | skew_symmetric[..., 1, 2] = -vec[..., 0]
93 | skew_symmetric[..., 2, 1] = vec[..., 0]
94 | return skew_symmetric
95 | def mExp(self, S):
96 | """Computing matrix exponential for skrew symmetric matrices"""
97 | a, b, c = S[..., 0, 1], S[..., 0, 2], S[..., 1, 2]
98 | theta = th.sqrt(a ** 2 + b ** 2 + c ** 2).unsqueeze(-1).unsqueeze(-1)
99 |
100 | S_normalized = S / theta
101 | S_norm_squared = S_normalized.matmul(S_normalized)
102 | sin_theta = th.sin(theta)
103 | cos_theta = th.cos(theta)
104 | tmp_S = self.I + sin_theta * S_normalized + (1 - cos_theta) * S_norm_squared
105 |
106 | S_new = th.where(theta < self.eps, S-S.detach()+self.I, tmp_S) # S+I to ensure autograd
107 |
108 | return S_new
109 |
110 | def transp(self, x, y, v):
111 | return v
112 |
113 | def inner(self, x, u, v, keepdim=False):
114 | if v is None:
115 | v = u
116 | return th.sum(u * v, dim=[-2, -1], keepdim=keepdim)
117 |
118 | def projx(self, x):
119 | u, s, vt = th.linalg.svd(x)
120 | ones = th.ones_like(s)[..., :-1]
121 | signs = th.sign(th.det(th.matmul(u, vt))).unsqueeze(-1)
122 | flip = th.cat([ones, signs], dim=-1)
123 | result = u.matmul(th.diag_embed(flip)).matmul(vt)
124 | return result
125 |
126 | def proju(self,X, H):
127 | k = self.multiskew(H)
128 | return k
129 |
130 | def egrad2rgrad(self, x, u):
131 | """Map the Euclidean gradient :math:`u` in the ambient space on the tangent
132 | space at :math:`x`.
133 | """
134 | k = self.multiskew(x.transpose(-1, -2).matmul(u))
135 | return k
136 |
137 | def retr(self, X,U):
138 | Y = X + X.matmul(U)
139 | Q, R = th.linalg.qr(Y)
140 | New = th.matmul(Q, th.diag_embed(th.sign(th.sign(th.diagonal(R, dim1=-2, dim2=-1)) + 0.5)))
141 | return New
142 |
143 | def multiskew(self,A):
144 | return 0.5 * (A - A.transpose(-1,-2))
145 |
146 | def logmap(self,R,S):
147 | """ return skrew symmetric matrices """
148 | return self.mLog(R.transpose(-1,-2).matmul(S))
149 |
150 | def expmap(self,R,V):
151 | """ V is the skrew symmetric matrices """
152 | return R.matmul(self.mExp(V))
153 |
154 | def geodesic(self,R,S,t):
155 | """ the geodesic connecting R and s """
156 | vector = self.logmap(R,S)
157 | X_new = R.matmul(self.mExp(t*vector))
158 | return X_new
159 |
160 | def trace(self,m):
161 | """Computation for trace of m of [...,n,n]"""
162 | return th.einsum("...ii->...", m)
163 |
164 | def _check_point_on_manifold(
165 | self, x: th.Tensor, *, atol=1e-5, rtol=1e-8
166 | ):
167 |
168 | if x.shape[-1] != 3 or x.shape[-2] != 3:
169 | raise ValueError("Input matrices must be 3x3.")
170 |
171 | # Check orthogonality
172 | is_orthogonal = th.allclose(x @ x.transpose(-1, -2),self.I, atol=atol,rtol=rtol)
173 |
174 | # Check determinant
175 | det = th.det(x)
176 | is_det_one = th.allclose(det, th.tensor(1.0, device=x.device,dtype=x.dtype), atol=atol,rtol=rtol)
177 |
178 | # Combine both conditions
179 | is_SO3 = is_orthogonal & is_det_one
180 |
181 | return is_SO3, None
182 |
183 | def _check_vector_on_tangent(
184 | self, x: th.Tensor, u: th.Tensor, *, atol=1e-5, rtol=1e-8
185 | ):
186 | """Check whether u is a skrew symmetric matrices"""
187 | diff = u + u.transpose(-1,-2)
188 | ok = th.allclose(diff, th.zeros_like(diff), atol=atol, rtol=rtol)
189 | return ok, None
190 |
191 | def is_not_equal(self, a, b, eps=0.01):
192 | """ Return true if not eaqual"""
193 | return th.nonzero(th.abs(a - b) > eps)
194 |
195 | # === Computing angles ===
196 |
197 | def cal_roc_angel_batch(self,r,epsilon=1e-4):
198 | """
199 | Following the matlab implemetation, we set derivative=0 for tr near -1 or near 3.
200 | Besides, there could be cases where tr \in [-1-eps, 3+eps] due to numerical error.
201 | We view the cases beyond the [-1,3] as -1 or 3, and the derivative is 0.
202 | return:
203 | tr <= -1-epsilon, theta=pi (derivative is 0)
204 | tr >= 3-epsilon, theta=0 (derivative is 0)
205 | """
206 | assert epsilon >= 0, "Epsilon must be positive"
207 |
208 | mtrc = self.trace(r)
209 |
210 | maskpi = (mtrc + 1) <= epsilon # tr <= -1 + epsilon
211 | # mtrcpi = -mtrc * maskpi * np.pi # this is different from the matlab implemetation, as its direvative is -pi
212 | mtrcpi = maskpi * th.pi
213 | maskacos = ((mtrc + 1) > epsilon) * ((3- mtrc) > epsilon) # -1+epsilon < tr < 3 - epsilon
214 |
215 | mtrcacos = th.acos((mtrc * maskacos - 1) / 2) * maskacos # -1+epsilon < tr < 3 - epsilon, use the acos
216 | results = mtrcpi + mtrcacos # for tr -tr <= -1 + epsilon and tr >= 3-epsilon, the derivative = 0
217 | return results
218 |
219 | # === For LieBN ===
220 |
221 | def cal_geom_mean_(self, X, batchdim=[0,1]):
222 | "Karcher flow"
223 | init_point = X[tuple(0 for _ in batchdim)] # Dynamically selects first elements along batchdim
224 | for ith in range(self.karcher_steps):
225 | tan_data = self.mLog(init_point.transpose(-1,-2).matmul(X))
226 | tan_mean = tan_data.mean(dim=batchdim)
227 | condition = tan_mean.norm(dim=(-1, -2))
228 | # print(f'{ith+1}: {condition}')
229 | if th.all(condition<=1e-4):
230 | # th.sum(condition>=0.1)
231 | # print('early stop')
232 | break
233 | init_point = init_point.matmul(self.mExp(tan_mean))
234 | return init_point
235 |
236 | def scaling(self, X,shift,running_var=None,batchdim=[0,1]):
237 | """Frechet variance"""
238 | Log_X = self.mLog(X)
239 | if running_var is None:
240 | Log_X_norm_square = Log_X.norm(p='fro', dim=(-2, -1), keepdim=True).square()
241 | var = th.mean(Log_X_norm_square, dim=batchdim)
242 | var = var.detach() if self.is_detach else var
243 | else:
244 | var = running_var
245 | factor = shift / (var + self.eps).sqrt()
246 | scale_Log = factor * Log_X
247 | X_scaled = self.mExp(scale_Log)
248 | return X_scaled,var
249 |
250 | def translation(self, X, P, is_inverse, is_left):
251 | """translation by P"""
252 | if is_left:
253 | X_new = P.transpose(-1, -2).matmul(X) if is_inverse else P.matmul(X)
254 | else:
255 | X_new = X.matmul(P.transpose(-1, -2)) if is_inverse else X.matmul(P)
256 | return X_new
257 |
258 |
--------------------------------------------------------------------------------
/LieBN/Geometry/Rotations/__init__.py:
--------------------------------------------------------------------------------
1 | from .RotMatrices import RotMatrices
--------------------------------------------------------------------------------
/LieBN/Geometry/Rotations/__pycache__/RotMatrices.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/Rotations/__pycache__/RotMatrices.cpython-38.pyc
--------------------------------------------------------------------------------
/LieBN/Geometry/Rotations/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/Rotations/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/LieBN/Geometry/SPD/__init__.py:
--------------------------------------------------------------------------------
1 | from .SPDMatrices import SPDMatrices,\
2 | SPDLogEuclideanMetric,SPDAdaptiveLogEuclideanMetric,\
3 | SPDLogCholeskyMetric,SPDAffineInvariantMetric,\
4 | SPDCholeskyRightInvariantMetric
--------------------------------------------------------------------------------
/LieBN/Geometry/SPD/__pycache__/SPDMatrices.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/SPD/__pycache__/SPDMatrices.cpython-38.pyc
--------------------------------------------------------------------------------
/LieBN/Geometry/SPD/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/SPD/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/LieBN/Geometry/SPD/__pycache__/sym_functional.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/SPD/__pycache__/sym_functional.cpython-38.pyc
--------------------------------------------------------------------------------
/LieBN/Geometry/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/__init__.py
--------------------------------------------------------------------------------
/LieBN/Geometry/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/Geometry/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/LieBN/LieBNBase.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Ziheng Chen
3 | Please cite the paper below if you use the code:
4 |
5 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024.
6 | Ziheng Chen, Yue Song, Tianyang Xu, Zhiwu Huang, Xiao-Jun Wu, and Nicu Sebe. Adaptive Log-Euclidean metrics for SPD matrix learning TIP 2024.
7 |
8 | Copyright (C) 2024 Ziheng Chen
9 | All rights reserved.
10 | """
11 |
12 | import torch as th
13 | import torch.nn as nn
14 |
15 | class LieBNBase(nn.Module):
16 | """
17 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024
18 | Input X: (batchdim,...,n,n) matrix matrices
19 | Output X_new: (batchdim,...,n,n) batch-normalized matrices
20 | arguments:
21 | shape: excluding the batch dim
22 | batchdim: the fist k dims are batch dim, such as [0],[0,1]...
23 | is_detach: whether use detach when calculating Frecheat statistics.
24 | parameters:
25 | self.weight: [...,n,n] manifold parameter
26 | self.shift: [...,1,1] scalar dispersion
27 | """
28 | def __init__(self,shape, batchdim=[0], momentum=0.1,is_detach=True):
29 | super().__init__()
30 | self.shape=shape; self.momentum = momentum; self.batchdim=batchdim
31 | self.eps = 1e-5;
32 | self.manifold = None;
33 | self.is_detach=is_detach
34 |
35 | # Handle channel vs non-channel case
36 | if len(self.shape) > 2:
37 | # --- running statistics ---
38 | self.register_buffer("running_mean", th.eye(self.shape[-1]).repeat(*self.shape[:-2], 1, 1))
39 | self.register_buffer("running_var", th.ones(*shape[:-2], 1, 1))
40 | # --- parameters ---
41 | self.shift = nn.Parameter(th.ones(*shape[:-2], 1, 1))
42 | else:
43 | # --- running statistics ---
44 | self.register_buffer("running_mean", th.eye(self.shape[-1]))
45 | self.register_buffer("running_var", th.ones(1))
46 | # --- parameters ---
47 | self.shift = nn.Parameter(th.ones(1))
48 |
49 | # biasing weight should be set specifically for each manifold
50 | # self.set_weight()
51 |
52 | def updating_running_statistics(self,batch_mean,batch_var):
53 | self.running_mean.data = self.manifold.geodesic(self.running_mean, batch_mean,self.momentum)
54 | self.running_var.data = (1 - self.momentum) * self.running_var + batch_var * self.momentum
55 |
56 | def set_weight(self):
57 | raise NotImplementedError
58 |
59 | def get_manifold(self):
60 | raise NotImplementedError
61 |
62 | def __repr__(self):
63 | attributes = []
64 | for key, value in self.__dict__.items():
65 | if key.startswith("_"):
66 | continue
67 | if isinstance(value, th.Tensor):
68 | attributes.append(f"{key}.shape={tuple(value.shape)}") # only show Tensor shape
69 | else:
70 | attributes.append(f"{key}={value}")
71 |
72 | # register_buffer
73 | for name, buffer in self.named_buffers(recurse=False):
74 | attributes.append(f"{name}={buffer.item() if buffer.numel() == 1 else tuple(buffer.shape)}")
75 |
76 | return f"{self.__class__.__name__}({', '.join(attributes)}) \n {self.manifold}"
77 |
--------------------------------------------------------------------------------
/LieBN/LieBNCor.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Ziheng Chen
3 | Please cite the paper below if you use the code:
4 |
5 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024.
6 | Ziheng Chen, Yue Song, Tianyang Xu, Zhiwu Huang, Xiao-Jun Wu, and Nicu Sebe. Adaptive Log-Euclidean metrics for SPD matrix learning TIP 2024.
7 |
8 | Copyright (C) 2024 Ziheng Chen
9 | All rights reserved.
10 | """
11 |
12 | import torch as th
13 | import geoopt
14 | from geoopt.manifolds import PoincareBall
15 |
16 | from .Geometry.Correlation.CorMatrices import CorEuclideanCholeskyMetric,CorLogEuclideanCholeskyMetric,CorOffLogMetric,CorLogScaledMetric
17 | from .Geometry.Correlation.ppbcm_functionals import poincare_to_hemisphere
18 |
19 | from . import LieBNBase
20 |
21 | class LieBNCor(LieBNBase):
22 | """ Implemented metrics: ECM,LECM,OLM,LSM; all are bi-invariant metric with Abelian groups.
23 | param_mode: trivial,Riem
24 |
25 | Note:
26 | - Our experiments indicates that the Riemmannian is better than trivialization for correlation biasing parameters;
27 | - We do not use detach when calculating Fréchet statistics, as our experiments show it is better than detach.
28 | - is_inv_deformation: when it is False, this might help to simplify computations (for certain following layers).
29 | """
30 |
31 | def __init__(self,shape ,batchdim=[0],momentum=0.1,is_detach=False,
32 | metric: str ='ECM',alpha=1.,param_1=0.,param_2=0.,max_iter=100,
33 | is_inv_deformation=True,param_mode='Riem'):
34 | super().__init__(shape, batchdim=batchdim, momentum=momentum,is_detach=is_detach)
35 | self.metric = metric; self.alpha = alpha;self.param_1 = param_1;self.param_2 = param_2;
36 | self.max_iter = max_iter;self.is_inv_deformation=is_inv_deformation
37 | self.param_mode=param_mode
38 | self.get_manifold()
39 | self.set_weight()
40 |
41 | def set_weight(self):
42 | """Initializes (n-1) separate Poincare vectors of increasing dimensions (1 to n-1) using the Poincare Ball manifold."""
43 | # use poincare updates
44 | if self.param_mode=='Riem':
45 | self.weight = th.nn.ParameterList([
46 | geoopt.ManifoldParameter(
47 | th.zeros(*self.shape[:-2], i)
48 | if len(self.shape) > 2 else th.zeros(i), # Handle channel vs non-channel case
49 | manifold=PoincareBall(c=1.0, learnable=False)
50 | )
51 | for i in range(1, self.shape[-1]) # Each vector has i dimensions
52 | ])
53 | elif self.param_mode=='trivial':
54 | # use identical metric updates
55 | self.weight = th.nn.Parameter(th.randn(self.shape))
56 |
57 | def poincare_to_correlation(self,weight):
58 | """Converts stored Poincare vectors into a correlation matrix without extra memory storage."""
59 | # Initialize lower-triangular Cholesky factor L
60 | L = th.zeros(self.shape, dtype=weight[0].dtype, device=weight[0].device)
61 | L[...,0, 0] = 1 # Set (0,0) explicitly to 1
62 |
63 | # Directly compute Hemisphere representations and assign to L
64 | for i in range(1, self.shape[-1]):
65 | L[...,i, :i+1] = poincare_to_hemisphere(weight[i - 1]) # No extra storage, direct assignment
66 | return L @ L.transpose(-1,-2)
67 |
68 | def Euc2Codomain(self,weight):
69 | if self.metric in ['ECM','LECM']:
70 | weight_codomain = weight.tril(-1)
71 | elif self.metric == 'OLM':
72 | weight_codomain = weight.tril(-1) + weight.tril(-1).transpose(-1,-2)
73 | elif self.metric == 'LSM':
74 | n = self.shape[-1]
75 | weight_codomain = th.zeros_like(weight)
76 | weight_codomain[..., :n - 1, :n - 1] = weight.tril(-1)[..., 1:n , :n-1]
77 | # **Symmetrization**: Ensure the upper-left part of Row0 is a symmetric matrix
78 | weight_codomain = weight_codomain.tril() + weight_codomain.tril(-1).transpose(-1, -2)
79 | # **Step 2**: Compute the last column so that the sum of each row is 0
80 | weight_codomain[..., :-1, -1] = -th.sum(weight_codomain[..., :-1, :-1], dim=-1)
81 | # **Step 3**: Compute the `[n, n]` element so that the last row can be recovered symmetrically
82 | weight_codomain[..., -1, -1] = -th.sum(weight_codomain[..., :-1, -1], dim=-1)
83 | # **Step 4**: Restore the last row to be equal to the transpose of the last column
84 | weight_codomain[..., -1, :-1] = weight_codomain[..., :-1, -1]
85 | return weight_codomain
86 |
87 | def forward(self,X):
88 | #deformation
89 | X_deformed = self.manifold.deformation(X)
90 | if self.param_mode == 'Riem':
91 | weight = self.manifold.deformation(self.poincare_to_correlation(self.weight))
92 | elif self.param_mode == 'trivial':
93 | weight = self.Euc2Codomain(self.weight)
94 |
95 | if(self.training):
96 | # centering
97 | batch_mean = self.manifold.cal_geom_mean(X_deformed,batchdim=self.batchdim)
98 | X_centered = self.manifold.translation(X_deformed,batch_mean,is_inverse=True)
99 |
100 | # scaling and shifting
101 | # Note that as the mean is calculated in closed-form, cal_geom_var (by dist(X_i,I)) is accurate
102 | batch_var = self.manifold.cal_geom_var(X_centered,batchdim=self.batchdim)
103 | factor = self.shift / (batch_var + self.eps).sqrt()
104 | X_scaled = self.manifold.scaling(X_centered, factor)
105 | self.updating_running_statistics(batch_mean, batch_var)
106 | else:
107 | # centering, scaling and shifting
108 | X_centered = self.manifold.translation(X_deformed, self.running_mean, is_inverse=True)
109 | factor = self.shift / (self.running_var + self.eps).sqrt()
110 | X_scaled = self.manifold.scaling(X_centered, factor)
111 | #biasing
112 | X_normalized = self.manifold.translation(X_scaled, weight, is_inverse=False)
113 | # inv_deformation
114 | if self.is_inv_deformation:
115 | X_new = self.manifold.inv_deformation(X_normalized)
116 | else:
117 | X_new = X_normalized
118 | return X_new
119 |
120 | def get_manifold(self):
121 | # ECM,LECM,OLM,LSM
122 | classes = {
123 | "ECM": CorEuclideanCholeskyMetric,
124 | "LECM": CorLogEuclideanCholeskyMetric,
125 | "OLM": CorOffLogMetric,
126 | "LSM": CorLogScaledMetric,
127 | }
128 |
129 | if self.metric == 'OLM':
130 | self.manifold = classes[self.metric](n=self.shape[-1],
131 | alpha=self.alpha,beta=self.param_1, gamma=self.param_2,
132 | max_iter=self.max_iter)
133 | elif self.metric == 'LSM':
134 | self.manifold = classes[self.metric](n=self.shape[-1],
135 | alpha=self.alpha, delta=self.param_1,zeta=self.param_2,
136 | max_iter=self.max_iter)
137 | elif self.metric in ['ECM','LECM']:
138 | self.manifold = classes[self.metric](n=self.shape[-1])
139 | else:
140 | raise NotImplementedError
141 |
142 | self.manifold.is_detach = self.is_detach
143 |
144 |
--------------------------------------------------------------------------------
/LieBN/LieBNRot.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Ziheng Chen
3 | Please cite the paper below if you use the code:
4 |
5 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024.
6 | Ziheng Chen, Yue Song, Tianyang Xu, Zhiwu Huang, Xiao-Jun Wu, and Nicu Sebe. Adaptive Log-Euclidean metrics for SPD matrix learning TIP 2024.
7 |
8 | Copyright (C) 2024 Ziheng Chen
9 | All rights reserved.
10 | """
11 |
12 | import torch as th
13 | import geoopt
14 | from .Geometry.Rotations import RotMatrices
15 | from . import LieBNBase
16 |
17 | class LieBNRot(LieBNBase):
18 | """ LieBN on the 3 x 3 rotations under the bi-invariant metric.
19 | is_left: True for left translation for centering and biasing.
20 | is_detach in RotMatrices is set to True by default.
21 | """
22 | def __init__(self,shape ,batchdim=[0,1],momentum=0.1,is_detach=True,
23 | karcher_steps=1,is_left=True):
24 | super().__init__(shape,batchdim,momentum,is_detach)
25 | self.karcher_steps = karcher_steps;self.is_left=is_left
26 | self.get_manifold()
27 | self.set_weight()
28 |
29 | def set_weight(self):
30 | if len(self.shape) > 2:
31 | self.weight = geoopt.ManifoldParameter(th.eye(self.shape[-1]).repeat(*self.shape[:-2], 1, 1),
32 | manifold=RotMatrices())
33 | else:
34 | self.weight = geoopt.ManifoldParameter(th.eye(self.shape[-1]),manifold=RotMatrices())
35 |
36 | def forward(self,S):
37 | if(self.training):
38 | batch_mean = self.manifold.cal_geom_mean(S,batchdim=self.batchdim)
39 | X_centered = self.manifold.translation(S,batch_mean,is_inverse=True,is_left=self.is_left)
40 | X_scaled,var = self.manifold.scaling(X_centered,shift=self.shift,batchdim=self.batchdim)
41 | self.updating_running_statistics(batch_mean, var)
42 | else:
43 | X_centered = self.manifold.translation(S, self.running_mean, is_inverse=True,is_left=self.is_left)
44 | X_scaled,_ = self.manifold.scaling(X_centered,shift=self.shift,running_var=self.running_var)
45 | X_normalized = self.manifold.translation(X_scaled, self.weight, is_inverse=False,is_left=self.is_left)
46 |
47 | return X_normalized
48 |
49 | def get_manifold(self):
50 | self.manifold = RotMatrices(karcher_steps=self.karcher_steps)
51 | self.manifold.is_detach = self.is_detach
52 |
53 |
--------------------------------------------------------------------------------
/LieBN/LieBNSPD.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Ziheng Chen
3 | Please cite the paper below if you use the code:
4 |
5 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024.
6 | Ziheng Chen, Yue Song, Tianyang Xu, Zhiwu Huang, Xiao-Jun Wu, and Nicu Sebe. Adaptive Log-Euclidean metrics for SPD matrix learning TIP 2024.
7 |
8 | Copyright (C) 2024 Ziheng Chen
9 | All rights reserved.
10 | """
11 |
12 | import torch as th
13 | import geoopt
14 | from geoopt.manifolds import SymmetricPositiveDefinite
15 |
16 | from .Geometry.SPD.SPDMatrices import SPDLogEuclideanMetric,SPDAdaptiveLogEuclideanMetric,SPDLogCholeskyMetric,\
17 | SPDAffineInvariantMetric,SPDCholeskyRightInvariantMetric,tril_param_metric,bi_param_metric,single_param_metric
18 |
19 | from . import LieBNBase
20 |
21 | class LieBNSPD(LieBNBase):
22 | """ Arguments:
23 | metric: LEM, ALEM, LCM, AIM, CRIM, with the last one is right-invariant
24 | karcher_step=1: for AIM and CRIM
25 | power=1: power deformation for AIM, CRIM, and LCM
26 | (alpha, beta)=(1,0): O(n)-invariance inner product in AIM, CRIM, LEM, and ALEM.
27 |
28 | Note:
29 | - The SPD biasing paramter is optimized by AIM-based geoopt.
30 | - For ALEM, the paramter A is in the SPDAdaptiveLogEuclideanMetric.
31 | - There are two ways to calculate the varriance: 1 via d(P_i, M); 2 via d(\bar{P}_i, I).
32 | If the Fréchet mean are accurate, such as LEM,LCM, and ALEM, these two ways are euqivalent.
33 | For AIM and CRIM, as we set karcher_steps=1, the 2nd is an approximation. We adpot the 2nd one for efficiency.
34 | If one set karcher flow to converge, 1 is euqivalent to 2.
35 | - Follwoing SPDNetBN, we use detach when calculating Fréchet statistics
36 | """
37 |
38 | def __init__(self,shape ,batchdim=[0],momentum=0.1,is_detach=True,
39 | metric: str ='AIM',power=1.,alpha=1.0,beta=0.,karcher_steps=1,):
40 | super().__init__(shape,batchdim,momentum,is_detach)
41 | self.metric = metric; self.power = power;self.alpha = alpha;self.beta = beta;
42 | self.karcher_steps = karcher_steps
43 | self.get_manifold()
44 | self.set_weight()
45 |
46 | def set_weight(self):
47 | if len(self.shape) > 2:
48 | self.weight = geoopt.ManifoldParameter(th.eye(self.shape[-1]).repeat(*self.shape[:-2], 1, 1),
49 | manifold=SymmetricPositiveDefinite())
50 | else:
51 | self.weight = geoopt.ManifoldParameter(th.eye(self.shape[-1]),
52 | manifold=SymmetricPositiveDefinite())
53 |
54 | def forward(self,X):
55 | #deformation
56 | X_deformed = self.manifold.deformation(X)
57 | weight = self.manifold.deformation(self.weight)
58 |
59 | if(self.training):
60 | # centering
61 | batch_mean = self.manifold.cal_geom_mean(X_deformed,batchdim=self.batchdim)
62 | X_centered = self.manifold.translation(X_deformed,batch_mean,is_inverse=True)
63 |
64 | # scaling and shifting
65 | # As centering is an isometry, batch variance is equal to the one of the centered data (to the identity element).
66 | # This is more efficient than calculating the original bach variance.
67 | # Note that if the mean is calculated approximatedly, such as karcher flow, cal_geom_var is also approximated
68 | # One can also try original bach variance.
69 | batch_var = self.manifold.cal_geom_var(X_centered,batchdim=self.batchdim)
70 | factor = self.shift / (batch_var + self.eps).sqrt()
71 | X_scaled = self.manifold.scaling(X_centered, factor)
72 | self.updating_running_statistics(batch_mean, batch_var)
73 |
74 | else:
75 | # centering, scaling and shifting
76 | X_centered = self.manifold.translation(X_deformed, self.running_mean, is_inverse=True)
77 | factor = self.shift / (self.running_var + self.eps).sqrt()
78 | X_scaled = self.manifold.scaling(X_centered, factor)
79 | #biasing
80 | X_normalized = self.manifold.translation(X_scaled, weight, is_inverse=False)
81 | # inv_deformation
82 | X_new = self.manifold.inv_deformation(X_normalized)
83 |
84 | return X_new
85 |
86 | def get_manifold(self):
87 | classes = {
88 | "LEM": SPDLogEuclideanMetric,
89 | "ALEM": SPDAdaptiveLogEuclideanMetric,
90 | "LCM": SPDLogCholeskyMetric,
91 | "AIM": SPDAffineInvariantMetric,
92 | "CRIM": SPDCholeskyRightInvariantMetric,
93 | }
94 | n=self.shape[-1]
95 | if self.metric in tril_param_metric:
96 | self.manifold = classes[self.metric](n=n, power=self.power,alpha=self.alpha, beta=self.beta,
97 | karcher_steps=self.karcher_steps)
98 | elif self.metric in bi_param_metric:
99 | self.manifold = classes[self.metric](n=n, alpha=self.alpha, beta=self.beta)
100 | elif self.metric in single_param_metric:
101 | self.manifold = classes[self.metric](n=n, power=self.power)
102 | else:
103 | raise NotImplementedError
104 |
105 | self.manifold.is_detach=self.is_detach
106 |
107 |
--------------------------------------------------------------------------------
/LieBN/LieBN_illustration.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/LieBN_illustration.png
--------------------------------------------------------------------------------
/LieBN/README.md:
--------------------------------------------------------------------------------
1 | [
](https://arxiv.org/abs/2403.11261)
2 | [
](https://openreview.net/forum?id=okYdj8Ysru)
3 | [
](https://openreview.net/pdf?id=okYdj8Ysru)
4 |
5 | # Lie Group Batch Normalization
6 |
7 |
8 |

9 |
Figure 1: Visualization of LieBN on the SPD, rotation, and correlation manifolds.
10 |
11 |
12 | ## Introduction
13 | This is the official implementation for our ICLR 2024 publications:
14 | - *A Lie Group Approach to Riemannian Batch Normalization*
15 | [[OpenReview](https://openreview.net/forum?id=okYdj8Ysru)]
16 |
17 | - *LieBN: Batch Normalization over Lie Groups* (Conference extension, in submission)
18 |
19 | If you find this project helpful, please consider citing us as follows:
20 |
21 | ```bib
22 | @inproceedings{chen2024liebn,
23 | title={A Lie Group Approach to Riemannian Batch Normalization},
24 | author={Ziheng Chen and Yue Song and Yunmei Liu and Nicu Sebe},
25 | booktitle={ICLR},
26 | year={2024},
27 | }
28 | ```
29 |
30 | If you have any problem, please contact me via ziheng_ch@163.com.
31 |
32 | ## Implementations
33 | This source code contains LieBN on the following manifolds:
34 | - SPD manifolds: Log-Euclidean Metric (LEM), Affine-Invariant Metric (AIM), Log-Cholesky Metric (LCM), and our proposed Cholesky Right Invariant Metric (CRIM);
35 | - Rotation groups: the canonical bi-invariant metric;
36 | - Full-rank correlation manifolds: Euclidean-Cholesky Metric (ECM), Log-Euclidean-Cholesky Metric (LECM), Off-Log Metric (OLM), and Log-Scaled Metric (LSM).
37 |
38 |
39 |
40 |
41 | Figure 2: Summary of Invariant Metrics.
42 |
43 |
44 |
45 | **Notes:**
46 | - By default, LieBN uses `.detach()` (`is_detach=True`) when computing Fréchet statistics on the SPD and rotation matrices, following prior work (SPDNetBN).
47 | - For the correlation matrix, however, LieBN sets `is_detach=False` by default, as our experiments suggest this yields better results.
48 |
49 | ## Requirements
50 | Requierments: `torch`, `geoopt`, and `pytorch3d`.
51 |
52 | ## Demos
53 | Demos of LieBN on different geometries can be found in `./demos`:
54 |
55 |
--------------------------------------------------------------------------------
/LieBN/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Ziheng Chen
3 | Please cite the paper below if you use the code:
4 |
5 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024.
6 | Ziheng Chen, Yue Song, Tianyang Xu, Zhiwu Huang, Xiao-Jun Wu, and Nicu Sebe. Adaptive Log-Euclidean metrics for SPD matrix learning TIP 2024.
7 |
8 | Copyright (C) 2024 Ziheng Chen
9 | All rights reserved.
10 | """
11 |
12 | from .LieBNBase import LieBNBase
13 | from .LieBNCor import LieBNCor
14 | from .LieBNSPD import LieBNSPD
15 | from .LieBNRot import LieBNRot
--------------------------------------------------------------------------------
/LieBN/__pycache__/LieBNBase.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/__pycache__/LieBNBase.cpython-38.pyc
--------------------------------------------------------------------------------
/LieBN/__pycache__/LieBNCor.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/__pycache__/LieBNCor.cpython-38.pyc
--------------------------------------------------------------------------------
/LieBN/__pycache__/LieBNRot.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/__pycache__/LieBNRot.cpython-38.pyc
--------------------------------------------------------------------------------
/LieBN/__pycache__/LieBNSPD.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/__pycache__/LieBNSPD.cpython-38.pyc
--------------------------------------------------------------------------------
/LieBN/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/LieBN/demos/liebncor.py:
--------------------------------------------------------------------------------
1 | """
2 | Demo for LieBN on Correlation Matrices
3 |
4 | @author: Ziheng Chen
5 | Please cite:
6 |
7 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe.
8 | A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024
9 |
10 | Copyright (C) 2024 Ziheng Chen
11 | All rights reserved.
12 | """
13 |
14 | import torch as th
15 | import geoopt
16 | from LieBN import LieBNCor
17 | from LieBN.Geometry.Correlation import Correlation
18 |
19 | # Set the random seed for reproducibility
20 | SEED = 42
21 | th.manual_seed(SEED)
22 |
23 |
24 | # Helper function to generate random correlation matrices
25 | def random_correlation_matrix(bs, c, n):
26 | corr_module = Correlation(n)
27 | return corr_module.random_cor([bs, c, n, n]).to(th.double)
28 |
29 |
30 | # Set dimensions
31 | bs, c, n = 4, 2, 5 # Batch size, channels, matrix size
32 | shape = [c, n, n]
33 | manifold = Correlation(n=n)
34 |
35 | # Instantiate correlation metric module and generate random input
36 | P = manifold.random(bs, c, n,n).requires_grad_(True).to(th.double) # Input correlation matrices
37 | target = manifold.random(bs, c, n,n).to(th.double) # Target correlation matrices for loss computation
38 |
39 | # Instantiate LieBN for Correlation Manifold (ECM, LECM, OLM, LSM)
40 | liebn = LieBNCor(shape=shape, metric='OLM',batchdim=[0]).to(th.double)
41 |
42 | print("\nLieBNCor Layer Initialized:", liebn)
43 | print("Manifold:", liebn.manifold)
44 |
45 | print("\n=== LieBNCor Parameters ===")
46 | for name, param in liebn.named_parameters():
47 | print(f"Parameter Name: {name}, Shape: {param.shape}, Requires Grad: {param.requires_grad}")
48 |
49 | # Define optimizer and loss function
50 | optimizer = geoopt.optim.RiemannianAdam(liebn.parameters(), lr=1e-3)
51 | criterion = th.nn.MSELoss()
52 |
53 | # Print initial loss
54 | with th.no_grad():
55 | initial_loss = criterion(liebn(P), target).item()
56 | print(f"\nInitial Loss: {initial_loss:.6f}")
57 |
58 | # Training loop
59 | num_epochs = 5 # Set higher if needed
60 | for i in range(num_epochs):
61 | liebn.train() # Set to training mode
62 | optimizer.zero_grad() # Clear previous gradients
63 | output = liebn(P)
64 | # Compute loss
65 | loss = criterion(output, target)
66 |
67 | # Backpropagation
68 | loss.backward()
69 |
70 | # Gradient Norm Check (Optional)
71 | grad_norm = th.nn.utils.clip_grad_norm_(liebn.parameters(), max_norm=1.0)
72 |
73 | optimizer.step()
74 |
75 | # Print loss every iteration
76 | print(f"Epoch {i + 1} | Loss: {loss.item():.6f} | Grad Norm: {grad_norm:.6f}")
77 |
78 | print("\nProcessed Correlation Matrices in Training Mode:", output.shape)
79 | print(f"Final Training Loss: {loss.item():.6f}")
80 |
81 | # Evaluation mode
82 | liebn.eval()
83 | with th.no_grad():
84 | test_output = liebn(P)
85 |
86 | print("\nProcessed Correlation Matrices in Testing Mode:", test_output.shape)
87 |
88 | print("\nTraining and Evaluation completed successfully! 🚀")
89 |
--------------------------------------------------------------------------------
/LieBN/demos/liebnrot.py:
--------------------------------------------------------------------------------
1 | """
2 | Demo for LieBN on SO(3) matrices
3 |
4 | @author: Ziheng Chen
5 | Please cite:
6 |
7 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe.
8 | A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024
9 |
10 | Copyright (C) 2024 Ziheng Chen
11 | All rights reserved.
12 | """
13 |
14 | import torch as th
15 | import geoopt
16 | from LieBN import LieBNRot
17 | from LieBN.Geometry.Rotations import RotMatrices
18 |
19 | # Set the random seed for reproducibility
20 | SEED = 42
21 | th.manual_seed(SEED)
22 |
23 | # Batch size, frame, space (SO(3) are 3x3 matrices)
24 | # In LieNet, the shape is [bs,f,s,3,3]
25 | bs, f, s = 5, 2, 3
26 | manifold = RotMatrices()
27 |
28 | # Generate random SO(3) matrices as inputs and targets
29 | input_rot = manifold.random(bs,f, s,3,3).to(th.double).requires_grad_(True).to(th.double) # Input SO(3) matrices
30 | target_rot = manifold.random(bs,f, s,3,3).to(th.double).to(th.double) # Target matrices for loss computation
31 |
32 | # Instantiate LieBN for rotation matrices
33 | # is_left=False/True
34 | # liebn = LieBNRot(shape=[s,3,3], batchdim=[0, 1], is_left=False,karcher_steps=100).to(th.double)
35 | liebn = LieBNRot(shape=[s,3,3], batchdim=[0,1], is_left=False).to(th.double)
36 |
37 | print("\nLieBNRot Layer Initialized:", liebn)
38 | print("Manifold:", liebn.manifold)
39 |
40 | print("\n=== LieBNRot Parameters ===")
41 | for name, param in liebn.named_parameters():
42 | print(f"Parameter Name: {name}, Shape: {param.shape}, Requires Grad: {param.requires_grad}")
43 |
44 | # Define optimizer and loss function
45 | optimizer = geoopt.optim.RiemannianAdam(liebn.parameters(), lr=1e-3)
46 | criterion = th.nn.MSELoss()
47 |
48 | # Training loop
49 | num_epochs = 2 # Set higher if needed
50 | for i in range(num_epochs):
51 | liebn.train() # Set to training mode
52 | optimizer.zero_grad() # Clear previous gradients
53 |
54 | # Forward pass
55 | output_rot = liebn(input_rot)
56 |
57 | # Compute loss
58 | loss = criterion(output_rot, target_rot)
59 |
60 | # Backpropagation
61 | loss.backward()
62 | optimizer.step()
63 |
64 | # Print loss every iteration
65 | print(f"Epoch {i + 1} | Loss: {loss.item():.6f}")
66 |
67 | print("\nProcessed SO(3) Matrices in Training Mode:", output_rot.shape)
68 | print(f"Final Training Loss: {loss.item():.6f}")
69 |
70 | # Evaluation mode
71 | liebn.eval()
72 | with th.no_grad():
73 | test_output = liebn(input_rot)
74 |
75 | print("\nProcessed SO(3) Matrices in Testing Mode:", test_output.shape)
76 |
77 | print("\nTraining and Evaluation completed successfully! 🚀")
78 |
--------------------------------------------------------------------------------
/LieBN/demos/liebnspd.py:
--------------------------------------------------------------------------------
1 | """
2 | Demo for LieBN on SPD Matrices
3 |
4 | @author: Ziheng Chen
5 | Please cite:
6 |
7 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe.
8 | A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024
9 |
10 | Copyright (C) 2024 Ziheng Chen
11 | All rights reserved.
12 | """
13 |
14 | import torch as th
15 | import geoopt
16 | from LieBN import LieBNSPD
17 | from LieBN.Geometry.SPD import SPDMatrices
18 |
19 | # Set the random seed for reproducibility
20 | SEED = 42
21 | th.manual_seed(SEED)
22 |
23 |
24 | # Batch size, channels, SPD matrix size
25 | # In SPDNet and SPDNetBN, the shape is [bs,1,n,n]
26 | bs, c, n = 4, 2, 5
27 |
28 | # Generate input SPD matrices and target SPD matrices
29 | manifold=SPDMatrices(n=n)
30 | P = manifold.random(bs, c, n,n).requires_grad_(True).to(th.double) # Input SPD matrices
31 | target = manifold.random(bs, c, n,n).to(th.double) # Target SPD matrices for loss computation
32 |
33 | # LEM,ALEM,LCM,AIM,CRIM
34 | liebn = LieBNSPD(shape=[c, n, n], metric="LCM", batchdim=[0]).to(th.double)
35 |
36 | print("\nLieBNSPD Layer Initialized:", liebn)
37 |
38 | print("\n=== LieBNSPD Parameters ===")
39 | for name, param in liebn.named_parameters():
40 | print(f"Parameter Name: {name}, Shape: {param.shape}, Requires Grad: {param.requires_grad}")
41 |
42 | # Define optimizer (Adam) and loss function (MSE)
43 | optimizer = geoopt.optim.RiemannianAdam(liebn.parameters(), lr=1e-3)
44 | criterion = th.nn.MSELoss()
45 |
46 | # Print initial loss
47 | with th.no_grad():
48 | initial_loss = criterion(liebn(P), target).item()
49 | print(f"\nInitial Loss: {initial_loss:.6f}")
50 |
51 | # Training loop
52 | num_epochs = 2 # Increase if needed
53 | for i in range(num_epochs):
54 | liebn.train() # Set to training mode
55 | optimizer.zero_grad() # Clear previous gradients
56 |
57 | # Forward pass
58 | output = liebn(P)
59 |
60 | # Compute loss
61 | loss = criterion(output, target)
62 |
63 | # Backpropagation
64 | loss.backward()
65 |
66 | # Gradient Norm Check (Optional)
67 | grad_norm = th.nn.utils.clip_grad_norm_(liebn.parameters(), max_norm=1.0)
68 |
69 | optimizer.step()
70 |
71 | # Print loss every iteration
72 | print(f"Epoch {i + 1} | Loss: {loss.item():.6f} | Grad Norm: {grad_norm:.6f}")
73 |
74 | print("\nProcessed SPD Matrices in Training Mode:", output.shape)
75 | print(f"Final Training Loss: {loss.item():.6f}")
76 |
77 | # Evaluation mode
78 | liebn.eval()
79 | with th.no_grad():
80 | test_output = liebn(P)
81 |
82 | print("\nProcessed SPD Matrices in Testing Mode:", test_output.shape)
83 |
84 | print("\nTraining and Evaluation completed successfully! 🚀")
85 |
--------------------------------------------------------------------------------
/LieBN/sum_metrics.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN/sum_metrics.png
--------------------------------------------------------------------------------
/LieBN_SPDNet/Network/Get_Model.py:
--------------------------------------------------------------------------------
1 | from Network.SPDNetRBN import SPDNetLieBN,SPDNet
2 |
3 |
4 | def get_model(args):
5 | if args.model_type in args.total_BN_model_types:
6 | model = SPDNetLieBN(args)
7 | elif args.model_type=='SPDNet':
8 | model = SPDNet(args)
9 | else:
10 | raise Exception('unknown model {} or metric {}'.format(args.model_type,args.metric))
11 | return model
--------------------------------------------------------------------------------
/LieBN_SPDNet/Network/SPDNetRBN.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import cplx.nn as nn_cplx
3 | import spd.nn as nn_spd
4 | from spd.LieBN import LieBatchNormSPD
5 |
6 | class SPDNetLieBN(nn.Module):
7 | def __init__(self,args):
8 | super(__class__, self).__init__()
9 | dims = [int(dim) for dim in args.architecture]
10 | self.feature = []
11 | if args.dataset == 'RADAR':
12 | self.feature.append(nn_cplx.SplitSignal_cplx(2, 20, 10))
13 | self.feature.append(nn_cplx.CovPool_cplx())
14 | self.feature.append(nn_spd.ReEig())
15 |
16 | for i in range(len(dims) - 2):
17 | self.feature.append(nn_spd.BiMap(1,1,dims[i], dims[i + 1]))
18 | self.feature.append(SPDBN(dims[i + 1],args))
19 | self.feature.append(nn_spd.ReEig())
20 |
21 | self.feature.append(nn_spd.BiMap(1,1,dims[-2], dims[-1]))
22 | self.feature.append(SPDBN(dims[-1],args))
23 | self.feature = nn.Sequential(*self.feature)
24 |
25 | self.classifier = LogEigMLR(dims[-1]**2,args.class_num)
26 |
27 | def forward(self, x):
28 | x_spd = self.feature(x)
29 | y = self.classifier(x_spd)
30 | return y
31 |
32 | # LogEig MLR
33 | class LogEigMLR(nn.Module):
34 | def __init__(self, input_dim, classnum):
35 | super(__class__, self).__init__()
36 | self.logeig = nn_spd.LogEig()
37 | self.linear = nn.Linear(input_dim, classnum).double()
38 |
39 | def forward(self, x):
40 | x_vec = self.logeig(x).view(x.shape[0], -1)
41 | y = self.linear(x_vec)
42 | return y
43 |
44 | class SPDBN(nn.Module):
45 | def __init__(self, n,args,ddevice='cpu'):
46 | super(__class__, self).__init__()
47 | if args.BN_type == 'brooks':
48 | self.BN = nn_spd.BatchNormSPD(n,args.momentum)
49 | elif args.BN_type == 'LieBN':
50 | self.BN = LieBatchNormSPD(n,
51 | metric=args.metric,
52 | theta=args.theta, alpha=args.alpha, beta=args.beta,momentum=args.momentum)
53 | else:
54 | raise Exception('unknown BN {}'.format(args.BN_type))
55 |
56 | def forward(self, x):
57 | x_spd = self.BN(x)
58 | return x_spd
59 |
60 | class SPDNet(nn.Module):
61 | def __init__(self,args):
62 | super(__class__, self).__init__()
63 | dims = [int(dim) for dim in args.architecture]
64 | self.feature = []
65 | if args.dataset == 'RADAR':
66 | self.feature.append(nn_cplx.SplitSignal_cplx(2, 20, 10))
67 | self.feature.append(nn_cplx.CovPool_cplx())
68 | self.feature.append(nn_spd.ReEig())
69 |
70 | for i in range(len(dims) - 2):
71 | self.feature.append(nn_spd.BiMap(1,1,dims[i], dims[i + 1]))
72 | self.feature.append(nn_spd.ReEig())
73 |
74 | self.feature.append(nn_spd.BiMap(1,1,dims[-2], dims[-1]))
75 | self.feature = nn.Sequential(*self.feature)
76 |
77 | self.classifier = LogEigMLR(dims[-1]**2,args.class_num)
78 |
79 | def forward(self, x):
80 | x_spd = self.feature(x)
81 | y = self.classifier(x_spd)
82 | return y
83 |
84 |
85 |
86 |
--------------------------------------------------------------------------------
/LieBN_SPDNet/Network/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_SPDNet/Network/__init__.py
--------------------------------------------------------------------------------
/LieBN_SPDNet/SPDNetLieBN.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Ziheng Chen
3 | Please cite the paper below if you use the code:
4 |
5 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024
6 |
7 | Copyright (C) 2024 Ziheng Chen
8 | All rights reserved.
9 | """
10 |
11 | import hydra
12 | from omegaconf import DictConfig
13 |
14 | from spd.training_script import training
15 |
16 | class Args:
17 | """ a Struct Class """
18 | pass
19 | args=Args()
20 | args.config_name='LieBN.yaml'
21 |
22 |
23 | args.total_BN_model_types=['SPDNetLieBN', 'SPDNetBN']
24 | args.total_LieBN_model_types=['SPDNetLieBN']
25 |
26 | @hydra.main(config_path='./conf/', config_name=args.config_name, version_base='1.1')
27 | def main(cfg: DictConfig):
28 | training(cfg,args)
29 |
30 | if __name__ == "__main__":
31 | main()
--------------------------------------------------------------------------------
/LieBN_SPDNet/conf/LieBN.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - nnet: SPDNetLieBN # SPDNet SPDNetLieBN SPDNetBN
4 | - dataset: RADAR # HDM05, RADAR, FPHA
5 | - override hydra/launcher: joblib
6 | fit:
7 | epochs: 200
8 | batch_size: 30
9 | threadnum: 1
10 | is_writer: True
11 | cycle: 1
12 | seed: 1024
13 | is_save: True
14 |
15 | hydra:
16 | run:
17 | dir: ./outputs/${dataset.name}
18 | sweep:
19 | dir: ./outputs/${dataset.name}
20 | subdir: '.'
21 | launcher:
22 | n_jobs: -1
23 | job_logging:
24 | handlers:
25 | file:
26 | class: logging.FileHandler
27 | filename: default.log
28 |
--------------------------------------------------------------------------------
/LieBN_SPDNet/conf/dataset/HDM05.yaml:
--------------------------------------------------------------------------------
1 | name: HDM05
2 | architecture: [93,30]
3 | path: /data #change this to your data folder
4 |
--------------------------------------------------------------------------------
/LieBN_SPDNet/conf/dataset/RADAR.yaml:
--------------------------------------------------------------------------------
1 | name: RADAR
2 | architecture: [20,16,12]
3 | path: /data #change this to your data folder
--------------------------------------------------------------------------------
/LieBN_SPDNet/conf/nnet/SPDNet.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | model_type: SPDNet
3 | optimizer:
4 | mode: AMSGRAD
5 | lr: 5e-3
6 | weight_decay: 0.
--------------------------------------------------------------------------------
/LieBN_SPDNet/conf/nnet/SPDNetBN.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - SPDNet
3 | - _self_
4 | model:
5 | model_type: SPDNetBN
6 | BN_type: brooks
7 | momentum: 0.1
--------------------------------------------------------------------------------
/LieBN_SPDNet/conf/nnet/SPDNetLieBN.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - SPDNetBN
3 | - _self_
4 | model:
5 | model_type: SPDNetLieBN
6 | BN_type: LieBN
7 | metric: AIM # AIM,LEM,LCM
8 | theta: 1.0
9 | alpha: 1.0
10 | beta: 0.0
--------------------------------------------------------------------------------
/LieBN_SPDNet/cplx/functional.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 |
3 | def conv_cplx(X,conv_re,conv_im):
4 | if(isinstance(X,list)):
5 | XX=[x.split(x.shape[0]//2,0) for x in X]
6 | tmp=X[0].split(X[0].shape[0]//2,0)
7 | tmpp=conv_re(tmp[0][None,...])-conv_im(tmp[1][None,...])
8 | P_Re=[conv_re(xx[0][None,...])-conv_im(xx[1][None,...]) for xx in XX]
9 | P_Im=[conv_im(xx[0][None,...])+conv_re(xx[1][None,...]) for xx in XX]
10 | P=[th.cat((p_Re,p_Im),0) for p_Re,p_Im in zip(P_Re,P_Im)]
11 | else:
12 | XX=X.split(X.shape[1]//2,1)
13 | P_Re=conv_re(XX[0])-conv_im(XX[1])
14 | P_Im=conv_im(XX[0])+conv_re(XX[1])
15 | P=th.cat((P_Re,P_Im),1)
16 | return P
17 |
18 | def split_signal_cplx(X,conv_re,conv_im):
19 | '''
20 | 1D to 2D complex conv layer, where the weights are adequately placed zeroes and ones to split a signal
21 | '''
22 | XX=X.split(X.shape[1]//2,1)
23 | P_Re=conv_re(XX[0])
24 | P_Im=conv_re(XX[1])
25 | P=th.cat((P_Re[:,None,:,:],P_Im[:,None,:,:]),1)
26 | return P
27 |
28 | def roll(X):
29 | return th.cat((X[:,:,X.shape[-2]//2:,:],X[:,:,:X.shape[-2]//2,:]),2)
30 |
31 | def decibel(X):
32 | # n_fft=X.shape[-2]//2
33 | # X_Re=X[:,:n_fft,:]
34 | # X_Im=X[:,n_fft:,:]
35 | if(isinstance(X,list)):
36 | absX=AbsolutSquared()(X)
37 | X_db=[10*th.log(x) for x in absX]
38 | else:
39 | X_db=10*th.log(absolut_squared(X))
40 | return X_db
41 |
42 | def absolut_squared(X):
43 | '''
44 | Inputs a (N,2*C,H,W) complex tensor
45 | Outputs a (N,C,H,W) real tensor containing the input's squared module
46 | '''
47 | XX=X.split(X.shape[1]//2,1)
48 | return (XX[0]**2+XX[1]**2)
49 |
50 | def oneD2twoD_cplx(x):
51 | '''
52 | Inputs a 3D complex tensor (N,2*n,T)
53 | Outputs a 4D complex tensor (N,2,n,T)
54 | '''
55 | return x.view(x.shape[0],2,-1,x.shape[-1])
56 |
57 | def batch_norm2d_cplx(X,running_mean,running_var,gamma11,gamma12,gamma22,bias,momentum,training):
58 | N,C,H,W=X.shape
59 | XX=list(X.split(C//2,1))
60 | X_re=XX[0].transpose(0,1).contiguous().view(C//2,N*H*W) #(C//2,NHW)
61 | X_im=XX[1].transpose(0,1).contiguous().view(C//2,N*H*W) #(C//2,NHW)
62 | X_cplx=th.cat((X_re[:,None,:],X_im[:,None,:]),1) #(C//2,2,NHW)
63 | if(training):
64 | mu_re=X_re.mean(1); mu_im=X_im.mean(1) #(C//2,)
65 | mu_cplx=th.cat((mu_re[:,None],mu_im[:,None]),1) #(C//2,2)
66 | var_re=X_re.var(1); var_im=X_im.var(1) #(C//2,)
67 | cov_imre=((X_re-mu_re[:,None])*(X_im-mu_im[:,None])).sum(1)/(N*H*W-1) #(C//2,)
68 | cov_cplx=utils.twobytwo_covmat_from_coeffs(var_re,var_im,cov_imre)
69 | with th.no_grad():
70 | running_mean=(1-momentum)*running_mean+momentum*mu_cplx #(C//2,2)
71 | running_var=(1-momentum)*running_var+momentum*cov_cplx #(C//2,2,2)
72 | # cov_cplx_sqinv=(functional_spd.SqminvEig()(cov_cplx[:,None,:,:]))[:,0,:,:]
73 | cov_cplx_sqinv=utils.twobytwo_sqinv(var_re,var_im,cov_imre)
74 | Y=cov_cplx_sqinv.matmul((X_cplx-mu_cplx[:,:,None]))
75 | else:
76 | # running_var_sqinv=(functional_spd.SqminvEig()(running_var[:,None,:,:].double())).float()[:,0,:,:]
77 | running_var_sqinv=utils.twobytwo_sqinv(running_var[:,0,0],running_var[:,1,1],running_var[:,1,0])
78 | Y=running_var_sqinv.matmul((X_cplx-running_mean[:,:,None]))
79 | weight=utils.twobytwo_covmat_from_coeffs(gamma11,gamma22,gamma12)
80 | Z=weight.matmul(Y)+bias[:,:,None]
81 | return Z.view(C,-1).view(C,N,H,W).transpose(0,1)
82 |
83 | def batch_norm2d_cplx_spd(X,running_mean,running_var,weight,bias,momentum,training):
84 | N,C,H,W=X.shape
85 | XX=list(X.split(C//2,1))
86 | X_re=XX[0].transpose(0,1).contiguous().view(C//2,N*H*W) #(C//2,NHW)
87 | X_im=XX[1].transpose(0,1).contiguous().view(C//2,N*H*W) #(C//2,NHW)
88 | X_cplx=th.cat((X_re[:,None,:],X_im[:,None,:]),1) #(C//2,2,NHW)
89 | if(training):
90 | mu_re=X_re.mean(1); mu_im=X_im.mean(1) #(C//2,)
91 | mu_cplx=th.cat((mu_re[:,None],mu_im[:,None]),1) #(C//2,2)
92 | var_re=X_re.var(1); var_im=X_im.var(1) #(C//2,)
93 | cov_imre=((X_re-mu_re[:,None])*(X_im-mu_im[:,None])).sum(1)/(N*H*W-1) #(C//2,)
94 | cov_cplx=utils.twobytwo_covmat_from_coeffs(var_re,var_im,cov_imre)
95 | with th.no_grad():
96 | running_mean=(1-momentum)*running_mean+momentum*mu_cplx #(C//2,2)
97 | running_var=(1-momentum)*running_var+momentum*cov_cplx #(C//2,2,2)
98 | cov_cplx_sqinv=(functional_spd.SqminvEig()(cov_cplx[:,None,:,:]))[:,0,:,:]
99 | # cov_cplx_sqinv=utils.twobytwo_sqinv(var_re,var_im,cov_imre)
100 | Y=cov_cplx_sqinv.matmul((X_cplx-mu_cplx[:,:,None]))
101 | else:
102 | running_var_sqinv=(functional_spd.SqminvEig()(running_var[:,None,:,:].double())).float()[:,0,:,:]
103 | # running_var_sqinv=utils.twobytwo_sqinv(running_var[:,0,0],running_var[:,1,1],running_var[:,1,0])
104 | Y=running_var_sqinv.matmul((X_cplx-running_mean[:,:,None]))
105 | weight=utils.twobytwo_covmat_from_coeffs(_gamma11,_gamma22,_gamma12) ##################### CHANGE
106 | Z=weight.matmul(Y)+bias[:,:,None]
107 | return Z.view(C,-1).view(C,N,H,W).transpose(0,1)
108 |
109 |
110 | def cov_pool_cplx(f,reg_mode='mle',N_estimates=None):
111 | """
112 | Input f: Temporal n-dimensionnal complex feature map of length T (T=1 for a unitary signal) (batch_size,2,n,T)
113 | Output X: Complex covariance matrix of size (batch_size,2,n,n)
114 | """
115 | N,_,n,T=f.shape
116 | ff=f.split(f.shape[1]//2,1)
117 | f_re=ff[0]; f_im=ff[1]
118 | if(N_estimates is not None):
119 | f_re=f_re.split(self._Nestimates,-1)
120 | if(f_re[-1].shape[-1]!=self._Nestimates):
121 | f_re=th.cat(f_re[:-1]+(th.cat((f_re[-1],th.zeros(N,1,n,self._Nestimates-f_re[-1].shape[-1])),-1),),1)
122 | else:
123 | f_re=th.cat(f_re,1)
124 | f_im=f_im.split(self._Nestimates,-1)
125 | if(f_im[-1].shape[-1]!=self._Nestimates):
126 | f_im=th.cat(f_im[:-1]+(th.cat((f_im[-1],th.zeros(N,1,n,self._Nestimates-f_im[-1].shape[-1])),-1),),1)
127 | else:
128 | f_im=th.cat(f_im,1)
129 | f_re-=f_re.mean(-1,True); f_im-=f_im.mean(-1,True)
130 | f_re=f_re.double(); f_im=f_im.double()
131 | X_Re=((f_re.matmul(f_re.transpose(-1,-2))+f_im.matmul(f_im.transpose(-1,-2)))/(f.shape[-1]-1))
132 | X_Im=((f_im.matmul(f_re.transpose(-1,-2))-f_re.matmul(f_im.transpose(-1,-2)))/(f.shape[-1]-1))
133 | if(reg_mode=='mle'):
134 | pass
135 | elif(self._reg_mode=='add_id'):
136 | X_Re=RegulEig(1e-6)(X_Re)
137 | X_Im=RegulEig(1e-6)(X_Im)
138 | elif(self._reg_mode=='adjust_eig'):
139 | X_Re=AdjustEig(0.75)(X_Re)
140 | X_Im=AdjustEig(0.75)(X_Im)
141 | X=(X_Re+X_Im)/2 ############## later, do cat for HPD
142 | # X=th.cat((X_Re,X_Im),1) #for real complex matrices
143 | return X
144 |
--------------------------------------------------------------------------------
/LieBN_SPDNet/cplx/nn.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 | import torch.nn as nn
3 | import numpy.random
4 | import torch.nn.functional as F
5 | import numpy as np
6 | from . import functional
7 |
8 | class Conv_cplx(nn.Module):
9 | '''
10 | Interface complex conv layer
11 | '''
12 | def forward(self,X):
13 | return functional.conv_cplx(X,self._conv_Re,self._conv_Im)
14 |
15 | class Conv1d_cplx(Conv_cplx):
16 | '''
17 | 1D complex conv layer
18 | Inputs a 3D Tensor (batch_size,2*C,T)
19 | C is the number of channels, 2*C=in_channels the effective number of channels for handling complex data
20 | Contains two real-valued conv layers
21 | Output is of shape (batch_size,out_channels,T) (complex channels is out_channels//2)
22 | '''
23 | def __init__(self,in_channels,out_channels,kernel_size,stride,bias=False):
24 | super(__class__,self).__init__()
25 | self._conv_Re=nn.Conv1d(in_channels//2,out_channels,kernel_size,stride,bias=bias)
26 | self._conv_Im=nn.Conv1d(in_channels//2,out_channels,kernel_size,stride,bias=bias)
27 |
28 | class FFT(Conv1d_cplx):
29 | '''
30 | 1D complex conv layer, where the weights are the Fourier atoms
31 | '''
32 | def __init__(self,in_channels,out_channels,kernel_size,stride,bias=False):
33 | super(__class__, self).__init__(in_channels,out_channels,kernel_size,stride,bias)
34 | atoms=signal_utils.fourier_atoms(out_channels,kernel_size).conj()
35 | atoms_Re=th.from_numpy(utils.cplx2bichan(atoms)[:,0,:][:,None,:]).float()
36 | atoms_Im=th.from_numpy(utils.cplx2bichan(atoms)[:,1,:][:,None,:]).float()
37 | self._conv_Re.weight.data=nn.Parameter(atoms_Re)
38 | self._conv_Im.weight.data=nn.Parameter(atoms_Im)
39 | for param in list(self.parameters()):
40 | param.requires_grad=False
41 |
42 | class SplitSignal_cplx(Conv1d_cplx):
43 | '''
44 | 1D to 2D complex conv layer, where the weights are adequately placed zeroes and ones to split a signal
45 | Input is still (batch_size,2,T)
46 | Output is 4-D complex signal (batch_size,2,window_size,T') instead of (batch_size,2*window_size,T')
47 | '''
48 | def __init__(self,in_channels,window_size,hop_length):
49 | super(__class__, self).__init__(in_channels,window_size,window_size,hop_length,False)
50 | self._conv_Re.weight.data=th.eye(window_size)[:,None,:]
51 | self._conv_Im.weight.data=th.eye(window_size)[:,None,:]
52 | for param in list(self._conv_Re.parameters()):
53 | param.requires_grad=False
54 | for param in list(self._conv_Im.parameters()):
55 | param.requires_grad=False
56 | def forward(self,X):
57 | return functional.split_signal_cplx(X,self._conv_Re,self._conv_Im)
58 |
59 | class SplitSignalBlock_cplx(Conv1d_cplx):
60 | '''
61 | 1D to 2D complex conv layer, where the weights are adequately placed zeroes and ones to split a signal
62 | Input is now L blocks of 3-D complex signals (batch_size,L,2*C_in,T), which are acted on independently by the conv layer
63 | Output is L blocks of 3-D complex signals (batch_size,L,2*C_out,window_size,T')
64 | '''
65 | def __init__(self,in_channels,window_size,hop_length):
66 | super(__class__, self).__init__(in_channels,window_size,window_size,hop_length,False)
67 | # self._conv_Re.weight.data=th.ones_like(self._conv_Re.weight.data)
68 | self._conv_Re.weight.data=th.eye(window_size)[:,None,:]
69 | # self._conv_Im.weight.data=th.ones_like(self._conv_Im.weight.data)
70 | self._conv_Im.weight.data=th.eye(window_size)[:,None,:]
71 | for param in list(self._conv_Re.parameters()):
72 | param.requires_grad=False
73 | for param in list(self._conv_Im.parameters()):
74 | param.requires_grad=False
75 | def forward(self,X):
76 | XX=[functional.split_signal_cplx(X[:,i,:,:],self._conv_Re,self._conv_Im)[:,None,:,:,:]
77 | for i in range(X.shape[1])]
78 | return th.cat(XX,1)
79 |
80 | class Conv2d_cplx(Conv_cplx):
81 | '''
82 | 2D complex conv layer
83 | Inputs a 4D Tensor (N,2*C,H,W)
84 | C is the number of channels, 2*C the effective number of channels for handling complex data
85 | Contains two real-valued conv layers
86 | '''
87 | def __init__(self,in_channels,out_channels,kernel_size,stride=(1,1),bias=True):
88 | super(Conv2d_cplx,self).__init__()
89 | self._conv_Re=nn.Conv2d(in_channels//2,out_channels//2,kernel_size,stride,bias=bias)
90 | self._conv_Im=nn.Conv2d(in_channels//2,out_channels//2,kernel_size,stride,bias=bias)
91 | sigma=2./(in_channels//2+out_channels//2)
92 | ampli=numpy.random.rayleigh(sigma,(out_channels//2,in_channels//2,kernel_size[0],kernel_size[1]))
93 | phase=numpy.random.uniform(-np.pi,np.pi,(out_channels//2,in_channels//2,kernel_size[0],kernel_size[1]))
94 | self._conv_Re.weight.data=nn.Parameter(th.Tensor(ampli*np.cos(phase)))
95 | self._conv_Im.weight.data=nn.Parameter(th.Tensor(ampli*np.sin(phase)))
96 |
97 | class ReLU_cplx(nn.ReLU):
98 | pass
99 | # def forward(self,X):
100 | # return F.relu(X)
101 |
102 | class Roll(nn.Module):
103 | def forward(self,X):
104 | return functional.roll(X)
105 |
106 | class Decibel(nn.Module):
107 | '''
108 | Inputs a (N,2*C,H,W) complex tensor
109 | Outputs a (N,C,H,W) real tensor containing the input's decibel amplitude
110 | '''
111 | def forward(self,X):
112 | return functional.decibel(X)
113 |
114 | class AbsolutSquared(nn.Module):
115 | '''
116 | Inputs a (N,2*C,H,W) complex tensor
117 | Outputs a (N,C,H,W) real tensor containing the input's squared module
118 | '''
119 | def forward(self,X):
120 | return functional.absolut_squared(X)
121 |
122 | class oneD2twoD_cplx(nn.Module):
123 | '''
124 | Inputs a 3D complex tensor (N,2*n,T)
125 | Outputs a 4D complex tensor (N,2,n,T)
126 | '''
127 | def forward(self,x):
128 | return functional.oneD2twoD_cplx(x)
129 |
130 | class MaxPool2d_cplx(nn.MaxPool2d):
131 | pass
132 | # def __init__(self,kernel_size,stride=None,padding=0,dilation=1,return_indices=False,ceil_mode=False):
133 | # super(MaxPool2d_cplx, self).__init__(kernel_size,stride,padding,dilation,return_indices,ceil_mode)
134 | # self._abs=AbsolutSquared()
135 | # def forward(self,X):
136 | # N,C,H,W=X.shape
137 | # X_abs=self._abs(X)
138 | # _,idx=F.max_pool2d(X_abs,self.kernel_size,self.stride,self.padding,self.dilation,self.ceil_mode,return_indices=True)
139 | # XX=X.split(C//2,1)
140 | # Y_re=th.gather(XX[0].view(N,C//2,-1),-1,idx.view(N,C//2,-1)).view(idx.shape)
141 | # Y_im=th.gather(XX[1].view(N,C//2,-1),-1,idx.view(N,C//2,-1)).view(idx.shape)
142 | # Y=th.cat((Y_re,Y_im),1)
143 | # return Y
144 |
145 | class BatchNorm2d_cplx(nn.BatchNorm2d):
146 | '''
147 | Inputs a (N,2*C,H,W) complex tensor
148 | Outputs a whitened and parametrically rescaled (N,2*C,H,W) complex tensor
149 | '''
150 | # def __init__(self,in_channels):
151 | # super(BatchNorm2d_cplx,self).__init__(in_channels)
152 | # self.momentum=0.1
153 | # self.running_mean=th.zeros(in_channels//2,2)
154 | # self.running_var=th.eye(2)[None,:,:].repeat(in_channels//2,1,1)/(2**.5)
155 | # self._gamma11=nn.Parameter(th.ones(in_channels//2)/2**.5)
156 | # self._gamma22=nn.Parameter(th.ones(in_channels//2)/2**.5)
157 | # self._gamma12=nn.Parameter(th.zeros(in_channels//2))
158 | # self.bias=nn.Parameter(th.zeros(in_channels//2,2))
159 | # def forward(self,X):
160 | # return functional.batch_norm2d_cplx(X,self.running_mean,self.running_var,
161 | # self._gamma11,self._gamma12,self._gamma22,self.bias,self.momentum,self.training)
162 | pass
163 |
164 | class BatchNorm2d_cplxSPD(nn.BatchNorm2d):
165 | '''
166 | Inputs a (N,2*C,H,W) complex tensor
167 | Outputs a whitened and parametrically rescaled (N,2*C,H,W) complex tensor
168 | '''
169 | def __init__(self,in_channels):
170 | super(BatchNorm2d_cplxSPD,self).__init__(in_channels)
171 | self.momentum=0.1
172 | self.running_mean=th.zeros(in_channels//2,2)
173 | self.running_var=th.eye(2)[None,:,:].repeat(in_channels//2,1,1)/(2**.5)
174 | self.weight_=nn.ParameterList([functional_spd.SPDParameter(th.eye(2)/(2**.5)) for _ in range(in_channels//2)])
175 | self.bias=nn.Parameter(th.zeros(in_channels//2,2))
176 | def forward(self,X):
177 | return functional.batch_norm2d_cplx_spd(X,self.running_mean,self.running_var,
178 | self.weight_,self.bias,self.momentum,self.training)
179 |
180 | class BatchNorm2d(nn.BatchNorm2d):
181 | pass
182 | # def __init__(self,in_channels):
183 | # super(BatchNorm2d,self).__init__(in_channels)
184 | # self.momentum=0.1
185 | # self.running_mean=th.zeros(in_channels)
186 | # self.running_var=th.zeros(in_channels)
187 | # self.weight=nn.Parameter(th.ones(in_channels))
188 | # self.bias=nn.Parameter(th.zeros(in_channels))
189 | # def forward(self,X):
190 | # N,C,H,W=X.shape
191 | # X_vec=X.transpose(0,1).contiguous().view(C,N*H*W)
192 | # if(self.training):
193 | # mu=X_vec.mean(1); var=X_vec.var(1)
194 | # with th.no_grad():
195 | # self.running_mean=(1-self.momentum)*self.running_mean+self.momentum*mu
196 | # self.running_var=(1-self.momentum)*self.running_var+self.momentum*var
197 | # Y=(X_vec-mu.view(-1,1))/(var.view(-1,1)**.5+self.eps)
198 | # else:
199 | # Y=(X_vec-self.running_mean.view(-1,1))/(self.running_var.view(-1,1)**.5+self.eps)
200 | # Z=self.weight.view(-1,1)*Y+self.bias.view(-1,1)
201 | # return Z.view(C,N,H,W).transpose(0,1)
202 | #
203 | # def forward(self, x):
204 | # self._check_input_dim(x)
205 | # y = x.transpose(0,1)
206 | # return_shape = y.shape
207 | # y = y.contiguous().view(x.size(1), -1)
208 | # mu = y.mean(dim=1)
209 | # sigma2 = y.var(dim=1)
210 | # if self.training is not True:
211 | # y = y - self.running_mean.view(-1, 1)
212 | # y = y / (self.running_var.view(-1, 1)**.5 + self.eps)
213 | # else:
214 | # if self.track_running_stats is True:
215 | # with torch.no_grad():
216 | # self.running_mean = (1-self.momentum)*self.running_mean + self.momentum*mu
217 | # self.running_var = (1-self.momentum)*self.running_var + self.momentum*sigma2
218 | # y = y - mu.view(-1,1)
219 | # y = y / (sigma2.view(-1,1)**.5 + self.eps)
220 | #
221 | # y = self.weight.view(-1, 1) * y + self.bias.view(-1, 1)
222 | # return y.view(return_shape).transpose(0,1)
223 |
224 | class CovPool_cplx(nn.Module):
225 | """
226 | Input f: Temporal n-dimensionnal complex feature map of length T (T=1 for a unitary signal) (batch_size,2,n,T)
227 | Output X: Complex covariance matrix of size (batch_size,2,n,n)
228 | """
229 | def __init__(self,reg_mode='mle',N_estimates=None):
230 | super(__class__,self).__init__()
231 | self._reg_mode=reg_mode; self.N_estimates=N_estimates
232 | def forward(self,f):
233 | return functional.cov_pool_cplx(f,self._reg_mode,self.N_estimates)
234 |
235 | class CovPoolBlock_cplx(nn.Module):
236 | """
237 | Input f: L blocks of temporal n-dimensionnal complex feature map of length T of shape (batch_size,L,2,n,T)
238 | Output X: L blocks of complex covariance matrix of size (batch_size,L,2,n,n)
239 | """
240 | def __init__(self,reg_mode='mle',N_estimates=None):
241 | super(__class__,self).__init__()
242 | self._reg_mode=reg_mode; self.N_estimates=N_estimates
243 | def forward(self,f):
244 | XX=[functional.cov_pool_cplx(f[:,i,:,:,:],self._reg_mode,self.N_estimates)[:,None,:,:,:]
245 | for i in range(f.shape[1])]
246 | return th.cat(XX,1)
--------------------------------------------------------------------------------
/LieBN_SPDNet/experiments.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | epochs=200
4 | datasets=RADAR,HDM05
5 |
6 | ### Experiments on SPDNet,SPDNetBN
7 | [ $? -eq 0 ] && python SPDNetLieBN.py -m dataset=$datasets nnet=SPDNetBN,SPDNet fit.epochs=$epochs
8 |
9 | ### Experiments on SPDNetLieBN-LEM
10 | [ $? -eq 0 ] && python SPDNetLieBN.py -m dataset=$datasets nnet=SPDNetLieBN fit.epochs=$epochs nnet.model.metric=LEM
11 |
12 | ### Experiments on SPDNetLieBN-AIM
13 | #RADAR
14 | [ $? -eq 0 ] && python SPDNetLieBN.py -m dataset=RADAR nnet=SPDNetLieBN fit.epochs=$epochs nnet.model.metric=AIM nnet.model.theta=1.
15 | #HDM05
16 | [ $? -eq 0 ] && python SPDNetLieBN.py -m dataset=HDM05 nnet=SPDNetLieBN fit.epochs=$epochs nnet.model.metric=AIM nnet.model.theta=1,1.5
17 |
18 | ### Experiments on SPDNetLieBN-LCM
19 | #RADAR
20 | [ $? -eq 0 ] && python SPDNetLieBN.py -m dataset=RADAR nnet=SPDNetLieBN fit.epochs=$epochs nnet.model.metric=LCM nnet.model.theta=1.,-0.5
21 | ##HDM05
22 | [ $? -eq 0 ] && python SPDNetLieBN.py -m dataset=HDM05 nnet=SPDNetLieBN fit.epochs=$epochs nnet.model.metric=LCM nnet.model.theta=1,0.5
--------------------------------------------------------------------------------
/LieBN_SPDNet/spd/DataLoader/FPHA_Loader.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | import numpy as np
5 | import torch as th
6 | import random
7 |
8 | from torch.utils import data
9 |
10 | class DatasetSPD(data.Dataset):
11 | def __init__(self, path, names):
12 | self._path = path
13 | self._names = names
14 |
15 | def __len__(self):
16 | return len(self._names)
17 |
18 | def __getitem__(self, item):
19 | x = np.load(self._path + self._names[item])[None, :, :].real
20 | x = th.from_numpy(x).double()
21 | y = int(self._names[item].split('.')[0].split('_')[-1])
22 | y = th.from_numpy(np.array(y)).long()
23 | # return x.to(device),y.to(device)
24 | return x, y
25 |
26 |
27 | class DataLoaderFPHA:
28 | def __init__(self, data_path, batch_size):
29 | path_train, path_test = data_path + 'train/', data_path + 'val/'
30 | for filenames in os.walk(path_train):
31 | names_train = sorted(filenames[2])
32 | for filenames in os.walk(path_test):
33 | names_test = sorted(filenames[2])
34 | self._train_generator = data.DataLoader(DatasetSPD(path_train, names_train), batch_size=batch_size,
35 | shuffle='True')
36 | self._test_generator = data.DataLoader(DatasetSPD(path_test, names_test), batch_size=batch_size,
37 | shuffle='False')
--------------------------------------------------------------------------------
/LieBN_SPDNet/spd/DataLoader/HDM05_Loader.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | import numpy as np
5 | import torch as th
6 | import random
7 |
8 | from torch.utils import data
9 |
10 | device = 'cpu'
11 | class DatasetHDM05(data.Dataset):
12 | def __init__(self, path, names):
13 | self._path = path
14 | self._names = names
15 |
16 | def __len__(self):
17 | return len(self._names)
18 |
19 | def __getitem__(self, item):
20 | x = np.load(self._path + self._names[item])[None, :, :].real
21 | x = th.from_numpy(x).double()
22 | y = int(self._names[item].split('.')[0].split('_')[-1])
23 | y = th.from_numpy(np.array(y)).long()
24 | return x.to(device), y.to(device)
25 |
26 |
27 | class DataLoaderHDM05:
28 | def __init__(self, data_path, pval, batch_size):
29 | for filenames in os.walk(data_path):
30 | names = sorted(filenames[2])
31 | random.Random(1024).shuffle(names)
32 | N_test = int(pval * len(names))
33 | train_set = DatasetHDM05(data_path, names[N_test:])
34 | test_set = DatasetHDM05(data_path, names[:N_test])
35 | self._train_generator = data.DataLoader(train_set, batch_size=batch_size, shuffle='True')
36 | self._test_generator = data.DataLoader(test_set, batch_size=batch_size, shuffle='False')
--------------------------------------------------------------------------------
/LieBN_SPDNet/spd/DataLoader/Radar_Loader.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | import numpy as np
5 | import torch as th
6 | import random
7 |
8 | from torch.utils import data
9 |
10 | pval=0.25 #validation percentage
11 | ptest=0.25 #test percentage
12 | # th.cuda.device('cpu')
13 |
14 | class DatasetRadar(data.Dataset):
15 | def __init__(self, path, names):
16 | self._path = path
17 | self._names = names
18 | def __len__(self):
19 | return len(self._names)
20 | def __getitem__(self, item):
21 | x=np.load(self._path+self._names[item])
22 | x=np.concatenate((x.real[:,None],x.imag[:,None]),axis=1).T
23 | x=th.from_numpy(x)
24 | y=int(self._names[item].split('.')[0].split('_')[-1])
25 | y=th.from_numpy(np.array(y))
26 | return x.float(),y.long()
27 | class DataLoaderRadar:
28 | def __init__(self,data_path,pval,batch_size):
29 | for filenames in os.walk(data_path):
30 | names=sorted(filenames[2])
31 | random.Random().shuffle(names)
32 | N_val=int(pval*len(names))
33 | N_test=int(ptest*len(names))
34 | N_train=(len(names)-N_test-N_val)
35 | train_set=DatasetRadar(data_path,names[N_val+N_test:int(N_train)+N_test+N_val])
36 | test_set=DatasetRadar(data_path,names[:N_test])
37 | val_set=DatasetRadar(data_path,names[N_test:N_test+N_val])
38 | self._train_generator=data.DataLoader(train_set,batch_size=batch_size,shuffle='True')
39 | self._test_generator=data.DataLoader(test_set,batch_size=batch_size,shuffle='False')
40 | self._val_generator=data.DataLoader(val_set,batch_size=batch_size,shuffle='False')
--------------------------------------------------------------------------------
/LieBN_SPDNet/spd/LieBN.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Ziheng Chen
3 | Please cite the paper below if you use the code:
4 |
5 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024
6 |
7 | Copyright (C) 2024 Ziheng Chen
8 | All rights reserved.
9 | """
10 |
11 | import torch as th
12 | import torch.nn as nn
13 | from . import functional
14 | import geoopt
15 | from geoopt.manifolds import SymmetricPositiveDefinite
16 |
17 | from spd import sym_functional
18 |
19 | dtype=th.double
20 | device=th.device('cpu')
21 |
22 | class LieBatchNormSPD(nn.Module):
23 | """
24 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024
25 | Input X: (...,n,n) SPD matrices
26 | Output X_new: (...,n,n) batch-normalized matrices
27 | Parameters:
28 | self.weight: (n,n) SPD parameter
29 | self.shift: scalar dispersion
30 | """
31 | def __init__(self,n: int ,metric: str ='AIM',theta: float =1.,alpha: float=1.0,beta: float=0.,ddevice='cpu',momentum=0.1):
32 | super(__class__,self).__init__()
33 | self.momentum=momentum;self.n=n; self.ddevice=ddevice;
34 | self.metric = metric;self.theta=th.tensor(theta,dtype=dtype);self.alpha=th.tensor(alpha,dtype=dtype);self.beta=th.tensor(beta,dtype=dtype)
35 | self.running_mean=th.eye(n,dtype=dtype)
36 | self.weight = geoopt.ManifoldParameter(th.eye(n, n, dtype=dtype, device=ddevice),
37 | manifold=SymmetricPositiveDefinite())
38 | self.eps=1e-05;
39 | self.running_var=th.ones(1,dtype=dtype);
40 | self.shift = nn.Parameter(th.ones(1, dtype=dtype))
41 |
42 | if metric != "AIM" and metric != "LCM" and metric != "LEM" :
43 | raise Exception('unknown metric {}'.format(metric))
44 |
45 | def forward(self,X):
46 | #deformation
47 | X_deformed = self.deformation(X)
48 | weight = self.deformation(self.weight)
49 |
50 | if(self.training):
51 | # centering
52 | batch_mean = self.cal_geom_mean(X_deformed)
53 | X_centered = self.Left_translation(X_deformed,batch_mean,is_inverse=True)
54 | # scaling and shifting
55 | batch_var = self.cal_geom_var(X_centered)
56 | X_scaled = self.scaling(X_centered, batch_var, self.shift)
57 | self.updating_running_statistics(batch_mean, batch_var)
58 |
59 | else:
60 | # centering, scaling and shifting
61 | X_centered = self.Left_translation(X_deformed, self.running_mean, is_inverse=True)
62 | X_scaled = self.scaling(X_centered, self.running_var, self.shift)
63 | #biasing
64 | X_normalized = self.Left_translation(X_scaled, weight, is_inverse=False)
65 | # inv_deformation
66 | X_new = self.inv_deformation(X_normalized)
67 |
68 | return X_new
69 | def alpha_beta_dist(self, X):
70 | """"computing the O(n)-invariant Euclidean distance to the identity (element)"""
71 | if self.beta==0.:
72 | dist = self.alpha * th.linalg.matrix_norm(X).square()
73 | else:
74 | item1 = th.linalg.matrix_norm(X)
75 | item2 = functional.trace(X)
76 | dist = self.alpha * item1.square() + self.beta * item2.square()
77 | return dist
78 |
79 | def spd_power(self,X):
80 | if self.theta == 1.:
81 | X_power = X
82 | else:
83 | X_power = sym_functional.sym_powm.apply(X, self.theta)
84 | return X_power
85 |
86 | def inv_power(self,X):
87 | if self.theta == 1.:
88 | X_power = X
89 | else:
90 | X_power = sym_functional.sym_powm.apply(X, 1/self.theta)
91 | return X_power
92 |
93 | def deformation(self,X):
94 |
95 | if self.metric=='AIM':
96 | X_deformed = self.spd_power(X)
97 | elif self.metric == 'LEM':
98 | X_deformed = sym_functional.sym_logm.apply(X)
99 | elif self.metric == 'LCM':
100 | X_power = self.spd_power(X)
101 | L = th.linalg.cholesky(X_power)
102 | diag_part = th.diag_embed(th.log(th.diagonal(L, dim1=-2, dim2=-1)))
103 | X_deformed = L.tril(-1) + diag_part
104 |
105 | return X_deformed
106 |
107 | def inv_deformation(self,X):
108 | if self.metric=='AIM':
109 | X_inv_deformed = self.inv_power(X)
110 | elif self.metric == 'LEM':
111 | X_inv_deformed = sym_functional.sym_expm.apply(X)
112 | elif self.metric == 'LCM':
113 | Cho = X.tril(-1) + th.diag_embed(th.exp(th.diagonal(X, dim1=-2, dim2=-1)))
114 | spd = Cho.matmul(Cho.transpose(-1,-2))
115 | X_inv_deformed = self.inv_power(spd)
116 | return X_inv_deformed
117 | def cal_geom_mean(self, X):
118 | """Frechet mean"""
119 | if self.metric == 'AIM':
120 | mean = self.BaryGeom(X.detach())
121 | elif self.metric == 'LEM' or self.metric == 'LCM':
122 | mean = X.detach().mean(dim=0, keepdim=True)
123 |
124 | return mean
125 | def cal_geom_var(self, X):
126 | """Frechet variance"""
127 | spd = X.detach()
128 | if self.metric == 'AIM':
129 | dists = self.alpha * th.linalg.matrix_norm(sym_functional.sym_logm.apply(spd)).square() + self.beta * th.logdet(spd).square()
130 | var = dists.mean()
131 |
132 | elif self.metric == 'LEM' or self.metric == 'LCM':
133 | dists = self.alpha_beta_dist(spd)
134 | var = dists.mean()
135 |
136 | if self.metric == 'AIM' or self.metric == 'LCM':
137 | var_final = var * (1 / (self.theta ** 2))
138 | else:
139 | var_final=var
140 | return var_final.unsqueeze(dim=0)
141 | def Left_translation(self, X, P, is_inverse):
142 | """Left translation by P"""
143 | if self.metric == 'AIM':
144 | L = th.linalg.cholesky(P);
145 | if is_inverse:
146 | tmp = th.linalg.solve(L, X)
147 | X_new = th.linalg.solve(L.transpose(-1, -2), tmp, left=False)
148 | else:
149 | X_new = L.matmul(X).matmul(L.transpose(-1, -2))
150 |
151 | elif self.metric == 'LEM' or self.metric == 'LCM':
152 | X_new = X - P if is_inverse else X + P
153 |
154 | return X_new
155 | def scaling(self, X, var,scale):
156 | """Scaling by variance"""
157 | factor = scale / (var+self.eps).sqrt()
158 | if self.metric == 'AIM':
159 | X_new = sym_functional.sym_powm.apply(X, factor)
160 | elif self.metric == 'LEM' or self.metric == 'LCM':
161 | X_new = X * factor
162 |
163 | return X_new
164 | def updating_running_statistics(self,batch_mean,batch_var=None):
165 | """updating running mean"""
166 | with th.no_grad():
167 | # updating mean
168 | if self.metric == 'AIM':
169 | self.running_mean.data = functional.geodesic(self.running_mean, batch_mean,self.momentum)
170 | elif self.metric == 'LEM' or self.metric == 'LCM':
171 | self.running_mean.data = (1-self.momentum) * self.running_mean+ batch_mean * self.momentum
172 | # updating var
173 | self.running_var.data = (1 - self.momentum) * self.running_var + batch_var * self.momentum
174 |
175 | def BaryGeom(self,X,karcher_steps=1,batchdim=0):
176 | '''
177 | Function which computes the Riemannian barycenter for a batch of data using the Karcher flow
178 | Input x is a batch of SPD matrices (...,n,n) to average
179 | Output is (n,n) Riemannian mean
180 | '''
181 | batch_mean = X.mean(dim=batchdim,keepdim=True)
182 | for _ in range(karcher_steps):
183 | bm_sq, bm_invsq = sym_functional.sym_invsqrtm2.apply(batch_mean)
184 | XT = sym_functional.sym_logm.apply(bm_invsq @ X @ bm_invsq)
185 | GT = XT.mean(dim=batchdim,keepdim=True)
186 | batch_mean = bm_sq @ sym_functional.sym_expm.apply(GT) @ bm_sq
187 | return batch_mean.squeeze()
--------------------------------------------------------------------------------
/LieBN_SPDNet/spd/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_SPDNet/spd/__init__.py
--------------------------------------------------------------------------------
/LieBN_SPDNet/spd/nn.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 | import torch.nn as nn
3 | from . import functional
4 | import geoopt
5 | from geoopt.manifolds import SymmetricPositiveDefinite
6 |
7 |
8 | dtype=th.double
9 | # dtype=th.float32
10 | device=th.device('cpu')
11 |
12 | class BiMap(nn.Module):
13 | """
14 | Input X: (batch_size,hi) SPDDR matrices of size (ni,ni)
15 | Output P: (batch_size,ho) of bilinearly mapped matrices of size (no,no)
16 | Stiefel parameter of size (ho,hi,ni,no)
17 | """
18 | def __init__(self,ho,hi,ni,no,device = th.device('cpu'),dtype=dtype,optimizer='geoopt'):
19 | super(BiMap, self).__init__()
20 | if optimizer == 'geoopt':
21 | self._W = geoopt.ManifoldParameter(th.empty(ho, hi, ni, no, dtype=dtype, device=device), manifold=geoopt.manifolds.Stiefel())
22 | elif optimizer == 'homemade':
23 | self._W=functional.StiefelParameter(th.empty(ho,hi,ni,no,dtype=dtype,device=device))
24 | else:
25 | raise Exception('unknown param_mode {}'.format(optimizer))
26 | self._ho=ho; self._hi=hi; self._ni=ni; self._no=no
27 | functional.init_bimap_parameter(self._W)
28 | def forward(self,X):
29 | return functional.bimap_channels(X,self._W)
30 |
31 | class LogEig(nn.Module):
32 | """
33 | Input P: (batch_size,h) SPDDR matrices of size (n,n)
34 | Output X: (batch_size,h) of log eigenvalues matrices of size (n,n)
35 | """
36 | def forward(self,P):
37 | return functional.LogEig.apply(P)
38 |
39 | class SqmEig(nn.Module):
40 | """
41 | Input P: (batch_size,h) SPDDR matrices of size (n,n)
42 | Output X: (batch_size,h) of sqrt eigenvalues matrices of size (n,n)
43 | """
44 | def forward(self,P):
45 | return functional.SqmEig.apply(P)
46 |
47 | class ReEig(nn.Module):
48 | """
49 | Input P: (batch_size,h) SPDDR matrices of size (n,n)
50 | Output X: (batch_size,h) of rectified eigenvalues matrices of size (n,n)
51 | """
52 | def forward(self,P):
53 | return functional.ReEig.apply(P)
54 |
55 | class BaryGeom(nn.Module):
56 | '''
57 | Function which computes the Riemannian barycenter for a batch of data using the Karcher flow
58 | Input x is a batch of SPDDR matrices (batch_size,1,n,n) to average
59 | Output is (n,n) Riemannian mean
60 | '''
61 | def forward(self,x):
62 | return functional.BaryGeom(x)
63 |
64 | class BatchNormSPD(nn.Module):
65 | """
66 | Input X: (N,h) SPD matrices of size (n,n) with h channels and batch size N
67 | Output P: (N,h) batch-normalized matrices
68 | SPD parameter of size (n,n)
69 | """
70 | def __init__(self,n,momentum=0.1):
71 | super(__class__,self).__init__()
72 | self.momentum=momentum;
73 | self.running_mean=th.eye(n,dtype=dtype) ################################
74 | self.weight = geoopt.ManifoldParameter(th.eye(n, n, dtype=dtype),
75 | manifold=SymmetricPositiveDefinite())
76 | def forward(self,X):
77 | N,h,n,n=X.shape
78 | X_batched=X.permute(2,3,0,1).contiguous().view(n,n,N*h,1).permute(2,3,0,1).contiguous()
79 | if(self.training):
80 | mean=functional.BaryGeom(X_batched)
81 | with th.no_grad():
82 | self.running_mean.data=functional.geodesic(self.running_mean,mean,self.momentum)
83 | X_centered=functional.CongrG(X_batched,mean,'neg')
84 | else:
85 | X_centered=functional.CongrG(X_batched,self.running_mean,'neg')
86 | X_normalized=functional.CongrG(X_centered,self.weight,'pos')
87 | # num_x = X.data.numpy()
88 | # num_X_centered = X_centered.data.numpy()
89 | # num_mean = mean.data.numpy()
90 | X_new = X_normalized.permute(2,3,0,1).contiguous().view(n,n,N,h).permute(2,3,0,1).contiguous()
91 | return X_new
92 |
93 | class CovPool(nn.Module):
94 | """
95 | Input f: Temporal n-dimensionnal feature map of length T (T=1 for a unitary signal) (batch_size,n,T)
96 | Output X: Covariance matrix of size (batch_size,1,n,n)
97 | """
98 | def __init__(self,reg_mode='mle'):
99 | super(__class__,self).__init__()
100 | self._reg_mode=reg_mode
101 | def forward(self,f):
102 | return functional.cov_pool(f,self._reg_mode)
103 |
104 | class CovPoolBlock(nn.Module):
105 | """
106 | Input f: L blocks of temporal n-dimensionnal feature map of length T (T=1 for a unitary signal) (batch_size,L,n,T)
107 | Output X: L covariance matrices, shape (batch_size,L,1,n,n)
108 | """
109 | def __init__(self,reg_mode='mle'):
110 | super(__class__,self).__init__()
111 | self._reg_mode=reg_mode
112 | def forward(self,f):
113 | ff=[functional.cov_pool(f[:,i,:,:],self._reg_mode)[:,None,:,:,:] for i in range(f.shape[1])]
114 | return th.cat(ff,1)
115 |
116 | class CovPoolMean(nn.Module):
117 | """
118 | Input f: Temporal n-dimensionnal feature map of length T (T=1 for a unitary signal) (batch_size,n,T)
119 | Output X: Covariance matrix of size (batch_size,1,n,n)
120 | """
121 | def __init__(self,reg_mode='mle'):
122 | super(__class__,self).__init__()
123 | self._reg_mode=reg_mode
124 | def forward(self,f):
125 | return functional.cov_pool_mu(f,self._reg_mode)
--------------------------------------------------------------------------------
/LieBN_SPDNet/spd/training_script.py:
--------------------------------------------------------------------------------
1 | from spd.utils import get_dataset_settings
2 | from Network import Get_Model
3 | import spd.utils as nn_spd_utils
4 |
5 | from torch.utils.tensorboard import SummaryWriter
6 |
7 | import os
8 | import datetime
9 | import time
10 | import logging
11 |
12 | import torch as th
13 | import torch.nn as nn
14 | import numpy as np
15 | import fcntl
16 |
17 | def training(cfg,args):
18 | args=nn_spd_utils.parse_cfg(args,cfg)
19 |
20 | #set logger
21 | logger = logging.getLogger(args.modelname)
22 | logger.setLevel(logging.INFO)
23 | args.logger = logger
24 | logger.info('begin model {} on dataset: {}'.format(args.modelname,args.dataset))
25 |
26 | #set seed and threadnum
27 | nn_spd_utils.set_seed_thread(args.seed,args.threadnum)
28 |
29 | # set dataset, model and optimizer
30 | args.class_num, args.DataLoader = get_dataset_settings(args)
31 | model = Get_Model.get_model(args)
32 | loss_fn = nn.CrossEntropyLoss()
33 | args.loss_fn = loss_fn.cuda()
34 | args.opti = nn_spd_utils.optimzer(model.parameters(), lr=args.lr, mode=args.optimizer,weight_decay=args.weight_decay)
35 | # begin training
36 | val_acc = training_loop(model,args)
37 |
38 | return val_acc
39 |
40 | def training_loop(model, args):
41 | #setting tensorboard
42 | if args.is_writer:
43 | args.writer_path = os.path.join('./tensorboard_logs/',f"{args.modelname}")
44 | args.logger.info('writer path {}'.format(args.writer_path))
45 | args.writer = SummaryWriter(args.writer_path)
46 |
47 | acc_val = [];loss_val = [];acc_train = [];loss_train = [];training_time=[]
48 | logger = args.logger
49 | # training loop
50 | for epoch in range(0, args.epochs):
51 | # train one epoch
52 | start = time.time()
53 | temp_loss_train, temp_acc_train = [], []
54 | model.train()
55 | for local_batch, local_labels in args.DataLoader._train_generator:
56 | args.opti.zero_grad()
57 | out = model(local_batch)
58 | l = args.loss_fn(out, local_labels)
59 | acc, loss = (out.argmax(1) == local_labels).cpu().numpy().sum() / out.shape[0], l.cpu().data.numpy()
60 | temp_loss_train.append(loss)
61 | temp_acc_train.append(acc)
62 | l.backward()
63 | args.opti.step()
64 | end = time.time()
65 | training_time.append(end-start)
66 | acc_train.append(np.asarray(temp_acc_train).mean() * 100)
67 | loss_train.append(np.asarray(temp_loss_train).mean())
68 |
69 | # validation
70 | acc_val_list = [];loss_val_list = [];y_true, y_pred = [], []
71 | model.eval()
72 | with th.no_grad():
73 | for local_batch, local_labels in args.DataLoader._test_generator:
74 | out = model(local_batch)
75 | l = args.loss_fn(out, local_labels)
76 | predicted_labels = out.argmax(1)
77 | y_true.extend(list(local_labels.cpu().numpy()));
78 | y_pred.extend(list(predicted_labels.cpu().numpy()))
79 | acc, loss = (predicted_labels == local_labels).cpu().numpy().sum() / out.shape[0], l.cpu().data.numpy()
80 | acc_val_list.append(acc)
81 | loss_val_list.append(loss)
82 | loss_val.append(np.asarray(loss_val_list).mean())
83 | acc_val.append(np.asarray(acc_val_list).mean() * 100)
84 |
85 | if args.is_writer:
86 | args.writer.add_scalar('Loss/val', loss_val[epoch], epoch)
87 | args.writer.add_scalar('Accuracy/val', acc_val[epoch], epoch)
88 | args.writer.add_scalar('Loss/train', loss_train[epoch], epoch)
89 | args.writer.add_scalar('Accuracy/train', acc_train[epoch], epoch)
90 |
91 | if epoch % args.cycle == 0:
92 | logger.info(
93 | 'Time: {:.2f}, Val acc: {:.2f}, loss: {:.2f} at epoch {:d}/{:d} '.format(
94 | training_time[epoch], acc_val[epoch], loss_val[epoch], epoch + 1, args.epochs))
95 |
96 | if args.is_save:
97 | average_time = np.asarray(training_time[-10:]).mean()
98 | final_val_acc = acc_val[-1]
99 | final_results = f'Final validation accuracy : {final_val_acc:.2f}% with average time: {average_time:.2f}'
100 | final_results_path = os.path.join(os.getcwd(), 'final_results_' + args.dataset)
101 | logger.info(f"results file path: {final_results_path}, and saving the results")
102 | write_final_results(final_results_path, args.modelname + '- ' + final_results)
103 | torch_results_dir = './torch_resutls'
104 | if not os.path.exists(torch_results_dir):
105 | os.makedirs(torch_results_dir)
106 | th.save({
107 | 'acc_val': acc_val,
108 | }, os.path.join(torch_results_dir,args.modelname.rsplit('-',1)[0]))
109 |
110 | if args.is_writer:
111 | args.writer.close()
112 | return acc_val
113 |
114 | def write_final_results(file_path,message):
115 | # Create a file lock
116 | with open(file_path, "a") as file:
117 | fcntl.flock(file.fileno(), fcntl.LOCK_EX) # Acquire an exclusive lock
118 |
119 | # Write the message to the file
120 | file.write(message + "\n")
121 |
122 | fcntl.flock(file.fileno(), fcntl.LOCK_UN)
--------------------------------------------------------------------------------
/LieBN_SPDNet/spd/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import time
4 | import datetime
5 |
6 | import random
7 | import geoopt
8 | import numpy as np
9 | import torch as th
10 |
11 | from spd.DataLoader.FPHA_Loader import DataLoaderFPHA
12 | from spd.DataLoader.HDM05_Loader import DataLoaderHDM05
13 | from spd.DataLoader.Radar_Loader import DataLoaderRadar
14 |
15 | def get_model_name(args):
16 | if args.model_type == 'SPDNet':
17 | name = f'{args.seed}-{args.lr}-wd_{args.weight_decay}-{args.model_type}-{args.optimizer}-{args.architecture}-{datetime.datetime.now().strftime("%H_%M")}'
18 | elif args.model_type == 'SPDNetBN':
19 | name = f'{args.seed}-{args.lr}-wd_{args.weight_decay}-m_{args.momentum}-{args.model_type}-{args.optimizer}-{args.architecture}-{datetime.datetime.now().strftime("%H_%M")}'
20 | elif args.model_type in args.total_LieBN_model_types:
21 | if args.model_type=='SPDNetLieBN_RS':
22 | model_type = args.model_type + '_init_RS'if args.init_by_RS else args.model_type
23 | else:
24 | model_type = args.model_type
25 | if args.metric == 'AIM' or args.metric == 'LEM':
26 | name = f'{args.seed}-{args.lr}-wd_{args.weight_decay}-m_{args.momentum}-{model_type}-{args.optimizer}-{args.architecture}-{args.metric}-({args.theta},{args.alpha},{args.beta:.4f})-{datetime.datetime.now().strftime("%H_%M")}'
27 | elif args.metric== 'LCM':
28 | name = f'{args.seed}-{args.lr}-wd_{args.weight_decay}-m_{args.momentum}-{model_type}-{args.optimizer}-{args.architecture}-{args.metric}-({args.theta})-{datetime.datetime.now().strftime("%H_%M")}'
29 | else:
30 | raise Exception('unknown metric {} or model'.format(args.metric,args.model_type))
31 | return name
32 |
33 | def get_dataset_settings(args):
34 | if args.dataset=='FPHA':
35 | class_num = 45
36 | DataLoader = DataLoaderFPHA(args.path,args.batchsize)
37 | elif args.dataset=='HDM05':
38 | class_num = 117
39 | pval = 0.5
40 | DataLoader = DataLoaderHDM05(args.path, pval, args.batchsize)
41 | elif args.dataset== 'RADAR' :
42 | class_num = 3
43 | pval = 0.25
44 | DataLoader = DataLoaderRadar(args.path,pval,args.batchsize)
45 | else:
46 | raise Exception('unknown dataset {}'.format(args.dataset))
47 | return class_num,DataLoader
48 |
49 | def set_seed_thread(seed,threadnum):
50 | th.set_num_threads(threadnum)
51 | seed = seed
52 | random.seed(seed)
53 | # th.cuda.set_device(args.gpu)
54 | np.random.seed(seed)
55 | th.manual_seed(seed)
56 | th.cuda.manual_seed(seed)
57 |
58 | def optimzer(parameters,lr,mode='AMSGRAD',weight_decay=0.):
59 | if mode=='ADAM':
60 | optim = geoopt.optim.RiemannianAdam(parameters, lr=lr,weight_decay=weight_decay)
61 | elif mode=='SGD':
62 | optim = geoopt.optim.RiemannianSGD(parameters, lr=lr,weight_decay=weight_decay)
63 | elif mode=='AMSGRAD':
64 | optim = geoopt.optim.RiemannianAdam(parameters, lr=lr,amsgrad=True,weight_decay=weight_decay)
65 | else:
66 | raise Exception('unknown optimizer {}'.format(mode))
67 | return optim
68 |
69 | def training_loop(model, data_loader, opti, loss_fn,writer, args,model_path, begin_epoch):
70 | acc_val = [];loss_val = [];acc_train = [];loss_train = []
71 | # training loop
72 | for epoch in range(begin_epoch, args.epochs):
73 | # train one epoch
74 | start = time.time()
75 | temp_loss_train, temp_acc_train = [], []
76 | model.train()
77 | for local_batch, local_labels in data_loader._train_generator:
78 | opti.zero_grad()
79 | out = model(local_batch)
80 | l = loss_fn(out, local_labels)
81 | acc, loss = (out.argmax(1) == local_labels).cpu().numpy().sum() / out.shape[0], l.cpu().data.numpy()
82 | temp_loss_train.append(loss)
83 | temp_acc_train.append(acc)
84 | l.backward()
85 | opti.step()
86 | if args.is_gpu:
87 | th.cuda.synchronize()
88 | end = time.time()
89 | acc_train.append(np.asarray(temp_acc_train).mean() * 100)
90 | loss_train.append(np.asarray(temp_loss_train).mean())
91 |
92 | # validation
93 | acc_val_list = [];loss_val_list = [];y_true, y_pred = [], []
94 | model.eval()
95 | with th.no_grad():
96 | for local_batch, local_labels in data_loader._test_generator:
97 | out = model(local_batch)
98 | l = loss_fn(out, local_labels)
99 | predicted_labels = out.argmax(1)
100 | y_true.extend(list(local_labels.cpu().numpy()));
101 | y_pred.extend(list(predicted_labels.cpu().numpy()))
102 | acc, loss = (predicted_labels == local_labels).cpu().numpy().sum() / out.shape[0], l.cpu().data.numpy()
103 | acc_val_list.append(acc)
104 | loss_val_list.append(loss)
105 | loss_val.append(np.asarray(loss_val_list).mean())
106 | acc_val.append(np.asarray(acc_val_list).mean() * 100)
107 | if args.is_writer:
108 | writer.add_scalar('Loss/val', loss_val[epoch], epoch)
109 | writer.add_scalar('Accuracy/val', acc_val[epoch], epoch)
110 | writer.add_scalar('Loss/train', loss_train[epoch], epoch)
111 | writer.add_scalar('Accuracy/train', acc_train[epoch], epoch)
112 | print(
113 | '{}: time: {:.2f}, Val acc: {:.2f}, loss: {:.2f}, at epoch {:d}/{:d} '.format(
114 | args.modelname,end - start, acc_val[epoch], loss_val[epoch], epoch + 1, args.epochs))
115 | if epoch + 1 == args.epochs and args.is_save:
116 | th.save({
117 | 'epoch': epoch,
118 | 'model_state_dict': model.state_dict(),
119 | 'lr': args.lr,
120 | 'acc_val': acc_val,
121 | 'acc_train': acc_train,
122 | 'loss_val': loss_val,
123 | 'loss_train': loss_train
124 | }, model_path + '-' + str(epoch))
125 | print('{}: Final validation accuracy: {}%'.format(args.modelname,acc_val[-1]))
126 | if args.is_writer:
127 | writer.close()
128 | return acc_val
129 |
130 | def del_file(path):
131 | ls = os.listdir(path)
132 | for i in ls:
133 | c_path = os.path.join(path, i)
134 | if os.path.isdir(c_path):
135 | del_file(c_path)
136 | else:
137 | os.remove(c_path)
138 |
139 | def resuming_writer(begin_epoch, writer,loss_val,loss_train,acc_val,acc_train):
140 | for epoch in range(begin_epoch):
141 | writer.add_scalar('Loss/val', loss_val[epoch], epoch)
142 | writer.add_scalar('Accuracy/val', acc_val[epoch], epoch)
143 | writer.add_scalar('Loss/train', loss_train[epoch], epoch)
144 | writer.add_scalar('Accuracy/train', acc_train[epoch], epoch)
145 |
146 | def parse_cfg(args,cfg):
147 | # setting args from cfg
148 |
149 | args.seed = cfg.fit.seed
150 | args.model_type = cfg.nnet.model.model_type
151 | args.is_save = cfg.fit.is_save
152 |
153 | if args.model_type in args.total_BN_model_types:
154 | args.BN_type = cfg.nnet.model.BN_type
155 | args.momentum = cfg.nnet.model.momentum
156 | if args.model_type in args.total_LieBN_model_types:
157 | args.metric = cfg.nnet.model.metric
158 | args.theta = cfg.nnet.model.theta
159 | args.alpha = cfg.nnet.model.alpha
160 | args.beta = eval(cfg.nnet.model.beta) if isinstance(cfg.nnet.model.beta, str) else cfg.nnet.model.beta
161 |
162 | args.dataset = cfg.dataset.name
163 | args.architecture = cfg.dataset.architecture
164 | args.path = cfg.dataset.path
165 |
166 | args.optimizer = cfg.nnet.optimizer.mode
167 | args.lr = cfg.nnet.optimizer.lr
168 | args.weight_decay = cfg.nnet.optimizer.weight_decay
169 |
170 | args.epochs = cfg.fit.epochs
171 | args.batchsize = cfg.fit.batch_size
172 |
173 | args.threadnum = cfg.fit.threadnum
174 | args.is_writer = cfg.fit.is_writer
175 | args.cycle = cfg.fit.cycle
176 |
177 | # get model name
178 | args.modelname = get_model_name(args)
179 |
180 | return args
--------------------------------------------------------------------------------
/LieBN_TSMNet/LieBN_utilities/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/LieBN_utilities/__init__.py
--------------------------------------------------------------------------------
/LieBN_TSMNet/LieBN_utilities/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import fcntl
3 | import random
4 |
5 | import numpy as np
6 | import torch as th
7 |
8 | def get_model_name(args):
9 | bias = 'bias' if args.learn_mean else 'non_bias'
10 |
11 | if args.model_type=='TSMNet+SPDDSMBN' or args.model_type=='TSMNet':
12 | name = f'{args.seed}-{args.lr}-wd_{args.weight_decay}-{args.model_type}-{args.optimizer}-{args.architecture}-{bias}-{datetime.datetime.now().strftime("%H_%M")}'
13 | elif args.model_type == 'TSMNet+LieBN':
14 | if args.metric == 'AIM' or args.metric == 'LEM':
15 | name = f'{args.seed}-{args.lr}-wd_{args.weight_decay}-{args.model_type}-{args.optimizer}-{args.architecture}-{bias}-{args.metric}-({args.theta},{args.alpha},{args.beta})-{datetime.datetime.now().strftime("%H_%M")}'
16 | elif args.metric== 'LCM':
17 | name = f'{args.seed}-{args.lr}-wd_{args.weight_decay}-{args.model_type}-{args.optimizer}-{args.architecture}-{bias}-{args.metric}-({args.theta})-{datetime.datetime.now().strftime("%H_%M")}'
18 | else:
19 | raise Exception('unknown metric {} or model'.format(args.metric,args.model_type))
20 | return name
21 |
22 | def write_final_results(file_path,message):
23 | # Create a file lock
24 | with open(file_path, "a") as file:
25 | fcntl.flock(file.fileno(), fcntl.LOCK_EX) # Acquire an exclusive lock
26 | # Write the message to the file
27 | file.write(message + "\n")
28 | fcntl.flock(file.fileno(), fcntl.LOCK_UN) # Release the lock
29 |
30 | def set_seed_thread(seed,threadnum):
31 | th.set_num_threads(threadnum)
32 | seed = seed
33 | random.seed(seed)
34 | # th.cuda.set_device(args.gpu)
35 | np.random.seed(seed)
36 | th.manual_seed(seed)
37 | th.cuda.manual_seed(seed)
--------------------------------------------------------------------------------
/LieBN_TSMNet/TSMNet-LieBN.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Ziheng Chen
3 | Please cite the paper below if you use the code:
4 |
5 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024
6 |
7 | Copyright (C) 2024 Ziheng Chen
8 | All rights reserved.
9 | """
10 |
11 | import hydra
12 | from omegaconf import DictConfig
13 |
14 | import warnings
15 | from sklearn.exceptions import FitFailedWarning, ConvergenceWarning
16 |
17 | from library.utils.hydra import hydra_helpers
18 | from LieBN_utilities.Training import training
19 |
20 | warnings.filterwarnings("ignore", category=FitFailedWarning)
21 | warnings.filterwarnings("ignore", category=ConvergenceWarning)
22 | warnings.filterwarnings("ignore", category=FutureWarning)
23 | warnings.filterwarnings("ignore", category=UserWarning)
24 | warnings.filterwarnings("ignore", category=RuntimeWarning)
25 |
26 | class Args:
27 | """ a Struct Class """
28 | pass
29 | args=Args()
30 | args.config_name='LieBN.yaml'
31 |
32 | @hydra_helpers
33 | @hydra.main(config_path='./conf/', config_name=args.config_name, version_base='1.1')
34 | # @hydra.main(config_path='./conf/', config_name=args.config_name)
35 | def main(cfg: DictConfig):
36 | training(cfg,args)
37 |
38 | if __name__ == '__main__':
39 |
40 | main()
41 |
--------------------------------------------------------------------------------
/LieBN_TSMNet/conf/LieBN.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - dataset: hinss2021
4 | - evaluation: inter-session+uda
5 | - preprocessing: bb4-36Hz
6 | - nnet: tsmnet_spddsmbn
7 | - override hydra/launcher: joblib
8 | fit:
9 | stratified: True
10 | epochs: 50
11 | batch_size_train: 50
12 | domains_per_batch: 5
13 | batch_size_test: -1
14 | validation_size: 0.2 #0.1 # float <1 for fraction; int for specific number
15 | test_size: 0.05 # percent of groups/domains used for testing
16 |
17 | saving_model:
18 | is_save_model: False
19 | is_save_results_each_fold: True
20 |
21 | score: balanced_accuracy # sklearn scores
22 | device: GPU
23 | threadnum: 1
24 | data_dir: /data
25 | is_debug: False
26 | seed: 42
27 | is_timing: False
28 |
29 | hydra:
30 | run:
31 | dir: outputs/${dataset.name}/${evaluation.strategy}
32 | sweep:
33 | dir: outputs/${dataset.name}/${evaluation.strategy}
34 | subdir: '.'
35 | launcher:
36 | n_jobs: -1
37 | job_logging:
38 | handlers:
39 | file:
40 | class: logging.FileHandler
41 | filename: default.log
42 | # job:
43 | # num_threads: 2
44 |
--------------------------------------------------------------------------------
/LieBN_TSMNet/conf/dataset/hinss2021.yaml:
--------------------------------------------------------------------------------
1 | # name of the dataset
2 | name: Hinss2021
3 | # python type (and parameters)
4 | type:
5 | _target_: datasetio.eeg.moabb.Hinss2021
6 |
7 | classes: ["easy", "medium", "difficult"]
8 | # channel selection
9 | channels: ['FP1', 'FP2', 'FPz', 'AF7', 'AF3', 'AFz', 'AF4', 'AF8', 'F7', 'F5', 'F3', 'F1', 'F2', 'F4', 'F6', 'F8', 'FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'C3', 'C4', 'CPz', 'PO3', 'PO4', 'POz', 'Oz' ]
10 | # if 'null' or not defined, all available channels will be used
11 | # resampling
12 | # if 'null' or not defined, the datasets sampling frequency will be used
13 | # resample: 250 # Hz
14 | ## epoching (relative to TASK CUE onset, as defined in the dataset)
15 | tmin: 0.0
16 | tmax: 1.996
17 |
--------------------------------------------------------------------------------
/LieBN_TSMNet/conf/evaluation/inter-session+uda.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - inter-session
3 | - _self_
4 | adapt:
5 | name: uda
6 | nadapt_domain: 1. # int -> absolute number of observations per CLASS
--------------------------------------------------------------------------------
/LieBN_TSMNet/conf/evaluation/inter-subject+uda.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - inter-subject
3 | - _self_
4 | adapt:
5 | name: uda
6 | nadapt_domain: 1. # int -> absolute number of observations per CLASS
--------------------------------------------------------------------------------
/LieBN_TSMNet/conf/nnet/tsmnet.yaml:
--------------------------------------------------------------------------------
1 | name: TSMNet
2 | inputtype: ${torchdtype:float32}
3 | model:
4 | _target_: spdnets.models.TSMNet
5 | temporal_filters: 4
6 | spatial_filters: 40
7 | subspacedims: 20
8 | learn_mean: False # Following SPDDSMBN, we set bias=I
9 | bnorm: null
10 | bnorm_dispersion: null
11 | optimizer:
12 | _target_: geoopt.optim.RiemannianAdam
13 | amsgrad: True
14 | weight_decay: 1e-4
15 | lr: 1e-3
16 | param_groups:
17 | -
18 | - 'spdnet.*.W'
19 | - weight_decay: 0
20 | scheduler:
21 | _target_: spdnets.batchnorm.DummyScheduler
--------------------------------------------------------------------------------
/LieBN_TSMNet/conf/nnet/tsmnet_LieBN.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - tsmnet
3 | - _self_
4 | name: TSMNet+LieBN
5 | model:
6 | bnorm: LieBN
7 | bnorm_dispersion: SCALAR
8 | metric: AIM
9 | theta: 1.0
10 | alpha: 1.0
11 | beta: 0.0
12 | scheduler:
13 | _target_: spdnets.Liebatchnorm.MomentumBatchNormScheduler
14 | epochs: ${sub:${fit.epochs},10}
15 | bs: ${rdiv:${fit.batch_size_train},${fit.domains_per_batch}}
16 | bs0: ${fit.batch_size_train}
17 | tau0: 0.85
18 |
19 | optimizer:
20 | param_groups:
21 | -
22 | - 'spd*.mean'
23 | - weight_decay: 0
24 | -
25 | - 'spdnet.*.W'
26 | - weight_decay: 0
--------------------------------------------------------------------------------
/LieBN_TSMNet/conf/nnet/tsmnet_spddsmbn.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - tsmnet
3 | - _self_
4 | name: TSMNet+SPDDSMBN
5 | model:
6 | bnorm: spddsbn
7 | bnorm_dispersion: SCALAR
8 | scheduler:
9 | _target_: spdnets.batchnorm.MomentumBatchNormScheduler
10 | epochs: ${sub:${fit.epochs},10}
11 | bs: ${rdiv:${fit.batch_size_train},${fit.domains_per_batch}}
12 | bs0: ${fit.batch_size_train}
13 | tau0: 0.85
14 |
15 | optimizer:
16 | param_groups:
17 | -
18 | - 'spd*.mean'
19 | - weight_decay: 0
20 | -
21 | - 'spdnet.*.W'
22 | - weight_decay: 0
--------------------------------------------------------------------------------
/LieBN_TSMNet/conf/preprocessing/bb4-36Hz.yaml:
--------------------------------------------------------------------------------
1 | bb4-36Hz:
2 | _target_: library.utils.moabb.CachedMotorImagery
3 | fmin: 4 # Hz
4 | fmax: 36 # Hz
5 | events: ${dataset.classes}
6 | channels: ${oc.select:dataset.channels}
7 | resample: ${oc.select:dataset.resample}
8 | tmin: ${oc.select:dataset.tmin, 0.0}
9 | tmax: ${oc.select:dataset.tmax}
10 |
--------------------------------------------------------------------------------
/LieBN_TSMNet/datasetio/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/datasetio/__init__.py
--------------------------------------------------------------------------------
/LieBN_TSMNet/datasetio/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/datasetio/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/datasetio/eeg/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/datasetio/eeg/__init__.py
--------------------------------------------------------------------------------
/LieBN_TSMNet/datasetio/eeg/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/datasetio/eeg/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/datasetio/eeg/moabb/__init__.py:
--------------------------------------------------------------------------------
1 | from moabb.datasets.bnci import BNCI2014001 as moabbBNCI2014001
2 | from moabb.datasets.bnci import BNCI2015001 as moabbBNCI2015001
3 | from moabb.datasets import Lee2019_MI as moabbLee2019
4 |
5 | from .base import PreprocessedDataset, CachableDatase
6 | from .stieger2021 import Stieger2021
7 | from .hinss2021 import Hinss2021
8 |
9 | class BNCI2014001(moabbBNCI2014001, CachableDatase):
10 |
11 | def _get_single_subject_data(self, subject):
12 | """return data for a single subject"""
13 | sessions = super()._get_single_subject_data(subject=subject)
14 | map = dict(session_T=1,session_E=2)
15 | sessions = dict([(map[k],v) for k, v in sessions.items()])
16 | return sessions
17 |
18 | class BNCI2015001(moabbBNCI2015001, CachableDatase):
19 |
20 | def _get_single_subject_data(self, subject):
21 | """return data for a single subject"""
22 | sessions = super()._get_single_subject_data(subject=subject)
23 | map = dict(session_A=1,session_B=2,session_C=3)
24 | sessions = dict([(map[k],v) for k, v in sessions.items()])
25 | return sessions
26 |
27 | class Lee2019(moabbLee2019, CachableDatase):
28 |
29 | def _get_single_subject_data(self, subject):
30 | """return data for a single subject"""
31 | sessions = super()._get_single_subject_data(subject=subject)
32 | map = dict(session_1=1,session_2=2)
33 | sessions = dict([(map[k],v) for k, v in sessions.items()])
34 | return sessions
35 |
--------------------------------------------------------------------------------
/LieBN_TSMNet/datasetio/eeg/moabb/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/datasetio/eeg/moabb/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/datasetio/eeg/moabb/__pycache__/base.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/datasetio/eeg/moabb/__pycache__/base.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/datasetio/eeg/moabb/__pycache__/hinss2021.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/datasetio/eeg/moabb/__pycache__/hinss2021.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/datasetio/eeg/moabb/__pycache__/stieger2021.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/datasetio/eeg/moabb/__pycache__/stieger2021.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/datasetio/eeg/moabb/base.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import mne
3 | import json
4 |
5 | from moabb.datasets.base import BaseDataset
6 |
7 | class CachableDatase(BaseDataset):
8 |
9 | def __init__(self, *args, **kwargs):
10 | super().__init__(*args, **kwargs)
11 |
12 | def __repr__(self) -> str:
13 | return json.dumps({self.__class__.__name__: self.__dict__})
14 |
15 | class PreprocessedDataset(CachableDatase):
16 |
17 | def __init__(self, *args, channels : Optional[list] = None, srate : Optional[int] = None, **kwargs):
18 | super().__init__(*args, **kwargs)
19 | self.channels = channels
20 | self.srate = srate
21 |
22 | def preprocess(self, raw):
23 |
24 | # find the events, first check stim_channels
25 | if len(mne.pick_types(raw.info, stim=True)) > 0:
26 | events = mne.find_events(raw, shortest_event=0, verbose=False)
27 | else:
28 | events = None # the dataset already uses annotations
29 |
30 | # optional resampling
31 | if self.srate is not None:
32 | ret = raw.resample(self.srate, events=events)
33 | raw, events = (ret, events) if events is None else (ret[0], ret[1])
34 |
35 | # convert optional events to annotations (before we discard the stim channels)
36 | if events is not None:
37 | rev_event_it = dict(zip(self.event_id.values(), self.event_id.keys()))
38 | annot = mne.annotations_from_events(events, raw.info['sfreq'], event_desc=rev_event_it)
39 | raw.set_annotations(annot)
40 |
41 | # pick subset of all channels
42 | if self.channels is not None:
43 | raw.pick_channels(self.channels)
44 | else:
45 | raw.pick_types(eeg=True)
46 |
47 | return raw
48 |
49 | def __repr__(self) -> str:
50 | return json.dumps({self.__class__.__name__: self.__dict__})
51 |
--------------------------------------------------------------------------------
/LieBN_TSMNet/datasetio/eeg/moabb/hinss2021.py:
--------------------------------------------------------------------------------
1 | import os
2 | import mne
3 | import numpy as np
4 | import pooch
5 | import logging
6 | import requests
7 | import json
8 | import subprocess
9 | import re
10 | import glob
11 |
12 | from scipy.io import loadmat
13 | import moabb.datasets.download as dl
14 |
15 | from .base import PreprocessedDataset
16 |
17 | log = logging.getLogger(__name__)
18 |
19 |
20 |
21 | def doi_to_url(doi, api_url = lambda x : f"https://doi.org/api/handles/{x}?type=URL"):
22 |
23 | url = None
24 | headers = {"Content-Type": "application/json"}
25 | response_data = dl.fs_issue_request("GET", api_url(doi), headers=headers)
26 |
27 | if 'values' in response_data:
28 | candidates = [ val['data']['value'] for val in response_data['values'] if 'data' in val and isinstance(val['data'], dict) and 'value' in val['data']]
29 | url = candidates[0] if len(candidates)> 0 else None
30 |
31 | return url
32 |
33 |
34 |
35 | def url_get_json(url : str):
36 |
37 | headers = {"Content-Type": "application/json"}
38 | response = dl.fs_issue_request("GET", url, headers=headers)
39 | return response
40 |
41 |
42 |
43 | class Hinss2021(PreprocessedDataset):
44 |
45 | ZENODO_JSON_API_URL = lambda x : f"https://zenodo.org/api/{x}"
46 |
47 | TASK_TO_EVENTID = dict(RS='rest', MATBeasy='easy', MATBmed='medium', MATBdiff='difficult')
48 |
49 | def __init__(self, interval = [0, 2], channels = None, srate = None):
50 | super().__init__(
51 | subjects=list(range(1, 15+1)),
52 | sessions_per_subject=2,
53 | events=dict(easy=1, medium=2, difficult=3, rest=4),
54 | code="Hinss2021",
55 | interval=interval,
56 | paradigm="imagery",
57 | doi="10.5281/zenodo.4917217",
58 | channels=channels,
59 | srate=srate
60 | )
61 |
62 |
63 |
64 | def preprocess(self, raw):
65 | # interpolate channels marked as bad
66 | if len(raw.info['bads']) > 0:
67 | raw.interpolate_bads()
68 | return super().preprocess(raw)
69 |
70 | def data_path(
71 | self, subject, path=None, force_update=False, update_path=None, verbose=None
72 | ):
73 | if subject not in self.subject_list:
74 | raise (ValueError("Invalid subject number"))
75 |
76 | key_dest = f"MNE-{self.code:s}-data"
77 | path = os.path.join(dl.get_dataset_path(self.code, path), key_dest)
78 |
79 | url = doi_to_url(self.doi)
80 | if url is None:
81 | raise ValueError("Could not find zenodo id based on dataset DOI!")
82 |
83 | zenodoid = url.split('/')[-1]
84 |
85 | metadata = url_get_json(Hinss2021.ZENODO_JSON_API_URL(f"records/{zenodoid}"))
86 |
87 | fnames = []
88 | for record in metadata['files']:
89 |
90 | fname = record['key']
91 | fpath = os.path.join(path, fname)
92 |
93 |
94 | # metadata
95 | # if record['type'] != 'zip' and not os.path.exists(fpath): # subject data
96 | # pooch.retrieve(record['links']['self'], record['checksum'], fname, path, downloader=pooch.HTTPDownloader(progressbar=True))
97 | # subject specific data
98 | if record['type'] == 'zip' and fname == f"P{subject:02d}.zip":
99 | if not os.path.exists(fpath):
100 | files = pooch.retrieve(record['links']['self'], record['checksum'], fname, path,
101 | processor=pooch.Unzip(),
102 | downloader=pooch.HTTPDownloader(progressbar=True))
103 |
104 | # load the data
105 | tasks = list(Hinss2021.TASK_TO_EVENTID.keys())
106 | taskpattern = '('+ '|'.join(tasks)+')'
107 | pattern = f'{fpath}.unzip/P{subject:02d}/S?/eeg/alldata_*.set'
108 | candidates = glob.glob(pattern, recursive=True)
109 | fnames += [c for c in candidates if re.search(f'.*{taskpattern}.set', c)]
110 |
111 | return fnames
112 |
113 |
114 | def _get_single_subject_data(self, subject):
115 | fnames = self.data_path(subject)
116 |
117 | subject_data = {}
118 | for fn in fnames:
119 | meta = re.search('alldata_sbj(?P\d\d)_sess(?P\d)_((?P\w+))',
120 | os.path.basename(fn))
121 | sid = int(meta['session'])
122 |
123 | if sid not in range(1,self.n_sessions+1):
124 | continue
125 |
126 | epochs = mne.io.read_epochs_eeglab(fn, verbose=False)
127 | assert(len(epochs.event_id) == 1)
128 | event_id = Hinss2021.TASK_TO_EVENTID[list(epochs.event_id.keys())[0]]
129 | epochs.event_id = {event_id : self.event_id[event_id]}
130 | epochs.events[:,2] = epochs.event_id[event_id]
131 |
132 | # covnert to continuous raw object with correct annotations
133 | continuous_data = np.swapaxes(epochs.get_data(),0,1).reshape((len(epochs.info['chs']),-1))
134 | raw = mne.io.RawArray(data=continuous_data, info=epochs.info, verbose=False, first_samp=1)
135 | # XXX use standard electrode layout rather than invidividual positions
136 | # raw.set_montage(epochs.get_montage())
137 | raw.set_montage('standard_1005')
138 | events = epochs.events.copy()
139 | evt_desc = dict(zip(epochs.event_id.values(),epochs.event_id.keys()))
140 |
141 | annot = mne.annotations_from_events(events, raw.info['sfreq'], event_desc=evt_desc, first_samp=1)
142 |
143 | raw.set_annotations(annot)
144 |
145 | if sid in subject_data:
146 | subject_data[sid][0].append(raw)
147 | else:
148 | subject_data[sid] = {0 : raw}
149 |
150 | # discard boundary annotations
151 | keep = [i for i, desc in enumerate(subject_data[sid][0].annotations.description) if desc in self.event_id]
152 | subject_data[sid][0].set_annotations(subject_data[sid][0].annotations[keep])
153 |
154 | return subject_data
155 |
--------------------------------------------------------------------------------
/LieBN_TSMNet/datasetio/eeg/moabb/stieger2021.py:
--------------------------------------------------------------------------------
1 | import os
2 | import mne
3 | import numpy as np
4 | import pooch
5 | import logging
6 |
7 | from scipy.io import loadmat
8 | import moabb.datasets.download as dl
9 |
10 | from .base import PreprocessedDataset
11 |
12 | log = logging.getLogger(__name__)
13 |
14 | class Stieger2021(PreprocessedDataset):
15 |
16 | BASE_URL = "https://ndownloader.figshare.com/files/"
17 |
18 | def __init__(self, interval = [0, 3], channels = None, srate = None, sessions = None):
19 | super().__init__(
20 | subjects=list(range(1, 63)),
21 | sessions_per_subject=11,
22 | events=dict(right_hand=1, left_hand=2, both_hand=3, rest=4),
23 | code="Stieger2021",
24 | interval=interval,
25 | paradigm="imagery",
26 | doi="10.1038/s41597-021-00883-1",
27 | channels=channels,
28 | srate=srate
29 | )
30 |
31 | self.sessions = sessions
32 | self.figshare_id = 13123148 # id on figshare
33 |
34 | assert(interval[0] >= 0.00) # the interval has to start after the cue onset
35 |
36 |
37 | def preprocess(self, raw):
38 | # interpolate channels marked as bad
39 | if len(raw.info['bads']) > 0:
40 | raw.interpolate_bads()
41 | return super().preprocess(raw)
42 |
43 | def data_path(
44 | self, subject, path=None, force_update=False, update_path=None, verbose=None
45 | ):
46 | if subject not in self.subject_list:
47 | raise (ValueError("Invalid subject number"))
48 |
49 | key_dest = f"MNE-{self.code:s}-data"
50 | path = os.path.join(dl.get_dataset_path(self.code, path), key_dest)
51 |
52 | filelist = dl.fs_get_file_list(self.figshare_id)
53 | reg = dl.fs_get_file_hash(filelist)
54 | fsn = dl.fs_get_file_id(filelist)
55 |
56 | spath = []
57 | for f in fsn.keys():
58 | if ".mat" not in f:
59 | continue
60 | sbj = int(f.split('_')[0][1:])
61 | ses = int(f.split('_')[-1].split('.')[0])
62 | if sbj == subject and ses in self.sessions:
63 | fpath = os.path.join(path, f)
64 | if not os.path.exists(fpath):
65 | pooch.retrieve(Stieger2021.BASE_URL + fsn[f], reg[fsn[f]], f, path, downloader=pooch.HTTPDownloader(progressbar=True))
66 | spath.append(fpath)
67 | return spath
68 |
69 | def _get_single_subject_data(self, subject):
70 | fname = self.data_path(subject)
71 |
72 | subject_data = {}
73 |
74 | for fn in fname:
75 |
76 | session = int(os.path.basename(fn).split('_')[2].split('.')[0])
77 |
78 | if self.sessions is not None:
79 | if session not in set(self.sessions):
80 | continue
81 |
82 | container = loadmat(
83 | fn,
84 | squeeze_me=True,
85 | struct_as_record=False,
86 | verify_compressed_data_integrity=False,
87 | )['BCI']
88 |
89 | srate = container.SRATE
90 |
91 | eeg_ch_names = container.chaninfo.label.tolist()
92 | # adjust naming convention
93 | eeg_ch_names = [ch.replace('Z', 'z').replace('FP','Fp') for ch in eeg_ch_names]
94 | # extract all standard EEG channels
95 | montage = mne.channels.make_standard_montage("standard_1005")
96 | channel_mask = np.isin(eeg_ch_names, montage.ch_names)
97 | ch_names = [ch for ch, found in zip(eeg_ch_names, channel_mask) if found] + ["Stim"]
98 | ch_types = ["eeg"] * channel_mask.sum() + ["stim"]
99 |
100 | X_flat = []
101 | stim_flat = []
102 | for i in range(container.data.shape[0]):
103 | x = container.data[i][channel_mask,:]
104 | y = container.TrialData[i].targetnumber
105 | stim = np.zeros_like(container.time[i])
106 | if container.TrialData[i].artifact == 0 and (container.TrialData[i].triallength + 2) > self.interval[1]:
107 | assert(container.time[i][2*srate] == 0) # this should be the cue time-point
108 | stim[2*srate] = y
109 | X_flat.append(x)
110 | stim_flat.append(stim[None,:])
111 |
112 | X_flat = np.concatenate(X_flat, axis=1)
113 | stim_flat = np.concatenate(stim_flat, axis=1)
114 |
115 | p_keep = np.flatnonzero(stim_flat).shape[0]/container.data.shape[0]
116 |
117 | message = f'Record {subject}/{session} (subject/session): rejecting {(1 - p_keep)*100:.0f}% of the trials.'
118 | if p_keep < 0.5:
119 | log.warning(message)
120 | else:
121 | log.info(message)
122 |
123 | eeg_data = np.concatenate([X_flat * 1e-6, stim_flat], axis=0)
124 |
125 | info = mne.create_info(ch_names=ch_names, ch_types=ch_types, sfreq=srate)
126 | raw = mne.io.RawArray(data=eeg_data, info=info, verbose=False)
127 | raw.set_montage(montage)
128 | if isinstance(container.chaninfo.noisechan, int):
129 | badchanidxs = [container.chaninfo.noisechan]
130 | elif isinstance(container.chaninfo.noisechan, np.ndarray):
131 | badchanidxs = container.chaninfo.noisechan
132 | else:
133 | badchanidxs = []
134 |
135 | for idx in badchanidxs:
136 | # badchan = [eeg_ch_names[idx-1] for idx in badchanidxs]
137 | used_channels = ch_names if self.channels is None else self.channels
138 | if eeg_ch_names[idx-1] in used_channels:
139 | raw.info['bads'].append(eeg_ch_names[idx-1])
140 |
141 | if len(raw.info['bads']) > 0:
142 | log.info(f'Record {subject}/{session} (subject/session): bad channels that will be interpolated: {raw.info["bads"]}')
143 |
144 | # subject_data[session] = {"run_0": self._common_prep(raw)}\
145 | subject_data[session] = {"run_0": self.preprocess(raw)}
146 | return subject_data
147 |
--------------------------------------------------------------------------------
/LieBN_TSMNet/experiments_Hinss21.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | data_dir=/data #change this to your data folder
3 |
4 | ### Experiments on SPDDSMBN
5 | [ $? -eq 0 ] && python TSMNet-LieBN.py -m data_dir=$data_dir dataset=hinss2021 evaluation=inter-subject+uda,inter-session+uda nnet=tsmnet_spddsmbn
6 | ### Experiments on DSMLieBN under standard metrics
7 | [ $? -eq 0 ] && python TSMNet-LieBN.py -m data_dir=$data_dir dataset=hinss2021 evaluation=inter-subject+uda,inter-session+uda nnet=tsmnet_LieBN nnet.model.metric=LEM,LCM,AIM
8 |
9 | ### Experiments on DSMLieBN under 0.5-LCM for inter-session
10 | [ $? -eq 0 ] && python TSMNet-LieBN.py -m data_dir=$data_dir dataset=hinss2021 evaluation=inter-session+uda nnet=tsmnet_LieBN nnet.model.metric=LCM nnet.model.theta=0.5
11 | ### Experiments on DSMLieBN under -0.5-AIM for inter-subject
12 | [ $? -eq 0 ] && python TSMNet-LieBN.py -m data_dir=$data_dir dataset=hinss2021 evaluation=inter-subject+uda nnet=tsmnet_LieBN nnet.model.metric=AIM nnet.model.theta=-0.5
--------------------------------------------------------------------------------
/LieBN_TSMNet/library/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/library/__init__.py
--------------------------------------------------------------------------------
/LieBN_TSMNet/library/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/library/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/library/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/library/utils/__init__.py
--------------------------------------------------------------------------------
/LieBN_TSMNet/library/utils/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/library/utils/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/library/utils/hydra/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import hydra
3 | from hydra.core.hydra_config import HydraConfig
4 | from omegaconf import OmegaConf, DictConfig
5 | import torch
6 |
7 | from sklearn.pipeline import Pipeline
8 |
9 | def hydra_helpers(func):
10 | def inner(*args, **kwargs):
11 | # setup helpers
12 |
13 | # omega conf helpers
14 | OmegaConf.register_new_resolver("len", lambda x:len(x), replace=True)
15 | OmegaConf.register_new_resolver("add", lambda x,y:x+y, replace=True)
16 | OmegaConf.register_new_resolver("sub", lambda x,y:x-y, replace=True)
17 | OmegaConf.register_new_resolver("mul", lambda x,y:x*y, replace=True)
18 | OmegaConf.register_new_resolver("rdiv", lambda x,y:x/y, replace=True)
19 |
20 | STR2TORCHDTYPE = {
21 | 'float32': torch.float32,
22 | 'float64': torch.float64,
23 | 'double': torch.double,
24 | }
25 | OmegaConf.register_new_resolver("torchdtype", lambda x:STR2TORCHDTYPE[x], replace=True)
26 | # if func is not None and list(kwargs.keys())[0] !='args':
27 | if func is not None:
28 | # func(args = kwargs['args'])
29 | func(*args, **kwargs)
30 | return inner
31 |
32 |
33 | def make_sklearn_pipeline(steps_config) -> Pipeline:
34 |
35 | steps = []
36 | for step_config in steps_config:
37 |
38 | # retrieve the name and parameter dictionary of the current steps
39 | step_name, step_transform = next(iter(step_config.items()))
40 | # instantiate the pipeline step, and append to the list of steps
41 | if isinstance(step_transform, DictConfig):
42 | pipeline_step = (step_name, hydra.utils.instantiate(step_transform, _convert_='partial'))
43 | else:
44 | pipeline_step = (step_name, step_transform)
45 | steps.append(pipeline_step)
46 |
47 | return Pipeline(steps)
48 |
--------------------------------------------------------------------------------
/LieBN_TSMNet/library/utils/hydra/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/library/utils/hydra/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/library/utils/moabb/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import re
3 | import json
4 | import os
5 | import hashlib
6 | import numpy as np
7 | import pandas as pd
8 | import mne
9 |
10 | from sklearn.base import BaseEstimator
11 | from moabb.paradigms.base import BaseParadigm
12 | from moabb.paradigms.motor_imagery import FilterBankMotorImagery, MotorImagery
13 | from mne import get_config, set_config
14 | from mne.datasets.utils import _get_path
15 | from mne.io import read_info, write_info
16 | from skorch import dataset
17 |
18 |
19 | log = logging.getLogger(__name__)
20 |
21 | class CachedParadigm(BaseParadigm):
22 |
23 | def _get_string_rep(self, obj):
24 | if issubclass(type(obj), BaseEstimator):
25 | str_repr = repr(obj.get_params())
26 | else:
27 | str_repr = repr(obj)
28 | str_no_addresses = re.sub("0x[a-z0-9]*", "0x__", str_repr)
29 | return str_no_addresses.replace("\n", "")
30 |
31 | def _get_rep(self, dataset):
32 | return self._get_string_rep(dataset) + '\n' + self._get_string_rep(self)
33 |
34 | def _get_cache_dir(self, rep):
35 | if get_config("MNEDATASET_TMP_DIR") is None:
36 | set_config("MNEDATASET_TMP_DIR", os.path.join(os.path.expanduser("~"), "mne_data"))
37 | base_dir = _get_path(None, "MNEDATASET_TMP_DIR", "preprocessed")
38 |
39 | digest = hashlib.sha1(rep.encode("utf8")).hexdigest()
40 |
41 | cache_dir = os.path.join(
42 | base_dir,
43 | "preprocessed",
44 | digest
45 | )
46 | return cache_dir
47 |
48 |
49 | def process_raw(self, raw, dataset, return_epochs=False,return_raws=False):
50 | # get events id
51 | event_id = self.used_events(dataset)
52 |
53 | # find the events, first check stim_channels then annotations
54 | stim_channels = mne.utils._get_stim_channel(None, raw.info,
55 | raise_error=False)
56 | if len(stim_channels) > 0:
57 | events = mne.find_events(raw, shortest_event=0, verbose=False)
58 | else:
59 | events, _ = mne.events_from_annotations(raw, event_id=event_id, verbose=False)
60 |
61 | # picks channels
62 | if self.channels is None:
63 | picks = mne.pick_types(raw.info, eeg=True, stim=False)
64 | else:
65 | picks = mne.pick_types(raw.info, stim=False, include=self.channels)
66 |
67 | # pick events, based on event_id
68 | try:
69 | events = mne.pick_events(events, include=list(event_id.values()))
70 | except RuntimeError:
71 | # skip raw if no event found
72 | return
73 |
74 | # get interval
75 | tmin = self.tmin + dataset.interval[0]
76 | if self.tmax is None:
77 | tmax = dataset.interval[1]
78 | else:
79 | tmax = self.tmax + dataset.interval[0]
80 |
81 | X = []
82 | for bandpass in self.filters:
83 | fmin, fmax = bandpass
84 | # filter data
85 | if fmin is None and fmax is None:
86 | raw_f = raw
87 | else:
88 | raw_f = raw.copy().filter(fmin, fmax, method='iir',
89 | picks=picks, verbose=False)
90 | # epoch data
91 | epochs = mne.Epochs(raw_f, events, event_id=event_id,
92 | tmin=tmin, tmax=tmax, proj=False,
93 | baseline=None, preload=True,
94 | verbose=False, picks=picks,
95 | event_repeated='drop',
96 | on_missing='ignore')
97 | if self.resample is not None:
98 | epochs = epochs.resample(self.resample)
99 | # rescale to work with uV
100 | if return_epochs:
101 | X.append(epochs)
102 | else:
103 | X.append(dataset.unit_factor * epochs.get_data())
104 |
105 | inv_events = {k: v for v, k in event_id.items()}
106 | labels = np.array([inv_events[e] for e in epochs.events[:, -1]])
107 |
108 | # if only one band, return a 3D array, otherwise return a 4D
109 | if len(self.filters) == 1:
110 | X = X[0]
111 | else:
112 | X = np.array(X).transpose((1, 2, 3, 0))
113 |
114 | metadata = pd.DataFrame(index=range(len(labels)))
115 | return X, labels, metadata
116 |
117 |
118 | def get_data(self, dataset, subjects=None, return_epochs=False):
119 |
120 | if return_epochs:
121 | raise ValueError("Only return_epochs=False is supported.")
122 |
123 | rep = self._get_rep(dataset)
124 | cache_dir = self._get_cache_dir(rep)
125 | os.makedirs(cache_dir, exist_ok=True)
126 |
127 | X = [] if return_epochs else np.array([])
128 | labels = []
129 | metadata = pd.Series([])
130 |
131 | if subjects is None:
132 | subjects = dataset.subject_list
133 |
134 | if not os.path.isfile(os.path.join(cache_dir, 'repr.json')):
135 | with open(os.path.join(cache_dir, 'repr.json'), 'w+') as f:
136 | f.write(self._get_rep(dataset))
137 |
138 | for subject in subjects:
139 | if not os.path.isfile(os.path.join(cache_dir, f'{subject}.npy')):
140 | # compute
141 | x, lbs, meta = super().get_data(dataset, [subject], return_epochs)
142 | np.save(os.path.join(cache_dir, f'{subject}.npy'), x)
143 | meta['label'] = lbs
144 | meta.to_csv(os.path.join(cache_dir, f'{subject}.csv'), index=False)
145 | log.info(f'saved cached data in directory {cache_dir}')
146 |
147 | # load from cache
148 | log.info(f'loading cached data from directory {cache_dir}')
149 | x = np.load(os.path.join(cache_dir, f'{subject}.npy'), mmap_mode ='r')
150 | meta = pd.read_csv(os.path.join(cache_dir, f'{subject}.csv'))
151 | lbs = meta['label'].tolist()
152 |
153 | if return_epochs:
154 | X.append(x)
155 | else:
156 | X = np.append(X, x, axis=0) if len(X) else x
157 | labels = np.append(labels, lbs, axis=0)
158 | metadata = pd.concat([metadata, meta], ignore_index=True)
159 |
160 | return X, labels, metadata
161 |
162 | def get_info(self, dataset):
163 | # check if the info has been saved
164 | rep = self._get_rep(dataset)
165 | cache_dir = self._get_cache_dir(rep)
166 | os.makedirs(cache_dir, exist_ok=True)
167 | info_file = os.path.join(cache_dir, f'raw-info.fif')
168 | if not os.path.isfile(info_file):
169 | x, _, _ = super().get_data(dataset, [dataset.subject_list[0]], True)
170 | info = x.info
171 | write_info(info_file, info)
172 | log.info(f'saved cached info in directory {cache_dir}')
173 | else:
174 | log.info(f'loading cached info from directory {cache_dir}')
175 | info = read_info(info_file)
176 | return info
177 |
178 | def __repr__(self) -> str:
179 | return json.dumps({self.__class__.__name__: self.__dict__})
180 |
181 |
182 | class CachedMotorImagery(CachedParadigm, MotorImagery):
183 |
184 | def __init__(self, **kwargs):
185 | n_classes = len(kwargs['events'])
186 | super().__init__(n_classes=n_classes, **kwargs)
187 |
188 |
189 | class CachedFilterBankMotorImagery(CachedParadigm, FilterBankMotorImagery):
190 |
191 | def __init__(self, **kwargs):
192 | n_classes = len(kwargs['events'])
193 | super().__init__(n_classes=n_classes, **kwargs)
194 |
195 |
--------------------------------------------------------------------------------
/LieBN_TSMNet/library/utils/moabb/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/library/utils/moabb/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/library/utils/pyriemann/__init__.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.linalg import eigvalsh
3 | from pyriemann.tangentspace import TangentSpace
4 |
5 | def squared_airm(A, B):
6 | return np.square(np.log(eigvalsh(A, B))).sum()
7 |
8 | def airm(A,B):
9 | return np.sqrt(squared_airm(A,B))
10 |
11 | def geom_mean(As):
12 | ts = TangentSpace()
13 | ts.fit(As)
14 | return ts.reference_
15 |
16 | def tsm(As):
17 | ts = TangentSpace()
18 | ts.fit(As)
19 | return ts.transform(As)
--------------------------------------------------------------------------------
/LieBN_TSMNet/library/utils/torch/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/library/utils/torch/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/BaseBatchNorm.py:
--------------------------------------------------------------------------------
1 |
2 | from enum import Enum
3 | import torch.nn as nn
4 |
5 | class BatchNormDispersion(Enum):
6 | NONE = 'mean'
7 | SCALAR = 'scalar'
8 | VECTOR = 'vector'
9 | class BatchNormTestStatsMode(Enum):
10 | BUFFER = 'buffer'
11 | REFIT = 'refit'
12 | ADAPT = 'adapt'
13 | class BatchNormTestStatsInterface:
14 | def set_test_stats_mode(self, mode: BatchNormTestStatsMode):
15 | pass
16 |
17 | class BaseBatchNorm(nn.Module, BatchNormTestStatsInterface):
18 | def __init__(self, eta=1.0, eta_test=0.1, test_stats_mode: BatchNormTestStatsMode = BatchNormTestStatsMode.BUFFER):
19 | super().__init__()
20 | self.eta = eta
21 | self.eta_test = eta_test
22 | self.test_stats_mode = test_stats_mode
23 |
24 | def set_test_stats_mode(self, mode: BatchNormTestStatsMode):
25 | self.test_stats_mode = mode
26 |
27 |
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/LieBNImpl.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Ziheng Chen
3 | Please cite the paper below if you use the code:
4 |
5 | Ziheng Chen, Yue Song, Yunmei Liu, Nicu Sebe. A Lie Group Approach to Riemannian Batch Normalization. ICLR 2024
6 |
7 | Copyright (C) 2024 Ziheng Chen
8 | All rights reserved.
9 | """
10 |
11 | from builtins import NotImplementedError
12 | from typing import Tuple
13 | import torch as th
14 | import torch.nn as nn
15 |
16 | from geoopt.tensor import ManifoldParameter, ManifoldTensor
17 | from .manifolds import SymmetricPositiveDefinite
18 | from . import functionals
19 | from spdnets.BaseBatchNorm import BaseBatchNorm,BatchNormDispersion,BatchNormTestStatsMode
20 |
21 |
22 | class SPDLieBatchNormImpl(BaseBatchNorm):
23 | def __init__(self, shape: Tuple[int, ...] or th.Size, batchdim: int,
24 | eta=1., eta_test=0.1,
25 | karcher_steps: int = 1, learn_mean=False, learn_std=True,
26 | dispersion: BatchNormDispersion = BatchNormDispersion.SCALAR,
27 | eps=1e-5, mean=None, std=None,
28 | metric='AIM',theta: float =1.,alpha: float=1.0,beta: float=0.,
29 | **kwargs):
30 | super().__init__(eta, eta_test)
31 | if metric != "AIM" and metric != "LCM" and metric != "LEM" :
32 | raise Exception('unknown metric {}'.format(metric))
33 | self.metric = metric;self.theta = th.tensor(theta);self.alpha = th.tensor(alpha);self.beta = th.tensor(beta)
34 | self.identity = th.eye(shape[-1])
35 |
36 | # the last two dimensions are used for SPD manifold
37 | assert (shape[-1] == shape[-2])
38 |
39 | if dispersion == BatchNormDispersion.VECTOR:
40 | raise NotImplementedError()
41 |
42 | self.dispersion = dispersion
43 | self.learn_mean = learn_mean
44 | self.learn_std = learn_std
45 | self.batchdim = batchdim
46 | self.karcher_steps = karcher_steps
47 | self.eps = eps
48 |
49 | init_mean = th.diag_embed(th.ones(shape[:-1], **kwargs))
50 | init_var = th.ones((*shape[:-2], 1), **kwargs)
51 |
52 | self.register_buffer('running_mean', ManifoldTensor(init_mean,
53 | manifold=SymmetricPositiveDefinite()))
54 | self.register_buffer('running_var', init_var)
55 | self.register_buffer('running_mean_test', ManifoldTensor(init_mean,
56 | manifold=SymmetricPositiveDefinite()))
57 | self.register_buffer('running_var_test', init_var)
58 |
59 | if mean is not None:
60 | self.mean = mean
61 | else:
62 | if self.learn_mean:
63 | self.mean = ManifoldParameter(init_mean.clone(), manifold=SymmetricPositiveDefinite())
64 | else:
65 | self.mean = ManifoldTensor(init_mean.clone(), manifold=SymmetricPositiveDefinite())
66 |
67 | if self.dispersion is not BatchNormDispersion.NONE:
68 | if std is not None:
69 | self.std = std
70 | else:
71 | if self.learn_std:
72 | self.std = nn.parameter.Parameter(init_var.clone())
73 | else:
74 | self.std = init_var.clone()
75 |
76 | @th.no_grad()
77 | def initrunningstats(self, X):
78 | self.running_mean.data = self.cal_geom_mean(X)
79 | self.running_mean_test.data = self.running_mean.data.clone()
80 |
81 | if self.dispersion is BatchNormDispersion.SCALAR:
82 | self.running_var.data = self.cal_geom_var(X, self.running_mean)
83 | self.running_var_test.data = self.running_var.data.clone()
84 |
85 | def forward(self, X):
86 | X_deformed = self.deformation(X);
87 | if self.learn_mean:
88 | weight = self.deformation(self.mean)
89 |
90 | if self.training:
91 | # compute batch mean
92 | batch_mean = self.cal_geom_mean(X_deformed)
93 | # update the running_mean, running_mean_test
94 | self.updating_running_means(batch_mean)
95 | # compute batch vars w.r.t. running_mean and running_mean_test, update the running_var and running_var_test
96 | if self.dispersion is not BatchNormDispersion.NONE:
97 | batch_var = self.cal_geom_var(X_deformed,self.running_mean)
98 | batch_var_test = self.cal_geom_var(X_deformed, self.running_mean_test)
99 | self.updating_running_var(batch_var,batch_var_test)
100 | else:
101 | if self.test_stats_mode == BatchNormTestStatsMode.BUFFER:
102 | pass # nothing to do: use the ones in the buffer
103 | elif self.test_stats_mode == BatchNormTestStatsMode.REFIT:
104 | self.initrunningstats(X_deformed)
105 | elif self.test_stats_mode == BatchNormTestStatsMode.ADAPT:
106 | raise NotImplementedError()
107 |
108 | rm = self.running_mean if self.training else self.running_mean_test
109 | if self.dispersion is BatchNormDispersion.SCALAR:
110 | rv = self.running_var if self.training else self.running_var_test
111 |
112 | # subtracting mean
113 | X_centered = self.Left_translation(X_deformed,rm,True)
114 | # scaling and shifting
115 | if self.dispersion is BatchNormDispersion.SCALAR:
116 | X_scaled = self.scaling(X_centered, rv,self.std)
117 | # biasing
118 | X_normalized = self.Left_translation(X_scaled, weight,False) if self.learn_mean else X_scaled
119 | # inv of deformation
120 | X_new = self.inv_deformation(X_normalized)
121 |
122 | return X_new
123 |
124 | def spd_power(self,X):
125 | if self.theta == 1.:
126 | X_power = X
127 | else:
128 | X_power = functionals.sym_powm.apply(X, self.theta)
129 | return X_power
130 |
131 | def inv_power(self,X):
132 | if self.theta == 1.:
133 | X_power = X
134 | else:
135 | X_power = functionals.sym_powm.apply(X, 1/self.theta)
136 | return X_power
137 | def deformation(self,X):
138 |
139 | if self.metric=='AIM':
140 | X_deformed = self.spd_power(X)
141 | elif self.metric == 'LEM':
142 | X_deformed = functionals.sym_logm.apply(X)
143 | elif self.metric == 'LCM':
144 | X_power = self.spd_power(X)
145 | L = th.linalg.cholesky(X_power)
146 | diag_part = th.diag_embed(th.log(th.diagonal(L, dim1=-2, dim2=-1)))
147 | X_deformed = L.tril(-1) + diag_part
148 |
149 | return X_deformed
150 |
151 | def inv_deformation(self,X):
152 | if self.metric=='AIM':
153 | X_inv_deformed = self.inv_power(X)
154 | elif self.metric == 'LEM':
155 | X_inv_deformed = functionals.sym_expm.apply(X)
156 | elif self.metric == 'LCM':
157 | Cho = X.tril(-1) + th.diag_embed(th.exp(th.diagonal(X, dim1=-2, dim2=-1)))
158 | spd = Cho.matmul(Cho.transpose(-1,-2))
159 | X_inv_deformed = self.inv_power(spd)
160 | return X_inv_deformed
161 |
162 | def cal_geom_mean(self, X):
163 | """Frechet mean"""
164 | if self.metric == 'AIM':
165 | mean = self.KF_AIM(X.detach())
166 | elif self.metric == 'LEM' or self.metric == 'LCM':
167 | mean = X.detach().mean(dim=self.batchdim, keepdim=True)
168 |
169 | return mean
170 | def cal_geom_var(self, X, rm):
171 | """Frechet variance w.r.t. rm"""
172 | spd = X.detach()
173 | if self.metric == 'AIM':
174 | rm_invsq = functionals.sym_invsqrtm.apply(rm)
175 | if self.beta == 0.:
176 | dists = self.alpha * th.linalg.matrix_norm(
177 | functionals.sym_logm.apply(rm_invsq @ spd @ rm_invsq)).square()
178 | else:
179 | dists = self.alpha * th.linalg.matrix_norm(functionals.sym_logm.apply(rm_invsq @ spd @ rm_invsq)).square()\
180 | + self.beta * th.logdet(th.linalg.solve(rm,spd)).square()
181 |
182 | elif self.metric == 'LEM' or self.metric == 'LCM':
183 | tmp = spd - rm
184 | if self.beta == 0.:
185 | dists = self.alpha * th.linalg.matrix_norm(tmp).square()
186 | else:
187 | item1 = th.linalg.matrix_norm(tmp)
188 | item2 = functionals.trace(tmp)
189 | dists = self.alpha * item1.square() + self.beta * item2.square()
190 |
191 | var = dists.mean(dim=self.batchdim, keepdim=True)
192 | if self.metric == 'AIM' or self.metric == 'LCM':
193 | var_final = var* (1/(self.theta**2))
194 | else:
195 | var_final = var
196 | return var_final.unsqueeze(dim=0)
197 | def Left_translation(self, X, P, is_inverse):
198 | """Left translation by P"""
199 | if self.metric == 'AIM':
200 | L = th.linalg.cholesky(P);
201 | if is_inverse:
202 | tmp = th.linalg.solve(L, X)
203 | X_new = th.linalg.solve(L.transpose(-1, -2), tmp, left=False)
204 | else:
205 | X_new = L.matmul(X).matmul(L.transpose(-1, -2))
206 |
207 | elif self.metric == 'LEM' or self.metric == 'LCM':
208 | X_new = X - P if is_inverse else X + P
209 |
210 | return X_new
211 | def scaling(self, X, var,scale):
212 | """Scaling by variance"""
213 | factor = scale / (var+self.eps).sqrt()
214 | if self.metric == 'AIM':
215 | X_new = functionals.sym_powm.apply(X, factor)
216 | elif self.metric == 'LEM' or self.metric == 'LCM':
217 | X_new = X * factor
218 |
219 | return X_new
220 | def updating_running_means(self,batch_mean):
221 | """updating running means"""
222 | with th.no_grad():
223 | if self.metric == 'AIM':
224 | self.running_mean.data = functionals.spd_2point_interpolation(self.running_mean, batch_mean,self.eta)
225 | self.running_mean_test.data = functionals.spd_2point_interpolation(self.running_mean_test, batch_mean,self.eta_test)
226 | elif self.metric == 'LEM' or self.metric == 'LCM':
227 | self.running_mean.data = (1-self.eta) * self.running_mean+ batch_mean * self.eta
228 | self.running_mean_test.data = (1 - self.eta_test) * self.running_mean_test + batch_mean * self.eta_test
229 |
230 | def updating_running_var(self, batch_var, batch_var_test):
231 | """updating running vars"""
232 | with th.no_grad():
233 | if self.dispersion is BatchNormDispersion.SCALAR:
234 | self.running_var = (1. - self.eta) * self.running_var + self.eta * batch_var
235 | self.running_var_test = (1. - self.eta_test) * self.running_var_test + self.eta_test * batch_var_test
236 |
237 | def KF_AIM(self,X,karcher_steps=1):
238 | '''
239 | Function which computes the Riemannian barycenter for a batch of data using the Karcher flow
240 | Input x is a batch of SPD matrices (batch_size,1,n,n) to average
241 | Output is (n,n) Riemannian mean
242 | '''
243 | batch_mean = X.mean(dim=self.batchdim, keepdim=True)
244 | for _ in range(karcher_steps):
245 | bm_sq, bm_invsq = functionals.sym_invsqrtm2.apply(batch_mean)
246 | XT = functionals.sym_logm.apply(bm_invsq @ X @ bm_invsq)
247 | GT = XT.mean(dim=self.batchdim, keepdim=True)
248 | batch_mean = bm_sq @ functionals.sym_expm.apply(GT) @ bm_sq
249 | return batch_mean
250 |
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/__init__.py
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/__pycache__/batchnorm.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/__pycache__/batchnorm.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/__pycache__/functionals.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/__pycache__/functionals.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/__pycache__/manifolds.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/__pycache__/manifolds.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/__pycache__/modules.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/__pycache__/modules.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/manifolds.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Union, Optional, Tuple
3 | from geoopt.manifolds import Manifold
4 | from . import functionals
5 |
6 | __all__ = ["SymmetricPositiveDefinite"]
7 |
8 | class SymmetricPositiveDefinite(Manifold):
9 | """
10 | Subclass of the SymmetricPositiveDefinite manifold using the
11 | affine invariant Riemannian metric (AIRM) as default metric
12 | """
13 |
14 | __scaling__ = Manifold.__scaling__.copy()
15 | name = "SymmetricPositiveDefinite"
16 | ndim = 2
17 | reversible = False
18 |
19 | def __init__(self):
20 | super().__init__()
21 |
22 | def dist(self, x: torch.Tensor, y: torch.Tensor, keepdim) -> torch.Tensor:
23 | """
24 | Computes the affine invariant Riemannian metric (AIM)
25 | """
26 | inv_sqrt_x = functionals.sym_invsqrtm.apply(x)
27 | return torch.norm(
28 | functionals.sym_logm.apply(inv_sqrt_x @ y @ inv_sqrt_x),
29 | dim=[-1, -2],
30 | keepdim=keepdim,
31 | )
32 |
33 | def _check_point_on_manifold(
34 | self, x: torch.Tensor, *, atol=1e-5, rtol=1e-5
35 | ) -> Union[Tuple[bool, Optional[str]], bool]:
36 | ok = torch.allclose(x, x.transpose(-1, -2), atol=atol, rtol=rtol)
37 | if not ok:
38 | return False, "`x != x.transpose` with atol={}, rtol={}".format(atol, rtol)
39 | e = torch.linalg.eigvalsh(x)
40 | ok = (e > -atol).min()
41 | if not ok:
42 | return False, "eigenvalues of x are not all greater than 0."
43 | return True, None
44 |
45 | def _check_vector_on_tangent(
46 | self, x: torch.Tensor, u: torch.Tensor, *, atol=1e-5, rtol=1e-5
47 | ) -> Union[Tuple[bool, Optional[str]], bool]:
48 | ok = torch.allclose(u, u.transpose(-1, -2), atol=atol, rtol=rtol)
49 | if not ok:
50 | return False, "`u != u.transpose` with atol={}, rtol={}".format(atol, rtol)
51 | return True, None
52 |
53 | def projx(self, x: torch.Tensor) -> torch.Tensor:
54 | symx = functionals.ensure_sym(x)
55 | return functionals.sym_abseig.apply(symx)
56 |
57 | def proju(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
58 | return functionals.ensure_sym(u)
59 |
60 | def egrad2rgrad(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
61 | return x @ self.proju(x, u) @ x
62 |
63 | def inner(self, x: torch.Tensor, u: torch.Tensor, v: Optional[torch.Tensor], keepdim) -> torch.Tensor:
64 | if v is None:
65 | v = u
66 | inv_x = functionals.sym_invm.apply(x)
67 | ret = torch.diagonal(inv_x @ u @ inv_x @ v, dim1=-2, dim2=-1).sum(-1)
68 | if keepdim:
69 | return torch.unsqueeze(torch.unsqueeze(ret, -1), -1)
70 | return ret
71 |
72 | def retr(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
73 | inv_x = functionals.sym_invm.apply(x)
74 | return functionals.ensure_sym(x + u + 0.5 * u @ inv_x @ u)
75 | # return self.expmap(x, u)
76 |
77 | def expmap(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
78 | sqrt_x, inv_sqrt_x = functionals.sym_invsqrtm2.apply(x)
79 | return sqrt_x @ functionals.sym_expm.apply(inv_sqrt_x @ u @ inv_sqrt_x) @ sqrt_x
80 |
81 | def logmap(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
82 | sqrt_x, inv_sqrt_x = functionals.sym_invsqrtm2.apply(x)
83 | return sqrt_x @ functionals.sym_logm.apply(inv_sqrt_x @ u @ inv_sqrt_x) @ sqrt_x
84 |
85 | def extra_repr(self) -> str:
86 | return "default_metric=AIM"
87 |
88 | def transp(self, x: torch.Tensor, y: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
89 |
90 | xinvy = torch.linalg.solve(x.double(),y.double())
91 | s, U = torch.linalg.eig(xinvy.transpose(-2,-1))
92 | s = s.real
93 | U = U.real
94 |
95 | Ut = U.transpose(-2,-1)
96 | Esqm = torch.linalg.solve(Ut, torch.diag_embed(s.sqrt()) @ Ut).transpose(-2,-1).to(y.dtype)
97 |
98 | return Esqm @ v @ Esqm.transpose(-1,-2)
99 |
100 | def random(self, *size, dtype=None, device=None, **kwargs) -> torch.Tensor:
101 | tens = torch.randn(*size, dtype=dtype, device=device, **kwargs)
102 | tens = functionals.ensure_sym(tens)
103 | tens = functionals.sym_expm.apply(tens)
104 | return tens
105 |
106 | def barycenter(self, X : torch.Tensor, steps : int = 1, dim = 0) -> torch.Tensor:
107 | """
108 | Compute several steps of the Kracher flow algorithm to estimate the
109 | Barycenter on the manifold.
110 | """
111 | return functionals.spd_mean_kracher_flow(X, None, maxiter=steps, dim=dim, return_dist=False)
112 |
113 | def geodesic(self, A : torch.Tensor, B : torch.Tensor, t : torch.Tensor) -> torch.Tensor:
114 | """
115 | Compute geodesic between two SPD tensors A and B and return
116 | point on the geodesic at length t \in [0,1]
117 | if t = 0, then A is returned
118 | if t = 1, then B is returned
119 | """
120 | Asq, Ainvsq = functionals.sym_invsqrtm2.apply(A)
121 | return Asq @ functionals.sym_powm.apply(Ainvsq @ B @ Ainvsq, t) @ Asq
122 |
123 | def transp_via_identity(self, X : torch.Tensor, A : torch.Tensor, B : torch.Tensor) -> torch.Tensor:
124 | """
125 | Parallel transport of the tensors in X around A to the identity matrix I
126 | Parallel transport from around the identity matrix to the new center (tensor B)
127 | """
128 | Ainvsq = functionals.sym_invsqrtm.apply(A)
129 | Bsq = functionals.sym_sqrtm.apply(B)
130 | return Bsq @ (Ainvsq @ X @ Ainvsq) @ Bsq
131 |
132 | def transp_identity_rescale_transp(self, X : torch.Tensor, A : torch.Tensor, s : torch.Tensor, B : torch.Tensor) -> torch.Tensor:
133 | """
134 | Parallel transport of the tensors in X around A to the identity matrix I
135 | Rescales the dispersion by the factor s
136 | Parallel transport from the identity to the new center (tensor B)
137 | """
138 | Ainvsq = functionals.sym_invsqrtm.apply(A)
139 | Bsq = functionals.sym_sqrtm.apply(B)
140 | return Bsq @ functionals.sym_powm.apply(Ainvsq @ X @ Ainvsq, s) @ Bsq
141 |
142 | def transp_identity_rescale_rotate_transp(self, X : torch.Tensor, A : torch.Tensor, s : torch.Tensor, B : torch.Tensor, W : torch.Tensor) -> torch.Tensor:
143 | """
144 | Parallel transport of the tensors in X around A to the identity matrix I
145 | Rescales the dispersion by the factor s
146 | Parallel transport from the identity to the new center (tensor B)
147 | """
148 | Ainvsq = functionals.sym_invsqrtm.apply(A)
149 | Bsq = functionals.sym_sqrtm.apply(B)
150 | WBsq = W @ Bsq
151 | return WBsq.transpose(-2,-1) @ functionals.sym_powm.apply(Ainvsq @ X @ Ainvsq, s) @ WBsq
152 |
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import BaseModel, FineTuneableModel, CPUModel, PatternInterpretableModel
2 | from .base import DomainAdaptBaseModel
3 | from .base import DomainAdaptFineTuneableModel, DomainAdaptJointTrainableModel
4 |
5 | from .eegnet import EEGNetv4, DANNEEGNet
6 | from .shconvnet import ShallowConvNet,DANNShallowConvNet,ShConvNetDSBN
7 | from .tsmnet import TSMNet, CNNNet
8 | from .tsmnetMLR import TSMNetMLR
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/models/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/models/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/models/__pycache__/base.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/models/__pycache__/base.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/models/__pycache__/dann.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/models/__pycache__/dann.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/models/__pycache__/eegnet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/models/__pycache__/eegnet.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/models/__pycache__/shconvnet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/models/__pycache__/shconvnet.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/models/__pycache__/tsmnet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/models/__pycache__/tsmnet.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/models/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class BaseModel(nn.Module):
6 | def __init__(self, nclasses=None, nchannels=None, nsamples=None, nbands=None, device=None, input_shape=None):
7 | super().__init__()
8 | self.device_ = device
9 | self.lossfn = torch.nn.CrossEntropyLoss()
10 | self.nclasses_ = nclasses
11 | self.nchannels_ = nchannels
12 | self.nsamples_ = nsamples
13 | self.nbands_ = nbands
14 | self.input_shape_ = input_shape
15 |
16 | # AUXILIARY METHODS
17 | def calculate_classification_accuracy(self, Y, Y_lat):
18 | Y_hat = Y_lat.argmax(1)
19 | acc = Y_hat.eq(Y).float().mean().item()
20 | P_hat = torch.softmax(Y_lat, dim=1)
21 | return acc, P_hat
22 |
23 | def calculate_objective(self, model_pred, y_true, model_inp=None):
24 | # Y_lat, l = self(X.to(self.device_), B)
25 | if isinstance(model_pred, (list, tuple)):
26 | y_class_hat = model_pred[0]
27 | else:
28 | y_class_hat = model_pred
29 | loss = self.lossfn(y_class_hat, y_true.to(y_class_hat.device))
30 | return loss
31 |
32 | def get_hyperparameters(self):
33 | return dict(nchannels = self.nchannels_,
34 | nclasses=self.nclasses_,
35 | nsamples=self.nsamples_,
36 | nbands=self.nbands_)
37 |
38 |
39 | class CPUModel:
40 | pass
41 |
42 |
43 | class FineTuneableModel:
44 | def finetune(self, x, y, d):
45 | raise NotImplementedError()
46 |
47 |
48 | class DomainAdaptBaseModel(BaseModel):
49 | def __init__(self, domains = [], **kwargs):
50 | super().__init__(**kwargs)
51 | self.domains_ = domains
52 |
53 |
54 | class DomainAdaptFineTuneableModel(DomainAdaptBaseModel):
55 | def domainadapt_finetune(self, x, y, d, target_domains):
56 | raise NotImplementedError()
57 |
58 |
59 | class DomainAdaptJointTrainableModel(DomainAdaptBaseModel):
60 | def calculate_objective(self, model_pred, y_true, model_inp=None):
61 | # filter out masked observations
62 | keep = y_true != -1 # special label
63 |
64 | if isinstance(model_pred, (list, tuple)):
65 | y_class_hat = model_pred[0]
66 | else:
67 | y_class_hat = model_pred
68 |
69 | return super().calculate_objective(y_class_hat[keep], y_true[keep], None)
70 |
71 |
72 | class PatternInterpretableModel:
73 | def compute_patterns(self, x, y, d):
74 | raise NotImplementedError()
75 |
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/models/dann.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .base import DomainAdaptJointTrainableModel
3 | import spdnets.modules as modules
4 |
5 | class DANNBase(DomainAdaptJointTrainableModel):
6 | """
7 | Domain adeversarial neural network (DANN) proposed
8 | by Ganin et al. 2016, JMLR
9 | """
10 | def __init__(self, daloss_scaling = 1., dann_mode = 'ganin2016', **kwargs):
11 | domains = kwargs['domains']
12 | assert (domains.dtype == torch.long)
13 | kwargs['domains'] = domains.sort()[0]
14 | super().__init__(**kwargs)
15 | self.dann_mode_ = dann_mode
16 |
17 | if self.dann_mode_ == 'ganin2015':
18 | grad_reversal_scaling = daloss_scaling
19 | self.daloss_scaling_ = 1.
20 | elif self.dann_mode_ == 'ganin2016':
21 | grad_reversal_scaling = 1.
22 | self.daloss_scaling_ = daloss_scaling
23 | else:
24 | raise NotImplementedError()
25 |
26 | ndim_latent = self._ndim_latent()
27 | self.adversary_loss = torch.nn.CrossEntropyLoss()
28 |
29 | self.adversary = torch.nn.Sequential(
30 | torch.nn.Flatten(start_dim=1),
31 | modules.ReverseGradient(scaling=grad_reversal_scaling),
32 | torch.nn.Linear(ndim_latent, len(self.domains_))
33 | ).to(self.device_)
34 |
35 | def _ndim_latent(self):
36 | raise NotImplementedError()
37 |
38 | def forward(self, l, d):
39 | # super().forward()
40 | # h = self.cnn(x[:,None,...]).flatten(start_dim=1)
41 | # y = self.classifier(h)
42 | y_domain = self.adversary(l)
43 | return y_domain
44 |
45 | def domainadapt(self, x, y, d, target_domain):
46 | pass # domain adaptation is done during the training process
47 |
48 | def calculate_objective(self, model_pred, y_true, model_inp):
49 | loss = super().calculate_objective(model_pred, y_true, model_inp)
50 | domain = model_inp['d']
51 | y_dom_hat = model_pred[1]
52 | # check if all requested domains were declared
53 | assert ((self.domains_[..., None] == domain[None,...]).any(dim=0).all())
54 | # assign to the class indices (buckets)
55 | y_dom = torch.bucketize(domain, self.domains_).to(y_dom_hat.device)
56 |
57 | adversarial_loss = self.adversary_loss(y_dom_hat, y_dom)
58 | loss = loss + self.daloss_scaling_ * adversarial_loss
59 |
60 | return loss
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/models/eegnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .base import BaseModel
3 | from .dann import DANNBase
4 | import spdnets.modules as modules
5 |
6 |
7 | class EEGNetv4(BaseModel):
8 | def __init__(self, is_within = False, srate = 128, f1 = 8, d = 2, **kwargs):
9 | super().__init__(**kwargs)
10 | self.is_within_ = is_within
11 | self.srate_ = srate
12 | self.f1_ = f1
13 | self.d_ = d
14 | self.f2_ = self.f1_ * self.d_
15 | momentum = 0.01
16 |
17 | kernel_length = int(self.srate_ // 2)
18 | nlatsamples_time = self.nsamples_ // 32
19 |
20 | temp2_kernel_length = int(self.srate_ // 2 // 4)
21 |
22 | if self.is_within_:
23 | drop_prob = 0.5
24 | else:
25 | drop_prob = 0.25
26 |
27 | bntemp = torch.nn.BatchNorm2d(self.f1_, momentum=momentum, affine=True, eps=1e-3)
28 | bnspat = torch.nn.BatchNorm2d(self.f1_ * self.d_, momentum=momentum, affine=True, eps=1e-3)
29 |
30 | self.cnn = torch.nn.Sequential(
31 | torch.nn.Conv2d(1,self.f1_,(1, kernel_length), bias=False, padding='same'),
32 | bntemp,
33 | modules.Conv2dWithNormConstraint(self.f1_, self.f1_ * self.d_, (self.nchannels_, 1), max_norm=1,
34 | stride=1, bias=False, groups=self.f1_, padding=(0, 0)),
35 | bnspat,
36 | torch.nn.ELU(),
37 | torch.nn.AvgPool2d(kernel_size=(1, 4), stride=(1, 4)),
38 | torch.nn.Dropout(p=drop_prob),
39 | torch.nn.Conv2d(self.f1_ * self.d_, self.f1_ * self.d_, (1, temp2_kernel_length),
40 | stride=1, bias=False, groups=self.f1_ * self.d_, padding='same'),
41 | torch.nn.Conv2d(self.f1_ * self.d_, self.f2_, (1, 1),
42 | stride=1, bias=False, padding=(0, 0)),
43 | torch.nn.BatchNorm2d(self.f2_, momentum=momentum, affine=True, eps=1e-3),
44 | torch.nn.ELU(),
45 | torch.nn.AvgPool2d(kernel_size=(1, 8), stride=(1, 8)),
46 | torch.nn.Dropout(p=drop_prob),
47 | ).to(self.device_)
48 |
49 | self.classifier = torch.nn.Sequential(
50 | torch.nn.Flatten(start_dim=1),
51 | modules.LinearWithNormConstraint(self.f2_ * nlatsamples_time, self.nclasses_, max_norm=0.25)
52 | ).to(self.device_)
53 |
54 | def get_hyperparameters(self):
55 | kwargs = super().get_hyperparameters()
56 | kwargs['nsamples'] = self.nsamples_
57 | kwargs['is_within_subject'] = self.is_within_subject_
58 | kwargs['srate'] = self.srate_
59 | kwargs['f1'] = self.f1_
60 | kwargs['d'] = self.d_
61 | return kwargs
62 |
63 | def forward(self, x, d):
64 | l = self.cnn(x[:,None,...])
65 | y = self.classifier(l)
66 | return y, l
67 |
68 |
69 | class DANNEEGNet(DANNBase, EEGNetv4):
70 | """
71 | Domain adeversarial neural network (DANN) proposed for EEG MI classification
72 | by Ozdenizci et al. 2020, IEEE Access
73 | """
74 | def __init__(self, daloss_scaling = 0.03, dann_mode = 'ganin2016', **kwargs):
75 | kwargs['daloss_scaling'] = daloss_scaling
76 | kwargs['dann_mode'] = dann_mode
77 | super().__init__(**kwargs)
78 |
79 | def _ndim_latent(self):
80 | return self.classifier[-1].weight.shape[-1]
81 |
82 | def forward(self, x, d):
83 | y, l = EEGNetv4.forward(self, x, d)
84 | y_domain = DANNBase.forward(self, l, d)
85 | return y, y_domain
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/models/shconvnet.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import torch
3 | from .base import BaseModel, DomainAdaptFineTuneableModel
4 | from .dann import DANNBase
5 | import spdnets.batchnorm as bn
6 | import spdnets.modules as modules
7 |
8 |
9 | class ShallowConvNet(BaseModel):
10 | def __init__(self, spatial_filters = 40, temporal_filters = 40, pdrop = 0.5, **kwargs):
11 | super().__init__(**kwargs)
12 | self.spatial_filters_ = spatial_filters
13 | self.temporal_filters_ = temporal_filters
14 |
15 | temp_cnn_kernel = 25
16 | temp_pool_kernel = 75
17 | temp_pool_stride = 15
18 | ntempconvout = int((self.nsamples_ - 1*(temp_cnn_kernel-1) - 1)/1 + 1)
19 | navgpoolout = int((ntempconvout - temp_pool_kernel)/temp_pool_stride + 1)
20 |
21 | self.bn = torch.nn.BatchNorm2d(self.spatial_filters_)
22 | drop = torch.nn.Dropout(p=pdrop)
23 |
24 | self.cnn = torch.nn.Sequential(
25 | torch.nn.Conv2d(1,self.temporal_filters_,(1, temp_cnn_kernel)),
26 | torch.nn.Conv2d(self.temporal_filters_, self.spatial_filters_,(self.nchannels_, 1)),
27 | ).to(self.device_)
28 | self.pool = torch.nn.Sequential(
29 | modules.MySquare(),
30 | torch.nn.AvgPool2d(kernel_size=(1, temp_pool_kernel), stride=(1, temp_pool_stride)),
31 | modules.MyLog(),
32 | drop,
33 | torch.nn.Flatten(start_dim=1),
34 | ).to(self.device_)
35 | self.classifier = torch.nn.Sequential(
36 | torch.nn.Linear(self.spatial_filters_ * navgpoolout, self.nclasses_),
37 | ).to(self.device_)
38 |
39 | def forward(self,x, d):
40 | l = self.cnn(x.to(self.device_)[:,None,...])
41 | l = self.bn(l)
42 | l = self.pool(l)
43 | y = self.classifier(l)
44 | return y, l
45 |
46 |
47 | class DANNShallowConvNet(DANNBase, ShallowConvNet):
48 | """
49 | Domain adeversarial neural network (DANN) proposed for EEG MI classification
50 | by Ozdenizci et al. 2020, IEEE Access
51 | """
52 | def __init__(self, daloss_scaling = 0.05, dann_mode = 'ganin2016', **kwargs):
53 | kwargs['daloss_scaling'] = daloss_scaling
54 | kwargs['dann_mode'] = dann_mode
55 | super().__init__(**kwargs)
56 |
57 | def _ndim_latent(self):
58 | return self.classifier[-1].weight.shape[-1]
59 |
60 | def forward(self, x, d):
61 | y, l = ShallowConvNet.forward(self, x, d)
62 | y_domain = DANNBase.forward(self, l, d)
63 | return y, y_domain
64 |
65 |
66 | class ShConvNetDSBN(ShallowConvNet, DomainAdaptFineTuneableModel):
67 | def __init__(self,
68 | bnorm_dispersion : Union[str, bn.BatchNormDispersion] = bn.BatchNormDispersion.VECTOR,
69 | **kwargs):
70 | super().__init__(**kwargs)
71 |
72 | if isinstance(bnorm_dispersion, str):
73 | bnorm_dispersion = bn.BatchNormDispersion[bnorm_dispersion]
74 |
75 | self.bn = bn.AdaMomDomainBatchNorm((1, self.spatial_filters_, 1, 1),
76 | batchdim=[0,2,3], # same as batch norm 2D
77 | domains=self.domains_,
78 | dispersion=bnorm_dispersion,
79 | eta=1., eta_test=.1).to(self.device_)
80 |
81 | def forward(self,x, d):
82 | l = self.cnn(x.to(self.device_)[:,None,...])
83 | l = self.bn(l,d.to(device=self.device_))
84 | l = self.pool(l)
85 | y = self.classifier(l)
86 | return y, l
87 |
88 | def domainadapt_finetune(self, x, y, d, target_domains):
89 | self.bn.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT)
90 | for du in d.unique():
91 | self.forward(x[d==du], d[d==du])
92 | self.bn.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER)
93 |
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/models/tsmnet.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 | import torch
3 |
4 | import spdnets.modules as modules
5 | import spdnets.batchnorm as bn
6 | import spdnets.Liebatchnorm as LieBN
7 | from spdnets.BaseBatchNorm import BatchNormDispersion
8 | from .base import DomainAdaptFineTuneableModel, FineTuneableModel, PatternInterpretableModel
9 |
10 |
11 | class TSMNet(DomainAdaptFineTuneableModel, FineTuneableModel, PatternInterpretableModel):
12 | def __init__(self, temporal_filters, spatial_filters = 40,
13 | subspacedims = 20,
14 | temp_cnn_kernel = 25,
15 | bnorm : Optional[str] = 'spdbn',
16 | bnorm_dispersion : Union[str, BatchNormDispersion] = BatchNormDispersion.SCALAR,
17 | metric: str ='AIM',theta: float =1., alpha: float=1.0, beta: float=0.,learn_mean=False,
18 | **kwargs):
19 | super().__init__(**kwargs)
20 | self.learn_mean = learn_mean;self.metric = metric;self.theta = torch.tensor(theta);self.alpha = torch.tensor(alpha);self.beta = torch.tensor(beta)
21 |
22 | self.temporal_filters_ = temporal_filters
23 | self.spatial_filters_ = spatial_filters
24 | self.subspacedimes = subspacedims
25 | self.bnorm_ = bnorm
26 | self.spd_device_ = torch.device('cpu')
27 | if isinstance(bnorm_dispersion, str):
28 | self.bnorm_dispersion_ = BatchNormDispersion[bnorm_dispersion]
29 | else:
30 | self.bnorm_dispersion_ = bnorm_dispersion
31 |
32 | tsdim = int(subspacedims*(subspacedims+1)/2)
33 |
34 | self.cnn = torch.nn.Sequential(
35 | torch.nn.Conv2d(1, self.temporal_filters_, kernel_size=(1,temp_cnn_kernel),
36 | padding='same', padding_mode='reflect'),
37 | torch.nn.Conv2d(self.temporal_filters_, self.spatial_filters_,(self.nchannels_, 1)),
38 | torch.nn.Flatten(start_dim=2),
39 | ).to(self.device_)
40 |
41 | self.cov_pooling = torch.nn.Sequential(
42 | modules.CovariancePool(),
43 | )
44 |
45 | if self.bnorm_ == 'spddsbn':
46 | self.spddsbnorm = bn.AdaMomDomainSPDBatchNorm((1,subspacedims,subspacedims), batchdim=0,
47 | domains=self.domains_,
48 | learn_mean=self.learn_mean,learn_std=True,
49 | dispersion=self.bnorm_dispersion_,
50 | eta=1., eta_test=.1, dtype=torch.double, device=self.spd_device_)
51 | elif self.bnorm_ == 'LieBN':
52 | self.spddsbnorm = LieBN.AdaMomDomainSPDLieBatchNorm((1,subspacedims,subspacedims), batchdim=0,
53 | domains=self.domains_,
54 | learn_mean=self.learn_mean,learn_std=True,
55 | dispersion=self.bnorm_dispersion_,
56 | eta=1., eta_test=.1, dtype=torch.double, device=self.spd_device_,
57 | metric=self.metric,theta=self.theta,beta=self.beta,alpha=self.alpha)
58 |
59 | elif self.bnorm_ is not None:
60 | raise NotImplementedError('requested undefined batch normalization method.')
61 |
62 | self.spdnet = torch.nn.Sequential(
63 | modules.BiMap((1,self.spatial_filters_,subspacedims), dtype=torch.double, device=self.spd_device_),
64 | modules.ReEig(threshold=1e-4),
65 | )
66 | self.logeig = torch.nn.Sequential(
67 | modules.LogEig(subspacedims),
68 | torch.nn.Flatten(start_dim=1),
69 | )
70 | self.classifier = torch.nn.Sequential(
71 | torch.nn.Linear(tsdim,self.nclasses_).double(),
72 | ).to(self.spd_device_)
73 |
74 | def to(self, device: Optional[Union[int, torch.device]] = None, dtype: Optional[Union[int, torch.dtype]] = None, non_blocking: bool = False):
75 | if device is not None:
76 | self.device_ = device
77 | self.cnn.to(self.device_)
78 | return super().to(device=None, dtype=dtype, non_blocking=non_blocking)
79 |
80 | def forward(self, x, d, return_latent=True, return_prebn=False, return_postbn=False):
81 | out = ()
82 | h = self.cnn(x.to(device=self.device_)[:,None,...])
83 | C = self.cov_pooling(h).to(device=self.spd_device_, dtype=torch.double)
84 | l = self.spdnet(C)
85 | out += (l,) if return_prebn else ()
86 | l = self.spdbnorm(l) if hasattr(self, 'spdbnorm') else l
87 | l = self.spddsbnorm(l,d.to(device=self.spd_device_)) if hasattr(self, 'spddsbnorm') else l
88 | out += (l,) if return_postbn else ()
89 | l = self.logeig(l)
90 | l = self.tsbnorm(l) if hasattr(self, 'tsbnorm') else l
91 | l = self.tsdsbnorm(l,d) if hasattr(self, 'tsdsbnorm') else l
92 | out += (l,) if return_latent else ()
93 | y = self.classifier(l)
94 | out = y if len(out) == 0 else (y, *out[::-1])
95 | return out
96 |
97 | def domainadapt_finetune(self, x, y, d, target_domains):
98 | if hasattr(self, 'spddsbnorm'):
99 | self.spddsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT)
100 | if hasattr(self, 'tsdsbnorm'):
101 | self.tsdsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT)
102 |
103 | with torch.no_grad():
104 | for du in d.unique():
105 | self.forward(x[d==du], d[d==du])
106 |
107 | if hasattr(self, 'spddsbnorm'):
108 | self.spddsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER)
109 | if hasattr(self, 'tsdsbnorm'):
110 | self.tsdsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER)
111 |
112 | def finetune(self, x, y, d):
113 | if hasattr(self, 'spdbnorm'):
114 | self.spdbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT)
115 | if hasattr(self, 'tsbnorm'):
116 | self.tsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT)
117 |
118 | with torch.no_grad():
119 | self.forward(x, d)
120 |
121 | if hasattr(self, 'spdbnorm'):
122 | self.spdbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER)
123 | if hasattr(self, 'tsbnorm'):
124 | self.tsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER)
125 |
126 | def compute_patterns(self, x, y, d):
127 | pass
128 |
129 |
130 |
131 | class CNNNet(DomainAdaptFineTuneableModel, FineTuneableModel):
132 | def __init__(self, temporal_filters, spatial_filters = 40,
133 | temp_cnn_kernel = 25,
134 | bnorm : Optional[str] = 'bn',
135 | bnorm_dispersion : Union[str, BatchNormDispersion] = BatchNormDispersion.SCALAR,
136 | **kwargs):
137 | super().__init__(**kwargs)
138 |
139 | self.temporal_filters_ = temporal_filters
140 | self.spatial_filters_ = spatial_filters
141 | self.bnorm_ = bnorm
142 |
143 | if isinstance(bnorm_dispersion, str):
144 | self.bnorm_dispersion_ = BatchNormDispersion[bnorm_dispersion]
145 | else:
146 | self.bnorm_dispersion_ = bnorm_dispersion
147 |
148 | self.cnn = torch.nn.Sequential(
149 | torch.nn.Conv2d(1, self.temporal_filters_, kernel_size=(1,temp_cnn_kernel),
150 | padding='same', padding_mode='reflect'),
151 | torch.nn.Conv2d(self.temporal_filters_, self.spatial_filters_,(self.nchannels_, 1)),
152 | torch.nn.Flatten(start_dim=2),
153 | ).to(self.device_)
154 |
155 | self.cov_pooling = torch.nn.Sequential(
156 | modules.CovariancePool(),
157 | )
158 |
159 | if self.bnorm_ == 'bn':
160 | self.bnorm = bn.AdaMomBatchNorm((1, self.spatial_filters_), batchdim=0, dispersion=self.bnorm_dispersion_,
161 | eta=1., eta_test=.1).to(self.device_)
162 | elif self.bnorm_ == 'dsbn':
163 | self.dsbnorm = bn.AdaMomDomainBatchNorm((1, self.spatial_filters_), batchdim=0,
164 | domains=self.domains_,
165 | dispersion=self.bnorm_dispersion_,
166 | eta=1., eta_test=.1).to(self.device_)
167 | elif self.bnorm_ is not None:
168 | raise NotImplementedError('requested undefined batch normalization method.')
169 |
170 | self.logarithm = torch.nn.Sequential(
171 | modules.MyLog(),
172 | torch.nn.Flatten(start_dim=1),
173 | )
174 | self.classifier = torch.nn.Sequential(
175 | torch.nn.Linear(self.spatial_filters_,self.nclasses_),
176 | ).to(self.device_)
177 |
178 | def forward(self, x, d, return_latent=True):
179 | out = ()
180 | h = self.cnn(x.to(device=self.device_)[:,None,...])
181 | C = self.cov_pooling(h)
182 | l = torch.diagonal(C, dim1=-2, dim2=-1)
183 | l = self.logarithm(l)
184 | l = self.bnorm(l) if hasattr(self, 'bnorm') else l
185 | l = self.dsbnorm(l,d) if hasattr(self, 'dsbnorm') else l
186 | out += (l,) if return_latent else ()
187 | y = self.classifier(l)
188 | out = y if len(out) == 0 else (y, *out[::-1])
189 | return out
190 |
191 | def domainadapt_finetune(self, x, y, d, target_domains):
192 | if hasattr(self, 'dsbnorm'):
193 | self.dsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT)
194 |
195 | with torch.no_grad():
196 | for du in d.unique():
197 | self.forward(x[d==du], d[d==du])
198 |
199 | if hasattr(self, 'dsbnorm'):
200 | self.dsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER)
201 |
202 | def finetune(self, x, y, d):
203 | if hasattr(self, 'bnorm'):
204 | self.bnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT)
205 |
206 | with torch.no_grad():
207 | self.forward(x, d)
208 |
209 | if hasattr(self, 'bnorm'):
210 | self.bnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER)
211 |
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/modules.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Tuple
3 | import torch
4 | from torch import Tensor
5 | from torch.nn.parameter import Parameter
6 | from torch.types import Number
7 | import torch.nn as nn
8 | from geoopt.tensor import ManifoldParameter
9 | from geoopt.manifolds import Stiefel, Sphere
10 | from . import functionals
11 |
12 | class Conv2dWithNormConstraint(torch.nn.Conv2d):
13 | def __init__(self, *args, max_norm=1, **kwargs):
14 | self.max_norm = max_norm
15 | super(Conv2dWithNormConstraint, self).__init__(*args, **kwargs)
16 |
17 | def forward(self, x):
18 | self.weight.data = torch.renorm(
19 | self.weight.data, p=2, dim=0, maxnorm=self.max_norm
20 | )
21 | return super(Conv2dWithNormConstraint, self).forward(x)
22 |
23 |
24 | class LinearWithNormConstraint(torch.nn.Linear):
25 | def __init__(self, *args, max_norm=1, **kwargs):
26 | self.max_norm = max_norm
27 | super(LinearWithNormConstraint, self).__init__(*args, **kwargs)
28 |
29 | def forward(self, x):
30 | self.weight.data = torch.renorm(
31 | self.weight.data, p=2, dim=0, maxnorm=self.max_norm
32 | )
33 | return super(LinearWithNormConstraint, self).forward(x)
34 |
35 |
36 | class MySquare(torch.nn.Module):
37 | def __init__(self):
38 | super().__init__()
39 | def forward(self, x):
40 | return x.square()
41 |
42 | class MyLog(torch.nn.Module):
43 | def __init__(self, eps = 1e-4):
44 | super().__init__()
45 | self.eps = eps
46 | def forward(self, x):
47 | return torch.log(x + self.eps)
48 |
49 |
50 | class MyConv2d(torch.nn.Conv2d):
51 | def __init__(self, *args, **kwargs):
52 | super(MyConv2d, self).__init__(*args, **kwargs)
53 |
54 | self.convshape = self.weight.shape
55 | w0 = self.weight.data.flatten(start_dim=1)
56 | self.weight = ManifoldParameter(w0 / w0.norm(dim=-1, keepdim=True), manifold=Sphere())
57 |
58 | def forward(self, x):
59 | return self._conv_forward(x, self.weight.view(self.convshape), self.bias)
60 |
61 |
62 | class UnitNormLinear(torch.nn.Linear):
63 | def __init__(self, *args, **kwargs):
64 | super(UnitNormLinear, self).__init__(*args, **kwargs)
65 |
66 | w0 = self.weight.data.flatten(start_dim=1)
67 | self.weight = ManifoldParameter(w0 / w0.norm(dim=-1, keepdim=True), manifold=Sphere())
68 |
69 | def forward(self, x):
70 | return super().forward(x)
71 |
72 |
73 | class MyLinear(nn.Module):
74 | def __init__(self, shape : Tuple[int, ...] or torch.Size, bias: bool = True, **kwargs):
75 | super().__init__()
76 |
77 | self.W = Parameter(torch.empty(shape, **kwargs))
78 |
79 | if bias:
80 | self.bias = Parameter(torch.empty((*shape[:-2], shape[-1]), **kwargs))
81 | else:
82 | self.register_parameter('bias', None)
83 |
84 | self.reset_parameters()
85 |
86 | def forward(self, X : Tensor) -> Tensor:
87 | A = (X.unsqueeze(-2) @ self.W).squeeze(-2)
88 | if self.bias is not None:
89 | A += self.bias
90 | return A
91 |
92 | @torch.no_grad()
93 | def reset_parameters(self):
94 | # kaiming initialization std2uniformbound * gain * fan_in
95 | bound = math.sqrt(3) * 1. / math.sqrt(self.W.shape[-2])
96 | self.W.data.uniform_(-bound,bound)
97 | if self.bias is not None:
98 | bound = 1 / math.sqrt(self.W.shape[-2])
99 | self.bias.data.uniform_(-bound, bound)
100 |
101 |
102 | class Encode2DPosition(nn.Module):
103 | """
104 | Encodes the 2D position of a 2D CNN or 2D image
105 | as additional channels.
106 | Input: (batch, chans, height, width)
107 | Output: (batch, chans+2, height, width)
108 | """
109 | def __init__(self, flatten = True):
110 | super().__init__()
111 | self.flatten = flatten
112 |
113 | def forward(self, X : Tensor) -> Tensor:
114 | pos1 = torch.arange(X.shape[-2])[None,None,:,None].tile((X.shape[0],1, 1, X.shape[-1])) / X.shape[-2]
115 | pos2 = torch.arange(X.shape[-1])[None,None,None,:].tile((X.shape[0],1, X.shape[-2], 1)) / X.shape[-1]
116 |
117 | Z = torch.cat((X, pos1, pos2),dim=1)
118 | if self.flatten:
119 | Z = Z.flatten(start_dim=-2)
120 |
121 | return Z
122 |
123 |
124 | class CovariancePool(nn.Module):
125 | def __init__(self, alpha = None, unitvar = False):
126 | super().__init__()
127 | self.pooldim = -1
128 | self.chandim = -2
129 | self.alpha = alpha
130 | self.unitvar = unitvar
131 |
132 | def forward(self, X : Tensor) -> Tensor:
133 | X0 = X - X.mean(dim=self.pooldim, keepdim=True)
134 | if self.unitvar:
135 | X0 = X0 / X0.std(dim=self.pooldim, keepdim=True)
136 | X0.nan_to_num_(0)
137 |
138 | C = (X0 @ X0.transpose(-2, -1)) / X0.shape[self.pooldim]
139 | if self.alpha is not None:
140 | Cd = C.diagonal(dim1=self.pooldim, dim2=self.pooldim-1)
141 | Cd += self.alpha
142 | return C
143 |
144 |
145 | class ReverseGradient(nn.Module):
146 | def __init__(self, scaling = 1.):
147 | super().__init__()
148 | self.scaling_ = scaling
149 |
150 | def forward(self, X : Tensor) -> Tensor:
151 | return functionals.reverse_gradient.apply(X, self.scaling_)
152 |
153 |
154 | class BiMap(nn.Module):
155 | def __init__(self, shape : Tuple[int, ...] or torch.Size, W0 : Tensor = None, manifold='stiefel', **kwargs):
156 | super().__init__()
157 |
158 | if manifold == 'stiefel':
159 | assert(shape[-2] >= shape[-1])
160 | mf = Stiefel()
161 | elif manifold == 'sphere':
162 | mf = Sphere()
163 | shape = list(shape)
164 | shape[-1], shape[-2] = shape[-2], shape[-1]
165 | else:
166 | raise NotImplementedError()
167 |
168 | # add constraint (also initializes the parameter to fulfill the constraint)
169 | self.W = ManifoldParameter(torch.empty(shape, **kwargs), manifold=mf)
170 |
171 | # optionally initialize the weights (initialization has to fulfill the constraint!)
172 | if W0 is not None:
173 | self.W.data = W0 # e.g., self.W = torch.nn.init.orthogonal_(self.W)
174 | else:
175 | self.reset_parameters()
176 |
177 | def forward(self, X : Tensor) -> Tensor:
178 | if isinstance(self.W.manifold, Sphere):
179 | return self.W @ X @ self.W.transpose(-2,-1)
180 | else:
181 | return self.W.transpose(-2,-1) @ X @ self.W
182 |
183 | @torch.no_grad()
184 | def reset_parameters(self):
185 | if isinstance(self.W.manifold, Stiefel):
186 | # uniform initialization on stiefel manifold after theorem 2.2.1 in Chikuse (2003): statistics on special manifolds
187 | W = torch.rand(self.W.shape, dtype=self.W.dtype, device=self.W.device)
188 | self.W.data = W @ functionals.sym_invsqrtm.apply(W.transpose(-1,-2) @ W)
189 | elif isinstance(self.W.manifold, Sphere):
190 | W = torch.empty(self.W.shape, dtype=self.W.dtype, device=self.W.device)
191 | # kaiming initialization std2uniformbound * gain * fan_in
192 | bound = math.sqrt(3) * 1. / W.shape[-1]
193 | W.uniform_(-bound, bound)
194 | # constraint has to be satisfied
195 | self.W.data = W / W.norm(dim=-1, keepdim=True)
196 | else:
197 | raise NotImplementedError()
198 |
199 |
200 | class ReEig(nn.Module):
201 | def __init__(self, threshold : Number = 1e-4):
202 | super().__init__()
203 | self.threshold = Tensor([threshold])
204 |
205 | def forward(self, X : Tensor) -> Tensor:
206 | return functionals.sym_reeig.apply(X, self.threshold)
207 |
208 |
209 | class LogEig(nn.Module):
210 | def __init__(self, ndim, tril=True):
211 | super().__init__()
212 |
213 | self.tril = tril
214 | if self.tril:
215 | ixs_lower = torch.tril_indices(ndim,ndim, offset=-1)
216 | ixs_diag = torch.arange(start=0, end=ndim, dtype=torch.long)
217 | self.ixs = torch.cat((ixs_diag[None,:].tile((2,1)), ixs_lower), dim=1)
218 | self.ndim = ndim
219 |
220 | def forward(self, X : Tensor) -> Tensor:
221 | return self.embed(functionals.sym_logm.apply(X))
222 |
223 | def embed(self, X : Tensor) -> Tensor:
224 | if self.tril:
225 | x_vec = X[...,self.ixs[0],self.ixs[1]]
226 | x_vec[...,self.ndim:] *= math.sqrt(2)
227 | else:
228 | x_vec = X.flatten(start_dim=-2)
229 | return x_vec
230 |
231 |
232 | class TrilEmbedder(nn.Module):
233 |
234 | def forward(self, X : Tensor) -> Tensor:
235 |
236 | ndim = X.shape[-1]
237 | ixs_lower = torch.tril_indices(ndim,ndim, offset=-1)
238 | ixs_diag = torch.arange(start=0, end=ndim, dtype=torch.long)
239 | ixs = torch.cat((ixs_diag[None,:].tile((2,1)), ixs_lower), dim=1)
240 |
241 | x_vec = X[...,ixs[0],ixs[1]]
242 | x_vec[...,ndim:] *= math.sqrt(2)
243 | return x_vec
244 |
245 | def inverse_transform(self, x_vec: Tensor) -> Tensor:
246 |
247 | ndim = int(-.5 + math.sqrt(.25 + 2*x_vec.shape[-1])) # c*(c+1)/2 = nts
248 | ixs_lower = torch.tril_indices(ndim,ndim, offset=-1)
249 | ixs_diag = torch.arange(start=0, end=ndim, dtype=torch.long)
250 |
251 | X = torch.zeros(x_vec.shape[:-1] + (ndim, ndim), device=x_vec.device, dtype=x_vec.dtype)
252 |
253 | # off diagonal elements
254 | X[...,ixs_lower[0],ixs_lower[1]] = x_vec[...,ndim:] / math.sqrt(2)
255 | X[...,ixs_lower[1],ixs_lower[0]] = x_vec[...,ndim:] / math.sqrt(2)
256 | X[...,ixs_diag,ixs_diag] = x_vec[...,:ndim]
257 |
258 | return X
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/utils/skorch/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .logging import TrainLog
3 | from .network import DomainAdaptNeuralNetClassifier
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/utils/skorch/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/utils/skorch/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/utils/skorch/__pycache__/logging.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/utils/skorch/__pycache__/logging.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/utils/skorch/__pycache__/network.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitZH-Chen/LieBN/ff2e6a5276360ab862b38e827f6640233cf665d2/LieBN_TSMNet/spdnets/utils/skorch/__pycache__/network.cpython-310.pyc
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/utils/skorch/logging.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from skorch.callbacks.logging import EpochTimer, PrintLog
3 |
4 | log = logging.getLogger(__name__)
5 |
6 | class TrainLog(PrintLog):
7 |
8 | def __init__(self, prefix='') -> None:
9 | super().__init__()
10 | self.prefix = prefix
11 |
12 | def initialize(self):
13 | return self
14 |
15 | def on_epoch_end(self, net, **kwargs):
16 | r = net.history[-1]
17 |
18 | if r['epoch'] == 1 or r['epoch'] % 10 == 0:
19 | log.info(f"{self.prefix} {r['epoch']:3d} : trn={r['train_loss']:.3f}/{r['score_trn']:.2f} val={r['valid_loss']:.3f}/{r['score_val']:.2f} time: {r['dur']:.2f}")
20 |
21 |
22 |
--------------------------------------------------------------------------------
/LieBN_TSMNet/spdnets/utils/skorch/network.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | from skorch.classifier import NeuralNetClassifier
4 | from skorch.callbacks.logging import EpochTimer, PrintLog
5 | from skorch.callbacks.scoring import EpochScoring, PassthroughScoring
6 |
7 | from spdnets.models import BaseModel, DomainAdaptFineTuneableModel, FineTuneableModel
8 |
9 | from .logging import TrainLog
10 |
11 | log = logging.getLogger(__name__)
12 |
13 | class DomainAdaptNeuralNetClassifier(NeuralNetClassifier):
14 | def __init__(self, module, *args, criterion=torch.nn.CrossEntropyLoss, **kwargs):
15 | super().__init__(module, *args, criterion=criterion, **kwargs)
16 |
17 | @property
18 | def _default_callbacks(self):
19 | return [
20 | ('epoch_timer', EpochTimer()),
21 | ('train_loss', PassthroughScoring(
22 | name='train_loss',
23 | on_train=True,
24 | )),
25 | ('valid_loss', PassthroughScoring(
26 | name='valid_loss',
27 | )),
28 | ('print_log', TrainLog()),
29 | ]
30 |
31 | def get_loss(self, mdl_pred, y_true, X=None, **kwargs):
32 | if isinstance(self.module_, BaseModel):
33 | return self.module_.calculate_objective(mdl_pred, y_true, X)
34 | elif isinstance(mdl_pred, (list, tuple)):
35 | y_hat = mdl_pred[0]
36 | else:
37 | y_hat = mdl_pred
38 | return self.criterion_(y_hat, y_true.to(y_hat.device))
39 |
40 | def domainadapt_finetune(self, x: torch.Tensor, y: torch.Tensor, d : torch.Tensor, target_domains=None):
41 | if isinstance(self.module_, DomainAdaptFineTuneableModel):
42 | self.module_.domainadapt_finetune(x=x.to(self.device), y=y, d=d, target_domains=target_domains)
43 | else:
44 | log.info("Model does not support domain adapt fine tuning.")
45 |
46 | def finetune(self, x: torch.Tensor, y: torch.Tensor, d : torch.Tensor):
47 | if isinstance(self.module_, FineTuneableModel):
48 | self.module_.finetune(x=x.to(self.device), y=y, d=d)
49 | else:
50 | log.info("Model does not support fine-tuning.")
51 |
--------------------------------------------------------------------------------
/Readme.md:
--------------------------------------------------------------------------------
1 | [
](https://arxiv.org/abs/2403.11261)
2 | [
](https://openreview.net/forum?id=okYdj8Ysru)
3 | [
](https://openreview.net/pdf?id=okYdj8Ysru)
4 |
5 |
6 | # A Lie Group Approach to Riemannian Batch Normalization
7 |
8 | **Updates (02/2025):** We have integrated the LieBN implementations into a toolbox, now supporting nine invariant metrics across different matrix manifolds:
9 |
10 | - Symmetric Positive Definite (SPD) Manifold: Four distinct metrics, including a newly introduced right-invariant metric.
11 | - Rotation Group: One bi-invariant metric.
12 | - Full-Rank Correlation Manifold: Four recently developed correlation geometries.
13 |
14 | The complete implementations can be found in the `LieBN` folder. This toolbox is designed to be plug-and-play, making it easy to apply LieBN as a drop-in normalization module across different neural architectures.
15 |
16 | ## Introduction
17 |
18 | This is the official code for our ICLR 2024 publication: *A Lie Group Approach to Riemannian Batch Normalization*. [[OpenReview](https://openreview.net/forum?id=okYdj8Ysru)].
19 |
20 | If you find this project helpful, please consider citing us as follows:
21 |
22 | ```bib
23 | @inproceedings{chen2024liebn,
24 | title={A Lie Group Approach to Riemannian Batch Normalization},
25 | author={Ziheng Chen and Yue Song and Yunmei Liu and Nicu Sebe},
26 | booktitle={The Twelfth International Conference on Learning Representations},
27 | year={2024},
28 | url={https://openreview.net/forum?id=okYdj8Ysru}
29 | }
30 | ```
31 |
32 | In case you have any problem, do not hesitate to contact me ziheng_ch@163.com.
33 |
34 | ## Requirements
35 |
36 | Install necessary dependencies by `conda`:
37 |
38 | ```setup
39 | conda env create --file environment.yaml
40 | ```
41 | **Note** that the [hydra](https://hydra.cc/) package is used to manage configuration files.
42 |
43 | ## Experiments on the SPDNet
44 |
45 | The code of experiments on SPDNet, SPDNetBN, and SPDNetLieBN is enclosed in the folder `./LieBN_SPDNet`
46 |
47 | The implementation is based on the official code of *Riemannian batch normalization for SPD neural networks* [[Neurips 2019](https://papers.nips.cc/paper_files/paper/2019/hash/6e69ebbfad976d4637bb4b39de261bf7-Abstract.html)] [[code](https://papers.nips.cc/paper_files/paper/2019/file/6e69ebbfad976d4637bb4b39de261bf7-Supplemental.zip)].
48 |
49 | ### Dataset
50 |
51 | The synthetic [Radar](https://www.dropbox.com/s/dfnlx2bnyh3kjwy/data.zip?e=1&dl=0) dataset is released by SPDNetBN. We further release our preprocessed [HDM05](https://www.dropbox.com/scl/fi/x2ouxjwqj3zrb1idgkg2g/HDM05.zip?rlkey=4f90ktgzfz28x3i2i4ylu6dvu&dl=0) dataset.
52 |
53 | Please download the datasets and put them in your personal folder and change the `path` accordingly in `./LieBN_SPDNet/conf/dataset/RADAR.yaml` and `./LieBN_SPDNet/conf/dataset/HDM05.yaml`
54 |
55 | ### Running experiments
56 |
57 | To run all the experiments on the Radar and HDM05 datasets, go to the folder `LieBN_SPDNet` and run this command:
58 |
59 | ```train
60 | bash run_experiments.sh
61 | ```
62 | This script contains the experiments on the Radar and HDM05 datasets shown in Tab. 4
63 |
64 | ## Experiments on the TSMNet
65 |
66 | The code of experiments on TSMNet, TSMNet + SPDDSMBN, and TSMNet + DSMLieBN is enclosed in the folder `./LieBN_TSMNet`
67 |
68 | The implementation is based on the official code of *SPD domain-specific batch normalization to crack interpretable unsupervised domain adaptation in EEG* [[Neurips 2022](https://openreview.net/forum?id=pp7onaiM4VB)] [[code](https://github.com/rkobler/TSMNet.git)].
69 |
70 | ### Dataset
71 |
72 | The [Hinss2021](https://doi.org/10.5281/zenodo.5055046) dataset is publicly available. The [moabb](https://neurotechx.github.io/moabb/) and [mne](https://mne.tools) packages are used to download and preprocess these datasets. There is no need to manually download and preprocess the datasets. This is done automatically. If necessary, change the `data_dir` in `./LieBN_TSMNet/conf/LieBN.yaml` to your personal folder.
73 |
74 | ### Running experiments
75 |
76 | To run all the experiments on the Radar and HDM05 datasets, go to the folder `LieBN_TSMNet` and run this command:
77 |
78 | ```train
79 | bash run_experiments.sh
80 | ```
81 | This script contains the experiments on the Hinss2021 datasets shown in Tab. 5
82 |
83 | **Note:** You also can change the `data_dir` in `run_experiments.sh`, which will override the hydra config.
84 |
85 |
86 |
87 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: LieBN
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | # basic programs
7 | - python==3.8.*
8 | - pip
9 | # scientific python base packages
10 | - numpy==1.20.*
11 | - pandas==1.2.*
12 | - scipy==1.6.*
13 | # jupyter notebooks
14 | - ipykernel
15 | - notebook
16 | - jupyterlab
17 | - nb_conda_kernels
18 | # python visualization
19 | - matplotlib==3.4.*
20 | - seaborn==0.11.*
21 | # machine learning
22 | - scikit-learn==1.0.*
23 | - pytorch==2.2.*
24 | - torchvision==0.17.*
25 | - skorch==0.11.*
26 | - pip:
27 | # tensorboard
28 | - tensorboard==2.14.*
29 | # m/eeg analysis
30 | - mne==0.22.*
31 | - moabb==0.4.*
32 | # command line interfacing
33 | - hydra-core==1.3.*
34 | - hydra-joblib-launcher==1.2.*
35 | # machine learning
36 | - pyriemann==0.2.*
37 | - git+https://github.com/geoopt/geoopt.git@524330b11c0f9f6046bda59fe334803b4b74e13e
38 | # this package
39 | - -e .
40 |
--------------------------------------------------------------------------------