├── .DS_Store ├── CVPR24_SPDMLR_PPT.pdf ├── CVPR24_SPDMLR_Poster.pdf ├── Hyperplane ├── .DS_Store ├── SPD_hyperplane_3D.m ├── confPath.m ├── hyperplanes_all.jpg └── utilities │ ├── gen_hyperplane.m │ ├── print_SPD.m │ └── spd_power.m ├── README.md ├── SPDNet-MLR.py ├── TSMNet-MLR.py ├── conf ├── .DS_Store ├── SPDNet │ ├── SPDNetMLR.yaml │ ├── dataset │ │ ├── HDM05.yaml │ │ └── RADAR.yaml │ └── nnet │ │ └── SPDNet.yaml └── TSMNet │ ├── TSMNetMLR.yaml │ ├── dataset │ └── hinss2021.yaml │ ├── evaluation │ ├── inter-session+uda.yaml │ ├── inter-session.yaml │ ├── inter-subject+uda.yaml │ └── inter-subject.yaml │ ├── nnet │ ├── tsmnet.yaml │ └── tsmnet_spddsmbn.yaml │ └── preprocessing │ └── bb4-36Hz.yaml ├── datasets ├── __init__.py ├── eeg │ ├── eeg_utils.py │ └── moabb │ │ ├── __init__.py │ │ ├── base.py │ │ └── hinss2021.py └── spdnet │ ├── HDM05_Loader.py │ └── Radar_Loader.py ├── environment.yaml ├── exp_eeg.sh ├── exp_spdnets.sh ├── library ├── .DS_Store ├── __init__.py ├── __pycache__ │ └── __init__.cpython-310.pyc └── utils │ ├── __init__.py │ ├── __pycache__ │ └── __init__.cpython-310.pyc │ ├── hydra │ ├── __init__.py │ └── __pycache__ │ │ └── __init__.cpython-310.pyc │ ├── moabb │ ├── __init__.py │ └── __pycache__ │ │ └── __init__.cpython-310.pyc │ ├── pyriemann │ └── __init__.py │ └── torch │ ├── __init__.py │ └── __pycache__ │ └── __init__.cpython-310.pyc ├── setup.py ├── spd ├── __init__.py ├── functional.py └── spd_matrices.py └── spdnets ├── .DS_Store ├── SPDMLR.py ├── __init__.py ├── __pycache__ ├── __init__.cpython-310.pyc ├── batchnorm.cpython-310.pyc ├── manifolds.cpython-310.pyc └── modules.cpython-310.pyc ├── batchnorm.py ├── cplx ├── functional.py └── nn.py ├── functionals.py ├── manifolds.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── base.cpython-310.pyc │ ├── dann.cpython-310.pyc │ ├── eegnet.cpython-310.pyc │ ├── shconvnet.cpython-310.pyc │ └── tsmnet.cpython-310.pyc ├── base.py ├── dann.py ├── eegnet.py ├── shconvnet.py ├── spdnet.py ├── tsmnet.py └── tsmnetMLR.py ├── modules.py ├── training ├── eeg_training.py └── spdnet_training.py └── utils ├── common_utils.py ├── skorch ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── logging.cpython-310.pyc │ └── network.cpython-310.pyc ├── logging.py └── network.py └── spdnet ├── Get_Model.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/.DS_Store -------------------------------------------------------------------------------- /CVPR24_SPDMLR_PPT.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/CVPR24_SPDMLR_PPT.pdf -------------------------------------------------------------------------------- /CVPR24_SPDMLR_Poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/CVPR24_SPDMLR_Poster.pdf -------------------------------------------------------------------------------- /Hyperplane/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/Hyperplane/.DS_Store -------------------------------------------------------------------------------- /Hyperplane/SPD_hyperplane_3D.m: -------------------------------------------------------------------------------- 1 | clear 2 | clc 3 | confPath 4 | rng(0); 5 | 6 | dim=2; 7 | num=20; 8 | max_bound = 3; 9 | size = 10; 10 | fontsize = 14; 11 | Z = zeros(num,num); 12 | mask = find(tril(ones(2,2))>0); 13 | SPD_X=zeros(num,num); 14 | SPD_Y=zeros(num,num); 15 | SPD_Z=zeros(num,num); 16 | 17 | [tmp_X,tmp_Z] = meshgrid(linspace(0,max_bound,num),linspace(0,max_bound,num)); 18 | tmp_Y = sqrt(tmp_X.*tmp_Z); 19 | X_spd=[tmp_X,tmp_X]; 20 | Z_spd=[tmp_Z,tmp_Z]; 21 | Y_spd=[tmp_Y,-tmp_Y]; 22 | 23 | metric={'LEM','LCM'}; 24 | total_num=2; 25 | % % LEM 26 | ith=1; 27 | subplot(1,total_num,ith) 28 | % A_vec={[2,0.5,1],[1,1,2],[1,1,10]}; 29 | A_vec={[10,0,0],[1,1,0],[0,0,10]}; 30 | print_SPD(X_spd,Y_spd,Z_spd,size,A_vec,metric{ith},num,fontsize,max_bound) 31 | 32 | % LCM 33 | ith=2; 34 | subplot(1,total_num,ith) 35 | % A_vec={[1,1,0.5],[1,1,2],[1,1,100]}; 36 | A_vec={[0.5,0,1],[2,0,1],[100,0,1]}; 37 | print_SPD(X_spd,Y_spd,Z_spd,size,A_vec,metric{ith},num,fontsize,max_bound) 38 | 39 | hfig = gcf; 40 | figWidth = 16; % Width of one column in inches (adjust as needed) 41 | figHeight = 5; % Height of the figure in inches (adjust as needed) 42 | set(hfig, 'Units', 'centimeters','Position', [1, 1, figWidth, figHeight]); 43 | % set(hfig, 'PaperSize', [figWidth+1 figHeight]) 44 | 45 | % figWidth = 14; % 设置图片宽度 14 7 46 | % figHeight = 8.6; % 设置图片高度 8.6 4.3 47 | % set(hfig,'PaperUnits','centimeters'); % 图片尺寸所用单位 48 | % set(hfig,'PaperPosition',[0 0 figWidth figHeight]); 49 | % set(hfig, 'PaperSize', [figWidth+1 figHeight]); 50 | % fileout = ['hyperplane.']; % 输出图片的文件名 51 | fileout = append('hyperplanes_all','.'); 52 | % print(hfig,[fileout,'tif'],'-r600','-dtiff'); % 设置图片格式、分辨率 53 | print(hfig,[fileout,'jpg'],'-r600','-djpeg'); % 设置图片格式、分辨率 54 | % print(hfig,[fileout,'pdf'],'-r600','-dpdf'); % 设置图片格式、分辨率 55 | 56 | function [x,y,z] = geodesic(vx,vy,vz,t) 57 | %% geodesic starting at I, with velolcity of v (correspongding velolcity in sym) 58 | v=[vx,vy;vy,vz]; 59 | p=[0.5819,0.8811; 0.8811,-0.5819]; 60 | % p=[2.0329,3.0810; 3.0810,-2.0326]; 61 | x=zeros(length(t),1); 62 | y=zeros(length(t),1); 63 | z=zeros(length(t),1); 64 | for ith = 1:length(t) 65 | tmp = expm(p+v*t(ith)); 66 | x(ith)=tmp(1); 67 | y(ith)=tmp(2); 68 | z(ith)=tmp(4); 69 | end 70 | end -------------------------------------------------------------------------------- /Hyperplane/confPath.m: -------------------------------------------------------------------------------- 1 | addpath(pwd); 2 | 3 | cd utilities; 4 | addpath(genpath(pwd)); 5 | cd ..; 6 | -------------------------------------------------------------------------------- /Hyperplane/hyperplanes_all.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/Hyperplane/hyperplanes_all.jpg -------------------------------------------------------------------------------- /Hyperplane/utilities/gen_hyperplane.m: -------------------------------------------------------------------------------- 1 | function [X,Y,Z] = gen_hyperplane(num,max,A_vec,metric,varargin) 2 | %% corresponding to the hyperplane in sym, satisfying =0, with S'= log(S); 3 | %% generate hyperplane 4 | %% P=I,A=diag(A1,A4) for LEM, satisfying =0; 5 | %% P=diag{P1,P4},P=I,A=diag(A1,A4) for AIM 6 | %% P=I,A=diag(A1,A4),for \theta-EM; 7 | 8 | X = zeros(num,1); 9 | Y = zeros(num,1); 10 | Z = zeros(num,1); 11 | tmp_spd = zeros(2,2); 12 | for ith = 1:num 13 | tmp_spd=gen_2D_spd(max,A_vec,metric,varargin); 14 | while (~ is_in_cone(tmp_spd,max)) 15 | tmp_spd=gen_2D_spd(max,A_vec,metric,varargin); 16 | end 17 | X(ith) = tmp_spd(1); 18 | Y(ith) = tmp_spd(2); 19 | Z(ith) = tmp_spd(4); 20 | end 21 | end 22 | 23 | function [spd] = gen_2D_spd(max,A_vec,metric,varargin) 24 | identity = eye(2); 25 | if strcmp(metric,'LEM') || strcmp(metric,'AIM') 26 | sym = Cal_sym(max,A_vec); 27 | spd=expm(sym); 28 | elseif strcmp(metric,'EM') 29 | theta=varargin{1}{1}; 30 | tmp_spd=zeros(2,2); 31 | while (is_not_spd(tmp_spd)) 32 | sym = Cal_sym(max,A_vec); 33 | tmp_spd = sym+identity; 34 | end 35 | spd = spd_power(tmp_spd,1\theta); 36 | elseif strcmp(metric,'BWM') 37 | tmp_spd=zeros(2,2); 38 | while (is_not_spd(tmp_spd)) 39 | sym = Cal_sym(max,A_vec); 40 | tmp_spd = sym+identity; 41 | end 42 | spd = tmp_spd*tmp_spd; 43 | elseif strcmp(metric,'LCM') 44 | A_vec(1)=0.5*A_vec(1); A_vec(3)=0.5*A_vec(3); 45 | sym = Cal_sym(max,A_vec); 46 | sym(1,2)=0; 47 | sym(1,1)=exp(sym(1,1));sym(2,2)=exp(sym(2,2)); 48 | spd=sym * sym'; 49 | end 50 | end 51 | 52 | function [sym] = Cal_sym(max,A_vec) 53 | %% calculating S in =0 with A_vec=[A1,A2,A4] 54 | sym = zeros(2); 55 | S1 = max*rand();S2 = max*rand();S3 = max*rand(); 56 | A1=A_vec(1); A2=A_vec(2); A4=A_vec(3); 57 | if A4~=0 58 | S4=(-1/A4) * (S1*A1+2*S2*A2); 59 | else 60 | S4=max*rand(); 61 | if A1~=0 62 | S1=((-2*A2)/A1) * S2 ; 63 | else 64 | S2=0; 65 | end 66 | end 67 | sym(1) = S1;sym(2)=S2;sym(3)=S2;sym(4) = S4; 68 | end 69 | 70 | function result = is_in_cone(X,max) 71 | if abs(X(1))<=max && abs(X(4))<=max && abs(X(2))<=max 72 | result = true; 73 | else 74 | result = false; 75 | end 76 | end 77 | 78 | function [spd] = gen_2dim_spd_t2(max) 79 | sym = zeros(2,2); 80 | x_1 = max*rand(); 81 | x_2 = max*rand(); 82 | sym(1) = x_1; 83 | sym(4) = x_2; 84 | spd=expm(sym); 85 | end 86 | 87 | function [results] = is_not_spd(X) 88 | [U,S] = eig(X); 89 | results = any(diag(S) <= 0); 90 | end 91 | -------------------------------------------------------------------------------- /Hyperplane/utilities/print_SPD.m: -------------------------------------------------------------------------------- 1 | function print_SPD(X_spd,Y_spd,Z_spd,size,A_vec,metric,num,fontsize,max_bound,varargin) 2 | % generate SPD 3 | if strcmp(metric,'EM') 4 | theta=varargin{1}; 5 | [X_h,Y_h,Z_h] = gen_hyperplane(num*num*10,max_bound,A_vec,metric,theta(1)); 6 | [X_h2,Y_h2,Z_h2] = gen_hyperplane(num*num*10,max_bound,A_vec,metric,theta(2)); 7 | [X_h3,Y_h3,Z_h3] = gen_hyperplane(num*num*10,max_bound,A_vec,metric,theta(3)); 8 | else 9 | [X_h,Y_h,Z_h] = gen_hyperplane(num*num*10,max_bound,A_vec{1},metric,varargin); 10 | [X_h2,Y_h2,Z_h2] = gen_hyperplane(num*num*10,max_bound,A_vec{2},metric,varargin); 11 | [X_h3,Y_h3,Z_h3] = gen_hyperplane(num*num*10,max_bound,A_vec{3},metric,varargin); 12 | end 13 | % Print SPD 14 | scatter3(X_spd,Y_spd,Z_spd,size,'k','.') 15 | hold on 16 | scatter3(X_h,Y_h,Z_h,size,'b','.') 17 | scatter3(X_h2,Y_h2,Z_h2,size,'r','.') 18 | scatter3(X_h3,Y_h3,Z_h3,size,'y','.') 19 | % scatter3(X_h4,Y_h4,Z_h4,'g','.') 20 | 21 | xlabel('$x$','interpreter','latex') 22 | ylabel('$y$','interpreter','latex') 23 | zlabel('$z$','interpreter','latex') 24 | metric_name =get_metric_name(metric); 25 | title(metric_name,'interpreter','latex'); 26 | % legend('Boudary of SPD Manifolds','P=I, A=diag(1,0)','P=I, A=diag(1,1)','P=I, A=diag(1,100)') 27 | set(gca,'FontSize',fontsize); 28 | view(-60,30) 29 | 30 | % hfig = gcf; 31 | % figWidth = 7; % 设置图片宽度 14 7 32 | % figHeight = 4.3; % 设置图片高度 8.6 4.3 33 | % set(hfig,'PaperUnits','centimeters'); % 图片尺寸所用单位 34 | % set(hfig,'PaperPosition',[0 0 figWidth figHeight]); 35 | % set(hfig, 'PaperSize', [figWidth+1 figHeight]); 36 | % % fileout = ['hyperplane.']; % 输出图片的文件名 37 | % fileout = append('hyperplane_',metric,'.') 38 | % % print(hfig,[fileout,'tif'],'-r600','-dtiff'); % 设置图片格式、分辨率 39 | % view(-60,30) 40 | % print(hfig,[fileout,'pdf'],'-r600','-dpdf'); % 设置图片格式、分辨率 41 | end 42 | 43 | function [name]=get_metric_name(metric) 44 | if strcmp(metric,'LEM') || strcmp(metric,'AIM') || strcmp(metric,'EM') 45 | name = strcat('$(\alpha,\beta)$-',metric); 46 | else 47 | name = strcat('$(\theta)$-',metric); 48 | end 49 | end -------------------------------------------------------------------------------- /Hyperplane/utilities/spd_power.m: -------------------------------------------------------------------------------- 1 | function [X_power] = spd_power(X,theta) 2 | %% for SPD matrix 3 | if det(X)<=0 4 | error("Wrong SPD") 5 | end 6 | [U, S, V] = svd(X); 7 | S_power = S.^theta; 8 | X_power = U * S_power * U'; 9 | end -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [](https://arxiv.org/abs/2305.11288) 2 | 3 | [//]: # ([](https://openreview.net/forum?id=okYdj8Ysru)) 4 | 5 | [//]: # ([](https://openreview.net/pdf?id=okYdj8Ysru)) 6 | 7 | 8 | # Riemannian Multinomial Logistics Regression for SPD Neural Networks 9 | 10 | **Update (2025.01):** Please check our [RMLR](https://arxiv.org/abs/2409.19433) ([code](https://github.com/GitZH-Chen/RMLR.git)) for additional SPD implementations. 11 | 12 | 13 | 14 | This is the official code for our CVPR 2024 publication: *Riemannian Multinomial Logistics Regression for SPD Neural Networks*. 15 | 16 | If you find this project helpful, please consider citing us as follows: 17 | 18 | 19 | ```bib 20 | 21 | @inproceedings{chen2024spdmlr, 22 | title={Riemannian Multinomial Logistics Regression for {SPD} Neural Networks}, 23 | author={Ziheng Chen and Yue Song and Gaowen Liu and Ramana Rao Kompella and Xiaojun Wu and Nicu Sebe}, 24 | booktitle={Conference on Computer Vision and Pattern Recognition 2024}, 25 | year={2024} 26 | } 27 | ``` 28 | And also our ICLR24 paper on Riemannian normalization over Lie groups: 29 | ```bib 30 | @inproceedings{chen2024liebn, 31 | title={A Lie Group Approach to Riemannian Batch Normalization}, 32 | author={Ziheng Chen and Yue Song and Yunmei Liu and Nicu Sebe}, 33 | booktitle={The Twelfth International Conference on Learning Representations}, 34 | year={2024}, 35 | url={https://openreview.net/forum?id=okYdj8Ysru} 36 | } 37 | ``` 38 | 39 | In case you have any problem, do not hesitate to contact me ziheng_ch@163.com. 40 | 41 | ## Requirements 42 | 43 | Install necessary dependencies by `conda`: 44 | 45 | ```setup 46 | conda env create --file environment.yaml 47 | ``` 48 | 49 | **Note** that the [hydra](https://hydra.cc/) package is used to manage configuration files. 50 | 51 | ## Experiments 52 | 53 | The implementation is based on the official code of 54 | 55 | - *Riemannian batch normalization for SPD neural networks* [[Neurips 2019](https://papers.nips.cc/paper_files/paper/2019/hash/6e69ebbfad976d4637bb4b39de261bf7-Abstract.html)] [[code](https://papers.nips.cc/paper_files/paper/2019/file/6e69ebbfad976d4637bb4b39de261bf7-Supplemental.zip)]. 56 | - *SPD domain-specific batch normalization to crack interpretable unsupervised domain adaptation in EEG* [[Neurips 2022](https://openreview.net/forum?id=pp7onaiM4VB)] [[code](https://github.com/rkobler/TSMNet.git)]. 57 | 58 | ### Dataset 59 | 60 | The synthetic [Radar](https://www.dropbox.com/s/dfnlx2bnyh3kjwy/data.zip?e=1&dl=0) dataset is released by SPDNetBN. We further release our preprocessed [HDM05](https://www.dropbox.com/scl/fi/x2ouxjwqj3zrb1idgkg2g/HDM05.zip?rlkey=4f90ktgzfz28x3i2i4ylu6dvu&dl=0) dataset. 61 | 62 | The [Hinss2021](https://doi.org/10.5281/zenodo.5055046) dataset is publicly available. 63 | The [moabb](https://neurotechx.github.io/moabb/) and [mne](https://mne.tools) packages are used to download and preprocess these datasets. 64 | There is no need to manually download and preprocess the datasets. 65 | This is done automatically. 66 | 67 | Please download the datasets and put them in your personal folder. 68 | If necessary, change the `path` accordingly in 69 | `conf/SPDNet/dataset/HDM05.yaml`, `conf/SPDNet/dataset/RADAR.yaml`, and `data_dir` in `conf/TSMNet/TSMNetMLR.yaml`. 70 | 71 | ### Running experiments 72 | 73 | To run experiments on the SPDNet, run this command: 74 | 75 | ```train 76 | bash exp_spdnets.sh 77 | ``` 78 | To run experiments on the TSMNet, run this command: 79 | ```train 80 | bash exp_eeg.sh 81 | ``` 82 | 83 | These scripts contain the experiments shown in Tabs. 3-5. 84 | 85 | **Note:** You also can change the `data_dir` in `exp_eeg.sh` or `xx_path` in `exp_spdnets.sh`, which will override the hydra config. 86 | 87 | To reproduce Fig. 1, please run `Hyperplane/SPD_hyperplane_3D.m` 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /SPDNet-MLR.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import DictConfig 3 | 4 | from spdnets.training.spdnet_training import training 5 | 6 | class Args: 7 | """ a Struct Class """ 8 | pass 9 | args=Args() 10 | args.config_name='SPDNetMLR.yaml' 11 | 12 | @hydra.main(config_path='./conf/SPDNet/', config_name=args.config_name, version_base='1.1') 13 | def main(cfg: DictConfig): 14 | training(cfg,args) 15 | 16 | if __name__ == "__main__": 17 | main() -------------------------------------------------------------------------------- /TSMNet-MLR.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import DictConfig 3 | import warnings 4 | from sklearn.exceptions import FitFailedWarning, ConvergenceWarning 5 | from library.utils.hydra import hydra_helpers 6 | 7 | from spdnets.training.eeg_training import training 8 | 9 | warnings.filterwarnings("ignore", category=FitFailedWarning) 10 | warnings.filterwarnings("ignore", category=ConvergenceWarning) 11 | warnings.filterwarnings("ignore", category=FutureWarning) 12 | warnings.filterwarnings("ignore", category=UserWarning) 13 | warnings.filterwarnings("ignore", category=RuntimeWarning) 14 | 15 | class Args: 16 | """ a Struct Class """ 17 | pass 18 | args=Args() 19 | args.config_name='TSMNetMLR.yaml' 20 | args.architecture='[40,20]' 21 | 22 | @hydra_helpers 23 | @hydra.main(config_path='./conf/TSMNet/', config_name=args.config_name, version_base='1.1') 24 | def main(cfg: DictConfig): 25 | training(cfg,args) 26 | 27 | if __name__ == '__main__': 28 | 29 | main() 30 | -------------------------------------------------------------------------------- /conf/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/conf/.DS_Store -------------------------------------------------------------------------------- /conf/SPDNet/SPDNetMLR.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - nnet: SPDNet 4 | - dataset: RADAR # HDM05, RADAR 5 | - override hydra/launcher: joblib 6 | fit: 7 | epochs: 200 8 | batch_size: 30 9 | threadnum: 2 10 | is_writer: True 11 | cycle: 1 12 | seed: 1024 13 | is_save: True 14 | 15 | hydra: 16 | run: 17 | dir: ./outputs/${dataset.name} 18 | sweep: 19 | dir: ./outputs/${dataset.name} 20 | subdir: '.' 21 | launcher: 22 | n_jobs: -1 23 | job_logging: 24 | handlers: 25 | file: 26 | class: logging.FileHandler 27 | filename: default.log 28 | -------------------------------------------------------------------------------- /conf/SPDNet/dataset/HDM05.yaml: -------------------------------------------------------------------------------- 1 | name: HDM05 2 | class_num: 117 3 | path: /data #change this to your data folder 4 | -------------------------------------------------------------------------------- /conf/SPDNet/dataset/RADAR.yaml: -------------------------------------------------------------------------------- 1 | name: RADAR 2 | class_num: 3 3 | path: /data #change this to your data folder -------------------------------------------------------------------------------- /conf/SPDNet/nnet/SPDNet.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | model_type: SPDNet 3 | init_mode: svd #uniform,svd 4 | bimap_manifold: stiefel #euclidean, stiefel 5 | architecture: [20,16,8] 6 | classifier: LogEigMLR #LogEigMLR, SPDMLR 7 | metric: SPDLogCholeskyMetric #SPDLogEuclideanMetric,SPDLogCholeskyMetric 8 | power: 1. 9 | alpha: 1.0 10 | beta: 0.0 11 | optimizer: 12 | mode: AMSGRAD #AMSGRAD,SGD,ADAM 13 | lr: 1e-2 14 | weight_decay: 0 -------------------------------------------------------------------------------- /conf/TSMNet/TSMNetMLR.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - dataset: hinss2021 4 | - evaluation: inter-session+uda 5 | - preprocessing: bb4-36Hz 6 | - nnet: tsmnet_spddsmbn 7 | - override hydra/launcher: joblib 8 | 9 | fit: 10 | stratified: True 11 | epochs: 50 12 | batch_size_train: 50 13 | domains_per_batch: 5 14 | batch_size_test: -1 15 | validation_size: 0.2 #0.1 # float <1 for fraction; int for specific number 16 | test_size: 0.05 # percent of groups/domains used for testing 17 | 18 | score: balanced_accuracy # sklearn scores 19 | device: GPU 20 | threadnum: 2 21 | data_dir: /data #change this to your data folder 22 | is_debug: False 23 | seed: 42 24 | is_timing: False 25 | 26 | hydra: 27 | run: 28 | dir: outputs/${dataset.name}/${evaluation.strategy} 29 | sweep: 30 | dir: outputs/${dataset.name}/${evaluation.strategy} 31 | subdir: '.' 32 | launcher: 33 | n_jobs: -1 34 | job_logging: 35 | handlers: 36 | file: 37 | class: logging.FileHandler 38 | filename: default.log 39 | saving_model: 40 | is_save: False -------------------------------------------------------------------------------- /conf/TSMNet/dataset/hinss2021.yaml: -------------------------------------------------------------------------------- 1 | # name of the dataset 2 | name: Hinss2021 3 | # python type (and parameters) 4 | type: 5 | _target_: datasets.eeg.moabb.Hinss2021 6 | 7 | classes: ["easy", "medium", "difficult"] 8 | # channel selection 9 | channels: ['FP1', 'FP2', 'FPz', 'AF7', 'AF3', 'AFz', 'AF4', 'AF8', 'F7', 'F5', 'F3', 'F1', 'F2', 'F4', 'F6', 'F8', 'FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'C3', 'C4', 'CPz', 'PO3', 'PO4', 'POz', 'Oz' ] 10 | # if 'null' or not defined, all available channels will be used 11 | # resampling 12 | # if 'null' or not defined, the datasets sampling frequency will be used 13 | # resample: 250 # Hz 14 | ## epoching (relative to TASK CUE onset, as defined in the dataset) 15 | tmin: 0.0 16 | tmax: 1.996 17 | -------------------------------------------------------------------------------- /conf/TSMNet/evaluation/inter-session+uda.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - inter-session 3 | - _self_ 4 | adapt: 5 | name: uda 6 | nadapt_domain: 1. # int -> absolute number of observations per CLASS -------------------------------------------------------------------------------- /conf/TSMNet/evaluation/inter-session.yaml: -------------------------------------------------------------------------------- 1 | strategy: inter-session 2 | adapt: 3 | name: 'no' -------------------------------------------------------------------------------- /conf/TSMNet/evaluation/inter-subject+uda.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - inter-subject 3 | - _self_ 4 | adapt: 5 | name: uda 6 | nadapt_domain: 1. # int -> absolute number of observations per CLASS -------------------------------------------------------------------------------- /conf/TSMNet/evaluation/inter-subject.yaml: -------------------------------------------------------------------------------- 1 | strategy: inter-subject 2 | adapt: 3 | name: 'no' -------------------------------------------------------------------------------- /conf/TSMNet/nnet/tsmnet.yaml: -------------------------------------------------------------------------------- 1 | name: TSMNet 2 | inputtype: ${torchdtype:float32} 3 | model: 4 | _target_: spdnets.models.TSMNetMLR 5 | temporal_filters: 4 6 | spatial_filters: 40 7 | subspacedims: 20 8 | bnorm: null 9 | bnorm_dispersion: null 10 | classifier: LogEigMLR #LogEigMLR, SPDMLR 11 | metric: SPDLogCholeskyMetric #SPDLogEuclideanMetric,SPDLogCholeskyMetric 12 | power: 1. 13 | alpha: 1.0 14 | beta: 0.0 15 | optimizer: 16 | _target_: geoopt.optim.RiemannianAdam 17 | amsgrad: True 18 | weight_decay: 1e-4 19 | lr: 1e-3 20 | param_groups: 21 | - 22 | - 'spdnet.*.W' 23 | - weight_decay: 0 24 | scheduler: 25 | _target_: spdnets.batchnorm.DummyScheduler -------------------------------------------------------------------------------- /conf/TSMNet/nnet/tsmnet_spddsmbn.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - tsmnet 3 | - _self_ 4 | name: TSMNet+SPDDSMBN 5 | model: 6 | bnorm: spddsbn 7 | bnorm_dispersion: SCALAR 8 | scheduler: 9 | _target_: spdnets.batchnorm.MomentumBatchNormScheduler 10 | epochs: ${sub:${fit.epochs},10} 11 | bs: ${rdiv:${fit.batch_size_train},${fit.domains_per_batch}} 12 | bs0: ${fit.batch_size_train} 13 | tau0: 0.85 14 | 15 | optimizer: 16 | param_groups: 17 | - 18 | - 'spd*.mean' 19 | - weight_decay: 0 20 | - 21 | - 'spdnet.*.W' 22 | - weight_decay: 0 -------------------------------------------------------------------------------- /conf/TSMNet/preprocessing/bb4-36Hz.yaml: -------------------------------------------------------------------------------- 1 | bb4-36Hz: 2 | _target_: library.utils.moabb.CachedMotorImagery 3 | fmin: 4 # Hz 4 | fmax: 36 # Hz 5 | events: ${dataset.classes} 6 | channels: ${oc.select:dataset.channels} 7 | resample: ${oc.select:dataset.resample} 8 | tmin: ${oc.select:dataset.tmin, 0.0} 9 | tmax: ${oc.select:dataset.tmax} 10 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/eeg/eeg_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import datetime 3 | import fcntl 4 | 5 | import os 6 | from time import time 7 | from hydra.core.hydra_config import HydraConfig 8 | import pandas as pd 9 | from skorch.callbacks.scoring import EpochScoring 10 | from skorch.dataset import ValidSplit 11 | from skorch.callbacks import Checkpoint 12 | import torch as th 13 | 14 | import logging 15 | import hydra 16 | import torch 17 | import numpy as np 18 | from omegaconf import DictConfig, OmegaConf, open_dict 19 | 20 | import moabb 21 | from sklearn.metrics import get_scorer, make_scorer 22 | from sklearn.model_selection import StratifiedShuffleSplit, GroupKFold 23 | from library.utils.moabb import CachedParadigm 24 | from spdnets.models import DomainAdaptBaseModel, DomainAdaptJointTrainableModel, EEGNetv4 25 | from spdnets.models import CPUModel 26 | from library.utils.torch import BalancedDomainDataLoader, CombinedDomainDataset, DomainIndex, StratifiedDomainDataLoader 27 | from spdnets.models.base import DomainAdaptFineTuneableModel, FineTuneableModel 28 | 29 | from spdnets.utils.skorch import DomainAdaptNeuralNetClassifier 30 | import mne 31 | 32 | from spdnets.utils.common_utils import set_seed_thread 33 | 34 | def training(cfg,args): 35 | data_dir = cfg.data_dir 36 | 37 | mne.set_config("MNE_DATA", data_dir) 38 | mne.set_config("MNEDATASET_TMP_DIR", data_dir) 39 | mne.set_config("_MNE_FAKE_HOME_DIR", data_dir) 40 | args.threadnum = cfg.threadnum 41 | args.is_debug = cfg.is_debug 42 | args.seed = cfg.seed 43 | set_seed_thread(args.seed, args.threadnum) 44 | 45 | args.name = cfg.nnet.name 46 | args.classifier = cfg.nnet.model.classifier 47 | args.metric = cfg.nnet.model.metric 48 | args.power = cfg.nnet.model.power 49 | args.alpha = cfg.nnet.model.alpha 50 | args.beta = cfg.nnet.model.beta 51 | 52 | args.optimiz = 'AMSGRAD' if cfg.nnet.optimizer.amsgrad else 'ADAM' 53 | # cfg.nnet.optimizer.amsgrad = True if args.optimiz == 'AMSGRAD' else False 54 | args.lr = cfg.nnet.optimizer.lr 55 | args.weight_decay = cfg.nnet.optimizer.weight_decay 56 | 57 | args.model_name = get_model_name(args) 58 | rng_seed = args.seed 59 | log = logging.getLogger(args.model_name) 60 | 61 | moabb.set_log_level("info") 62 | 63 | # setting device 64 | if cfg.device =='CPU': 65 | device = torch.device('cpu') 66 | elif cfg.device == 'GPU': 67 | gpuid = f"cuda:{HydraConfig.get().job.get('num', 0) % torch.cuda.device_count()}" 68 | # log.info(f"GPU ID: {gpuid}") 69 | device = torch.device(gpuid) 70 | elif 0 <= cfg.device and cfg.device<= th.cuda.device_count(): 71 | device = torch.device(cfg.device) 72 | else: 73 | log.info('Wrong device or not available') 74 | log.info(f"device: {device}") 75 | cpu = torch.device('cpu') 76 | 77 | with open_dict(cfg): 78 | if 'ft_pipeline' not in cfg.nnet: 79 | cfg.nnet.ft_pipeline = None 80 | if 'prep_pipeline' not in cfg.nnet: 81 | cfg.nnet.prep_pipeline = None 82 | 83 | dataset = hydra.utils.instantiate(cfg.dataset.type, _convert_='partial') 84 | ppreprocessing_dict = hydra.utils.instantiate(cfg.preprocessing, _convert_='partial') 85 | assert (len(ppreprocessing_dict) == 1) # only 1 paradigm is allowed per call 86 | prep_name, paradigm = next(iter(ppreprocessing_dict.items())) 87 | 88 | res_dir = os.path.join(cfg.evaluation.strategy, prep_name) 89 | if not os.path.exists(res_dir): 90 | os.makedirs(res_dir) 91 | 92 | results = pd.DataFrame( \ 93 | columns=['dataset', 'subject', 'session', 'method', 'score_trn', 'score_tst', 94 | 'time', 'n_test', 'classes']) 95 | resix = 0 96 | results['score_trn'] = results['score_trn'].astype(np.double) 97 | results['score_tst'] = results['score_tst'].astype(np.double) 98 | results['time'] = results['time'].astype(np.double) 99 | results['n_test'] = results['n_test'].astype(int) 100 | results['classes'] = results['classes'].astype(int) 101 | 102 | results_fit = [] 103 | 104 | scorefun = get_scorer(cfg.score)._score_func 105 | 106 | def masked_scorefun(y_true, y_pred, **kwargs): 107 | masked = y_true == -1 108 | if np.all(masked): 109 | log.warning('Nothing to score because all target values are masked (value = -1).') 110 | return np.nan 111 | return scorefun(y_true[~masked], y_pred[~masked], **kwargs) 112 | 113 | scorer = make_scorer(masked_scorefun) 114 | 115 | dadapt = cfg.evaluation.adapt 116 | 117 | bacc_val_logger = EpochScoring(scoring=scorer, 118 | lower_is_better=False, 119 | on_train=False, 120 | name='score_val') 121 | bacc_trn_logger = EpochScoring(scoring=scorer, 122 | lower_is_better=False, 123 | on_train=True, 124 | name='score_trn') 125 | 126 | if 'inter-session' in cfg.evaluation.strategy: 127 | subset_iter = iter([[s] for s in dataset.subject_list]) 128 | groupvarname = 'session' 129 | elif 'inter-subject' in cfg.evaluation.strategy: 130 | subset_iter = iter([None]) 131 | groupvarname = 'subject' 132 | else: 133 | raise NotImplementedError() 134 | if args.is_debug: 135 | subset_iter = iter([[1]]) 136 | number=0 137 | for subset in subset_iter: 138 | 139 | if groupvarname == 'session': 140 | domain_expression = "session" 141 | elif groupvarname == 'subject': 142 | domain_expression = "session + subject * 1000" 143 | 144 | selected_sessions = cfg.dataset.get("sessions", None) 145 | 146 | ds = CombinedDomainDataset.from_moabb(paradigm, dataset, subjects=subset, domain_expression=domain_expression, 147 | dtype=cfg.nnet.inputtype, sessions=selected_sessions) 148 | 149 | if cfg.nnet.prep_pipeline is not None: 150 | ds = ds.cache() # we need to load the entire dataset for preprocessing 151 | 152 | sessions = ds.metadata.session.astype(np.int64).values 153 | subjects = ds.metadata.subject.astype(np.int64).values 154 | 155 | g = ds.metadata[groupvarname].astype(np.int64).values 156 | groups = np.unique(g) 157 | 158 | domains = ds.domains.unique() 159 | 160 | n_classes = len(ds.labels.unique()) 161 | 162 | if len(groups) < 2: 163 | log.warning( 164 | f"Insufficient number (n={len(groups)}) of groups ({groupvarname}) in the (sub-)dataset to run leave 1 group out CV!") 165 | continue 166 | 167 | mdl_kwargs = dict(nclasses=n_classes) 168 | 169 | mdl_kwargs['nchannels'] = ds.shape[1] 170 | mdl_kwargs['nsamples'] = ds.shape[2] 171 | mdl_kwargs['nbands'] = ds.shape[3] if ds.ndim == 4 else 1 172 | mdl_kwargs['input_shape'] = (1,) + ds.shape[1:] 173 | 174 | mdl_dict = OmegaConf.to_container(cfg.nnet.model, resolve=True) 175 | mdl_class = hydra.utils.get_class(mdl_dict.pop('_target_')) 176 | 177 | if issubclass(mdl_class, DomainAdaptBaseModel): 178 | mdl_kwargs['domains'] = domains 179 | if issubclass(mdl_class, EEGNetv4): 180 | if isinstance(paradigm, CachedParadigm): 181 | info = paradigm.get_info(dataset) 182 | mdl_kwargs['srate'] = int(info['sfreq']) 183 | else: 184 | raise NotImplementedError() 185 | if issubclass(mdl_class, FineTuneableModel) and isinstance(ds, CombinedDomainDataset): 186 | # we need to load the entire dataset 187 | ds = ds.cache() 188 | 189 | mdl_kwargs = {**mdl_kwargs, **mdl_dict} 190 | 191 | optim_kwargs = OmegaConf.to_container(cfg.nnet.optimizer, resolve=True) 192 | optim_class = hydra.utils.get_class(optim_kwargs.pop('_target_')) 193 | 194 | metaddata = { 195 | 'model_class': mdl_class, 196 | 'model_kwargs': mdl_kwargs, 197 | 'optim_class': optim_class, 198 | 'optim_kwargs': optim_kwargs 199 | } 200 | if cfg.saving_model.is_save: 201 | mdl_metadata_dir = os.path.join(res_dir, 'metadata') 202 | if not os.path.exists(mdl_metadata_dir): 203 | os.makedirs(mdl_metadata_dir) 204 | torch.save(metaddata, f=os.path.join(mdl_metadata_dir, f'meta-{cfg.nnet.name}.pth')) 205 | with open(os.path.join(mdl_metadata_dir, f'config-{cfg.nnet.name}.yaml'), 'w+') as f: 206 | f.writelines(OmegaConf.to_yaml(cfg)) 207 | 208 | if issubclass(mdl_class, CPUModel): 209 | device = cpu 210 | 211 | mdl_kwargs['device'] = device 212 | 213 | n_test_groups = int(np.clip(np.round(len(groups) * cfg.fit.test_size), 1, None)) 214 | 215 | log.info(f"Performing leave {n_test_groups} (={cfg.fit.test_size * 100:.0f}%) {groupvarname}(s) out CV") 216 | cv = GroupKFold(n_splits=int(len(groups) / n_test_groups)) 217 | number = number + len(ds.labels) 218 | print(f"number: {number}") 219 | ds.eval() # unmask labels 220 | for train, test in cv.split(ds.labels, ds.labels, g): 221 | 222 | target_domains = ds.domains[test].unique().numpy() 223 | torch.manual_seed(rng_seed + target_domains[0]) 224 | 225 | prep_pipeline = hydra.utils.instantiate(cfg.nnet.prep_pipeline, _convert_='partial') 226 | ft_pipeline = hydra.utils.instantiate(cfg.nnet.ft_pipeline, _convert_='partial') 227 | 228 | if dadapt is not None and dadapt.name != 'no': 229 | # extend training data with adaptation set 230 | if issubclass(mdl_class, DomainAdaptJointTrainableModel): 231 | stratvar = ds.labels + ds.domains * n_classes 232 | adapt_domain = test # extract_adapt_idxs(dadapt.nadapt_domain, test, stratvar) 233 | else: 234 | # some nets to not require target domain data during training 235 | adapt_domain = np.array([], dtype=np.int64) 236 | log.info("Model does not require adaptation. Using original training data.") 237 | 238 | train_source_doms = train 239 | train = np.concatenate((train, adapt_domain)) 240 | 241 | if dadapt.name == 'uda': 242 | ds.set_masked_labels(adapt_domain) 243 | elif dadapt.name == 'sda': 244 | test = np.setdiff1d(test, adapt_domain) 245 | 246 | if len(test) == 0: 247 | raise ValueError('No data left in the test set!') 248 | else: 249 | train_source_doms = train 250 | 251 | test_groups = np.unique(g[test]) 252 | test_group_list = [] 253 | for test_group in test_groups: 254 | test_dict = {} 255 | subject = np.unique(subjects[g == test_group]) 256 | assert (len(subject) == 1) # only one subject per group 257 | test_dict['subject'] = subject[0] 258 | if groupvarname == 'subject': 259 | test_dict['session'] = -1 260 | else: 261 | session = np.unique(sessions[g == test_group]) 262 | assert (len(session) == 1) # only one session per group 263 | test_dict['session'] = session[0] 264 | test_dict['idxs'] = np.intersect1d(test, np.nonzero(g == test_group)) 265 | test_group_list.append(test_dict) 266 | 267 | t_start = time() 268 | 269 | ## preprocessing 270 | dsprep = ds.copy(deep=False) 271 | dsprep.train() # mask labels 272 | 273 | if prep_pipeline is not None: 274 | prep_pipeline.fit(dsprep.features[train].numpy(), dsprep.labels[train]) 275 | dsprep.set_features(prep_pipeline.transform(dsprep.features)) 276 | 277 | # torch dataset generation 278 | batch_size_valid = cfg.fit.validation_size if type(cfg.fit.validation_size) == int else int( 279 | np.ceil(cfg.fit.validation_size * len(train))) 280 | 281 | dsprep.eval() # unmask labels 282 | # extract stratified (classes and groups) validation data 283 | stratvar = dsprep.labels[train] + dsprep.domains[train] * n_classes 284 | valid_cv = ValidSplit(iter(StratifiedShuffleSplit(n_splits=1, test_size=cfg.fit.validation_size, 285 | random_state=rng_seed + target_domains[0]).split(stratvar, 286 | stratvar))) 287 | 288 | netkwargs = {'module__' + k: v for k, v in mdl_kwargs.items()} 289 | netkwargs = {**netkwargs, **{'optimizer__' + k: v for k, v in optim_kwargs.items()}} 290 | if cfg.fit.stratified: 291 | 292 | n_train_domains = len(dsprep.domains[train].unique()) 293 | domains_per_batch = min(cfg.fit.domains_per_batch, n_train_domains) 294 | batch_size_train = int( 295 | max(np.round(cfg.fit.batch_size_train / domains_per_batch), 2) * domains_per_batch) 296 | 297 | netkwargs['iterator_train'] = StratifiedDomainDataLoader 298 | netkwargs['iterator_train__domains_per_batch'] = domains_per_batch 299 | netkwargs['iterator_train__shuffle'] = True 300 | netkwargs['iterator_train__batch_size'] = batch_size_train 301 | else: 302 | netkwargs['iterator_train'] = BalancedDomainDataLoader 303 | netkwargs['iterator_train__domains_per_batch'] = cfg.fit.domains_per_batch 304 | netkwargs['iterator_train__drop_last'] = True 305 | netkwargs['iterator_train__replacement'] = False 306 | netkwargs['iterator_train__batch_size'] = cfg.fit.batch_size_train 307 | netkwargs['iterator_valid__batch_size'] = batch_size_valid 308 | netkwargs['max_epochs'] = cfg.fit.epochs 309 | netkwargs[ 310 | 'callbacks__print_log__prefix'] = f'{dataset.code} {n_classes}cl | {test_groups} | {args.model_name} :' 311 | 312 | scheduler = hydra.utils.instantiate(cfg.nnet.scheduler, _convert_='partial') 313 | 314 | # save model 315 | if cfg.saving_model.is_save: 316 | mdl_path_tmp = os.path.join(res_dir, 'models', 'tmp', f'{test_groups}_{cfg.nnet.name}.pth') 317 | if not os.path.exists(os.path.split(mdl_path_tmp)[0]): 318 | os.makedirs(os.path.split(mdl_path_tmp)[0]) 319 | checkpoint = Checkpoint( 320 | f_params=mdl_path_tmp, f_criterion=None, f_optimizer=None, f_history=None, 321 | monitor='valid_loss_best', load_best=True) 322 | net = DomainAdaptNeuralNetClassifier( 323 | mdl_class, 324 | train_split=valid_cv, 325 | callbacks=[bacc_trn_logger, bacc_val_logger, scheduler, checkpoint], 326 | optimizer=optim_class, 327 | verbose=0, 328 | device=device, 329 | **netkwargs) 330 | else: 331 | net = DomainAdaptNeuralNetClassifier( 332 | mdl_class, 333 | train_split=valid_cv, 334 | callbacks=[bacc_trn_logger, bacc_val_logger, scheduler], 335 | optimizer=optim_class, 336 | verbose=0, 337 | device=device, 338 | **netkwargs) 339 | 340 | dsprep.train() # mask labels 341 | dstrn = torch.utils.data.Subset(dsprep, train) 342 | net.fit(dstrn, None) 343 | 344 | res = pd.DataFrame(net.history) 345 | res = res.drop(res.filter(regex='.*batches|_best|_count').columns, axis=1) 346 | res = res.drop(res.filter(regex='event.*').columns, axis=1) 347 | res = res.rename(columns=dict(train_loss="loss_trn", valid_loss="loss_val", dur="time")) 348 | res['domains'] = str(test_groups) 349 | res['method'] = cfg.nnet.name 350 | res['dataset'] = dataset.code 351 | results_fit.append(res) 352 | if cfg.is_timing: 353 | time_epochs = res.time; 354 | log.info('{} average time: {:.2f} and average of smallest 5 time: {:.2f} in total {} epoch'.format( 355 | cfg.evaluation.strategy,\ 356 | np.mean(time_epochs[-5:]),np.mean(np.sort(time_epochs)[:5]),len(time_epochs))) 357 | return 358 | 359 | 360 | if cfg.evaluation.adapt.name == "uda": 361 | if isinstance(net.module_, DomainAdaptFineTuneableModel): 362 | dsprep.train() # mask target domain labels 363 | for du in dsprep.domains.unique(): 364 | domain_data = dsprep[DomainIndex(du.item())] 365 | net.module_.domainadapt_finetune(x=domain_data[0]['x'], y=domain_data[1], d=domain_data[0]['d'], 366 | target_domains=target_domains) 367 | elif cfg.evaluation.adapt.name == "no": 368 | if isinstance(net.module_, FineTuneableModel): 369 | dsprep.train() # mask target domain labels 370 | net.module_.finetune(x=dsprep.features[train], y=dsprep.labels[train], d=dsprep.domains[train]) 371 | 372 | duration = time() - t_start 373 | 374 | # save the final model 375 | if cfg.saving_model.is_save: 376 | for test_group in test_group_list: 377 | mdl_path = os.path.join(res_dir, 'models', f'{test_group["subject"]}', f'{test_group["session"]}', 378 | f'{cfg.nnet.name}.pth') 379 | if not os.path.exists(os.path.split(mdl_path)[0]): 380 | os.makedirs(os.path.split(mdl_path)[0]) 381 | net.save_params(f_params=mdl_path) 382 | 383 | ## evaluation 384 | dsprep.eval() # unmask target domain labels 385 | 386 | y_hat = np.empty(dsprep.labels.shape) 387 | # find out latent space dimensionality 388 | _, l0 = net.forward(dsprep[DomainIndex(dsprep.domains[0])][0]) 389 | l = np.empty((len(dsprep),) + l0.shape[1:]) 390 | 391 | for du in dsprep.domains.unique(): 392 | ixs = np.flatnonzero(dsprep.domains == du) 393 | domain_data = dsprep[DomainIndex(du)] 394 | 395 | y_hat_domain, l_domain, *_ = net.forward(domain_data[0]) 396 | y_hat_domain, l_domain = y_hat_domain.numpy().argmax(axis=1), l_domain.to(device=cpu).numpy() 397 | y_hat[ixs] = y_hat_domain 398 | l[ixs] = l_domain 399 | 400 | score_trn = scorefun(dsprep.labels[train_source_doms], y_hat[train_source_doms]) 401 | 402 | for test_group in test_group_list: 403 | score_tst = scorefun(dsprep.labels[test_group["idxs"]], y_hat[test_group["idxs"]]) 404 | 405 | res = pd.DataFrame({'dataset': dataset.code, 406 | 'subject': test_group["subject"], 407 | 'session': test_group["session"], 408 | 'method': cfg.nnet.name, 409 | 'score_trn': score_trn, 410 | 'score_tst': score_tst, 411 | 'time': duration, 412 | 'n_test': len(test), 413 | 'classes': n_classes}, index=[resix]) 414 | results = results.append(res) 415 | resix += 1 416 | r = res.iloc[0, :] 417 | log.info( 418 | f'{r.dataset} {r.classes}cl | {r.subject} | {r.session} : trn={r.score_trn:.2f} tst={r.score_tst:.2f} time={duration:.2f}') 419 | 420 | ## fine tuning 421 | if ft_pipeline is not None: 422 | # fitting 423 | dsprep.train() # mask target domain labels 424 | ft_pipeline.fit(l[train], dsprep.labels[train]) 425 | y_hat_ft = ft_pipeline.predict(l) 426 | 427 | # evaluation 428 | dsprep.eval() # unmask target domain labels 429 | ft_score_trn = scorefun(dsprep.labels[train_source_doms], y_hat_ft[train_source_doms]) 430 | 431 | for test_group in test_group_list: 432 | ft_score_tst = scorefun(dsprep.labels[test_group["idxs"]], y_hat_ft[test_group["idxs"]]) 433 | 434 | res = pd.DataFrame({'dataset': dataset.code, 435 | 'subject': test_group["subject"], 436 | 'session': test_group["session"], 437 | 'method': f'{cfg.nnet.name}+FT', 438 | 'score_trn': ft_score_trn, 439 | 'score_tst': ft_score_tst, 440 | 'time': duration, 441 | 'n_test': len(test), 442 | 'classes': n_classes}, index=[resix]) 443 | results = results.append(res) 444 | resix += 1 445 | r = res.iloc[0, :] 446 | log.info( 447 | f'{r.dataset} {r.classes}cl | {r.subject} | {r.session} | {r.method} : trn={r.score_trn:.2f} tst={r.score_tst:.2f}') 448 | 449 | if len(results_fit): 450 | results_fit = pd.concat(results_fit) 451 | 452 | results_fit['preprocessing'] = prep_name 453 | results_fit['evaluation'] = cfg.evaluation.strategy 454 | results_fit['adaptation'] = cfg.evaluation.adapt.name 455 | 456 | for method in results_fit['method'].unique(): 457 | method_res = results[results['method'] == method] 458 | results_fit.to_csv(os.path.join(res_dir, f'nnfitscores_{method}.csv'), index=False) 459 | 460 | if len(results) > 0: 461 | 462 | results['preprocessing'] = prep_name 463 | results['evaluation'] = cfg.evaluation.strategy 464 | results['adaptation'] = cfg.evaluation.adapt.name 465 | if cfg.saving_model.is_save: 466 | for method in results['method'].unique(): 467 | method_res = results[results['method'] == method] 468 | method_res.to_csv(os.path.join(res_dir, f'scores_{method}.csv'), index=False) 469 | tmp = results.groupby('method').agg(['mean', 'std']) 470 | column_labels = [('score_trn', 'mean'), ('score_trn', 'std'), \ 471 | ('score_tst', 'mean'), ('score_tst', 'std')] 472 | time_lables = [('time', 'mean'), ('time', 'std')] 473 | row_label = tmp.index.tolist()[0] 474 | # print(tmp) 475 | # log.info(tmp.loc[row_label, column_labels] * 100) 476 | # log.info(tmp.loc[row_label, time_lables]) 477 | final_results = "final results: score_trn: {}/{}, score_tst: {}/{}, time: {}/{}".format( \ 478 | tmp.loc[row_label, column_labels[0]] * 100, tmp.loc[row_label, column_labels[1]] * 100, \ 479 | tmp.loc[row_label, column_labels[2]] * 100, tmp.loc[row_label, column_labels[3]] * 100, \ 480 | tmp.loc[row_label, time_lables[0]], tmp.loc[row_label, time_lables[1]] 481 | ) 482 | log.info(final_results) 483 | 484 | log_filename = HydraConfig.get().job_logging.handlers.file.filename 485 | split_filename = log_filename.rsplit('.',1) 486 | final_filename = f"final_result_{split_filename[0]}.txt" 487 | final_file_path = os.path.join(os.getcwd(),final_filename) 488 | log.info("results file path: {}, and saving the results".format(final_file_path)) 489 | write_final_results(final_file_path, args.model_name+'_'+final_results) 490 | 491 | def get_model_name(args): 492 | if args.classifier == 'SPDMLR': 493 | if args.metric == 'SPDLogEuclideanMetric': 494 | description = f'{args.metric}-[{args.alpha},{args.beta:.4f}]' 495 | elif args.metric == 'SPDLogCholeskyMetric': 496 | description = f'{args.metric}-[{args.power}]' 497 | 498 | description = '-' + description + '-' 499 | elif args.classifier == 'LogEigMLR': 500 | description='' 501 | else: 502 | raise NotImplementedError 503 | 504 | name = f'{args.lr}-{args.name}{description}-{args.classifier}-{args.architecture}-{datetime.datetime.now().strftime("%H_%M")}' 505 | return name 506 | 507 | def write_final_results(file_path,message): 508 | # Create a file lock 509 | with open(file_path, "a") as file: 510 | fcntl.flock(file.fileno(), fcntl.LOCK_EX) # Acquire an exclusive lock 511 | 512 | # Write the message to the file 513 | file.write(message + "\n") 514 | 515 | fcntl.flock(file.fileno(), fcntl.LOCK_UN) # Release the lock -------------------------------------------------------------------------------- /datasets/eeg/moabb/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import PreprocessedDataset, CachableDatase 2 | from .hinss2021 import Hinss2021 -------------------------------------------------------------------------------- /datasets/eeg/moabb/base.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import mne 3 | import json 4 | 5 | from moabb.datasets.base import BaseDataset 6 | 7 | class CachableDatase(BaseDataset): 8 | 9 | def __init__(self, *args, **kwargs): 10 | super().__init__(*args, **kwargs) 11 | 12 | def __repr__(self) -> str: 13 | return json.dumps({self.__class__.__name__: self.__dict__}) 14 | 15 | class PreprocessedDataset(CachableDatase): 16 | 17 | def __init__(self, *args, channels : Optional[list] = None, srate : Optional[int] = None, **kwargs): 18 | super().__init__(*args, **kwargs) 19 | self.channels = channels 20 | self.srate = srate 21 | 22 | def preprocess(self, raw): 23 | 24 | # find the events, first check stim_channels 25 | if len(mne.pick_types(raw.info, stim=True)) > 0: 26 | events = mne.find_events(raw, shortest_event=0, verbose=False) 27 | else: 28 | events = None # the dataset already uses annotations 29 | 30 | # optional resampling 31 | if self.srate is not None: 32 | ret = raw.resample(self.srate, events=events) 33 | raw, events = (ret, events) if events is None else (ret[0], ret[1]) 34 | 35 | # convert optional events to annotations (before we discard the stim channels) 36 | if events is not None: 37 | rev_event_it = dict(zip(self.event_id.values(), self.event_id.keys())) 38 | annot = mne.annotations_from_events(events, raw.info['sfreq'], event_desc=rev_event_it) 39 | raw.set_annotations(annot) 40 | 41 | # pick subset of all channels 42 | if self.channels is not None: 43 | raw.pick_channels(self.channels) 44 | else: 45 | raw.pick_types(eeg=True) 46 | 47 | return raw 48 | 49 | def __repr__(self) -> str: 50 | return json.dumps({self.__class__.__name__: self.__dict__}) 51 | -------------------------------------------------------------------------------- /datasets/eeg/moabb/hinss2021.py: -------------------------------------------------------------------------------- 1 | import os 2 | import mne 3 | import numpy as np 4 | import pooch 5 | import logging 6 | import requests 7 | import json 8 | import subprocess 9 | import re 10 | import glob 11 | 12 | from scipy.io import loadmat 13 | import moabb.datasets.download as dl 14 | 15 | from .base import PreprocessedDataset 16 | 17 | log = logging.getLogger(__name__) 18 | 19 | 20 | 21 | def doi_to_url(doi, api_url = lambda x : f"https://doi.org/api/handles/{x}?type=URL"): 22 | 23 | url = None 24 | headers = {"Content-Type": "application/json"} 25 | response_data = dl.fs_issue_request("GET", api_url(doi), headers=headers) 26 | 27 | if 'values' in response_data: 28 | candidates = [ val['data']['value'] for val in response_data['values'] if 'data' in val and isinstance(val['data'], dict) and 'value' in val['data']] 29 | url = candidates[0] if len(candidates)> 0 else None 30 | 31 | return url 32 | 33 | 34 | 35 | def url_get_json(url : str): 36 | 37 | headers = {"Content-Type": "application/json"} 38 | response = dl.fs_issue_request("GET", url, headers=headers) 39 | return response 40 | 41 | 42 | 43 | class Hinss2021(PreprocessedDataset): 44 | 45 | ZENODO_JSON_API_URL = lambda x : f"https://zenodo.org/api/{x}" 46 | 47 | TASK_TO_EVENTID = dict(RS='rest', MATBeasy='easy', MATBmed='medium', MATBdiff='difficult') 48 | 49 | def __init__(self, interval = [0, 2], channels = None, srate = None): 50 | super().__init__( 51 | subjects=list(range(1, 15+1)), 52 | sessions_per_subject=2, 53 | events=dict(easy=1, medium=2, difficult=3, rest=4), 54 | code="Hinss2021", 55 | interval=interval, 56 | paradigm="imagery", 57 | doi="10.5281/zenodo.4917217", 58 | channels=channels, 59 | srate=srate 60 | ) 61 | 62 | 63 | 64 | def preprocess(self, raw): 65 | # interpolate channels marked as bad 66 | if len(raw.info['bads']) > 0: 67 | raw.interpolate_bads() 68 | return super().preprocess(raw) 69 | 70 | def data_path( 71 | self, subject, path=None, force_update=False, update_path=None, verbose=None 72 | ): 73 | if subject not in self.subject_list: 74 | raise (ValueError("Invalid subject number")) 75 | 76 | key_dest = f"MNE-{self.code:s}-data" 77 | path = os.path.join(dl.get_dataset_path(self.code, path), key_dest) 78 | 79 | url = doi_to_url(self.doi) 80 | if url is None: 81 | raise ValueError("Could not find zenodo id based on dataset DOI!") 82 | 83 | zenodoid = url.split('/')[-1] 84 | 85 | metadata = url_get_json(Hinss2021.ZENODO_JSON_API_URL(f"records/{zenodoid}")) 86 | 87 | fnames = [] 88 | for record in metadata['files']: 89 | 90 | fname = record['key'] 91 | fpath = os.path.join(path, fname) 92 | 93 | 94 | # metadata 95 | # if record['type'] != 'zip' and not os.path.exists(fpath): # subject data 96 | # pooch.retrieve(record['links']['self'], record['checksum'], fname, path, downloader=pooch.HTTPDownloader(progressbar=True)) 97 | # subject specific data 98 | if record['type'] == 'zip' and fname == f"P{subject:02d}.zip": 99 | if not os.path.exists(fpath): 100 | files = pooch.retrieve(record['links']['self'], record['checksum'], fname, path, 101 | processor=pooch.Unzip(), 102 | downloader=pooch.HTTPDownloader(progressbar=True)) 103 | 104 | # load the data 105 | tasks = list(Hinss2021.TASK_TO_EVENTID.keys()) 106 | taskpattern = '('+ '|'.join(tasks)+')' 107 | pattern = f'{fpath}.unzip/P{subject:02d}/S?/eeg/alldata_*.set' 108 | candidates = glob.glob(pattern, recursive=True) 109 | fnames += [c for c in candidates if re.search(f'.*{taskpattern}.set', c)] 110 | 111 | return fnames 112 | 113 | 114 | def _get_single_subject_data(self, subject): 115 | fnames = self.data_path(subject) 116 | 117 | subject_data = {} 118 | for fn in fnames: 119 | meta = re.search('alldata_sbj(?P\d\d)_sess(?P\d)_((?P\w+))', 120 | os.path.basename(fn)) 121 | sid = int(meta['session']) 122 | 123 | if sid not in range(1,self.n_sessions+1): 124 | continue 125 | 126 | epochs = mne.io.read_epochs_eeglab(fn, verbose=False) 127 | assert(len(epochs.event_id) == 1) 128 | event_id = Hinss2021.TASK_TO_EVENTID[list(epochs.event_id.keys())[0]] 129 | epochs.event_id = {event_id : self.event_id[event_id]} 130 | epochs.events[:,2] = epochs.event_id[event_id] 131 | 132 | # covnert to continuous raw object with correct annotations 133 | continuous_data = np.swapaxes(epochs.get_data(),0,1).reshape((len(epochs.info['chs']),-1)) 134 | raw = mne.io.RawArray(data=continuous_data, info=epochs.info, verbose=False, first_samp=1) 135 | # XXX use standard electrode layout rather than invidividual positions 136 | # raw.set_montage(epochs.get_montage()) 137 | raw.set_montage('standard_1005') 138 | events = epochs.events.copy() 139 | evt_desc = dict(zip(epochs.event_id.values(),epochs.event_id.keys())) 140 | 141 | annot = mne.annotations_from_events(events, raw.info['sfreq'], event_desc=evt_desc, first_samp=1) 142 | 143 | raw.set_annotations(annot) 144 | 145 | if sid in subject_data: 146 | subject_data[sid][0].append(raw) 147 | else: 148 | subject_data[sid] = {0 : raw} 149 | 150 | # discard boundary annotations 151 | keep = [i for i, desc in enumerate(subject_data[sid][0].annotations.description) if desc in self.event_id] 152 | subject_data[sid][0].set_annotations(subject_data[sid][0].annotations[keep]) 153 | 154 | return subject_data 155 | -------------------------------------------------------------------------------- /datasets/spdnet/HDM05_Loader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | import numpy as np 5 | import torch as th 6 | import random 7 | 8 | from torch.utils import data 9 | 10 | device = 'cpu' 11 | class DatasetHDM05(data.Dataset): 12 | def __init__(self, path, names): 13 | self._path = path 14 | self._names = names 15 | 16 | def __len__(self): 17 | return len(self._names) 18 | 19 | def __getitem__(self, item): 20 | x = np.load(self._path + self._names[item])[None, :, :].real 21 | x = th.from_numpy(x).double() 22 | y = int(self._names[item].split('.')[0].split('_')[-1]) 23 | y = th.from_numpy(np.array(y)).long() 24 | return x.to(device), y.to(device) 25 | 26 | 27 | class DataLoaderHDM05: 28 | def __init__(self, data_path, pval, batch_size): 29 | for filenames in os.walk(data_path): 30 | names = sorted(filenames[2]) 31 | random.Random(1024).shuffle(names) 32 | N_test = int(pval * len(names)) 33 | train_set = DatasetHDM05(data_path, names[N_test:]) 34 | test_set = DatasetHDM05(data_path, names[:N_test]) 35 | self._train_generator = data.DataLoader(train_set, batch_size=batch_size, shuffle='True') 36 | self._test_generator = data.DataLoader(test_set, batch_size=batch_size, shuffle='False') -------------------------------------------------------------------------------- /datasets/spdnet/Radar_Loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch as th 4 | import random 5 | from torch.utils import data 6 | 7 | pval=0.25 #validation percentage 8 | ptest=0.25 #test percentage 9 | # th.cuda.device('cpu') 10 | 11 | class DatasetRadar(data.Dataset): 12 | def __init__(self, path, names): 13 | self._path = path 14 | self._names = names 15 | def __len__(self): 16 | return len(self._names) 17 | def __getitem__(self, item): 18 | x=np.load(self._path+self._names[item]) 19 | x=np.concatenate((x.real[:,None],x.imag[:,None]),axis=1).T 20 | x=th.from_numpy(x) 21 | y=int(self._names[item].split('.')[0].split('_')[-1]) 22 | y=th.from_numpy(np.array(y)) 23 | return x.float(),y.long() 24 | class DataLoaderRadar: 25 | def __init__(self,data_path,pval,batch_size): 26 | for filenames in os.walk(data_path): 27 | names=sorted(filenames[2]) 28 | random.Random().shuffle(names) 29 | N_val=int(pval*len(names)) 30 | N_test=int(ptest*len(names)) 31 | N_train=len(names)-N_test-N_val 32 | train_set=DatasetRadar(data_path,names[N_val+N_test:int(N_train)+N_test+N_val]) 33 | test_set=DatasetRadar(data_path,names[:N_test]) 34 | val_set=DatasetRadar(data_path,names[N_test:N_test+N_val]) 35 | self._train_generator=data.DataLoader(train_set,batch_size=batch_size,shuffle='True') 36 | self._test_generator=data.DataLoader(test_set,batch_size=batch_size,shuffle='False') 37 | self._val_generator=data.DataLoader(val_set,batch_size=batch_size,shuffle='False') -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: LieBN 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | # basic programs 7 | - python==3.8.* 8 | - pip 9 | # scientific python base packages 10 | - numpy==1.20.* 11 | - pandas==1.2.* 12 | - scipy==1.6.* 13 | # jupyter notebooks 14 | - ipykernel 15 | - notebook 16 | - jupyterlab 17 | - nb_conda_kernels 18 | # python visualization 19 | - matplotlib==3.4.* 20 | - seaborn==0.11.* 21 | # machine learning 22 | - scikit-learn==1.0.* 23 | - pytorch==2.2.* 24 | - torchvision==0.17.* 25 | - skorch==0.11.* 26 | - pip: 27 | # tensorboard 28 | - tensorboard==2.14.* 29 | # m/eeg analysis 30 | - mne==0.22.* 31 | - moabb==0.4.* 32 | # command line interfacing 33 | - hydra-core==1.3.* 34 | - hydra-joblib-launcher==1.2.* 35 | # machine learning 36 | - pyriemann==0.2.* 37 | - git+https://github.com/geoopt/geoopt.git@524330b11c0f9f6046bda59fe334803b4b74e13e 38 | # this package 39 | - -e . 40 | -------------------------------------------------------------------------------- /exp_eeg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | data_dir=/data #change this to your data folder 4 | ### Experiments on SPDDSMBN 5 | [ $? -eq 0 ] && python TSMNet-MLR.py -m data_dir=$data_dir evaluation=inter-subject+uda,inter-session+uda 6 | 7 | ### Experiments on SPDDSMBN+SPDMLR-(1,0)-LEM 8 | [ $? -eq 0 ] && python TSMNet-MLR.py -m data_dir=$data_dir evaluation=inter-subject+uda,inter-session+uda nnet.model.metric=SPDLogEuclideanMetric 9 | 10 | ### Experiments on SPDDSMBN+SPDMLR-(\theta)-LCM 11 | # inter-session inter-subject 12 | [ $? -eq 0 ] && python TSMNet-MLR.py -m data_dir=$data_dir evaluation=inter-session+uda,inter-subject+uda nnet.model.metric=SPDLogCholeskyMetric nnet.model.power=1.,1.5 -------------------------------------------------------------------------------- /exp_spdnets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #----RADAR---- 4 | radar_path=/data #change this to your data folder 5 | 6 | ### Experiments on SPDNet 7 | [ $? -eq 0 ] && python SPDNet-MLR.py -m dataset=RADAR dataset.path=$radar_path nnet.model.architecture=[20,16,14,12,10,8],[20,16,8] nnet.model.classifier=LogEigMLR 8 | 9 | ### Experiments on SPDNet-LEM 10 | [ $? -eq 0 ] && python SPDNet-MLR.py -m dataset=RADAR dataset.path=$radar_path nnet.model.architecture=[20,16,14,12,10,8],[20,16,8] nnet.model.classifier=SPDMLR\ 11 | nnet.model.metric=SPDLogEuclideanMetric nnet.model.beta=1.,0. 12 | 13 | ### Experiments on SPDNet-LCM 14 | [ $? -eq 0 ] && python SPDNet-MLR.py -m dataset=RADAR dataset.path=$radar_path nnet.model.architecture=[20,16,14,12,10,8],[20,16,8] nnet.model.classifier=SPDMLR\ 15 | nnet.model.metric=SPDLogCholeskyMetric nnet.model.power=1.,0.5 16 | 17 | #----HDM05---- 18 | hdm05_path=/data #change this to your data folder 19 | 20 | ### Experiments on SPDNet 21 | [ $? -eq 0 ] && python SPDNet-MLR.py -m dataset=HDM05 dataset.path=$hdm05_path nnet.model.architecture=[93,30],[93,70,30],[93,70,50,30] nnet.model.classifier=LogEigMLR 22 | 23 | ### Experiments on SPDNet-LEM 24 | [ $? -eq 0 ] && python SPDNet-MLR.py -m dataset=HDM05 dataset.path=$hdm05_path nnet.model.architecture=[93,30],[93,70,30],[93,70,50,30] nnet.model.classifier=SPDMLR\ 25 | nnet.model.metric=SPDLogEuclideanMetric 26 | 27 | ### Experiments on SPDNet-LCM 28 | [ $? -eq 0 ] && python SPDNet-MLR.py -m dataset=HDM05 dataset.path=$hdm05_path nnet.model.architecture=[93,30],[93,70,30],[93,70,50,30] nnet.model.classifier=SPDMLR\ 29 | nnet.model.metric=SPDLogCholeskyMetric nnet.model.power=1.,0.5 -------------------------------------------------------------------------------- /library/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/library/.DS_Store -------------------------------------------------------------------------------- /library/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/library/__init__.py -------------------------------------------------------------------------------- /library/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/library/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /library/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/library/utils/__init__.py -------------------------------------------------------------------------------- /library/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/library/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /library/utils/hydra/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hydra 3 | from hydra.core.hydra_config import HydraConfig 4 | from omegaconf import OmegaConf, DictConfig 5 | import torch 6 | 7 | from sklearn.pipeline import Pipeline 8 | 9 | def hydra_helpers(func): 10 | def inner(*args, **kwargs): 11 | # setup helpers 12 | 13 | # omega conf helpers 14 | OmegaConf.register_new_resolver("len", lambda x:len(x), replace=True) 15 | OmegaConf.register_new_resolver("add", lambda x,y:x+y, replace=True) 16 | OmegaConf.register_new_resolver("sub", lambda x,y:x-y, replace=True) 17 | OmegaConf.register_new_resolver("mul", lambda x,y:x*y, replace=True) 18 | OmegaConf.register_new_resolver("rdiv", lambda x,y:x/y, replace=True) 19 | 20 | STR2TORCHDTYPE = { 21 | 'float32': torch.float32, 22 | 'float64': torch.float64, 23 | 'double': torch.double, 24 | } 25 | OmegaConf.register_new_resolver("torchdtype", lambda x:STR2TORCHDTYPE[x], replace=True) 26 | # if func is not None and list(kwargs.keys())[0] !='args': 27 | if func is not None: 28 | # func(args = kwargs['args']) 29 | func(*args, **kwargs) 30 | return inner 31 | 32 | 33 | def make_sklearn_pipeline(steps_config) -> Pipeline: 34 | 35 | steps = [] 36 | for step_config in steps_config: 37 | 38 | # retrieve the name and parameter dictionary of the current steps 39 | step_name, step_transform = next(iter(step_config.items())) 40 | # instantiate the pipeline step, and append to the list of steps 41 | if isinstance(step_transform, DictConfig): 42 | pipeline_step = (step_name, hydra.utils.instantiate(step_transform, _convert_='partial')) 43 | else: 44 | pipeline_step = (step_name, step_transform) 45 | steps.append(pipeline_step) 46 | 47 | return Pipeline(steps) 48 | -------------------------------------------------------------------------------- /library/utils/hydra/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/library/utils/hydra/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /library/utils/moabb/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | import json 4 | import os 5 | import hashlib 6 | import numpy as np 7 | import pandas as pd 8 | import mne 9 | 10 | from sklearn.base import BaseEstimator 11 | from moabb.paradigms.base import BaseParadigm 12 | from moabb.paradigms.motor_imagery import FilterBankMotorImagery, MotorImagery 13 | from mne import get_config, set_config 14 | from mne.datasets.utils import _get_path 15 | from mne.io import read_info, write_info 16 | from skorch import dataset 17 | 18 | 19 | log = logging.getLogger(__name__) 20 | 21 | class CachedParadigm(BaseParadigm): 22 | 23 | def _get_string_rep(self, obj): 24 | if issubclass(type(obj), BaseEstimator): 25 | str_repr = repr(obj.get_params()) 26 | else: 27 | str_repr = repr(obj) 28 | str_no_addresses = re.sub("0x[a-z0-9]*", "0x__", str_repr) 29 | return str_no_addresses.replace("\n", "") 30 | 31 | def _get_rep(self, dataset): 32 | return self._get_string_rep(dataset) + '\n' + self._get_string_rep(self) 33 | 34 | def _get_cache_dir(self, rep): 35 | if get_config("MNEDATASET_TMP_DIR") is None: 36 | set_config("MNEDATASET_TMP_DIR", os.path.join(os.path.expanduser("~"), "mne_data")) 37 | base_dir = _get_path(None, "MNEDATASET_TMP_DIR", "preprocessed") 38 | 39 | digest = hashlib.sha1(rep.encode("utf8")).hexdigest() 40 | 41 | cache_dir = os.path.join( 42 | base_dir, 43 | "preprocessed", 44 | digest 45 | ) 46 | return cache_dir 47 | 48 | 49 | def process_raw(self, raw, dataset, return_epochs=False,return_raws=False): 50 | # get events id 51 | event_id = self.used_events(dataset) 52 | 53 | # find the events, first check stim_channels then annotations 54 | stim_channels = mne.utils._get_stim_channel(None, raw.info, 55 | raise_error=False) 56 | if len(stim_channels) > 0: 57 | events = mne.find_events(raw, shortest_event=0, verbose=False) 58 | else: 59 | events, _ = mne.events_from_annotations(raw, event_id=event_id, verbose=False) 60 | 61 | # picks channels 62 | if self.channels is None: 63 | picks = mne.pick_types(raw.info, eeg=True, stim=False) 64 | else: 65 | picks = mne.pick_types(raw.info, stim=False, include=self.channels) 66 | 67 | # pick events, based on event_id 68 | try: 69 | events = mne.pick_events(events, include=list(event_id.values())) 70 | except RuntimeError: 71 | # skip raw if no event found 72 | return 73 | 74 | # get interval 75 | tmin = self.tmin + dataset.interval[0] 76 | if self.tmax is None: 77 | tmax = dataset.interval[1] 78 | else: 79 | tmax = self.tmax + dataset.interval[0] 80 | 81 | X = [] 82 | for bandpass in self.filters: 83 | fmin, fmax = bandpass 84 | # filter data 85 | if fmin is None and fmax is None: 86 | raw_f = raw 87 | else: 88 | raw_f = raw.copy().filter(fmin, fmax, method='iir', 89 | picks=picks, verbose=False) 90 | # epoch data 91 | epochs = mne.Epochs(raw_f, events, event_id=event_id, 92 | tmin=tmin, tmax=tmax, proj=False, 93 | baseline=None, preload=True, 94 | verbose=False, picks=picks, 95 | event_repeated='drop', 96 | on_missing='ignore') 97 | if self.resample is not None: 98 | epochs = epochs.resample(self.resample) 99 | # rescale to work with uV 100 | if return_epochs: 101 | X.append(epochs) 102 | else: 103 | X.append(dataset.unit_factor * epochs.get_data()) 104 | 105 | inv_events = {k: v for v, k in event_id.items()} 106 | labels = np.array([inv_events[e] for e in epochs.events[:, -1]]) 107 | 108 | # if only one band, return a 3D array, otherwise return a 4D 109 | if len(self.filters) == 1: 110 | X = X[0] 111 | else: 112 | X = np.array(X).transpose((1, 2, 3, 0)) 113 | 114 | metadata = pd.DataFrame(index=range(len(labels))) 115 | return X, labels, metadata 116 | 117 | 118 | def get_data(self, dataset, subjects=None, return_epochs=False): 119 | 120 | if return_epochs: 121 | raise ValueError("Only return_epochs=False is supported.") 122 | 123 | rep = self._get_rep(dataset) 124 | cache_dir = self._get_cache_dir(rep) 125 | os.makedirs(cache_dir, exist_ok=True) 126 | 127 | X = [] if return_epochs else np.array([]) 128 | labels = [] 129 | metadata = pd.Series([]) 130 | 131 | if subjects is None: 132 | subjects = dataset.subject_list 133 | 134 | if not os.path.isfile(os.path.join(cache_dir, 'repr.json')): 135 | with open(os.path.join(cache_dir, 'repr.json'), 'w+') as f: 136 | f.write(self._get_rep(dataset)) 137 | 138 | for subject in subjects: 139 | if not os.path.isfile(os.path.join(cache_dir, f'{subject}.npy')): 140 | # compute 141 | x, lbs, meta = super().get_data(dataset, [subject], return_epochs) 142 | np.save(os.path.join(cache_dir, f'{subject}.npy'), x) 143 | meta['label'] = lbs 144 | meta.to_csv(os.path.join(cache_dir, f'{subject}.csv'), index=False) 145 | log.info(f'saved cached data in directory {cache_dir}') 146 | 147 | # load from cache 148 | log.info(f'loading cached data from directory {cache_dir}') 149 | x = np.load(os.path.join(cache_dir, f'{subject}.npy'), mmap_mode ='r') 150 | meta = pd.read_csv(os.path.join(cache_dir, f'{subject}.csv')) 151 | lbs = meta['label'].tolist() 152 | 153 | if return_epochs: 154 | X.append(x) 155 | else: 156 | X = np.append(X, x, axis=0) if len(X) else x 157 | labels = np.append(labels, lbs, axis=0) 158 | metadata = pd.concat([metadata, meta], ignore_index=True) 159 | 160 | return X, labels, metadata 161 | 162 | def get_info(self, dataset): 163 | # check if the info has been saved 164 | rep = self._get_rep(dataset) 165 | cache_dir = self._get_cache_dir(rep) 166 | os.makedirs(cache_dir, exist_ok=True) 167 | info_file = os.path.join(cache_dir, f'raw-info.fif') 168 | if not os.path.isfile(info_file): 169 | x, _, _ = super().get_data(dataset, [dataset.subject_list[0]], True) 170 | info = x.info 171 | write_info(info_file, info) 172 | log.info(f'saved cached info in directory {cache_dir}') 173 | else: 174 | log.info(f'loading cached info from directory {cache_dir}') 175 | info = read_info(info_file) 176 | return info 177 | 178 | def __repr__(self) -> str: 179 | return json.dumps({self.__class__.__name__: self.__dict__}) 180 | 181 | 182 | class CachedMotorImagery(CachedParadigm, MotorImagery): 183 | 184 | def __init__(self, **kwargs): 185 | n_classes = len(kwargs['events']) 186 | super().__init__(n_classes=n_classes, **kwargs) 187 | 188 | 189 | class CachedFilterBankMotorImagery(CachedParadigm, FilterBankMotorImagery): 190 | 191 | def __init__(self, **kwargs): 192 | n_classes = len(kwargs['events']) 193 | super().__init__(n_classes=n_classes, **kwargs) 194 | 195 | -------------------------------------------------------------------------------- /library/utils/moabb/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/library/utils/moabb/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /library/utils/pyriemann/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.linalg import eigvalsh 3 | from pyriemann.tangentspace import TangentSpace 4 | 5 | def squared_airm(A, B): 6 | return np.square(np.log(eigvalsh(A, B))).sum() 7 | 8 | def airm(A,B): 9 | return np.sqrt(squared_airm(A,B)) 10 | 11 | def geom_mean(As): 12 | ts = TangentSpace() 13 | ts.fit(As) 14 | return ts.reference_ 15 | 16 | def tsm(As): 17 | ts = TangentSpace() 18 | ts.fit(As) 19 | return ts.transform(As) -------------------------------------------------------------------------------- /library/utils/torch/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from importlib.metadata import metadata 3 | from geoopt.manifolds.sphere import Sphere 4 | from omegaconf.dictconfig import DictConfig 5 | from torch.utils.data.dataloader import DataLoader 6 | from torch.utils.data.dataset import Subset, TensorDataset 7 | from library.utils.moabb import CachedParadigm 8 | from spdnets.manifolds import SymmetricPositiveDefinite 9 | from geoopt.tensor import ManifoldParameter 10 | from geoopt.manifolds import Stiefel 11 | import numpy as np 12 | from copy import deepcopy 13 | import time 14 | import torch 15 | from typing import Iterator, Sequence, Tuple 16 | from sklearn.model_selection import StratifiedKFold 17 | from datasets.eeg.moabb import CachableDatase 18 | from sklearn.preprocessing import LabelEncoder 19 | import pandas as pd 20 | import warnings 21 | 22 | 23 | class BufferDataset(torch.utils.data.Dataset): 24 | 25 | def __init__(self, items) -> None: 26 | super().__init__() 27 | self.items = items 28 | 29 | def __len__(self): 30 | return self.items[0].shape[0] 31 | 32 | def __getitem__(self, index): 33 | return [item[index] for item in self.items] 34 | 35 | 36 | class DomainIndex(int): 37 | ''' 38 | Place holder class to get an entire domain 39 | ''' 40 | pass 41 | 42 | 43 | class DomainDataset(torch.utils.data.Dataset): 44 | 45 | def __init__(self, 46 | labels : torch.LongTensor, 47 | domains : torch.LongTensor, 48 | metadata : pd.DataFrame, 49 | training : bool = True, 50 | dtype : torch.dtype = torch.double, 51 | mask_indices : Sequence[int] = None): 52 | 53 | self.dtype = dtype 54 | self._training = training 55 | self._metadata = metadata 56 | self._mask_indices = mask_indices 57 | assert(len(metadata) == len(labels)) 58 | assert(len(metadata) == len(domains)) 59 | self._metadata = metadata 60 | self._domains = domains 61 | self._labels = labels 62 | 63 | @property 64 | def features(self) -> torch.Tensor: 65 | return self.get_features(range(len(self))).to(dtype=self.dtype) 66 | 67 | @property 68 | def metadata(self) -> pd.DataFrame: 69 | return self._metadata 70 | 71 | @property 72 | def domains(self) -> torch.Tensor: 73 | return self._domains 74 | 75 | @property 76 | def labels(self) -> torch.Tensor: 77 | labels = self._labels.clone() 78 | if self._mask_indices is not None and self.training: 79 | labels[self._mask_indices] = -1 80 | return labels 81 | 82 | @property 83 | def training(self) -> bool: 84 | return self._training 85 | 86 | @property 87 | def shape(self): 88 | raise NotImplementedError() 89 | 90 | @property 91 | def ndim(self): 92 | raise NotImplementedError() 93 | 94 | def train(self): 95 | self._training = True 96 | 97 | def eval(self): 98 | self._training = False 99 | 100 | def set_masked_labels(self, indices): 101 | self._mask_indices = indices 102 | 103 | def get_feature(self, index : int) -> torch.Tensor: 104 | raise NotImplementedError() 105 | 106 | def get_features(self, indices) -> torch.Tensor: 107 | raise NotImplementedError() 108 | 109 | def copy(self, deep=False): 110 | raise NotImplementedError() 111 | 112 | def __len__(self): 113 | return len(self.metadata) 114 | 115 | def __getitem__(self, index): 116 | if isinstance(index, DomainIndex): 117 | # load the data of an entire domain 118 | indices = np.flatnonzero(self.domains.numpy() == index) 119 | features = self.get_features(indices) 120 | return [dict(x=features.to(dtype=self.dtype),d=self.domains[indices]), self.labels[indices]] 121 | else: 122 | feature = self.get_feature(index) 123 | return [dict(x=feature.to(dtype=self.dtype),d=self.domains[index]), self.labels[index]] 124 | 125 | 126 | class CachedDomainDataset(DomainDataset): 127 | 128 | def __init__(self, features, **kwargs) -> None: 129 | super().__init__(**kwargs) 130 | assert(len(self) == len(features)) 131 | self._features = features 132 | 133 | @property 134 | def shape(self): 135 | return self._features.shape 136 | 137 | @property 138 | def ndim(self): 139 | return self._features.ndim 140 | 141 | def get_feature(self, index : int) -> torch.Tensor: 142 | return self._features[index] 143 | 144 | def get_features(self, indices) -> torch.Tensor: 145 | return self._features[indices] 146 | 147 | def set_features(self, features) -> torch.Tensor: 148 | if isinstance(features, np.ndarray): 149 | self._features = torch.from_numpy(features) 150 | elif isinstance(features, torch.Tensor): 151 | self._features = features 152 | else: 153 | raise ValueError() 154 | 155 | def copy(self, deep=False): 156 | 157 | features = self._features.clone() if deep else self._features 158 | labels = self._labels.clone() if deep else self._labels 159 | domains = self._domains.clone() if deep else self._domains 160 | 161 | obj = CachedDomainDataset(features, labels=labels, domains=domains, 162 | metadata=self._metadata.copy(deep=deep), 163 | training=self.training, dtype=self.dtype, 164 | mask_indices=self._mask_indices) 165 | return obj 166 | 167 | 168 | class CombinedDomainDataset(DomainDataset, torch.utils.data.ConcatDataset): 169 | 170 | def __init__(self, features : Sequence[torch.Tensor], **kwargs): 171 | 172 | torch.utils.data.ConcatDataset.__init__(self, features) 173 | DomainDataset.__init__(self, **kwargs) 174 | 175 | @classmethod 176 | def from_moabb(cls, paradigm : CachedParadigm, ds : CachableDatase, 177 | subjects : list = None, domain_expression = "session + subject * 1000", sessions : DictConfig = None, 178 | **kwargs): 179 | if subjects is None: 180 | subjects = ds.subject_list 181 | features = [] 182 | metadata = [] 183 | labels = [] 184 | with warnings.catch_warnings(): 185 | warnings.simplefilter("ignore", UserWarning) 186 | for ix, subject in enumerate(subjects): 187 | x, l, md = paradigm.get_data(ds, [subject], False) 188 | 189 | if sessions is not None: 190 | unique_sessions = md.session.unique() 191 | if 'order' in sessions and sessions['order']== 'last': 192 | unique_sessions = unique_sessions[::-1] 193 | msk = md.session.isin(unique_sessions[0:sessions.get('n', len(unique_sessions))]) 194 | x = x[msk] 195 | l = l[msk] 196 | md = md[msk] 197 | 198 | features += [torch.from_numpy(x)] 199 | md['setindex'] = ix 200 | metadata += [md] 201 | labels += [l] 202 | metadata = pd.concat(metadata, ignore_index=True) 203 | labels = torch.from_numpy(LabelEncoder().fit_transform(np.concatenate(labels))).to(dtype=torch.long) 204 | domains = torch.from_numpy(metadata.eval(domain_expression).to_numpy(dtype=np.int64)) 205 | 206 | return CombinedDomainDataset(features=features, labels=labels, domains=domains, metadata=metadata, **kwargs) 207 | 208 | @property 209 | def shape(self): 210 | shape = list(self.datasets[0].shape) 211 | shape[0] = len(self) 212 | return tuple(shape) 213 | 214 | @property 215 | def ndim(self): 216 | return self.datasets[0].ndim 217 | 218 | def get_feature(self, index : int) -> torch.Tensor: 219 | return torch.utils.data.ConcatDataset.__getitem__(self, index) 220 | 221 | def get_features(self, indices) -> torch.Tensor: 222 | setix = self.metadata.loc[indices, 'setindex'].unique() 223 | if len(setix) > 1: 224 | raise ValueError('Domain data has to be contained in a single subset!') 225 | setix = setix[0] 226 | 227 | if setix == 0: 228 | subindices = indices 229 | else: 230 | subindices = indices - self.cumulative_sizes[setix - 1] 231 | 232 | return self.datasets[setix][subindices] 233 | 234 | def cache(self) -> CachedDomainDataset: 235 | features = torch.cat([ds.to(dtype=self.dtype) for ds in self.datasets]) 236 | obj = CachedDomainDataset(features, labels=self.labels, domains=self._domains, metadata=self.metadata, 237 | training = self.training, dtype=self.dtype, 238 | mask_indices = self._mask_indices) 239 | return obj 240 | 241 | def copy(self, deep=False): 242 | features = [dataset.clone() if deep else dataset for dataset in self.datasets] 243 | obj = CombinedDomainDataset(features, labels=self.labels, domains=self._domains, metadata=self.metadata, 244 | training = self.training, dtype=self.dtype, 245 | mask_indices = self._mask_indices) 246 | return obj 247 | 248 | 249 | class StratifyableDataset(torch.utils.data.Dataset): 250 | def __init__(self, dataset, stratvar) -> None: 251 | super().__init__() 252 | self.dataset = dataset 253 | self.stratvar = stratvar 254 | assert(self.stratvar.shape[0] == len(dataset)) 255 | 256 | def __len__(self): 257 | return len(self.dataset) 258 | 259 | def __getitem__(self, index): 260 | return self.dataset[index] 261 | 262 | class BalancedDomainDataLoader(DataLoader): 263 | 264 | def __init__(self, dataset = None, batch_size = 1, domains_per_batch = 1, shuffle=True, replacement=False, **kwargs): 265 | if isinstance(dataset, Subset) and isinstance(dataset.dataset, DomainDataset): 266 | domains = dataset.dataset.domains[dataset.indices] 267 | elif isinstance(dataset, DomainDataset): 268 | domains = dataset.domains 269 | else: 270 | raise NotImplementedError() 271 | sampler = BalancedDomainSampler(domains, 272 | int(batch_size/domains_per_batch), 273 | shuffle=shuffle, replacement=replacement) 274 | super().__init__(dataset=dataset, sampler=sampler, batch_size=batch_size, **kwargs) 275 | 276 | 277 | class StratifiedDataLoader(DataLoader): 278 | 279 | def __init__(self, dataset = None, batch_size = 1, shuffle=True, **kwargs): 280 | if isinstance(dataset, Subset) and isinstance(dataset.dataset, StratifyableDataset): 281 | stratvar = dataset.dataset.stratvar[dataset.indices] 282 | elif isinstance(dataset, StratifyableDataset): 283 | stratvar = dataset.stratvar 284 | else: 285 | raise NotImplementedError() 286 | 287 | sampler = StratifiedSampler(stratvar=stratvar, batch_size=batch_size, shuffle=shuffle) 288 | super().__init__(dataset=dataset, sampler=sampler, batch_size=batch_size, **kwargs) 289 | 290 | 291 | class StratifiedDomainDataLoader(DataLoader): 292 | 293 | def __init__(self, dataset = None, batch_size = 1, domains_per_batch = 1, shuffle=True, **kwargs): 294 | 295 | if isinstance(dataset, Subset) and isinstance(dataset.dataset, Subset) and isinstance(dataset.dataset.dataset, (DomainDataset, CachedDomainDataset)): 296 | domains = dataset.dataset.dataset.domains[dataset.dataset.indices][dataset.indices] 297 | labels = dataset.dataset.dataset.domains[dataset.dataset.indices][dataset.indices] 298 | elif isinstance(dataset, Subset) and isinstance(dataset.dataset, (DomainDataset, CachedDomainDataset)): 299 | domains = dataset.dataset.domains[dataset.indices] 300 | labels = dataset.dataset.domains[dataset.indices] 301 | elif isinstance(dataset, (DomainDataset, CachedDomainDataset)): 302 | domains = dataset.domains 303 | labels = dataset.labels 304 | else: 305 | raise NotImplementedError() 306 | 307 | sampler = StratifiedDomainSampler(domains, labels, 308 | int(batch_size/domains_per_batch), domains_per_batch, 309 | shuffle=shuffle) 310 | 311 | super().__init__(dataset=dataset, sampler=sampler, batch_size=batch_size, **kwargs) 312 | 313 | 314 | def sample_without_replacement(domainlist, shuffle = True): 315 | du, counts = domainlist.unique(return_counts=True) 316 | dl = [] 317 | while counts.sum() > 0: 318 | mask = counts > 0 319 | if shuffle: 320 | ixs = torch.randperm(du[mask].shape[0]) 321 | else: 322 | ixs = range(du[mask].shape[0]) 323 | counts[mask] -= 1 324 | dl.append(du[mask][ixs]) 325 | return torch.cat(dl, dim=0) 326 | 327 | 328 | class BalancedDomainSampler(torch.utils.data.Sampler[int]): 329 | def __init__(self, domains, samples_per_domain:int, shuffle = False, replacement = True) -> None: 330 | super().__init__(domains) 331 | self.samples_per_domain = samples_per_domain 332 | self.shuffle = shuffle 333 | self.replacement = replacement 334 | 335 | du, didxs, counts = domains.unique(return_inverse=True, return_counts=True) 336 | du = du.tolist() 337 | didxs = didxs.tolist() 338 | counts = counts.tolist() 339 | 340 | self.domainlist = torch.cat( 341 | [domain * torch.ones((counts[ix]//self.samples_per_domain), 342 | dtype=torch.long) for ix, domain in enumerate(du)]) 343 | 344 | self.domaindict = {} 345 | for domix, domid in enumerate(du): 346 | self.domaindict[domid] = torch.LongTensor( 347 | [idx for idx,dom in enumerate(didxs) if dom == domix]) 348 | 349 | def __iter__(self) -> Iterator[int]: 350 | if self.shuffle: 351 | if self.replacement: 352 | permidxs = torch.randperm(self.domainlist.shape[0]) 353 | domainlist = self.domainlist[permidxs] 354 | else: 355 | domainlist = sample_without_replacement(self.domainlist, shuffle=True) 356 | else: 357 | if self.replacement: 358 | domainlist = self.domainlist 359 | else: 360 | domainlist = sample_without_replacement(self.domainlist, shuffle=False) 361 | 362 | generators = {} 363 | for domain in self.domaindict.keys(): 364 | if self.shuffle: 365 | permidxs = torch.randperm(self.domaindict[domain].shape[0]) 366 | else: 367 | permidxs = range(self.domaindict[domain].shape[0]) 368 | generators[domain] = iter( 369 | torch.utils.data.BatchSampler( 370 | self.domaindict[domain][permidxs].tolist(), 371 | batch_size=self.samples_per_domain, drop_last=True)) 372 | 373 | for item in domainlist.tolist(): 374 | batch = next(generators[item]) 375 | yield from batch 376 | 377 | def __len__(self) -> int: 378 | return len(self.domainlist) * self.samples_per_domain 379 | 380 | 381 | class StratifiedDomainSampler(): 382 | 383 | def __init__(self, domains, stratvar, samples_per_domain, domains_per_batch, shuffle = True) -> None: 384 | self.samples_per_domain = samples_per_domain 385 | self.domains_per_batch = domains_per_batch 386 | self.shuffle = shuffle 387 | self.stratvar = stratvar 388 | 389 | du, didxs, counts = domains.unique(return_inverse=True, return_counts=True) 390 | du = du.tolist() 391 | didxs = didxs.tolist() 392 | 393 | self.domaincounts = torch.LongTensor((counts/self.samples_per_domain).tolist()) 394 | 395 | self.domaindict = {} 396 | for domix, _ in enumerate(du): 397 | self.domaindict[domix] = torch.LongTensor( 398 | [idx for idx,dom in enumerate(didxs) if dom == domix]) 399 | 400 | def __iter__(self) -> Iterator[int]: 401 | 402 | domaincounts = self.domaincounts.clone() 403 | 404 | generators = {} 405 | for domain in self.domaindict.keys(): 406 | if self.shuffle: 407 | permidxs = torch.randperm(self.domaindict[domain].shape[0]) 408 | else: 409 | permidxs = range(self.domaindict[domain].shape[0]) 410 | generators[domain] = \ 411 | iter( 412 | StratifiedSampler( 413 | self.stratvar[self.domaindict[domain]], 414 | batch_size=self.samples_per_domain, 415 | shuffle=self.shuffle 416 | )) 417 | # torch.utils.data.BatchSampler( 418 | # self.domaindict[domain][permidxs].tolist(), 419 | # batch_size=self.samples_per_domain, drop_last=True)) 420 | 421 | 422 | while domaincounts.sum() > 0: 423 | 424 | assert((domaincounts >= 0).all()) 425 | # candidates = [ix for ix, count in enumerate(domaincounts.tolist()) if count > 0] 426 | candidates = torch.nonzero(domaincounts, as_tuple=False).flatten() 427 | if candidates.shape[0] < self.domains_per_batch: 428 | break 429 | 430 | # candidates = torch.LongTensor(candidates) 431 | permidxs = torch.randperm(candidates.shape[0]) 432 | candidates = candidates[permidxs] 433 | 434 | # icap = min(len(candidates), self.domains_per_batch) 435 | batchdomains = candidates[:self.domains_per_batch] 436 | 437 | for item in batchdomains.tolist(): 438 | within_domain_idxs = [next(generators[item]) for _ in range(self.samples_per_domain)] 439 | batch = self.domaindict[item][within_domain_idxs] 440 | # batch = next(generators[item]) 441 | domaincounts[item] = domaincounts[item] - 1 442 | yield from batch 443 | yield from [] 444 | 445 | def __len__(self) -> int: 446 | return self.domaincounts.sum() * self.samples_per_domain 447 | 448 | 449 | class StratifiedSampler(torch.utils.data.Sampler[int]): 450 | """Stratified Sampling 451 | Provides equal representation of target classes in each batch 452 | """ 453 | def __init__(self, stratvar, batch_size, shuffle = True): 454 | self.n_splits = max(int(stratvar.shape[0] / batch_size), 2) 455 | self.stratvar = stratvar 456 | self.shuffle = shuffle 457 | 458 | def gen_sample_array(self): 459 | s = StratifiedKFold(n_splits=self.n_splits, shuffle=self.shuffle) 460 | indices = [test for _, test in s.split(self.stratvar, self.stratvar)] 461 | return np.hstack(indices) 462 | 463 | def __iter__(self): 464 | return iter(self.gen_sample_array()) 465 | 466 | def __len__(self): 467 | return len(self.stratvar) 468 | -------------------------------------------------------------------------------- /library/utils/torch/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/library/utils/torch/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='spddsmbn', 4 | version='1.0.0', 5 | packages=find_packages()) -------------------------------------------------------------------------------- /spd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/spd/__init__.py -------------------------------------------------------------------------------- /spd/functional.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | 4 | def trace(A): 5 | """" 6 | compute the batch trace of A [...,n,n] 7 | """ 8 | # trace_vec = th.diagonal(A, dim1=-2, dim2=-1).sum(dim=-1) 9 | r_trace = th.einsum("...ii->...", A) 10 | return r_trace 11 | 12 | def inner_product(A, B): 13 | """" 14 | compute the batch inner product of A and B, with [...,n,n] [...,n,n] 15 | """ 16 | r_inner_prod = th.einsum("...ij,...ij->...", A, B) 17 | return r_inner_prod 18 | 19 | 20 | def tril_half_diag(A): 21 | """"[...n,n] A, strictly lower part + 1/2 * half of diagonal part""" 22 | str_tril_A = A.tril(-1) 23 | diag_A_vec = th.diagonal(A, dim1=-2, dim2=-1) 24 | half_diag_A = str_tril_A + 0.5 * th.diag_embed(diag_A_vec) 25 | return half_diag_A 26 | -------------------------------------------------------------------------------- /spd/spd_matrices.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch as th 3 | import torch.nn as nn 4 | 5 | from spdnets.functionals import sym_powm,sym_logm 6 | from spd.functional import inner_product,trace,tril_half_diag 7 | 8 | class SPDMatrices(nn.Module): 9 | """Computation for SPD data with [...,n,n]""" 10 | def __init__(self, n,power=1.): 11 | super().__init__() 12 | self.n=n; self.dim = int(n * (n + 1) / 2) 13 | self.register_buffer('power', th.tensor(power)) 14 | self.register_buffer('I', th.eye(n)) 15 | 16 | if power == 0: 17 | raise Exception('power should not be zero with power={:.4f}'.format(power)) 18 | self.sgn_power = -1 if self.power < 0 else 1 19 | 20 | def spd_pow(self, S): 21 | if self.power == 1.: 22 | Power_S = S; 23 | else: 24 | Power_S = sym_powm.apply(S, self.power) 25 | return Power_S 26 | 27 | def RMLR(self, S, P, A): 28 | """ 29 | RMLR based on margin distance, generating A by parallel transportation 30 | Inputs: 31 | S: [b,c,n,n] SPD 32 | P: [class,n,n] SPD matrices 33 | A: [class,n,n] symmetric matrices 34 | """ 35 | raise NotImplementedError 36 | 37 | class SPDOnInvariantMetric(SPDMatrices): 38 | """ 39 | Computation for SPD data with [b,c,n,n], the base class of (\theta,\alpha,\beta)-EM/LEM/AIM/ 40 | """ 41 | def __init__(self, n, alpha=1.0, beta=0.,power=1.): 42 | super(__class__, self).__init__(n,power) 43 | if alpha <= 0 or beta <= -alpha / n: 44 | raise Exception('wrong alpha or beta with alpha={:.4f},beta={:.4f}'.format(alpha, beta)) 45 | self.alpha = alpha;self.beta = beta; 46 | self.p = (self.alpha + self.n * self.beta) ** 0.5 47 | self.q = self.alpha ** 0.5 48 | 49 | def alpha_beta_Euc_inner_product(self, tangent_vector1, tangent_vector2): 50 | """"computing the O(n)-invariant Euclidean inner product""" 51 | if self.alpha==1. and self.beta==0.: 52 | X_new = inner_product(tangent_vector1, tangent_vector2) 53 | else: 54 | item1 = inner_product(tangent_vector1, tangent_vector2) 55 | trace_vec1 = trace(tangent_vector1) 56 | trace_vec2 = trace(tangent_vector2) 57 | item2 = th.mul(trace_vec1, trace_vec2) 58 | X_new = self.alpha * item1 + self.beta * item2 59 | return X_new 60 | 61 | class SPDLogEuclideanMetric(SPDOnInvariantMetric): 62 | """ (\alpha,\beta)-LEM """ 63 | def __init__(self,n,alpha=1.0, beta=0.): 64 | super(__class__, self).__init__(n,alpha, beta) 65 | 66 | def RMLR(self, S, P, A): 67 | P_phi = sym_logm.apply(P) 68 | S_phi = sym_logm.apply(S) 69 | X_new = self.alpha_beta_Euc_inner_product(S_phi - P_phi, A) 70 | 71 | return X_new 72 | 73 | class SPDLogCholeskyMetric(SPDMatrices): 74 | """ \theta-LCM """ 75 | def __init__(self, n,power=1.): 76 | super(__class__, self).__init__(n,power) 77 | 78 | def RMLR(self, S, P, A): 79 | Power_S = self.spd_pow(S) 80 | Power_P = self.spd_pow(P) 81 | 82 | Chol_of_Power_S = th.linalg.cholesky(Power_S) 83 | Chol_of_Power_P = th.linalg.cholesky(Power_P) 84 | 85 | item1_diag_vec = th.log(th.diagonal(Chol_of_Power_S, dim1=-2, dim2=-1)) - th.log(th.diagonal(Chol_of_Power_P, dim1=-2, dim2=-1)) 86 | item1 = Chol_of_Power_S.tril(-1) - Chol_of_Power_P.tril(-1) + th.diag_embed(item1_diag_vec) 87 | X_new = (1 / self.power) * inner_product(item1, tril_half_diag(A)) 88 | 89 | return X_new -------------------------------------------------------------------------------- /spdnets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/spdnets/.DS_Store -------------------------------------------------------------------------------- /spdnets/SPDMLR.py: -------------------------------------------------------------------------------- 1 | """ 2 | Official Implementation of the SPD MLR presented in 3 | @inproceedings{chen2024spdmlr, 4 | title={Riemannian Multinomial Logistics Regression for SPD Neural Networks}, 5 | author={Chen, Ziheng and Song, Yue and Liu, Gaowen and Kompella, Ramana Rao and Wu, Xiaojun and Sebe, Nicu}, 6 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 7 | year={2024} 8 | } 9 | """ 10 | 11 | import math 12 | import geoopt 13 | import torch as th 14 | import torch.nn as nn 15 | 16 | import spd.spd_matrices as spd_matrices 17 | 18 | class SPDRMLR(nn.Module): 19 | def __init__(self,n,c,metric='SPDLogEuclideanMetric',power=1.0,alpha=1.0,beta=0.): 20 | """ 21 | Input X: (N,h,n,n) SPD matrices 22 | Output P: (N,dim_vec) vectors 23 | SPD parameter of size (c,n,n), where c denotes the number of classes 24 | Sym parameters (c,n,n) 25 | """ 26 | super(__class__, self).__init__() 27 | self.n = n;self.c = c; 28 | self.metric = metric;self.power = power;self.alpha = alpha;self.beta = beta; 29 | 30 | self.P = geoopt.ManifoldParameter(th.empty(c, n, n), manifold=geoopt.manifolds.SymmetricPositiveDefinite()) 31 | init_3Didentity(self.P) 32 | self.A = nn.Parameter(th.zeros_like(self.P)) 33 | init_matrix_uniform(self.A, int(n ** 2)) 34 | 35 | if self.metric == 'SPDLogEuclideanMetric': 36 | self.spd = spd_matrices.SPDLogEuclideanMetric(n=self.n,alpha=self.alpha,beta=self.beta) 37 | elif self.metric=='SPDLogCholeskyMetric': 38 | self.power = power; 39 | self.spd = spd_matrices.SPDLogCholeskyMetric(n=self.n,power=self.power) 40 | else: 41 | raise Exception('unknown metric {}'.format(metric)) 42 | 43 | def forward(self,X): 44 | A_sym = symmetrize_by_tril(self.A) 45 | X_new = self.spd.RMLR(X, self.P, A_sym) 46 | return X_new 47 | 48 | def __repr__(self): 49 | return f"{self.__class__.__name__}(n={self.n},c={self.c},metric={self.metric},power={self.power},alpha={self.alpha},beta={self.beta})" 50 | 51 | def symmetrize_by_tril(A): 52 | """" 53 | symmetrize A by the lower part of A, with [...,n,n] 54 | """ 55 | str_tril_A = A.tril(-1) 56 | diag_A_vec = th.diagonal(A, dim1=-2, dim2=-1) 57 | tmp_A_sym = str_tril_A + str_tril_A.transpose(-1, -2) + th.diag_embed(diag_A_vec) 58 | return tmp_A_sym 59 | 60 | def init_matrix_uniform(A,fan_in,factor=6): 61 | bound = math.sqrt(factor / fan_in) if fan_in > 0 else 0 62 | nn.init.uniform_(A, -bound, bound) 63 | 64 | def init_3Didentity(S): 65 | """ initializes to identity a (h,ni,no) 3D-SPDParameter""" 66 | h,n1,n2=S.shape 67 | for i in range(h): 68 | S.data[i] = th.eye(n1, n2) -------------------------------------------------------------------------------- /spdnets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/spdnets/__init__.py -------------------------------------------------------------------------------- /spdnets/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/spdnets/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /spdnets/__pycache__/batchnorm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/spdnets/__pycache__/batchnorm.cpython-310.pyc -------------------------------------------------------------------------------- /spdnets/__pycache__/manifolds.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/spdnets/__pycache__/manifolds.cpython-310.pyc -------------------------------------------------------------------------------- /spdnets/__pycache__/modules.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/spdnets/__pycache__/modules.cpython-310.pyc -------------------------------------------------------------------------------- /spdnets/cplx/functional.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | def conv_cplx(X,conv_re,conv_im): 4 | if(isinstance(X,list)): 5 | XX=[x.split(x.shape[0]//2,0) for x in X] 6 | tmp=X[0].split(X[0].shape[0]//2,0) 7 | tmpp=conv_re(tmp[0][None,...])-conv_im(tmp[1][None,...]) 8 | P_Re=[conv_re(xx[0][None,...])-conv_im(xx[1][None,...]) for xx in XX] 9 | P_Im=[conv_im(xx[0][None,...])+conv_re(xx[1][None,...]) for xx in XX] 10 | P=[th.cat((p_Re,p_Im),0) for p_Re,p_Im in zip(P_Re,P_Im)] 11 | else: 12 | XX=X.split(X.shape[1]//2,1) 13 | P_Re=conv_re(XX[0])-conv_im(XX[1]) 14 | P_Im=conv_im(XX[0])+conv_re(XX[1]) 15 | P=th.cat((P_Re,P_Im),1) 16 | return P 17 | 18 | def split_signal_cplx(X,conv_re,conv_im): 19 | ''' 20 | 1D to 2D complex conv layer, where the weights are adequately placed zeroes and ones to split a signal 21 | ''' 22 | XX=X.split(X.shape[1]//2,1) 23 | P_Re=conv_re(XX[0]) 24 | P_Im=conv_re(XX[1]) 25 | P=th.cat((P_Re[:,None,:,:],P_Im[:,None,:,:]),1) 26 | return P 27 | 28 | def roll(X): 29 | return th.cat((X[:,:,X.shape[-2]//2:,:],X[:,:,:X.shape[-2]//2,:]),2) 30 | 31 | def decibel(X): 32 | # n_fft=X.shape[-2]//2 33 | # X_Re=X[:,:n_fft,:] 34 | # X_Im=X[:,n_fft:,:] 35 | if(isinstance(X,list)): 36 | absX=AbsolutSquared()(X) 37 | X_db=[10*th.log(x) for x in absX] 38 | else: 39 | X_db=10*th.log(absolut_squared(X)) 40 | return X_db 41 | 42 | def absolut_squared(X): 43 | ''' 44 | Inputs a (N,2*C,H,W) complex tensor 45 | Outputs a (N,C,H,W) real tensor containing the input's squared module 46 | ''' 47 | XX=X.split(X.shape[1]//2,1) 48 | return (XX[0]**2+XX[1]**2) 49 | 50 | def oneD2twoD_cplx(x): 51 | ''' 52 | Inputs a 3D complex tensor (N,2*n,T) 53 | Outputs a 4D complex tensor (N,2,n,T) 54 | ''' 55 | return x.view(x.shape[0],2,-1,x.shape[-1]) 56 | 57 | def batch_norm2d_cplx(X,running_mean,running_var,gamma11,gamma12,gamma22,bias,momentum,training): 58 | N,C,H,W=X.shape 59 | XX=list(X.split(C//2,1)) 60 | X_re=XX[0].transpose(0,1).contiguous().view(C//2,N*H*W) #(C//2,NHW) 61 | X_im=XX[1].transpose(0,1).contiguous().view(C//2,N*H*W) #(C//2,NHW) 62 | X_cplx=th.cat((X_re[:,None,:],X_im[:,None,:]),1) #(C//2,2,NHW) 63 | if(training): 64 | mu_re=X_re.mean(1); mu_im=X_im.mean(1) #(C//2,) 65 | mu_cplx=th.cat((mu_re[:,None],mu_im[:,None]),1) #(C//2,2) 66 | var_re=X_re.var(1); var_im=X_im.var(1) #(C//2,) 67 | cov_imre=((X_re-mu_re[:,None])*(X_im-mu_im[:,None])).sum(1)/(N*H*W-1) #(C//2,) 68 | cov_cplx=utils.twobytwo_covmat_from_coeffs(var_re,var_im,cov_imre) 69 | with th.no_grad(): 70 | running_mean=(1-momentum)*running_mean+momentum*mu_cplx #(C//2,2) 71 | running_var=(1-momentum)*running_var+momentum*cov_cplx #(C//2,2,2) 72 | # cov_cplx_sqinv=(functional_spd.SqminvEig()(cov_cplx[:,None,:,:]))[:,0,:,:] 73 | cov_cplx_sqinv=utils.twobytwo_sqinv(var_re,var_im,cov_imre) 74 | Y=cov_cplx_sqinv.matmul((X_cplx-mu_cplx[:,:,None])) 75 | else: 76 | # running_var_sqinv=(functional_spd.SqminvEig()(running_var[:,None,:,:].double())).float()[:,0,:,:] 77 | running_var_sqinv=utils.twobytwo_sqinv(running_var[:,0,0],running_var[:,1,1],running_var[:,1,0]) 78 | Y=running_var_sqinv.matmul((X_cplx-running_mean[:,:,None])) 79 | weight=utils.twobytwo_covmat_from_coeffs(gamma11,gamma22,gamma12) 80 | Z=weight.matmul(Y)+bias[:,:,None] 81 | return Z.view(C,-1).view(C,N,H,W).transpose(0,1) 82 | 83 | def batch_norm2d_cplx_spd(X,running_mean,running_var,weight,bias,momentum,training): 84 | N,C,H,W=X.shape 85 | XX=list(X.split(C//2,1)) 86 | X_re=XX[0].transpose(0,1).contiguous().view(C//2,N*H*W) #(C//2,NHW) 87 | X_im=XX[1].transpose(0,1).contiguous().view(C//2,N*H*W) #(C//2,NHW) 88 | X_cplx=th.cat((X_re[:,None,:],X_im[:,None,:]),1) #(C//2,2,NHW) 89 | if(training): 90 | mu_re=X_re.mean(1); mu_im=X_im.mean(1) #(C//2,) 91 | mu_cplx=th.cat((mu_re[:,None],mu_im[:,None]),1) #(C//2,2) 92 | var_re=X_re.var(1); var_im=X_im.var(1) #(C//2,) 93 | cov_imre=((X_re-mu_re[:,None])*(X_im-mu_im[:,None])).sum(1)/(N*H*W-1) #(C//2,) 94 | cov_cplx=utils.twobytwo_covmat_from_coeffs(var_re,var_im,cov_imre) 95 | with th.no_grad(): 96 | running_mean=(1-momentum)*running_mean+momentum*mu_cplx #(C//2,2) 97 | running_var=(1-momentum)*running_var+momentum*cov_cplx #(C//2,2,2) 98 | cov_cplx_sqinv=(functional_spd.SqminvEig()(cov_cplx[:,None,:,:]))[:,0,:,:] 99 | # cov_cplx_sqinv=utils.twobytwo_sqinv(var_re,var_im,cov_imre) 100 | Y=cov_cplx_sqinv.matmul((X_cplx-mu_cplx[:,:,None])) 101 | else: 102 | running_var_sqinv=(functional_spd.SqminvEig()(running_var[:,None,:,:].double())).float()[:,0,:,:] 103 | # running_var_sqinv=utils.twobytwo_sqinv(running_var[:,0,0],running_var[:,1,1],running_var[:,1,0]) 104 | Y=running_var_sqinv.matmul((X_cplx-running_mean[:,:,None])) 105 | weight=utils.twobytwo_covmat_from_coeffs(_gamma11,_gamma22,_gamma12) ##################### CHANGE 106 | Z=weight.matmul(Y)+bias[:,:,None] 107 | return Z.view(C,-1).view(C,N,H,W).transpose(0,1) 108 | 109 | 110 | def cov_pool_cplx(f,reg_mode='mle',N_estimates=None): 111 | """ 112 | Input f: Temporal n-dimensionnal complex feature map of length T (T=1 for a unitary signal) (batch_size,2,n,T) 113 | Output X: Complex covariance matrix of size (batch_size,2,n,n) 114 | """ 115 | N,_,n,T=f.shape 116 | ff=f.split(f.shape[1]//2,1) 117 | f_re=ff[0]; f_im=ff[1] 118 | if(N_estimates is not None): 119 | f_re=f_re.split(self._Nestimates,-1) 120 | if(f_re[-1].shape[-1]!=self._Nestimates): 121 | f_re=th.cat(f_re[:-1]+(th.cat((f_re[-1],th.zeros(N,1,n,self._Nestimates-f_re[-1].shape[-1])),-1),),1) 122 | else: 123 | f_re=th.cat(f_re,1) 124 | f_im=f_im.split(self._Nestimates,-1) 125 | if(f_im[-1].shape[-1]!=self._Nestimates): 126 | f_im=th.cat(f_im[:-1]+(th.cat((f_im[-1],th.zeros(N,1,n,self._Nestimates-f_im[-1].shape[-1])),-1),),1) 127 | else: 128 | f_im=th.cat(f_im,1) 129 | f_re-=f_re.mean(-1,True); f_im-=f_im.mean(-1,True) 130 | f_re=f_re.double(); f_im=f_im.double() 131 | X_Re=((f_re.matmul(f_re.transpose(-1,-2))+f_im.matmul(f_im.transpose(-1,-2)))/(f.shape[-1]-1)) 132 | X_Im=((f_im.matmul(f_re.transpose(-1,-2))-f_re.matmul(f_im.transpose(-1,-2)))/(f.shape[-1]-1)) 133 | if(reg_mode=='mle'): 134 | pass 135 | elif(self._reg_mode=='add_id'): 136 | X_Re=RegulEig(1e-6)(X_Re) 137 | X_Im=RegulEig(1e-6)(X_Im) 138 | elif(self._reg_mode=='adjust_eig'): 139 | X_Re=AdjustEig(0.75)(X_Re) 140 | X_Im=AdjustEig(0.75)(X_Im) 141 | X=(X_Re+X_Im)/2 ############## later, do cat for HPD 142 | # X=th.cat((X_Re,X_Im),1) #for real complex matrices 143 | return X 144 | -------------------------------------------------------------------------------- /spdnets/cplx/nn.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import numpy.random 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from . import functional 7 | 8 | class Conv_cplx(nn.Module): 9 | ''' 10 | Interface complex conv layer 11 | ''' 12 | def forward(self,X): 13 | return functional.conv_cplx(X,self._conv_Re,self._conv_Im) 14 | 15 | class Conv1d_cplx(Conv_cplx): 16 | ''' 17 | 1D complex conv layer 18 | Inputs a 3D Tensor (batch_size,2*C,T) 19 | C is the number of channels, 2*C=in_channels the effective number of channels for handling complex data 20 | Contains two real-valued conv layers 21 | Output is of shape (batch_size,out_channels,T) (complex channels is out_channels//2) 22 | ''' 23 | def __init__(self,in_channels,out_channels,kernel_size,stride,bias=False): 24 | super(__class__,self).__init__() 25 | self._conv_Re=nn.Conv1d(in_channels//2,out_channels,kernel_size,stride,bias=bias) 26 | self._conv_Im=nn.Conv1d(in_channels//2,out_channels,kernel_size,stride,bias=bias) 27 | 28 | class FFT(Conv1d_cplx): 29 | ''' 30 | 1D complex conv layer, where the weights are the Fourier atoms 31 | ''' 32 | def __init__(self,in_channels,out_channels,kernel_size,stride,bias=False): 33 | super(__class__, self).__init__(in_channels,out_channels,kernel_size,stride,bias) 34 | atoms=signal_utils.fourier_atoms(out_channels,kernel_size).conj() 35 | atoms_Re=th.from_numpy(utils.cplx2bichan(atoms)[:,0,:][:,None,:]).float() 36 | atoms_Im=th.from_numpy(utils.cplx2bichan(atoms)[:,1,:][:,None,:]).float() 37 | self._conv_Re.weight.data=nn.Parameter(atoms_Re) 38 | self._conv_Im.weight.data=nn.Parameter(atoms_Im) 39 | for param in list(self.parameters()): 40 | param.requires_grad=False 41 | 42 | class SplitSignal_cplx(Conv1d_cplx): 43 | ''' 44 | 1D to 2D complex conv layer, where the weights are adequately placed zeroes and ones to split a signal 45 | Input is still (batch_size,2,T) 46 | Output is 4-D complex signal (batch_size,2,window_size,T') instead of (batch_size,2*window_size,T') 47 | ''' 48 | def __init__(self,in_channels,window_size,hop_length): 49 | super(__class__, self).__init__(in_channels,window_size,window_size,hop_length,False) 50 | self._conv_Re.weight.data=th.eye(window_size)[:,None,:] 51 | self._conv_Im.weight.data=th.eye(window_size)[:,None,:] 52 | for param in list(self._conv_Re.parameters()): 53 | param.requires_grad=False 54 | for param in list(self._conv_Im.parameters()): 55 | param.requires_grad=False 56 | def forward(self,X): 57 | return functional.split_signal_cplx(X,self._conv_Re,self._conv_Im) 58 | 59 | class SplitSignalBlock_cplx(Conv1d_cplx): 60 | ''' 61 | 1D to 2D complex conv layer, where the weights are adequately placed zeroes and ones to split a signal 62 | Input is now L blocks of 3-D complex signals (batch_size,L,2*C_in,T), which are acted on independently by the conv layer 63 | Output is L blocks of 3-D complex signals (batch_size,L,2*C_out,window_size,T') 64 | ''' 65 | def __init__(self,in_channels,window_size,hop_length): 66 | super(__class__, self).__init__(in_channels,window_size,window_size,hop_length,False) 67 | # self._conv_Re.weight.data=th.ones_like(self._conv_Re.weight.data) 68 | self._conv_Re.weight.data=th.eye(window_size)[:,None,:] 69 | # self._conv_Im.weight.data=th.ones_like(self._conv_Im.weight.data) 70 | self._conv_Im.weight.data=th.eye(window_size)[:,None,:] 71 | for param in list(self._conv_Re.parameters()): 72 | param.requires_grad=False 73 | for param in list(self._conv_Im.parameters()): 74 | param.requires_grad=False 75 | def forward(self,X): 76 | XX=[functional.split_signal_cplx(X[:,i,:,:],self._conv_Re,self._conv_Im)[:,None,:,:,:] 77 | for i in range(X.shape[1])] 78 | return th.cat(XX,1) 79 | 80 | class Conv2d_cplx(Conv_cplx): 81 | ''' 82 | 2D complex conv layer 83 | Inputs a 4D Tensor (N,2*C,H,W) 84 | C is the number of channels, 2*C the effective number of channels for handling complex data 85 | Contains two real-valued conv layers 86 | ''' 87 | def __init__(self,in_channels,out_channels,kernel_size,stride=(1,1),bias=True): 88 | super(Conv2d_cplx,self).__init__() 89 | self._conv_Re=nn.Conv2d(in_channels//2,out_channels//2,kernel_size,stride,bias=bias) 90 | self._conv_Im=nn.Conv2d(in_channels//2,out_channels//2,kernel_size,stride,bias=bias) 91 | sigma=2./(in_channels//2+out_channels//2) 92 | ampli=numpy.random.rayleigh(sigma,(out_channels//2,in_channels//2,kernel_size[0],kernel_size[1])) 93 | phase=numpy.random.uniform(-np.pi,np.pi,(out_channels//2,in_channels//2,kernel_size[0],kernel_size[1])) 94 | self._conv_Re.weight.data=nn.Parameter(th.Tensor(ampli*np.cos(phase))) 95 | self._conv_Im.weight.data=nn.Parameter(th.Tensor(ampli*np.sin(phase))) 96 | 97 | class ReLU_cplx(nn.ReLU): 98 | pass 99 | # def forward(self,X): 100 | # return F.relu(X) 101 | 102 | class Roll(nn.Module): 103 | def forward(self,X): 104 | return functional.roll(X) 105 | 106 | class Decibel(nn.Module): 107 | ''' 108 | Inputs a (N,2*C,H,W) complex tensor 109 | Outputs a (N,C,H,W) real tensor containing the input's decibel amplitude 110 | ''' 111 | def forward(self,X): 112 | return functional.decibel(X) 113 | 114 | class AbsolutSquared(nn.Module): 115 | ''' 116 | Inputs a (N,2*C,H,W) complex tensor 117 | Outputs a (N,C,H,W) real tensor containing the input's squared module 118 | ''' 119 | def forward(self,X): 120 | return functional.absolut_squared(X) 121 | 122 | class oneD2twoD_cplx(nn.Module): 123 | ''' 124 | Inputs a 3D complex tensor (N,2*n,T) 125 | Outputs a 4D complex tensor (N,2,n,T) 126 | ''' 127 | def forward(self,x): 128 | return functional.oneD2twoD_cplx(x) 129 | 130 | class MaxPool2d_cplx(nn.MaxPool2d): 131 | pass 132 | # def __init__(self,kernel_size,stride=None,padding=0,dilation=1,return_indices=False,ceil_mode=False): 133 | # super(MaxPool2d_cplx, self).__init__(kernel_size,stride,padding,dilation,return_indices,ceil_mode) 134 | # self._abs=AbsolutSquared() 135 | # def forward(self,X): 136 | # N,C,H,W=X.shape 137 | # X_abs=self._abs(X) 138 | # _,idx=F.max_pool2d(X_abs,self.kernel_size,self.stride,self.padding,self.dilation,self.ceil_mode,return_indices=True) 139 | # XX=X.split(C//2,1) 140 | # Y_re=th.gather(XX[0].view(N,C//2,-1),-1,idx.view(N,C//2,-1)).view(idx.shape) 141 | # Y_im=th.gather(XX[1].view(N,C//2,-1),-1,idx.view(N,C//2,-1)).view(idx.shape) 142 | # Y=th.cat((Y_re,Y_im),1) 143 | # return Y 144 | 145 | class BatchNorm2d_cplx(nn.BatchNorm2d): 146 | ''' 147 | Inputs a (N,2*C,H,W) complex tensor 148 | Outputs a whitened and parametrically rescaled (N,2*C,H,W) complex tensor 149 | ''' 150 | # def __init__(self,in_channels): 151 | # super(BatchNorm2d_cplx,self).__init__(in_channels) 152 | # self.momentum=0.1 153 | # self.running_mean=th.zeros(in_channels//2,2) 154 | # self.running_var=th.eye(2)[None,:,:].repeat(in_channels//2,1,1)/(2**.5) 155 | # self._gamma11=nn.Parameter(th.ones(in_channels//2)/2**.5) 156 | # self._gamma22=nn.Parameter(th.ones(in_channels//2)/2**.5) 157 | # self._gamma12=nn.Parameter(th.zeros(in_channels//2)) 158 | # self.bias=nn.Parameter(th.zeros(in_channels//2,2)) 159 | # def forward(self,X): 160 | # return functional.batch_norm2d_cplx(X,self.running_mean,self.running_var, 161 | # self._gamma11,self._gamma12,self._gamma22,self.bias,self.momentum,self.training) 162 | pass 163 | 164 | class BatchNorm2d_cplxSPD(nn.BatchNorm2d): 165 | ''' 166 | Inputs a (N,2*C,H,W) complex tensor 167 | Outputs a whitened and parametrically rescaled (N,2*C,H,W) complex tensor 168 | ''' 169 | def __init__(self,in_channels): 170 | super(BatchNorm2d_cplxSPD,self).__init__(in_channels) 171 | self.momentum=0.1 172 | self.running_mean=th.zeros(in_channels//2,2) 173 | self.running_var=th.eye(2)[None,:,:].repeat(in_channels//2,1,1)/(2**.5) 174 | self.weight_=nn.ParameterList([functional_spd.SPDParameter(th.eye(2)/(2**.5)) for _ in range(in_channels//2)]) 175 | self.bias=nn.Parameter(th.zeros(in_channels//2,2)) 176 | def forward(self,X): 177 | return functional.batch_norm2d_cplx_spd(X,self.running_mean,self.running_var, 178 | self.weight_,self.bias,self.momentum,self.training) 179 | 180 | class BatchNorm2d(nn.BatchNorm2d): 181 | pass 182 | # def __init__(self,in_channels): 183 | # super(BatchNorm2d,self).__init__(in_channels) 184 | # self.momentum=0.1 185 | # self.running_mean=th.zeros(in_channels) 186 | # self.running_var=th.zeros(in_channels) 187 | # self.weight=nn.Parameter(th.ones(in_channels)) 188 | # self.bias=nn.Parameter(th.zeros(in_channels)) 189 | # def forward(self,X): 190 | # N,C,H,W=X.shape 191 | # X_vec=X.transpose(0,1).contiguous().view(C,N*H*W) 192 | # if(self.training): 193 | # mu=X_vec.mean(1); var=X_vec.var(1) 194 | # with th.no_grad(): 195 | # self.running_mean=(1-self.momentum)*self.running_mean+self.momentum*mu 196 | # self.running_var=(1-self.momentum)*self.running_var+self.momentum*var 197 | # Y=(X_vec-mu.view(-1,1))/(var.view(-1,1)**.5+self.eps) 198 | # else: 199 | # Y=(X_vec-self.running_mean.view(-1,1))/(self.running_var.view(-1,1)**.5+self.eps) 200 | # Z=self.weight.view(-1,1)*Y+self.bias.view(-1,1) 201 | # return Z.view(C,N,H,W).transpose(0,1) 202 | # 203 | # def forward(self, x): 204 | # self._check_input_dim(x) 205 | # y = x.transpose(0,1) 206 | # return_shape = y.shape 207 | # y = y.contiguous().view(x.size(1), -1) 208 | # mu = y.mean(dim=1) 209 | # sigma2 = y.var(dim=1) 210 | # if self.training is not True: 211 | # y = y - self.running_mean.view(-1, 1) 212 | # y = y / (self.running_var.view(-1, 1)**.5 + self.eps) 213 | # else: 214 | # if self.track_running_stats is True: 215 | # with torch.no_grad(): 216 | # self.running_mean = (1-self.momentum)*self.running_mean + self.momentum*mu 217 | # self.running_var = (1-self.momentum)*self.running_var + self.momentum*sigma2 218 | # y = y - mu.view(-1,1) 219 | # y = y / (sigma2.view(-1,1)**.5 + self.eps) 220 | # 221 | # y = self.weight.view(-1, 1) * y + self.bias.view(-1, 1) 222 | # return y.view(return_shape).transpose(0,1) 223 | 224 | class CovPool_cplx(nn.Module): 225 | """ 226 | Input f: Temporal n-dimensionnal complex feature map of length T (T=1 for a unitary signal) (batch_size,2,n,T) 227 | Output X: Complex covariance matrix of size (batch_size,2,n,n) 228 | """ 229 | def __init__(self,reg_mode='mle',N_estimates=None): 230 | super(__class__,self).__init__() 231 | self._reg_mode=reg_mode; self.N_estimates=N_estimates 232 | def forward(self,f): 233 | return functional.cov_pool_cplx(f,self._reg_mode,self.N_estimates) 234 | 235 | class CovPoolBlock_cplx(nn.Module): 236 | """ 237 | Input f: L blocks of temporal n-dimensionnal complex feature map of length T of shape (batch_size,L,2,n,T) 238 | Output X: L blocks of complex covariance matrix of size (batch_size,L,2,n,n) 239 | """ 240 | def __init__(self,reg_mode='mle',N_estimates=None): 241 | super(__class__,self).__init__() 242 | self._reg_mode=reg_mode; self.N_estimates=N_estimates 243 | def forward(self,f): 244 | XX=[functional.cov_pool_cplx(f[:,i,:,:,:],self._reg_mode,self.N_estimates)[:,None,:,:,:] 245 | for i in range(f.shape[1])] 246 | return th.cat(XX,1) -------------------------------------------------------------------------------- /spdnets/functionals.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from typing import Callable, Tuple 4 | from typing import Any 5 | from torch.autograd import Function, gradcheck 6 | from torch.functional import Tensor 7 | from torch.types import Number 8 | 9 | # define the epsilon precision depending on the tensor datatype 10 | EPS = {torch.float32: 1e-4, torch.float64: 1e-7} 11 | 12 | 13 | def ensure_sym(A: Tensor) -> Tensor: 14 | """Ensures that the last two dimensions of the tensor are symmetric. 15 | Parameters 16 | ---------- 17 | A : torch.Tensor 18 | with the last two dimensions being identical 19 | ------- 20 | Returns : torch.Tensor 21 | """ 22 | return 0.5 * (A + A.transpose(-1,-2)) 23 | 24 | 25 | def broadcast_dims(A: torch.Size, B: torch.Size, raise_error:bool=True) -> Tuple: 26 | """Return the dimensions that can be broadcasted. 27 | Parameters 28 | ---------- 29 | A : torch.Size 30 | shape of first tensor 31 | B : torch.Size 32 | shape of second tensor 33 | raise_error : bool (=True) 34 | flag that indicates if an error should be raised if A and B cannot be broadcasted 35 | ------- 36 | Returns : torch.Tensor 37 | """ 38 | # check if the tensors can be broadcasted 39 | if raise_error: 40 | if len(A) != len(B): 41 | raise ValueError('The number of dimensions must be equal!') 42 | 43 | tdim = torch.tensor((A, B), dtype=torch.int32) 44 | 45 | # find differing dimensions 46 | bdims = tuple(torch.where(tdim[0].ne(tdim[1]))[0].tolist()) 47 | 48 | # check if one of the different dimensions has size 1 49 | if raise_error: 50 | if not tdim[:,bdims].eq(1).any(dim=0).all(): 51 | raise ValueError('Broadcast not possible! One of the dimensions must be 1.') 52 | 53 | return bdims 54 | 55 | 56 | def sum_bcastdims(A: Tensor, shape_out: torch.Size) -> Tensor: 57 | """Returns a tensor whose values along the broadcast dimensions are summed. 58 | Parameters 59 | ---------- 60 | A : torch.Tensor 61 | tensor that should be modified 62 | shape_out : torch.Size 63 | desired shape of the tensor after aggregation 64 | ------- 65 | Returns : the aggregated tensor with the desired shape 66 | """ 67 | bdims = broadcast_dims(A.shape, shape_out) 68 | 69 | if len(bdims) == 0: 70 | return A 71 | else: 72 | return A.sum(dim=bdims, keepdim=True) 73 | 74 | 75 | def randn_sym(shape, **kwargs): 76 | ndim = shape[-1] 77 | X = torch.randn(shape, **kwargs) 78 | ixs = torch.tril_indices(ndim,ndim, offset=-1) 79 | X[...,ixs[0],ixs[1]] /= math.sqrt(2) 80 | X[...,ixs[1],ixs[0]] = X[...,ixs[0],ixs[1]] 81 | return X 82 | 83 | 84 | def spd_2point_interpolation(A : Tensor, B : Tensor, t : Number) -> Tensor: 85 | """ 86 | A with 1-t, B with t 87 | """ 88 | rm_sq, rm_invsq = sym_invsqrtm2.apply(A) 89 | return rm_sq @ sym_powm.apply(rm_invsq @ B @ rm_invsq, torch.tensor(t)) @ rm_sq 90 | 91 | 92 | class reverse_gradient(Function): 93 | """ 94 | Reversal of the gradient 95 | Parameters 96 | --------- 97 | scaling : Number 98 | A constant number that is multiplied to the sign-reversed gradients (1.0 default) 99 | """ 100 | @staticmethod 101 | def forward(ctx, x, scaling = 1.0): 102 | ctx.scaling = scaling 103 | return x.view_as(x) 104 | 105 | @staticmethod 106 | def backward(ctx, grad_output): 107 | grad_output = grad_output.neg() * ctx.scaling 108 | return grad_output, None 109 | 110 | 111 | class sym_modeig: 112 | """Basic class that modifies the eigenvalues with an arbitrary elementwise function 113 | """ 114 | 115 | @staticmethod 116 | def forward(M : Tensor, fun : Callable[[Tensor], Tensor], fun_param : Tensor = None, 117 | ensure_symmetric : bool = False, ensure_psd : bool = False, decom_mode='svd') -> Tensor: 118 | """Modifies the eigenvalues of a batch of symmetric matrices in the tensor M (last two dimensions). 119 | 120 | Source: Brooks et al. 2019, Riemannian batch normalization for SPD neural networks, NeurIPS 121 | 122 | Parameters 123 | ---------- 124 | M : torch.Tensor 125 | (batch) of symmetric matrices 126 | fun : Callable[[Tensor], Tensor] 127 | elementwise function 128 | ensure_symmetric : bool = False (optional) 129 | if ensure_symmetric=True, then M is symmetrized 130 | ensure_psd : bool = False (optional) 131 | if ensure_psd=True, then the eigenvalues are clamped so that they are > 0 132 | ------- 133 | Returns : torch.Tensor with modified eigenvalues 134 | """ 135 | if ensure_symmetric: 136 | M = ensure_sym(M) 137 | 138 | # compute the eigenvalues and vectors 139 | if decom_mode=='eigh': 140 | s, U = torch.linalg.eigh(M) 141 | elif decom_mode=='svd': 142 | U, s,_ = torch.linalg.svd(M) 143 | if ensure_psd: 144 | s = s.clamp(min=EPS[s.dtype]) 145 | 146 | # modify the eigenvalues 147 | smod = fun(s, fun_param) 148 | X = U @ torch.diag_embed(smod) @ U.transpose(-1,-2) 149 | 150 | return X, s, smod, U 151 | 152 | @staticmethod 153 | def backward(dX : Tensor, s : Tensor, smod : Tensor, U : Tensor, 154 | fun_der : Callable[[Tensor], Tensor], fun_der_param : Tensor = None) -> Tensor: 155 | """Backpropagates the derivatives 156 | 157 | Source: Brooks et al. 2019, Riemannian batch normalization for SPD neural networks, NeurIPS 158 | 159 | Parameters 160 | ---------- 161 | dX : torch.Tensor 162 | (batch) derivatives that should be backpropagated 163 | s : torch.Tensor 164 | eigenvalues of the original input 165 | smod : torch.Tensor 166 | modified eigenvalues 167 | U : torch.Tensor 168 | eigenvector of the input 169 | fun_der : Callable[[Tensor], Tensor] 170 | elementwise function derivative 171 | ------- 172 | Returns : torch.Tensor containing the backpropagated derivatives 173 | """ 174 | 175 | # compute Lowener matrix 176 | # denominator 177 | L_den = s[...,None] - s[...,None].transpose(-1,-2) 178 | # find cases (similar or different eigenvalues, via threshold) 179 | is_eq = L_den.abs() < EPS[s.dtype] 180 | L_den[is_eq] = 1.0 181 | # case: sigma_i != sigma_j 182 | L_num_ne = smod[...,None] - smod[...,None].transpose(-1,-2) 183 | L_num_ne[is_eq] = 0 184 | # case: sigma_i == sigma_j 185 | sder = fun_der(s, fun_der_param) 186 | L_num_eq = 0.5 * (sder[...,None] + sder[...,None].transpose(-1,-2)) 187 | L_num_eq[~is_eq] = 0 188 | # compose Loewner matrix 189 | L = (L_num_ne + L_num_eq) / L_den 190 | dM = U @ (L * (U.transpose(-1,-2) @ ensure_sym(dX) @ U)) @ U.transpose(-1,-2) 191 | return dM 192 | 193 | 194 | class sym_reeig(Function): 195 | """ 196 | Rectifies the eigenvalues of a batch of symmetric matrices in the tensor M (last two dimensions). 197 | """ 198 | @staticmethod 199 | def value(s : Tensor, threshold : Tensor) -> Tensor: 200 | return s.clamp(min=threshold.item()) 201 | 202 | @staticmethod 203 | def derivative(s : Tensor, threshold : Tensor) -> Tensor: 204 | return (s>threshold.item()).type(s.dtype) 205 | 206 | @staticmethod 207 | def forward(ctx: Any, M: Tensor, threshold : Tensor, ensure_symmetric : bool = False) -> Tensor: 208 | X, s, smod, U = sym_modeig.forward(M, sym_reeig.value, threshold, ensure_symmetric=ensure_symmetric) 209 | ctx.save_for_backward(s, smod, U, threshold) 210 | return X 211 | 212 | @staticmethod 213 | def backward(ctx: Any, dX: Tensor): 214 | s, smod, U, threshold = ctx.saved_tensors 215 | return sym_modeig.backward(dX, s, smod, U, sym_reeig.derivative, threshold), None, None 216 | 217 | @staticmethod 218 | def tests(): 219 | """ 220 | Basic unit tests and test to check gradients 221 | """ 222 | ndim = 2 223 | nb = 1 224 | # generate random base SPD matrix 225 | A = torch.randn((1,ndim,ndim), dtype=torch.double) 226 | U, s, _ = torch.linalg.svd(A) 227 | 228 | threshold = torch.tensor([1e-3], dtype=torch.double) 229 | 230 | # generate batches 231 | # linear case (all eigenvalues are above the threshold) 232 | s = threshold * 1e1 + torch.rand((nb,ndim), dtype=torch.double) * threshold 233 | M = U @ torch.diag_embed(s) @ U.transpose(-1,-2) 234 | 235 | assert (sym_reeig.apply(M, threshold, False).allclose(M)) 236 | M.requires_grad_(True) 237 | assert(gradcheck(sym_reeig.apply, (M, threshold, True))) 238 | 239 | # non-linear case (some eigenvalues are below the threshold) 240 | s = torch.rand((nb,ndim), dtype=torch.double) * threshold 241 | s[::2] += threshold 242 | M = U @ torch.diag_embed(s) @ U.transpose(-1,-2) 243 | assert (~sym_reeig.apply(M, threshold, False).allclose(M)) 244 | M.requires_grad_(True) 245 | assert(gradcheck(sym_reeig.apply, (M, threshold, True))) 246 | 247 | # linear case, all eigenvalues are identical 248 | s = torch.ones((nb,ndim), dtype=torch.double) 249 | M = U @ torch.diag_embed(s) @ U.transpose(-1,-2) 250 | assert (sym_reeig.apply(M, threshold, True).allclose(M)) 251 | M.requires_grad_(True) 252 | assert(gradcheck(sym_reeig.apply, (M, threshold, True))) 253 | 254 | 255 | class sym_abseig(Function): 256 | """ 257 | Computes the absolute values of all eigenvalues for a batch symmetric matrices. 258 | """ 259 | @staticmethod 260 | def value(s : Tensor, param:Tensor = None) -> Tensor: 261 | return s.abs() 262 | 263 | @staticmethod 264 | def derivative(s : Tensor, param:Tensor = None) -> Tensor: 265 | return s.sign() 266 | 267 | @staticmethod 268 | def forward(ctx: Any, M: Tensor, ensure_symmetric : bool = False) -> Tensor: 269 | X, s, smod, U = sym_modeig.forward(M, sym_abseig.value, ensure_symmetric=ensure_symmetric) 270 | ctx.save_for_backward(s, smod, U) 271 | return X 272 | 273 | @staticmethod 274 | def backward(ctx: Any, dX: Tensor): 275 | s, smod, U = ctx.saved_tensors 276 | return sym_modeig.backward(dX, s, smod, U, sym_abseig.derivative), None 277 | 278 | 279 | class sym_logm(Function): 280 | """ 281 | Computes the matrix logarithm for a batch of SPD matrices. 282 | Ensures that the input matrices are SPD by clamping eigenvalues. 283 | During backprop, the update along the clamped eigenvalues is zeroed 284 | """ 285 | @staticmethod 286 | def value(s : Tensor, param:Tensor = None) -> Tensor: 287 | # ensure that the eigenvalues are positive 288 | return s.clamp(min=EPS[s.dtype]).log() 289 | 290 | @staticmethod 291 | def derivative(s : Tensor, param:Tensor = None) -> Tensor: 292 | # compute derivative 293 | sder = s.reciprocal() 294 | # pick subgradient 0 for clamped eigenvalues 295 | sder[s<=EPS[s.dtype]] = 0 296 | return sder 297 | 298 | @staticmethod 299 | def forward(ctx: Any, M: Tensor, ensure_symmetric : bool = False) -> Tensor: 300 | X, s, smod, U = sym_modeig.forward(M, sym_logm.value, ensure_symmetric=ensure_symmetric) 301 | ctx.save_for_backward(s, smod, U) 302 | return X 303 | 304 | @staticmethod 305 | def backward(ctx: Any, dX: Tensor): 306 | s, smod, U = ctx.saved_tensors 307 | return sym_modeig.backward(dX, s, smod, U, sym_logm.derivative), None 308 | 309 | 310 | class sym_expm(Function): 311 | """ 312 | Computes the matrix exponential for a batch of symmetric matrices. 313 | """ 314 | @staticmethod 315 | def value(s : Tensor, param:Tensor = None) -> Tensor: 316 | return s.exp() 317 | 318 | @staticmethod 319 | def derivative(s : Tensor, param:Tensor = None) -> Tensor: 320 | return s.exp() 321 | 322 | @staticmethod 323 | def forward(ctx: Any, M: Tensor, ensure_symmetric : bool = False) -> Tensor: 324 | X, s, smod, U = sym_modeig.forward(M, sym_expm.value, ensure_symmetric=ensure_symmetric,decom_mode='eigh') 325 | ctx.save_for_backward(s, smod, U) 326 | return X 327 | 328 | @staticmethod 329 | def backward(ctx: Any, dX: Tensor): 330 | s, smod, U = ctx.saved_tensors 331 | return sym_modeig.backward(dX, s, smod, U, sym_expm.derivative), None 332 | 333 | 334 | class sym_powm(Function): 335 | """ 336 | Computes the matrix power for a batch of symmetric matrices. 337 | """ 338 | @staticmethod 339 | def value(s : Tensor, exponent : Tensor) -> Tensor: 340 | return s.pow(exponent=exponent) 341 | 342 | @staticmethod 343 | def derivative(s : Tensor, exponent : Tensor) -> Tensor: 344 | return exponent * s.pow(exponent=exponent-1.) 345 | 346 | @staticmethod 347 | def forward(ctx: Any, M: Tensor, exponent : Tensor, ensure_symmetric : bool = False) -> Tensor: 348 | X, s, smod, U = sym_modeig.forward(M, sym_powm.value, exponent, ensure_symmetric=ensure_symmetric) 349 | ctx.save_for_backward(s, smod, U, exponent) 350 | return X 351 | 352 | @staticmethod 353 | def backward(ctx: Any, dX: Tensor): 354 | s, smod, U, exponent = ctx.saved_tensors 355 | dM = sym_modeig.backward(dX, s, smod, U, sym_powm.derivative, exponent) 356 | 357 | dXs = (U.transpose(-1,-2) @ ensure_sym(dX) @ U).diagonal(dim1=-1,dim2=-2) 358 | dexp = dXs * smod * s.log() 359 | 360 | return dM, dexp, None 361 | 362 | 363 | class sym_sqrtm(Function): 364 | """ 365 | Computes the matrix square root for a batch of SPD matrices. 366 | """ 367 | @staticmethod 368 | def value(s : Tensor, param:Tensor = None) -> Tensor: 369 | return s.clamp(min=EPS[s.dtype]).sqrt() 370 | 371 | @staticmethod 372 | def derivative(s : Tensor, param:Tensor = None) -> Tensor: 373 | sder = 0.5 * s.rsqrt() 374 | # pick subgradient 0 for clamped eigenvalues 375 | sder[s<=EPS[s.dtype]] = 0 376 | return sder 377 | 378 | @staticmethod 379 | def forward(ctx: Any, M: Tensor, ensure_symmetric : bool = False) -> Tensor: 380 | X, s, smod, U = sym_modeig.forward(M, sym_sqrtm.value, ensure_symmetric=ensure_symmetric) 381 | ctx.save_for_backward(s, smod, U) 382 | return X 383 | 384 | @staticmethod 385 | def backward(ctx: Any, dX: Tensor): 386 | s, smod, U = ctx.saved_tensors 387 | return sym_modeig.backward(dX, s, smod, U, sym_sqrtm.derivative), None 388 | 389 | 390 | class sym_invsqrtm(Function): 391 | """ 392 | Computes the inverse matrix square root for a batch of SPD matrices. 393 | """ 394 | @staticmethod 395 | def value(s : Tensor, param:Tensor = None) -> Tensor: 396 | return s.clamp(min=EPS[s.dtype]).rsqrt() 397 | 398 | @staticmethod 399 | def derivative(s : Tensor, param:Tensor = None) -> Tensor: 400 | sder = -0.5 * s.pow(-1.5) 401 | # pick subgradient 0 for clamped eigenvalues 402 | sder[s<=EPS[s.dtype]] = 0 403 | return sder 404 | 405 | @staticmethod 406 | def forward(ctx: Any, M: Tensor, ensure_symmetric : bool = False) -> Tensor: 407 | X, s, smod, U = sym_modeig.forward(M, sym_invsqrtm.value, ensure_symmetric=ensure_symmetric) 408 | ctx.save_for_backward(s, smod, U) 409 | return X 410 | 411 | @staticmethod 412 | def backward(ctx: Any, dX: Tensor): 413 | s, smod, U = ctx.saved_tensors 414 | return sym_modeig.backward(dX, s, smod, U, sym_invsqrtm.derivative), None 415 | 416 | 417 | class sym_invsqrtm2(Function): 418 | """ 419 | Computes the square root and inverse square root matrices for a batch of SPD matrices. 420 | """ 421 | 422 | @staticmethod 423 | def forward(ctx: Any, M: Tensor, ensure_symmetric : bool = False) -> Tensor: 424 | Xsq, s, smod, U = sym_modeig.forward(M, sym_sqrtm.value, ensure_symmetric=ensure_symmetric) 425 | smod2 = sym_invsqrtm.value(s) 426 | Xinvsq = U @ torch.diag_embed(smod2) @ U.transpose(-1,-2) 427 | ctx.save_for_backward(s, smod, smod2, U) 428 | return Xsq, Xinvsq 429 | 430 | @staticmethod 431 | def backward(ctx: Any, dXsq: Tensor, dXinvsq: Tensor): 432 | s, smod, smod2, U = ctx.saved_tensors 433 | dMsq = sym_modeig.backward(dXsq, s, smod, U, sym_sqrtm.derivative) 434 | dMinvsq = sym_modeig.backward(dXinvsq, s, smod2, U, sym_invsqrtm.derivative) 435 | 436 | return dMsq + dMinvsq, None 437 | 438 | 439 | class sym_invm(Function): 440 | """ 441 | Computes the inverse matrices for a batch of SPD matrices. 442 | """ 443 | @staticmethod 444 | def value(s : Tensor, param:Tensor = None) -> Tensor: 445 | return s.clamp(min=EPS[s.dtype]).reciprocal() 446 | 447 | @staticmethod 448 | def derivative(s : Tensor, param:Tensor = None) -> Tensor: 449 | sder = -1. * s.pow(-2) 450 | # pick subgradient 0 for clamped eigenvalues 451 | sder[s<=EPS[s.dtype]] = 0 452 | return sder 453 | 454 | @staticmethod 455 | def forward(ctx: Any, M: Tensor, ensure_symmetric : bool = False) -> Tensor: 456 | X, s, smod, U = sym_modeig.forward(M, sym_invm.value, ensure_symmetric=ensure_symmetric) 457 | ctx.save_for_backward(s, smod, U) 458 | return X 459 | 460 | @staticmethod 461 | def backward(ctx: Any, dX: Tensor): 462 | s, smod, U = ctx.saved_tensors 463 | return sym_modeig.backward(dX, s, smod, U, sym_invm.derivative), None 464 | 465 | 466 | def spd_mean_kracher_flow(X : Tensor, G0 : Tensor = None, maxiter : int = 50, dim = 0, weights = None, return_dist = False, return_XT = False) -> Tensor: 467 | 468 | if X.shape[dim] == 1: 469 | if return_dist: 470 | return X, torch.tensor([0.0], dtype=X.dtype, device=X.device) 471 | else: 472 | return X 473 | 474 | if weights is None: 475 | n = X.shape[dim] 476 | weights = torch.ones((*X.shape[:-2], 1, 1), dtype=X.dtype, device=X.device) 477 | weights /= n 478 | 479 | if G0 is None: 480 | G = (X * weights).sum(dim=dim, keepdim=True) 481 | else: 482 | G = G0.clone() 483 | 484 | nu = 1. 485 | dist = tau = crit = torch.finfo(X.dtype).max 486 | i = 0 487 | 488 | while (crit > EPS[X.dtype]) and (i < maxiter) and (nu > EPS[X.dtype]): 489 | i += 1 490 | 491 | Gsq, Ginvsq = sym_invsqrtm2.apply(G) 492 | XT = sym_logm.apply(Ginvsq @ X @ Ginvsq) 493 | GT = (XT * weights).sum(dim=dim, keepdim=True) 494 | G = Gsq @ sym_expm.apply(nu * GT) @ Gsq 495 | 496 | if return_dist: 497 | dist = torch.norm(XT - GT, p='fro', dim=(-2,-1)) 498 | crit = torch.norm(GT, p='fro', dim=(-2,-1)).max() 499 | h = nu * crit 500 | if h < tau: 501 | nu = 0.95 * nu 502 | tau = h 503 | else: 504 | nu = 0.5 * nu 505 | 506 | if return_dist: 507 | return G, dist 508 | if return_XT: 509 | return G, XT 510 | return G 511 | -------------------------------------------------------------------------------- /spdnets/manifolds.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Union, Optional, Tuple 3 | from geoopt.manifolds import Manifold 4 | from . import functionals 5 | 6 | __all__ = ["SymmetricPositiveDefinite"] 7 | 8 | class SymmetricPositiveDefinite(Manifold): 9 | """ 10 | Subclass of the SymmetricPositiveDefinite manifold using the 11 | affine invariant Riemannian metric (AIRM) as default metric 12 | """ 13 | 14 | __scaling__ = Manifold.__scaling__.copy() 15 | name = "SymmetricPositiveDefinite" 16 | ndim = 2 17 | reversible = False 18 | 19 | def __init__(self): 20 | super().__init__() 21 | 22 | def dist(self, x: torch.Tensor, y: torch.Tensor, keepdim) -> torch.Tensor: 23 | """ 24 | Computes the affine invariant Riemannian metric (AIM) 25 | """ 26 | inv_sqrt_x = functionals.sym_invsqrtm.apply(x) 27 | return torch.norm( 28 | functionals.sym_logm.apply(inv_sqrt_x @ y @ inv_sqrt_x), 29 | dim=[-1, -2], 30 | keepdim=keepdim, 31 | ) 32 | 33 | def _check_point_on_manifold( 34 | self, x: torch.Tensor, *, atol=1e-5, rtol=1e-5 35 | ) -> Union[Tuple[bool, Optional[str]], bool]: 36 | ok = torch.allclose(x, x.transpose(-1, -2), atol=atol, rtol=rtol) 37 | if not ok: 38 | return False, "`x != x.transpose` with atol={}, rtol={}".format(atol, rtol) 39 | e = torch.linalg.eigvalsh(x) 40 | ok = (e > -atol).min() 41 | if not ok: 42 | return False, "eigenvalues of x are not all greater than 0." 43 | return True, None 44 | 45 | def _check_vector_on_tangent( 46 | self, x: torch.Tensor, u: torch.Tensor, *, atol=1e-5, rtol=1e-5 47 | ) -> Union[Tuple[bool, Optional[str]], bool]: 48 | ok = torch.allclose(u, u.transpose(-1, -2), atol=atol, rtol=rtol) 49 | if not ok: 50 | return False, "`u != u.transpose` with atol={}, rtol={}".format(atol, rtol) 51 | return True, None 52 | 53 | def projx(self, x: torch.Tensor) -> torch.Tensor: 54 | symx = functionals.ensure_sym(x) 55 | return functionals.sym_abseig.apply(symx) 56 | 57 | def proju(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor: 58 | return functionals.ensure_sym(u) 59 | 60 | def egrad2rgrad(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor: 61 | return x @ self.proju(x, u) @ x 62 | 63 | def inner(self, x: torch.Tensor, u: torch.Tensor, v: Optional[torch.Tensor], keepdim) -> torch.Tensor: 64 | if v is None: 65 | v = u 66 | inv_x = functionals.sym_invm.apply(x) 67 | ret = torch.diagonal(inv_x @ u @ inv_x @ v, dim1=-2, dim2=-1).sum(-1) 68 | if keepdim: 69 | return torch.unsqueeze(torch.unsqueeze(ret, -1), -1) 70 | return ret 71 | 72 | def retr(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor: 73 | inv_x = functionals.sym_invm.apply(x) 74 | return functionals.ensure_sym(x + u + 0.5 * u @ inv_x @ u) 75 | # return self.expmap(x, u) 76 | 77 | def expmap(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor: 78 | sqrt_x, inv_sqrt_x = functionals.sym_invsqrtm2.apply(x) 79 | return sqrt_x @ functionals.sym_expm.apply(inv_sqrt_x @ u @ inv_sqrt_x) @ sqrt_x 80 | 81 | def logmap(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor: 82 | sqrt_x, inv_sqrt_x = functionals.sym_invsqrtm2.apply(x) 83 | return sqrt_x @ functionals.sym_logm.apply(inv_sqrt_x @ u @ inv_sqrt_x) @ sqrt_x 84 | 85 | def extra_repr(self) -> str: 86 | return "default_metric=AIM" 87 | 88 | def transp(self, x: torch.Tensor, y: torch.Tensor, v: torch.Tensor) -> torch.Tensor: 89 | 90 | xinvy = torch.linalg.solve(x.double(),y.double()) 91 | s, U = torch.linalg.eig(xinvy.transpose(-2,-1)) 92 | s = s.real 93 | U = U.real 94 | 95 | Ut = U.transpose(-2,-1) 96 | Esqm = torch.linalg.solve(Ut, torch.diag_embed(s.sqrt()) @ Ut).transpose(-2,-1).to(y.dtype) 97 | 98 | return Esqm @ v @ Esqm.transpose(-1,-2) 99 | 100 | def random(self, *size, dtype=None, device=None, **kwargs) -> torch.Tensor: 101 | tens = torch.randn(*size, dtype=dtype, device=device, **kwargs) 102 | tens = functionals.ensure_sym(tens) 103 | tens = functionals.sym_expm.apply(tens) 104 | return tens 105 | 106 | def barycenter(self, X : torch.Tensor, steps : int = 1, dim = 0) -> torch.Tensor: 107 | """ 108 | Compute several steps of the Kracher flow algorithm to estimate the 109 | Barycenter on the manifold. 110 | """ 111 | return functionals.spd_mean_kracher_flow(X, None, maxiter=steps, dim=dim, return_dist=False) 112 | 113 | def geodesic(self, A : torch.Tensor, B : torch.Tensor, t : torch.Tensor) -> torch.Tensor: 114 | """ 115 | Compute geodesic between two SPD tensors A and B and return 116 | point on the geodesic at length t \in [0,1] 117 | if t = 0, then A is returned 118 | if t = 1, then B is returned 119 | """ 120 | Asq, Ainvsq = functionals.sym_invsqrtm2.apply(A) 121 | return Asq @ functionals.sym_powm.apply(Ainvsq @ B @ Ainvsq, t) @ Asq 122 | 123 | def transp_via_identity(self, X : torch.Tensor, A : torch.Tensor, B : torch.Tensor) -> torch.Tensor: 124 | """ 125 | Parallel transport of the tensors in X around A to the identity matrix I 126 | Parallel transport from around the identity matrix to the new center (tensor B) 127 | """ 128 | Ainvsq = functionals.sym_invsqrtm.apply(A) 129 | Bsq = functionals.sym_sqrtm.apply(B) 130 | return Bsq @ (Ainvsq @ X @ Ainvsq) @ Bsq 131 | 132 | def transp_identity_rescale_transp(self, X : torch.Tensor, A : torch.Tensor, s : torch.Tensor, B : torch.Tensor) -> torch.Tensor: 133 | """ 134 | Parallel transport of the tensors in X around A to the identity matrix I 135 | Rescales the dispersion by the factor s 136 | Parallel transport from the identity to the new center (tensor B) 137 | """ 138 | Ainvsq = functionals.sym_invsqrtm.apply(A) 139 | Bsq = functionals.sym_sqrtm.apply(B) 140 | return Bsq @ functionals.sym_powm.apply(Ainvsq @ X @ Ainvsq, s) @ Bsq 141 | 142 | def transp_identity_rescale_rotate_transp(self, X : torch.Tensor, A : torch.Tensor, s : torch.Tensor, B : torch.Tensor, W : torch.Tensor) -> torch.Tensor: 143 | """ 144 | Parallel transport of the tensors in X around A to the identity matrix I 145 | Rescales the dispersion by the factor s 146 | Parallel transport from the identity to the new center (tensor B) 147 | """ 148 | Ainvsq = functionals.sym_invsqrtm.apply(A) 149 | Bsq = functionals.sym_sqrtm.apply(B) 150 | WBsq = W @ Bsq 151 | return WBsq.transpose(-2,-1) @ functionals.sym_powm.apply(Ainvsq @ X @ Ainvsq, s) @ WBsq 152 | -------------------------------------------------------------------------------- /spdnets/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseModel, FineTuneableModel, CPUModel, PatternInterpretableModel 2 | from .base import DomainAdaptBaseModel 3 | from .base import DomainAdaptFineTuneableModel, DomainAdaptJointTrainableModel 4 | 5 | from .eegnet import EEGNetv4, DANNEEGNet 6 | from .shconvnet import ShallowConvNet,DANNShallowConvNet,ShConvNetDSBN 7 | from .tsmnet import TSMNet, CNNNet 8 | from .tsmnetMLR import TSMNetMLR -------------------------------------------------------------------------------- /spdnets/models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/spdnets/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /spdnets/models/__pycache__/base.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/spdnets/models/__pycache__/base.cpython-310.pyc -------------------------------------------------------------------------------- /spdnets/models/__pycache__/dann.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/spdnets/models/__pycache__/dann.cpython-310.pyc -------------------------------------------------------------------------------- /spdnets/models/__pycache__/eegnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/spdnets/models/__pycache__/eegnet.cpython-310.pyc -------------------------------------------------------------------------------- /spdnets/models/__pycache__/shconvnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/spdnets/models/__pycache__/shconvnet.cpython-310.pyc -------------------------------------------------------------------------------- /spdnets/models/__pycache__/tsmnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/spdnets/models/__pycache__/tsmnet.cpython-310.pyc -------------------------------------------------------------------------------- /spdnets/models/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BaseModel(nn.Module): 6 | def __init__(self, nclasses=None, nchannels=None, nsamples=None, nbands=None, device=None, input_shape=None): 7 | super().__init__() 8 | self.device_ = device 9 | self.lossfn = torch.nn.CrossEntropyLoss() 10 | self.nclasses_ = nclasses 11 | self.nchannels_ = nchannels 12 | self.nsamples_ = nsamples 13 | self.nbands_ = nbands 14 | self.input_shape_ = input_shape 15 | 16 | # AUXILIARY METHODS 17 | def calculate_classification_accuracy(self, Y, Y_lat): 18 | Y_hat = Y_lat.argmax(1) 19 | acc = Y_hat.eq(Y).float().mean().item() 20 | P_hat = torch.softmax(Y_lat, dim=1) 21 | return acc, P_hat 22 | 23 | def calculate_objective(self, model_pred, y_true, model_inp=None): 24 | # Y_lat, l = self(X.to(self.device_), B) 25 | if isinstance(model_pred, (list, tuple)): 26 | y_class_hat = model_pred[0] 27 | else: 28 | y_class_hat = model_pred 29 | loss = self.lossfn(y_class_hat, y_true.to(y_class_hat.device)) 30 | return loss 31 | 32 | def get_hyperparameters(self): 33 | return dict(nchannels = self.nchannels_, 34 | nclasses=self.nclasses_, 35 | nsamples=self.nsamples_, 36 | nbands=self.nbands_) 37 | 38 | 39 | class CPUModel: 40 | pass 41 | 42 | 43 | class FineTuneableModel: 44 | def finetune(self, x, y, d): 45 | raise NotImplementedError() 46 | 47 | 48 | class DomainAdaptBaseModel(BaseModel): 49 | def __init__(self, domains = [], **kwargs): 50 | super().__init__(**kwargs) 51 | self.domains_ = domains 52 | 53 | 54 | class DomainAdaptFineTuneableModel(DomainAdaptBaseModel): 55 | def domainadapt_finetune(self, x, y, d, target_domains): 56 | raise NotImplementedError() 57 | 58 | 59 | class DomainAdaptJointTrainableModel(DomainAdaptBaseModel): 60 | def calculate_objective(self, model_pred, y_true, model_inp=None): 61 | # filter out masked observations 62 | keep = y_true != -1 # special label 63 | 64 | if isinstance(model_pred, (list, tuple)): 65 | y_class_hat = model_pred[0] 66 | else: 67 | y_class_hat = model_pred 68 | 69 | return super().calculate_objective(y_class_hat[keep], y_true[keep], None) 70 | 71 | 72 | class PatternInterpretableModel: 73 | def compute_patterns(self, x, y, d): 74 | raise NotImplementedError() 75 | -------------------------------------------------------------------------------- /spdnets/models/dann.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base import DomainAdaptJointTrainableModel 3 | import spdnets.modules as modules 4 | 5 | class DANNBase(DomainAdaptJointTrainableModel): 6 | """ 7 | Domain adeversarial neural network (DANN) proposed 8 | by Ganin et al. 2016, JMLR 9 | """ 10 | def __init__(self, daloss_scaling = 1., dann_mode = 'ganin2016', **kwargs): 11 | domains = kwargs['domains'] 12 | assert (domains.dtype == torch.long) 13 | kwargs['domains'] = domains.sort()[0] 14 | super().__init__(**kwargs) 15 | self.dann_mode_ = dann_mode 16 | 17 | if self.dann_mode_ == 'ganin2015': 18 | grad_reversal_scaling = daloss_scaling 19 | self.daloss_scaling_ = 1. 20 | elif self.dann_mode_ == 'ganin2016': 21 | grad_reversal_scaling = 1. 22 | self.daloss_scaling_ = daloss_scaling 23 | else: 24 | raise NotImplementedError() 25 | 26 | ndim_latent = self._ndim_latent() 27 | self.adversary_loss = torch.nn.CrossEntropyLoss() 28 | 29 | self.adversary = torch.nn.Sequential( 30 | torch.nn.Flatten(start_dim=1), 31 | modules.ReverseGradient(scaling=grad_reversal_scaling), 32 | torch.nn.Linear(ndim_latent, len(self.domains_)) 33 | ).to(self.device_) 34 | 35 | def _ndim_latent(self): 36 | raise NotImplementedError() 37 | 38 | def forward(self, l, d): 39 | # super().forward() 40 | # h = self.cnn(x[:,None,...]).flatten(start_dim=1) 41 | # y = self.classifier(h) 42 | y_domain = self.adversary(l) 43 | return y_domain 44 | 45 | def domainadapt(self, x, y, d, target_domain): 46 | pass # domain adaptation is done during the training process 47 | 48 | def calculate_objective(self, model_pred, y_true, model_inp): 49 | loss = super().calculate_objective(model_pred, y_true, model_inp) 50 | domain = model_inp['d'] 51 | y_dom_hat = model_pred[1] 52 | # check if all requested domains were declared 53 | assert ((self.domains_[..., None] == domain[None,...]).any(dim=0).all()) 54 | # assign to the class indices (buckets) 55 | y_dom = torch.bucketize(domain, self.domains_).to(y_dom_hat.device) 56 | 57 | adversarial_loss = self.adversary_loss(y_dom_hat, y_dom) 58 | loss = loss + self.daloss_scaling_ * adversarial_loss 59 | 60 | return loss -------------------------------------------------------------------------------- /spdnets/models/eegnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base import BaseModel 3 | from .dann import DANNBase 4 | import spdnets.modules as modules 5 | 6 | 7 | class EEGNetv4(BaseModel): 8 | def __init__(self, is_within = False, srate = 128, f1 = 8, d = 2, **kwargs): 9 | super().__init__(**kwargs) 10 | self.is_within_ = is_within 11 | self.srate_ = srate 12 | self.f1_ = f1 13 | self.d_ = d 14 | self.f2_ = self.f1_ * self.d_ 15 | momentum = 0.01 16 | 17 | kernel_length = int(self.srate_ // 2) 18 | nlatsamples_time = self.nsamples_ // 32 19 | 20 | temp2_kernel_length = int(self.srate_ // 2 // 4) 21 | 22 | if self.is_within_: 23 | drop_prob = 0.5 24 | else: 25 | drop_prob = 0.25 26 | 27 | bntemp = torch.nn.BatchNorm2d(self.f1_, momentum=momentum, affine=True, eps=1e-3) 28 | bnspat = torch.nn.BatchNorm2d(self.f1_ * self.d_, momentum=momentum, affine=True, eps=1e-3) 29 | 30 | self.cnn = torch.nn.Sequential( 31 | torch.nn.Conv2d(1,self.f1_,(1, kernel_length), bias=False, padding='same'), 32 | bntemp, 33 | modules.Conv2dWithNormConstraint(self.f1_, self.f1_ * self.d_, (self.nchannels_, 1), max_norm=1, 34 | stride=1, bias=False, groups=self.f1_, padding=(0, 0)), 35 | bnspat, 36 | torch.nn.ELU(), 37 | torch.nn.AvgPool2d(kernel_size=(1, 4), stride=(1, 4)), 38 | torch.nn.Dropout(p=drop_prob), 39 | torch.nn.Conv2d(self.f1_ * self.d_, self.f1_ * self.d_, (1, temp2_kernel_length), 40 | stride=1, bias=False, groups=self.f1_ * self.d_, padding='same'), 41 | torch.nn.Conv2d(self.f1_ * self.d_, self.f2_, (1, 1), 42 | stride=1, bias=False, padding=(0, 0)), 43 | torch.nn.BatchNorm2d(self.f2_, momentum=momentum, affine=True, eps=1e-3), 44 | torch.nn.ELU(), 45 | torch.nn.AvgPool2d(kernel_size=(1, 8), stride=(1, 8)), 46 | torch.nn.Dropout(p=drop_prob), 47 | ).to(self.device_) 48 | 49 | self.classifier = torch.nn.Sequential( 50 | torch.nn.Flatten(start_dim=1), 51 | modules.LinearWithNormConstraint(self.f2_ * nlatsamples_time, self.nclasses_, max_norm=0.25) 52 | ).to(self.device_) 53 | 54 | def get_hyperparameters(self): 55 | kwargs = super().get_hyperparameters() 56 | kwargs['nsamples'] = self.nsamples_ 57 | kwargs['is_within_subject'] = self.is_within_subject_ 58 | kwargs['srate'] = self.srate_ 59 | kwargs['f1'] = self.f1_ 60 | kwargs['d'] = self.d_ 61 | return kwargs 62 | 63 | def forward(self, x, d): 64 | l = self.cnn(x[:,None,...]) 65 | y = self.classifier(l) 66 | return y, l 67 | 68 | 69 | class DANNEEGNet(DANNBase, EEGNetv4): 70 | """ 71 | Domain adeversarial neural network (DANN) proposed for EEG MI classification 72 | by Ozdenizci et al. 2020, IEEE Access 73 | """ 74 | def __init__(self, daloss_scaling = 0.03, dann_mode = 'ganin2016', **kwargs): 75 | kwargs['daloss_scaling'] = daloss_scaling 76 | kwargs['dann_mode'] = dann_mode 77 | super().__init__(**kwargs) 78 | 79 | def _ndim_latent(self): 80 | return self.classifier[-1].weight.shape[-1] 81 | 82 | def forward(self, x, d): 83 | y, l = EEGNetv4.forward(self, x, d) 84 | y_domain = DANNBase.forward(self, l, d) 85 | return y, y_domain -------------------------------------------------------------------------------- /spdnets/models/shconvnet.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import torch 3 | from .base import BaseModel, DomainAdaptFineTuneableModel 4 | from .dann import DANNBase 5 | import spdnets.batchnorm as bn 6 | import spdnets.modules as modules 7 | 8 | 9 | class ShallowConvNet(BaseModel): 10 | def __init__(self, spatial_filters = 40, temporal_filters = 40, pdrop = 0.5, **kwargs): 11 | super().__init__(**kwargs) 12 | self.spatial_filters_ = spatial_filters 13 | self.temporal_filters_ = temporal_filters 14 | 15 | temp_cnn_kernel = 25 16 | temp_pool_kernel = 75 17 | temp_pool_stride = 15 18 | ntempconvout = int((self.nsamples_ - 1*(temp_cnn_kernel-1) - 1)/1 + 1) 19 | navgpoolout = int((ntempconvout - temp_pool_kernel)/temp_pool_stride + 1) 20 | 21 | self.bn = torch.nn.BatchNorm2d(self.spatial_filters_) 22 | drop = torch.nn.Dropout(p=pdrop) 23 | 24 | self.cnn = torch.nn.Sequential( 25 | torch.nn.Conv2d(1,self.temporal_filters_,(1, temp_cnn_kernel)), 26 | torch.nn.Conv2d(self.temporal_filters_, self.spatial_filters_,(self.nchannels_, 1)), 27 | ).to(self.device_) 28 | self.pool = torch.nn.Sequential( 29 | modules.MySquare(), 30 | torch.nn.AvgPool2d(kernel_size=(1, temp_pool_kernel), stride=(1, temp_pool_stride)), 31 | modules.MyLog(), 32 | drop, 33 | torch.nn.Flatten(start_dim=1), 34 | ).to(self.device_) 35 | self.classifier = torch.nn.Sequential( 36 | torch.nn.Linear(self.spatial_filters_ * navgpoolout, self.nclasses_), 37 | ).to(self.device_) 38 | 39 | def forward(self,x, d): 40 | l = self.cnn(x.to(self.device_)[:,None,...]) 41 | l = self.bn(l) 42 | l = self.pool(l) 43 | y = self.classifier(l) 44 | return y, l 45 | 46 | 47 | class DANNShallowConvNet(DANNBase, ShallowConvNet): 48 | """ 49 | Domain adeversarial neural network (DANN) proposed for EEG MI classification 50 | by Ozdenizci et al. 2020, IEEE Access 51 | """ 52 | def __init__(self, daloss_scaling = 0.05, dann_mode = 'ganin2016', **kwargs): 53 | kwargs['daloss_scaling'] = daloss_scaling 54 | kwargs['dann_mode'] = dann_mode 55 | super().__init__(**kwargs) 56 | 57 | def _ndim_latent(self): 58 | return self.classifier[-1].weight.shape[-1] 59 | 60 | def forward(self, x, d): 61 | y, l = ShallowConvNet.forward(self, x, d) 62 | y_domain = DANNBase.forward(self, l, d) 63 | return y, y_domain 64 | 65 | 66 | class ShConvNetDSBN(ShallowConvNet, DomainAdaptFineTuneableModel): 67 | def __init__(self, 68 | bnorm_dispersion : Union[str, bn.BatchNormDispersion] = bn.BatchNormDispersion.VECTOR, 69 | **kwargs): 70 | super().__init__(**kwargs) 71 | 72 | if isinstance(bnorm_dispersion, str): 73 | bnorm_dispersion = bn.BatchNormDispersion[bnorm_dispersion] 74 | 75 | self.bn = bn.AdaMomDomainBatchNorm((1, self.spatial_filters_, 1, 1), 76 | batchdim=[0,2,3], # same as batch norm 2D 77 | domains=self.domains_, 78 | dispersion=bnorm_dispersion, 79 | eta=1., eta_test=.1).to(self.device_) 80 | 81 | def forward(self,x, d): 82 | l = self.cnn(x.to(self.device_)[:,None,...]) 83 | l = self.bn(l,d.to(device=self.device_)) 84 | l = self.pool(l) 85 | y = self.classifier(l) 86 | return y, l 87 | 88 | def domainadapt_finetune(self, x, y, d, target_domains): 89 | self.bn.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT) 90 | for du in d.unique(): 91 | self.forward(x[d==du], d[d==du]) 92 | self.bn.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER) 93 | -------------------------------------------------------------------------------- /spdnets/models/spdnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Union 3 | import torch 4 | import torch.nn as nn 5 | 6 | import spdnets.cplx.nn as nn_cplx 7 | import spdnets.modules as modules 8 | from spdnets.SPDMLR import SPDRMLR 9 | 10 | 11 | class SPDNet(nn.Module): 12 | def __init__(self,args): 13 | super(__class__, self).__init__() 14 | dims = [int(dim) for dim in args.architecture] 15 | self.feature = [] 16 | if args.dataset == 'RADAR': 17 | self.feature.append(nn_cplx.SplitSignal_cplx(2, 20, 10)) 18 | self.feature.append(nn_cplx.CovPool_cplx()) 19 | self.feature.append(modules.ReEig()) 20 | 21 | for i in range(len(dims) - 2): 22 | shape=[dims[i], dims[i + 1]] 23 | self.feature.append(modules.BiMap(shape,init_mode=args.init_mode,manifold=args.bimap_manifold)) 24 | self.feature.append(modules.ReEig()) 25 | 26 | self.feature.append(modules.BiMap([dims[-2], dims[-1]],init_mode=args.init_mode,manifold=args.bimap_manifold)) 27 | self.feature = nn.Sequential(*self.feature) 28 | 29 | self.construct_classifier(args.classifier,dims[-1],args.class_num,args.metric,args.power,args.alpha,args.beta) 30 | 31 | def forward(self, x): 32 | x_spd = self.feature(x) 33 | y = self.classifier(x_spd) 34 | return y 35 | 36 | def construct_classifier(self,classifier,subspacedims,nclasses_,metric,power,alpha,beta): 37 | if classifier=='SPDMLR': 38 | self.classifier = torch.nn.Sequential( 39 | SPDRMLR(n=subspacedims,c=nclasses_,metric=metric,power=power,alpha=alpha,beta=beta) 40 | ) 41 | elif classifier=='LogEigMLR': 42 | """Following SPDNet and SPDNetBN, we use the full matrices""" 43 | tsdim = int( subspacedims ** 2 ) 44 | self.classifier = torch.nn.Sequential( 45 | modules.LogEig(subspacedims), 46 | torch.nn.Linear(tsdim, nclasses_), 47 | ) 48 | else: 49 | raise Exception(f'wrong clssifier {classifier}') 50 | -------------------------------------------------------------------------------- /spdnets/models/tsmnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Union 3 | import torch 4 | 5 | import spdnets.modules as modules 6 | import spdnets.batchnorm as bn 7 | from .base import DomainAdaptFineTuneableModel, FineTuneableModel, PatternInterpretableModel 8 | 9 | 10 | class TSMNet(DomainAdaptFineTuneableModel, FineTuneableModel, PatternInterpretableModel): 11 | def __init__(self, temporal_filters, spatial_filters = 40, 12 | subspacedims = 20, 13 | temp_cnn_kernel = 25, 14 | bnorm : Optional[str] = 'spdbn', 15 | bnorm_dispersion : Union[str, bn.BatchNormDispersion] = bn.BatchNormDispersion.SCALAR, 16 | **kwargs): 17 | super().__init__(**kwargs) 18 | 19 | self.temporal_filters_ = temporal_filters 20 | self.spatial_filters_ = spatial_filters 21 | self.subspacedimes = subspacedims 22 | self.bnorm_ = bnorm 23 | self.spd_device_ = torch.device('cpu') 24 | if isinstance(bnorm_dispersion, str): 25 | self.bnorm_dispersion_ = bn.BatchNormDispersion[bnorm_dispersion] 26 | else: 27 | self.bnorm_dispersion_ = bnorm_dispersion 28 | 29 | tsdim = int(subspacedims*(subspacedims+1)/2) 30 | 31 | self.cnn = torch.nn.Sequential( 32 | torch.nn.Conv2d(1, self.temporal_filters_, kernel_size=(1,temp_cnn_kernel), 33 | padding='same', padding_mode='reflect'), 34 | torch.nn.Conv2d(self.temporal_filters_, self.spatial_filters_,(self.nchannels_, 1)), 35 | torch.nn.Flatten(start_dim=2), 36 | ).to(self.device_) 37 | 38 | self.cov_pooling = torch.nn.Sequential( 39 | modules.CovariancePool(), 40 | ) 41 | 42 | if self.bnorm_ == 'spdbn': 43 | self.spdbnorm = bn.AdaMomSPDBatchNorm((1,subspacedims,subspacedims), batchdim=0, 44 | dispersion=self.bnorm_dispersion_, 45 | learn_mean=False,learn_std=True, 46 | eta=1., eta_test=.1, dtype=torch.double, device=self.spd_device_) 47 | elif self.bnorm_ == 'brooks': 48 | self.spdbnorm = modules.BatchNormSPDBrooks((1,subspacedims,subspacedims), batchdim=0, dtype=torch.double, device=self.spd_device_) 49 | elif self.bnorm_ == 'tsbn': 50 | self.tsbnorm = bn.AdaMomBatchNorm((1, tsdim), batchdim=0, dispersion=self.bnorm_dispersion_, 51 | eta=1., eta_test=.1, dtype=torch.double, device=self.spd_device_).to(self.device_) 52 | elif self.bnorm_ == 'spddsbn': 53 | self.spddsbnorm = bn.AdaMomDomainSPDBatchNorm((1,subspacedims,subspacedims), batchdim=0, 54 | domains=self.domains_, 55 | learn_mean=False,learn_std=True, 56 | dispersion=self.bnorm_dispersion_, 57 | eta=1., eta_test=.1, dtype=torch.double, device=self.spd_device_) 58 | elif self.bnorm_ == 'tsdsbn': 59 | self.tsdsbnorm = bn.AdaMomDomainBatchNorm((1, tsdim), batchdim=0, 60 | domains=self.domains_, 61 | dispersion=self.bnorm_dispersion_, 62 | eta=1., eta_test=.1, dtype=torch.double).to(self.device_) 63 | elif self.bnorm_ is not None: 64 | raise NotImplementedError('requested undefined batch normalization method.') 65 | 66 | self.spdnet = torch.nn.Sequential( 67 | modules.BiMap((1,self.spatial_filters_,subspacedims), dtype=torch.double, device=self.spd_device_), 68 | modules.ReEig(threshold=1e-4), 69 | ) 70 | self.logeig = torch.nn.Sequential( 71 | modules.LogEig(subspacedims), 72 | torch.nn.Flatten(start_dim=1), 73 | ) 74 | self.classifier = torch.nn.Sequential( 75 | torch.nn.Linear(tsdim,self.nclasses_).double(), 76 | ).to(self.spd_device_) 77 | 78 | def to(self, device: Optional[Union[int, torch.device]] = None, dtype: Optional[Union[int, torch.dtype]] = None, non_blocking: bool = False): 79 | if device is not None: 80 | self.device_ = device 81 | self.cnn.to(self.device_) 82 | return super().to(device=None, dtype=dtype, non_blocking=non_blocking) 83 | 84 | def forward(self, x, d, return_latent=True, return_prebn=False, return_postbn=False): 85 | out = () 86 | h = self.cnn(x.to(device=self.device_)[:,None,...]) 87 | C = self.cov_pooling(h).to(device=self.spd_device_, dtype=torch.double) 88 | l = self.spdnet(C) 89 | out += (l,) if return_prebn else () 90 | l = self.spdbnorm(l) if hasattr(self, 'spdbnorm') else l 91 | l = self.spddsbnorm(l,d.to(device=self.spd_device_)) if hasattr(self, 'spddsbnorm') else l 92 | out += (l,) if return_postbn else () 93 | l = self.logeig(l) 94 | l = self.tsbnorm(l) if hasattr(self, 'tsbnorm') else l 95 | l = self.tsdsbnorm(l,d) if hasattr(self, 'tsdsbnorm') else l 96 | out += (l,) if return_latent else () 97 | y = self.classifier(l) 98 | out = y if len(out) == 0 else (y, *out[::-1]) 99 | return out 100 | 101 | def domainadapt_finetune(self, x, y, d, target_domains): 102 | if hasattr(self, 'spddsbnorm'): 103 | self.spddsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT) 104 | if hasattr(self, 'tsdsbnorm'): 105 | self.tsdsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT) 106 | 107 | with torch.no_grad(): 108 | for du in d.unique(): 109 | self.forward(x[d==du], d[d==du]) 110 | 111 | if hasattr(self, 'spddsbnorm'): 112 | self.spddsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER) 113 | if hasattr(self, 'tsdsbnorm'): 114 | self.tsdsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER) 115 | 116 | def finetune(self, x, y, d): 117 | if hasattr(self, 'spdbnorm'): 118 | self.spdbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT) 119 | if hasattr(self, 'tsbnorm'): 120 | self.tsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT) 121 | 122 | with torch.no_grad(): 123 | self.forward(x, d) 124 | 125 | if hasattr(self, 'spdbnorm'): 126 | self.spdbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER) 127 | if hasattr(self, 'tsbnorm'): 128 | self.tsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER) 129 | 130 | def compute_patterns(self, x, y, d): 131 | pass 132 | 133 | 134 | 135 | class CNNNet(DomainAdaptFineTuneableModel, FineTuneableModel): 136 | def __init__(self, temporal_filters, spatial_filters = 40, 137 | temp_cnn_kernel = 25, 138 | bnorm : Optional[str] = 'bn', 139 | bnorm_dispersion : Union[str, bn.BatchNormDispersion] = bn.BatchNormDispersion.SCALAR, 140 | **kwargs): 141 | super().__init__(**kwargs) 142 | 143 | self.temporal_filters_ = temporal_filters 144 | self.spatial_filters_ = spatial_filters 145 | self.bnorm_ = bnorm 146 | 147 | if isinstance(bnorm_dispersion, str): 148 | self.bnorm_dispersion_ = bn.BatchNormDispersion[bnorm_dispersion] 149 | else: 150 | self.bnorm_dispersion_ = bnorm_dispersion 151 | 152 | self.cnn = torch.nn.Sequential( 153 | torch.nn.Conv2d(1, self.temporal_filters_, kernel_size=(1,temp_cnn_kernel), 154 | padding='same', padding_mode='reflect'), 155 | torch.nn.Conv2d(self.temporal_filters_, self.spatial_filters_,(self.nchannels_, 1)), 156 | torch.nn.Flatten(start_dim=2), 157 | ).to(self.device_) 158 | 159 | self.cov_pooling = torch.nn.Sequential( 160 | modules.CovariancePool(), 161 | ) 162 | 163 | if self.bnorm_ == 'bn': 164 | self.bnorm = bn.AdaMomBatchNorm((1, self.spatial_filters_), batchdim=0, dispersion=self.bnorm_dispersion_, 165 | eta=1., eta_test=.1).to(self.device_) 166 | elif self.bnorm_ == 'dsbn': 167 | self.dsbnorm = bn.AdaMomDomainBatchNorm((1, self.spatial_filters_), batchdim=0, 168 | domains=self.domains_, 169 | dispersion=self.bnorm_dispersion_, 170 | eta=1., eta_test=.1).to(self.device_) 171 | elif self.bnorm_ is not None: 172 | raise NotImplementedError('requested undefined batch normalization method.') 173 | 174 | self.logarithm = torch.nn.Sequential( 175 | modules.MyLog(), 176 | torch.nn.Flatten(start_dim=1), 177 | ) 178 | self.classifier = torch.nn.Sequential( 179 | torch.nn.Linear(self.spatial_filters_,self.nclasses_), 180 | ).to(self.device_) 181 | 182 | def forward(self, x, d, return_latent=True): 183 | out = () 184 | h = self.cnn(x.to(device=self.device_)[:,None,...]) 185 | C = self.cov_pooling(h) 186 | l = torch.diagonal(C, dim1=-2, dim2=-1) 187 | l = self.logarithm(l) 188 | l = self.bnorm(l) if hasattr(self, 'bnorm') else l 189 | l = self.dsbnorm(l,d) if hasattr(self, 'dsbnorm') else l 190 | out += (l,) if return_latent else () 191 | y = self.classifier(l) 192 | out = y if len(out) == 0 else (y, *out[::-1]) 193 | return out 194 | 195 | def domainadapt_finetune(self, x, y, d, target_domains): 196 | if hasattr(self, 'dsbnorm'): 197 | self.dsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT) 198 | 199 | with torch.no_grad(): 200 | for du in d.unique(): 201 | self.forward(x[d==du], d[d==du]) 202 | 203 | if hasattr(self, 'dsbnorm'): 204 | self.dsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER) 205 | 206 | def finetune(self, x, y, d): 207 | if hasattr(self, 'bnorm'): 208 | self.bnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT) 209 | 210 | with torch.no_grad(): 211 | self.forward(x, d) 212 | 213 | if hasattr(self, 'bnorm'): 214 | self.bnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER) 215 | -------------------------------------------------------------------------------- /spdnets/models/tsmnetMLR.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Union 3 | import torch 4 | 5 | import spdnets.modules as modules 6 | import spdnets.batchnorm as bn 7 | from spdnets.SPDMLR import SPDRMLR 8 | from .base import DomainAdaptFineTuneableModel, FineTuneableModel, PatternInterpretableModel 9 | 10 | 11 | class TSMNetMLR(DomainAdaptFineTuneableModel, FineTuneableModel, PatternInterpretableModel): 12 | def __init__(self, temporal_filters, spatial_filters = 40, 13 | subspacedims = 20, 14 | temp_cnn_kernel = 25, 15 | bnorm : Optional[str] = 'spdbn', 16 | bnorm_dispersion : Union[str, bn.BatchNormDispersion] = bn.BatchNormDispersion.SCALAR, 17 | classifier='LogEigMLR', 18 | metric='SPDEuclideanMetric',power=1.0,alpha=1.0,beta=0.,**kwargs): 19 | super().__init__(**kwargs) 20 | 21 | self.temporal_filters_ = temporal_filters 22 | self.spatial_filters_ = spatial_filters 23 | self.subspacedimes = subspacedims 24 | self.bnorm_ = bnorm 25 | self.spd_device_ = torch.device('cpu') 26 | if isinstance(bnorm_dispersion, str): 27 | self.bnorm_dispersion_ = bn.BatchNormDispersion[bnorm_dispersion] 28 | else: 29 | self.bnorm_dispersion_ = bnorm_dispersion 30 | 31 | tsdim = int(subspacedims*(subspacedims+1)/2) 32 | 33 | self.cnn = torch.nn.Sequential( 34 | torch.nn.Conv2d(1, self.temporal_filters_, kernel_size=(1,temp_cnn_kernel), 35 | padding='same', padding_mode='reflect'), 36 | torch.nn.Conv2d(self.temporal_filters_, self.spatial_filters_,(self.nchannels_, 1)), 37 | torch.nn.Flatten(start_dim=2), 38 | ).to(self.device_) 39 | 40 | self.cov_pooling = torch.nn.Sequential( 41 | modules.CovariancePool(), 42 | ) 43 | 44 | if self.bnorm_ == 'spdbn': 45 | self.spdbnorm = bn.AdaMomSPDBatchNorm((1,subspacedims,subspacedims), batchdim=0, 46 | dispersion=self.bnorm_dispersion_, 47 | learn_mean=False,learn_std=True, 48 | eta=1., eta_test=.1, dtype=torch.double, device=self.spd_device_) 49 | elif self.bnorm_ == 'brooks': 50 | self.spdbnorm = modules.BatchNormSPDBrooks((1,subspacedims,subspacedims), batchdim=0, dtype=torch.double, device=self.spd_device_) 51 | elif self.bnorm_ == 'tsbn': 52 | self.tsbnorm = bn.AdaMomBatchNorm((1, tsdim), batchdim=0, dispersion=self.bnorm_dispersion_, 53 | eta=1., eta_test=.1, dtype=torch.double, device=self.spd_device_).to(self.device_) 54 | elif self.bnorm_ == 'spddsbn': 55 | self.spddsbnorm = bn.AdaMomDomainSPDBatchNorm((1,subspacedims,subspacedims), batchdim=0, 56 | domains=self.domains_, 57 | learn_mean=False,learn_std=True, 58 | dispersion=self.bnorm_dispersion_, 59 | eta=1., eta_test=.1, dtype=torch.double, device=self.spd_device_) 60 | elif self.bnorm_ == 'tsdsbn': 61 | self.tsdsbnorm = bn.AdaMomDomainBatchNorm((1, tsdim), batchdim=0, 62 | domains=self.domains_, 63 | dispersion=self.bnorm_dispersion_, 64 | eta=1., eta_test=.1, dtype=torch.double).to(self.device_) 65 | elif self.bnorm_ is not None: 66 | raise NotImplementedError('requested undefined batch normalization method.') 67 | 68 | self.spdnet = torch.nn.Sequential( 69 | modules.BiMap((1,self.spatial_filters_,subspacedims), dtype=torch.double, device=self.spd_device_), 70 | modules.ReEig(threshold=1e-4), 71 | ) 72 | self.construct_classifier(classifier,subspacedims,self.nclasses_,metric,power,alpha,beta) 73 | 74 | 75 | def to(self, device: Optional[Union[int, torch.device]] = None, dtype: Optional[Union[int, torch.dtype]] = None, non_blocking: bool = False): 76 | if device is not None: 77 | self.device_ = device 78 | self.cnn.to(self.device_) 79 | return super().to(device=None, dtype=dtype, non_blocking=non_blocking) 80 | 81 | def forward(self, x, d, return_latent=True, return_prebn=False, return_postbn=False): 82 | out = () 83 | h = self.cnn(x.to(device=self.device_)[:,None,...]) 84 | C = self.cov_pooling(h).to(device=self.spd_device_, dtype=torch.double) 85 | l = self.spdnet(C) 86 | out += (l,) if return_prebn else () 87 | l = self.spdbnorm(l) if hasattr(self, 'spdbnorm') else l 88 | l = self.spddsbnorm(l,d.to(device=self.spd_device_)) if hasattr(self, 'spddsbnorm') else l 89 | out += (l,) if return_postbn else () 90 | out += (l,) if return_latent else () 91 | y = self.classifier(l) 92 | out = y if len(out) == 0 else (y, *out[::-1]) 93 | return out 94 | 95 | def domainadapt_finetune(self, x, y, d, target_domains): 96 | if hasattr(self, 'spddsbnorm'): 97 | self.spddsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT) 98 | if hasattr(self, 'tsdsbnorm'): 99 | self.tsdsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT) 100 | 101 | with torch.no_grad(): 102 | for du in d.unique(): 103 | self.forward(x[d==du], d[d==du]) 104 | 105 | if hasattr(self, 'spddsbnorm'): 106 | self.spddsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER) 107 | if hasattr(self, 'tsdsbnorm'): 108 | self.tsdsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER) 109 | 110 | def finetune(self, x, y, d): 111 | if hasattr(self, 'spdbnorm'): 112 | self.spdbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT) 113 | if hasattr(self, 'tsbnorm'): 114 | self.tsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT) 115 | 116 | with torch.no_grad(): 117 | self.forward(x, d) 118 | 119 | if hasattr(self, 'spdbnorm'): 120 | self.spdbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER) 121 | if hasattr(self, 'tsbnorm'): 122 | self.tsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER) 123 | 124 | def compute_patterns(self, x, y, d): 125 | pass 126 | 127 | def construct_classifier(self,classifier,subspacedims,nclasses_,metric,power,alpha,beta): 128 | if classifier=='SPDMLR': 129 | self.classifier = torch.nn.Sequential( 130 | modules.UnsqueezeLayer(1), 131 | SPDRMLR(n=subspacedims,c=nclasses_,metric=metric,power=power,alpha=alpha,beta=beta).to(self.spd_device_).double() 132 | ) 133 | elif classifier=='LogEigMLR': 134 | tsdim = int(subspacedims * (subspacedims + 1) / 2) 135 | self.classifier = torch.nn.Sequential( 136 | modules.LogEig(subspacedims,tril=True), 137 | torch.nn.Linear(tsdim, self.nclasses_).double(), 138 | ).to(self.spd_device_) 139 | else: 140 | raise Exception(f'wrong clssifier {classifier}') 141 | 142 | 143 | 144 | class CNNNet(DomainAdaptFineTuneableModel, FineTuneableModel): 145 | def __init__(self, temporal_filters, spatial_filters = 40, 146 | temp_cnn_kernel = 25, 147 | bnorm : Optional[str] = 'bn', 148 | bnorm_dispersion : Union[str, bn.BatchNormDispersion] = bn.BatchNormDispersion.SCALAR, 149 | **kwargs): 150 | super().__init__(**kwargs) 151 | 152 | self.temporal_filters_ = temporal_filters 153 | self.spatial_filters_ = spatial_filters 154 | self.bnorm_ = bnorm 155 | 156 | if isinstance(bnorm_dispersion, str): 157 | self.bnorm_dispersion_ = bn.BatchNormDispersion[bnorm_dispersion] 158 | else: 159 | self.bnorm_dispersion_ = bnorm_dispersion 160 | 161 | self.cnn = torch.nn.Sequential( 162 | torch.nn.Conv2d(1, self.temporal_filters_, kernel_size=(1,temp_cnn_kernel), 163 | padding='same', padding_mode='reflect'), 164 | torch.nn.Conv2d(self.temporal_filters_, self.spatial_filters_,(self.nchannels_, 1)), 165 | torch.nn.Flatten(start_dim=2), 166 | ).to(self.device_) 167 | 168 | self.cov_pooling = torch.nn.Sequential( 169 | modules.CovariancePool(), 170 | ) 171 | 172 | if self.bnorm_ == 'bn': 173 | self.bnorm = bn.AdaMomBatchNorm((1, self.spatial_filters_), batchdim=0, dispersion=self.bnorm_dispersion_, 174 | eta=1., eta_test=.1).to(self.device_) 175 | elif self.bnorm_ == 'dsbn': 176 | self.dsbnorm = bn.AdaMomDomainBatchNorm((1, self.spatial_filters_), batchdim=0, 177 | domains=self.domains_, 178 | dispersion=self.bnorm_dispersion_, 179 | eta=1., eta_test=.1).to(self.device_) 180 | elif self.bnorm_ is not None: 181 | raise NotImplementedError('requested undefined batch normalization method.') 182 | 183 | self.logarithm = torch.nn.Sequential( 184 | modules.MyLog(), 185 | torch.nn.Flatten(start_dim=1), 186 | ) 187 | self.classifier = torch.nn.Sequential( 188 | torch.nn.Linear(self.spatial_filters_,self.nclasses_), 189 | ).to(self.device_) 190 | 191 | def forward(self, x, d, return_latent=True): 192 | out = () 193 | h = self.cnn(x.to(device=self.device_)[:,None,...]) 194 | C = self.cov_pooling(h) 195 | l = torch.diagonal(C, dim1=-2, dim2=-1) 196 | l = self.logarithm(l) 197 | l = self.bnorm(l) if hasattr(self, 'bnorm') else l 198 | l = self.dsbnorm(l,d) if hasattr(self, 'dsbnorm') else l 199 | out += (l,) if return_latent else () 200 | y = self.classifier(l) 201 | out = y if len(out) == 0 else (y, *out[::-1]) 202 | return out 203 | 204 | def domainadapt_finetune(self, x, y, d, target_domains): 205 | if hasattr(self, 'dsbnorm'): 206 | self.dsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT) 207 | 208 | with torch.no_grad(): 209 | for du in d.unique(): 210 | self.forward(x[d==du], d[d==du]) 211 | 212 | if hasattr(self, 'dsbnorm'): 213 | self.dsbnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER) 214 | 215 | def finetune(self, x, y, d): 216 | if hasattr(self, 'bnorm'): 217 | self.bnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.REFIT) 218 | 219 | with torch.no_grad(): 220 | self.forward(x, d) 221 | 222 | if hasattr(self, 'bnorm'): 223 | self.bnorm.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER) 224 | -------------------------------------------------------------------------------- /spdnets/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple 3 | import torch 4 | from torch import Tensor 5 | from torch.nn.parameter import Parameter 6 | from torch.types import Number 7 | import torch.nn as nn 8 | from geoopt.tensor import ManifoldParameter 9 | from geoopt.manifolds import Stiefel, Sphere, Euclidean 10 | from . import functionals 11 | 12 | class Conv2dWithNormConstraint(torch.nn.Conv2d): 13 | def __init__(self, *args, max_norm=1, **kwargs): 14 | self.max_norm = max_norm 15 | super(Conv2dWithNormConstraint, self).__init__(*args, **kwargs) 16 | 17 | def forward(self, x): 18 | self.weight.data = torch.renorm( 19 | self.weight.data, p=2, dim=0, maxnorm=self.max_norm 20 | ) 21 | return super(Conv2dWithNormConstraint, self).forward(x) 22 | 23 | 24 | class LinearWithNormConstraint(torch.nn.Linear): 25 | def __init__(self, *args, max_norm=1, **kwargs): 26 | self.max_norm = max_norm 27 | super(LinearWithNormConstraint, self).__init__(*args, **kwargs) 28 | 29 | def forward(self, x): 30 | self.weight.data = torch.renorm( 31 | self.weight.data, p=2, dim=0, maxnorm=self.max_norm 32 | ) 33 | return super(LinearWithNormConstraint, self).forward(x) 34 | 35 | 36 | class MySquare(torch.nn.Module): 37 | def __init__(self): 38 | super().__init__() 39 | def forward(self, x): 40 | return x.square() 41 | 42 | class MyLog(torch.nn.Module): 43 | def __init__(self, eps = 1e-4): 44 | super().__init__() 45 | self.eps = eps 46 | def forward(self, x): 47 | return torch.log(x + self.eps) 48 | 49 | 50 | class MyConv2d(torch.nn.Conv2d): 51 | def __init__(self, *args, **kwargs): 52 | super(MyConv2d, self).__init__(*args, **kwargs) 53 | 54 | self.convshape = self.weight.shape 55 | w0 = self.weight.data.flatten(start_dim=1) 56 | self.weight = ManifoldParameter(w0 / w0.norm(dim=-1, keepdim=True), manifold=Sphere()) 57 | 58 | def forward(self, x): 59 | return self._conv_forward(x, self.weight.view(self.convshape), self.bias) 60 | 61 | 62 | class UnitNormLinear(torch.nn.Linear): 63 | def __init__(self, *args, **kwargs): 64 | super(UnitNormLinear, self).__init__(*args, **kwargs) 65 | 66 | w0 = self.weight.data.flatten(start_dim=1) 67 | self.weight = ManifoldParameter(w0 / w0.norm(dim=-1, keepdim=True), manifold=Sphere()) 68 | 69 | def forward(self, x): 70 | return super().forward(x) 71 | 72 | 73 | class MyLinear(nn.Module): 74 | def __init__(self, shape : Tuple[int, ...] or torch.Size, bias: bool = True, **kwargs): 75 | super().__init__() 76 | 77 | self.W = Parameter(torch.empty(shape, **kwargs)) 78 | 79 | if bias: 80 | self.bias = Parameter(torch.empty((*shape[:-2], shape[-1]), **kwargs)) 81 | else: 82 | self.register_parameter('bias', None) 83 | 84 | self.reset_parameters() 85 | 86 | def forward(self, X : Tensor) -> Tensor: 87 | A = (X.unsqueeze(-2) @ self.W).squeeze(-2) 88 | if self.bias is not None: 89 | A += self.bias 90 | return A 91 | 92 | @torch.no_grad() 93 | def reset_parameters(self): 94 | # kaiming initialization std2uniformbound * gain * fan_in 95 | bound = math.sqrt(3) * 1. / math.sqrt(self.W.shape[-2]) 96 | self.W.data.uniform_(-bound,bound) 97 | if self.bias is not None: 98 | bound = 1 / math.sqrt(self.W.shape[-2]) 99 | self.bias.data.uniform_(-bound, bound) 100 | 101 | 102 | class Encode2DPosition(nn.Module): 103 | """ 104 | Encodes the 2D position of a 2D CNN or 2D image 105 | as additional channels. 106 | Input: (batch, chans, height, width) 107 | Output: (batch, chans+2, height, width) 108 | """ 109 | def __init__(self, flatten = True): 110 | super().__init__() 111 | self.flatten = flatten 112 | 113 | def forward(self, X : Tensor) -> Tensor: 114 | pos1 = torch.arange(X.shape[-2])[None,None,:,None].tile((X.shape[0],1, 1, X.shape[-1])) / X.shape[-2] 115 | pos2 = torch.arange(X.shape[-1])[None,None,None,:].tile((X.shape[0],1, X.shape[-2], 1)) / X.shape[-1] 116 | 117 | Z = torch.cat((X, pos1, pos2),dim=1) 118 | if self.flatten: 119 | Z = Z.flatten(start_dim=-2) 120 | 121 | return Z 122 | 123 | 124 | class CovariancePool(nn.Module): 125 | def __init__(self, alpha = None, unitvar = False): 126 | super().__init__() 127 | self.pooldim = -1 128 | self.chandim = -2 129 | self.alpha = alpha 130 | self.unitvar = unitvar 131 | 132 | def forward(self, X : Tensor) -> Tensor: 133 | X0 = X - X.mean(dim=self.pooldim, keepdim=True) 134 | if self.unitvar: 135 | X0 = X0 / X0.std(dim=self.pooldim, keepdim=True) 136 | X0.nan_to_num_(0) 137 | 138 | C = (X0 @ X0.transpose(-2, -1)) / X0.shape[self.pooldim] 139 | if self.alpha is not None: 140 | Cd = C.diagonal(dim1=self.pooldim, dim2=self.pooldim-1) 141 | Cd += self.alpha 142 | return C 143 | 144 | 145 | class ReverseGradient(nn.Module): 146 | def __init__(self, scaling = 1.): 147 | super().__init__() 148 | self.scaling_ = scaling 149 | 150 | def forward(self, X : Tensor) -> Tensor: 151 | return functionals.reverse_gradient.apply(X, self.scaling_) 152 | 153 | 154 | #----- Layers for SPDNet and TSMNet ----- 155 | class UnsqueezeLayer(nn.Module): 156 | def __init__(self, dim): 157 | super(__class__, self).__init__() 158 | self.dim = dim 159 | 160 | def forward(self, x): 161 | return x.unsqueeze(self.dim) 162 | 163 | 164 | class BiMap(nn.Module): 165 | """Note that 166 | following TSMNet, we use uniform distribution on the stiefel as initialization on the TSMNet 167 | following SPDNet and SPDNetBN, we use svd for initialization on the SPDNet 168 | """ 169 | def __init__(self, shape : Tuple[int, ...] or torch.Size, W0 : Tensor = None, manifold='stiefel', init_mode='uniform', **kwargs): 170 | super().__init__() 171 | 172 | self.shape=shape;self.manifold=manifold;self.init_mode=init_mode 173 | if manifold == 'euclidean': 174 | mf = Euclidean() 175 | # self.W = nn.Parameter(torch.empty(shape, **kwargs)) 176 | else: 177 | if manifold == 'stiefel': 178 | assert(shape[-2] >= shape[-1]) 179 | mf = Stiefel() 180 | elif manifold == 'sphere': 181 | mf = Sphere() 182 | shape = list(shape) 183 | shape[-1], shape[-2] = shape[-2], shape[-1] 184 | else: 185 | raise NotImplementedError() 186 | 187 | # add constraint (also initializes the parameter to fulfill the constraint) 188 | self.W = ManifoldParameter(torch.empty(shape, **kwargs), manifold=mf) 189 | 190 | # optionally initialize the weights (initialization has to fulfill the constraint!) 191 | if W0 is not None: 192 | self.W.data = W0 # e.g., self.W = torch.nn.init.orthogonal_(self.W) 193 | else: 194 | self.reset_parameters() 195 | 196 | def forward(self, X : Tensor) -> Tensor: 197 | if isinstance(self.W.manifold, Sphere): 198 | return self.W @ X @ self.W.transpose(-2,-1) 199 | else: 200 | return self.W.transpose(-2,-1) @ X @ self.W 201 | 202 | @torch.no_grad() 203 | def reset_parameters(self): 204 | if isinstance(self.W.manifold, Euclidean): 205 | v = torch.empty_like(self.W).uniform_(0., 1.) 206 | vv = torch.svd(v.matmul(v.t()))[0][:, :self.W.shape[-1]] 207 | self.W.data = vv 208 | elif isinstance(self.W.manifold, Stiefel): 209 | if self.init_mode=='uniform': 210 | # uniform initialization on stiefel manifold after theorem 2.2.1 in Chikuse (2003): statistics on special manifolds 211 | W = torch.rand(self.W.shape, dtype=self.W.dtype, device=self.W.device) 212 | self.W.data = W @ functionals.sym_invsqrtm.apply(W.transpose(-1,-2) @ W) 213 | elif self.init_mode=='svd': 214 | v = torch.empty_like(self.W).uniform_(0., 1.) 215 | vv = torch.svd(v.matmul(v.t()))[0][:, :self.W.shape[-1]] 216 | self.W.data = vv 217 | elif isinstance(self.W.manifold, Sphere): 218 | W = torch.empty(self.W.shape, dtype=self.W.dtype, device=self.W.device) 219 | # kaiming initialization std2uniformbound * gain * fan_in 220 | bound = math.sqrt(3) * 1. / W.shape[-1] 221 | W.uniform_(-bound, bound) 222 | # constraint has to be satisfied 223 | self.W.data = W / W.norm(dim=-1, keepdim=True) 224 | 225 | 226 | def __repr__(self): 227 | return f"{self.__class__.__name__}(shape={self.shape},manifold={self.manifold},init_mode={self.init_mode})" 228 | 229 | 230 | class ReEig(nn.Module): 231 | def __init__(self, threshold : Number = 1e-4): 232 | super().__init__() 233 | self.threshold = Tensor([threshold]) 234 | 235 | def forward(self, X : Tensor) -> Tensor: 236 | return functionals.sym_reeig.apply(X, self.threshold) 237 | 238 | def __repr__(self): 239 | return f"{self.__class__.__name__}(threshold={self.threshold})" 240 | 241 | 242 | class LogEig(nn.Module): 243 | """Note that following TSMNet and SPDNet, we set tril=True for TSMNet and False for SPDNet """ 244 | def __init__(self, ndim, tril=False): 245 | super().__init__() 246 | 247 | self.tril = tril 248 | if self.tril: 249 | ixs_lower = torch.tril_indices(ndim,ndim, offset=-1) 250 | ixs_diag = torch.arange(start=0, end=ndim, dtype=torch.long) 251 | self.ixs = torch.cat((ixs_diag[None,:].tile((2,1)), ixs_lower), dim=1) 252 | self.ndim = ndim 253 | 254 | def forward(self, X : Tensor) -> Tensor: 255 | return self.embed(functionals.sym_logm.apply(X)) 256 | 257 | def embed(self, X : Tensor) -> Tensor: 258 | if self.tril: 259 | x_vec = X[...,self.ixs[0],self.ixs[1]] 260 | x_vec[...,self.ndim:] *= math.sqrt(2) 261 | else: 262 | x_vec = X.flatten(start_dim=1) 263 | return x_vec 264 | 265 | def __repr__(self): 266 | return f"{self.__class__.__name__}(ndim={self.ndim},tril={self.tril}" 267 | 268 | -------------------------------------------------------------------------------- /spdnets/training/spdnet_training.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | import os 3 | import time 4 | import logging 5 | import torch as th 6 | import torch.nn as nn 7 | import numpy as np 8 | import fcntl 9 | 10 | from spdnets.utils.spdnet import Get_Model 11 | from spdnets.utils.spdnet.utils import get_dataset_settings,optimzer,parse_cfg,train_per_epoch,val_per_epoch 12 | import spdnets.utils.spdnet.utils as spdnet_utils 13 | from spdnets.utils.common_utils import set_seed_thread 14 | 15 | def training(cfg,args): 16 | args=parse_cfg(args,cfg) 17 | 18 | #set logger 19 | logger = logging.getLogger(args.modelname) 20 | logger.setLevel(logging.INFO) 21 | args.logger = logger 22 | logger.info('begin model {} on dataset: {}'.format(args.modelname,args.dataset)) 23 | 24 | #set seed and threadnum 25 | set_seed_thread(args.seed,args.threadnum) 26 | 27 | # set dataset, model and optimizer 28 | args.DataLoader = get_dataset_settings(args) 29 | model = Get_Model.get_model(args) 30 | loss_fn = nn.CrossEntropyLoss() 31 | args.loss_fn = loss_fn.cuda() 32 | args.opti = optimzer(model.parameters(), lr=args.lr, mode=args.optimizer,weight_decay=args.weight_decay) 33 | # begin training 34 | val_acc = training_loop(model,args) 35 | 36 | return val_acc 37 | 38 | def training_loop(model, args): 39 | #setting tensorboard 40 | if args.is_writer: 41 | args.writer_path = os.path.join('./tensorboard_logs/',f"{args.modelname}") 42 | args.logger.info('writer path {}'.format(args.writer_path)) 43 | args.writer = SummaryWriter(args.writer_path) 44 | 45 | acc_val = [];loss_val = [];acc_train = [];loss_train = [];training_time=[] 46 | logger = args.logger 47 | # training loop 48 | for epoch in range(0, args.epochs): 49 | # training 50 | elapse,epoch_loss_train,epoch_acc_train = train_per_epoch(model,args) 51 | training_time.append(elapse) 52 | acc_train.append(np.asarray(epoch_acc_train).mean() * 100) 53 | loss_train.append(np.asarray(epoch_loss_train).mean()) 54 | 55 | # validation 56 | epoch_loss_val,epoch_acc_val = val_per_epoch(model, args) 57 | loss_val.append(np.asarray(epoch_loss_val).mean()) 58 | acc_val.append(np.asarray(epoch_acc_val).mean() * 100) 59 | 60 | # save data into tensorboard 61 | if args.is_writer: 62 | args.writer.add_scalar('Loss/val', loss_val[epoch], epoch) 63 | args.writer.add_scalar('Accuracy/val', acc_val[epoch], epoch) 64 | args.writer.add_scalar('Loss/train', loss_train[epoch], epoch) 65 | args.writer.add_scalar('Accuracy/train', acc_train[epoch], epoch) 66 | 67 | # print results 68 | spdnet_utils.print_results(logger,training_time,acc_val,loss_val,epoch,args) 69 | 70 | # save final data 71 | save_results(logger,training_time, acc_val, args) 72 | 73 | if args.is_writer: 74 | args.writer.close() 75 | return acc_val 76 | 77 | def write_final_results(file_path,message): 78 | # Create a file lock 79 | with open(file_path, "a") as file: 80 | fcntl.flock(file.fileno(), fcntl.LOCK_EX) # Acquire an exclusive lock 81 | # Write the message to the file 82 | file.write(message + "\n") 83 | fcntl.flock(file.fileno(), fcntl.LOCK_UN) 84 | 85 | def save_results(logger,training_time,acc_val,args): 86 | if args.is_save: 87 | average_time = np.asarray(training_time[-10:]).mean() 88 | final_val_acc = acc_val[-1] 89 | final_results = f'Final validation accuracy : {final_val_acc:.2f}% with average time: {average_time:.2f}' 90 | final_results_path = os.path.join(os.getcwd(), 'final_results_' + args.dataset) 91 | logger.info(f"results file path: {final_results_path}, and saving the results") 92 | write_final_results(final_results_path, args.modelname + '- ' + final_results) 93 | torch_results_dir = './torch_resutls' 94 | if not os.path.exists(torch_results_dir): 95 | os.makedirs(torch_results_dir) 96 | th.save({ 97 | 'acc_val': acc_val, 98 | }, os.path.join(torch_results_dir,args.modelname.rsplit('-',1)[0])) -------------------------------------------------------------------------------- /spdnets/utils/common_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch as th 4 | 5 | def set_seed_only(seed): 6 | seed = seed 7 | random.seed(seed) 8 | # th.cuda.set_device(args.gpu) 9 | np.random.seed(seed) 10 | th.manual_seed(seed) 11 | th.cuda.manual_seed(seed) 12 | 13 | def set_seed_thread(seed,threadnum): 14 | th.set_num_threads(threadnum) 15 | seed = seed 16 | random.seed(seed) 17 | # th.cuda.set_device(args.gpu) 18 | np.random.seed(seed) 19 | th.manual_seed(seed) 20 | th.cuda.manual_seed(seed) 21 | 22 | def set_up(args): 23 | set_seed_thread(args.seed, args.threadnum) 24 | print('begin model {}'.format(args.modelname)) 25 | print('writer path {}'.format(args.writer_path)) 26 | -------------------------------------------------------------------------------- /spdnets/utils/skorch/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .logging import TrainLog 3 | from .network import DomainAdaptNeuralNetClassifier -------------------------------------------------------------------------------- /spdnets/utils/skorch/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/spdnets/utils/skorch/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /spdnets/utils/skorch/__pycache__/logging.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/spdnets/utils/skorch/__pycache__/logging.cpython-310.pyc -------------------------------------------------------------------------------- /spdnets/utils/skorch/__pycache__/network.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitZH-Chen/SPDMLR/cefb58bd686a03b90fb97d89b273f6043f8e920a/spdnets/utils/skorch/__pycache__/network.cpython-310.pyc -------------------------------------------------------------------------------- /spdnets/utils/skorch/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from skorch.callbacks.logging import EpochTimer, PrintLog 3 | 4 | log = logging.getLogger(__name__) 5 | 6 | class TrainLog(PrintLog): 7 | 8 | def __init__(self, prefix='') -> None: 9 | super().__init__() 10 | self.prefix = prefix 11 | 12 | def initialize(self): 13 | return self 14 | 15 | def on_epoch_end(self, net, **kwargs): 16 | r = net.history[-1] 17 | 18 | if r['epoch'] == 1 or r['epoch'] % 10 == 0: 19 | log.info(f"{self.prefix} {r['epoch']:3d} : trn={r['train_loss']:.3f}/{r['score_trn']:.2f} val={r['valid_loss']:.3f}/{r['score_val']:.2f} time: {r['dur']:.2f}") 20 | 21 | 22 | -------------------------------------------------------------------------------- /spdnets/utils/skorch/network.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from skorch.classifier import NeuralNetClassifier 4 | from skorch.callbacks.logging import EpochTimer, PrintLog 5 | from skorch.callbacks.scoring import EpochScoring, PassthroughScoring 6 | 7 | from spdnets.models import BaseModel, DomainAdaptFineTuneableModel, FineTuneableModel 8 | 9 | from .logging import TrainLog 10 | 11 | log = logging.getLogger(__name__) 12 | 13 | class DomainAdaptNeuralNetClassifier(NeuralNetClassifier): 14 | def __init__(self, module, *args, criterion=torch.nn.CrossEntropyLoss, **kwargs): 15 | super().__init__(module, *args, criterion=criterion, **kwargs) 16 | 17 | @property 18 | def _default_callbacks(self): 19 | return [ 20 | ('epoch_timer', EpochTimer()), 21 | ('train_loss', PassthroughScoring( 22 | name='train_loss', 23 | on_train=True, 24 | )), 25 | ('valid_loss', PassthroughScoring( 26 | name='valid_loss', 27 | )), 28 | ('print_log', TrainLog()), 29 | ] 30 | 31 | def get_loss(self, mdl_pred, y_true, X=None, **kwargs): 32 | if isinstance(self.module_, BaseModel): 33 | return self.module_.calculate_objective(mdl_pred, y_true, X) 34 | elif isinstance(mdl_pred, (list, tuple)): 35 | y_hat = mdl_pred[0] 36 | else: 37 | y_hat = mdl_pred 38 | return self.criterion_(y_hat, y_true.to(y_hat.device)) 39 | 40 | def domainadapt_finetune(self, x: torch.Tensor, y: torch.Tensor, d : torch.Tensor, target_domains=None): 41 | if isinstance(self.module_, DomainAdaptFineTuneableModel): 42 | self.module_.domainadapt_finetune(x=x.to(self.device), y=y, d=d, target_domains=target_domains) 43 | else: 44 | log.info("Model does not support domain adapt fine tuning.") 45 | 46 | def finetune(self, x: torch.Tensor, y: torch.Tensor, d : torch.Tensor): 47 | if isinstance(self.module_, FineTuneableModel): 48 | self.module_.finetune(x=x.to(self.device), y=y, d=d) 49 | else: 50 | log.info("Model does not support fine-tuning.") 51 | -------------------------------------------------------------------------------- /spdnets/utils/spdnet/Get_Model.py: -------------------------------------------------------------------------------- 1 | from spdnets.models.spdnet import SPDNet 2 | 3 | def get_model(args): 4 | model = SPDNet(args) 5 | print(model) 6 | return model.double() -------------------------------------------------------------------------------- /spdnets/utils/spdnet/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import geoopt 3 | import time 4 | import torch as th 5 | 6 | from datasets.spdnet.Radar_Loader import DataLoaderRadar 7 | from datasets.spdnet.HDM05_Loader import DataLoaderHDM05 8 | 9 | def get_dataset_settings(args): 10 | if args.dataset=='HDM05': 11 | pval = 0.5 12 | DataLoader = DataLoaderHDM05(args.path, pval, args.batchsize) 13 | elif args.dataset== 'RADAR' : 14 | pval = 0.25 15 | DataLoader = DataLoaderRadar(args.path,pval,args.batchsize) 16 | else: 17 | raise Exception('unknown dataset {}'.format(args.dataset)) 18 | return DataLoader 19 | 20 | def get_model_name(args): 21 | if args.classifier == 'SPDMLR': 22 | if args.metric == 'SPDLogEuclideanMetric': 23 | description = f'{args.metric}-[{args.alpha},{args.beta:.4f}]' 24 | elif args.metric == 'SPDLogCholeskyMetric': 25 | description = f'{args.metric}-[{args.power}]' 26 | 27 | description = '-' + description 28 | elif args.classifier == 'LogEigMLR': 29 | description='' 30 | else: 31 | raise NotImplementedError 32 | optim = f'{args.lr}-{args.optimizer}-{args.weight_decay}' 33 | name = f'{optim}-{args.model_type}-{args.init_mode}-{args.bimap_manifold}-{args.classifier}{description}-{args.architecture}-{datetime.datetime.now().strftime("%H_%M")}' 34 | return name 35 | 36 | def optimzer(parameters,lr,mode='AMSGRAD',weight_decay=0.): 37 | if mode=='ADAM': 38 | optim = geoopt.optim.RiemannianAdam(parameters, lr=lr,weight_decay=weight_decay) 39 | elif mode=='SGD': 40 | optim = geoopt.optim.RiemannianSGD(parameters, lr=lr,weight_decay=weight_decay) 41 | elif mode=='AMSGRAD': 42 | optim = geoopt.optim.RiemannianAdam(parameters, lr=lr,amsgrad=True,weight_decay=weight_decay) 43 | else: 44 | raise Exception('unknown optimizer {}'.format(mode)) 45 | return optim 46 | 47 | def parse_cfg(args,cfg): 48 | # setting args from cfg 49 | 50 | #fit 51 | args.epochs = cfg.fit.epochs 52 | args.batchsize = cfg.fit.batch_size 53 | args.threadnum = cfg.fit.threadnum 54 | args.is_writer = cfg.fit.is_writer 55 | args.cycle = cfg.fit.cycle 56 | args.seed = cfg.fit.seed 57 | args.is_save = cfg.fit.is_save 58 | 59 | # model 60 | args.model_type = cfg.nnet.model.model_type 61 | args.init_mode = cfg.nnet.model.init_mode 62 | args.bimap_manifold = cfg.nnet.model.bimap_manifold 63 | args.architecture = cfg.nnet.model.architecture 64 | args.classifier = cfg.nnet.model.classifier 65 | args.metric = cfg.nnet.model.metric 66 | args.power = cfg.nnet.model.power 67 | args.alpha = cfg.nnet.model.alpha 68 | args.beta = eval(cfg.nnet.model.beta) if isinstance(cfg.nnet.model.beta, str) else cfg.nnet.model.beta 69 | 70 | #optimizer 71 | args.optimizer = cfg.nnet.optimizer.mode 72 | args.lr = cfg.nnet.optimizer.lr 73 | args.weight_decay = cfg.nnet.optimizer.weight_decay 74 | 75 | #dataset 76 | args.dataset = cfg.dataset.name 77 | args.class_num=cfg.dataset.class_num 78 | args.path = cfg.dataset.path 79 | 80 | # get model name 81 | args.modelname = get_model_name(args) 82 | 83 | return args 84 | 85 | def train_per_epoch(model,args): 86 | start = time.time() 87 | epoch_loss, epoch_acc = [], [] 88 | model.train() 89 | for local_batch, local_labels in args.DataLoader._train_generator: 90 | local_batch = local_batch.to(th.double) 91 | args.opti.zero_grad() 92 | out = model(local_batch) 93 | l = args.loss_fn(out, local_labels) 94 | acc, loss = (out.argmax(1) == local_labels).cpu().numpy().sum() / out.shape[0], l.cpu().data.numpy() 95 | epoch_loss.append(loss) 96 | epoch_acc.append(acc) 97 | l.backward() 98 | args.opti.step() 99 | end = time.time() 100 | elapse = end - start 101 | return elapse,epoch_loss,epoch_acc 102 | 103 | def val_per_epoch(model,args): 104 | epoch_loss, epoch_acc = [], [] 105 | y_true, y_pred = [], [] 106 | model.eval() 107 | with th.no_grad(): 108 | for local_batch, local_labels in args.DataLoader._test_generator: 109 | local_batch = local_batch.to(th.double) 110 | out = model(local_batch) 111 | l = args.loss_fn(out, local_labels) 112 | predicted_labels = out.argmax(1) 113 | y_true.extend(list(local_labels.cpu().numpy())) 114 | y_pred.extend(list(predicted_labels.cpu().numpy())) 115 | acc, loss = (predicted_labels == local_labels).cpu().numpy().sum() / out.shape[0], l.cpu().data.numpy() 116 | epoch_acc.append(acc) 117 | epoch_loss.append(loss) 118 | return epoch_loss,epoch_acc 119 | 120 | def print_results(logger,training_time,acc_val,loss_val,epoch,args): 121 | if epoch % args.cycle == 0: 122 | logger.info(f'Time: {training_time[epoch]:.2f}, Val acc: {acc_val[epoch]:.2f}, loss: {loss_val[epoch]:.2f} at epoch {epoch + 1:d}/{args.epochs:d}') --------------------------------------------------------------------------------