├── figures ├── .gitkeep ├── Lie_Nuerons_ICML24.pdf ├── lie_neurons_icon.jpg └── lie_neurons_modules.jpg ├── matlab ├── inv_figure.fig ├── vee_so3.m ├── hat_so3.m ├── hat_sl3.m ├── vee_sl3.m ├── test_euler_poincare.m ├── test_Ad_N.m ├── test_killing_relu.m ├── find_adjoint_KAN.m ├── sl3_equivariant_lift.asv ├── sl3_equivariant_lift.m ├── sl3_equivariant_lift_find_bound.asv └── sl3_equivariant_lift_find_bound.m ├── playground ├── emlp_test.py ├── neural_ode_nonhomogeneous.py └── neural_ode_demo.py ├── docker ├── EMLP │ ├── build_docker_container.bash │ ├── README.md │ └── Dockerfile └── LieNeurons │ ├── build_docker_container.bash │ ├── README.md │ └── Dockerfile ├── test ├── test_sp4_hat.py ├── test_se3_hat.py ├── test_so3_bch.py ├── test_so3_exp.py ├── test_invariant.py ├── test_equivariant_liebracket.py ├── test_equivariant.py ├── test_pooling_equivariant.py ├── test_batch_norm_equivariant.py └── test_so3_bch_time.py ├── config ├── sl3_equiv │ ├── testing_param.yaml │ └── training_param.yaml ├── so3_bch │ ├── testing_param.yaml │ └── training_param.yaml ├── sl3_inv │ ├── testing_param.yaml │ └── training_param.yaml ├── platonic_solid_cls │ ├── testing_param.yaml │ └── training_param.yaml └── sp4_inv │ ├── testing_param.yaml │ └── training_param.yaml ├── LICENSE ├── script ├── run_euler_poincare.bash └── evaluate_multiple.bash ├── data_loader ├── sp4_inv_data_loader.py ├── sl3_inv_data_loader.py └── so3_bch_data_loader.py ├── README.md ├── core ├── vn_layers.py └── lie_group_util.py ├── .gitignore ├── experiment ├── platonic_solid_cls_test.py ├── sl3_inv_test.py ├── sp4_inv_test.py ├── sl3_equiv_test.py ├── sl3_inv_train.py ├── sp4_inv_train.py ├── so3_bch_layers.py ├── platonic_solid_cls_layers.py ├── sl3_equiv_train.py └── platonic_solid_cls_train.py └── data_gen ├── gen_sl3_inv_data.py ├── gen_sp4_inv_data.py ├── gen_so3_bch.py └── gen_sl3_inv_5_input_data.py /figures/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /matlab/inv_figure.fig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMich-CURLY/LieNeurons/HEAD/matlab/inv_figure.fig -------------------------------------------------------------------------------- /figures/Lie_Nuerons_ICML24.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMich-CURLY/LieNeurons/HEAD/figures/Lie_Nuerons_ICML24.pdf -------------------------------------------------------------------------------- /figures/lie_neurons_icon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMich-CURLY/LieNeurons/HEAD/figures/lie_neurons_icon.jpg -------------------------------------------------------------------------------- /figures/lie_neurons_modules.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMich-CURLY/LieNeurons/HEAD/figures/lie_neurons_modules.jpg -------------------------------------------------------------------------------- /matlab/vee_so3.m: -------------------------------------------------------------------------------- 1 | function v = vee_so3(M) 2 | 3 | v(1) = M(3, 2); 4 | v(2) = M(1, 3); 5 | v(3) = M(2, 1); 6 | 7 | end -------------------------------------------------------------------------------- /matlab/hat_so3.m: -------------------------------------------------------------------------------- 1 | function R = hat_so3(v) 2 | Ex = [0, 0, 0;0, 0, -1;0, 1, 0]; 3 | Ey = [0, 0, 1;0, 0, 0;-1, 0, 0]; 4 | Ez = [0, -1, 0;1, 0, 0;0, 0, 0]; 5 | 6 | R = v(1)*Ex+v(2)*Ey+v(3)*Ez; 7 | 8 | % if ndims(v)==1 9 | % R = v(1)*Ex+v(2)*Ey+v(3)*Ez; 10 | % elseif ndims(v) > 1 11 | % R = v(:,1).*Ex+v(:,2).*Ey+v(:,3).*Ez; 12 | % end 13 | 14 | end -------------------------------------------------------------------------------- /playground/emlp_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import emlp.nn.pytorch as nn 4 | from emlp.reps import T,V,sparsify_basis 5 | from emlp.groups import SO, SL 6 | 7 | 8 | 9 | if __name__ == '__main__': 10 | # Define the input representation 11 | G = SL(3) 12 | reps = V(G) 13 | reps_out = (V**1*V.T**1)(G) 14 | 15 | print(reps.rho(G)) 16 | print(reps.size()) 17 | print(reps_out.size()) 18 | Q = (reps>>reps_out).equivariant_basis() 19 | print(f"Basis matrix of shape {Q.shape}") 20 | # print(sparsify_basis(Q).reshape(3,3)) 21 | print(reps(G).size()) 22 | # Define the output representation -------------------------------------------------------------------------------- /docker/EMLP/build_docker_container.bash: -------------------------------------------------------------------------------- 1 | container_name=$1 2 | 3 | xhost +local: 4 | docker run -it --net=host --shm-size 8G --gpus all \ 5 | --user=$(id -u) \ 6 | -e DISPLAY=$DISPLAY \ 7 | -e QT_GRAPHICSSYSTEM=native \ 8 | -e NVIDIA_DRIVER_CAPABILITIES=all \ 9 | -e XAUTHORITY \ 10 | -e USER=$USER \ 11 | --workdir=/home/$USER/ \ 12 | -v "/tmp/.X11-unix:/tmp/.X11-unix:rw" \ 13 | -v "/etc/passwd:/etc/passwd:rw" \ 14 | -e "TERM=xterm-256color" \ 15 | -v "/home/$USER/DockerFolder:/home/$USER/" \ 16 | -v "/media/$USER/DATA/data:/home/$USER/data/" \ 17 | --device=/dev/dri:/dev/dri \ 18 | --name=${container_name} \ 19 | --security-opt seccomp=unconfined \ 20 | umcurly/emlp:latest 21 | -------------------------------------------------------------------------------- /docker/LieNeurons/build_docker_container.bash: -------------------------------------------------------------------------------- 1 | container_name=$1 2 | 3 | xhost +local: 4 | docker run -it --net=host --shm-size 8G --gpus all \ 5 | --user=$(id -u) \ 6 | -e DISPLAY=$DISPLAY \ 7 | -e QT_GRAPHICSSYSTEM=native \ 8 | -e NVIDIA_DRIVER_CAPABILITIES=all \ 9 | -e XAUTHORITY \ 10 | -e USER=$USER \ 11 | --workdir=/home/$USER/ \ 12 | -v "/tmp/.X11-unix:/tmp/.X11-unix:rw" \ 13 | -v "/etc/passwd:/etc/passwd:rw" \ 14 | -e "TERM=xterm-256color" \ 15 | -v "/home/$USER/DockerFolder:/home/$USER/" \ 16 | -v "/media/$USER/DATA/data:/home/$USER/data/" \ 17 | --device=/dev/dri:/dev/dri \ 18 | --name=${container_name} \ 19 | --security-opt seccomp=unconfined \ 20 | umcurly/lieneurons:latest 21 | -------------------------------------------------------------------------------- /matlab/hat_sl3.m: -------------------------------------------------------------------------------- 1 | function A = hat_sl3(v) 2 | 3 | % E1 = [1, 0, 0; 0, -1, 0; 0, 0, 0]; 4 | % E2 = [0, 1, 0; 1, 0, 0; 0, 0, 0]; 5 | % E3 = [0, -1, 0; 1, 0, 0; 0, 0, 0]; 6 | % E4 = [1, 0, 0; 0, 1, 0; 0, 0, -2]; 7 | % E5 = [0, 0, 1; 0, 0, 0; 0, 0, 0]; 8 | % E6 = [0, 0, 0; 0, 0, 1; 0, 0, 0]; 9 | % E7 = [0, 0, 0; 0, 0, 0; 1, 0, 0]; 10 | % E8 = [0, 0, 0; 0, 0, 0; 0, 1, 0]; 11 | E1 = [1, 0, 0; 0, -1, 0; 0, 0, 0]; 12 | E2 = [0, 1, 0; 1, 0, 0; 0, 0, 0]; 13 | E3 = [0, -1, 0; 1, 0, 0; 0, 0, 0]; 14 | E4 = [1, 0, 0; 0, 1, 0; 0, 0, -2]; 15 | E5 = [0, 0, 1; 0, 0, 0; 0, 0, 0]; 16 | E6 = [0, 0, 0; 0, 0, 1; 0, 0, 0]; 17 | E7 = [0, 0, 0; 0, 0, 0; 1, 0, 0]; 18 | E8 = [0, 0, 0; 0, 0, 0; 0, 1, 0]; 19 | 20 | 21 | A = v(1)*E1+v(2)*E2+v(3)*E3+v(4)*E4+v(5)*E5+v(6)*E6+v(7)*E7+v(8)*E8; 22 | end 23 | 24 | -------------------------------------------------------------------------------- /docker/EMLP/README.md: -------------------------------------------------------------------------------- 1 | # EPN-NetVLAD Docker 2 | This folder contains instructions on creating a docker image that includes PyTorch and other libraries needed for EPN-NetVLAD. 3 | 4 | ## How to build 5 | 6 | build image 7 | ``` 8 | docker build --tag umcurly/emlp . 9 | ``` 10 | change folder direction in line 15, then build container 11 | ``` 12 | bash build_docker_container.bash [container_name] 13 | ``` 14 | After building the container, you will enter the docker container. To work stably in docker, we recommend running `exit` and then follow the next section for running docker. 15 | 16 | ## How to use 17 | start docker 18 | ``` 19 | docker start [container_name] 20 | ``` 21 | run docker 22 | ``` 23 | docker exec -it [container_name] /bin/bash 24 | ``` 25 | run docker with root access 26 | ``` 27 | docker exec -u root -it [container_name] /bin/bash 28 | ``` 29 | -------------------------------------------------------------------------------- /docker/LieNeurons/README.md: -------------------------------------------------------------------------------- 1 | # EPN-NetVLAD Docker 2 | This folder contains instructions on creating a docker image that includes PyTorch and other libraries needed for EPN-NetVLAD. 3 | 4 | ## How to build 5 | 6 | build image 7 | ``` 8 | docker build --tag umcurly/lieneurons . 9 | ``` 10 | change folder direction in line 15, then build container 11 | ``` 12 | bash build_docker_container.bash [container_name] 13 | ``` 14 | After building the container, you will enter the docker container. To work stably in docker, we recommend running `exit` and then follow the next section for running docker. 15 | 16 | ## How to use 17 | start docker 18 | ``` 19 | docker start [container_name] 20 | ``` 21 | run docker 22 | ``` 23 | docker exec -it [container_name] /bin/bash 24 | ``` 25 | run docker with root access 26 | ``` 27 | docker exec -u root -it [container_name] /bin/bash 28 | ``` 29 | -------------------------------------------------------------------------------- /matlab/vee_sl3.m: -------------------------------------------------------------------------------- 1 | function v = vee_sl3(M) 2 | v(4) = -0.5 * M(3, 3); 3 | v(5) = M(1, 3); 4 | v(6) = M(2, 3); 5 | v(7) = M(3, 1); 6 | v(8) = M(3, 2); 7 | v(1) = (M(1, 1) - v(4)); 8 | 9 | v(2) = 0.5*(M(1, 2) + M(2, 1)); 10 | v(3) = 0.5*(M(2, 1) - M(1, 2)); 11 | 12 | v = v'; 13 | end 14 | 15 | % def vee_sl3(M): 16 | % # [a1 + a4, a2 - a3, a5] 17 | % # [a2 + a3, a4 - a1, a6] 18 | % # [ a7, a8, -2*a4] 19 | % v = torch.zeros(M.shape[:-2]+(8,)).to(M.device) 20 | % v[..., 3] = -0.5*M[..., 2, 2] 21 | % v[..., 4] = M[..., 0, 2] 22 | % v[..., 5] = M[..., 1, 2] 23 | % v[..., 6] = M[..., 2, 0] 24 | % v[..., 7] = M[..., 2, 1] 25 | % v[..., 0] = (M[..., 0, 0] - v[..., 3]) 26 | % 27 | % v[..., 1] = 0.5*(M[..., 0, 1] + M[..., 1, 0]) 28 | % v[..., 2] = 0.5*(M[..., 1, 0] - M[..., 0, 1]) 29 | % return v -------------------------------------------------------------------------------- /test/test_sp4_hat.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import numpy as np 5 | import scipy.linalg 6 | import torch 7 | 8 | from core.lie_alg_util import * 9 | 10 | if __name__ == "__main__": 11 | 12 | v = torch.Tensor([1,2,3,4,5,6,7,8,9,10]) 13 | print("v: ", v) 14 | sp4_hatlayer = HatLayer(algebra_type='sp4') 15 | M = sp4_hatlayer(v) 16 | print("M: ", M) 17 | v2 = vee(M, algebra_type='sp4') 18 | print("v after vee: ", v2) 19 | 20 | rnd_scale=5 21 | v = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 1, 10))) 22 | M = sp4_hatlayer(v).squeeze(0).squeeze(0) 23 | 24 | zeros = torch.zeros(2, 2) 25 | I = torch.eye(2) 26 | # Construct the 4x4 matrix by combining the components 27 | omega = torch.cat((torch.cat((zeros, I), dim=1), 28 | torch.cat((-I, zeros), dim=1)), dim=0) 29 | 30 | print("Checking if M satisfy the definition of sp(4): ") 31 | print(omega@M+M.T@omega) -------------------------------------------------------------------------------- /matlab/test_euler_poincare.m: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | true_y0 = [2., 1.,3.0]; 5 | t = [0:0.025:25]; 6 | 7 | 8 | [t,y] = ode45(@EulerPoincare, t, true_y0); 9 | 10 | 11 | f1 = figure(1); 12 | plot(t,y); 13 | legend("x","y","z"); 14 | xlabel("time") 15 | ylabel("x,y,z") 16 | 17 | f2 = figure(2); 18 | plot(y(:,1),y(:,2)); 19 | xlabel("x") 20 | ylabel("y") 21 | 22 | f3 = figure(3); 23 | plot3(y(:,1),y(:,2),y(:,3)); 24 | 25 | [X,Y,Z] = meshgrid(-2:0.1:2,-2:0.1:2,-2:0.1:2); 26 | 27 | function w_wedge = wedge(w) 28 | w_wedge = zeros(3,3); 29 | w_wedge(1,2) = -w(3); 30 | w_wedge(1,3) = w(2); 31 | w_wedge(2,1) = w(3); 32 | w_wedge(2,3) = -w(1); 33 | w_wedge(3,1) = -w(2); 34 | w_wedge(3,2) = w(1); 35 | end 36 | 37 | function dwdt = EulerPoincare(t,w) 38 | % I = [[12, 0, 0];[0, 20., 0];[0, 0, 5.]]; 39 | % I = [[12, -5., 7.];[-5., 20., -2.];[7., -2., 5.]]; 40 | I = [[5410880., -246595., 2967671.];[-246595., 29457838., -47804.];[2967671., -47804., 26744180.]] 41 | w_wedge = wedge(w); 42 | 43 | 44 | dwdt = -I\w_wedge*I*w; 45 | end -------------------------------------------------------------------------------- /config/sl3_equiv/testing_param.yaml: -------------------------------------------------------------------------------- 1 | test_data_path: "/home/justin/code/LieNeurons/data/sl3_equiv_lie_bracket_data/sl3_equiv_10000_lie_bracket_2inputs_test_data.npz" 2 | # logger 3 | model_description: "SL3EquivariantFunctionFitting" 4 | print_freq: 100 5 | # model_type: "LN" 6 | # model_path: "/home/justin/code/LieNeurons/weights/0926_exp_sl3_equiv_lie_bracket_2input_LN_5_256_no_share_nonlin_train_10000_lr_ 0.000003_best_test_loss_acc.pt" 7 | 8 | # model_type: "LN_relu_bracket" 9 | # model_path: "/home/justin/code/LieNeurons/weights/0928_equiv_one_relu_one_bracket_final_best_test_loss_acc.pt" 10 | model_type: "MLP" 11 | model_path: "/home/justin/code/LieNeurons/weights/rebuttal_equiv_mlp_augmented_5_best_test_loss_acc.pt" 12 | # model_type: "MLP" 13 | # model_path: "/home/justin/code/LieNeurons/weights/rebuttal_equiv_mlp_augmented_5_best_test_loss_acc.pt" 14 | # model_type: "MLP" 15 | # model_path: "/home/justin/code/LieNeurons/weights/0928_equiv_MLP_best_test_loss_acc.pt" 16 | 17 | # model_type: "LN_bracket_no_residual" 18 | # model_path: "/home/justin/code/LieNeurons/weights/0928_equiv_one_bracket_no_residual_best_test_loss_acc.pt" 19 | batch_size: 100 20 | shuffle: False 21 | -------------------------------------------------------------------------------- /config/so3_bch/testing_param.yaml: -------------------------------------------------------------------------------- 1 | test_data_path: "/home/justin/code/LieNeurons/data/so3_bch_data/so3_bch_10000_test_data.npz" 2 | # logger 3 | model_description: "SO3BCHFunctionFitting" 4 | print_freq: 100 5 | 6 | # model_type: "LN_relu_bracket" 7 | # model_path: "/home/justin/code/LieNeurons/weights/0120_so3_bch_relu_bracket_best_test_loss_acc.pt" 8 | 9 | # model_type: "LN_relu" 10 | # model_path: "/home/justin/code/LieNeurons/weights/0120_so3_bch_relu_4_layers_1024" 11 | 12 | # model_type: "LN_bracket" 13 | # model_path: "/home/justin/code/LieNeurons/weights/0124_so3_bch_bracket_best_test_loss_acc.pt" 14 | 15 | # model_type: "MLP" 16 | # model_path: "/home/justin/code/LieNeurons/weights/0327_so3_bch_mlp_4_layers_1024_augmented_best_test_loss_acc.pt" 17 | 18 | # model_type: "EMLP" 19 | # model_path: "/home/justin/code/LieNeurons/weights/EMLP_best_test_loss_acc.pt" 20 | 21 | # model_type: "e3nn_norm" 22 | # model_path: "/home/justin/code/LieNeurons/weights/0404_BCH_e3nn_norm_256_best_test_loss_acc.pt" 23 | 24 | model_type: "e3nn_s2grid" 25 | model_path: "/home/justin/code/LieNeurons/weights/0404_BCH_e3nn_s2grid_best_test_loss_acc.pt" 26 | 27 | batch_size: 1 28 | shuffle: False 29 | 30 | calculate_eq_approx_error: False -------------------------------------------------------------------------------- /config/sl3_inv/testing_param.yaml: -------------------------------------------------------------------------------- 1 | test_data_path: "/home/justin/code/LieNeurons/data/sl3_inv_data/sl3_inv_10000_s_05_test_data.npz" 2 | # logger 3 | model_description: "SL3InvariantFunctionFitting" 4 | print_freq: 100 5 | # model_type: "LN" 6 | # model_path: "/home/justin/code/LieNeurons/weights/0928_sl3_inv_LN_256_no_share_nonlin_train_10000_lr_0.000003_new_best_test_loss_acc.pt" 7 | # model_type: "MLP" 8 | # model_path: "/home/justin/code/LieNeurons/weights/rebuttal_inv_mlp_augmented_5_best_test_loss_acc.pt" 9 | model_type: "MLP" 10 | model_path: "/home/justin/code/LieNeurons/weights/rebuttal_inv_mlp_augmented_5_best_test_loss_acc.pt" 11 | # model_type: "LN_bracket" 12 | # model_path: "/home/justin/code/LieNeurons/weights/0928_one_bracket_0.00003_best_test_loss_acc.pt" 13 | 14 | # model_type: "LN_bracket" 15 | # model_path: "/home/justin/code/LieNeurons/weights/1120_one_bracket_2_best_test_loss_acc.pt" 16 | 17 | # model_type: "MLP" 18 | # model_path: "/home/justin/code/LieNeurons/weights/1120_MLP_augmented_best_test_loss_acc.pt" 19 | 20 | # model_type: "LN_bracket_no_residual" 21 | # model_path: "/home/justin/code/LieNeurons/weights/0928_one_bracket_no_residual_best_test_loss_acc.pt" 22 | num_workers: 1 23 | batch_size: 100 24 | shuffle: True 25 | -------------------------------------------------------------------------------- /test/test_se3_hat.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import numpy as np 5 | import scipy.linalg 6 | import torch 7 | 8 | from core.lie_alg_util import * 9 | 10 | if __name__ == "__main__": 11 | 12 | v = torch.Tensor([1,2,3,4,5,6]) 13 | print("v: ", v) 14 | se3_hatlayer = HatLayer(algebra_type='se3') 15 | M = se3_hatlayer(v) 16 | print("M: ", M) 17 | v2 = vee(M, algebra_type='se3') 18 | print("v after vee: ", v2) 19 | 20 | rnd_scale=5 21 | v = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 1, 6))) 22 | print("v: ", v) 23 | print(v.shape) 24 | M = se3_hatlayer(v) 25 | M_SE3 = scipy.linalg.expm(M[0,0,:,:].numpy()) 26 | print("M_SE3: ", M_SE3) 27 | M2 = scipy.linalg.logm(M_SE3) 28 | 29 | print("M: ", M) 30 | print("M log: ", M2) 31 | 32 | v = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 1, 3))) 33 | print("v: ", v) 34 | print(v.shape) 35 | so3_hatlayer = HatLayer(algebra_type='so3') 36 | M = so3_hatlayer(v) 37 | print(M.shape) 38 | M_SE3 = scipy.linalg.expm(M[0,0,:,:].numpy()) 39 | print("M_SE3: ", M_SE3) 40 | M2 = scipy.linalg.logm(M_SE3) 41 | 42 | print("M: ", M) 43 | print("M log: ", M2) -------------------------------------------------------------------------------- /config/platonic_solid_cls/testing_param.yaml: -------------------------------------------------------------------------------- 1 | test_data_path: "/home/justin/code/LieNeurons/data/sl3_equiv_5_input_data/sl3_equiv_100_s_05_test_data.npz" 2 | # logger 3 | model_description: "SL3EquivariantFunctionFitting" 4 | print_freq: 100 5 | # model_type: "LN" 6 | # model_path: "/home/justin/code/LieNeurons/weights/0926_platonic_cls_LN_best_test_loss_good.pt" 7 | 8 | model_type: "MLP" 9 | model_path: "/home/justin/code/LieNeurons/weights/rebuttal_cls_mlp_aug_4_best_test_acc.pt" 10 | # model_type: "LN_relu_bracket" 11 | # model_path: "/home/justin/code/LieNeurons/weights/rebuttal_cls_rb_5_best_test_acc.pt" 12 | 13 | # model_type: "LN_relu" 14 | # model_path: "/home/justin/code/LieNeurons/weights/0928_cls_one_relu_best_test_acc.pt" 15 | 16 | # model_type: "LN_bracket" 17 | # model_path: "/home/justin/code/LieNeurons/weights/0928_cls_one_bracket_best_test_acc.pt" 18 | 19 | # model_type: "MLP" 20 | # model_path: "/home/justin/code/LieNeurons/weights/0928_cls_MLP_best_test_acc.pt" 21 | 22 | # model_type: "LN_bracket_no_residual" 23 | # model_path: "/home/justin/code/LieNeurons/weights/0928_cls_one_bracket_no_residual_best_test_acc.pt" 24 | batch_size: 100 25 | shuffle: False 26 | num_test: 1000 27 | num_rotations: 500 28 | rotation_factor: 0.174533 # 10 degrees 29 | -------------------------------------------------------------------------------- /test/test_so3_bch.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import numpy as np 5 | import scipy.linalg 6 | import torch 7 | import math 8 | 9 | from core.lie_alg_util import * 10 | from core.lie_group_util import * 11 | 12 | if __name__ == "__main__": 13 | 14 | 15 | v1 = torch.rand(1,3) 16 | v1 = v1/torch.norm(v1) 17 | phi = math.pi 18 | v1 = phi*v1 19 | # print("v: ", v1) 20 | 21 | v2 = v1 22 | v2 = v2/torch.norm(v2) 23 | phi2 = math.pi 24 | v2 = phi2*v2 25 | 26 | v1 = torch.Tensor([[1.4911, 0.6458, 1.0547]]) 27 | v2 = torch.Tensor([[0.2295, 2.0104, 0.0430]]) 28 | 29 | so3_hatlayer = HatLayer(algebra_type='so3') 30 | K1 = so3_hatlayer(v1) 31 | K2 = so3_hatlayer(v2) 32 | 33 | R1 = exp_so3(K1[0,:,:]) 34 | R2 = exp_so3(K2[0,:,:]) 35 | 36 | print("v1: ", v1) 37 | print("v2: ", v2) 38 | print("R1: ", R1) 39 | print("R2: ", R2) 40 | 41 | R3 = R1@R2 42 | print("R3:", R3) 43 | K3 = log_SO3(R3) 44 | K3_BCH = BCH_approx(K1[0,:,:],K2[0,:,:]) 45 | K3_BCH_SO3 = BCH_so3(K1[0,:,:],K2[0,:,:]) 46 | print("K1: ", K1) 47 | print("K2: ", K2) 48 | print("----") 49 | print("K3: ", K3) 50 | print("K3 BCH: ", K3_BCH) 51 | print("K1+K2: ", K1+K2) 52 | print("K3 BCH SO3: ", K3_BCH_SO3) 53 | print("norm K3: ", torch.norm(vee(K3, algebra_type='so3'))) -------------------------------------------------------------------------------- /test/test_so3_exp.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import numpy as np 5 | import scipy.linalg 6 | import torch 7 | import math 8 | 9 | from core.lie_alg_util import * 10 | from core.lie_group_util import * 11 | 12 | if __name__ == "__main__": 13 | 14 | rnd_scale=1 15 | v = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 1, 3))) 16 | v = v/torch.norm(v) 17 | phi = math.pi*torch.rand(1) 18 | v = phi*v 19 | print("v: ", v) 20 | so3_hatlayer = HatLayer(algebra_type='so3') 21 | K = so3_hatlayer(v) 22 | K_SO3 = torch.asarray(scipy.linalg.expm(K[0,0,:,:].numpy())) 23 | K_SO3_2 = exp_hat_and_so3(v.T) 24 | K_SO3_3 = exp_so3(K[0,0,:,:]) 25 | K_SO3_exp_gpu = exp_so3(K[0,0,:,:].to('cuda')) 26 | print("-------------------") 27 | print("SO3: ", K_SO3) 28 | print("SO3_2: ", K_SO3_2) 29 | print("SO3_3: ", K_SO3_3) 30 | print("-------------------") 31 | print("det SO3: ", np.linalg.det(K_SO3)) 32 | print("det SO3_2: ", np.linalg.det(K_SO3_2)) 33 | print("-------------------") 34 | K_after = log_SO3(K_SO3) 35 | K2_after = log_SO3(K_SO3_2) 36 | K_log_gpu = log_SO3(K_SO3_3.to('cuda')) 37 | K_exp_gpu_after = log_SO3(K_SO3_exp_gpu) 38 | 39 | print("K: ", K) 40 | print("M: ", K_after) 41 | print("M log: ", K2_after) 42 | print("M log gpu: ", K_exp_gpu_after) 43 | print("M log gpu: ", K_log_gpu) -------------------------------------------------------------------------------- /config/sp4_inv/testing_param.yaml: -------------------------------------------------------------------------------- 1 | test_data_path: "/home/justin/code/LieNeurons/data/sp4_inv_data/sp4_inv_10000_s_05_test_data.npz" 2 | # logger 3 | model_description: "SP4InvariantFunctionFitting" 4 | print_freq: 100 5 | # model_type: "LN" 6 | # model_path: "/home/justin/code/LieNeurons/weights/0928_sl3_inv_LN_256_no_share_nonlin_train_10000_lr_0.000003_new_best_test_loss_acc.pt" 7 | # model_type: "MLP" 8 | # model_path: "/home/justin/code/LieNeurons/weights/rebuttal_inv_mlp_augmented_5_best_test_loss_acc.pt" 9 | # model_type: "MLP" 10 | # model_path: "/home/justin/code/LieNeurons/weights/rebuttal_inv_mlp_augmented_5_best_test_loss_acc.pt" 11 | 12 | # model_type: "LN_bracket" 13 | # model_path: "/home/justin/code/LieNeurons/weights/0928_one_bracket_0.00003_best_test_loss_acc.pt" 14 | 15 | model_type: "LN_relu_bracket" 16 | model_path: "/home/justin/code/LieNeurons/weights/0404_sp4_inv_relu_bracket_best_test_loss_acc.pt" 17 | 18 | # model_type: "LN_bracket" 19 | # model_path: "/home/justin/code/LieNeurons/weights/1120_one_bracket_2_best_test_loss_acc.pt" 20 | 21 | # model_type: "MLP512" 22 | # model_path: "/home/justin/code/LieNeurons/weights/0404_sp4_inv_MLP_512_augmented_best_test_loss_acc.pt" 23 | 24 | # model_type: "MLP512" 25 | # model_path: "/home/justin/code/LieNeurons/weights/0404_sp4_inv_MLP_512_best_test_loss_acc.pt" 26 | 27 | # model_type: "LN_bracket_no_residual" 28 | # model_path: "/home/justin/code/LieNeurons/weights/0928_one_bracket_no_residual_best_test_loss_acc.pt" 29 | num_workers: 1 30 | batch_size: 100 31 | shuffle: True 32 | -------------------------------------------------------------------------------- /docker/LieNeurons/Dockerfile: -------------------------------------------------------------------------------- 1 | # FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel 2 | # FROM pytorch/pytorch:0.4_cuda9_cudnn7 3 | # FROM pytorch/pytorch:1.2-cuda10.0-cudnn7-devel 4 | FROM nvcr.io/nvidia/pytorch:21.12-py3 5 | 6 | LABEL version="0.5" 7 | 8 | # USER root 9 | 10 | ENV DEBIAN_FRONTEND noninteractive 11 | 12 | # build essentials 13 | # RUN apt-get update && apt-get -y install cmake 14 | RUN apt-get update && apt-get install -y vim 15 | RUN apt-get install -y build-essential snap 16 | RUN apt-get update && apt-get install -y git-all 17 | # RUN add-apt-repository ppa:rmescandon/yq 18 | # RUN apt update 19 | # RUN snap install yq 20 | RUN wget https://github.com/mikefarah/yq/releases/latest/download/yq_linux_amd64 -O /usr/bin/yq &&\ 21 | chmod +x /usr/bin/yq 22 | 23 | # python essentials 24 | # Starting from numpy 1.24.0 np.float would pop up an error. 25 | # Cocoeval tool still uses np.float. As a result we can only use numpy 1.23.0. 26 | # Alternatively, we can clone the cocoapi and modify the cocoeval tool. 27 | # https://github.com/cocodataset/cocoapi/pull/569 28 | RUN pip install numpy==1.23.0 29 | RUN pip install -U matplotlib 30 | RUN pip install scipy 31 | RUN pip install tensorboardX 32 | RUN pip install wandb 33 | 34 | # torch 35 | RUN pip install torch>=1.7.0 36 | RUN pip install torchvision>=0.8.1 37 | 38 | # einops 39 | RUN pip install einops 40 | 41 | # pytest 42 | RUN pip install pytest 43 | 44 | # setuptool 45 | RUN pip install setuptools==59.5.0 46 | 47 | # torch diff eq 48 | RUN pip install git+https://github.com/rtqichen/torchdiffeq -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, UMich-CURLY 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /test/test_invariant.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from core.lie_neurons_layers import * 8 | from core.lie_alg_util import * 9 | from experiment.sl3_inv_layers import * 10 | 11 | 12 | if __name__ == "__main__": 13 | print("testing the invariant layer") 14 | 15 | # test equivariant linear layer 16 | num_points = 1 17 | num_features = 10 18 | out_features = 3 19 | batch_size = 20 20 | rnd_scale = 0.5 21 | 22 | hat_layer = HatLayer() 23 | 24 | x = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (batch_size, num_features, 8, num_points))).reshape( 25 | batch_size, num_features, 8, num_points) 26 | y = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, 8)) 27 | 28 | # SL(3) transformation 29 | Y = torch.linalg.matrix_exp(hat_layer(y)) 30 | 31 | # model = LNInvariant(num_features,method='learned_killing') 32 | model = SL3InvariantLayers(num_features) 33 | 34 | x_hat = hat_layer(x.transpose(2, -1)) 35 | new_x_hat = torch.matmul(Y, torch.matmul(x_hat, torch.inverse(Y))) 36 | new_x = vee_sl3(new_x_hat).transpose(2, -1) 37 | 38 | model.eval() 39 | with torch.no_grad(): 40 | out_x = model(x) 41 | out_new_x = model(new_x) 42 | 43 | test_result = torch.allclose( 44 | out_new_x, out_x, rtol=1e-4, atol=1e-4) 45 | 46 | print("out x[0,0,:]", out_x[0, :]) 47 | print("out new x[0,0,:]: ", out_new_x[0, :]) 48 | print("differences: ", out_x[ 0, :] - out_new_x[ 0, :]) 49 | 50 | print("The network is equivariant: ", test_result) 51 | -------------------------------------------------------------------------------- /script/run_euler_poincare.bash: -------------------------------------------------------------------------------- 1 | # python experiment/euler_poincare_eq_train.py --adjoint --viz --inertia_type iss --model_type LN_ode7 --fig_save_path figures/final_LN_ode7 --model_save_path weights/final_LN_ode7 2 | # python experiment/euler_poincare_eq_train.py --adjoint --viz --inertia_type iss --model_type neural_ode --fig_save_path figures/final_neural_ode --model_save_path weights/final_neural_ode 3 | # python experiment/euler_poincare_eq_train.py --adjoint --viz --inertia_type iss --model_type LN_ode7 --fig_save_path figures/final_LN_ode7_train_on_one --model_save_path weights/final_LN_ode7_train_on_one --num_training 1 4 | # python experiment/euler_poincare_eq_train.py --adjoint --viz --inertia_type iss --model_type neural_ode --fig_save_path figures/final_neural_ode_train_on_one --model_save_path weights/final_neural_ode_train_on_one --num_training 1 5 | # python experiment/euler_poincare_eq_train.py --adjoint --viz --inertia_type iss --model_type neural_ode2 --fig_save_path figures/final_neural_ode2_train_on_one --model_save_path weights/final_neural_ode2_train_on_one --num_training 1 6 | # python experiment/euler_poincare_eq_train.py --adjoint --viz --inertia_type iss --model_type neural_ode3 --fig_save_path figures/final_neural_ode3 --model_save_path weights/final_neural_ode3 7 | 8 | python experiment/euler_poincare_eq_train.py --adjoint --viz --inertia_type iss --model_type LN_ode8 --fig_save_path figures/final_LN_ode8 --model_save_path weights/final_LN_ode8 9 | python experiment/euler_poincare_eq_train.py --adjoint --viz --inertia_type iss --model_type LN_ode8 --fig_save_path figures/final_LN_ode8_train_on_one --model_save_path weights/final_LN_ode8_train_on_one --num_training 1 -------------------------------------------------------------------------------- /matlab/test_Ad_N.m: -------------------------------------------------------------------------------- 1 | clear; 2 | clc; 3 | 4 | num_sample = 100000; 5 | eps = 1e-5; 6 | for i=1:num_sample 7 | %% generate N 8 | n = rand(3,1); 9 | N = [1 n(1) n(2); 0 1 n(3); 0 0 1]; 10 | N = 1/det(N)^(1/3) * N; 11 | n1 = N(1,2); 12 | n2 = N(1,3); 13 | n3 = N(2,3); 14 | 15 | Ad_N =... 16 | [[ 1, n1, n1, 0, 0, 0, n2/2 + (n1*n3)/2, -n3/2]; 17 | [ -n1, 1 - n1^2/2, -n1^2/2, 0, 0, 0, n3/2 - (n1*n2)/2, n2/2]; 18 | [ n1, n1^2/2, n1^2/2 + 1, 0, 0, 0, n3/2 + (n1*n2)/2, -n2/2]; 19 | [ 0, 0, 0, 1, 0, 0, n2/2 - (n1*n3)/2, n3/2]; 20 | [2*n1*n3 - n2, - n3 - n1*(n2 - n1*n3), n3 - n1*(n2 - n1*n3), -3*n2, 1, n1, -n2*(n2 - n1*n3), -n2*n3]; 21 | [ n3, n1*n3 - n2, n1*n3 - n2, -3*n3, 0, 1, -n3*(n2 - n1*n3), -n3^2]; 22 | [ 0, 0, 0, 0, 0, 0, 1, 0]; 23 | [ 0, 0, 0, 0, 0, 0, -n1, 1]]; 24 | 25 | %% generate sl(3) 26 | v = rand(8,1); 27 | v_hat = hat_sl3(v); 28 | v_hat_ad = N * v_hat * inv(N); 29 | v_ad = Ad_N * v; % Linear map on v 30 | v_hat_ad_vee = vee_sl3(v_hat_ad); % vee after adjoint on v_hat 31 | 32 | diff = v_ad - v_hat_ad_vee; % They should be the same 33 | 34 | if norm(diff) > eps 35 | disp('error bigger than eps'); 36 | disp(diff) 37 | elseif i==num_sample 38 | disp('Passed all test case!'); 39 | end 40 | 41 | end 42 | 43 | -------------------------------------------------------------------------------- /matlab/test_killing_relu.m: -------------------------------------------------------------------------------- 1 | clear; 2 | clc; 3 | 4 | %% sl(3) generators 5 | G1 = [1,0,0; 6 | 0, -1, 0; 7 | 0,0,0]; 8 | G2 = [0, 1,0; 9 | 1, 0 ,0; 10 | 0, 0 ,0]; 11 | G3 = [0, -1, 0; 12 | 1, 0, 0; 13 | 0, 0, 0]; 14 | G4 = [1, 0, 0; 15 | 0, 1, 0; 16 | 0, 0, -2]; 17 | G5 = [0,0,1; 18 | 0,0,0; 19 | 0,0,0]; 20 | G6 = [0,0,0; 21 | 0,0,1; 22 | 0,0,0]; 23 | G7 = [0,0,0; 24 | 0,0,0; 25 | 1,0,0]; 26 | G8 = [0,0,0; 27 | 0,0,0; 28 | 0,1,0]; 29 | 30 | %% 31 | num_sample = 100000; 32 | out_sample = zeros(num_sample,2); 33 | for i=1:num_sample 34 | 35 | a = 2*(rand(8,1)-0.5); 36 | % b = 2*(rand(8,1)-0.5); 37 | b = a; 38 | 39 | A = a(1)*G1+a(2)*G2+a(3)*G3+a(4)*G4+a(5)*G5+a(6)*G6+a(7)*G7+a(8)*G8; 40 | B = b(1)*G1+b(2)*G2+b(3)*G3+b(4)*G4+b(5)*G5+b(6)*G6+b(7)*G7+b(8)*G8; 41 | 42 | if(trace(A)~=0) 43 | print("no"); 44 | end 45 | K = -6*trace(A*B); 46 | 47 | % relu 48 | if(K>0) 49 | out_sample(i,1)=K; 50 | out_sample(i,2)=K; 51 | else 52 | out_sample(i,1)=K; 53 | out_sample(i,2)=0; 54 | end 55 | end 56 | 57 | 58 | %% 59 | figure(1) 60 | scatter(out_sample(:,1),out_sample(:,2)); 61 | 62 | %% 63 | G = [G1,G2,G3,G4,G5,G6,G7,G8]; 64 | % 65 | % def sl3_to_R8(M): 66 | % # [a1 + a4, a2 - a3, a5] 67 | % # [a2 + a3, a4 - a1, a6] 68 | % # [ a7, a8, -2*a4] 69 | % v = torch.zeros(8).to(M.device) 70 | % v[3] = -0.5*M[2,2] 71 | % v[4] = M[0,2] 72 | % v[5] = M[1,2] 73 | % v[6] = M[2,0] 74 | % v[7] = M[2,1] 75 | % v[0] = (M[0,0] - v[3]) 76 | % 77 | % v[1] = 0.5*(M[0,1] + M[1,0]) 78 | % v[2] = 0.5*(M[1,0] - M[0,1]) 79 | % return v -------------------------------------------------------------------------------- /config/platonic_solid_cls/training_param.yaml: -------------------------------------------------------------------------------- 1 | # train_data_path: "/home/justin/code/LieNeurons/data/sl3_equiv_5_input_data/sl3_equiv_100_s_05_train_data.npz" 2 | # test_data_path: "/home/justin/code/LieNeurons/data/sl3_equiv_5_input_data/sl3_equiv_100_s_05_test_data.npz" 3 | # logger 4 | model_description: "SL3classification" 5 | print_freq: 1 6 | # model_type: "LN_relu_bracket" 7 | # model_save_path: "/home/justin/code/LieNeurons/weights/0928_cls_one_relu_one_bracket" 8 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0928_cls_one_relu_one_bracket" 9 | 10 | # model_type: "LN_relu" 11 | # model_save_path: "/home/justin/code/LieNeurons/weights/0928_cls_one_relu" 12 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0928_cls_one_relu" 13 | 14 | # model_type: "LN_bracket" 15 | # model_save_path: "/home/justin/code/LieNeurons/weights/0928_cls_one_bracket" 16 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0928_cls_one_bracket" 17 | model_type: "MLP" 18 | model_save_path: "/home/justin/code/LieNeurons/weights/rebuttal_cls_mlp_aug_4" 19 | log_writer_path: "/home/justin/code/LieNeurons/logs/rebuttal_cls_mlp_aug_4" 20 | # model_type: "LN_relu_bracket" 21 | # model_save_path: "/home/justin/code/LieNeurons/weights/rebuttal_cls_rb_5" 22 | # log_writer_path: "/home/justin/code/LieNeurons/logs/rebuttal_cls_rb_5" 23 | resume_training: false 24 | resume_model_path: "/home/justin/code/LieNeurons/weights/0921_sl3_equiv_first_try.pth.tar" 25 | num_workers: 1 26 | batch_size: 100 27 | update_every_batch: 1 28 | # initial_learning_rate: 0.00001 29 | initial_learning_rate: 0.0001 30 | learning_rate_decay_rate: 0.0 31 | weight_decay_rate: 0.00 32 | num_train: 500000 33 | num_test: 1000 34 | shuffle: True 35 | train_augmentation: True 36 | rotation_factor: 0.174533 # 10 degrees 37 | -------------------------------------------------------------------------------- /test/test_equivariant_liebracket.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from core.lie_neurons_layers import * 8 | from core.lie_alg_util import * 9 | 10 | 11 | if __name__ == "__main__": 12 | print("testing equivariant linear layer") 13 | 14 | # test equivariant linear layer 15 | num_points = 100 16 | num_features = 10 17 | out_features = 3 18 | 19 | x = torch.Tensor(np.random.rand(num_features, 8, num_points) 20 | ).reshape(1, num_features, 8, num_points) 21 | y = torch.Tensor(np.random.rand(8)) 22 | 23 | hat_layer = HatLayer() 24 | 25 | # SL(3) transformation 26 | Y = torch.linalg.matrix_exp(hat_layer(y)) 27 | 28 | model = LNLieBracket( 29 | num_features, share_nonlinearity=False) 30 | 31 | x_hat = hat_layer(x.transpose(2, -1)) 32 | new_x_hat = torch.matmul(Y, torch.matmul(x_hat, torch.inverse(Y))) 33 | new_x = vee_sl3(new_x_hat).transpose(2, -1) 34 | 35 | model.eval() 36 | with torch.no_grad(): 37 | out_x = model(x) 38 | out_new_x = model(new_x) 39 | 40 | out_x_hat = hat_layer(out_x.transpose(2, -1)) 41 | out_x_hat_conj = torch.matmul(Y, torch.matmul(out_x_hat, torch.inverse(Y))) 42 | out_x_conj = vee_sl3(out_x_hat_conj).transpose(2, -1) 43 | 44 | test_result = torch.allclose( 45 | out_new_x, out_x_conj, rtol=1e-4, atol=1e-4) 46 | 47 | print("out x[0,0,:,0]", out_x[0, 0, :, 0]) 48 | print("out x conj[0,0,:,0]: ", out_x_conj[0, 0, :, 0]) 49 | print("out new x[0,0,:,0]: ", out_new_x[0, 0, :, 0]) 50 | print("differences: ", 51 | out_x_conj[0, 0, :, 0] - out_new_x[0, 0, :, 0]) 52 | 53 | print("The network is equivariant: ", test_result) 54 | -------------------------------------------------------------------------------- /test/test_equivariant.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from core.lie_neurons_layers import * 8 | from core.lie_alg_util import * 9 | 10 | 11 | if __name__ == "__main__": 12 | print("testing equivariant linear layer") 13 | 14 | # test equivariant linear layer 15 | num_points = 100 16 | num_features = 10 17 | out_features = 3 18 | 19 | x = torch.Tensor(np.random.rand(num_features, 8, num_points) 20 | ).reshape(1, num_features, 8, num_points) 21 | y = torch.Tensor(np.random.rand(8)) 22 | 23 | hat_layer = HatLayer() 24 | 25 | # SL(3) transformation 26 | Y = torch.linalg.matrix_exp(hat_layer(y)) 27 | 28 | model = LNLinearAndKillingRelu( 29 | num_features, out_features, share_nonlinearity=True) 30 | 31 | x_hat = hat_layer(x.transpose(2, -1)) 32 | new_x_hat = torch.matmul(Y, torch.matmul(x_hat, torch.inverse(Y))) 33 | new_x = vee_sl3(new_x_hat).transpose(2, -1) 34 | 35 | model.eval() 36 | with torch.no_grad(): 37 | out_x = model(x) 38 | out_new_x = model(new_x) 39 | 40 | out_x_hat = hat_layer(out_x.transpose(2, -1)) 41 | out_x_hat_conj = torch.matmul(Y, torch.matmul(out_x_hat, torch.inverse(Y))) 42 | out_x_conj = vee_sl3(out_x_hat_conj).transpose(2, -1) 43 | 44 | test_result = torch.allclose( 45 | out_new_x, out_x_conj, rtol=1e-4, atol=1e-4) 46 | 47 | print("out x[0,0,:,0]", out_x[0, 0, :, 0]) 48 | print("out x conj[0,0,:,0]: ", out_x_conj[0, 0, :, 0]) 49 | print("out new x[0,0,:,0]: ", out_new_x[0, 0, :, 0]) 50 | print("differences: ", 51 | out_x_conj[0, 0, :, 0] - out_new_x[0, 0, :, 0]) 52 | 53 | print("The network is equivariant: ", test_result) 54 | -------------------------------------------------------------------------------- /test/test_pooling_equivariant.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from core.lie_neurons_layers import * 8 | from core.lie_alg_util import * 9 | 10 | if __name__ == "__main__": 11 | print("testing equivariant linear layer") 12 | 13 | # test equivariant linear layer 14 | num_points = 100 15 | num_features = 10 16 | out_features = 3 17 | 18 | x = torch.Tensor(np.random.rand(num_features, 8, num_points) 19 | ).reshape(1, num_features, 8, num_points) 20 | y = torch.Tensor(np.random.rand(8)) 21 | 22 | hat_layer = HatLayer() 23 | 24 | # SL(3) transformation 25 | Y = torch.linalg.matrix_exp(hat_layer(y)) 26 | 27 | model = LNLinearAndKillingReluAndPooling( 28 | num_features, out_features, share_nonlinearity=True, abs_killing_form=False) 29 | 30 | x_hat = hat_layer(x.transpose(2, -1)) 31 | new_x_hat = torch.matmul(Y, torch.matmul(x_hat, torch.inverse(Y))) 32 | new_x = vee_sl3(new_x_hat).transpose(2, -1) 33 | 34 | model.eval() 35 | with torch.no_grad(): 36 | out_x = model(x) 37 | out_new_x = model(new_x) 38 | 39 | out_x_hat = hat_layer(out_x.transpose(2, -1)) 40 | out_x_hat_conj = torch.matmul(Y, torch.matmul(out_x_hat, torch.inverse(Y))) 41 | out_x_conj = vee_sl3(out_x_hat_conj).transpose(2, -1) 42 | 43 | test_result = torch.allclose( 44 | out_new_x, out_x_conj, rtol=1e-4, atol=1e-4) 45 | 46 | print("out x[0,0,:]", out_x[0, 0, :]) 47 | print("out x conj[0,0,:]: ", out_x_conj[0, 0, :]) 48 | print("out new x[0,0,:]: ", out_new_x[0, 0, :]) 49 | print("differences: ", out_x_conj[0, 0, :] - out_new_x[0, 0, :]) 50 | 51 | print("The network is equivariant: ", test_result) 52 | -------------------------------------------------------------------------------- /test/test_batch_norm_equivariant.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from core.lie_neurons_layers import * 8 | from core.lie_alg_util import * 9 | from experiment.sl3_equiv_layers import * 10 | 11 | if __name__ == "__main__": 12 | print("testing equivariant linear layer") 13 | 14 | # test equivariant linear layer 15 | num_points = 1 16 | num_features = 10 17 | out_features = 3 18 | batch_size = 20 19 | rnd_scale = 0.5 20 | 21 | hat_layer = HatLayer() 22 | 23 | x = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (batch_size, num_features, 8, num_points))).reshape( 24 | batch_size, num_features, 8, num_points) 25 | y = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, 8)) 26 | 27 | # SL(3) transformation 28 | Y = torch.linalg.matrix_exp(hat_layer(y)) 29 | 30 | # model = LNLinearAndKillingReluAndPooling( 31 | # num_features, out_features, share_nonlinearity=True, use_batch_norm=True, dim=4) 32 | 33 | model = SL3EquivariantLayers(num_features) 34 | 35 | x_hat = hat_layer(x.transpose(2, -1)) 36 | new_x_hat = torch.matmul(Y, torch.matmul(x_hat, torch.inverse(Y))) 37 | new_x = vee_sl3(new_x_hat).transpose(2, -1) 38 | 39 | model.eval() 40 | with torch.no_grad(): 41 | out_x = model(x) 42 | out_new_x = model(new_x) 43 | 44 | out_x_hat = hat_layer(out_x) 45 | out_x_hat_conj = torch.matmul(Y, torch.matmul(out_x_hat, torch.inverse(Y))) 46 | out_x_conj = vee_sl3(out_x_hat_conj) 47 | 48 | test_result = torch.allclose( 49 | out_new_x, out_x_conj, rtol=1e-4, atol=1e-4) 50 | 51 | print("out x[0,0,:]", out_x[ 0, :]) 52 | print("out x conj[0,0,:]: ", out_x_conj[0, :]) 53 | print("out new x[0,0,:]: ", out_new_x[0, :]) 54 | print("differences: ", out_x_conj[0, :] - out_new_x[ 0, :]) 55 | 56 | print("The network is equivariant: ", test_result) 57 | -------------------------------------------------------------------------------- /matlab/find_adjoint_KAN.m: -------------------------------------------------------------------------------- 1 | clear; 2 | clc; 3 | 4 | syms n1 n2 n3 a1 a2 v1 v2 v3 v4 v5 v6 v7 v8 5 | assumeAlso([n1 n2 n3 a1 a2 v1 v2 v3 v4 v5 v6 v7 v8],'real') 6 | 7 | syms h1 h2 h3 h4 h5 h6 h7 h8 h9 8 | assumeAlso([h1 h2 h3 h4 h5 h6 h7 h8 h9],'real') 9 | %% basis 10 | 11 | E1 = [1, 0, 0; 0, -1, 0; 0, 0, 0]; 12 | E2 = [0, 1, 0; 1, 0, 0; 0, 0, 0]; 13 | E3 = [0, -1, 0; 1, 0, 0; 0, 0, 0]; 14 | E4 = [1, 0, 0; 0, 1, 0; 0, 0, -2]; 15 | E5 = [0, 0, 1; 0, 0, 0; 0, 0, 0]; 16 | E6 = [0, 0, 0; 0, 0, 1; 0, 0, 0]; 17 | E7 = [0, 0, 0; 0, 0, 0; 1, 0, 0]; 18 | E8 = [0, 0, 0; 0, 0, 0; 0, 1, 0]; 19 | 20 | E1_vec = reshape(E1,1,[])'; 21 | E2_vec = reshape(E2,1,[])'; 22 | E3_vec = reshape(E3,1,[])'; 23 | E4_vec = reshape(E4,1,[])'; 24 | E5_vec = reshape(E5,1,[])'; 25 | E6_vec = reshape(E6,1,[])'; 26 | E7_vec = reshape(E7,1,[])'; 27 | E8_vec = reshape(E8,1,[])'; 28 | 29 | 30 | E_vec = [E1_vec, E2_vec, E3_vec, E4_vec, E5_vec, E6_vec, E7_vec, E8_vec]; 31 | 32 | x_hat = v1*E1+v2*E2+v3*E3+v4*E4+v5*E5+v6*E6+v7*E7+v8*E8; 33 | 34 | %% find Ad_N 35 | N = [1 n1 n2; 0 1 n3; 0 0 1]; 36 | Ad_N_hat = N*x_hat*inv(N); 37 | Ad_N_hat_vec = reshape(Ad_N_hat,1,[])'; 38 | 39 | % solve for least square 40 | x = inv(E_vec'*E_vec)*(E_vec')*Ad_N_hat_vec; 41 | var = [v1,v2,v3,v4,v5,v6,v7,v8]; 42 | [Ad_N,b]=equationsToMatrix(x,var); 43 | 44 | %% find Ad_A 45 | A = [a1, 0,0; 0,a2,0;0,0,1/a1/a2]; 46 | Ad_A_hat = A*x_hat*inv(A); 47 | Ad_A_hat_vec = reshape(Ad_A_hat,1,[])'; 48 | 49 | % solve for least square 50 | x = inv(E_vec'*E_vec)*(E_vec')*Ad_A_hat_vec; 51 | var = [v1,v2,v3,v4,v5,v6,v7,v8]; 52 | [Ad_A,b]=equationsToMatrix(x,var); 53 | 54 | %% find h 55 | H = [h1,h2,h3;h4,h5,h6;h7,h8,h9]; 56 | Ad_H_hat = H*x_hat*inv(H); 57 | Ad_H_hat_vec = reshape(Ad_H_hat,1,[])'; 58 | 59 | % solve for least square 60 | x = inv(E_vec'*E_vec)*(E_vec')*Ad_H_hat_vec; 61 | var = [v1,v2,v3,v4,v5,v6,v7,v8]; 62 | [Ad_H,b]=equationsToMatrix(x,var); 63 | 64 | Ad_kron_H = kron(inv(H)',Ad_H); 65 | 66 | %% 67 | % sigma = svd(Ad_kron_H); 68 | % z = null(Ad_kron_H - eye(24)) -------------------------------------------------------------------------------- /docker/EMLP/Dockerfile: -------------------------------------------------------------------------------- 1 | # FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel 2 | # FROM pytorch/pytorch:0.4_cuda9_cudnn7 3 | # FROM pytorch/pytorch:1.2-cuda10.0-cudnn7-devel 4 | FROM nvcr.io/nvidia/pytorch:22.10-py3 5 | 6 | LABEL version="0.5" 7 | 8 | # USER root 9 | 10 | ENV DEBIAN_FRONTEND noninteractive 11 | 12 | # build essentials 13 | # RUN apt-get update && apt-get -y install cmake 14 | RUN apt-get update && apt-get install -y vim 15 | RUN apt-get install -y build-essential snap 16 | RUN apt-get update && apt-get install -y git-all 17 | # RUN add-apt-repository ppa:rmescandon/yq 18 | # RUN apt update 19 | # RUN snap install yq 20 | RUN wget https://github.com/mikefarah/yq/releases/latest/download/yq_linux_amd64 -O /usr/bin/yq &&\ 21 | chmod +x /usr/bin/yq 22 | 23 | # RUN apt-get update && apt-get install python3.9 -y 24 | # RUN apt-get update && apt-get install python3-pip -y 25 | 26 | # python essentials 27 | # Starting from numpy 1.24.0 np.float would pop up an error. 28 | # Cocoeval tool still uses np.float. As a result we can only use numpy 1.23.0. 29 | # Alternatively, we can clone the cocoapi and modify the cocoeval tool. 30 | # https://github.com/cocodataset/cocoapi/pull/569 31 | 32 | RUN pip install numpy==1.23.0 33 | RUN pip install -U matplotlib 34 | RUN pip install scipy 35 | RUN pip install tensorboardX 36 | RUN pip install wandb 37 | 38 | # torch 39 | RUN pip install torch>=1.7.0 40 | RUN pip install torchvision>=0.8.1 41 | 42 | # einops 43 | RUN pip install einops 44 | 45 | # pytest 46 | RUN pip install pytest 47 | 48 | # setuptool 49 | RUN pip install setuptools 50 | 51 | # torch diff eq 52 | RUN pip install git+https://github.com/rtqichen/torchdiffeq 53 | 54 | # EMLP 55 | RUN pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 56 | 57 | RUN pip install h5py 58 | RUN pip install objax 59 | RUN pip install optax==0.1.7 60 | # RUN pip install plum-dispatc 61 | RUN pip install scikit-learn 62 | # RUN pip install tqdm>=4.38 63 | RUN pip install emlp 64 | -------------------------------------------------------------------------------- /config/sp4_inv/training_param.yaml: -------------------------------------------------------------------------------- 1 | train_data_path: "/home/justin/code/LieNeurons/data/sp4_inv_data/sp4_inv_10000_s_05_augmented_train_data.npz" 2 | test_data_path: "/home/justin/code/LieNeurons/data/sp4_inv_data/sp4_inv_10000_s_05_test_data.npz" 3 | 4 | # logger 5 | model_description: "SP4InvariantFunctionFitting" 6 | print_freq: 100 7 | 8 | # model_type: "MLP" 9 | # model_save_path: "/home/justin/code/LieNeurons/weights/rebuttal_inv_mlp_augmented_5" 10 | # log_writer_path: "/home/justin/code/LieNeurons/logs/rebuttal_inv_mlp_augmented_5" 11 | 12 | # model_type: "LN_relu_bracket" 13 | # model_save_path: "/home/justin/code/LieNeurons/weights/0404_sp4_inv_relu_bracket" 14 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0404_sp4_inv_relu_bracket" 15 | 16 | # model_type: "LN_relu" 17 | # model_save_path: "/home/justin/code/LieNeurons/weights/1120_one_relu_test" 18 | # log_writer_path: "/home/justin/code/LieNeurons/logs/1120_one_relu_test" 19 | 20 | # model_type: "LN_bracket" 21 | # model_save_path: "/home/justin/code/LieNeurons/weights/1120_one_bracket_3" 22 | # log_writer_path: "/home/justin/code/LieNeurons/logs/1120_one_bracket_3" 23 | 24 | # model_type: "MLP" 25 | # model_save_path: "/home/justin/code/LieNeurons/weights/0404_sp4_inv_MLP_augmented" 26 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0404_sp4_inv_MLP_augmented" 27 | 28 | model_type: "MLP512" 29 | model_save_path: "/home/justin/code/LieNeurons/weights/0404_sp4_inv_MLP_512_augmented" 30 | log_writer_path: "/home/justin/code/LieNeurons/logs/0404_sp4_inv_MLP_512_augmented" 31 | 32 | # model_type: "LN_bracket_no_residual" 33 | # model_save_path: "/home/justin/code/LieNeurons/weights/1120_one_bracket_no_residual" 34 | # log_writer_path: "/home/justin/code/LieNeurons/logs/1120_one_bracket_no_residual" 35 | resume_training: false 36 | resume_model_path: "/home/justin/code/LieNeurons/weights/0921_sl3_inv_first_try.pth.tar" 37 | num_workers: 1 38 | batch_size: 100 39 | update_every_batch: 1 40 | # initial_learning_rate: 0.00003 41 | initial_learning_rate: 0.0001 42 | learning_rate_decay_rate: 0.0 43 | weight_decay_rate: 0.00 44 | num_epochs: 600 45 | shuffle: True 46 | -------------------------------------------------------------------------------- /config/sl3_inv/training_param.yaml: -------------------------------------------------------------------------------- 1 | train_data_path: "/home/justin/code/LieNeurons/data/sl3_inv_data/sl3_inv_10000_s_05_augmented_train_data.npz" 2 | test_data_path: "/home/justin/code/LieNeurons/data/sl3_inv_data/sl3_inv_10000_s_05_test_data.npz" 3 | # train_data_path: "/home/justin/code/LieNeurons/data/sl3_inv_data/sl3_inv_10000_s_05_train_data.npz" 4 | # test_data_path: "/home/justin/code/LieNeurons/data/sl3_inv_data/sl3_inv_10000_s_05_test_data.npz" 5 | 6 | # train_data_path: "/home/justin/code/LieNeurons/data/sl3_inv_data/sl3_inv_10000_s_05_augmented_train_data.npz" 7 | # test_data_path: "/home/justin/code/LieNeurons/data/sl3_inv_data/sl3_inv_10000_s_05_test_data.npz" 8 | # logger 9 | model_description: "SL3InvariantFunctionFitting" 10 | print_freq: 100 11 | model_type: "MLP" 12 | model_save_path: "/home/justin/code/LieNeurons/weights/rebuttal_inv_mlp_augmented_5" 13 | log_writer_path: "/home/justin/code/LieNeurons/logs/rebuttal_inv_mlp_augmented_5" 14 | # model_type: "LN_relu_bracket" 15 | # model_save_path: "/home/justin/code/LieNeurons/weights/1120_one_relu_one_bracket" 16 | # log_writer_path: "/home/justin/code/LieNeurons/logs/1120_one_relu_one_bracket" 17 | 18 | # model_type: "LN_relu" 19 | # model_save_path: "/home/justin/code/LieNeurons/weights/1120_one_relu_test" 20 | # log_writer_path: "/home/justin/code/LieNeurons/logs/1120_one_relu_test" 21 | 22 | # model_type: "LN_bracket" 23 | # model_save_path: "/home/justin/code/LieNeurons/weights/1120_one_bracket_3" 24 | # log_writer_path: "/home/justin/code/LieNeurons/logs/1120_one_bracket_3" 25 | 26 | # model_type: "MLP" 27 | # model_save_path: "/home/justin/code/LieNeurons/weights/1120_MLP" 28 | # log_writer_path: "/home/justin/code/LieNeurons/logs/1120_MLP" 29 | 30 | # model_type: "LN_bracket_no_residual" 31 | # model_save_path: "/home/justin/code/LieNeurons/weights/1120_one_bracket_no_residual" 32 | # log_writer_path: "/home/justin/code/LieNeurons/logs/1120_one_bracket_no_residual" 33 | resume_training: false 34 | resume_model_path: "/home/justin/code/LieNeurons/weights/0921_sl3_inv_first_try.pth.tar" 35 | num_workers: 1 36 | batch_size: 100 37 | update_every_batch: 1 38 | # initial_learning_rate: 0.00003 39 | initial_learning_rate: 0.0001 40 | learning_rate_decay_rate: 0.0 41 | weight_decay_rate: 0.00 42 | num_epochs: 600 43 | shuffle: True 44 | -------------------------------------------------------------------------------- /config/sl3_equiv/training_param.yaml: -------------------------------------------------------------------------------- 1 | # train_data_path: "/home/justin/code/LieNeurons/data/sl3_equiv_lie_bracket_data/sl3_equiv_10000_lie_bracket_2inputs_augmented_train_data.npz" 2 | # test_data_path: "/home/justin/code/LieNeurons/data/sl3_equiv_lie_bracket_data/sl3_equiv_10000_lie_bracket_2inputs_test_data.npz" 3 | train_data_path: "/home/justin/code/LieNeurons/data/sl3_equiv_lie_bracket_data/sl3_equiv_10000_lie_bracket_2inputs_train_data.npz" 4 | test_data_path: "/home/justin/code/LieNeurons/data/sl3_equiv_lie_bracket_data/sl3_equiv_10000_lie_bracket_2inputs_test_data.npz" 5 | # train_data_path: "/home/justin/code/LieNeurons/data/sl3_equiv_lie_bracket_data/sl3_equiv_10000_lie_bracket_2inputs_augmented_train_data.npz" 6 | # test_data_path: "/home/justin/code/LieNeurons/data/sl3_equiv_lie_bracket_data/sl3_equiv_10000_lie_bracket_2inputs_test_data.npz" 7 | # logger 8 | model_description: "SL3EquivariantFunctionFitting" 9 | print_freq: 100 10 | # model_type: "LN_relu_bracket" 11 | # model_save_path: "/home/justin/code/LieNeurons/weights/1119_equiv_one_relu_one_bracket_final" 12 | # log_writer_path: "/home/justin/code/LieNeurons/logs/1119_equiv_one_relu_one_bracket_final" 13 | 14 | # model_type: "LN_relu" 15 | # model_save_path: "/home/justin/code/LieNeurons/weights/1119_equiv_one_relu" 16 | # log_writer_path: "/home/justin/code/LieNeurons/logs/1119_equiv_one_relu" 17 | 18 | # model_type: "LN_bracket" 19 | # model_save_path: "/home/justin/code/LieNeurons/weights/1119_equiv_one_bracket_3" 20 | # log_writer_path: "/home/justin/code/LieNeurons/logs/1119_equiv_one_bracket_3" 21 | 22 | # model_type: "MLP" 23 | # model_save_path: "/home/justin/code/LieNeurons/weights/rebuttal_equiv_mlp_augmented_5" 24 | # log_writer_path: "/home/justin/code/LieNeurons/logs/rebuttal_equiv_mlp_augmented_5" 25 | 26 | model_type: "LN_bracket_no_residual" 27 | model_save_path: "/home/justin/code/LieNeurons/weights/0404_equiv_one_bracket_no_residual_test" 28 | log_writer_path: "/home/justin/code/LieNeurons/logs/0404_equiv_one_bracket_no_residual_test" 29 | 30 | resume_training: false 31 | resume_model_path: "/home/justin/code/LieNeurons/weights/0921_sl3_equiv_first_try.pth.tar" 32 | num_workers: 1 33 | batch_size: 100 34 | update_every_batch: 1 35 | # initial_learning_rate: 0.0001 36 | # initial_learning_rate: 0.000003 37 | # initial_learning_rate: 0.000001 38 | initial_learning_rate: 0.00001 39 | learning_rate_decay_rate: 0.0 40 | weight_decay_rate: 0.00 41 | num_epochs: 1500 42 | shuffle: True 43 | -------------------------------------------------------------------------------- /data_loader/sp4_inv_data_loader.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | from scipy.linalg import expm 8 | from tqdm import tqdm 9 | 10 | 11 | from einops import rearrange, repeat 12 | from einops.layers.torch import Rearrange 13 | 14 | from core.lie_neurons_layers import * 15 | 16 | 17 | class sp4InvDataSet(Dataset): 18 | def __init__(self, data_path, device='cuda'): 19 | data = np.load(data_path) 20 | _, num_points = data['x1'].shape 21 | _,_,num_conjugate = data['x1_conjugate'].shape 22 | self.x1 = rearrange(torch.from_numpy(data['x1']).type( 23 | 'torch.FloatTensor').to(device), 'k n -> n 1 k 1') 24 | self.x2 = rearrange(torch.from_numpy(data['x2']).type( 25 | 'torch.FloatTensor').to(device), 'k n -> n 1 k 1') 26 | self.x = torch.cat((self.x1, self.x2), dim=1) 27 | 28 | self.x1_conjugate = rearrange(torch.from_numpy(data['x1_conjugate']).type( 29 | 'torch.FloatTensor').to(device),'k n c -> c n 1 k 1') 30 | self.x2_conjugate = rearrange(torch.from_numpy(data['x2_conjugate']).type( 31 | 'torch.FloatTensor').to(device),'k n c -> c n 1 k 1') 32 | self.x_conjugate = torch.cat( 33 | (self.x1_conjugate, self.x2_conjugate), dim=2) 34 | self.y = torch.from_numpy(data['y']).type( 35 | 'torch.FloatTensor').to(device).reshape(num_points, 1) 36 | 37 | self.num_data = self.x1.shape[0] 38 | 39 | def __len__(self): 40 | return self.num_data 41 | 42 | def __getitem__(self, idx): 43 | if torch.is_tensor(idx): 44 | idx = idx.tolist() 45 | 46 | sample = {'x1': self.x1[idx, :, :, :], 'x2': self.x2[idx, :, :, :], 'x': self.x[idx, :, :, :], 47 | 'x1_conjugate': self.x1_conjugate[:,idx, :, :, :], 'x2_conjugate': self.x2_conjugate[:,idx, :, :, :], 48 | 'x_conjugate': self.x_conjugate[:,idx, :, :, :], 'y': self.y[idx, :]} 49 | return sample 50 | 51 | 52 | if __name__ == "__main__": 53 | 54 | DataLoader = sp4InvDataSet("data/sp4_inv_data/sp4_inv_10000_s_05_train_data.npz") 55 | 56 | print(DataLoader.x1.shape) 57 | print(DataLoader.x2.shape) 58 | print(DataLoader.x1_conjugate.shape) 59 | print(DataLoader.x2_conjugate.shape) 60 | print(DataLoader.x.shape) 61 | print(DataLoader.x_conjugate.shape) 62 | print(DataLoader.y.shape) 63 | for i, samples in tqdm(enumerate(DataLoader, start=0)): 64 | input_data = samples['x'] 65 | y = samples['y'] 66 | print(input_data.shape) 67 | print("x1: \n",input_data[0,:,0]) 68 | print("x2: \n",input_data[1,:,0]) 69 | print("y: ", y) 70 | print("------------------") 71 | # print(input_data.shape) 72 | -------------------------------------------------------------------------------- /test/test_so3_bch_time.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import numpy as np 5 | import scipy.linalg 6 | import torch 7 | import math 8 | import time 9 | 10 | from core.lie_alg_util import * 11 | from core.lie_group_util import * 12 | from experiment.so3_bch_layers import * 13 | 14 | if __name__ == "__main__": 15 | 16 | 17 | num_test = 10000 18 | 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | print('Using ', device) 21 | 22 | so3_hatlayer = HatLayer(algebra_type='so3') 23 | 24 | def gen_random_rotation_vector(): 25 | v = torch.rand(1,3) 26 | v = v/torch.norm(v) 27 | phi = (math.pi-1e-6)*torch.rand(1) 28 | v = phi*v 29 | return v 30 | sum_t_first_order = 0 31 | sum_t_second_order = 0 32 | sum_t_third_order = 0 33 | sum_t_mlp = 0 34 | sum_t_LN = 0 35 | 36 | model = SO3EquivariantReluBracketLayers(2).to(device) 37 | 38 | # model = SO3EquivariantReluLayers(2).to(device) 39 | # model = SO3EquivariantBracketLayers(2).to(device) 40 | mlp_model = MLP(6).to(device) 41 | 42 | checkpoint = torch.load('/home/justin/code/LieNeurons/weights/0120_so3_bch_relu_bracket_best_test_loss_acc.pt') 43 | model.load_state_dict(checkpoint['model_state_dict'],strict=False) 44 | 45 | checkpoint = torch.load('/home/justin/code/LieNeurons/weights/0327_so3_bch_mlp_4_layers_1024_augmented_best_test_loss_acc.pt') 46 | mlp_model.load_state_dict(checkpoint['model_state_dict'],strict=False) 47 | 48 | # generate training data 49 | for i in range(num_test): 50 | # generate random v1, v2 51 | v1 = gen_random_rotation_vector() 52 | v2 = gen_random_rotation_vector() 53 | 54 | K1 = so3_hatlayer(v1).to(device) 55 | K2 = so3_hatlayer(v2).to(device) 56 | 57 | 58 | t1 = time.time() 59 | out = BCH_first_order_approx(K1,K2) 60 | t2 = time.time() 61 | sum_t_first_order += t2-t1 62 | 63 | t1 = time.time() 64 | out = BCH_second_order_approx(K1,K2) 65 | t2 = time.time() 66 | sum_t_second_order += t2-t1 67 | 68 | t1 = time.time() 69 | out = BCH_third_order_approx(K1,K2) 70 | t2 = time.time() 71 | sum_t_third_order += t2-t1 72 | 73 | v1 = v1.to(device) 74 | v2 = v2.to(device) 75 | v = torch.cat((v1,v2),0) 76 | 77 | v = rearrange(v, 'n k -> 1 n k 1') 78 | 79 | t1 = time.time() 80 | out = model(v) 81 | t2 = time.time() 82 | sum_t_LN += t2-t1 83 | 84 | t1 = time.time() 85 | out = mlp_model(v) 86 | t2 = time.time() 87 | sum_t_mlp += t2-t1 88 | 89 | 90 | print("First order time: ", sum_t_first_order/num_test) 91 | print("Second order time: ", sum_t_second_order/num_test) 92 | print("Third order time: ", sum_t_third_order/num_test) 93 | print("LN time: ", sum_t_LN/num_test) 94 | print("MLP time: ", sum_t_mlp/num_test) 95 | -------------------------------------------------------------------------------- /config/so3_bch/training_param.yaml: -------------------------------------------------------------------------------- 1 | train_data_path: "/home/justin/code/LieNeurons/data/so3_bch_data/so3_bch_10000_train_data.npz" 2 | # train_data_path: "/home/justin/code/LieNeurons/data/so3_bch_data/so3_bch_10000_augmented_train_data.npz" 3 | test_data_path: "/home/justin/code/LieNeurons/data/so3_bch_data/so3_bch_10000_test_data.npz" 4 | # train_data_path: "/home/justin/code/LieNeurons/data/so3_bch_data/so3_bch_10000_approx_no_conj_train_data.npz" 5 | # test_data_path: "/home/justin/code/LieNeurons/data/so3_bch_data/so3_bch_10000_approx_no_conj_test_data.npz" 6 | 7 | # logger 8 | model_description: "SO3 BCH 10000" 9 | print_freq: 100 10 | # model_type: "LN_relu_bracket" 11 | # model_save_path: "/home/justin/code/LieNeurons/weights/0125_so3_bch_relu_bracket" 12 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0125_so3_bch_relu_bracket" 13 | 14 | # model_type: "LN_relu" 15 | # model_save_path: "/home/justin/code/LieNeurons/weights/0120_so3_bch_relu_4_layers_1024" 16 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0120_so3_bch_relu_4_layers_1024" 17 | 18 | # model_type: "LN_bracket" 19 | # model_save_path: "/home/justin/code/LieNeurons/weights/0124_so3_bch_bracket" 20 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0124_so3_bch_bracket" 21 | 22 | # model_type: "MLP" 23 | # model_save_path: "/home/justin/code/LieNeurons/weights/0327_so3_bch_mlp_4_layers_1024_augmented" 24 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0327_so3_bch_mlp_4_layers_1024_augmented" 25 | 26 | # model_type: "LN_bracket_no_residual" 27 | # model_save_path: "/home/justin/code/LieNeurons/weights/0928_equiv_one_bracket_no_residual" 28 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0928_equiv_one_bracket_no_residual" 29 | 30 | # model_type: "EMLP" 31 | # model_save_path: "/home/justin/code/LieNeurons/weights/0326_BCH_EMLP_512_test" 32 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0326_BCH_EMLP_512_test" 33 | 34 | # model_type: "e3nn_norm" 35 | # model_save_path: "/home/justin/code/LieNeurons/weights/0404_BCH_e3nn_norm_1024" 36 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0404_BCH_e3nn_norm_1024" 37 | 38 | model_type: "e3nn_s2grid" 39 | model_save_path: "/home/justin/code/LieNeurons/weights/0404_BCH_e3nn_s2grid_small_lr" 40 | log_writer_path: "/home/justin/code/LieNeurons/logs/0404_BCH_e3nn_s2grid_small_lr" 41 | 42 | # model_type: "VN_relu" 43 | # model_save_path: "/home/justin/code/LieNeurons/weights/0327_BCH_VN" 44 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0327_BCH_VN" 45 | 46 | resume_training: False 47 | resume_model_path: "/home/justin/code/LieNeurons/weights/0326_BCH_EMLP_512_test_best_test_loss_acc.pt" 48 | 49 | num_workers: 1 50 | batch_size: 100 51 | update_every_batch: 1 52 | initial_learning_rate: 0.0001 53 | # initial_learning_rate: 0.000003 54 | # initial_learning_rate: 0.00001 55 | # initial_learning_rate: 0.0005 # EMLP 56 | learning_rate_decay_rate: 0.0 57 | weight_decay_rate: 0.00 58 | num_epochs: 10000 59 | shuffle: True 60 | full_eval_during_training: False 61 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Lie Neurons: Adjoint Equivariant Neural Networks for Semi-simple Lie Algebras

2 | 3 |

4 | Tzu-Yuan Lin*1    5 | Minghan Zhu*1    6 | Maani Ghaffari1    7 |
8 | *Eqaul Contributions   1University of Michigan, Ann Arbor    9 |

10 |

11 | 12 |

13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 |

21 | 22 | ## About 23 | An MLP framework that takes Lie algebraic data as inputs and is equivariant to the adjoint representation of the group by construction. 24 | 25 | ![front_figure](figures/lie_neurons_icon.jpg?raw=true "Title") 26 | 27 | ## Modules 28 | ![modules](figures/lie_neurons_modules.jpg?raw=true "Modules") 29 | 30 | ## Updates 31 | * [07/2024] The initial code is open-sourced. We are still re-organizing the code. We plan to release a cleaner version of the code soon. Feel free to reach out if you have any questions! :) 32 | * [07/2024] We presented our paper at ICML 24! 33 | 34 | ## Docker 35 | * We provide [docker](https://docs.docker.com/get-started/) files in [`docker/`](https://github.com/UMich-CURLY/LieNeurons/tree/main/docker). 36 | * Detailed tutorial on how to build the docker container can be found in the README in each docker folder. 37 | 38 | ## Training the Network 39 | * All the training codes for experiments are in [`experiment/`](https://github.com/UMich-CURLY/LieNeurons/tree/main/experiment). 40 | * Before training, you'll have to generate the data using Python scripts in [`data_gen`](https://github.com/UMich-CURLY/LieNeurons/tree/main/data_gen). 41 | * Empirically, we found out that using a lower learning rate (around `3e-5`) helps the convergence during training. This is likely due to the lack of normalization layers. 42 | * When working with $\mathfrak{so}(3)$, Lie Neurons specialize to [Vector Neurons](https://github.com/FlyingGiraffe/vnn) with an additional bracket nonlinearity and a channel mixing layer. Since the inner product is well-defined on $\mathfrak{so}(3)$, one can plug in the batch normalization layers proposed in Vector Neurons to improve stability during training. 43 | 44 | ## Citation 45 | If you find the work useful, please kindly cite our paper: 46 | ``` 47 | @inproceedings{lin2024ln, 48 | title={{Lie Neurons}: {Adjoint}-Equivariant Neural Networks for Semisimple {Lie} Algebras}, 49 | author={Lin, Tzu-Yuan and Zhu, Minghan and Ghaffari, Maani}, 50 | booktitle={International Conference on Machine Learning}, 51 | pages={30529--30545}, 52 | year={2024}, 53 | organization={PMLR} 54 | } 55 | ``` 56 | -------------------------------------------------------------------------------- /core/vn_layers.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | 6 | sys.path.append('.') 7 | 8 | 9 | EPS = 1e-6 10 | 11 | class VNLinear(nn.Module): 12 | def __init__(self, in_channels, out_channels): 13 | super(VNLinear, self).__init__() 14 | self.map_to_feat = nn.Linear(in_channels, out_channels, bias=False) 15 | 16 | def forward(self, x): 17 | ''' 18 | x: point features of shape [B, N_feat, 3, N_samples, ...] 19 | ''' 20 | x_out = self.map_to_feat(x.transpose(1,-1)).transpose(1,-1) 21 | return x_out 22 | 23 | 24 | class VNLeakyReLU(nn.Module): 25 | def __init__(self, in_channels, share_nonlinearity=False, leaky_relu=False, negative_slope=0.2): 26 | super(VNLeakyReLU, self).__init__() 27 | if share_nonlinearity == True: 28 | self.map_to_dir = nn.Linear(in_channels, 1, bias=False) 29 | else: 30 | self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False) 31 | self.negative_slope = negative_slope 32 | self.leaky_relu = leaky_relu 33 | 34 | def forward(self, x): 35 | ''' 36 | x: point features of shape [B, N_feat, 3, N_samples, ...] 37 | ''' 38 | d = self.map_to_dir(x.transpose(1,-1)).transpose(1,-1) 39 | dotprod = (x*d).sum(2, keepdim=True) 40 | mask = (dotprod >= 0).float() 41 | d_norm_sq = (d*d).sum(2, keepdim=True) 42 | 43 | if self.leaky_relu: 44 | x_out = self.negative_slope * x + (1-self.negative_slope) * (mask*x + (1-mask)*(x-(dotprod/(d_norm_sq+EPS))*d)) 45 | else: 46 | x_out = torch.where(dotprod >= 0, x, x-(dotprod/(d_norm_sq+EPS))*d) 47 | return x_out 48 | 49 | class VNLinearAndLeakyReLU(nn.Module): 50 | def __init__(self, in_channels, out_channels, dim=5, share_nonlinearity=False, leaky_relu=False, use_batchnorm='norm', negative_slope=0.2): 51 | super(VNLinearAndLeakyReLU, self).__init__() 52 | self.dim = dim 53 | self.share_nonlinearity = share_nonlinearity 54 | self.use_batchnorm = use_batchnorm 55 | self.negative_slope = negative_slope 56 | 57 | self.linear = VNLinear(in_channels, out_channels) 58 | self.leaky_relu = VNLeakyReLU(out_channels, share_nonlinearity=share_nonlinearity, leaky_relu=leaky_relu, negative_slope=negative_slope) 59 | 60 | # BatchNorm 61 | self.use_batchnorm = use_batchnorm 62 | if use_batchnorm != 'none': 63 | self.batchnorm = VNBatchNorm(out_channels, dim=dim) 64 | 65 | def forward(self, x): 66 | ''' 67 | x: point features of shape [B, N_feat, 3, N_samples, ...] 68 | ''' 69 | # Conv 70 | x = self.linear(x) 71 | # InstanceNorm 72 | if self.use_batchnorm != 'none': 73 | x = self.batchnorm(x) 74 | # LeakyReLU 75 | x_out = self.leaky_relu(x) 76 | return x_out 77 | 78 | 79 | class VNBatchNorm(nn.Module): 80 | def __init__(self, num_features, dim): 81 | super(VNBatchNorm, self).__init__() 82 | self.dim = dim 83 | if dim == 3 or dim == 4: 84 | self.bn = nn.BatchNorm1d(num_features) 85 | elif dim == 5: 86 | self.bn = nn.BatchNorm2d(num_features) 87 | 88 | def forward(self, x): 89 | ''' 90 | x: point features of shape [B, N_feat, 3, N_samples, ...] 91 | ''' 92 | norm = torch.sqrt((x*x).sum(2)) 93 | norm_bn = self.bn(norm) 94 | norm = norm.unsqueeze(2) 95 | norm_bn = norm_bn.unsqueeze(2) 96 | x = x / norm * norm_bn 97 | 98 | return x 99 | -------------------------------------------------------------------------------- /matlab/sl3_equivariant_lift.asv: -------------------------------------------------------------------------------- 1 | clc; 2 | clear; 3 | 4 | syms v1 v2 v3 v4 v5 v6 v7 v8 5 | assumeAlso([v1 v2 v3 v4 v5 v6 v7 v8],'real') 6 | 7 | syms h1 h2 h3 h4 h5 h6 h7 h8 h9 8 | assumeAlso([h1 h2 h3 h4 h5 h6 h7 h8 h9],'real') 9 | %% 10 | 11 | E1 = [1, 0, 0; 0, -1, 0; 0, 0, 0]; 12 | E2 = [0, 1, 0; 1, 0, 0; 0, 0, 0]; 13 | E3 = [0, -1, 0; 1, 0, 0; 0, 0, 0]; 14 | E4 = [1, 0, 0; 0, 1, 0; 0, 0, -2]; 15 | E5 = [0, 0, 1; 0, 0, 0; 0, 0, 0]; 16 | E6 = [0, 0, 0; 0, 0, 1; 0, 0, 0]; 17 | E7 = [0, 0, 0; 0, 0, 0; 1, 0, 0]; 18 | E8 = [0, 0, 0; 0, 0, 0; 0, 1, 0]; 19 | 20 | Ex = [0, 0, 0;0, 0, -1;0, 1, 0]; 21 | Ey = [0, 0, 1;0, 0, 0;-1, 0, 0]; 22 | Ez = [0, -1, 0;1, 0, 0;0, 0, 0]; 23 | 24 | Ea1 = [1,0,0;0,0,0;0,0,-1]; 25 | Ea2 = [0,0,0;0,1,0;0,0,-1]; 26 | 27 | Enx = [0,0,1;0,0,0;0,0,0]; 28 | Eny = [0,0,0;0,0,1;0,0,0 29 | 30 | % E1 = [0, 0, 1; 0, 0, 0; 0, 0, 0]; 31 | % E2 = [0, 0, 0; 0, 0, 1; 0, 0, 0]; 32 | % E3 = [0, -1, 0; 1, 0, 0; 0, 0, 0]; 33 | % E4 = [0, 0, 0; 0, 0, 0; 0, 0, -1]; 34 | % E5 = [1, 0, 0; 0, -1, 0; 0, 0, 0]; 35 | % E6 = [0, 1, 0; 0, 0, 0; 0, 0, 0]; 36 | % E7 = [0, 0, 0; 0, 0, 0; 1, 0, 0]; 37 | % E8 = [0, 0, 0; 0, 0, 0; 0, 1, 0]; 38 | 39 | E = {E1,E2,E3,E4,E5,E6,E7,E8}; 40 | 41 | E1_vec = reshape(E1,1,[])'; 42 | E2_vec = reshape(E2,1,[])'; 43 | E3_vec = reshape(E3,1,[])'; 44 | E4_vec = reshape(E4,1,[])'; 45 | E5_vec = reshape(E5,1,[])'; 46 | E6_vec = reshape(E6,1,[])'; 47 | E7_vec = reshape(E7,1,[])'; 48 | E8_vec = reshape(E8,1,[])'; 49 | 50 | 51 | E_vec = [E1_vec, E2_vec, E3_vec, E4_vec, E5_vec, E6_vec, E7_vec, E8_vec]; 52 | 53 | x_hat = v1*E1+v2*E2+v3*E3+v4*E4+v5*E5+v6*E6+v7*E7+v8*E8; 54 | 55 | 56 | %% find h 57 | H = [h1,h2,h3;h4,h5,h6;h7,h8,h9]; 58 | Ad_H_hat = H*x_hat*inv(H); 59 | Ad_H_hat_vec = reshape(Ad_H_hat,1,[])'; 60 | 61 | % solve least square to obtain x 62 | x = inv(E_vec'*E_vec)*(E_vec')*Ad_H_hat_vec; 63 | var = [v1,v2,v3,v4,v5,v6,v7,v8]; 64 | [Ad_H_sym,b]=equationsToMatrix(x,var); 65 | 66 | syms f(h1,h2,h3,h4,h5,h6,h7,h8,h9); 67 | f(h1,h2,h3,h4,h5,h6,h7,h8,h9) = Ad_H_sym; 68 | 69 | %% find Ad_Ei 70 | Ad_E = {}; 71 | dAd_E = {}; 72 | for i=1:8 73 | % Find Ad_E_i 74 | H = expm(E{i}); % Using numerical exponential for now 75 | Ad_E{i} = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3))); 76 | dAd_E{i} = logm(Ad_E{i}); 77 | end 78 | 79 | %% construct dro_3 80 | 81 | dro_3 = {}; 82 | C = []; 83 | for i =1:8 84 | dro_3{i} = kron(-E{i}',eye(8))+kron(eye(3),dAd_E{i}); 85 | C = [C;dro_3{i}]; 86 | end 87 | 88 | %% solve for the null space 89 | [U,S,V] = svd(C); 90 | 91 | 92 | Q = null(C); 93 | 94 | %% 95 | % rank(C,1e-10) 96 | 97 | %% test solving for one h 98 | a = E1+E2+E3+E4+E5+E6+E7+E8; 99 | % a = E4 100 | H = expm(a); 101 | % rnd = 100*rand(3); 102 | % x=rnd(1); 103 | % y=rnd(2); 104 | % z=rnd(3); 105 | % H = expm(x*Ex+y*Ey+z*Ez); 106 | 107 | % rnd = 100*randn(3); 108 | % n1=rnd(1); 109 | % n2=rnd(2); 110 | % n3=rnd(3); 111 | % H = [1 n1 n2; 0 1 n3; 0 0 1]; 112 | 113 | % rnd = 2*randn(3); 114 | % a1=rnd(1); 115 | % a2=rnd(2); 116 | % H = [a1, 0,0; 0,a2,0;0,0,1/a1/a2]; 117 | 118 | % H = [0.5, 0,0; 0,2,0;0,0,1/0.5/2]; 119 | Ad_test = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3))); 120 | 121 | D = kron(inv(H'),Ad_test)-eye(24); 122 | % D = kron(Ad_test,inv(H'))-eye(24); 123 | % [DU,DS,DV] = svd(D); 124 | 125 | DQ = null(D); 126 | 127 | rank(D) 128 | 129 | wd = DQ*DQ'*ones(24,1); 130 | 131 | Wd = reshape(wd,[8,3]); 132 | 133 | v_test = [2,3,1]'; 134 | H_v_test = H*v_test; 135 | 136 | x_test = Wd*v_test; 137 | x_H_test = Wd*H_v_test; 138 | x_ad_test = Ad_test*x_test; 139 | disp([x_H_test,x_ad_test]) 140 | % disp(x_H_test - x_ad_test) 141 | 142 | % rank([E1;E2;E3;E4;E5;E6;E7;E8]) 143 | 144 | %% 145 | -------------------------------------------------------------------------------- /playground/neural_ode_nonhomogeneous.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from scipy.signal import lti 7 | from scipy.signal import lsim2 8 | from scipy import interpolate 9 | from random import randint 10 | import matplotlib.pyplot as plt 11 | from torchdiffeq import odeint # modified version 12 | 13 | parser = argparse.ArgumentParser('Pendel') 14 | parser.add_argument('--data_size', type=int, default=1000) 15 | parser.add_argument('--batch_time', type=int, default=10) 16 | parser.add_argument('--batch_size', type=int, default=20) 17 | parser.add_argument('--niters', type=int, default=300) 18 | parser.add_argument('--test_freq', type=int, default=20) 19 | args = parser.parse_args() 20 | 21 | # initial conditions / system dynamics 22 | x0 = [0] 23 | # definition of the LTI-system 24 | A = np.array([0.05]) 25 | B = np.array([0.05]) # x' = 0.05 x + 0.05 u 26 | C = np.array([1]) 27 | D = np.array([0]) # y = x 28 | system = lti(A, B, C, D) 29 | 30 | t = torch.linspace(0., 25., args.data_size) 31 | u = torch.Tensor(np.ones_like(t)) # external input -- validation 32 | tout, y, x = lsim2(system, u, t, x0) 33 | true_y = torch.FloatTensor(y) 34 | ufunc_val = interpolate.interp1d(t, u.reshape(-1,1), kind='linear', axis=0, bounds_error=False, fill_value='extrapolate') 35 | y0 = torch.FloatTensor(x0) 36 | 37 | def get_batch(): 38 | u_in = torch.zeros(args.batch_time,args.batch_size) 39 | true_x = torch.zeros(args.batch_time,args.batch_size) 40 | batch_t = t[:args.batch_time] # (T) 41 | for i in range(args.batch_size): 42 | s = torch.from_numpy(np.random.choice(np.arange(args.data_size - args.batch_time, dtype=np.int64), 1, replace=False)) 43 | u = torch.ones(args.batch_time)+randint(0,10)/5 # external input -- training 44 | tout, y2, x = lsim2(system, u, batch_t, y[s]) 45 | tout = torch.tensor(tout) 46 | x = torch.tensor(x) 47 | u_in[:,i] = u.reshape(1, args.batch_time) 48 | true_x[:,i] = x.reshape(1, args.batch_time) 49 | batch_x0 = true_x[0,:] 50 | return batch_x0, batch_t, true_x, u_in 51 | 52 | 53 | def visualize(true_x, pred_x): 54 | plt.title('Federpendel') 55 | plt.xlabel('t') 56 | plt.ylabel('x,v') 57 | plt.plot(t.numpy(), true_x.numpy()) 58 | plt.plot(t.numpy(), pred_x.numpy(),'--') 59 | plt.xlim(t.min(), t.max()) 60 | plt.tight_layout() 61 | plt.draw() 62 | plt.pause(0.001) 63 | 64 | class ODEFunc(nn.Module): 65 | 66 | def __init__(self): 67 | super(ODEFunc, self).__init__() 68 | self.linear = nn.Linear(2,1) 69 | 70 | def forward(self, t, x, args): 71 | ufun = args[0] 72 | unew = torch.FloatTensor(ufun(t.detach().numpy())) 73 | in1 = torch.stack((x, unew), dim=1) 74 | out = self.linear(in1).squeeze() 75 | return out 76 | 77 | 78 | if __name__ == '__main__': 79 | 80 | func = ODEFunc() 81 | optimizer = optim.RMSprop(func.parameters(), lr=0.02) 82 | lossb = nn.MSELoss() 83 | for itr in range(1, args.niters + 1): 84 | optimizer.zero_grad() 85 | batch_x0, batch_t, batch_x, batch_u = get_batch() 86 | ufunc = interpolate.interp1d(batch_t, batch_u, kind='linear', axis=0, bounds_error=False, fill_value='extrapolate') 87 | pred_x = odeint(func, batch_x0, batch_t, args=(ufunc,), method='dopri5') 88 | loss = lossb(pred_x, batch_x) 89 | loss.backward() 90 | optimizer.step() 91 | if itr % args.test_freq == 0: 92 | with torch.no_grad(): 93 | pred_x = odeint(func, y0, t, args=(ufunc_val,), method='dopri5') 94 | loss = lossb(pred_x.squeeze(), true_y.squeeze()) 95 | print('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item())) 96 | visualize(true_y.squeeze(), pred_x.squeeze()) -------------------------------------------------------------------------------- /core/lie_group_util.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import torch 4 | sys.path.append('.') 5 | 6 | 7 | from einops import rearrange, repeat 8 | from einops.layers.torch import Rearrange 9 | 10 | from core.lie_alg_util import lie_bracket, vee_so3 11 | 12 | def skew(v): 13 | # M = [[0 , -v[2], v[1]], 14 | # [v[2], 0, -v[0]], 15 | # [-v[1], v[0], 0]] 16 | 17 | v2 = v.clone() 18 | M = torch.zeros((3,3)) 19 | M[0,1] = -v2[2] 20 | M[0,2] = v2[1] 21 | M[1,0] = v2[2] 22 | M[1,2] = -v2[0] 23 | M[2,0] = -v2[1] 24 | M[2,1] = v2[0] 25 | print("M", M) 26 | return M 27 | 28 | 29 | def exp_hat_and_so3(w): 30 | I = torch.eye(3) 31 | theta = torch.norm(w) 32 | A = skew(w) 33 | return I + (torch.sin(theta)/theta)*A + ((1-torch.cos(theta))/(theta*theta))*torch.matmul(A,A) 34 | 35 | 36 | # def bacth_exp_so3(A): 37 | # # A (B,3,3) 38 | # I = torch.eye(3).to(A.device) 39 | # theta = torch.sqrt(-torch.trace(torch.matmul(A,A))/2.0) 40 | # return I + (torch.sin(theta)/theta)*A + ((1-torch.cos(theta))/(theta*theta))*torch.matmul(A,A) 41 | 42 | def batch_trace(A): 43 | return A.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1) 44 | 45 | def exp_so3(A): 46 | I = torch.eye(3).to(A.device) 47 | # print("A: ",A.shape) 48 | theta = torch.sqrt(-batch_trace(torch.matmul(A,A))/2.0).reshape(-1,1,1) 49 | # print("theta: ", theta.shape) 50 | return I + (torch.sin(theta)/theta)*A + ((1-torch.cos(theta))/(theta**2))*torch.matmul(A,A) 51 | 52 | def log_SO3(R): 53 | 54 | # print("trace: ", torch.trace(R)) 55 | # print("R",R) 56 | cos_R = (batch_trace(R)-1)/2.0 57 | # print("cos_R", cos_R) 58 | theta = torch.acos(torch.clamp(cos_R, -1+1e-7, 1-1e-7)) 59 | # theta2 = torch.asin(torch.sqrt((3-batch_trace(R))*(1+batch_trace(R)))/2.0) 60 | # print("theta: ", theta) 61 | # print("theta2: ", theta2) 62 | 63 | if torch.isnan(theta): 64 | raise ValueError("theta is nan") 65 | # return torch.zeros((3,3)).to(R.device) 66 | # if torch.isnan(theta): 67 | # return torch.zeros((3,3)).to(R.device) 68 | # print("theta: ", theta) 69 | # if theta - np.pi < 1e-6: 70 | # return theta/2/(np.pi-theta)*(R-R.T) 71 | # elif theta > np.pi: 72 | # theta = np.pi-theta 73 | # K = (theta/(2*torch.sin(theta)))[:,None,None]*(R-R.transpose(-1,-2)) 74 | 75 | return (theta/(2*torch.sin(theta)))[:,None,None]*(R-R.transpose(-1,-2)) 76 | 77 | 78 | def BCH_first_order_approx(X,Y): 79 | return X+Y 80 | 81 | def BCH_second_order_approx(X,Y): 82 | return X+Y+1/2*lie_bracket(X,Y) 83 | 84 | def BCH_third_order_approx(X,Y): 85 | return X+Y+1/2*lie_bracket(X,Y)+1/12*lie_bracket(X,lie_bracket(X,Y))-1/12*lie_bracket(Y,lie_bracket(X,Y)) 86 | 87 | def BCH_approx(X,Y): 88 | return X+Y+1/2*lie_bracket(X,Y)+1/12*lie_bracket(X,lie_bracket(X,Y))-1/12*lie_bracket(Y,lie_bracket(X,Y))\ 89 | -1/24*lie_bracket(Y,lie_bracket(X,lie_bracket(X,Y)))-1/720*lie_bracket(Y,lie_bracket(Y,lie_bracket(Y,lie_bracket(Y,X))))\ 90 | -1/720*lie_bracket(X,lie_bracket(X,lie_bracket(X,lie_bracket(X,Y)))) 91 | 92 | def BCH_so3(X,Y): 93 | x = vee_so3(X) 94 | y = vee_so3(Y) 95 | theta = torch.norm(x) 96 | phi = torch.norm(y) 97 | delta = torch.acos(x.transpose(-1,0)@y/torch.norm(x)/torch.norm(y)) # angles between x and y 98 | a = torch.sin(theta)*torch.cos(phi/2.0)*torch.cos(phi/2.0)-torch.sin(phi)*torch.sin(theta/2.0)*torch.sin(theta/2.0)*torch.cos(delta) 99 | b = torch.sin(phi)*torch.cos(theta/2.0)*torch.cos(theta/2.0)-torch.sin(theta)*torch.sin(phi/2.0)*torch.sin(phi/2.0)*torch.cos(delta) 100 | c = 1/2.0*torch.sin(theta)*torch.sin(phi)-2.0*torch.sin(theta/2.0)*torch.sin(theta/2.0)*torch.sin(phi/2.0)*torch.sin(phi/2.0)*torch.cos(delta) 101 | d = torch.sqrt(a*a+b*b+2.0*a*b*torch.cos(delta)+c*c*torch.sin(delta)*torch.sin(delta)) 102 | 103 | alpha = torch.asin(d)*a/d/theta 104 | beta = torch.asin(d)*b/d/phi 105 | gamma = torch.asin(d)*c/d/theta/phi 106 | 107 | return alpha*X+beta*Y+gamma*lie_bracket(X,Y) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | 163 | # matlab 164 | .asv 165 | .fig 166 | 167 | weights/ 168 | logs/ 169 | data/ 170 | logs_rebuttal/ 171 | logs_archive/ 172 | logs_good/ 173 | png_ln_ode_demo/ 174 | png_neural_ode_demo/ 175 | figures/ 176 | training_all.txt 177 | eval_results.txt 178 | eval_aug_mlp.txt 179 | training_mlp_aug_equiv.txt 180 | -------------------------------------------------------------------------------- /matlab/sl3_equivariant_lift.m: -------------------------------------------------------------------------------- 1 | clc; 2 | clear; 3 | 4 | syms v1 v2 v3 v4 v5 v6 v7 v8 5 | assumeAlso([v1 v2 v3 v4 v5 v6 v7 v8],'real') 6 | 7 | syms h1 h2 h3 h4 h5 h6 h7 h8 h9 8 | assumeAlso([h1 h2 h3 h4 h5 h6 h7 h8 h9],'real') 9 | %% 10 | 11 | % E1 = [1, 0, 0; 0, -1, 0; 0, 0, 0]; 12 | % E2 = [0, 1, 0; 1, 0, 0; 0, 0, 0]; 13 | % E3 = [0, -1, 0; 1, 0, 0; 0, 0, 0]; 14 | % E4 = [1, 0, 0; 0, 1, 0; 0, 0, -2]; 15 | % E5 = [0, 0, 1; 0, 0, 0; 0, 0, 0]; 16 | % E6 = [0, 0, 0; 0, 0, 1; 0, 0, 0]; 17 | % E7 = [0, 0, 0; 0, 0, 0; 1, 0, 0]; 18 | % E8 = [0, 0, 0; 0, 0, 0; 0, 1, 0]; 19 | 20 | Ekx = [0, 0, 0;0, 0, -1;0, 1, 0]; 21 | Eky = [0, 0, 1;0, 0, 0;-1, 0, 0]; 22 | Ekz = [0, -1, 0;1, 0, 0;0, 0, 0]; 23 | 24 | Ea1 = [1,0,0;0,0,0;0,0,-1]; 25 | Ea2 = [0,0,0;0,1,0;0,0,-1]; 26 | 27 | Ea3 = [1,0,0;0,-1,0;0,0,0]; 28 | 29 | Enx = [0,0,1;0,0,0;0,0,0]; 30 | Eny = [0,0,0;0,0,1;0,0,0]; 31 | Enz = [0,1,0;0,0,0;0,0,0]; 32 | 33 | E1 = Ekx; 34 | E2 = Eky; 35 | E3 = Ekz; 36 | E4 = Ea1; 37 | E5 = Ea2; 38 | E6 = Enx; 39 | E7 = Eny; 40 | E8 = Enz; 41 | 42 | % E1 = [0, 0, 1; 0, 0, 0; 0, 0, 0]; 43 | % E2 = [0, 0, 0; 0, 0, 1; 0, 0, 0]; 44 | % E3 = [0, -1, 0; 1, 0, 0; 0, 0, 0]; 45 | % E4 = [0, 0, 0; 0, 0, 0; 0, 0, -1]; 46 | % E5 = [1, 0, 0; 0, -1, 0; 0, 0, 0]; 47 | % E6 = [0, 1, 0; 0, 0, 0; 0, 0, 0]; 48 | % E7 = [0, 0, 0; 0, 0, 0; 1, 0, 0]; 49 | % E8 = [0, 0, 0; 0, 0, 0; 0, 1, 0]; 50 | 51 | % E = {E1,E2,E3,E4,E5,E6,E7,E8}; 52 | E = {Ekx,Eky,Ekz,Ea1,Ea2,Enx,Eny,Enz}; 53 | % E = {Ekx,Eky,Ekz}; 54 | 55 | E1_vec = reshape(E1,1,[])'; 56 | E2_vec = reshape(E2,1,[])'; 57 | E3_vec = reshape(E3,1,[])'; 58 | E4_vec = reshape(E4,1,[])'; 59 | E5_vec = reshape(E5,1,[])'; 60 | E6_vec = reshape(E6,1,[])'; 61 | E7_vec = reshape(E7,1,[])'; 62 | E8_vec = reshape(E8,1,[])'; 63 | 64 | 65 | E_vec = [E1_vec, E2_vec, E3_vec, E4_vec, E5_vec, E6_vec, E7_vec, E8_vec]; 66 | 67 | x_hat = v1*E1+v2*E2+v3*E3+v4*E4+v5*E5+v6*E6+v7*E7+v8*E8; 68 | 69 | 70 | %% find h 71 | H = [h1,h2,h3;h4,h5,h6;h7,h8,h9]; 72 | Ad_H_hat = H*x_hat*inv(H); 73 | Ad_H_hat_vec = reshape(Ad_H_hat,1,[])'; 74 | 75 | % solve least square to obtain x 76 | x = inv(E_vec'*E_vec)*(E_vec')*Ad_H_hat_vec; 77 | var = [v1,v2,v3,v4,v5,v6,v7,v8]; 78 | [Ad_H_sym,b]=equationsToMatrix(x,var); 79 | 80 | syms f(h1,h2,h3,h4,h5,h6,h7,h8,h9); 81 | f(h1,h2,h3,h4,h5,h6,h7,h8,h9) = Ad_H_sym; 82 | 83 | %% find Ad_Ei 84 | Ad_E = {}; 85 | dAd_E = {}; 86 | for i=1:size(E,2) 87 | % Find Ad_E_i 88 | H = expm(E{i}); % Using numerical exponential for now 89 | Ad_E{i} = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3))); 90 | dAd_E{i} = logm(Ad_E{i}); 91 | end 92 | 93 | %% construct dro_3 94 | 95 | dro_3 = {}; 96 | C = []; 97 | for i =1:size(E,2) 98 | dro_3{i} = kron(-E{i}',eye(8))+kron(eye(3),dAd_E{i}); 99 | C = [C;dro_3{i}]; 100 | end 101 | 102 | %% solve for the null space 103 | [U,S,V] = svd(C); 104 | 105 | 106 | Q = null(C); 107 | 108 | %% 109 | H = expm(hat_sl3(rand(8,1))); 110 | Ad_test = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3))); 111 | 112 | % w = Q*Q'*rand(24,1)*10; 113 | w = Q*Q'*ones(24,1)*10; 114 | 115 | W = reshape(w,[8,3]); 116 | 117 | v_test = [2,3,1]'; 118 | H_v_test = H*v_test; 119 | 120 | x_test = W*v_test; 121 | x_H_test = W*H_v_test; 122 | x_ad_test = Ad_test*x_test; 123 | disp([x_H_test,x_ad_test]) 124 | 125 | %% 126 | % rank(C,1e-10) 127 | 128 | %% test solving for one h 129 | % % a = E1+E2+E3+E4+E5+E6+E7+E8; 130 | % 131 | % % a = Ekx+Eky+Ekz+Ea1+Ea2+Enx+Eny+Enz; 132 | % 133 | % a = Ekx+Eky+Ekz+Ea1+Ea2+Enx+Eny+Enz; 134 | % % a = Ea1+Ea2; 135 | % % a = E4 136 | % H = expm(a); 137 | % % rnd = 100*rand(3); 138 | % % x=rnd(1); 139 | % % y=rnd(2); 140 | % % z=rnd(3); 141 | % % H = expm(x*Ekx+y*Eky+z*Ekz); 142 | % 143 | % % rnd = 100*randn(3); 144 | % % n1=rnd(1); 145 | % % n2=rnd(2); 146 | % % n3=rnd(3); 147 | % % H = [1 n1 n2; 0 1 n3; 0 0 1]; 148 | % 149 | % % rnd = 2*randn(3); 150 | % % a1=rnd(1); 151 | % % a2=rnd(2); 152 | % % H = [a1, 0,0; 0,a2,0;0,0,1/a1/a2]; 153 | % 154 | % % H = [0.5, 0,0; 0,2,0;0,0,1/0.5/2]; 155 | % Ad_test = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3))); 156 | % 157 | % D = kron(inv(H'),Ad_test)-eye(24); 158 | % % D = kron(Ad_test,inv(H'))-eye(24); 159 | % % [DU,DS,DV] = svd(D); 160 | % 161 | % DQ = null(D); 162 | % 163 | % rank(D) 164 | % 165 | % wd = DQ*DQ'*ones(24,1); 166 | % 167 | % Wd = reshape(wd,[8,3]); 168 | % 169 | % v_test = [2,3,1]'; 170 | % H_v_test = H*v_test; 171 | % 172 | % x_test = Wd*v_test; 173 | % x_H_test = Wd*H_v_test; 174 | % x_ad_test = Ad_test*x_test; 175 | % disp([x_H_test,x_ad_test]) 176 | % disp(x_H_test - x_ad_test) 177 | 178 | % rank([E1;E2;E3;E4;E5;E6;E7;E8]) 179 | 180 | %% 181 | -------------------------------------------------------------------------------- /experiment/platonic_solid_cls_test.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import argparse 5 | import os 6 | import time 7 | import yaml 8 | import numpy as np 9 | from tqdm import tqdm 10 | import matplotlib.pyplot as plt 11 | import time 12 | from scipy.spatial.transform import Rotation 13 | 14 | import torch 15 | from torch import nn 16 | import torch.optim as optim 17 | from torch.utils.tensorboard import SummaryWriter 18 | from torch.utils.data import Dataset, DataLoader, IterableDataset 19 | 20 | from core.lie_neurons_layers import * 21 | from experiment.platonic_solid_cls_layers import * 22 | from data_gen.gen_platonic_solids import * 23 | 24 | 25 | def test_perspective(model, test_loader, criterion, config, device): 26 | model.eval() 27 | hat_layer = HatLayer(algebra_type='sl3').to(device) 28 | rots = random_sample_rotations(config['num_rotations'], config['rotation_factor'],device) 29 | with torch.no_grad(): 30 | loss_sum = 0.0 31 | num_correct = 0 32 | num_correct_non_conj = 0 33 | for iter, samples in tqdm(enumerate(test_loader, start=0)): 34 | 35 | x = samples[0].to(device) 36 | y = samples[1].to(device) 37 | 38 | x_hat = hat_layer(x) 39 | x = rearrange(x,'b n f k -> b f k n') 40 | 41 | output = model(x) 42 | _, prediction_non_conj = torch.max(output,1) 43 | num_correct_non_conj += (prediction_non_conj==y).sum().item() 44 | for r in range(config['num_rotations']): 45 | cur_rot = rots[r,:,:] 46 | x_rot_hat = torch.matmul(cur_rot, torch.matmul(x_hat, torch.inverse(cur_rot))) 47 | x_rot = rearrange(vee_sl3(x_rot_hat),'b n f k -> b f k n') 48 | output_rot = model(x_rot) 49 | 50 | _, prediction = torch.max(output_rot,1) 51 | num_correct += (prediction==y).sum().item() 52 | 53 | loss = criterion(output_rot, y) 54 | loss_sum += loss.item() 55 | 56 | loss_avg = loss_sum/config['num_test']*config['batch_size']/config['num_rotations'] 57 | acc_avg = num_correct/config['num_test']/config['num_rotations'] 58 | acc_avg_non_conj = num_correct_non_conj/config['num_test']/1.0 59 | return loss_avg, acc_avg, acc_avg_non_conj 60 | 61 | def random_sample_rotations(num_rotation, rotation_factor: float = 1.0, device='cpu') -> np.ndarray: 62 | r = np.zeros((num_rotation, 3, 3)) 63 | for n in range(num_rotation): 64 | # angle_z, angle_y, angle_x 65 | euler = np.random.rand(3) * np.pi * 2 / rotation_factor # (0, 2 * pi / rotation_range) 66 | r[n,:,:] = Rotation.from_euler('zyx', euler).as_matrix() 67 | return torch.from_numpy(r).type('torch.FloatTensor').to(device) 68 | 69 | 70 | 71 | 72 | def main(): 73 | # torch.autograd.set_detect_anomaly(True) 74 | 75 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 76 | print('Using ', device) 77 | 78 | parser = argparse.ArgumentParser(description='Train the network') 79 | parser.add_argument('--training_config', type=str, 80 | default=os.path.dirname(os.path.abspath(__file__))+'/../config/platonic_solid_cls/testing_param.yaml') 81 | args = parser.parse_args() 82 | 83 | # load yaml file 84 | config = yaml.safe_load(open(args.training_config)) 85 | 86 | 87 | test_set = PlatonicDataset(config['num_test']) 88 | test_loader = DataLoader(dataset=test_set, batch_size=config['batch_size'], 89 | shuffle=config['shuffle']) 90 | 91 | if config['model_type'] == "LN_relu_bracket": 92 | model = LNReluBracketPlatonicSolidClassifier(3).to(device) 93 | elif config['model_type'] == "LN_relu": 94 | model = LNReluPlatonicSolidClassifier(3).to(device) 95 | elif config['model_type'] == "LN_bracket": 96 | model = LNBracketPlatonicSolidClassifier(3).to(device) 97 | elif config['model_type'] == "MLP": 98 | model = MLP(288).to(device) 99 | elif config['model_type'] == "LN_bracket_no_residual": 100 | model = LNBracketNoResidualConnectPlatonicSolidClassifier(3).to(device) 101 | 102 | print("Using model: ", config['model_type']) 103 | print("total number of parameters: ", sum(p.numel() for p in model.parameters())) 104 | 105 | checkpoint = torch.load(config['model_path']) 106 | model.load_state_dict(checkpoint['model_state_dict'],strict=False) 107 | 108 | criterion = nn.CrossEntropyLoss().to(device) 109 | test_loss, test_acc, test_acc_non_conj = test_perspective(model, test_loader, criterion, config, device) 110 | print("test loss: ", test_loss) 111 | print("test with conjugate acc: ", test_acc) 112 | print("test without conjugate acc: ", "{:.4f}".format(test_acc_non_conj)) 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /data_gen/gen_sl3_inv_data.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import numpy as np 5 | import torch 6 | from scipy.linalg import expm 7 | 8 | from core.lie_neurons_layers import * 9 | 10 | 11 | def invariant_function(x1, x2): 12 | return torch.sin(torch.trace(x1@x1))+torch.cos(torch.trace(x2@x2))\ 13 | -torch.pow(torch.trace(x2@x2), 3)/2.0+torch.det(x1@x2)+torch.exp(torch.trace(x1@x1)) 14 | 15 | 16 | if __name__ == "__main__": 17 | data_saved_path = "data/sl3_inv_data/" 18 | data_name = "sl3_inv_10000_s_05_augmented" 19 | gen_augmented_training_data = True 20 | num_training = 5000 21 | num_testing = 10000 22 | num_conjugate = 1 23 | rnd_scale = 0.5 24 | 25 | train_data = {} 26 | test_data = {} 27 | 28 | hat_layer = HatLayer(algebra_type='sl3') 29 | 30 | x1 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_training)))\ 31 | .reshape(1, 1, 8, num_training) 32 | x2 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_training)))\ 33 | .reshape(1, 1, 8, num_training) 34 | 35 | # conjugate transformation 36 | h = torch.Tensor(np.random.uniform(-rnd_scale, 37 | rnd_scale, (num_conjugate, num_training, 8))) 38 | H = torch.linalg.matrix_exp(hat_layer(h)) 39 | # conjugate x1 40 | x1_hat = hat_layer(x1.transpose(2, -1)) 41 | conj_x1_hat = torch.matmul(H, torch.matmul(x1_hat, torch.inverse(H))) 42 | conj_x1 = rearrange(vee_sl3(conj_x1_hat), 'b c t l -> b l t c') 43 | 44 | # conjugate x2 45 | x2_hat = hat_layer(x2.transpose(2, -1)) 46 | conj_x2_hat = torch.matmul(H, torch.matmul(x2_hat, torch.inverse(H))) 47 | conj_x2 = rearrange(vee_sl3(conj_x2_hat), 'b c t l -> b l t c') 48 | 49 | inv_output = torch.zeros((1, num_training, 1)) 50 | # compute invariant function 51 | for n in range(num_training): 52 | inv_output[0, n, 0] = invariant_function( 53 | x1_hat[0, 0, n, :, :], x2_hat[0, 0, n, :, :]) 54 | 55 | # print("--------------invariant function output: ------------------") 56 | # print(invariant_function(x1_hat[0, 0, n, :, :], x2_hat[0, 0, n, :, :])) 57 | # print(invariant_function( 58 | # conj_x1_hat[0, 0, n, :, :], conj_x2_hat[0, 0, n, :, :])) 59 | 60 | # print(x1.shape) 61 | # print(conj_x1.shape) 62 | if(gen_augmented_training_data): 63 | # this is for training data only 64 | train_data['x1'] = torch.cat((x1.reshape(8, num_training),conj_x1.reshape(8, num_training*num_conjugate)),dim=1).numpy() 65 | train_data['x2'] = torch.cat((x2.reshape(8, num_training),conj_x2.reshape(8, num_training*num_conjugate)),dim=1).numpy() 66 | train_data['x1_conjugate'] = conj_x1.reshape(8, num_training, num_conjugate).repeat(1,num_conjugate+1,1).numpy() 67 | train_data['x2_conjugate'] = conj_x2.reshape(8, num_training, num_conjugate).repeat(1,num_conjugate+1,1).numpy() 68 | train_data['y'] = inv_output.reshape(1, num_training).repeat(1,num_conjugate+1).numpy() 69 | else: 70 | train_data['x1'] = x1.numpy().reshape(8, num_training) 71 | train_data['x2'] = x2.numpy().reshape(8, num_training) 72 | train_data['x1_conjugate'] = conj_x1.numpy().reshape(8, num_training, num_conjugate) 73 | train_data['x2_conjugate'] = conj_x2.numpy().reshape(8, num_training, num_conjugate) 74 | train_data['y'] = inv_output.numpy().reshape(1, num_training) 75 | 76 | np.savez(data_saved_path + data_name + "_train_data.npz", **train_data) 77 | 78 | 79 | ''' 80 | Generate testing data 81 | ''' 82 | x1 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_testing)))\ 83 | .reshape(1, 1, 8, num_testing) 84 | x2 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_testing)))\ 85 | .reshape(1, 1, 8, num_testing) 86 | 87 | # conjugate transformation 88 | h = torch.Tensor(np.random.uniform(-rnd_scale, 89 | rnd_scale, (num_conjugate, num_testing, 8))) 90 | H = torch.linalg.matrix_exp(hat_layer(h)) 91 | # conjugate x1 92 | x1_hat = hat_layer(x1.transpose(2, -1)) 93 | conj_x1_hat = torch.matmul(H, torch.matmul(x1_hat, torch.inverse(H))) 94 | conj_x1 = rearrange(vee_sl3(conj_x1_hat), 'b c t l -> b l t c') 95 | 96 | # conjugate x2 97 | x2_hat = hat_layer(x2.transpose(2, -1)) 98 | conj_x2_hat = torch.matmul(H, torch.matmul(x2_hat, torch.inverse(H))) 99 | conj_x2 = rearrange(vee_sl3(conj_x2_hat), 'b c t l -> b l t c') 100 | inv_output = torch.zeros((1, num_testing, 1)) 101 | # compute invariant function 102 | for n in range(num_testing): 103 | inv_output[0, n, 0] = invariant_function( 104 | x1_hat[0, 0, n, :, :], x2_hat[0, 0, n, :, :]) 105 | 106 | # for i in range(num_conjugate): 107 | # print(invariant_function(conj_x1_hat[0, i, n, :, :],conj_x2_hat[0, i, n, :, :])) 108 | 109 | test_data['x1'] = x1.numpy().reshape(8, num_testing) 110 | test_data['x2'] = x2.numpy().reshape(8, num_testing) 111 | test_data['x1_conjugate'] = conj_x1.numpy().reshape(8, num_testing, num_conjugate) 112 | test_data['x2_conjugate'] = conj_x2.numpy().reshape(8, num_testing, num_conjugate) 113 | test_data['y'] = inv_output.numpy().reshape(1, num_testing) 114 | 115 | np.savez(data_saved_path + data_name + "_test_data.npz", **test_data) 116 | 117 | print("Done! Data saved to: \n", data_saved_path + 118 | data_name + "_train_data.npz\n", data_saved_path + data_name + "_test_data.npz") 119 | -------------------------------------------------------------------------------- /experiment/sl3_inv_test.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import argparse 5 | import os 6 | import time 7 | import yaml 8 | import numpy as np 9 | from tqdm import tqdm 10 | import matplotlib.pyplot as plt 11 | import time 12 | 13 | import torch 14 | from torch import nn 15 | import torch.optim as optim 16 | from torch.utils.tensorboard import SummaryWriter 17 | from torch.utils.data import Dataset, DataLoader 18 | 19 | from core.lie_neurons_layers import * 20 | from experiment.sl3_inv_layers import * 21 | from data_loader.sl3_inv_data_loader import * 22 | 23 | 24 | def test(model, test_loader, criterion, config, device): 25 | model.eval() 26 | with torch.no_grad(): 27 | loss_sum = 0.0 28 | for i, sample in tqdm(enumerate(test_loader, start=0)): 29 | x = sample['x'].to(device) 30 | y = sample['y'].to(device) 31 | 32 | output = model(x) 33 | 34 | loss = criterion(output, y) 35 | loss_sum += loss.item() 36 | 37 | loss_avg = loss_sum/len(test_loader) 38 | 39 | return loss_avg 40 | 41 | 42 | def test_invariance(model, test_loader, criterion, config, device): 43 | model.eval() 44 | with torch.no_grad(): 45 | loss_sum = 0.0 46 | loss_non_conj_sum = 0.0 47 | diff_output_sum = 0.0 48 | loss_all = [] 49 | loss_non_conj_all = [] 50 | 51 | for i, sample in tqdm(enumerate(test_loader, start=0)): 52 | x = sample['x'].to(device) 53 | x_conj = sample['x_conjugate'].to(device) 54 | y = sample['y'].to(device) 55 | 56 | output_x = model(x) 57 | loss_non_conj = criterion(output_x, y) 58 | loss_non_conj_sum += loss_non_conj.item() 59 | loss_non_conj_all.append(loss_non_conj.item()) 60 | # print(output_x) 61 | # print(x_conj.shape) 62 | for j in range(x_conj.shape[1]): 63 | x_conj_j = x_conj[:, j, :, :, :] 64 | output_conj = model(x_conj_j) 65 | diff_output = output_x - output_conj 66 | loss = criterion(output_conj, y) 67 | loss_all.append(loss.item()) 68 | loss_sum += loss.item() 69 | diff_output_sum += torch.sum(torch.abs(diff_output)) 70 | # print(output_conj) 71 | 72 | # print("diff", diff_output[0,:]) 73 | # print("conj_out", output_conj[0,:]) 74 | # print("out",output_x[0,:]) 75 | # print("y", y[0,:]) 76 | # print(loss.item()) 77 | # print("----------------------") 78 | # print(loss_all) 79 | # print(torch.Tensor(loss_all).shape) 80 | # loss_avg2 = torch.mean(torch.Tensor(loss_all)) 81 | # print("loss_avg 2: ", loss_avg2) 82 | loss_std = torch.std(torch.Tensor(loss_all)) 83 | loss_non_conj_std = torch.std(torch.Tensor(loss_non_conj_all)) 84 | loss_avg = loss_sum/len(test_loader)/x_conj.shape[1] 85 | diff_output_avg = diff_output_sum/len(test_loader.dataset)/x_conj.shape[1] 86 | loss_non_conj_avg = loss_non_conj_sum/len(test_loader) 87 | 88 | return loss_avg, loss_std, loss_non_conj_avg, loss_non_conj_std, diff_output_avg 89 | 90 | def main(): 91 | # torch.autograd.set_detect_anomaly(True) 92 | 93 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 94 | print('Using ', device) 95 | 96 | parser = argparse.ArgumentParser(description='Train the network') 97 | parser.add_argument('--test_config', type=str, 98 | default=os.path.dirname(os.path.abspath(__file__))+'/../config/sl3_inv/testing_param.yaml') 99 | args = parser.parse_args() 100 | 101 | # load yaml file 102 | config = yaml.safe_load(open(args.test_config)) 103 | 104 | test_set = sl3InvDataSet2Input(config['test_data_path'], device=device) 105 | test_loader = DataLoader(dataset=test_set, batch_size=config['batch_size'], 106 | shuffle=config['shuffle']) 107 | 108 | 109 | if config['model_type'] == "LN_relu_bracket": 110 | model = SL3InvariantReluBracketLayers(2).to(device) 111 | elif config['model_type'] == "LN_relu": 112 | model = SL3InvariantReluLayers(2).to(device) 113 | elif config['model_type'] == "LN_bracket": 114 | model = SL3InvariantBracketLayers(2).to(device) 115 | elif config['model_type'] == "MLP": 116 | model = MLP(16).to(device) 117 | elif config['model_type'] == "LN_bracket_no_residual": 118 | model = SL3InvariantBracketNoResidualConnectLayers(2).to(device) 119 | 120 | print("Using model: ", config['model_type']) 121 | print("total number of parameters: ", sum(p.numel() for p in model.parameters())) 122 | checkpoint = torch.load(config['model_path']) 123 | model.load_state_dict(checkpoint['model_state_dict'],strict=False) 124 | 125 | criterion = nn.MSELoss().to(device) 126 | # test_loss = test(model, test_loader, criterion, config, device) 127 | test_loss_inv, test_loss_inv_std, loss_non_conj_avg, loss_non_conj_std, diff_output_avg = test_invariance(model, test_loader, criterion, config, device) 128 | 129 | print("test_loss type:",type(test_loss_inv)) 130 | print("test loss: ", test_loss_inv) 131 | print("test loss std", test_loss_inv_std) 132 | print("avg diff output: ", diff_output_avg) 133 | print("loss non conj: ", loss_non_conj_avg) 134 | print("loss non conj std: ", loss_non_conj_std) 135 | 136 | if __name__ == "__main__": 137 | main() 138 | -------------------------------------------------------------------------------- /data_gen/gen_sp4_inv_data.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import numpy as np 5 | import torch 6 | from scipy.linalg import expm 7 | 8 | from core.lie_neurons_layers import * 9 | 10 | 11 | def invariant_function(x1, x2): 12 | return torch.sin(torch.trace(x1@x1))+torch.cos(torch.trace(x2@x2))\ 13 | -torch.pow(torch.trace(x2@x2), 3)/2.0+torch.det(x1@x2)+torch.exp(torch.trace(x1@x1)) 14 | 15 | 16 | if __name__ == "__main__": 17 | data_saved_path = "data/sp4_inv_data/" 18 | data_name = "sp4_inv_10000_s_05_augmented" 19 | gen_augmented_training_data = True 20 | num_training = 5000 21 | num_testing = 10000 22 | num_conjugate = 1 23 | rnd_scale = 0.5 24 | 25 | train_data = {} 26 | test_data = {} 27 | 28 | hat_layer = HatLayer(algebra_type='sp4') 29 | 30 | x1 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 10, num_training)))\ 31 | .reshape(1, 1, 10, num_training) 32 | x2 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 10, num_training)))\ 33 | .reshape(1, 1, 10, num_training) 34 | 35 | # conjugate transformation 36 | h = torch.Tensor(np.random.uniform(-rnd_scale, 37 | rnd_scale, (num_conjugate, num_training, 10))) 38 | H = torch.linalg.matrix_exp(hat_layer(h)) 39 | # conjugate x1 40 | x1_hat = hat_layer(x1.transpose(2, -1)) 41 | conj_x1_hat = torch.matmul(H, torch.matmul(x1_hat, torch.inverse(H))) 42 | conj_x1 = rearrange(vee_sp4(conj_x1_hat), 'b c t l -> b l t c') 43 | 44 | # conjugate x2 45 | x2_hat = hat_layer(x2.transpose(2, -1)) 46 | conj_x2_hat = torch.matmul(H, torch.matmul(x2_hat, torch.inverse(H))) 47 | conj_x2 = rearrange(vee_sp4(conj_x2_hat), 'b c t l -> b l t c') 48 | 49 | inv_output = torch.zeros((1, num_training, 1)) 50 | # compute invariant function 51 | for n in range(num_training): 52 | inv_output[0, n, 0] = invariant_function( 53 | x1_hat[0, 0, n, :, :], x2_hat[0, 0, n, :, :]) 54 | 55 | # print("--------------invariant function output: ------------------") 56 | # print(invariant_function(x1_hat[0, 0, n, :, :], x2_hat[0, 0, n, :, :])) 57 | # print(invariant_function( 58 | # conj_x1_hat[0, 0, n, :, :], conj_x2_hat[0, 0, n, :, :])) 59 | 60 | # print(x1.shape) 61 | # print(conj_x1.shape) 62 | if(gen_augmented_training_data): 63 | # this is for training data only 64 | train_data['x1'] = torch.cat((x1.reshape(10, num_training),conj_x1.reshape(10, num_training*num_conjugate)),dim=1).numpy() 65 | train_data['x2'] = torch.cat((x2.reshape(10, num_training),conj_x2.reshape(10, num_training*num_conjugate)),dim=1).numpy() 66 | train_data['x1_conjugate'] = conj_x1.reshape(10, num_training, num_conjugate).repeat(1,num_conjugate+1,1).numpy() 67 | train_data['x2_conjugate'] = conj_x2.reshape(10, num_training, num_conjugate).repeat(1,num_conjugate+1,1).numpy() 68 | train_data['y'] = inv_output.reshape(1, num_training).repeat(1,num_conjugate+1).numpy() 69 | else: 70 | train_data['x1'] = x1.numpy().reshape(10, num_training) 71 | train_data['x2'] = x2.numpy().reshape(10, num_training) 72 | train_data['x1_conjugate'] = conj_x1.numpy().reshape(10, num_training, num_conjugate) 73 | train_data['x2_conjugate'] = conj_x2.numpy().reshape(10, num_training, num_conjugate) 74 | train_data['y'] = inv_output.numpy().reshape(1, num_training) 75 | 76 | np.savez(data_saved_path + data_name + "_train_data.npz", **train_data) 77 | 78 | 79 | ''' 80 | Generate testing data 81 | ''' 82 | x1 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 10, num_testing)))\ 83 | .reshape(1, 1, 10, num_testing) 84 | x2 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 10, num_testing)))\ 85 | .reshape(1, 1, 10, num_testing) 86 | 87 | # conjugate transformation 88 | h = torch.Tensor(np.random.uniform(-rnd_scale, 89 | rnd_scale, (num_conjugate, num_testing, 10))) 90 | H = torch.linalg.matrix_exp(hat_layer(h)) 91 | # conjugate x1 92 | x1_hat = hat_layer(x1.transpose(2, -1)) 93 | conj_x1_hat = torch.matmul(H, torch.matmul(x1_hat, torch.inverse(H))) 94 | conj_x1 = rearrange(vee_sp4(conj_x1_hat), 'b c t l -> b l t c') 95 | 96 | # conjugate x2 97 | x2_hat = hat_layer(x2.transpose(2, -1)) 98 | conj_x2_hat = torch.matmul(H, torch.matmul(x2_hat, torch.inverse(H))) 99 | # print(conj_x2_hat.shape) 100 | conj_x2 = rearrange(vee_sp4(conj_x2_hat), 'b c t l -> b l t c') 101 | inv_output = torch.zeros((1, num_testing, 1)) 102 | # compute invariant function 103 | for n in range(num_testing): 104 | inv_output[0, n, 0] = invariant_function( 105 | x1_hat[0, 0, n, :, :], x2_hat[0, 0, n, :, :]) 106 | 107 | # for i in range(num_conjugate): 108 | # print(invariant_function(conj_x1_hat[0, i, n, :, :],conj_x2_hat[0, i, n, :, :])) 109 | 110 | test_data['x1'] = x1.numpy().reshape(10, num_testing) 111 | test_data['x2'] = x2.numpy().reshape(10, num_testing) 112 | test_data['x1_conjugate'] = conj_x1.numpy().reshape(10, num_testing, num_conjugate) 113 | test_data['x2_conjugate'] = conj_x2.numpy().reshape(10, num_testing, num_conjugate) 114 | test_data['y'] = inv_output.numpy().reshape(1, num_testing) 115 | 116 | np.savez(data_saved_path + data_name + "_test_data.npz", **test_data) 117 | 118 | print("Done! Data saved to: \n", data_saved_path + 119 | data_name + "_train_data.npz\n", data_saved_path + data_name + "_test_data.npz") 120 | -------------------------------------------------------------------------------- /data_loader/sl3_inv_data_loader.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | from scipy.linalg import expm 8 | from tqdm import tqdm 9 | 10 | 11 | from einops import rearrange, repeat 12 | from einops.layers.torch import Rearrange 13 | 14 | from core.lie_neurons_layers import * 15 | 16 | 17 | class sl3InvDataSet2Input(Dataset): 18 | def __init__(self, data_path, device='cuda'): 19 | data = np.load(data_path) 20 | _, num_points = data['x1'].shape 21 | _,_,num_conjugate = data['x1_conjugate'].shape 22 | self.x1 = rearrange(torch.from_numpy(data['x1']).type( 23 | 'torch.FloatTensor').to(device), 'k n -> n 1 k 1') 24 | self.x2 = rearrange(torch.from_numpy(data['x2']).type( 25 | 'torch.FloatTensor').to(device), 'k n -> n 1 k 1') 26 | self.x = torch.cat((self.x1, self.x2), dim=1) 27 | 28 | self.x1_conjugate = rearrange(torch.from_numpy(data['x1_conjugate']).type( 29 | 'torch.FloatTensor').to(device),'k n c -> c n 1 k 1') 30 | self.x2_conjugate = rearrange(torch.from_numpy(data['x2_conjugate']).type( 31 | 'torch.FloatTensor').to(device),'k n c -> c n 1 k 1') 32 | self.x_conjugate = torch.cat( 33 | (self.x1_conjugate, self.x2_conjugate), dim=2) 34 | self.y = torch.from_numpy(data['y']).type( 35 | 'torch.FloatTensor').to(device).reshape(num_points, 1) 36 | 37 | self.num_data = self.x1.shape[0] 38 | 39 | def __len__(self): 40 | return self.num_data 41 | 42 | def __getitem__(self, idx): 43 | if torch.is_tensor(idx): 44 | idx = idx.tolist() 45 | 46 | sample = {'x1': self.x1[idx, :, :, :], 'x2': self.x2[idx, :, :, :], 'x': self.x[idx, :, :, :], 47 | 'x1_conjugate': self.x1_conjugate[:,idx, :, :, :], 'x2_conjugate': self.x2_conjugate[:,idx, :, :, :], 48 | 'x_conjugate': self.x_conjugate[:,idx, :, :, :], 'y': self.y[idx, :]} 49 | return sample 50 | 51 | 52 | class sl3InvDataSet5Input(Dataset): 53 | def __init__(self, data_path, device='cuda'): 54 | data = np.load(data_path) 55 | _, num_points = data['x1'].shape 56 | _,_,num_conjugate = data['x1_conjugate'].shape 57 | self.x1 = rearrange(torch.from_numpy(data['x1']).type( 58 | 'torch.FloatTensor').to(device), 'k n -> n 1 k 1') 59 | self.x2 = rearrange(torch.from_numpy(data['x2']).type( 60 | 'torch.FloatTensor').to(device), 'k n -> n 1 k 1') 61 | self.x3 = rearrange(torch.from_numpy(data['x3']).type( 62 | 'torch.FloatTensor').to(device), 'k n -> n 1 k 1') 63 | self.x4 = rearrange(torch.from_numpy(data['x4']).type( 64 | 'torch.FloatTensor').to(device), 'k n -> n 1 k 1') 65 | self.x5 = rearrange(torch.from_numpy(data['x5']).type( 66 | 'torch.FloatTensor').to(device), 'k n -> n 1 k 1') 67 | self.x = torch.cat((self.x1, self.x2, self.x3, self.x4, self.x5), dim=1) 68 | 69 | self.x1_conjugate = rearrange(torch.from_numpy(data['x1_conjugate']).type( 70 | 'torch.FloatTensor').to(device),'k n c -> c n 1 k 1') 71 | self.x2_conjugate = rearrange(torch.from_numpy(data['x2_conjugate']).type( 72 | 'torch.FloatTensor').to(device),'k n c -> c n 1 k 1') 73 | self.x3_conjugate = rearrange(torch.from_numpy(data['x3_conjugate']).type( 74 | 'torch.FloatTensor').to(device),'k n c -> c n 1 k 1') 75 | self.x4_conjugate = rearrange(torch.from_numpy(data['x4_conjugate']).type( 76 | 'torch.FloatTensor').to(device),'k n c -> c n 1 k 1') 77 | self.x5_conjugate = rearrange(torch.from_numpy(data['x5_conjugate']).type( 78 | 'torch.FloatTensor').to(device),'k n c -> c n 1 k 1') 79 | self.x_conjugate = torch.cat( 80 | (self.x1_conjugate, self.x2_conjugate,self.x3_conjugate,self.x4_conjugate,self.x5_conjugate), dim=2) 81 | self.y = torch.from_numpy(data['y']).type( 82 | 'torch.FloatTensor').to(device).reshape(num_points, 1) 83 | 84 | self.num_data = self.x1.shape[0] 85 | 86 | def __len__(self): 87 | return self.num_data 88 | 89 | def __getitem__(self, idx): 90 | if torch.is_tensor(idx): 91 | idx = idx.tolist() 92 | 93 | sample = {'x1': self.x1[idx, :, :, :], 'x2': self.x2[idx, :, :, :],'x3': self.x3[idx, :, :, :],\ 94 | 'x4': self.x4[idx, :, :, :],'x5': self.x5[idx, :, :, :], 'x': self.x[idx, :, :, :], 95 | 'x1_conjugate': self.x1_conjugate[:,idx, :, :, :], 'x2_conjugate': self.x2_conjugate[:,idx, :, :, :],\ 96 | 'x3_conjugate': self.x3_conjugate[:,idx, :, :, :],'x4_conjugate': self.x4_conjugate[:,idx, :, :, :],\ 97 | 'x5_conjugate': self.x2_conjugate[:,idx, :, :, :], 'x_conjugate': self.x_conjugate[:,idx, :, :, :],\ 98 | 'y': self.y[idx, :]} 99 | return sample 100 | 101 | if __name__ == "__main__": 102 | 103 | DataLoader = sl3InvDataSet5Input("data/sl3_inv_5_input_data/sl3_inv_1000_s_05_train_data.npz") 104 | 105 | print(DataLoader.x1.shape) 106 | print(DataLoader.x2.shape) 107 | print(DataLoader.x1_conjugate.shape) 108 | print(DataLoader.x2_conjugate.shape) 109 | print(DataLoader.x.shape) 110 | print(DataLoader.x_conjugate.shape) 111 | print(DataLoader.y.shape) 112 | for i, samples in tqdm(enumerate(DataLoader, start=0)): 113 | input_data = samples['x'] 114 | y = samples['y'] 115 | print(y) 116 | # print(input_data.shape) 117 | -------------------------------------------------------------------------------- /experiment/sp4_inv_test.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import argparse 5 | import os 6 | import time 7 | import yaml 8 | import numpy as np 9 | from tqdm import tqdm 10 | import matplotlib.pyplot as plt 11 | import time 12 | 13 | import torch 14 | from torch import nn 15 | import torch.optim as optim 16 | from torch.utils.tensorboard import SummaryWriter 17 | from torch.utils.data import Dataset, DataLoader 18 | 19 | from core.lie_neurons_layers import * 20 | from experiment.sp4_inv_layers import * 21 | from data_loader.sp4_inv_data_loader import * 22 | 23 | 24 | def test(model, test_loader, criterion, config, device): 25 | model.eval() 26 | with torch.no_grad(): 27 | loss_sum = 0.0 28 | for i, sample in tqdm(enumerate(test_loader, start=0)): 29 | x = sample['x'].to(device) 30 | y = sample['y'].to(device) 31 | 32 | output = model(x) 33 | 34 | loss = criterion(output, y) 35 | loss_sum += loss.item() 36 | 37 | loss_avg = loss_sum/len(test_loader) 38 | 39 | return loss_avg 40 | 41 | 42 | def test_invariance(model, test_loader, criterion, config, device): 43 | model.eval() 44 | with torch.no_grad(): 45 | loss_sum = 0.0 46 | loss_non_conj_sum = 0.0 47 | diff_output_sum = 0.0 48 | loss_all = [] 49 | loss_non_conj_all = [] 50 | 51 | for i, sample in tqdm(enumerate(test_loader, start=0)): 52 | x = sample['x'].to(device) 53 | x_conj = sample['x_conjugate'].to(device) 54 | y = sample['y'].to(device) 55 | 56 | output_x = model(x) 57 | loss_non_conj = criterion(output_x, y) 58 | loss_non_conj_sum += loss_non_conj.item() 59 | loss_non_conj_all.append(loss_non_conj.item()) 60 | # print(output_x) 61 | # print(x_conj.shape) 62 | for j in range(x_conj.shape[1]): 63 | x_conj_j = x_conj[:, j, :, :, :] 64 | output_conj = model(x_conj_j) 65 | diff_output = output_x - output_conj 66 | loss = criterion(output_conj, y) 67 | loss_all.append(loss.item()) 68 | loss_sum += loss.item() 69 | diff_output_sum += torch.sum(torch.abs(diff_output)) 70 | # print(output_conj) 71 | 72 | # print("diff", diff_output[0,:]) 73 | # print("conj_out", output_conj[0,:]) 74 | # print("out",output_x[0,:]) 75 | # print("y", y[0,:]) 76 | # print(loss.item()) 77 | # print("----------------------") 78 | # print(loss_all) 79 | # print(torch.Tensor(loss_all).shape) 80 | # loss_avg2 = torch.mean(torch.Tensor(loss_all)) 81 | # print("loss_avg 2: ", loss_avg2) 82 | loss_std = torch.std(torch.Tensor(loss_all)) 83 | loss_non_conj_std = torch.std(torch.Tensor(loss_non_conj_all)) 84 | loss_avg = loss_sum/len(test_loader)/x_conj.shape[1] 85 | diff_output_avg = diff_output_sum/len(test_loader.dataset)/x_conj.shape[1] 86 | loss_non_conj_avg = loss_non_conj_sum/len(test_loader) 87 | 88 | return loss_avg, loss_std, loss_non_conj_avg, loss_non_conj_std, diff_output_avg 89 | 90 | def main(): 91 | # torch.autograd.set_detect_anomaly(True) 92 | 93 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 94 | print('Using ', device) 95 | 96 | parser = argparse.ArgumentParser(description='Train the network') 97 | parser.add_argument('--test_config', type=str, 98 | default=os.path.dirname(os.path.abspath(__file__))+'/../config/sp4_inv/testing_param.yaml') 99 | args = parser.parse_args() 100 | 101 | # load yaml file 102 | config = yaml.safe_load(open(args.test_config)) 103 | 104 | test_set = sp4InvDataSet(config['test_data_path'], device=device) 105 | test_loader = DataLoader(dataset=test_set, batch_size=config['batch_size'], 106 | shuffle=config['shuffle']) 107 | 108 | 109 | if config['model_type'] == "LN_relu_bracket": 110 | model = SP4InvariantReluBracketLayers(2).to(device) 111 | elif config['model_type'] == "LN_relu": 112 | model = SP4InvariantReluLayers(2).to(device) 113 | elif config['model_type'] == "LN_bracket": 114 | model = SP4InvariantBracketLayers(2).to(device) 115 | elif config['model_type'] == "MLP": 116 | model = MLP(20).to(device) 117 | elif config['model_type'] == "MLP512": 118 | model = MLP512(20).to(device) 119 | elif config['model_type'] == "LN_bracket_no_residual": 120 | model = SP4InvariantBracketNoResidualConnectLayers(2).to(device) 121 | else: 122 | raise ValueError("model type not supported") 123 | 124 | print("Using model: ", config['model_type']) 125 | print("total number of parameters: ", sum(p.numel() for p in model.parameters())) 126 | checkpoint = torch.load(config['model_path']) 127 | model.load_state_dict(checkpoint['model_state_dict'],strict=False) 128 | 129 | criterion = nn.MSELoss().to(device) 130 | # test_loss = test(model, test_loader, criterion, config, device) 131 | test_loss_inv, test_loss_inv_std, loss_non_conj_avg, loss_non_conj_std, diff_output_avg = test_invariance(model, test_loader, criterion, config, device) 132 | 133 | print("test_loss type:",type(test_loss_inv)) 134 | print("test loss: ", test_loss_inv) 135 | print("test loss std", test_loss_inv_std) 136 | print("avg diff output: ", diff_output_avg) 137 | print("loss non conj: ", loss_non_conj_avg) 138 | print("loss non conj std: ", loss_non_conj_std) 139 | 140 | if __name__ == "__main__": 141 | main() 142 | -------------------------------------------------------------------------------- /data_loader/so3_bch_data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | import sys # nopep8 3 | sys.path.append('.') # nopep8 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset 8 | from scipy.linalg import expm 9 | from tqdm import tqdm 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | from einops import rearrange, repeat 14 | from einops.layers.torch import Rearrange 15 | 16 | from core.lie_neurons_layers import * 17 | from core.lie_group_util import * 18 | 19 | class so3BchDataSet(Dataset): 20 | def __init__(self, data_path, device='cuda'): 21 | data = np.load(data_path) 22 | num_points, _ = data['x1'].shape 23 | 24 | self.x1 = rearrange(torch.from_numpy(data['x1']).type( 25 | 'torch.FloatTensor').to(device), 'n k -> n 1 k 1') 26 | self.x2 = rearrange(torch.from_numpy(data['x2']).type( 27 | 'torch.FloatTensor').to(device), 'n k -> n 1 k 1') 28 | self.x = torch.cat((self.x1, self.x2), dim=1) # n 2 k 1 29 | 30 | self.y = torch.from_numpy(data['y']).type( 31 | 'torch.FloatTensor').to(device) # [N,3] 32 | # print(self.y.shape) 33 | self.num_data = self.x1.shape[0] 34 | 35 | def __len__(self): 36 | return self.num_data 37 | 38 | def __getitem__(self, idx): 39 | if torch.is_tensor(idx): 40 | idx = idx.tolist() 41 | 42 | sample = {'x1': self.x1[idx, :, :, :], 'x2': self.x2[idx, :, :, :], 43 | 'x': self.x[idx, :, :, :], 'y': self.y[idx, :]} 44 | return sample 45 | 46 | class so3BchTestDataSet(Dataset): 47 | ''' 48 | Test data set contains augmented infomation 49 | ''' 50 | def __init__(self, data_path, device='cuda'): 51 | data = np.load(data_path) 52 | num_points, _ = data['x1'].shape 53 | 54 | self.x1 = rearrange(torch.from_numpy(data['x1']).type( 55 | 'torch.FloatTensor').to(device), 'n k -> n 1 k 1') 56 | self.x2 = rearrange(torch.from_numpy(data['x2']).type( 57 | 'torch.FloatTensor').to(device), 'n k -> n 1 k 1') 58 | self.x = torch.cat((self.x1, self.x2), dim=1) # n 2 k 1 59 | self.x1_conj = rearrange(torch.from_numpy(data['x1_conjugate']).type( 60 | 'torch.FloatTensor').to(device), 'n c k -> c n 1 k 1') 61 | self.x2_conj = rearrange(torch.from_numpy(data['x2_conjugate']).type( 62 | 'torch.FloatTensor').to(device), 'n c k -> c n 1 k 1') 63 | self.x_conj = torch.cat((self.x1_conj, self.x2_conj), dim=2) # c n 2 k 1 64 | self.y = torch.from_numpy(data['y']).type( 65 | 'torch.FloatTensor').to(device) # [N,3] 66 | self.y_conj = rearrange(torch.from_numpy(data['y_conj']).type( 67 | 'torch.FloatTensor').to(device), 'n c k -> c n k') # [c, N,3] 68 | self.R = rearrange(torch.from_numpy(data['R_aug']).type( 69 | 'torch.FloatTensor').to(device), 'n c k1 k2 -> c n k1 k2') # [c, N,3,3] 70 | # print(self.y.shape) 71 | self.num_data = self.x1.shape[0] 72 | 73 | def __len__(self): 74 | return self.num_data 75 | 76 | def __getitem__(self, idx): 77 | if torch.is_tensor(idx): 78 | idx = idx.tolist() 79 | 80 | sample = {'x1': self.x1[idx, :, :, :], 'x2': self.x2[idx, :, :, :],\ 81 | 'x': self.x[idx, :, :, :], 'y': self.y[idx, :],\ 82 | 'x1_conj': self.x1_conj[:,idx, :, :, :], 'x2_conj': self.x2_conj[:,idx, :, :, :],\ 83 | 'x_conj': self.x_conj[:,idx,:,:,:],'y_conj': self.y_conj[:,idx, :], \ 84 | 'R': self.R[:,idx, :, :] 85 | } 86 | return sample 87 | 88 | 89 | if __name__ == "__main__": 90 | 91 | DataLoader = so3BchDataSet( 92 | "data/so3_bch_data/so3_bch_10000_4_train_data.npz") 93 | 94 | hat = HatLayer(algebra_type='so3').to('cuda') 95 | # print(DataLoader.x1.shape) 96 | # print(DataLoader.x2.shape) 97 | # print(DataLoader.x.shape) 98 | # print(DataLoader.y.shape) 99 | 100 | sum = 0 101 | sum_inf = 0 102 | sum_nan = 0 103 | sum_y_nan = 0 104 | sum_y_inf = 0 105 | v3_list = torch.zeros((DataLoader.num_data, 3)) 106 | for i, samples in tqdm(enumerate(DataLoader, start=0)): 107 | input_data = samples['x'].to('cuda') 108 | R1 = exp_so3(hat(input_data[0, :, :].squeeze(-1))) 109 | R2 = exp_so3(hat(input_data[1, :, :].squeeze(-1))) 110 | v3 = vee(log_SO3(torch.matmul(R1,R2)),algebra_type='so3') 111 | v3_list[i,:] = v3 112 | # print("v1", input_data[0, :, :].squeeze(-1)) 113 | # print("v2", input_data[1, :, :].squeeze(-1)) 114 | # print("v3", v3) 115 | # print("norm v3", torch.norm(v3)) 116 | y = samples['y'].to('cuda') 117 | 118 | # print(torch.trace(hat(y))) 119 | # print(float(torch.norm(v3))) 120 | if(torch.isinf(input_data).any()): 121 | sum_inf += 1 122 | 123 | if(torch.isnan(input_data).any()): 124 | sum_nan += 1 125 | 126 | if(torch.isinf(y).any()): 127 | sum_y_inf += 1 128 | 129 | if(torch.isnan(y).any()): 130 | sum_y_nan += 1 131 | 132 | # print("norm", torch.norm(v3-y)) 133 | 134 | if(torch.norm(v3-y) > 1e-3): 135 | # print("\nv3", v3) 136 | # print("y", y) 137 | # print("diff", v3-y) 138 | # print("norm", torch.norm(v3-y)) 139 | # print("-------------") 140 | sum += 1 141 | 142 | # diff_norm = torch.norm(v3-y) 143 | 144 | # print(y) 145 | # print(input_data.shape) 146 | 147 | # write a code to plot histogram of v1, v2, v3 148 | plt.hist(v3_list[:,0].cpu().numpy(), bins=100) 149 | plt.show() 150 | 151 | print("error > 1e-3", sum) 152 | print("input inf num", sum_inf) 153 | print("input nan num", sum_nan) 154 | print("y inf num", sum_y_inf) 155 | print("y nan num", sum_y_nan) -------------------------------------------------------------------------------- /data_gen/gen_so3_bch.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import numpy as np 5 | import scipy.linalg 6 | import torch 7 | import math 8 | 9 | from core.lie_alg_util import * 10 | from core.lie_group_util import * 11 | 12 | if __name__ == "__main__": 13 | 14 | data_saved_path = "data/so3_bch_data/" 15 | data_name = "so3_bch_10000_augmented" 16 | gen_augmented_training_data = False 17 | 18 | num_training = 5000 19 | num_testing = 5000 20 | num_conjugate = 1 21 | 22 | train_data = {} 23 | test_data = {} 24 | so3_hatlayer = HatLayer(algebra_type='so3') 25 | 26 | def gen_random_rotation_vector(): 27 | v = torch.rand(1,3) 28 | v = v/torch.norm(v) 29 | phi = (math.pi-1e-6)*torch.rand(1) 30 | v = phi*v 31 | return v 32 | 33 | x1 = torch.zeros((num_training,3)) 34 | x2 = torch.zeros((num_training,3)) 35 | y = torch.zeros((num_training,3)) 36 | 37 | # if we want to generate adjoint augmented training data 38 | if(gen_augmented_training_data): 39 | x1_conj = torch.zeros((num_training, num_conjugate,3)) 40 | x2_conj = torch.zeros((num_training, num_conjugate,3)) 41 | y_conj = torch.zeros((num_training, num_conjugate,3)) 42 | R_aug = torch.zeros((num_training, num_conjugate,3,3)) 43 | 44 | # generate training data 45 | for i in range(num_training): 46 | # generate random v1, v2 47 | v1 = gen_random_rotation_vector() 48 | v2 = gen_random_rotation_vector() 49 | 50 | K1 = so3_hatlayer(v1) 51 | K2 = so3_hatlayer(v2) 52 | 53 | R1 = exp_so3(K1[0,:,:]) 54 | R2 = exp_so3(K2[0,:,:]) 55 | 56 | R3 = torch.matmul(R1,R2) 57 | 58 | v3 = vee(log_SO3(R3), algebra_type='so3') 59 | 60 | # v3 = vee(BCH_approx(K1[0,:,:], K2[0,:,:]), algebra_type='so3') 61 | if(torch.norm(v3) > math.pi): 62 | print("----------output bigger than pi---------") 63 | print("norm v1", torch.norm(v1)) 64 | print("norm v2", torch.norm(v2)) 65 | print("norm v3", torch.norm(v3)) 66 | print("v1",v1) 67 | print("v2",v2) 68 | print("v3",v3) 69 | print("R1",R1) 70 | print("R2",R2) 71 | print("R3",R3) 72 | 73 | x1[i,:] = v1 74 | x2[i,:] = v2 75 | y[i,:] = v3 76 | 77 | if(gen_augmented_training_data): 78 | for j in range(num_conjugate): 79 | v4 = gen_random_rotation_vector() 80 | K4 = so3_hatlayer(v4) 81 | R4 = exp_so3(K4[0,:,:]) 82 | 83 | R1_conj = R4@R1@R4.T 84 | R2_conj = R4@R2@R4.T 85 | R3_conj = R1_conj@R2_conj 86 | 87 | v1_conj = vee(log_SO3(R1_conj), algebra_type='so3') 88 | v2_conj = vee(log_SO3(R2_conj), algebra_type='so3') 89 | v3_conj = vee(log_SO3(R3_conj), algebra_type='so3') 90 | 91 | x1_conj[i,j,:] = v1_conj 92 | x2_conj[i,j,:] = v2_conj 93 | y_conj[i,j,:] = v3_conj 94 | R_aug[i,j,:,:] = R4 95 | 96 | if(gen_augmented_training_data): 97 | # concatenate the conjugate data 98 | train_data['x1'] = torch.cat((x1,x1_conj.reshape(num_training*num_conjugate,3)),dim=0).numpy() 99 | train_data['x2'] = torch.cat((x2,x2_conj.reshape(num_training*num_conjugate,3)),dim=0).numpy() 100 | train_data['y'] = torch.cat((y,y_conj.reshape(num_training*num_conjugate,3)),dim=0).numpy() 101 | train_data['R_aug'] = R_aug.numpy() 102 | else: 103 | train_data['x1'] = x1.numpy() 104 | train_data['x2'] = x2.numpy() 105 | train_data['y'] = y.numpy() 106 | 107 | np.savez(data_saved_path + data_name + "_train_data.npz", **train_data) 108 | 109 | 110 | 111 | # generate testing data 112 | x1 = torch.zeros((num_testing,3)) 113 | x2 = torch.zeros((num_testing,3)) 114 | y = torch.zeros((num_testing,3)) 115 | 116 | x1_conj = torch.zeros((num_testing, num_conjugate,3)) 117 | x2_conj = torch.zeros((num_testing, num_conjugate,3)) 118 | y_conj = torch.zeros((num_testing, num_conjugate,3)) 119 | R_aug = torch.zeros((num_testing, num_conjugate,3,3)) 120 | 121 | for i in range(num_testing): 122 | # generate random v1, v2 123 | v1 = gen_random_rotation_vector() 124 | v2 = gen_random_rotation_vector() 125 | 126 | K1 = so3_hatlayer(v1) 127 | K2 = so3_hatlayer(v2) 128 | 129 | R1 = exp_so3(K1[0,:,:]) 130 | R2 = exp_so3(K2[0,:,:]) 131 | 132 | R3 = torch.matmul(R1,R2) 133 | v3 = vee(log_SO3(R3), algebra_type='so3') 134 | 135 | # v3 = vee(BCH_approx(K1[0,:,:], K2[0,:,:]), algebra_type='so3') 136 | 137 | x1[i,:] = v1 138 | x2[i,:] = v2 139 | y[i,:] = v3 140 | 141 | for j in range(num_conjugate): 142 | v4 = gen_random_rotation_vector() 143 | K4 = so3_hatlayer(v4) 144 | R4 = exp_so3(K4[0,:,:]) 145 | 146 | R1_conj = R4@R1@R4.T 147 | R2_conj = R4@R2@R4.T 148 | R3_conj = R1_conj@R2_conj 149 | 150 | v1_conj = vee(log_SO3(R1_conj), algebra_type='so3') 151 | v2_conj = vee(log_SO3(R2_conj), algebra_type='so3') 152 | v3_conj = vee(log_SO3(R3_conj), algebra_type='so3') 153 | 154 | x1_conj[i,j,:] = v1_conj 155 | x2_conj[i,j,:] = v2_conj 156 | y_conj[i,j,:] = v3_conj 157 | R_aug[i,j,:,:] = R4 158 | 159 | test_data['x1'] = x1.numpy() 160 | test_data['x2'] = x2.numpy() 161 | test_data['x1_conjugate'] = x1_conj.numpy() 162 | test_data['x2_conjugate'] = x2_conj.numpy() 163 | test_data['y'] = y.numpy() 164 | test_data['y_conj'] = y_conj.numpy() 165 | test_data['R_aug'] = R_aug.numpy() 166 | 167 | np.savez(data_saved_path + data_name + "_test_data.npz", **test_data) 168 | 169 | print("Done! Data saved to: \n", data_saved_path + 170 | data_name + "_train_data.npz\n", data_saved_path + data_name + "_test_data.npz") -------------------------------------------------------------------------------- /experiment/sl3_equiv_test.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import argparse 5 | import os 6 | import time 7 | import yaml 8 | import numpy as np 9 | from tqdm import tqdm 10 | import matplotlib.pyplot as plt 11 | import time 12 | 13 | import torch 14 | from torch import nn 15 | import torch.optim as optim 16 | from torch.utils.tensorboard import SummaryWriter 17 | from torch.utils.data import Dataset, DataLoader 18 | 19 | from core.lie_alg_util import * 20 | from core.lie_neurons_layers import * 21 | from experiment.sl3_equiv_layers import * 22 | from data_loader.sl3_equiv_data_loader import * 23 | 24 | 25 | def test(model, test_loader, criterion, config, device): 26 | model.eval() 27 | with torch.no_grad(): 28 | loss_sum = 0.0 29 | for i, sample in tqdm(enumerate(test_loader, start=0)): 30 | x = sample['x'].to(device) 31 | y = sample['y'].to(device) 32 | 33 | output = model(x) 34 | 35 | loss = criterion(output, y) 36 | loss_sum += loss.item() 37 | 38 | loss_avg = loss_sum/len(test_loader) 39 | 40 | return loss_avg 41 | 42 | 43 | def test_equivariance(model, test_loader, criterion, config, device): 44 | model.eval() 45 | hat_layer = HatLayer(algebra_type='sl3').to(device) 46 | with torch.no_grad(): 47 | loss_sum = 0.0 48 | loss_non_conj_sum = 0.0 49 | diff_output_sum = 0.0 50 | for i, sample in tqdm(enumerate(test_loader, start=0)): 51 | x = sample['x'].to(device) # [B, 5, 8, 1] 52 | x_conj = sample['x_conjugate'].to(device) # [B, C, 5, 8, 1] 53 | y = sample['y'].to(device) # [B, 8] 54 | H = sample['H'].to(device) # [B, C, 3, 3] 55 | y_conj = sample['y_conj'] 56 | 57 | output_x = model(x) # [B, 8] 58 | output_x_hat = hat_layer(output_x) # [B, 3, 3] 59 | loss_non_conj_sum += criterion(output_x, y).item() 60 | 61 | for j in range(x_conj.shape[1]): 62 | x_conj_j = x_conj[:, j, :, :, :] # [B, 5, 8, 1] 63 | H_j = H[:, j, :, :] # [B, 3, 3] 64 | conj_output = model(x_conj_j) # [B, 8] 65 | # print('output hat', output_x_hat.shape) 66 | # print('H_j', H_j.shape) 67 | output_then_conj_hat = torch.matmul(H_j, torch.matmul(output_x_hat, torch.inverse(H_j))) 68 | # print('output then conj hat', output_then_conj_hat.shape) 69 | output_then_conj = vee_sl3(output_then_conj_hat) 70 | 71 | diff_output = output_then_conj - conj_output 72 | # print(output_then_conj) 73 | 74 | conj_output_hat = hat_layer(conj_output) 75 | conj_output_hat_conj_back = torch.matmul(torch.inverse(H_j), torch.matmul(conj_output_hat, H_j)) 76 | conj_output_conj_back = vee_sl3(conj_output_hat_conj_back) 77 | 78 | 79 | # y_hat = hat_layer(y) 80 | # conj_y_hat = torch.matmul(H_j, torch.matmul(y_hat, torch.inverse(H_j))) 81 | # conj_y = vee_sl3(conj_y_hat) 82 | loss = criterion(conj_output_conj_back, y) 83 | # print("conj_out", conj_output[0,:]) 84 | # print("y_conj", y_conj[0, j, :]) 85 | # print("diff_out_out_conj", diff_output[0,:]) 86 | # print("out",output_x[0,:]) 87 | # print("y", y[0,:]) 88 | # print("diff_out",output_x[0,:] - y[0,:]) 89 | # print(loss.item()) 90 | # print("----------------------") 91 | loss_sum += loss.item() 92 | diff_output_sum += torch.sum(torch.abs(diff_output)) 93 | 94 | # print('diff_output: ', diff_output) 95 | 96 | loss_avg = loss_sum/len(test_loader)/x_conj.shape[1] 97 | diff_output_avg = diff_output_sum/len(test_loader.dataset)/x_conj.shape[1]/x_conj.shape[3] 98 | loss_non_conj_avg = loss_non_conj_sum/len(test_loader) 99 | 100 | return loss_avg, loss_non_conj_avg, diff_output_avg 101 | 102 | def main(): 103 | # torch.autograd.set_detect_anomaly(True) 104 | 105 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 106 | print('Using ', device) 107 | 108 | parser = argparse.ArgumentParser(description='Train the network') 109 | parser.add_argument('--test_config', type=str, 110 | default=os.path.dirname(os.path.abspath(__file__))+'/../config/sl3_equiv/testing_param.yaml') 111 | args = parser.parse_args() 112 | 113 | # load yaml file 114 | config = yaml.safe_load(open(args.test_config)) 115 | 116 | test_set = sl3EquivDataSetLieBracket2Input(config['test_data_path'], device=device) 117 | test_loader = DataLoader(dataset=test_set, batch_size=config['batch_size'], 118 | shuffle=config['shuffle']) 119 | 120 | if config['model_type'] == "LN_relu_bracket": 121 | model = SL3EquivariantReluBracketLayers(2).to(device) 122 | elif config['model_type'] == "LN_relu": 123 | model = SL3EquivariantReluLayers(2).to(device) 124 | elif config['model_type'] == "LN_bracket": 125 | model = SL3EquivariantBracketLayers(2).to(device) 126 | elif config['model_type'] == "MLP": 127 | model = MLP(16).to(device) 128 | elif config['model_type'] == "LN_bracket_no_residual": 129 | model = SL3EquivariantBracketNoResidualConnectLayers(2).to(device) 130 | 131 | print("Using model: ", config['model_type']) 132 | print("total number of parameters: ", sum(p.numel() for p in model.parameters())) 133 | 134 | # model = SL3InvariantLayersTest(2).to(device) 135 | checkpoint = torch.load(config['model_path']) 136 | model.load_state_dict(checkpoint['model_state_dict'],strict=False) 137 | 138 | criterion = nn.MSELoss().to(device) 139 | test_loss_equiv, loss_non_conj_avg, diff_output_avg = test_equivariance(model, test_loader, criterion, config, device) 140 | print("test_loss type:",type(test_loss_equiv)) 141 | # print("avg diff output type: ", diff_output_avg.dtype) 142 | print("test loss: ", test_loss_equiv) 143 | print("avg diff output: ", diff_output_avg) 144 | print("loss non conj avg: ", loss_non_conj_avg) 145 | 146 | if __name__ == "__main__": 147 | main() 148 | -------------------------------------------------------------------------------- /playground/neural_ode_demo.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import os 5 | import argparse 6 | import time 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | from core.lie_neurons_layers import * 13 | 14 | parser = argparse.ArgumentParser('ODE demo') 15 | parser.add_argument('--method', type=str, choices=['dopri5', 'adams'], default='dopri5') 16 | parser.add_argument('--data_size', type=int, default=1000) 17 | parser.add_argument('--batch_time', type=int, default=10) 18 | parser.add_argument('--batch_size', type=int, default=20) 19 | parser.add_argument('--niters', type=int, default=2000) 20 | parser.add_argument('--test_freq', type=int, default=20) 21 | parser.add_argument('--viz', action='store_true') 22 | parser.add_argument('--gpu', type=int, default=0) 23 | parser.add_argument('--adjoint', action='store_true') 24 | args = parser.parse_args() 25 | 26 | if args.adjoint: 27 | from torchdiffeq import odeint_adjoint as odeint 28 | else: 29 | from torchdiffeq import odeint 30 | 31 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') 32 | 33 | true_y0 = torch.tensor([[2., 0.]]).to(device) 34 | t = torch.linspace(0., 25., args.data_size).to(device) 35 | true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]]).to(device) 36 | 37 | 38 | class Lambda(nn.Module): 39 | 40 | def forward(self, t, y): 41 | return torch.mm(y**3, true_A) 42 | 43 | 44 | with torch.no_grad(): 45 | true_y = odeint(Lambda(), true_y0, t, method='dopri5') 46 | 47 | 48 | def get_batch(): 49 | s = torch.from_numpy(np.random.choice(np.arange(args.data_size - args.batch_time, dtype=np.int64), args.batch_size, replace=False)) 50 | batch_y0 = true_y[s] # (M, D) 51 | batch_t = t[:args.batch_time] # (T) 52 | batch_y = torch.stack([true_y[s + i] for i in range(args.batch_time)], dim=0) # (T, M, D) 53 | return batch_y0.to(device), batch_t.to(device), batch_y.to(device) 54 | 55 | 56 | def makedirs(dirname): 57 | if not os.path.exists(dirname): 58 | os.makedirs(dirname) 59 | 60 | 61 | if args.viz: 62 | makedirs('png_neural_ode_demo') 63 | import matplotlib.pyplot as plt 64 | fig = plt.figure(figsize=(12, 4), facecolor='white') 65 | ax_traj = fig.add_subplot(131, frameon=False) 66 | ax_phase = fig.add_subplot(132, frameon=False) 67 | ax_vecfield = fig.add_subplot(133, frameon=False) 68 | plt.show(block=False) 69 | 70 | 71 | def visualize(true_y, pred_y, odefunc, itr): 72 | 73 | if args.viz: 74 | 75 | ax_traj.cla() 76 | ax_traj.set_title('Trajectories') 77 | ax_traj.set_xlabel('t') 78 | ax_traj.set_ylabel('x,y') 79 | ax_traj.plot(t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 0], t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 1], 'g-') 80 | ax_traj.plot(t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 0], '--', t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 1], 'b--') 81 | ax_traj.set_xlim(t.cpu().min(), t.cpu().max()) 82 | ax_traj.set_ylim(-2, 2) 83 | ax_traj.legend() 84 | 85 | ax_phase.cla() 86 | ax_phase.set_title('Phase Portrait') 87 | ax_phase.set_xlabel('x') 88 | ax_phase.set_ylabel('y') 89 | ax_phase.plot(true_y.cpu().numpy()[:, 0, 0], true_y.cpu().numpy()[:, 0, 1], 'g-') 90 | ax_phase.plot(pred_y.cpu().numpy()[:, 0, 0], pred_y.cpu().numpy()[:, 0, 1], 'b--') 91 | ax_phase.set_xlim(-2, 2) 92 | ax_phase.set_ylim(-2, 2) 93 | 94 | ax_vecfield.cla() 95 | ax_vecfield.set_title('Learned Vector Field') 96 | ax_vecfield.set_xlabel('x') 97 | ax_vecfield.set_ylabel('y') 98 | 99 | y, x = np.mgrid[-2:2:21j, -2:2:21j] 100 | dydt = odefunc(0, torch.Tensor(np.stack([x, y], -1).reshape(21 * 21, 2)).to(device)).cpu().detach().numpy() 101 | mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1) 102 | dydt = (dydt / mag) 103 | dydt = dydt.reshape(21, 21, 2) 104 | 105 | ax_vecfield.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], color="black") 106 | ax_vecfield.set_xlim(-2, 2) 107 | ax_vecfield.set_ylim(-2, 2) 108 | 109 | fig.tight_layout() 110 | plt.savefig('png_neural_ode_demo/{:03d}'.format(itr)) 111 | plt.draw() 112 | plt.pause(0.001) 113 | 114 | 115 | class ODEFunc(nn.Module): 116 | 117 | def __init__(self): 118 | super(ODEFunc, self).__init__() 119 | 120 | self.net = nn.Sequential( 121 | nn.Linear(2, 50), 122 | nn.Tanh(), 123 | nn.Linear(50, 2), 124 | ) 125 | 126 | for m in self.net.modules(): 127 | if isinstance(m, nn.Linear): 128 | nn.init.normal_(m.weight, mean=0, std=0.1) 129 | nn.init.constant_(m.bias, val=0) 130 | 131 | def forward(self, t, y): 132 | # print(y.shape) 133 | return self.net(y**3) 134 | 135 | 136 | class RunningAverageMeter(object): 137 | """Computes and stores the average and current value""" 138 | 139 | def __init__(self, momentum=0.99): 140 | self.momentum = momentum 141 | self.reset() 142 | 143 | def reset(self): 144 | self.val = None 145 | self.avg = 0 146 | 147 | def update(self, val): 148 | if self.val is None: 149 | self.avg = val 150 | else: 151 | self.avg = self.avg * self.momentum + val * (1 - self.momentum) 152 | self.val = val 153 | 154 | 155 | if __name__ == '__main__': 156 | 157 | ii = 0 158 | 159 | func = ODEFunc().to(device) 160 | 161 | optimizer = optim.RMSprop(func.parameters(), lr=1e-3) 162 | end = time.time() 163 | 164 | time_meter = RunningAverageMeter(0.97) 165 | 166 | loss_meter = RunningAverageMeter(0.97) 167 | 168 | for itr in range(1, args.niters + 1): 169 | optimizer.zero_grad() 170 | batch_y0, batch_t, batch_y = get_batch() 171 | pred_y = odeint(func, batch_y0, batch_t).to(device) 172 | loss = torch.mean(torch.abs(pred_y - batch_y)) 173 | loss.backward() 174 | optimizer.step() 175 | 176 | time_meter.update(time.time() - end) 177 | loss_meter.update(loss.item()) 178 | 179 | if itr % args.test_freq == 0: 180 | with torch.no_grad(): 181 | pred_y = odeint(func, true_y0, t) 182 | loss = torch.mean(torch.abs(pred_y - true_y)) 183 | print('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item())) 184 | visualize(true_y, pred_y, func, ii) 185 | ii += 1 186 | 187 | end = time.time() 188 | -------------------------------------------------------------------------------- /experiment/sl3_inv_train.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import argparse 5 | import os 6 | import time 7 | import yaml 8 | import numpy as np 9 | from tqdm import tqdm 10 | import matplotlib.pyplot as plt 11 | import time 12 | 13 | import torch 14 | from torch import nn 15 | import torch.optim as optim 16 | from torch.utils.tensorboard import SummaryWriter 17 | from torch.utils.data import Dataset, DataLoader 18 | 19 | from core.lie_neurons_layers import * 20 | from experiment.sl3_inv_layers import * 21 | from data_loader.sl3_inv_data_loader import * 22 | 23 | 24 | def init_writer(config): 25 | writer = SummaryWriter( 26 | config['log_writer_path']+"_"+str(time.localtime()), comment=config['model_description']) 27 | writer.add_text("train_data_path: ", config['train_data_path']) 28 | writer.add_text("model_save_path: ", config['model_save_path']) 29 | writer.add_text("log_writer_path: ", config['log_writer_path']) 30 | writer.add_text("shuffle: ", str(config['shuffle'])) 31 | writer.add_text("batch_size: ", str(config['batch_size'])) 32 | writer.add_text("init_lr: ", str(config['initial_learning_rate'])) 33 | writer.add_text("num_epochs: ", str(config['num_epochs'])) 34 | 35 | return writer 36 | 37 | 38 | def test(model, test_loader, criterion, config, device): 39 | model.eval() 40 | with torch.no_grad(): 41 | loss_sum = 0.0 42 | for i, sample in tqdm(enumerate(test_loader, start=0)): 43 | x = sample['x'].to(device) 44 | y = sample['y'].to(device) 45 | 46 | output = model(x) 47 | 48 | loss = criterion(output, y) 49 | loss_sum += loss.item() 50 | 51 | loss_avg = loss_sum/len(test_loader) 52 | 53 | return loss_avg 54 | 55 | 56 | def train(model, train_loader, test_loader, config, device='cpu'): 57 | 58 | writer = init_writer(config) 59 | 60 | # create criterion 61 | criterion = nn.MSELoss().to(device) 62 | # criterion = nn.L1Loss().to(device) 63 | optimizer = optim.Adam(model.parameters( 64 | ), lr=config['initial_learning_rate'], weight_decay=config['weight_decay_rate']) 65 | # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=config['learning_rate_decay_rate']) 66 | # scheduler = optim.lr_scheduler.LinearLR(optimizer,total_iters=config['num_epochs']) 67 | 68 | # if config['resume_training']: 69 | # checkpoint = torch.load(config['resume_model_path']) 70 | # model.load_state_dict(checkpoint['model_state_dict']) 71 | # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 72 | # start_epoch = checkpoint['epoch'] 73 | # else: 74 | start_epoch = 0 75 | 76 | best_loss = float("inf") 77 | for epoch in range(start_epoch, config['num_epochs']): 78 | running_loss = 0.0 79 | loss_sum = 0.0 80 | model.train() 81 | optimizer.zero_grad() 82 | for i, sample in tqdm(enumerate(train_loader, start=0)): 83 | x = sample['x'].to(device) 84 | y = sample['y'].to(device) 85 | 86 | output = model(x) 87 | 88 | loss = criterion(output, y) 89 | loss.backward() 90 | 91 | # we only update the weights every config['update_every_batch'] iterations 92 | # This is to simulate a larger batch size 93 | # if (i+1) % config['update_every_batch'] == 0: 94 | optimizer.step() 95 | optimizer.zero_grad() 96 | 97 | # cur_training_loss_history.append(loss.item()) 98 | running_loss += loss.item() 99 | loss_sum += loss.item() 100 | 101 | if i % config['print_freq'] == 0: 102 | print("epoch %d / %d, iteration %d / %d, loss: %.8f" % 103 | (epoch, config['num_epochs'], i, len(train_loader), running_loss/config['print_freq'])) 104 | running_loss = 0.0 105 | 106 | 107 | # scheduler.step() 108 | 109 | 110 | train_loss = loss_sum/len(train_loader) 111 | 112 | test_loss = test( 113 | model, test_loader, criterion, config, device) 114 | 115 | # log down info in tensorboard 116 | writer.add_scalar('training loss', train_loss, epoch) 117 | writer.add_scalar('test loss', test_loss, epoch) 118 | 119 | # if we achieve best val loss, save the model 120 | if test_loss < best_loss: 121 | best_loss = test_loss 122 | 123 | state = {'epoch': epoch, 124 | 'model_state_dict': model.state_dict(), 125 | 'optimizer_state_dict': optimizer.state_dict(), 126 | 'loss': train_loss, 127 | 'test loss': test_loss} 128 | 129 | torch.save(state, config['model_save_path'] + 130 | '_best_test_loss_acc.pt') 131 | print("------------------------------") 132 | print("Finished epoch %d / %d, train loss: %.4f test loss: %.4f" % 133 | (epoch, config['num_epochs'], train_loss, test_loss)) 134 | 135 | # save model 136 | state = {'epoch': epoch, 137 | 'model_state_dict': model.state_dict(), 138 | 'optimizer_state_dict': optimizer.state_dict(), 139 | 'loss': train_loss, 140 | 'test loss': test_loss} 141 | 142 | torch.save(state, config['model_save_path']+'_last_epo.pt') 143 | 144 | writer.close() 145 | 146 | 147 | def main(): 148 | # torch.autograd.set_detect_anomaly(True) 149 | 150 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 151 | print('Using ', device) 152 | 153 | parser = argparse.ArgumentParser(description='Train the network') 154 | parser.add_argument('--training_config', type=str, 155 | default=os.path.dirname(os.path.abspath(__file__))+'/../config/sl3_inv/training_param.yaml') 156 | args = parser.parse_args() 157 | 158 | # load yaml file 159 | config = yaml.safe_load(open(args.training_config)) 160 | 161 | 162 | training_set = sl3InvDataSet2Input(config['train_data_path'], device=device) 163 | train_loader = DataLoader(dataset=training_set, batch_size=config['batch_size'], 164 | shuffle=config['shuffle']) 165 | 166 | test_set = sl3InvDataSet2Input(config['test_data_path'], device=device) 167 | test_loader = DataLoader(dataset=test_set, batch_size=config['batch_size'], 168 | shuffle=config['shuffle']) 169 | 170 | 171 | if config['model_type'] == "LN_relu_bracket": 172 | model = SL3InvariantReluBracketLayers(2).to(device) 173 | elif config['model_type'] == "LN_relu": 174 | model = SL3InvariantReluLayers(2).to(device) 175 | elif config['model_type'] == "LN_bracket": 176 | model = SL3InvariantBracketLayers(2).to(device) 177 | elif config['model_type'] == "MLP": 178 | model = MLP(16).to(device) 179 | elif config['model_type'] == "LN_bracket_no_residual": 180 | model = SL3InvariantBracketNoResidualConnectLayers(2).to(device) 181 | 182 | train(model, train_loader, test_loader, config, device) 183 | 184 | 185 | if __name__ == "__main__": 186 | main() 187 | -------------------------------------------------------------------------------- /experiment/sp4_inv_train.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import argparse 5 | import os 6 | import time 7 | import yaml 8 | import numpy as np 9 | from tqdm import tqdm 10 | import matplotlib.pyplot as plt 11 | import time 12 | 13 | import torch 14 | from torch import nn 15 | import torch.optim as optim 16 | from torch.utils.tensorboard import SummaryWriter 17 | from torch.utils.data import Dataset, DataLoader 18 | 19 | from core.lie_neurons_layers import * 20 | from experiment.sp4_inv_layers import * 21 | from data_loader.sp4_inv_data_loader import * 22 | 23 | 24 | def init_writer(config): 25 | writer = SummaryWriter( 26 | config['log_writer_path']+"_"+str(time.localtime()), comment=config['model_description']) 27 | writer.add_text("train_data_path: ", config['train_data_path']) 28 | writer.add_text("model_save_path: ", config['model_save_path']) 29 | writer.add_text("log_writer_path: ", config['log_writer_path']) 30 | writer.add_text("shuffle: ", str(config['shuffle'])) 31 | writer.add_text("batch_size: ", str(config['batch_size'])) 32 | writer.add_text("init_lr: ", str(config['initial_learning_rate'])) 33 | writer.add_text("num_epochs: ", str(config['num_epochs'])) 34 | 35 | return writer 36 | 37 | 38 | def test(model, test_loader, criterion, config, device): 39 | model.eval() 40 | with torch.no_grad(): 41 | loss_sum = 0.0 42 | for i, sample in tqdm(enumerate(test_loader, start=0)): 43 | x = sample['x'].to(device) 44 | y = sample['y'].to(device) 45 | 46 | output = model(x) 47 | 48 | loss = criterion(output, y) 49 | loss_sum += loss.item() 50 | 51 | loss_avg = loss_sum/len(test_loader) 52 | 53 | return loss_avg 54 | 55 | 56 | def train(model, train_loader, test_loader, config, device='cpu'): 57 | 58 | writer = init_writer(config) 59 | 60 | # create criterion 61 | criterion = nn.MSELoss().to(device) 62 | # criterion = nn.L1Loss().to(device) 63 | optimizer = optim.Adam(model.parameters( 64 | ), lr=config['initial_learning_rate'], weight_decay=config['weight_decay_rate']) 65 | # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=config['learning_rate_decay_rate']) 66 | # scheduler = optim.lr_scheduler.LinearLR(optimizer,total_iters=config['num_epochs']) 67 | 68 | # if config['resume_training']: 69 | # checkpoint = torch.load(config['resume_model_path']) 70 | # model.load_state_dict(checkpoint['model_state_dict']) 71 | # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 72 | # start_epoch = checkpoint['epoch'] 73 | # else: 74 | start_epoch = 0 75 | 76 | best_loss = float("inf") 77 | for epoch in range(start_epoch, config['num_epochs']): 78 | running_loss = 0.0 79 | loss_sum = 0.0 80 | model.train() 81 | optimizer.zero_grad() 82 | for i, sample in tqdm(enumerate(train_loader, start=0)): 83 | x = sample['x'].to(device) 84 | y = sample['y'].to(device) 85 | 86 | output = model(x) 87 | 88 | loss = criterion(output, y) 89 | loss.backward() 90 | 91 | # we only update the weights every config['update_every_batch'] iterations 92 | # This is to simulate a larger batch size 93 | # if (i+1) % config['update_every_batch'] == 0: 94 | optimizer.step() 95 | optimizer.zero_grad() 96 | 97 | # cur_training_loss_history.append(loss.item()) 98 | running_loss += loss.item() 99 | loss_sum += loss.item() 100 | 101 | if i % config['print_freq'] == 0: 102 | print("epoch %d / %d, iteration %d / %d, loss: %.8f" % 103 | (epoch, config['num_epochs'], i, len(train_loader), running_loss/config['print_freq'])) 104 | running_loss = 0.0 105 | 106 | 107 | # scheduler.step() 108 | 109 | 110 | train_loss = loss_sum/len(train_loader) 111 | 112 | test_loss = test( 113 | model, test_loader, criterion, config, device) 114 | 115 | # log down info in tensorboard 116 | writer.add_scalar('training loss', train_loss, epoch) 117 | writer.add_scalar('test loss', test_loss, epoch) 118 | 119 | # if we achieve best val loss, save the model 120 | if test_loss < best_loss: 121 | best_loss = test_loss 122 | 123 | state = {'epoch': epoch, 124 | 'model_state_dict': model.state_dict(), 125 | 'optimizer_state_dict': optimizer.state_dict(), 126 | 'loss': train_loss, 127 | 'test loss': test_loss} 128 | 129 | torch.save(state, config['model_save_path'] + 130 | '_best_test_loss_acc.pt') 131 | print("------------------------------") 132 | print("Finished epoch %d / %d, train loss: %.4f test loss: %.4f" % 133 | (epoch, config['num_epochs'], train_loss, test_loss)) 134 | 135 | # save model 136 | state = {'epoch': epoch, 137 | 'model_state_dict': model.state_dict(), 138 | 'optimizer_state_dict': optimizer.state_dict(), 139 | 'loss': train_loss, 140 | 'test loss': test_loss} 141 | 142 | torch.save(state, config['model_save_path']+'_last_epo.pt') 143 | 144 | writer.close() 145 | 146 | 147 | def main(): 148 | # torch.autograd.set_detect_anomaly(True) 149 | 150 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 151 | print('Using ', device) 152 | 153 | parser = argparse.ArgumentParser(description='Train the network') 154 | parser.add_argument('--training_config', type=str, 155 | default=os.path.dirname(os.path.abspath(__file__))+'/../config/sp4_inv/training_param.yaml') 156 | args = parser.parse_args() 157 | 158 | # load yaml file 159 | config = yaml.safe_load(open(args.training_config)) 160 | 161 | 162 | training_set = sp4InvDataSet(config['train_data_path'], device=device) 163 | train_loader = DataLoader(dataset=training_set, batch_size=config['batch_size'], 164 | shuffle=config['shuffle']) 165 | 166 | test_set = sp4InvDataSet(config['test_data_path'], device=device) 167 | test_loader = DataLoader(dataset=test_set, batch_size=config['batch_size'], 168 | shuffle=config['shuffle']) 169 | 170 | 171 | if config['model_type'] == "LN_relu_bracket": 172 | model = SP4InvariantReluBracketLayers(2).to(device) 173 | elif config['model_type'] == "LN_relu": 174 | model = SP4InvariantReluLayers(2).to(device) 175 | elif config['model_type'] == "LN_bracket": 176 | model = SP4InvariantBracketLayers(2).to(device) 177 | elif config['model_type'] == "MLP": 178 | model = MLP(20).to(device) 179 | elif config['model_type'] == "MLP512": 180 | model = MLP512(20).to(device) 181 | elif config['model_type'] == "LN_bracket_no_residual": 182 | model = SP4InvariantBracketNoResidualConnectLayers(2).to(device) 183 | 184 | train(model, train_loader, test_loader, config, device) 185 | 186 | 187 | if __name__ == "__main__": 188 | main() 189 | -------------------------------------------------------------------------------- /experiment/so3_bch_layers.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import os 5 | import copy 6 | import math 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from einops import rearrange, repeat 12 | from einops.layers.torch import Rearrange 13 | 14 | 15 | from core.lie_alg_util import * 16 | from core.lie_neurons_layers import * 17 | from core.vn_layers import * 18 | 19 | 20 | class SO3EquivariantVNReluLayers(nn.Module): 21 | def __init__(self, in_channels): 22 | super(SO3EquivariantVNReluLayers, self).__init__() 23 | feat_dim = 1024 24 | share_nonlinearity = False 25 | leaky_relu = True 26 | self.ln_fc = VNLinearAndLeakyReLU( 27 | in_channels, feat_dim, share_nonlinearity=share_nonlinearity, leaky_relu=leaky_relu, dim=4) 28 | self.ln_fc2 = VNLinearAndLeakyReLU(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity,leaky_relu=leaky_relu, dim=4) 29 | self.ln_fc3 = VNLinearAndLeakyReLU(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity,leaky_relu=leaky_relu, dim=4) 30 | self.ln_fc4 = VNLinearAndLeakyReLU(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity,leaky_relu=leaky_relu, dim=4) 31 | 32 | self.fc_final = nn.Linear(feat_dim, 1, bias=False) 33 | 34 | def forward(self, x): 35 | ''' 36 | x input of shape [B, F, 3, 1] 37 | ''' 38 | x = self.ln_fc(x) 39 | x = self.ln_fc2(x) 40 | x = self.ln_fc3(x) 41 | x = self.ln_fc4(x) 42 | 43 | x = torch.permute(x, (0, 3, 2, 1)) # [B, 1, 3, F] 44 | x_out = rearrange(self.fc_final(x), 'b 1 k 1 -> b k') # [B, 3] 45 | return x_out 46 | 47 | class SO3EquivariantReluLayers(nn.Module): 48 | def __init__(self, in_channels): 49 | super(SO3EquivariantReluLayers, self).__init__() 50 | feat_dim = 1024 51 | share_nonlinearity = False 52 | leaky_relu = True 53 | self.ln_fc = LNLinearAndKillingRelu( 54 | in_channels, feat_dim, share_nonlinearity=share_nonlinearity, leaky_relu=leaky_relu, algebra_type='so3') 55 | self.ln_fc2 = LNLinearAndKillingRelu(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity,leaky_relu=leaky_relu, algebra_type='so3') 56 | self.ln_fc3 = LNLinearAndKillingRelu(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity,leaky_relu=leaky_relu, algebra_type='so3') 57 | self.ln_fc4 = LNLinearAndKillingRelu(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity,leaky_relu=leaky_relu, algebra_type='so3') 58 | 59 | self.fc_final = nn.Linear(feat_dim, 1, bias=False) 60 | 61 | def forward(self, x): 62 | ''' 63 | x input of shape [B, F, 3, 1] 64 | ''' 65 | x = self.ln_fc(x) # [B, F, 3, 1] 66 | x = self.ln_fc2(x) # [B, F, 3, 1] 67 | x = self.ln_fc3(x) 68 | x = self.ln_fc4(x) 69 | 70 | x = torch.permute(x, (0, 3, 2, 1)) # [B, 1, 3, F] 71 | x_out = rearrange(self.fc_final(x), 'b 1 k 1 -> b k') # [B, 3] 72 | return x_out 73 | 74 | 75 | class SO3EquivariantBracketLayers(nn.Module): 76 | def __init__(self, in_channels): 77 | super(SO3EquivariantBracketLayers, self).__init__() 78 | feat_dim = 1024 79 | share_nonlinearity = False 80 | self.ln_fc = LNLinearAndLieBracket(in_channels, feat_dim,share_nonlinearity=share_nonlinearity, algebra_type='so3') 81 | self.ln_fc2 = LNLinearAndLieBracket(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity, algebra_type='so3') 82 | self.ln_fc3 = LNLinearAndLieBracket(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity, algebra_type='so3') 83 | self.ln_fc4 = LNLinearAndLieBracket(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity, algebra_type='so3') 84 | # self.ln_fc5 = LNLinearAndLieBracket(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity, algebra_type='so3') 85 | # self.ln_fc6 = LNLinearAndLieBracket(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity, algebra_type='so3') 86 | # self.ln_fc7 = LNLinearAndLieBracket(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity, algebra_type='so3') 87 | 88 | self.fc_final = nn.Linear(feat_dim, 1, bias=False) 89 | 90 | def forward(self, x): 91 | ''' 92 | x input of shape [B, F, 3, 1] 93 | ''' 94 | x = self.ln_fc(x) # [B, F, 3, 1] 95 | x = self.ln_fc2(x) # [B, F, 3, 1] 96 | x = self.ln_fc3(x) 97 | x = self.ln_fc4(x) 98 | # x = self.ln_fc5(x) 99 | # x = self.ln_fc6(x) 100 | # x = self.ln_fc7(x) 101 | 102 | x = torch.permute(x, (0, 3, 2, 1)) # [B, 1, 3, F] 103 | x_out = rearrange(self.fc_final(x), 'b 1 k 1 -> b k') # [B, 3] 104 | return x_out 105 | 106 | class SO3EquivariantReluBracketLayers(nn.Module): 107 | def __init__(self, in_channels): 108 | super(SO3EquivariantReluBracketLayers, self).__init__() 109 | feat_dim = 1024 110 | share_nonlinearity = False 111 | leaky_relu = True 112 | self.ln_fc = LNLinearAndKillingRelu( 113 | feat_dim, feat_dim, share_nonlinearity=share_nonlinearity, leaky_relu=leaky_relu, algebra_type='so3') 114 | self.ln_fc2 = LNLinearAndKillingRelu( 115 | feat_dim, feat_dim, share_nonlinearity=share_nonlinearity, leaky_relu=leaky_relu, algebra_type='so3') 116 | self.ln_fc_bracket = LNLinearAndLieBracket(in_channels, feat_dim,share_nonlinearity=share_nonlinearity, algebra_type='so3') 117 | self.ln_fc_bracket2 = LNLinearAndLieBracket(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity, algebra_type='so3') 118 | 119 | self.fc_final = nn.Linear(feat_dim, 1, bias=False) 120 | 121 | def forward(self, x): 122 | ''' 123 | x input of shape [B, F, 3, 1] 124 | ''' 125 | 126 | x = self.ln_fc_bracket(x) # [B, F, 3, 1] 127 | x = self.ln_fc(x) # [B, F, 3, 1] 128 | x = self.ln_fc_bracket2(x) 129 | x = self.ln_fc2(x) 130 | 131 | x = torch.permute(x, (0, 3, 2, 1)) # [B, 1, 3, F] 132 | x_out = rearrange(self.fc_final(x), 'b 1 k 1 -> b k') # [B, 3] 133 | # x_out = rearrange(self.ln_pooling(x), 'b 1 k 1 -> b k') # [B, F, 1, 1] 134 | return x_out 135 | 136 | 137 | class MLP(nn.Module): 138 | def __init__(self, in_channels): 139 | super(MLP, self).__init__() 140 | feat_dim = 1024 141 | self.fc = nn.Linear(in_channels, feat_dim) 142 | self.relu = nn.ReLU() 143 | self.fc2 = nn.Linear(feat_dim, feat_dim) 144 | self.fc3 = nn.Linear(feat_dim, feat_dim) 145 | 146 | self.fc4 = nn.Linear(feat_dim, feat_dim) 147 | self.fc_final = nn.Linear(feat_dim, 3) 148 | 149 | def forward(self, x): 150 | ''' 151 | x input of shape [B, F, 3, 1] 152 | ''' 153 | B, F, _, _ = x.shape 154 | x = torch.reshape(x, (B, -1)) 155 | x = self.fc(x) 156 | x = self.relu(x) 157 | x = self.fc2(x) 158 | x = self.relu(x) 159 | x = self.fc3(x) 160 | x = self.relu(x) 161 | x = self.fc4(x) 162 | x = self.relu(x) 163 | x_out = torch.reshape(self.fc_final(x), (B, 3)) # [B, 3] 164 | 165 | return x_out 166 | -------------------------------------------------------------------------------- /experiment/platonic_solid_cls_layers.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import os 5 | import copy 6 | import math 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from einops import rearrange, repeat 12 | from einops.layers.torch import Rearrange 13 | 14 | from core.lie_alg_util import * 15 | from core.lie_neurons_layers import * 16 | 17 | 18 | class LNPlatonicSolidClassifier(nn.Module): 19 | def __init__(self, in_channels): 20 | super(LNPlatonicSolidClassifier, self).__init__() 21 | feat_dim = 256 22 | share_nonlinearity = False 23 | self.ln_fc = LNLinearAndKillingRelu( 24 | in_channels, feat_dim, share_nonlinearity=share_nonlinearity) 25 | self.ln_fc2 = LNLinearAndKillingRelu( 26 | feat_dim, feat_dim, share_nonlinearity=share_nonlinearity) 27 | self.ln_fc3 = LNLinearAndKillingRelu( 28 | feat_dim, feat_dim, share_nonlinearity=share_nonlinearity) 29 | self.ln_pooling = LNMaxPool( 30 | feat_dim, abs_killing_form=False) # [B, F, 8, 1] 31 | self.ln_inv = LNInvariant(feat_dim, method='self_killing') 32 | self.fc_final = nn.Linear(feat_dim, 3, bias=False) 33 | 34 | def forward(self, x): 35 | ''' 36 | x input of shape [B, F, 8, 1] 37 | ''' 38 | x = self.ln_fc(x) # [B, F, 8, N] 39 | # x = self.ln_fc2(x) # [B, F, 8, N] 40 | # x = self.ln_fc3(x) # [B, F, 8, N] 41 | x = self.ln_pooling(x) # [B, F, 8, 1] 42 | x_inv = self.ln_inv(x).unsqueeze(-1) # [B, F, 1, 1] 43 | x_inv = torch.permute(x_inv, (0, 3, 2, 1)) # [B, 1, 1, F] 44 | x_out = rearrange(self.fc_final(x_inv), 45 | 'b 1 1 cls -> b cls') # [B, cls] 46 | 47 | return x_out 48 | 49 | 50 | class LNReluPlatonicSolidClassifier(nn.Module): 51 | def __init__(self, in_channels): 52 | super(LNReluPlatonicSolidClassifier, self).__init__() 53 | feat_dim = 256 54 | share_nonlinearity = False 55 | self.ln_fc = LNLinearAndKillingRelu( 56 | in_channels, feat_dim, share_nonlinearity=share_nonlinearity) 57 | self.ln_pooling = LNMaxPool( 58 | feat_dim, abs_killing_form=False) # [B, F, 8, 1] 59 | self.ln_inv = LNInvariant(feat_dim, method='self_killing') 60 | self.fc_final = nn.Linear(feat_dim, 3, bias=False) 61 | 62 | def forward(self, x): 63 | ''' 64 | x input of shape [B, F, 8, 1] 65 | ''' 66 | x = self.ln_fc(x) # [B, F, 8, N] 67 | x = self.ln_pooling(x) # [B, F, 8, 1] 68 | x_inv = self.ln_inv(x).unsqueeze(-1) # [B, F, 1, 1] 69 | x_inv = torch.permute(x_inv, (0, 3, 2, 1)) # [B, 1, 1, F] 70 | x_out = rearrange(self.fc_final(x_inv), 71 | 'b 1 1 cls -> b cls') # [B, cls] 72 | 73 | return x_out 74 | 75 | 76 | class LNBracketPlatonicSolidClassifier(nn.Module): 77 | def __init__(self, in_channels): 78 | super(LNBracketPlatonicSolidClassifier, self).__init__() 79 | feat_dim = 256 80 | share_nonlinearity = False 81 | self.ln_fc = LNLinearAndLieBracket( 82 | in_channels, feat_dim, share_nonlinearity=share_nonlinearity) 83 | # self.ln_fc2 = LNLinearAndLieBracket(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity) 84 | self.ln_pooling = LNMaxPool( 85 | feat_dim, abs_killing_form=False) # [B, F, 8, 1] 86 | self.ln_inv = LNInvariant(feat_dim, method='self_killing') 87 | self.fc_final = nn.Linear(feat_dim, 3, bias=False) 88 | 89 | def forward(self, x): 90 | ''' 91 | x input of shape [B, F, 8, 1] 92 | ''' 93 | x = self.ln_fc(x) # [B, F, 8, N] 94 | x = self.ln_pooling(x) # [B, F, 8, 1] 95 | x_inv = self.ln_inv(x).unsqueeze(-1) # [B, F, 1, 1] 96 | x_inv = torch.permute(x_inv, (0, 3, 2, 1)) # [B, 1, 1, F] 97 | x_out = rearrange(self.fc_final(x_inv), 98 | 'b 1 1 cls -> b cls') # [B, cls] 99 | 100 | return x_out 101 | 102 | 103 | class LNReluBracketPlatonicSolidClassifier(nn.Module): 104 | def __init__(self, in_channels): 105 | super(LNReluBracketPlatonicSolidClassifier, self).__init__() 106 | feat_dim = 256 107 | share_nonlinearity = False 108 | self.ln_fc = LNLinearAndKillingRelu( 109 | in_channels, feat_dim, share_nonlinearity=share_nonlinearity) 110 | self.ln_fc2 = LNLinearAndLieBracket( 111 | feat_dim, feat_dim, share_nonlinearity=share_nonlinearity) 112 | self.ln_pooling = LNMaxPool( 113 | feat_dim, abs_killing_form=False) # [B, F, 8, 1] 114 | self.ln_inv = LNInvariant(feat_dim, method='self_killing') 115 | self.fc_final = nn.Linear(feat_dim, 3, bias=False) 116 | 117 | def forward(self, x): 118 | ''' 119 | x input of shape [B, F, 8, 1] 120 | ''' 121 | x = self.ln_fc(x) # [B, F, 8, N] 122 | x = self.ln_fc2(x) # [B, F, 8, N] 123 | x = self.ln_pooling(x) # [B, F, 8, 1] 124 | x_inv = self.ln_inv(x).unsqueeze(-1) # [B, F, 1, 1] 125 | x_inv = torch.permute(x_inv, (0, 3, 2, 1)) # [B, 1, 1, F] 126 | x_out = rearrange(self.fc_final(x_inv), 127 | 'b 1 1 cls -> b cls') # [B, cls] 128 | 129 | return x_out 130 | 131 | 132 | class LNBracketNoResidualConnectPlatonicSolidClassifier(nn.Module): 133 | def __init__(self, in_channels): 134 | super(LNBracketNoResidualConnectPlatonicSolidClassifier, self).__init__() 135 | feat_dim = 256 136 | share_nonlinearity = False 137 | self.ln_fc = LNLinearAndLieBracketNoResidualConnect( 138 | in_channels, feat_dim, share_nonlinearity=share_nonlinearity) 139 | # self.ln_fc2 = LNLinearAndLieBracket(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity) 140 | self.ln_pooling = LNMaxPool( 141 | feat_dim, abs_killing_form=False) # [B, F, 8, 1] 142 | self.ln_inv = LNInvariant(feat_dim, method='self_killing') 143 | self.fc_final = nn.Linear(feat_dim, 3, bias=False) 144 | 145 | def forward(self, x): 146 | ''' 147 | x input of shape [B, F, 8, 1] 148 | ''' 149 | x = self.ln_fc(x) # [B, F, 8, N] 150 | x = self.ln_pooling(x) # [B, F, 8, 1] 151 | x_inv = self.ln_inv(x).unsqueeze(-1) # [B, F, 1, 1] 152 | x_inv = torch.permute(x_inv, (0, 3, 2, 1)) # [B, 1, 1, F] 153 | x_out = rearrange(self.fc_final(x_inv), 154 | 'b 1 1 cls -> b cls') # [B, cls] 155 | 156 | return x_out 157 | 158 | 159 | class MLP(nn.Module): 160 | def __init__(self, in_channels): 161 | super(MLP, self).__init__() 162 | feat_dim = 256 163 | self.fc = nn.Linear(in_channels, feat_dim) 164 | self.fc2 = nn.Linear(feat_dim, feat_dim) 165 | self.fc3 = nn.Linear(feat_dim, feat_dim) 166 | self.relu = nn.ReLU() 167 | self.fc_final = nn.Linear(feat_dim, 3) 168 | 169 | def forward(self, x): 170 | ''' 171 | x input of shape [B, F, 8, N] 172 | ''' 173 | B, F, _, _ = x.shape 174 | x = torch.reshape(x, (B, -1)) 175 | x = self.fc(x) 176 | x = self.relu(x) 177 | x = self.fc2(x) 178 | x = self.relu(x) 179 | x = self.fc3(x) 180 | x = self.relu(x) 181 | x_out = torch.reshape(self.fc_final(x), (B, 3)) # [B, cls] 182 | 183 | return x_out 184 | -------------------------------------------------------------------------------- /experiment/sl3_equiv_train.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import argparse 5 | import os 6 | import time 7 | import yaml 8 | import numpy as np 9 | from tqdm import tqdm 10 | import matplotlib.pyplot as plt 11 | import time 12 | 13 | import torch 14 | from torch import nn 15 | import torch.optim as optim 16 | from torch.utils.tensorboard import SummaryWriter 17 | from torch.utils.data import Dataset, DataLoader 18 | 19 | from core.lie_neurons_layers import * 20 | from experiment.sl3_equiv_layers import * 21 | from data_loader.sl3_equiv_data_loader import * 22 | 23 | 24 | def init_writer(config): 25 | writer = SummaryWriter( 26 | config['log_writer_path']+"_"+str(time.localtime()), comment=config['model_description']) 27 | writer.add_text("train_data_path: ", config['train_data_path']) 28 | writer.add_text("model_save_path: ", config['model_save_path']) 29 | writer.add_text("log_writer_path: ", config['log_writer_path']) 30 | writer.add_text("shuffle: ", str(config['shuffle'])) 31 | writer.add_text("batch_size: ", str(config['batch_size'])) 32 | writer.add_text("init_lr: ", str(config['initial_learning_rate'])) 33 | writer.add_text("num_epochs: ", str(config['num_epochs'])) 34 | 35 | return writer 36 | 37 | 38 | def test(model, test_loader, criterion, config, device): 39 | model.eval() 40 | with torch.no_grad(): 41 | loss_sum = 0.0 42 | for i, sample in tqdm(enumerate(test_loader, start=0)): 43 | x = sample['x'].to(device) 44 | y = sample['y'].to(device) 45 | 46 | output = model(x) 47 | 48 | loss = criterion(output, y) 49 | loss_sum += loss.item() 50 | 51 | loss_avg = loss_sum/len(test_loader) 52 | 53 | return loss_avg 54 | 55 | 56 | def train(model, train_loader, test_loader, config, device='cpu'): 57 | 58 | writer = init_writer(config) 59 | 60 | # create criterion 61 | criterion = nn.MSELoss().to(device) 62 | # criterion = nn.L1Loss().to(device) 63 | optimizer = optim.Adam(model.parameters( 64 | ), lr=config['initial_learning_rate'], weight_decay=config['weight_decay_rate']) 65 | # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=config['learning_rate_decay_rate']) 66 | # scheduler = optim.lr_scheduler.LinearLR(optimizer,total_iters=config['num_epochs']) 67 | 68 | # if config['resume_training']: 69 | # checkpoint = torch.load(config['resume_model_path']) 70 | # model.load_state_dict(checkpoint['model_state_dict']) 71 | # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 72 | # start_epoch = checkpoint['epoch'] 73 | # else: 74 | start_epoch = 0 75 | 76 | best_loss = float("inf") 77 | for epoch in range(start_epoch, config['num_epochs']): 78 | running_loss = 0.0 79 | loss_sum = 0.0 80 | model.train() 81 | optimizer.zero_grad() 82 | for i, sample in tqdm(enumerate(train_loader, start=0)): 83 | x = sample['x'].to(device) 84 | y = sample['y'].to(device) 85 | 86 | output = model(x) 87 | 88 | loss = criterion(output, y) 89 | loss.backward() 90 | 91 | # we only update the weights every config['update_every_batch'] iterations 92 | # This is to simulate a larger batch size 93 | # if (i+1) % config['update_every_batch'] == 0: 94 | optimizer.step() 95 | optimizer.zero_grad() 96 | 97 | # cur_training_loss_history.append(loss.item()) 98 | running_loss += loss.item() 99 | loss_sum += loss.item() 100 | 101 | if i % config['print_freq'] == 0: 102 | print("epoch %d / %d, iteration %d / %d, loss: %.8f" % 103 | (epoch, config['num_epochs'], i, len(train_loader), running_loss/config['print_freq'])) 104 | running_loss = 0.0 105 | 106 | 107 | # scheduler.step() 108 | 109 | # train_top1, train_top5, _ = validate(train_loader, model, criterion, config, device) 110 | 111 | train_loss = loss_sum/len(train_loader) 112 | 113 | test_loss = test( 114 | model, test_loader, criterion, config, device) 115 | 116 | # log down info in tensorboard 117 | writer.add_scalar('training loss', train_loss, epoch) 118 | writer.add_scalar('test loss', test_loss, epoch) 119 | 120 | # if we achieve best val loss, save the model 121 | if test_loss < best_loss: 122 | best_loss = test_loss 123 | 124 | state = {'epoch': epoch, 125 | 'model_state_dict': model.state_dict(), 126 | 'optimizer_state_dict': optimizer.state_dict(), 127 | 'loss': train_loss, 128 | 'test loss': test_loss} 129 | 130 | torch.save(state, config['model_save_path'] + 131 | '_best_test_loss_acc.pt') 132 | print("------------------------------") 133 | # print("Finished epoch %d / %d, training top 1 acc: %.4f, training top 5 acc: %.4f, \ 134 | # validation top1 acc: %.4f, validation top 5 acc: %.4f" %\ 135 | # (epoch, config['num_epochs'], train_top1, train_top5, val_top1, val_top5)) 136 | print("Finished epoch %d / %d, train loss: %.4f test loss: %.4f" % 137 | (epoch, config['num_epochs'], train_loss, test_loss)) 138 | 139 | # save model 140 | state = {'epoch': epoch, 141 | 'model_state_dict': model.state_dict(), 142 | 'optimizer_state_dict': optimizer.state_dict(), 143 | 'loss': train_loss, 144 | 'test loss': test_loss} 145 | 146 | torch.save(state, config['model_save_path']+'_last_epo.pt') 147 | 148 | writer.close() 149 | 150 | 151 | def main(): 152 | # torch.autograd.set_detect_anomaly(True) 153 | 154 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 155 | print('Using ', device) 156 | 157 | parser = argparse.ArgumentParser(description='Train the network') 158 | parser.add_argument('--training_config', type=str, 159 | default=os.path.dirname(os.path.abspath(__file__))+'/../config/sl3_equiv/training_param.yaml') 160 | args = parser.parse_args() 161 | 162 | # load yaml file 163 | config = yaml.safe_load(open(args.training_config)) 164 | 165 | # 5 input 166 | training_set = sl3EquivDataSetLieBracket2Input( 167 | config['train_data_path'], device=device) 168 | train_loader = DataLoader(dataset=training_set, batch_size=config['batch_size'], 169 | shuffle=config['shuffle']) 170 | 171 | test_set = sl3EquivDataSetLieBracket2Input( 172 | config['test_data_path'], device=device) 173 | test_loader = DataLoader(dataset=test_set, batch_size=config['batch_size'], 174 | shuffle=config['shuffle']) 175 | 176 | if config['model_type'] == "LN_relu_bracket": 177 | model = SL3EquivariantReluBracketLayers(2).to(device) 178 | elif config['model_type'] == "LN_relu": 179 | model = SL3EquivariantReluLayers(2).to(device) 180 | elif config['model_type'] == "LN_bracket": 181 | model = SL3EquivariantBracketLayers(2).to(device) 182 | elif config['model_type'] == "MLP": 183 | model = MLP(16).to(device) 184 | elif config['model_type'] == "LN_bracket_no_residual": 185 | model = SL3EquivariantBracketNoResidualConnectLayers(2).to(device) 186 | 187 | train(model, train_loader, test_loader, config, device) 188 | 189 | 190 | if __name__ == "__main__": 191 | main() 192 | -------------------------------------------------------------------------------- /matlab/sl3_equivariant_lift_find_bound.asv: -------------------------------------------------------------------------------- 1 | clc; 2 | clear; 3 | 4 | compute_K = false; 5 | num_K = 1000; 6 | rnd_scale_K = 100; 7 | 8 | compute_A = false; 9 | num_A = 1000; 10 | rnd_scale_A = 100; 11 | 12 | 13 | compute_N = false; 14 | num_N = 1000; 15 | rnd_scale_N = 100; 16 | 17 | compute_KN = true; 18 | num_KN = 1000; 19 | rnd_scale_KN = 100; 20 | 21 | syms v1 v2 v3 v4 v5 v6 v7 v8 22 | assumeAlso([v1 v2 v3 v4 v5 v6 v7 v8],'real') 23 | 24 | syms h1 h2 h3 h4 h5 h6 h7 h8 h9 25 | assumeAlso([h1 h2 h3 h4 h5 h6 h7 h8 h9],'real') 26 | %% 27 | 28 | E1 = [1, 0, 0; 0, -1, 0; 0, 0, 0]; 29 | E2 = [0, 1, 0; 1, 0, 0; 0, 0, 0]; 30 | E3 = [0, -1, 0; 1, 0, 0; 0, 0, 0]; 31 | E4 = [1, 0, 0; 0, 1, 0; 0, 0, -2]; 32 | E5 = [0, 0, 1; 0, 0, 0; 0, 0, 0]; 33 | E6 = [0, 0, 0; 0, 0, 1; 0, 0, 0]; 34 | E7 = [0, 0, 0; 0, 0, 0; 1, 0, 0]; 35 | E8 = [0, 0, 0; 0, 0, 0; 0, 1, 0]; 36 | 37 | Ex = [0, 0, 0;0, 0, -1;0, 1, 0]; 38 | Ey = [0, 0, 1;0, 0, 0;-1, 0, 0]; 39 | Ez = [0, -1, 0;1, 0, 0;0, 0, 0]; 40 | 41 | % E1 = [0, 0, 1; 0, 0, 0; 0, 0, 0]; 42 | % E2 = [0, 0, 0; 0, 0, 1; 0, 0, 0]; 43 | % E3 = [0, -1, 0; 1, 0, 0; 0, 0, 0]; 44 | % E4 = [0, 0, 0; 0, 0, 0; 0, 0, -1]; 45 | % E5 = [1, 0, 0; 0, -1, 0; 0, 0, 0]; 46 | % E6 = [0, 1, 0; 0, 0, 0; 0, 0, 0]; 47 | % E7 = [0, 0, 0; 0, 0, 0; 1, 0, 0]; 48 | % E8 = [0, 0, 0; 0, 0, 0; 0, 1, 0]; 49 | 50 | E = {E1,E2,E3,E4,E5,E6,E7,E8}; 51 | 52 | E1_vec = reshape(E1,1,[])'; 53 | E2_vec = reshape(E2,1,[])'; 54 | E3_vec = reshape(E3,1,[])'; 55 | E4_vec = reshape(E4,1,[])'; 56 | E5_vec = reshape(E5,1,[])'; 57 | E6_vec = reshape(E6,1,[])'; 58 | E7_vec = reshape(E7,1,[])'; 59 | E8_vec = reshape(E8,1,[])'; 60 | 61 | 62 | E_vec = [E1_vec, E2_vec, E3_vec, E4_vec, E5_vec, E6_vec, E7_vec, E8_vec]; 63 | 64 | x_hat = v1*E1+v2*E2+v3*E3+v4*E4+v5*E5+v6*E6+v7*E7+v8*E8; 65 | 66 | 67 | %% find h 68 | H = [h1,h2,h3;h4,h5,h6;h7,h8,h9]; 69 | Ad_H_hat = H*x_hat*inv(H); 70 | Ad_H_hat_vec = reshape(Ad_H_hat,1,[])'; 71 | 72 | % solve least square to obtain x 73 | x = inv(E_vec'*E_vec)*(E_vec')*Ad_H_hat_vec; 74 | var = [v1,v2,v3,v4,v5,v6,v7,v8]; 75 | [Ad_H_sym,b]=equationsToMatrix(x,var); 76 | 77 | syms f(h1,h2,h3,h4,h5,h6,h7,h8,h9); 78 | f(h1,h2,h3,h4,h5,h6,h7,h8,h9) = Ad_H_sym; 79 | 80 | %% K 81 | if compute_K 82 | rnd = rnd_scale_K*rand(num_K,3); 83 | tol = 1e-8; 84 | K_results = zeros(num_K,4); 85 | K_color = zeros(num_K,3); 86 | K_norm = zeros(num_K,1); 87 | for i=1:num_K 88 | H = expm(hat_so3(rnd(i,:))); 89 | Ad_test = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3))); 90 | D = kron(inv(H'),Ad_test)-eye(24); 91 | 92 | K_results(i,1) = rnd(i,1); 93 | K_results(i,2) = rnd(i,2); 94 | K_results(i,3) = rnd(i,3); 95 | K_results(i,4) = rank(D,tol); 96 | K_norm(i,1) = norm(K_results(i,1:3)); 97 | if K_results(i,4)<24 98 | K_results(i,5) = 0; 99 | K_color(i,:) = [0, 0.4470, 0.7410]; 100 | elseif K_results(i,4) == 24 101 | % H 102 | K_results(i,5) = 1; 103 | K_color(i,:) = [0.8500, 0.3250, 0.0980]; 104 | end 105 | end 106 | 107 | K = figure(1); 108 | S = repmat(25,num_K,1); 109 | scatter3(K_results(:,1),K_results(:,2),K_results(:,3),S,K_color(:,:),'filled'); 110 | xlabel("x") 111 | ylabel("y") 112 | zlabel("z") 113 | 114 | K2 = figure(2); 115 | scatter(K_norm,K_results(:,5),S,K_color,'filled'); 116 | end 117 | %% 118 | % idx = K_results(:,5) == 1; 119 | % no_sol_K = K_results(idx,:); 120 | % no_sol_norm = K_norm(idx,:); 121 | % no_sol_norm_mod_pi = mod(no_sol_norm,pi); 122 | % K3 = figure(3); 123 | % histogram(no_sol_norm_mod_pi,314); 124 | 125 | %% A 126 | if compute_A 127 | rnd = rnd_scale_A*rand(num_A,2); 128 | tol = 1e-8; 129 | A_results = zeros(num_A,4); 130 | A_color = zeros(num_A,3); 131 | A_norm = zeros(num_A,1); 132 | for i = 1:num_A 133 | a1 = rnd(i,1); 134 | % a2 = rnd(i,2); 135 | a2 = 1/a1; 136 | H = [a1, 0,0; 0,a2,0;0,0,1/a1/a2]; 137 | % H = [a1,0,0; 0,1/a1,0;0,0,1]; 138 | 139 | Ad_test = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3))); 140 | D = kron(inv(H'),Ad_test)-eye(24); 141 | 142 | A_results(i,1) = a1; 143 | A_results(i,2) = a2; 144 | A_results(i,4) = rank(D,tol); 145 | A_norm(i,1) = norm(A_results(i,1:2)); 146 | if A_results(i,4)<24 147 | A_results(i,5) = 0; 148 | A_color(i,:) = [0, 0.4470, 0.7410]; 149 | elseif A_results(i,4) == 24 150 | % H 151 | A_results(i,5) = 1; 152 | A_color(i,:) = [0.8500, 0.3250, 0.0980]; 153 | end 154 | end 155 | 156 | A_figure1 = figure(1); 157 | S = repmat(25,num_A,1); 158 | scatter(A_results(:,1),A_results(:,2),S,A_color(:,:),'filled'); 159 | xlabel("x") 160 | ylabel("y") 161 | zlabel("z") 162 | title("Solution existence for A") 163 | 164 | A_figure2 = figure(2); 165 | scatter(A_norm,A_results(:,5),S,A_color,'filled'); 166 | 167 | 168 | idx_A = A_results(:,5) == 1; 169 | no_sol_A = A_results(idx_A,:); 170 | end 171 | 172 | %% N 173 | if compute_N 174 | rnd = rnd_scale_N*rand(num_N,3); 175 | tol = 1e-8; 176 | N_results = zeros(num_N,4); 177 | N_color = zeros(num_N,3); 178 | N_norm = zeros(num_N,1); 179 | for i = 1:num_N 180 | n1 = rnd(i,1); 181 | n2 = rnd(i,2); 182 | n3 = rnd(i,3); 183 | H = [1 n1 n2; 0 1 n3; 0 0 1]; 184 | 185 | Ad_test = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3))); 186 | D = kron(inv(H'),Ad_test)-eye(24); 187 | 188 | N_results(i,1) = n1; 189 | N_results(i,2) = n2; 190 | N_results(i,3) = n3; 191 | N_results(i,4) = rank(D,tol); 192 | N_norm(i,1) = norm(N_results(i,1:2)); 193 | if N_results(i,4)<24 194 | N_results(i,5) = 0; 195 | N_color(i,:) = [0, 0.4470, 0.7410]; 196 | elseif N_results(i,4) == 24 197 | % H 198 | N_results(i,5) = 1; 199 | N_color(i,:) = [0.8500, 0.3250, 0.0980]; 200 | end 201 | end 202 | 203 | N_figure1 = figure(1); 204 | S = repmat(25,num_N,1); 205 | scatter3(N_results(:,1),N_results(:,2),N_results(:,3),S,N_color(:,:),'filled'); 206 | xlabel("x") 207 | ylabel("y") 208 | zlabel("z") 209 | title("Solution existence for N") 210 | 211 | N_figure2 = figure(2); 212 | scatter(N_norm,N_results(:,5),S,N_color,'filled'); 213 | 214 | 215 | idx_N = N_results(:,5) == 1; 216 | no_sol_N = N_results(idx_N,:); 217 | end 218 | 219 | %% KN 220 | if compute_KN 221 | rnd = rnd_scale_KN*rand(num_KN,3); 222 | rnd2 = rnd_scale_KN*rand(num_KN,3); 223 | tol = 1e-8; 224 | KN_results = zeros(num_KN,4); 225 | KN_color = zeros(num_KN,3); 226 | KN_norm = zeros(num_KN,1); 227 | for i = 1:num_KN 228 | n1 = rnd(i,1); 229 | n2 = rnd(i,2); 230 | n3 = rnd(i,3); 231 | N = [1 n1 n2; 0 1 n3; 0 0 1]; 232 | K = expm(hat_so3(rnd2(i,:))); 233 | 234 | H = K*N; 235 | 236 | Ad_test = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3))); 237 | D = kron(inv(H'),Ad_test)-eye(24); 238 | 239 | KN_results(i,1) = n1; 240 | KN_results(i,2) = n2; 241 | KN_results(i,3) = n3; 242 | KN_results(i,4) = rank(D,tol); 243 | KN_results(i,6) = rnd2(i,1); 244 | KN_results(i,7) = rnd2(i,2); 245 | KN_results(i,8) = rnd2(i,3); 246 | KN_norm(i,1) = norm(N_results(i,1:2)); 247 | if KN_results(i,4)<24 248 | KN_results(i,5) = 0; 249 | KN_color(i,:) = [0, 0.4470, 0.7410]; 250 | elseif N_results(i,4) == 24 251 | % H 252 | KN_results(i,5) = 1; 253 | KN_color(i,:) = [0.8500, 0.3250, 0.0980]; 254 | end 255 | end 256 | 257 | KN_figure1 = figure(1); 258 | S = repmat(25,num_N,1); 259 | scatter3(KN_results(:,1),KN_results(:,2),KN_results(:,3),S,KN_color(:,:),'filled'); 260 | xlabel("x") 261 | ylabel("y") 262 | zlabel("z") 263 | title("Solution existence for N") 264 | 265 | N_figure2 = figure(2); 266 | scatter(N_norm,N_results(:,5),S,N_color,'filled'); 267 | 268 | 269 | idx_N = N_results(:,5) == 1; 270 | no_sol_N = N_results(idx_N,:); 271 | end -------------------------------------------------------------------------------- /matlab/sl3_equivariant_lift_find_bound.m: -------------------------------------------------------------------------------- 1 | clc; 2 | clear; 3 | 4 | compute_K = false; 5 | num_K = 1000; 6 | rnd_scale_K = 100; 7 | 8 | compute_A = false; 9 | num_A = 1000; 10 | rnd_scale_A = 100; 11 | 12 | 13 | compute_N = false; 14 | num_N = 1000; 15 | rnd_scale_N = 100; 16 | 17 | compute_KN = true; 18 | num_KN = 1000; 19 | rnd_scale_KN = 1; 20 | 21 | syms v1 v2 v3 v4 v5 v6 v7 v8 22 | assumeAlso([v1 v2 v3 v4 v5 v6 v7 v8],'real') 23 | 24 | syms h1 h2 h3 h4 h5 h6 h7 h8 h9 25 | assumeAlso([h1 h2 h3 h4 h5 h6 h7 h8 h9],'real') 26 | %% 27 | 28 | E1 = [1, 0, 0; 0, -1, 0; 0, 0, 0]; 29 | E2 = [0, 1, 0; 1, 0, 0; 0, 0, 0]; 30 | E3 = [0, -1, 0; 1, 0, 0; 0, 0, 0]; 31 | E4 = [1, 0, 0; 0, 1, 0; 0, 0, -2]; 32 | E5 = [0, 0, 1; 0, 0, 0; 0, 0, 0]; 33 | E6 = [0, 0, 0; 0, 0, 1; 0, 0, 0]; 34 | E7 = [0, 0, 0; 0, 0, 0; 1, 0, 0]; 35 | E8 = [0, 0, 0; 0, 0, 0; 0, 1, 0]; 36 | 37 | Ex = [0, 0, 0;0, 0, -1;0, 1, 0]; 38 | Ey = [0, 0, 1;0, 0, 0;-1, 0, 0]; 39 | Ez = [0, -1, 0;1, 0, 0;0, 0, 0]; 40 | 41 | % E1 = [0, 0, 1; 0, 0, 0; 0, 0, 0]; 42 | % E2 = [0, 0, 0; 0, 0, 1; 0, 0, 0]; 43 | % E3 = [0, -1, 0; 1, 0, 0; 0, 0, 0]; 44 | % E4 = [0, 0, 0; 0, 0, 0; 0, 0, -1]; 45 | % E5 = [1, 0, 0; 0, -1, 0; 0, 0, 0]; 46 | % E6 = [0, 1, 0; 0, 0, 0; 0, 0, 0]; 47 | % E7 = [0, 0, 0; 0, 0, 0; 1, 0, 0]; 48 | % E8 = [0, 0, 0; 0, 0, 0; 0, 1, 0]; 49 | 50 | E = {E1,E2,E3,E4,E5,E6,E7,E8}; 51 | 52 | E1_vec = reshape(E1,1,[])'; 53 | E2_vec = reshape(E2,1,[])'; 54 | E3_vec = reshape(E3,1,[])'; 55 | E4_vec = reshape(E4,1,[])'; 56 | E5_vec = reshape(E5,1,[])'; 57 | E6_vec = reshape(E6,1,[])'; 58 | E7_vec = reshape(E7,1,[])'; 59 | E8_vec = reshape(E8,1,[])'; 60 | 61 | 62 | E_vec = [E1_vec, E2_vec, E3_vec, E4_vec, E5_vec, E6_vec, E7_vec, E8_vec]; 63 | 64 | x_hat = v1*E1+v2*E2+v3*E3+v4*E4+v5*E5+v6*E6+v7*E7+v8*E8; 65 | 66 | 67 | %% find h 68 | H = [h1,h2,h3;h4,h5,h6;h7,h8,h9]; 69 | Ad_H_hat = H*x_hat*inv(H); 70 | Ad_H_hat_vec = reshape(Ad_H_hat,1,[])'; 71 | 72 | % solve least square to obtain x 73 | x = inv(E_vec'*E_vec)*(E_vec')*Ad_H_hat_vec; 74 | var = [v1,v2,v3,v4,v5,v6,v7,v8]; 75 | [Ad_H_sym,b]=equationsToMatrix(x,var); 76 | 77 | syms f(h1,h2,h3,h4,h5,h6,h7,h8,h9); 78 | f(h1,h2,h3,h4,h5,h6,h7,h8,h9) = Ad_H_sym; 79 | 80 | %% K 81 | if compute_K 82 | rnd = rnd_scale_K*rand(num_K,3); 83 | tol = 1e-8; 84 | K_results = zeros(num_K,4); 85 | K_color = zeros(num_K,3); 86 | K_norm = zeros(num_K,1); 87 | for i=1:num_K 88 | H = expm(hat_so3(rnd(i,:))); 89 | Ad_test = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3))); 90 | D = kron(inv(H'),Ad_test)-eye(24); 91 | 92 | K_results(i,1) = rnd(i,1); 93 | K_results(i,2) = rnd(i,2); 94 | K_results(i,3) = rnd(i,3); 95 | K_results(i,4) = rank(D,tol); 96 | K_norm(i,1) = norm(K_results(i,1:3)); 97 | if K_results(i,4)<24 98 | K_results(i,5) = 0; 99 | K_color(i,:) = [0, 0.4470, 0.7410]; 100 | elseif K_results(i,4) == 24 101 | % H 102 | K_results(i,5) = 1; 103 | K_color(i,:) = [0.8500, 0.3250, 0.0980]; 104 | end 105 | end 106 | 107 | K = figure(1); 108 | S = repmat(25,num_K,1); 109 | scatter3(K_results(:,1),K_results(:,2),K_results(:,3),S,K_color(:,:),'filled'); 110 | xlabel("x") 111 | ylabel("y") 112 | zlabel("z") 113 | 114 | K2 = figure(2); 115 | scatter(K_norm,K_results(:,5),S,K_color,'filled'); 116 | end 117 | %% 118 | % idx = K_results(:,5) == 1; 119 | % no_sol_K = K_results(idx,:); 120 | % no_sol_norm = K_norm(idx,:); 121 | % no_sol_norm_mod_pi = mod(no_sol_norm,pi); 122 | % K3 = figure(3); 123 | % histogram(no_sol_norm_mod_pi,314); 124 | 125 | %% A 126 | if compute_A 127 | rnd = rnd_scale_A*rand(num_A,2); 128 | tol = 1e-8; 129 | A_results = zeros(num_A,4); 130 | A_color = zeros(num_A,3); 131 | A_norm = zeros(num_A,1); 132 | for i = 1:num_A 133 | a1 = rnd(i,1); 134 | % a2 = rnd(i,2); 135 | a2 = 1/a1; 136 | H = [a1, 0,0; 0,a2,0;0,0,1/a1/a2]; 137 | % H = [a1,0,0; 0,1/a1,0;0,0,1]; 138 | 139 | Ad_test = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3))); 140 | D = kron(inv(H'),Ad_test)-eye(24); 141 | 142 | A_results(i,1) = a1; 143 | A_results(i,2) = a2; 144 | A_results(i,4) = rank(D,tol); 145 | A_norm(i,1) = norm(A_results(i,1:2)); 146 | if A_results(i,4)<24 147 | A_results(i,5) = 0; 148 | A_color(i,:) = [0, 0.4470, 0.7410]; 149 | elseif A_results(i,4) == 24 150 | % H 151 | A_results(i,5) = 1; 152 | A_color(i,:) = [0.8500, 0.3250, 0.0980]; 153 | end 154 | end 155 | 156 | A_figure1 = figure(1); 157 | S = repmat(25,num_A,1); 158 | scatter(A_results(:,1),A_results(:,2),S,A_color(:,:),'filled'); 159 | xlabel("x") 160 | ylabel("y") 161 | zlabel("z") 162 | title("Solution existence for A") 163 | 164 | A_figure2 = figure(2); 165 | scatter(A_norm,A_results(:,5),S,A_color,'filled'); 166 | 167 | 168 | idx_A = A_results(:,5) == 1; 169 | no_sol_A = A_results(idx_A,:); 170 | end 171 | 172 | %% N 173 | if compute_N 174 | rnd = rnd_scale_N*rand(num_N,3); 175 | tol = 1e-8; 176 | N_results = zeros(num_N,4); 177 | N_color = zeros(num_N,3); 178 | N_norm = zeros(num_N,1); 179 | for i = 1:num_N 180 | n1 = rnd(i,1); 181 | n2 = rnd(i,2); 182 | n3 = rnd(i,3); 183 | H = [1 n1 n2; 0 1 n3; 0 0 1]; 184 | 185 | Ad_test = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3))); 186 | D = kron(inv(H'),Ad_test)-eye(24); 187 | 188 | N_results(i,1) = n1; 189 | N_results(i,2) = n2; 190 | N_results(i,3) = n3; 191 | N_results(i,4) = rank(D,tol); 192 | N_norm(i,1) = norm(N_results(i,1:3)); 193 | if N_results(i,4)<24 194 | N_results(i,5) = 0; 195 | N_color(i,:) = [0, 0.4470, 0.7410]; 196 | elseif N_results(i,4) == 24 197 | % H 198 | N_results(i,5) = 1; 199 | N_color(i,:) = [0.8500, 0.3250, 0.0980]; 200 | end 201 | end 202 | 203 | N_figure1 = figure(1); 204 | S = repmat(25,num_N,1); 205 | scatter3(N_results(:,1),N_results(:,2),N_results(:,3),S,N_color(:,:),'filled'); 206 | xlabel("x") 207 | ylabel("y") 208 | zlabel("z") 209 | title("Solution existence for N") 210 | 211 | N_figure2 = figure(2); 212 | scatter(N_norm,N_results(:,5),S,N_color,'filled'); 213 | 214 | 215 | idx_N = N_results(:,5) == 1; 216 | no_sol_N = N_results(idx_N,:); 217 | end 218 | 219 | %% KN 220 | if compute_KN 221 | rnd = rnd_scale_KN*rand(num_KN,3); 222 | rnd2 = rnd_scale_KN*rand(num_KN,3); 223 | tol = 1e-8; 224 | KN_results = zeros(num_KN,4); 225 | KN_color = zeros(num_KN,3); 226 | KN_norm = zeros(num_KN,1); 227 | for i = 1:num_KN 228 | n1 = rnd(i,1); 229 | n2 = rnd(i,2); 230 | n3 = rnd(i,3); 231 | N = [1 n1 n2; 0 1 n3; 0 0 1]; 232 | K = expm(hat_so3(rnd2(i,:))); 233 | 234 | H = K*N; 235 | 236 | Ad_test = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3))); 237 | D = kron(inv(H'),Ad_test)-eye(24); 238 | 239 | KN_results(i,1) = n1; 240 | KN_results(i,2) = n2; 241 | KN_results(i,3) = n3; 242 | KN_results(i,4) = rank(D,tol); 243 | KN_results(i,6) = rnd2(i,1); 244 | KN_results(i,7) = rnd2(i,2); 245 | KN_results(i,8) = rnd2(i,3); 246 | KN_norm(i,1) = norm(KN_results(i,1:3)); 247 | if KN_results(i,4)<24 248 | KN_results(i,5) = 0; 249 | KN_color(i,:) = [0, 0.4470, 0.7410]; 250 | elseif KN_results(i,4) == 24 251 | % H 252 | KN_results(i,5) = 1; 253 | KN_color(i,:) = [0.8500, 0.3250, 0.0980]; 254 | end 255 | end 256 | 257 | KN_figure1 = figure(1); 258 | S = repmat(25,num_N,1); 259 | scatter3(KN_results(:,1),KN_results(:,2),KN_results(:,3),S,KN_color(:,:),'filled'); 260 | xlabel("x") 261 | ylabel("y") 262 | zlabel("z") 263 | title("Solution existence for KN") 264 | 265 | N_figure2 = figure(2); 266 | scatter(KN_norm,KN_results(:,5),S,KN_color,'filled'); 267 | 268 | 269 | idx_KN = KN_results(:,5) == 1; 270 | no_sol_KN = KN_results(idx_KN,:); 271 | end -------------------------------------------------------------------------------- /data_gen/gen_sl3_inv_5_input_data.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import numpy as np 5 | import torch 6 | from scipy.linalg import expm 7 | 8 | 9 | from einops import rearrange, repeat 10 | from einops.layers.torch import Rearrange 11 | 12 | from core.lie_neurons_layers import * 13 | 14 | 15 | def invariant_function(x1, x2, x3, x4, x5): 16 | return torch.sin(torch.trace(x1@x2@x3))+torch.cos(torch.trace(x3@x4@x5))\ 17 | -torch.pow(torch.trace(x5@x1), 6)/2.0+torch.det(x3@x2)+torch.exp(torch.trace(x4@x1))\ 18 | +torch.trace(x1@x2@x3@x4@x5) 19 | 20 | 21 | if __name__ == "__main__": 22 | data_saved_path = "data/sl3_inv_5_input_data/" 23 | data_name = "sl3_inv_1000_s_05_test" 24 | num_training = 1000 25 | num_testing = 1000 26 | num_conjugate = 500 27 | rnd_scale = 0.5 28 | 29 | train_data = {} 30 | test_data = {} 31 | 32 | hat_layer = HatLayer(algebra_type='sl3') 33 | 34 | x1 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_training)))\ 35 | .reshape(1, 1, 8, num_training) 36 | x2 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_training)))\ 37 | .reshape(1, 1, 8, num_training) 38 | x3 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_training)))\ 39 | .reshape(1, 1, 8, num_training) 40 | x4 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_training)))\ 41 | .reshape(1, 1, 8, num_training) 42 | x5 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_training)))\ 43 | .reshape(1, 1, 8, num_training) 44 | 45 | # conjugate transformation 46 | h = torch.Tensor(np.random.uniform(-rnd_scale, 47 | rnd_scale, (num_conjugate, num_training, 8))) 48 | H = torch.linalg.matrix_exp(hat_layer(h)) 49 | # conjugate x1 50 | x1_hat = hat_layer(x1.transpose(2, -1)) 51 | conj_x1_hat = torch.matmul(H, torch.matmul(x1_hat, torch.inverse(H))) 52 | conj_x1 = rearrange(vee_sl3(conj_x1_hat), 'b c t l -> b l t c') 53 | 54 | # conjugate x2 55 | x2_hat = hat_layer(x2.transpose(2, -1)) 56 | conj_x2_hat = torch.matmul(H, torch.matmul(x2_hat, torch.inverse(H))) 57 | conj_x2 = rearrange(vee_sl3(conj_x2_hat), 'b c t l -> b l t c') 58 | 59 | # conjugate x3 60 | x3_hat = hat_layer(x3.transpose(2, -1)) 61 | conj_x3_hat = torch.matmul(H, torch.matmul(x3_hat, torch.inverse(H))) 62 | conj_x3 = rearrange(vee_sl3(conj_x3_hat), 'b c t l -> b l t c') 63 | 64 | # conjugate x4 65 | x4_hat = hat_layer(x4.transpose(2, -1)) 66 | conj_x4_hat = torch.matmul(H, torch.matmul(x4_hat, torch.inverse(H))) 67 | conj_x4 = rearrange(vee_sl3(conj_x4_hat), 'b c t l -> b l t c') 68 | 69 | # conjugate x5 70 | x5_hat = hat_layer(x5.transpose(2, -1)) 71 | conj_x5_hat = torch.matmul(H, torch.matmul(x5_hat, torch.inverse(H))) 72 | conj_x5 = rearrange(vee_sl3(conj_x5_hat), 'b c t l -> b l t c') 73 | 74 | 75 | 76 | inv_output = torch.zeros((1, num_training, 1)) 77 | # compute invariant function 78 | for n in range(num_training): 79 | inv_output[0, n, 0] = invariant_function( 80 | x1_hat[0, 0, n, :, :], x2_hat[0, 0, n, :, :],x3_hat[0, 0, n, :, :],\ 81 | x4_hat[0, 0, n, :, :],x5_hat[0, 0, n, :, :]) 82 | 83 | # print("-------------------------------") 84 | # print(inv_output[0,n,0]) 85 | # for i in range(num_conjugate): 86 | # print(invariant_function(conj_x1_hat[0, i, n, :, :],conj_x2_hat[0, i, n, :, :],\ 87 | # conj_x3_hat[0, i, n, :, :],conj_x4_hat[0, i, n, :, :],\ 88 | # conj_x5_hat[0, i, n, :, :])) 89 | 90 | train_data['x1'] = x1.numpy().reshape(8, num_training) 91 | train_data['x2'] = x2.numpy().reshape(8, num_training) 92 | train_data['x3'] = x3.numpy().reshape(8, num_training) 93 | train_data['x4'] = x4.numpy().reshape(8, num_training) 94 | train_data['x5'] = x5.numpy().reshape(8, num_training) 95 | train_data['x1_conjugate'] = conj_x1.numpy().reshape(8, num_training, num_conjugate) 96 | train_data['x2_conjugate'] = conj_x2.numpy().reshape(8, num_training, num_conjugate) 97 | train_data['x3_conjugate'] = conj_x3.numpy().reshape(8, num_training, num_conjugate) 98 | train_data['x4_conjugate'] = conj_x4.numpy().reshape(8, num_training, num_conjugate) 99 | train_data['x5_conjugate'] = conj_x5.numpy().reshape(8, num_training, num_conjugate) 100 | train_data['y'] = inv_output.numpy().reshape(1, num_training) 101 | 102 | np.savez(data_saved_path + data_name + "_train_data.npz", **train_data) 103 | 104 | 105 | ''' 106 | Generate testing data 107 | ''' 108 | x1 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_training)))\ 109 | .reshape(1, 1, 8, num_testing) 110 | x2 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_testing)))\ 111 | .reshape(1, 1, 8, num_testing) 112 | x3 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_testing)))\ 113 | .reshape(1, 1, 8, num_testing) 114 | x4 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_testing)))\ 115 | .reshape(1, 1, 8, num_testing) 116 | x5 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_testing)))\ 117 | .reshape(1, 1, 8, num_testing) 118 | 119 | # conjugate transformation 120 | h = torch.Tensor(np.random.uniform(-rnd_scale, 121 | rnd_scale, (num_conjugate, num_testing, 8))) 122 | H = torch.linalg.matrix_exp(hat_layer(h)) 123 | # conjugate x1 124 | x1_hat = hat_layer(x1.transpose(2, -1)) 125 | conj_x1_hat = torch.matmul(H, torch.matmul(x1_hat, torch.inverse(H))) 126 | conj_x1 = rearrange(vee_sl3(conj_x1_hat), 'b c t l -> b l t c') 127 | 128 | # conjugate x2 129 | x2_hat = hat_layer(x2.transpose(2, -1)) 130 | conj_x2_hat = torch.matmul(H, torch.matmul(x2_hat, torch.inverse(H))) 131 | conj_x2 = rearrange(vee_sl3(conj_x2_hat), 'b c t l -> b l t c') 132 | 133 | # conjugate x3 134 | x3_hat = hat_layer(x3.transpose(2, -1)) 135 | conj_x3_hat = torch.matmul(H, torch.matmul(x3_hat, torch.inverse(H))) 136 | conj_x3 = rearrange(vee_sl3(conj_x3_hat), 'b c t l -> b l t c') 137 | 138 | # conjugate x4 139 | x4_hat = hat_layer(x4.transpose(2, -1)) 140 | conj_x4_hat = torch.matmul(H, torch.matmul(x4_hat, torch.inverse(H))) 141 | conj_x4 = rearrange(vee_sl3(conj_x4_hat), 'b c t l -> b l t c') 142 | 143 | # conjugate x5 144 | x5_hat = hat_layer(x5.transpose(2, -1)) 145 | conj_x5_hat = torch.matmul(H, torch.matmul(x5_hat, torch.inverse(H))) 146 | conj_x5 = rearrange(vee_sl3(conj_x5_hat), 'b c t l -> b l t c') 147 | 148 | inv_output = torch.zeros((1, num_testing, 1)) 149 | # compute invariant function 150 | for n in range(num_testing): 151 | inv_output[0, n, 0] = invariant_function( 152 | x1_hat[0, 0, n, :, :], x2_hat[0, 0, n, :, :],x3_hat[0, 0, n, :, :],\ 153 | x4_hat[0, 0, n, :, :],x5_hat[0, 0, n, :, :]) 154 | 155 | # print("-------------------------------") 156 | # print(inv_output[0,n,0]) 157 | # for i in range(num_conjugate): 158 | # print(invariant_function(conj_x1_hat[0, i, n, :, :],conj_x2_hat[0, i, n, :, :],\ 159 | # conj_x3_hat[0, i, n, :, :],conj_x4_hat[0, i, n, :, :],\ 160 | # conj_x5_hat[0, i, n, :, :])) 161 | 162 | test_data['x1'] = x1.numpy().reshape(8, num_testing) 163 | test_data['x2'] = x2.numpy().reshape(8, num_testing) 164 | test_data['x3'] = x3.numpy().reshape(8, num_testing) 165 | test_data['x4'] = x4.numpy().reshape(8, num_testing) 166 | test_data['x5'] = x5.numpy().reshape(8, num_testing) 167 | test_data['x1_conjugate'] = conj_x1.numpy().reshape(8, num_testing, num_conjugate) 168 | test_data['x2_conjugate'] = conj_x2.numpy().reshape(8, num_testing, num_conjugate) 169 | test_data['x3_conjugate'] = conj_x3.numpy().reshape(8, num_testing, num_conjugate) 170 | test_data['x4_conjugate'] = conj_x4.numpy().reshape(8, num_testing, num_conjugate) 171 | test_data['x5_conjugate'] = conj_x5.numpy().reshape(8, num_testing, num_conjugate) 172 | test_data['y'] = inv_output.numpy().reshape(1, num_testing) 173 | 174 | np.savez(data_saved_path + data_name + "_test_data.npz", **test_data) 175 | 176 | print("Done! Data saved to: \n", data_saved_path + 177 | data_name + "_train_data.npz\n", data_saved_path + data_name + "_test_data.npz") 178 | -------------------------------------------------------------------------------- /script/evaluate_multiple.bash: -------------------------------------------------------------------------------- 1 | proj_dir="/home/justin/code/LieNeurons/" 2 | config_path="config/" 3 | num_experiment=4 4 | GREEN='\033[0;32m' 5 | ORANGE='\033[0;33m' 6 | NC='\033[0m' # No Color 7 | export proj_dir 8 | 9 | # ****************************************************************************** 10 | # *******************************Evaluations *********************************** 11 | # ****************************************************************************** 12 | # invariant tasks 13 | # mlp 14 | # for iter in $(seq 1 $num_experiment); do 15 | # echo -e "--------------------%% Running invariant task: mlp $iter %%--------------------" 16 | # export iter 17 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_inv_mlp_"+env(iter)+"_best_test_loss_acc.pt"' $config_path"sl3_inv/testing_param.yaml" 18 | # yq e -i '.model_type = "MLP"' $config_path"sl3_inv/testing_param.yaml" 19 | # python experiment/sl3_inv_test.py 20 | # done 21 | # # invariant tasks 22 | # # LN-LR 23 | # for iter in $(seq 1 $num_experiment); do 24 | # echo -e "--------------------%% Running invariant task: LN-LR $iter %%--------------------" 25 | # export iter 26 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_inv_relu_"+env(iter)+"_best_test_loss_acc.pt"' $config_path"sl3_inv/testing_param.yaml" 27 | # yq e -i '.model_type = "LN_relu"' $config_path"sl3_inv/testing_param.yaml" 28 | # python experiment/sl3_inv_test.py 29 | # done 30 | # # invariant tasks 31 | # # LN-LB 32 | # for iter in $(seq 1 $num_experiment); do 33 | # echo -e "--------------------%% Running invariant task: LN-LB $iter %%--------------------" 34 | # export iter 35 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_inv_bracket_"+env(iter)+"_best_test_loss_acc.pt"' $config_path"sl3_inv/testing_param.yaml" 36 | # yq e -i '.model_type = "LN_bracket"' $config_path"sl3_inv/testing_param.yaml" 37 | # python experiment/sl3_inv_test.py 38 | # done 39 | 40 | # # invariant tasks 41 | # # LN-LR+LB 42 | # for iter in $(seq 1 $num_experiment); do 43 | # echo -e "--------------------%% Running invariant task: LN-LR+LB $iter %%--------------------" 44 | # export iter 45 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_inv_rb_"+env(iter)+"_best_test_loss_acc.pt"' $config_path"sl3_inv/testing_param.yaml" 46 | # yq e -i '.model_type = "LN_relu_bracket"' $config_path"sl3_inv/testing_param.yaml" 47 | # python experiment/sl3_inv_test.py 48 | # done 49 | 50 | # # mlp augmented 51 | # for iter in $(seq 1 $num_experiment); do 52 | # echo -e "--------------------%% Running invariant task: mlp augmented $iter %%--------------------" 53 | # export iter 54 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_inv_mlp_augmented_"+env(iter)+"_best_test_loss_acc.pt"' $config_path"sl3_inv/testing_param.yaml" 55 | # yq e -i '.model_type = "MLP"' $config_path"sl3_inv/testing_param.yaml" 56 | # python experiment/sl3_inv_test.py 57 | # done 58 | 59 | # # ============================================================================================================================================================================================ 60 | # # equivariant task 61 | # # mlp 62 | # for iter in $(seq 1 $num_experiment); do 63 | # echo -e "--------------------%% Running equivariant task: mlp $iter %%--------------------" 64 | # export iter 65 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_equiv_mlp_"+env(iter)+"_best_test_loss_acc.pt"' $config_path"sl3_equiv/testing_param.yaml" 66 | # yq e -i '.model_type = "MLP"' $config_path"sl3_equiv/testing_param.yaml" 67 | # python experiment/sl3_equiv_test.py 68 | # done 69 | 70 | # # equivariant task 71 | # # LN-LR 72 | # for iter in $(seq 1 $num_experiment); do 73 | # echo -e "--------------------%% Running equivariant task: LN-LR $iter %%--------------------" 74 | # export iter 75 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_equiv_relu_"+env(iter)+"_best_test_loss_acc.pt"' $config_path"sl3_equiv/testing_param.yaml" 76 | # yq e -i '.model_type = "LN_relu"' $config_path"sl3_equiv/testing_param.yaml" 77 | # python experiment/sl3_equiv_test.py 78 | # done 79 | 80 | # # equivariant task 81 | # # LN-LB 82 | # for iter in $(seq 1 $num_experiment); do 83 | # echo -e "--------------------%% Running equivariant task: LN-LB $iter %%--------------------" 84 | # export iter 85 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_equiv_bracket_"+env(iter)+"_best_test_loss_acc.pt"' $config_path"sl3_equiv/testing_param.yaml" 86 | # yq e -i '.model_type = "LN_bracket"' $config_path"sl3_equiv/testing_param.yaml" 87 | # python experiment/sl3_equiv_test.py 88 | # done 89 | 90 | # # equivariant task 91 | # # LN-LR+LB 92 | # for iter in $(seq 1 $num_experiment); do 93 | # echo -e "--------------------%% Running equivariant task: LN-LR+LB $iter %%--------------------" 94 | # export iter 95 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_equiv_rb_"+env(iter)+"_best_test_loss_acc.pt"' $config_path"sl3_equiv/testing_param.yaml" 96 | # yq e -i '.model_type = "LN_relu_bracket"' $config_path"sl3_equiv/testing_param.yaml" 97 | # python experiment/sl3_equiv_test.py 98 | # done 99 | 100 | # equivariant task 101 | # mlp augmented 102 | # for iter in $(seq 1 $num_experiment); do 103 | # echo -e "--------------------%% Running equivariant task: mlp augmented $iter %%--------------------" 104 | # export iter 105 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_equiv_mlp_augmented_"+env(iter)+"_best_test_loss_acc.pt"' $config_path"sl3_equiv/testing_param.yaml" 106 | # yq e -i '.model_type = "MLP"' $config_path"sl3_equiv/testing_param.yaml" 107 | # python experiment/sl3_equiv_test.py 108 | # done 109 | 110 | # ============================================================================================================================================================================================ 111 | # classification task 112 | 113 | # mlp augmented 114 | for iter in $(seq 1 $num_experiment); do 115 | echo -e "--------------------%% Running classification task: mlp augmentation $iter %%--------------------" 116 | export iter 117 | yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_cls_mlp_aug_"+env(iter)+"_best_test_acc.pt"' $config_path"platonic_solid_cls/testing_param.yaml" 118 | yq e -i '.model_type = "MLP"' $config_path"platonic_solid_cls/testing_param.yaml" 119 | python experiment/platonic_solid_cls_test.py 120 | done 121 | 122 | 123 | # mlp 124 | # for iter in $(seq 1 $num_experiment); do 125 | # echo -e "--------------------%% Running classification task: mlp $iter %%--------------------" 126 | # export iter 127 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_cls_mlp_"+env(iter)+"_best_loss_acc.pt"' $config_path"platonic_solid_cls/testing_param.yaml" 128 | # yq e -i '.model_type = "MLP"' $config_path"platonic_solid_cls/testing_param.yaml" 129 | # python experiment/platonic_solid_cls_test.py 130 | # done 131 | 132 | # # classification task 133 | # # LN-LR 134 | # for iter in $(seq 1 $num_experiment); do 135 | # echo -e "--------------------%% Running classification task: LN-LR $iter %%--------------------" 136 | # export iter 137 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_cls_relu_"+env(iter)+"_best_test_acc.pt"' $config_path"platonic_solid_cls/testing_param.yaml" 138 | # yq e -i '.model_type = "LN_relu"' $config_path"platonic_solid_cls/testing_param.yaml" 139 | # python experiment/platonic_solid_cls_test.py 140 | # done 141 | 142 | # # classification task 143 | # # LN-LB 144 | # for iter in $(seq 1 $num_experiment); do 145 | # echo -e "--------------------%% Running classification task: LN-LB $iter %%--------------------" 146 | # export iter 147 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_cls_bracket_"+env(iter)+"_best_test_acc.pt"' $config_path"platonic_solid_cls/testing_param.yaml" 148 | # yq e -i '.model_type = "LN_bracket"' $config_path"platonic_solid_cls/testing_param.yaml" 149 | # python experiment/platonic_solid_cls_test.py 150 | # done 151 | # # classification task 152 | # # LN-LR+LB 153 | # for iter in $(seq 1 $num_experiment); do 154 | # echo -e "--------------------%% Running classification task: LN-LR+LB $iter %%--------------------" 155 | # export iter 156 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_cls_rb_"+env(iter)+"_best_test_acc.pt"' $config_path"platonic_solid_cls/testing_param.yaml" 157 | # yq e -i '.model_type = "LN_relu_bracket"' $config_path"platonic_solid_cls/testing_param.yaml" 158 | # python experiment/platonic_solid_cls_test.py 159 | # done -------------------------------------------------------------------------------- /experiment/platonic_solid_cls_train.py: -------------------------------------------------------------------------------- 1 | import sys # nopep8 2 | sys.path.append('.') # nopep8 3 | 4 | import argparse 5 | import os 6 | import time 7 | import yaml 8 | import numpy as np 9 | from tqdm import tqdm 10 | import matplotlib.pyplot as plt 11 | import time 12 | from scipy.spatial.transform import Rotation 13 | 14 | import torch 15 | from torch import nn 16 | import torch.optim as optim 17 | from torch.utils.tensorboard import SummaryWriter 18 | from torch.utils.data import Dataset, DataLoader, IterableDataset 19 | 20 | from core.lie_neurons_layers import * 21 | from experiment.platonic_solid_cls_layers import * 22 | from data_gen.gen_platonic_solids import * 23 | 24 | 25 | def init_writer(config): 26 | writer = SummaryWriter( 27 | config['log_writer_path']+"_"+str(time.localtime()), comment=config['model_description']) 28 | # writer.add_text("train_data_path: ", config['train_data_path']) 29 | writer.add_text("model_save_path: ", config['model_save_path']) 30 | writer.add_text("log_writer_path: ", config['log_writer_path']) 31 | writer.add_text("shuffle: ", str(config['shuffle'])) 32 | writer.add_text("batch_size: ", str(config['batch_size'])) 33 | writer.add_text("init_lr: ", str(config['initial_learning_rate'])) 34 | writer.add_text("num_train: ", str(config['num_train'])) 35 | 36 | return writer 37 | 38 | def random_sample_rotations(num_rotation, rotation_factor: float = 1.0, device='cpu') -> np.ndarray: 39 | r = np.zeros((num_rotation, 3, 3)) 40 | for n in range(num_rotation): 41 | # angle_z, angle_y, angle_x 42 | euler = np.random.rand(3) * np.pi * 2 / rotation_factor # (0, 2 * pi / rotation_range) 43 | r[n,:,:] = Rotation.from_euler('zyx', euler).as_matrix() 44 | return torch.from_numpy(r).type('torch.FloatTensor').to(device) 45 | 46 | def test(model, test_loader, criterion, config, device): 47 | model.eval() 48 | with torch.no_grad(): 49 | loss_sum = 0.0 50 | num_correct = 0 51 | for iter, samples in tqdm(enumerate(test_loader, start=0)): 52 | 53 | x = samples[0].to(device) 54 | y = samples[1].to(device) 55 | 56 | x = rearrange(x,'b n f k -> b f k n') 57 | output = model(x) 58 | 59 | _, prediction = torch.max(output,1) 60 | num_correct += (prediction==y).sum().item() 61 | 62 | loss = criterion(output, y) 63 | loss_sum += loss.item() 64 | 65 | loss_avg = loss_sum/config['num_test']*config['batch_size'] 66 | acc_avg = num_correct/config['num_test'] 67 | 68 | return loss_avg, acc_avg 69 | 70 | 71 | def train(model, train_loader, test_loader, config, device='cpu'): 72 | 73 | writer = init_writer(config) 74 | 75 | # create criterion 76 | criterion = nn.CrossEntropyLoss().to(device) 77 | optimizer = optim.Adam(model.parameters( 78 | ), lr=config['initial_learning_rate'], weight_decay=config['weight_decay_rate']) 79 | # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=config['learning_rate_decay_rate']) 80 | # scheduler = optim.lr_scheduler.LinearLR(optimizer,total_iters=config['num_epochs']) 81 | 82 | if config['resume_training']: 83 | checkpoint = torch.load(config['resume_model_path']) 84 | model.load_state_dict(checkpoint['model_state_dict']) 85 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 86 | start_iter = checkpoint['epoch'] 87 | else: 88 | start_iter = 0 89 | 90 | best_loss = float("inf") 91 | best_acc = 0.0 92 | running_loss = 0.0 93 | loss_sum = 0.0 94 | hat_layer = HatLayer(algebra_type='sl3').to(device) 95 | 96 | for iter, samples in tqdm(enumerate(train_loader, start=0)): 97 | 98 | model.train() 99 | optimizer.zero_grad() 100 | 101 | x = samples[0].to(device) 102 | cls = samples[1].to(device) 103 | if config['train_augmentation']: 104 | rot = random_sample_rotations(1, config['rotation_factor'],device) 105 | x_hat = hat_layer(x) 106 | x_rot_hat = torch.matmul(rot, torch.matmul(x_hat, torch.inverse(rot))) 107 | x = vee_sl3(x_rot_hat) 108 | 109 | x = rearrange(x,'b n f k -> b f k n') 110 | 111 | output = model(x) 112 | # print(output) 113 | loss = criterion(output, cls) 114 | loss.backward() 115 | 116 | # we only update the weights every config['update_every_batch'] iterations 117 | # This is to simulate a larger batch size 118 | # if (i+1) % config['update_every_batch'] == 0: 119 | optimizer.step() 120 | optimizer.zero_grad() 121 | 122 | # cur_training_loss_history.append(loss.item()) 123 | running_loss += loss.item() 124 | loss_sum += loss.item() 125 | 126 | # if iter % config['print_freq'] == 0: 127 | # print("iteration %d / %d, loss: %.8f" % 128 | # (iter, config['num_train'], running_loss/config['print_freq'])) 129 | # running_loss = 0.0 130 | 131 | 132 | # scheduler.step() 133 | 134 | # train_top1, train_top5, _ = validate(train_loader, model, criterion, config, device) 135 | 136 | train_loss = loss_sum/(iter+1) 137 | 138 | test_loss, test_acc = test( 139 | model, test_loader, criterion, config, device) 140 | 141 | # log down info in tensorboard 142 | writer.add_scalar('training loss', train_loss, iter) 143 | writer.add_scalar('test loss', test_loss, iter) 144 | writer.add_scalar('test acc', test_acc, iter) 145 | 146 | # if we achieve best val loss, save the model 147 | if test_loss < best_loss: 148 | best_loss = test_loss 149 | 150 | state = {'iter': iter, 151 | 'model_state_dict': model.state_dict(), 152 | 'optimizer_state_dict': optimizer.state_dict(), 153 | 'loss': train_loss, 154 | 'test loss': test_loss, 155 | 'test acc': test_acc} 156 | 157 | torch.save(state, config['model_save_path'] + 158 | '_best_test_loss.pt') 159 | 160 | if test_acc > best_acc: 161 | best_acc = test_acc 162 | 163 | state = {'iter': iter, 164 | 'model_state_dict': model.state_dict(), 165 | 'optimizer_state_dict': optimizer.state_dict(), 166 | 'loss': train_loss, 167 | 'test loss': test_loss, 168 | 'test acc': test_acc} 169 | 170 | torch.save(state, config['model_save_path'] + 171 | '_best_test_acc.pt') 172 | print("------------------------------") 173 | # print("Finished epoch %d / %d, training top 1 acc: %.4f, training top 5 acc: %.4f, \ 174 | # validation top1 acc: %.4f, validation top 5 acc: %.4f" %\ 175 | # (epoch, config['num_epochs'], train_top1, train_top5, val_top1, val_top5)) 176 | print("Finished iteration %d / %d, train loss: %.4f test loss: %.4f test acc: %.4f" % 177 | (iter, config['num_train']/config['batch_size'], train_loss, test_loss, test_acc)) 178 | 179 | # save model 180 | state = {'iter': iter, 181 | 'model_state_dict': model.state_dict(), 182 | 'optimizer_state_dict': optimizer.state_dict(), 183 | 'loss': train_loss, 184 | 'test loss': test_loss} 185 | 186 | torch.save(state, config['model_save_path']+'_last_iter.pt') 187 | 188 | writer.close() 189 | 190 | 191 | def main(): 192 | # torch.autograd.set_detect_anomaly(True) 193 | 194 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 195 | print('Using ', device) 196 | 197 | parser = argparse.ArgumentParser(description='Train the network') 198 | parser.add_argument('--training_config', type=str, 199 | default=os.path.dirname(os.path.abspath(__file__))+'/../config/platonic_solid_cls/training_param.yaml') 200 | args = parser.parse_args() 201 | 202 | # load yaml file 203 | config = yaml.safe_load(open(args.training_config)) 204 | 205 | # create dataset and dataloader 206 | training_set = PlatonicDataset(config['num_train']) 207 | train_loader = DataLoader(dataset=training_set, batch_size=config['batch_size'], 208 | shuffle=config['shuffle']) 209 | 210 | test_set = PlatonicDataset(config['num_test']) 211 | test_loader = DataLoader(dataset=test_set, batch_size=config['batch_size'], 212 | shuffle=config['shuffle']) 213 | 214 | if config['model_type'] == "LN_relu_bracket": 215 | model = LNReluBracketPlatonicSolidClassifier(3).to(device) 216 | elif config['model_type'] == "LN_relu": 217 | model = LNReluPlatonicSolidClassifier(3).to(device) 218 | elif config['model_type'] == "LN_bracket": 219 | model = LNBracketPlatonicSolidClassifier(3).to(device) 220 | elif config['model_type'] == "MLP": 221 | model = MLP(288).to(device) 222 | elif config['model_type'] == "LN_bracket_no_residual": 223 | model = LNBracketNoResidualConnectPlatonicSolidClassifier(3).to(device) 224 | 225 | train(model, train_loader, test_loader, config, device) 226 | 227 | 228 | if __name__ == "__main__": 229 | main() 230 | --------------------------------------------------------------------------------