├── figures
├── .gitkeep
├── Lie_Nuerons_ICML24.pdf
├── lie_neurons_icon.jpg
└── lie_neurons_modules.jpg
├── matlab
├── inv_figure.fig
├── vee_so3.m
├── hat_so3.m
├── hat_sl3.m
├── vee_sl3.m
├── test_euler_poincare.m
├── test_Ad_N.m
├── test_killing_relu.m
├── find_adjoint_KAN.m
├── sl3_equivariant_lift.asv
├── sl3_equivariant_lift.m
├── sl3_equivariant_lift_find_bound.asv
└── sl3_equivariant_lift_find_bound.m
├── playground
├── emlp_test.py
├── neural_ode_nonhomogeneous.py
└── neural_ode_demo.py
├── docker
├── EMLP
│ ├── build_docker_container.bash
│ ├── README.md
│ └── Dockerfile
└── LieNeurons
│ ├── build_docker_container.bash
│ ├── README.md
│ └── Dockerfile
├── test
├── test_sp4_hat.py
├── test_se3_hat.py
├── test_so3_bch.py
├── test_so3_exp.py
├── test_invariant.py
├── test_equivariant_liebracket.py
├── test_equivariant.py
├── test_pooling_equivariant.py
├── test_batch_norm_equivariant.py
└── test_so3_bch_time.py
├── config
├── sl3_equiv
│ ├── testing_param.yaml
│ └── training_param.yaml
├── so3_bch
│ ├── testing_param.yaml
│ └── training_param.yaml
├── sl3_inv
│ ├── testing_param.yaml
│ └── training_param.yaml
├── platonic_solid_cls
│ ├── testing_param.yaml
│ └── training_param.yaml
└── sp4_inv
│ ├── testing_param.yaml
│ └── training_param.yaml
├── LICENSE
├── script
├── run_euler_poincare.bash
└── evaluate_multiple.bash
├── data_loader
├── sp4_inv_data_loader.py
├── sl3_inv_data_loader.py
└── so3_bch_data_loader.py
├── README.md
├── core
├── vn_layers.py
└── lie_group_util.py
├── .gitignore
├── experiment
├── platonic_solid_cls_test.py
├── sl3_inv_test.py
├── sp4_inv_test.py
├── sl3_equiv_test.py
├── sl3_inv_train.py
├── sp4_inv_train.py
├── so3_bch_layers.py
├── platonic_solid_cls_layers.py
├── sl3_equiv_train.py
└── platonic_solid_cls_train.py
└── data_gen
├── gen_sl3_inv_data.py
├── gen_sp4_inv_data.py
├── gen_so3_bch.py
└── gen_sl3_inv_5_input_data.py
/figures/.gitkeep:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/matlab/inv_figure.fig:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UMich-CURLY/LieNeurons/HEAD/matlab/inv_figure.fig
--------------------------------------------------------------------------------
/figures/Lie_Nuerons_ICML24.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UMich-CURLY/LieNeurons/HEAD/figures/Lie_Nuerons_ICML24.pdf
--------------------------------------------------------------------------------
/figures/lie_neurons_icon.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UMich-CURLY/LieNeurons/HEAD/figures/lie_neurons_icon.jpg
--------------------------------------------------------------------------------
/figures/lie_neurons_modules.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UMich-CURLY/LieNeurons/HEAD/figures/lie_neurons_modules.jpg
--------------------------------------------------------------------------------
/matlab/vee_so3.m:
--------------------------------------------------------------------------------
1 | function v = vee_so3(M)
2 |
3 | v(1) = M(3, 2);
4 | v(2) = M(1, 3);
5 | v(3) = M(2, 1);
6 |
7 | end
--------------------------------------------------------------------------------
/matlab/hat_so3.m:
--------------------------------------------------------------------------------
1 | function R = hat_so3(v)
2 | Ex = [0, 0, 0;0, 0, -1;0, 1, 0];
3 | Ey = [0, 0, 1;0, 0, 0;-1, 0, 0];
4 | Ez = [0, -1, 0;1, 0, 0;0, 0, 0];
5 |
6 | R = v(1)*Ex+v(2)*Ey+v(3)*Ez;
7 |
8 | % if ndims(v)==1
9 | % R = v(1)*Ex+v(2)*Ey+v(3)*Ez;
10 | % elseif ndims(v) > 1
11 | % R = v(:,1).*Ex+v(:,2).*Ey+v(:,3).*Ez;
12 | % end
13 |
14 | end
--------------------------------------------------------------------------------
/playground/emlp_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import emlp.nn.pytorch as nn
4 | from emlp.reps import T,V,sparsify_basis
5 | from emlp.groups import SO, SL
6 |
7 |
8 |
9 | if __name__ == '__main__':
10 | # Define the input representation
11 | G = SL(3)
12 | reps = V(G)
13 | reps_out = (V**1*V.T**1)(G)
14 |
15 | print(reps.rho(G))
16 | print(reps.size())
17 | print(reps_out.size())
18 | Q = (reps>>reps_out).equivariant_basis()
19 | print(f"Basis matrix of shape {Q.shape}")
20 | # print(sparsify_basis(Q).reshape(3,3))
21 | print(reps(G).size())
22 | # Define the output representation
--------------------------------------------------------------------------------
/docker/EMLP/build_docker_container.bash:
--------------------------------------------------------------------------------
1 | container_name=$1
2 |
3 | xhost +local:
4 | docker run -it --net=host --shm-size 8G --gpus all \
5 | --user=$(id -u) \
6 | -e DISPLAY=$DISPLAY \
7 | -e QT_GRAPHICSSYSTEM=native \
8 | -e NVIDIA_DRIVER_CAPABILITIES=all \
9 | -e XAUTHORITY \
10 | -e USER=$USER \
11 | --workdir=/home/$USER/ \
12 | -v "/tmp/.X11-unix:/tmp/.X11-unix:rw" \
13 | -v "/etc/passwd:/etc/passwd:rw" \
14 | -e "TERM=xterm-256color" \
15 | -v "/home/$USER/DockerFolder:/home/$USER/" \
16 | -v "/media/$USER/DATA/data:/home/$USER/data/" \
17 | --device=/dev/dri:/dev/dri \
18 | --name=${container_name} \
19 | --security-opt seccomp=unconfined \
20 | umcurly/emlp:latest
21 |
--------------------------------------------------------------------------------
/docker/LieNeurons/build_docker_container.bash:
--------------------------------------------------------------------------------
1 | container_name=$1
2 |
3 | xhost +local:
4 | docker run -it --net=host --shm-size 8G --gpus all \
5 | --user=$(id -u) \
6 | -e DISPLAY=$DISPLAY \
7 | -e QT_GRAPHICSSYSTEM=native \
8 | -e NVIDIA_DRIVER_CAPABILITIES=all \
9 | -e XAUTHORITY \
10 | -e USER=$USER \
11 | --workdir=/home/$USER/ \
12 | -v "/tmp/.X11-unix:/tmp/.X11-unix:rw" \
13 | -v "/etc/passwd:/etc/passwd:rw" \
14 | -e "TERM=xterm-256color" \
15 | -v "/home/$USER/DockerFolder:/home/$USER/" \
16 | -v "/media/$USER/DATA/data:/home/$USER/data/" \
17 | --device=/dev/dri:/dev/dri \
18 | --name=${container_name} \
19 | --security-opt seccomp=unconfined \
20 | umcurly/lieneurons:latest
21 |
--------------------------------------------------------------------------------
/matlab/hat_sl3.m:
--------------------------------------------------------------------------------
1 | function A = hat_sl3(v)
2 |
3 | % E1 = [1, 0, 0; 0, -1, 0; 0, 0, 0];
4 | % E2 = [0, 1, 0; 1, 0, 0; 0, 0, 0];
5 | % E3 = [0, -1, 0; 1, 0, 0; 0, 0, 0];
6 | % E4 = [1, 0, 0; 0, 1, 0; 0, 0, -2];
7 | % E5 = [0, 0, 1; 0, 0, 0; 0, 0, 0];
8 | % E6 = [0, 0, 0; 0, 0, 1; 0, 0, 0];
9 | % E7 = [0, 0, 0; 0, 0, 0; 1, 0, 0];
10 | % E8 = [0, 0, 0; 0, 0, 0; 0, 1, 0];
11 | E1 = [1, 0, 0; 0, -1, 0; 0, 0, 0];
12 | E2 = [0, 1, 0; 1, 0, 0; 0, 0, 0];
13 | E3 = [0, -1, 0; 1, 0, 0; 0, 0, 0];
14 | E4 = [1, 0, 0; 0, 1, 0; 0, 0, -2];
15 | E5 = [0, 0, 1; 0, 0, 0; 0, 0, 0];
16 | E6 = [0, 0, 0; 0, 0, 1; 0, 0, 0];
17 | E7 = [0, 0, 0; 0, 0, 0; 1, 0, 0];
18 | E8 = [0, 0, 0; 0, 0, 0; 0, 1, 0];
19 |
20 |
21 | A = v(1)*E1+v(2)*E2+v(3)*E3+v(4)*E4+v(5)*E5+v(6)*E6+v(7)*E7+v(8)*E8;
22 | end
23 |
24 |
--------------------------------------------------------------------------------
/docker/EMLP/README.md:
--------------------------------------------------------------------------------
1 | # EPN-NetVLAD Docker
2 | This folder contains instructions on creating a docker image that includes PyTorch and other libraries needed for EPN-NetVLAD.
3 |
4 | ## How to build
5 |
6 | build image
7 | ```
8 | docker build --tag umcurly/emlp .
9 | ```
10 | change folder direction in line 15, then build container
11 | ```
12 | bash build_docker_container.bash [container_name]
13 | ```
14 | After building the container, you will enter the docker container. To work stably in docker, we recommend running `exit` and then follow the next section for running docker.
15 |
16 | ## How to use
17 | start docker
18 | ```
19 | docker start [container_name]
20 | ```
21 | run docker
22 | ```
23 | docker exec -it [container_name] /bin/bash
24 | ```
25 | run docker with root access
26 | ```
27 | docker exec -u root -it [container_name] /bin/bash
28 | ```
29 |
--------------------------------------------------------------------------------
/docker/LieNeurons/README.md:
--------------------------------------------------------------------------------
1 | # EPN-NetVLAD Docker
2 | This folder contains instructions on creating a docker image that includes PyTorch and other libraries needed for EPN-NetVLAD.
3 |
4 | ## How to build
5 |
6 | build image
7 | ```
8 | docker build --tag umcurly/lieneurons .
9 | ```
10 | change folder direction in line 15, then build container
11 | ```
12 | bash build_docker_container.bash [container_name]
13 | ```
14 | After building the container, you will enter the docker container. To work stably in docker, we recommend running `exit` and then follow the next section for running docker.
15 |
16 | ## How to use
17 | start docker
18 | ```
19 | docker start [container_name]
20 | ```
21 | run docker
22 | ```
23 | docker exec -it [container_name] /bin/bash
24 | ```
25 | run docker with root access
26 | ```
27 | docker exec -u root -it [container_name] /bin/bash
28 | ```
29 |
--------------------------------------------------------------------------------
/matlab/vee_sl3.m:
--------------------------------------------------------------------------------
1 | function v = vee_sl3(M)
2 | v(4) = -0.5 * M(3, 3);
3 | v(5) = M(1, 3);
4 | v(6) = M(2, 3);
5 | v(7) = M(3, 1);
6 | v(8) = M(3, 2);
7 | v(1) = (M(1, 1) - v(4));
8 |
9 | v(2) = 0.5*(M(1, 2) + M(2, 1));
10 | v(3) = 0.5*(M(2, 1) - M(1, 2));
11 |
12 | v = v';
13 | end
14 |
15 | % def vee_sl3(M):
16 | % # [a1 + a4, a2 - a3, a5]
17 | % # [a2 + a3, a4 - a1, a6]
18 | % # [ a7, a8, -2*a4]
19 | % v = torch.zeros(M.shape[:-2]+(8,)).to(M.device)
20 | % v[..., 3] = -0.5*M[..., 2, 2]
21 | % v[..., 4] = M[..., 0, 2]
22 | % v[..., 5] = M[..., 1, 2]
23 | % v[..., 6] = M[..., 2, 0]
24 | % v[..., 7] = M[..., 2, 1]
25 | % v[..., 0] = (M[..., 0, 0] - v[..., 3])
26 | %
27 | % v[..., 1] = 0.5*(M[..., 0, 1] + M[..., 1, 0])
28 | % v[..., 2] = 0.5*(M[..., 1, 0] - M[..., 0, 1])
29 | % return v
--------------------------------------------------------------------------------
/test/test_sp4_hat.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import numpy as np
5 | import scipy.linalg
6 | import torch
7 |
8 | from core.lie_alg_util import *
9 |
10 | if __name__ == "__main__":
11 |
12 | v = torch.Tensor([1,2,3,4,5,6,7,8,9,10])
13 | print("v: ", v)
14 | sp4_hatlayer = HatLayer(algebra_type='sp4')
15 | M = sp4_hatlayer(v)
16 | print("M: ", M)
17 | v2 = vee(M, algebra_type='sp4')
18 | print("v after vee: ", v2)
19 |
20 | rnd_scale=5
21 | v = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 1, 10)))
22 | M = sp4_hatlayer(v).squeeze(0).squeeze(0)
23 |
24 | zeros = torch.zeros(2, 2)
25 | I = torch.eye(2)
26 | # Construct the 4x4 matrix by combining the components
27 | omega = torch.cat((torch.cat((zeros, I), dim=1),
28 | torch.cat((-I, zeros), dim=1)), dim=0)
29 |
30 | print("Checking if M satisfy the definition of sp(4): ")
31 | print(omega@M+M.T@omega)
--------------------------------------------------------------------------------
/matlab/test_euler_poincare.m:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | true_y0 = [2., 1.,3.0];
5 | t = [0:0.025:25];
6 |
7 |
8 | [t,y] = ode45(@EulerPoincare, t, true_y0);
9 |
10 |
11 | f1 = figure(1);
12 | plot(t,y);
13 | legend("x","y","z");
14 | xlabel("time")
15 | ylabel("x,y,z")
16 |
17 | f2 = figure(2);
18 | plot(y(:,1),y(:,2));
19 | xlabel("x")
20 | ylabel("y")
21 |
22 | f3 = figure(3);
23 | plot3(y(:,1),y(:,2),y(:,3));
24 |
25 | [X,Y,Z] = meshgrid(-2:0.1:2,-2:0.1:2,-2:0.1:2);
26 |
27 | function w_wedge = wedge(w)
28 | w_wedge = zeros(3,3);
29 | w_wedge(1,2) = -w(3);
30 | w_wedge(1,3) = w(2);
31 | w_wedge(2,1) = w(3);
32 | w_wedge(2,3) = -w(1);
33 | w_wedge(3,1) = -w(2);
34 | w_wedge(3,2) = w(1);
35 | end
36 |
37 | function dwdt = EulerPoincare(t,w)
38 | % I = [[12, 0, 0];[0, 20., 0];[0, 0, 5.]];
39 | % I = [[12, -5., 7.];[-5., 20., -2.];[7., -2., 5.]];
40 | I = [[5410880., -246595., 2967671.];[-246595., 29457838., -47804.];[2967671., -47804., 26744180.]]
41 | w_wedge = wedge(w);
42 |
43 |
44 | dwdt = -I\w_wedge*I*w;
45 | end
--------------------------------------------------------------------------------
/config/sl3_equiv/testing_param.yaml:
--------------------------------------------------------------------------------
1 | test_data_path: "/home/justin/code/LieNeurons/data/sl3_equiv_lie_bracket_data/sl3_equiv_10000_lie_bracket_2inputs_test_data.npz"
2 | # logger
3 | model_description: "SL3EquivariantFunctionFitting"
4 | print_freq: 100
5 | # model_type: "LN"
6 | # model_path: "/home/justin/code/LieNeurons/weights/0926_exp_sl3_equiv_lie_bracket_2input_LN_5_256_no_share_nonlin_train_10000_lr_ 0.000003_best_test_loss_acc.pt"
7 |
8 | # model_type: "LN_relu_bracket"
9 | # model_path: "/home/justin/code/LieNeurons/weights/0928_equiv_one_relu_one_bracket_final_best_test_loss_acc.pt"
10 | model_type: "MLP"
11 | model_path: "/home/justin/code/LieNeurons/weights/rebuttal_equiv_mlp_augmented_5_best_test_loss_acc.pt"
12 | # model_type: "MLP"
13 | # model_path: "/home/justin/code/LieNeurons/weights/rebuttal_equiv_mlp_augmented_5_best_test_loss_acc.pt"
14 | # model_type: "MLP"
15 | # model_path: "/home/justin/code/LieNeurons/weights/0928_equiv_MLP_best_test_loss_acc.pt"
16 |
17 | # model_type: "LN_bracket_no_residual"
18 | # model_path: "/home/justin/code/LieNeurons/weights/0928_equiv_one_bracket_no_residual_best_test_loss_acc.pt"
19 | batch_size: 100
20 | shuffle: False
21 |
--------------------------------------------------------------------------------
/config/so3_bch/testing_param.yaml:
--------------------------------------------------------------------------------
1 | test_data_path: "/home/justin/code/LieNeurons/data/so3_bch_data/so3_bch_10000_test_data.npz"
2 | # logger
3 | model_description: "SO3BCHFunctionFitting"
4 | print_freq: 100
5 |
6 | # model_type: "LN_relu_bracket"
7 | # model_path: "/home/justin/code/LieNeurons/weights/0120_so3_bch_relu_bracket_best_test_loss_acc.pt"
8 |
9 | # model_type: "LN_relu"
10 | # model_path: "/home/justin/code/LieNeurons/weights/0120_so3_bch_relu_4_layers_1024"
11 |
12 | # model_type: "LN_bracket"
13 | # model_path: "/home/justin/code/LieNeurons/weights/0124_so3_bch_bracket_best_test_loss_acc.pt"
14 |
15 | # model_type: "MLP"
16 | # model_path: "/home/justin/code/LieNeurons/weights/0327_so3_bch_mlp_4_layers_1024_augmented_best_test_loss_acc.pt"
17 |
18 | # model_type: "EMLP"
19 | # model_path: "/home/justin/code/LieNeurons/weights/EMLP_best_test_loss_acc.pt"
20 |
21 | # model_type: "e3nn_norm"
22 | # model_path: "/home/justin/code/LieNeurons/weights/0404_BCH_e3nn_norm_256_best_test_loss_acc.pt"
23 |
24 | model_type: "e3nn_s2grid"
25 | model_path: "/home/justin/code/LieNeurons/weights/0404_BCH_e3nn_s2grid_best_test_loss_acc.pt"
26 |
27 | batch_size: 1
28 | shuffle: False
29 |
30 | calculate_eq_approx_error: False
--------------------------------------------------------------------------------
/config/sl3_inv/testing_param.yaml:
--------------------------------------------------------------------------------
1 | test_data_path: "/home/justin/code/LieNeurons/data/sl3_inv_data/sl3_inv_10000_s_05_test_data.npz"
2 | # logger
3 | model_description: "SL3InvariantFunctionFitting"
4 | print_freq: 100
5 | # model_type: "LN"
6 | # model_path: "/home/justin/code/LieNeurons/weights/0928_sl3_inv_LN_256_no_share_nonlin_train_10000_lr_0.000003_new_best_test_loss_acc.pt"
7 | # model_type: "MLP"
8 | # model_path: "/home/justin/code/LieNeurons/weights/rebuttal_inv_mlp_augmented_5_best_test_loss_acc.pt"
9 | model_type: "MLP"
10 | model_path: "/home/justin/code/LieNeurons/weights/rebuttal_inv_mlp_augmented_5_best_test_loss_acc.pt"
11 | # model_type: "LN_bracket"
12 | # model_path: "/home/justin/code/LieNeurons/weights/0928_one_bracket_0.00003_best_test_loss_acc.pt"
13 |
14 | # model_type: "LN_bracket"
15 | # model_path: "/home/justin/code/LieNeurons/weights/1120_one_bracket_2_best_test_loss_acc.pt"
16 |
17 | # model_type: "MLP"
18 | # model_path: "/home/justin/code/LieNeurons/weights/1120_MLP_augmented_best_test_loss_acc.pt"
19 |
20 | # model_type: "LN_bracket_no_residual"
21 | # model_path: "/home/justin/code/LieNeurons/weights/0928_one_bracket_no_residual_best_test_loss_acc.pt"
22 | num_workers: 1
23 | batch_size: 100
24 | shuffle: True
25 |
--------------------------------------------------------------------------------
/test/test_se3_hat.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import numpy as np
5 | import scipy.linalg
6 | import torch
7 |
8 | from core.lie_alg_util import *
9 |
10 | if __name__ == "__main__":
11 |
12 | v = torch.Tensor([1,2,3,4,5,6])
13 | print("v: ", v)
14 | se3_hatlayer = HatLayer(algebra_type='se3')
15 | M = se3_hatlayer(v)
16 | print("M: ", M)
17 | v2 = vee(M, algebra_type='se3')
18 | print("v after vee: ", v2)
19 |
20 | rnd_scale=5
21 | v = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 1, 6)))
22 | print("v: ", v)
23 | print(v.shape)
24 | M = se3_hatlayer(v)
25 | M_SE3 = scipy.linalg.expm(M[0,0,:,:].numpy())
26 | print("M_SE3: ", M_SE3)
27 | M2 = scipy.linalg.logm(M_SE3)
28 |
29 | print("M: ", M)
30 | print("M log: ", M2)
31 |
32 | v = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 1, 3)))
33 | print("v: ", v)
34 | print(v.shape)
35 | so3_hatlayer = HatLayer(algebra_type='so3')
36 | M = so3_hatlayer(v)
37 | print(M.shape)
38 | M_SE3 = scipy.linalg.expm(M[0,0,:,:].numpy())
39 | print("M_SE3: ", M_SE3)
40 | M2 = scipy.linalg.logm(M_SE3)
41 |
42 | print("M: ", M)
43 | print("M log: ", M2)
--------------------------------------------------------------------------------
/config/platonic_solid_cls/testing_param.yaml:
--------------------------------------------------------------------------------
1 | test_data_path: "/home/justin/code/LieNeurons/data/sl3_equiv_5_input_data/sl3_equiv_100_s_05_test_data.npz"
2 | # logger
3 | model_description: "SL3EquivariantFunctionFitting"
4 | print_freq: 100
5 | # model_type: "LN"
6 | # model_path: "/home/justin/code/LieNeurons/weights/0926_platonic_cls_LN_best_test_loss_good.pt"
7 |
8 | model_type: "MLP"
9 | model_path: "/home/justin/code/LieNeurons/weights/rebuttal_cls_mlp_aug_4_best_test_acc.pt"
10 | # model_type: "LN_relu_bracket"
11 | # model_path: "/home/justin/code/LieNeurons/weights/rebuttal_cls_rb_5_best_test_acc.pt"
12 |
13 | # model_type: "LN_relu"
14 | # model_path: "/home/justin/code/LieNeurons/weights/0928_cls_one_relu_best_test_acc.pt"
15 |
16 | # model_type: "LN_bracket"
17 | # model_path: "/home/justin/code/LieNeurons/weights/0928_cls_one_bracket_best_test_acc.pt"
18 |
19 | # model_type: "MLP"
20 | # model_path: "/home/justin/code/LieNeurons/weights/0928_cls_MLP_best_test_acc.pt"
21 |
22 | # model_type: "LN_bracket_no_residual"
23 | # model_path: "/home/justin/code/LieNeurons/weights/0928_cls_one_bracket_no_residual_best_test_acc.pt"
24 | batch_size: 100
25 | shuffle: False
26 | num_test: 1000
27 | num_rotations: 500
28 | rotation_factor: 0.174533 # 10 degrees
29 |
--------------------------------------------------------------------------------
/test/test_so3_bch.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import numpy as np
5 | import scipy.linalg
6 | import torch
7 | import math
8 |
9 | from core.lie_alg_util import *
10 | from core.lie_group_util import *
11 |
12 | if __name__ == "__main__":
13 |
14 |
15 | v1 = torch.rand(1,3)
16 | v1 = v1/torch.norm(v1)
17 | phi = math.pi
18 | v1 = phi*v1
19 | # print("v: ", v1)
20 |
21 | v2 = v1
22 | v2 = v2/torch.norm(v2)
23 | phi2 = math.pi
24 | v2 = phi2*v2
25 |
26 | v1 = torch.Tensor([[1.4911, 0.6458, 1.0547]])
27 | v2 = torch.Tensor([[0.2295, 2.0104, 0.0430]])
28 |
29 | so3_hatlayer = HatLayer(algebra_type='so3')
30 | K1 = so3_hatlayer(v1)
31 | K2 = so3_hatlayer(v2)
32 |
33 | R1 = exp_so3(K1[0,:,:])
34 | R2 = exp_so3(K2[0,:,:])
35 |
36 | print("v1: ", v1)
37 | print("v2: ", v2)
38 | print("R1: ", R1)
39 | print("R2: ", R2)
40 |
41 | R3 = R1@R2
42 | print("R3:", R3)
43 | K3 = log_SO3(R3)
44 | K3_BCH = BCH_approx(K1[0,:,:],K2[0,:,:])
45 | K3_BCH_SO3 = BCH_so3(K1[0,:,:],K2[0,:,:])
46 | print("K1: ", K1)
47 | print("K2: ", K2)
48 | print("----")
49 | print("K3: ", K3)
50 | print("K3 BCH: ", K3_BCH)
51 | print("K1+K2: ", K1+K2)
52 | print("K3 BCH SO3: ", K3_BCH_SO3)
53 | print("norm K3: ", torch.norm(vee(K3, algebra_type='so3')))
--------------------------------------------------------------------------------
/test/test_so3_exp.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import numpy as np
5 | import scipy.linalg
6 | import torch
7 | import math
8 |
9 | from core.lie_alg_util import *
10 | from core.lie_group_util import *
11 |
12 | if __name__ == "__main__":
13 |
14 | rnd_scale=1
15 | v = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 1, 3)))
16 | v = v/torch.norm(v)
17 | phi = math.pi*torch.rand(1)
18 | v = phi*v
19 | print("v: ", v)
20 | so3_hatlayer = HatLayer(algebra_type='so3')
21 | K = so3_hatlayer(v)
22 | K_SO3 = torch.asarray(scipy.linalg.expm(K[0,0,:,:].numpy()))
23 | K_SO3_2 = exp_hat_and_so3(v.T)
24 | K_SO3_3 = exp_so3(K[0,0,:,:])
25 | K_SO3_exp_gpu = exp_so3(K[0,0,:,:].to('cuda'))
26 | print("-------------------")
27 | print("SO3: ", K_SO3)
28 | print("SO3_2: ", K_SO3_2)
29 | print("SO3_3: ", K_SO3_3)
30 | print("-------------------")
31 | print("det SO3: ", np.linalg.det(K_SO3))
32 | print("det SO3_2: ", np.linalg.det(K_SO3_2))
33 | print("-------------------")
34 | K_after = log_SO3(K_SO3)
35 | K2_after = log_SO3(K_SO3_2)
36 | K_log_gpu = log_SO3(K_SO3_3.to('cuda'))
37 | K_exp_gpu_after = log_SO3(K_SO3_exp_gpu)
38 |
39 | print("K: ", K)
40 | print("M: ", K_after)
41 | print("M log: ", K2_after)
42 | print("M log gpu: ", K_exp_gpu_after)
43 | print("M log gpu: ", K_log_gpu)
--------------------------------------------------------------------------------
/config/sp4_inv/testing_param.yaml:
--------------------------------------------------------------------------------
1 | test_data_path: "/home/justin/code/LieNeurons/data/sp4_inv_data/sp4_inv_10000_s_05_test_data.npz"
2 | # logger
3 | model_description: "SP4InvariantFunctionFitting"
4 | print_freq: 100
5 | # model_type: "LN"
6 | # model_path: "/home/justin/code/LieNeurons/weights/0928_sl3_inv_LN_256_no_share_nonlin_train_10000_lr_0.000003_new_best_test_loss_acc.pt"
7 | # model_type: "MLP"
8 | # model_path: "/home/justin/code/LieNeurons/weights/rebuttal_inv_mlp_augmented_5_best_test_loss_acc.pt"
9 | # model_type: "MLP"
10 | # model_path: "/home/justin/code/LieNeurons/weights/rebuttal_inv_mlp_augmented_5_best_test_loss_acc.pt"
11 |
12 | # model_type: "LN_bracket"
13 | # model_path: "/home/justin/code/LieNeurons/weights/0928_one_bracket_0.00003_best_test_loss_acc.pt"
14 |
15 | model_type: "LN_relu_bracket"
16 | model_path: "/home/justin/code/LieNeurons/weights/0404_sp4_inv_relu_bracket_best_test_loss_acc.pt"
17 |
18 | # model_type: "LN_bracket"
19 | # model_path: "/home/justin/code/LieNeurons/weights/1120_one_bracket_2_best_test_loss_acc.pt"
20 |
21 | # model_type: "MLP512"
22 | # model_path: "/home/justin/code/LieNeurons/weights/0404_sp4_inv_MLP_512_augmented_best_test_loss_acc.pt"
23 |
24 | # model_type: "MLP512"
25 | # model_path: "/home/justin/code/LieNeurons/weights/0404_sp4_inv_MLP_512_best_test_loss_acc.pt"
26 |
27 | # model_type: "LN_bracket_no_residual"
28 | # model_path: "/home/justin/code/LieNeurons/weights/0928_one_bracket_no_residual_best_test_loss_acc.pt"
29 | num_workers: 1
30 | batch_size: 100
31 | shuffle: True
32 |
--------------------------------------------------------------------------------
/docker/LieNeurons/Dockerfile:
--------------------------------------------------------------------------------
1 | # FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel
2 | # FROM pytorch/pytorch:0.4_cuda9_cudnn7
3 | # FROM pytorch/pytorch:1.2-cuda10.0-cudnn7-devel
4 | FROM nvcr.io/nvidia/pytorch:21.12-py3
5 |
6 | LABEL version="0.5"
7 |
8 | # USER root
9 |
10 | ENV DEBIAN_FRONTEND noninteractive
11 |
12 | # build essentials
13 | # RUN apt-get update && apt-get -y install cmake
14 | RUN apt-get update && apt-get install -y vim
15 | RUN apt-get install -y build-essential snap
16 | RUN apt-get update && apt-get install -y git-all
17 | # RUN add-apt-repository ppa:rmescandon/yq
18 | # RUN apt update
19 | # RUN snap install yq
20 | RUN wget https://github.com/mikefarah/yq/releases/latest/download/yq_linux_amd64 -O /usr/bin/yq &&\
21 | chmod +x /usr/bin/yq
22 |
23 | # python essentials
24 | # Starting from numpy 1.24.0 np.float would pop up an error.
25 | # Cocoeval tool still uses np.float. As a result we can only use numpy 1.23.0.
26 | # Alternatively, we can clone the cocoapi and modify the cocoeval tool.
27 | # https://github.com/cocodataset/cocoapi/pull/569
28 | RUN pip install numpy==1.23.0
29 | RUN pip install -U matplotlib
30 | RUN pip install scipy
31 | RUN pip install tensorboardX
32 | RUN pip install wandb
33 |
34 | # torch
35 | RUN pip install torch>=1.7.0
36 | RUN pip install torchvision>=0.8.1
37 |
38 | # einops
39 | RUN pip install einops
40 |
41 | # pytest
42 | RUN pip install pytest
43 |
44 | # setuptool
45 | RUN pip install setuptools==59.5.0
46 |
47 | # torch diff eq
48 | RUN pip install git+https://github.com/rtqichen/torchdiffeq
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2023, UMich-CURLY
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/test/test_invariant.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from core.lie_neurons_layers import *
8 | from core.lie_alg_util import *
9 | from experiment.sl3_inv_layers import *
10 |
11 |
12 | if __name__ == "__main__":
13 | print("testing the invariant layer")
14 |
15 | # test equivariant linear layer
16 | num_points = 1
17 | num_features = 10
18 | out_features = 3
19 | batch_size = 20
20 | rnd_scale = 0.5
21 |
22 | hat_layer = HatLayer()
23 |
24 | x = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (batch_size, num_features, 8, num_points))).reshape(
25 | batch_size, num_features, 8, num_points)
26 | y = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, 8))
27 |
28 | # SL(3) transformation
29 | Y = torch.linalg.matrix_exp(hat_layer(y))
30 |
31 | # model = LNInvariant(num_features,method='learned_killing')
32 | model = SL3InvariantLayers(num_features)
33 |
34 | x_hat = hat_layer(x.transpose(2, -1))
35 | new_x_hat = torch.matmul(Y, torch.matmul(x_hat, torch.inverse(Y)))
36 | new_x = vee_sl3(new_x_hat).transpose(2, -1)
37 |
38 | model.eval()
39 | with torch.no_grad():
40 | out_x = model(x)
41 | out_new_x = model(new_x)
42 |
43 | test_result = torch.allclose(
44 | out_new_x, out_x, rtol=1e-4, atol=1e-4)
45 |
46 | print("out x[0,0,:]", out_x[0, :])
47 | print("out new x[0,0,:]: ", out_new_x[0, :])
48 | print("differences: ", out_x[ 0, :] - out_new_x[ 0, :])
49 |
50 | print("The network is equivariant: ", test_result)
51 |
--------------------------------------------------------------------------------
/script/run_euler_poincare.bash:
--------------------------------------------------------------------------------
1 | # python experiment/euler_poincare_eq_train.py --adjoint --viz --inertia_type iss --model_type LN_ode7 --fig_save_path figures/final_LN_ode7 --model_save_path weights/final_LN_ode7
2 | # python experiment/euler_poincare_eq_train.py --adjoint --viz --inertia_type iss --model_type neural_ode --fig_save_path figures/final_neural_ode --model_save_path weights/final_neural_ode
3 | # python experiment/euler_poincare_eq_train.py --adjoint --viz --inertia_type iss --model_type LN_ode7 --fig_save_path figures/final_LN_ode7_train_on_one --model_save_path weights/final_LN_ode7_train_on_one --num_training 1
4 | # python experiment/euler_poincare_eq_train.py --adjoint --viz --inertia_type iss --model_type neural_ode --fig_save_path figures/final_neural_ode_train_on_one --model_save_path weights/final_neural_ode_train_on_one --num_training 1
5 | # python experiment/euler_poincare_eq_train.py --adjoint --viz --inertia_type iss --model_type neural_ode2 --fig_save_path figures/final_neural_ode2_train_on_one --model_save_path weights/final_neural_ode2_train_on_one --num_training 1
6 | # python experiment/euler_poincare_eq_train.py --adjoint --viz --inertia_type iss --model_type neural_ode3 --fig_save_path figures/final_neural_ode3 --model_save_path weights/final_neural_ode3
7 |
8 | python experiment/euler_poincare_eq_train.py --adjoint --viz --inertia_type iss --model_type LN_ode8 --fig_save_path figures/final_LN_ode8 --model_save_path weights/final_LN_ode8
9 | python experiment/euler_poincare_eq_train.py --adjoint --viz --inertia_type iss --model_type LN_ode8 --fig_save_path figures/final_LN_ode8_train_on_one --model_save_path weights/final_LN_ode8_train_on_one --num_training 1
--------------------------------------------------------------------------------
/matlab/test_Ad_N.m:
--------------------------------------------------------------------------------
1 | clear;
2 | clc;
3 |
4 | num_sample = 100000;
5 | eps = 1e-5;
6 | for i=1:num_sample
7 | %% generate N
8 | n = rand(3,1);
9 | N = [1 n(1) n(2); 0 1 n(3); 0 0 1];
10 | N = 1/det(N)^(1/3) * N;
11 | n1 = N(1,2);
12 | n2 = N(1,3);
13 | n3 = N(2,3);
14 |
15 | Ad_N =...
16 | [[ 1, n1, n1, 0, 0, 0, n2/2 + (n1*n3)/2, -n3/2];
17 | [ -n1, 1 - n1^2/2, -n1^2/2, 0, 0, 0, n3/2 - (n1*n2)/2, n2/2];
18 | [ n1, n1^2/2, n1^2/2 + 1, 0, 0, 0, n3/2 + (n1*n2)/2, -n2/2];
19 | [ 0, 0, 0, 1, 0, 0, n2/2 - (n1*n3)/2, n3/2];
20 | [2*n1*n3 - n2, - n3 - n1*(n2 - n1*n3), n3 - n1*(n2 - n1*n3), -3*n2, 1, n1, -n2*(n2 - n1*n3), -n2*n3];
21 | [ n3, n1*n3 - n2, n1*n3 - n2, -3*n3, 0, 1, -n3*(n2 - n1*n3), -n3^2];
22 | [ 0, 0, 0, 0, 0, 0, 1, 0];
23 | [ 0, 0, 0, 0, 0, 0, -n1, 1]];
24 |
25 | %% generate sl(3)
26 | v = rand(8,1);
27 | v_hat = hat_sl3(v);
28 | v_hat_ad = N * v_hat * inv(N);
29 | v_ad = Ad_N * v; % Linear map on v
30 | v_hat_ad_vee = vee_sl3(v_hat_ad); % vee after adjoint on v_hat
31 |
32 | diff = v_ad - v_hat_ad_vee; % They should be the same
33 |
34 | if norm(diff) > eps
35 | disp('error bigger than eps');
36 | disp(diff)
37 | elseif i==num_sample
38 | disp('Passed all test case!');
39 | end
40 |
41 | end
42 |
43 |
--------------------------------------------------------------------------------
/matlab/test_killing_relu.m:
--------------------------------------------------------------------------------
1 | clear;
2 | clc;
3 |
4 | %% sl(3) generators
5 | G1 = [1,0,0;
6 | 0, -1, 0;
7 | 0,0,0];
8 | G2 = [0, 1,0;
9 | 1, 0 ,0;
10 | 0, 0 ,0];
11 | G3 = [0, -1, 0;
12 | 1, 0, 0;
13 | 0, 0, 0];
14 | G4 = [1, 0, 0;
15 | 0, 1, 0;
16 | 0, 0, -2];
17 | G5 = [0,0,1;
18 | 0,0,0;
19 | 0,0,0];
20 | G6 = [0,0,0;
21 | 0,0,1;
22 | 0,0,0];
23 | G7 = [0,0,0;
24 | 0,0,0;
25 | 1,0,0];
26 | G8 = [0,0,0;
27 | 0,0,0;
28 | 0,1,0];
29 |
30 | %%
31 | num_sample = 100000;
32 | out_sample = zeros(num_sample,2);
33 | for i=1:num_sample
34 |
35 | a = 2*(rand(8,1)-0.5);
36 | % b = 2*(rand(8,1)-0.5);
37 | b = a;
38 |
39 | A = a(1)*G1+a(2)*G2+a(3)*G3+a(4)*G4+a(5)*G5+a(6)*G6+a(7)*G7+a(8)*G8;
40 | B = b(1)*G1+b(2)*G2+b(3)*G3+b(4)*G4+b(5)*G5+b(6)*G6+b(7)*G7+b(8)*G8;
41 |
42 | if(trace(A)~=0)
43 | print("no");
44 | end
45 | K = -6*trace(A*B);
46 |
47 | % relu
48 | if(K>0)
49 | out_sample(i,1)=K;
50 | out_sample(i,2)=K;
51 | else
52 | out_sample(i,1)=K;
53 | out_sample(i,2)=0;
54 | end
55 | end
56 |
57 |
58 | %%
59 | figure(1)
60 | scatter(out_sample(:,1),out_sample(:,2));
61 |
62 | %%
63 | G = [G1,G2,G3,G4,G5,G6,G7,G8];
64 | %
65 | % def sl3_to_R8(M):
66 | % # [a1 + a4, a2 - a3, a5]
67 | % # [a2 + a3, a4 - a1, a6]
68 | % # [ a7, a8, -2*a4]
69 | % v = torch.zeros(8).to(M.device)
70 | % v[3] = -0.5*M[2,2]
71 | % v[4] = M[0,2]
72 | % v[5] = M[1,2]
73 | % v[6] = M[2,0]
74 | % v[7] = M[2,1]
75 | % v[0] = (M[0,0] - v[3])
76 | %
77 | % v[1] = 0.5*(M[0,1] + M[1,0])
78 | % v[2] = 0.5*(M[1,0] - M[0,1])
79 | % return v
--------------------------------------------------------------------------------
/config/platonic_solid_cls/training_param.yaml:
--------------------------------------------------------------------------------
1 | # train_data_path: "/home/justin/code/LieNeurons/data/sl3_equiv_5_input_data/sl3_equiv_100_s_05_train_data.npz"
2 | # test_data_path: "/home/justin/code/LieNeurons/data/sl3_equiv_5_input_data/sl3_equiv_100_s_05_test_data.npz"
3 | # logger
4 | model_description: "SL3classification"
5 | print_freq: 1
6 | # model_type: "LN_relu_bracket"
7 | # model_save_path: "/home/justin/code/LieNeurons/weights/0928_cls_one_relu_one_bracket"
8 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0928_cls_one_relu_one_bracket"
9 |
10 | # model_type: "LN_relu"
11 | # model_save_path: "/home/justin/code/LieNeurons/weights/0928_cls_one_relu"
12 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0928_cls_one_relu"
13 |
14 | # model_type: "LN_bracket"
15 | # model_save_path: "/home/justin/code/LieNeurons/weights/0928_cls_one_bracket"
16 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0928_cls_one_bracket"
17 | model_type: "MLP"
18 | model_save_path: "/home/justin/code/LieNeurons/weights/rebuttal_cls_mlp_aug_4"
19 | log_writer_path: "/home/justin/code/LieNeurons/logs/rebuttal_cls_mlp_aug_4"
20 | # model_type: "LN_relu_bracket"
21 | # model_save_path: "/home/justin/code/LieNeurons/weights/rebuttal_cls_rb_5"
22 | # log_writer_path: "/home/justin/code/LieNeurons/logs/rebuttal_cls_rb_5"
23 | resume_training: false
24 | resume_model_path: "/home/justin/code/LieNeurons/weights/0921_sl3_equiv_first_try.pth.tar"
25 | num_workers: 1
26 | batch_size: 100
27 | update_every_batch: 1
28 | # initial_learning_rate: 0.00001
29 | initial_learning_rate: 0.0001
30 | learning_rate_decay_rate: 0.0
31 | weight_decay_rate: 0.00
32 | num_train: 500000
33 | num_test: 1000
34 | shuffle: True
35 | train_augmentation: True
36 | rotation_factor: 0.174533 # 10 degrees
37 |
--------------------------------------------------------------------------------
/test/test_equivariant_liebracket.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from core.lie_neurons_layers import *
8 | from core.lie_alg_util import *
9 |
10 |
11 | if __name__ == "__main__":
12 | print("testing equivariant linear layer")
13 |
14 | # test equivariant linear layer
15 | num_points = 100
16 | num_features = 10
17 | out_features = 3
18 |
19 | x = torch.Tensor(np.random.rand(num_features, 8, num_points)
20 | ).reshape(1, num_features, 8, num_points)
21 | y = torch.Tensor(np.random.rand(8))
22 |
23 | hat_layer = HatLayer()
24 |
25 | # SL(3) transformation
26 | Y = torch.linalg.matrix_exp(hat_layer(y))
27 |
28 | model = LNLieBracket(
29 | num_features, share_nonlinearity=False)
30 |
31 | x_hat = hat_layer(x.transpose(2, -1))
32 | new_x_hat = torch.matmul(Y, torch.matmul(x_hat, torch.inverse(Y)))
33 | new_x = vee_sl3(new_x_hat).transpose(2, -1)
34 |
35 | model.eval()
36 | with torch.no_grad():
37 | out_x = model(x)
38 | out_new_x = model(new_x)
39 |
40 | out_x_hat = hat_layer(out_x.transpose(2, -1))
41 | out_x_hat_conj = torch.matmul(Y, torch.matmul(out_x_hat, torch.inverse(Y)))
42 | out_x_conj = vee_sl3(out_x_hat_conj).transpose(2, -1)
43 |
44 | test_result = torch.allclose(
45 | out_new_x, out_x_conj, rtol=1e-4, atol=1e-4)
46 |
47 | print("out x[0,0,:,0]", out_x[0, 0, :, 0])
48 | print("out x conj[0,0,:,0]: ", out_x_conj[0, 0, :, 0])
49 | print("out new x[0,0,:,0]: ", out_new_x[0, 0, :, 0])
50 | print("differences: ",
51 | out_x_conj[0, 0, :, 0] - out_new_x[0, 0, :, 0])
52 |
53 | print("The network is equivariant: ", test_result)
54 |
--------------------------------------------------------------------------------
/test/test_equivariant.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from core.lie_neurons_layers import *
8 | from core.lie_alg_util import *
9 |
10 |
11 | if __name__ == "__main__":
12 | print("testing equivariant linear layer")
13 |
14 | # test equivariant linear layer
15 | num_points = 100
16 | num_features = 10
17 | out_features = 3
18 |
19 | x = torch.Tensor(np.random.rand(num_features, 8, num_points)
20 | ).reshape(1, num_features, 8, num_points)
21 | y = torch.Tensor(np.random.rand(8))
22 |
23 | hat_layer = HatLayer()
24 |
25 | # SL(3) transformation
26 | Y = torch.linalg.matrix_exp(hat_layer(y))
27 |
28 | model = LNLinearAndKillingRelu(
29 | num_features, out_features, share_nonlinearity=True)
30 |
31 | x_hat = hat_layer(x.transpose(2, -1))
32 | new_x_hat = torch.matmul(Y, torch.matmul(x_hat, torch.inverse(Y)))
33 | new_x = vee_sl3(new_x_hat).transpose(2, -1)
34 |
35 | model.eval()
36 | with torch.no_grad():
37 | out_x = model(x)
38 | out_new_x = model(new_x)
39 |
40 | out_x_hat = hat_layer(out_x.transpose(2, -1))
41 | out_x_hat_conj = torch.matmul(Y, torch.matmul(out_x_hat, torch.inverse(Y)))
42 | out_x_conj = vee_sl3(out_x_hat_conj).transpose(2, -1)
43 |
44 | test_result = torch.allclose(
45 | out_new_x, out_x_conj, rtol=1e-4, atol=1e-4)
46 |
47 | print("out x[0,0,:,0]", out_x[0, 0, :, 0])
48 | print("out x conj[0,0,:,0]: ", out_x_conj[0, 0, :, 0])
49 | print("out new x[0,0,:,0]: ", out_new_x[0, 0, :, 0])
50 | print("differences: ",
51 | out_x_conj[0, 0, :, 0] - out_new_x[0, 0, :, 0])
52 |
53 | print("The network is equivariant: ", test_result)
54 |
--------------------------------------------------------------------------------
/test/test_pooling_equivariant.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from core.lie_neurons_layers import *
8 | from core.lie_alg_util import *
9 |
10 | if __name__ == "__main__":
11 | print("testing equivariant linear layer")
12 |
13 | # test equivariant linear layer
14 | num_points = 100
15 | num_features = 10
16 | out_features = 3
17 |
18 | x = torch.Tensor(np.random.rand(num_features, 8, num_points)
19 | ).reshape(1, num_features, 8, num_points)
20 | y = torch.Tensor(np.random.rand(8))
21 |
22 | hat_layer = HatLayer()
23 |
24 | # SL(3) transformation
25 | Y = torch.linalg.matrix_exp(hat_layer(y))
26 |
27 | model = LNLinearAndKillingReluAndPooling(
28 | num_features, out_features, share_nonlinearity=True, abs_killing_form=False)
29 |
30 | x_hat = hat_layer(x.transpose(2, -1))
31 | new_x_hat = torch.matmul(Y, torch.matmul(x_hat, torch.inverse(Y)))
32 | new_x = vee_sl3(new_x_hat).transpose(2, -1)
33 |
34 | model.eval()
35 | with torch.no_grad():
36 | out_x = model(x)
37 | out_new_x = model(new_x)
38 |
39 | out_x_hat = hat_layer(out_x.transpose(2, -1))
40 | out_x_hat_conj = torch.matmul(Y, torch.matmul(out_x_hat, torch.inverse(Y)))
41 | out_x_conj = vee_sl3(out_x_hat_conj).transpose(2, -1)
42 |
43 | test_result = torch.allclose(
44 | out_new_x, out_x_conj, rtol=1e-4, atol=1e-4)
45 |
46 | print("out x[0,0,:]", out_x[0, 0, :])
47 | print("out x conj[0,0,:]: ", out_x_conj[0, 0, :])
48 | print("out new x[0,0,:]: ", out_new_x[0, 0, :])
49 | print("differences: ", out_x_conj[0, 0, :] - out_new_x[0, 0, :])
50 |
51 | print("The network is equivariant: ", test_result)
52 |
--------------------------------------------------------------------------------
/test/test_batch_norm_equivariant.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from core.lie_neurons_layers import *
8 | from core.lie_alg_util import *
9 | from experiment.sl3_equiv_layers import *
10 |
11 | if __name__ == "__main__":
12 | print("testing equivariant linear layer")
13 |
14 | # test equivariant linear layer
15 | num_points = 1
16 | num_features = 10
17 | out_features = 3
18 | batch_size = 20
19 | rnd_scale = 0.5
20 |
21 | hat_layer = HatLayer()
22 |
23 | x = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (batch_size, num_features, 8, num_points))).reshape(
24 | batch_size, num_features, 8, num_points)
25 | y = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, 8))
26 |
27 | # SL(3) transformation
28 | Y = torch.linalg.matrix_exp(hat_layer(y))
29 |
30 | # model = LNLinearAndKillingReluAndPooling(
31 | # num_features, out_features, share_nonlinearity=True, use_batch_norm=True, dim=4)
32 |
33 | model = SL3EquivariantLayers(num_features)
34 |
35 | x_hat = hat_layer(x.transpose(2, -1))
36 | new_x_hat = torch.matmul(Y, torch.matmul(x_hat, torch.inverse(Y)))
37 | new_x = vee_sl3(new_x_hat).transpose(2, -1)
38 |
39 | model.eval()
40 | with torch.no_grad():
41 | out_x = model(x)
42 | out_new_x = model(new_x)
43 |
44 | out_x_hat = hat_layer(out_x)
45 | out_x_hat_conj = torch.matmul(Y, torch.matmul(out_x_hat, torch.inverse(Y)))
46 | out_x_conj = vee_sl3(out_x_hat_conj)
47 |
48 | test_result = torch.allclose(
49 | out_new_x, out_x_conj, rtol=1e-4, atol=1e-4)
50 |
51 | print("out x[0,0,:]", out_x[ 0, :])
52 | print("out x conj[0,0,:]: ", out_x_conj[0, :])
53 | print("out new x[0,0,:]: ", out_new_x[0, :])
54 | print("differences: ", out_x_conj[0, :] - out_new_x[ 0, :])
55 |
56 | print("The network is equivariant: ", test_result)
57 |
--------------------------------------------------------------------------------
/matlab/find_adjoint_KAN.m:
--------------------------------------------------------------------------------
1 | clear;
2 | clc;
3 |
4 | syms n1 n2 n3 a1 a2 v1 v2 v3 v4 v5 v6 v7 v8
5 | assumeAlso([n1 n2 n3 a1 a2 v1 v2 v3 v4 v5 v6 v7 v8],'real')
6 |
7 | syms h1 h2 h3 h4 h5 h6 h7 h8 h9
8 | assumeAlso([h1 h2 h3 h4 h5 h6 h7 h8 h9],'real')
9 | %% basis
10 |
11 | E1 = [1, 0, 0; 0, -1, 0; 0, 0, 0];
12 | E2 = [0, 1, 0; 1, 0, 0; 0, 0, 0];
13 | E3 = [0, -1, 0; 1, 0, 0; 0, 0, 0];
14 | E4 = [1, 0, 0; 0, 1, 0; 0, 0, -2];
15 | E5 = [0, 0, 1; 0, 0, 0; 0, 0, 0];
16 | E6 = [0, 0, 0; 0, 0, 1; 0, 0, 0];
17 | E7 = [0, 0, 0; 0, 0, 0; 1, 0, 0];
18 | E8 = [0, 0, 0; 0, 0, 0; 0, 1, 0];
19 |
20 | E1_vec = reshape(E1,1,[])';
21 | E2_vec = reshape(E2,1,[])';
22 | E3_vec = reshape(E3,1,[])';
23 | E4_vec = reshape(E4,1,[])';
24 | E5_vec = reshape(E5,1,[])';
25 | E6_vec = reshape(E6,1,[])';
26 | E7_vec = reshape(E7,1,[])';
27 | E8_vec = reshape(E8,1,[])';
28 |
29 |
30 | E_vec = [E1_vec, E2_vec, E3_vec, E4_vec, E5_vec, E6_vec, E7_vec, E8_vec];
31 |
32 | x_hat = v1*E1+v2*E2+v3*E3+v4*E4+v5*E5+v6*E6+v7*E7+v8*E8;
33 |
34 | %% find Ad_N
35 | N = [1 n1 n2; 0 1 n3; 0 0 1];
36 | Ad_N_hat = N*x_hat*inv(N);
37 | Ad_N_hat_vec = reshape(Ad_N_hat,1,[])';
38 |
39 | % solve for least square
40 | x = inv(E_vec'*E_vec)*(E_vec')*Ad_N_hat_vec;
41 | var = [v1,v2,v3,v4,v5,v6,v7,v8];
42 | [Ad_N,b]=equationsToMatrix(x,var);
43 |
44 | %% find Ad_A
45 | A = [a1, 0,0; 0,a2,0;0,0,1/a1/a2];
46 | Ad_A_hat = A*x_hat*inv(A);
47 | Ad_A_hat_vec = reshape(Ad_A_hat,1,[])';
48 |
49 | % solve for least square
50 | x = inv(E_vec'*E_vec)*(E_vec')*Ad_A_hat_vec;
51 | var = [v1,v2,v3,v4,v5,v6,v7,v8];
52 | [Ad_A,b]=equationsToMatrix(x,var);
53 |
54 | %% find h
55 | H = [h1,h2,h3;h4,h5,h6;h7,h8,h9];
56 | Ad_H_hat = H*x_hat*inv(H);
57 | Ad_H_hat_vec = reshape(Ad_H_hat,1,[])';
58 |
59 | % solve for least square
60 | x = inv(E_vec'*E_vec)*(E_vec')*Ad_H_hat_vec;
61 | var = [v1,v2,v3,v4,v5,v6,v7,v8];
62 | [Ad_H,b]=equationsToMatrix(x,var);
63 |
64 | Ad_kron_H = kron(inv(H)',Ad_H);
65 |
66 | %%
67 | % sigma = svd(Ad_kron_H);
68 | % z = null(Ad_kron_H - eye(24))
--------------------------------------------------------------------------------
/docker/EMLP/Dockerfile:
--------------------------------------------------------------------------------
1 | # FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel
2 | # FROM pytorch/pytorch:0.4_cuda9_cudnn7
3 | # FROM pytorch/pytorch:1.2-cuda10.0-cudnn7-devel
4 | FROM nvcr.io/nvidia/pytorch:22.10-py3
5 |
6 | LABEL version="0.5"
7 |
8 | # USER root
9 |
10 | ENV DEBIAN_FRONTEND noninteractive
11 |
12 | # build essentials
13 | # RUN apt-get update && apt-get -y install cmake
14 | RUN apt-get update && apt-get install -y vim
15 | RUN apt-get install -y build-essential snap
16 | RUN apt-get update && apt-get install -y git-all
17 | # RUN add-apt-repository ppa:rmescandon/yq
18 | # RUN apt update
19 | # RUN snap install yq
20 | RUN wget https://github.com/mikefarah/yq/releases/latest/download/yq_linux_amd64 -O /usr/bin/yq &&\
21 | chmod +x /usr/bin/yq
22 |
23 | # RUN apt-get update && apt-get install python3.9 -y
24 | # RUN apt-get update && apt-get install python3-pip -y
25 |
26 | # python essentials
27 | # Starting from numpy 1.24.0 np.float would pop up an error.
28 | # Cocoeval tool still uses np.float. As a result we can only use numpy 1.23.0.
29 | # Alternatively, we can clone the cocoapi and modify the cocoeval tool.
30 | # https://github.com/cocodataset/cocoapi/pull/569
31 |
32 | RUN pip install numpy==1.23.0
33 | RUN pip install -U matplotlib
34 | RUN pip install scipy
35 | RUN pip install tensorboardX
36 | RUN pip install wandb
37 |
38 | # torch
39 | RUN pip install torch>=1.7.0
40 | RUN pip install torchvision>=0.8.1
41 |
42 | # einops
43 | RUN pip install einops
44 |
45 | # pytest
46 | RUN pip install pytest
47 |
48 | # setuptool
49 | RUN pip install setuptools
50 |
51 | # torch diff eq
52 | RUN pip install git+https://github.com/rtqichen/torchdiffeq
53 |
54 | # EMLP
55 | RUN pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
56 |
57 | RUN pip install h5py
58 | RUN pip install objax
59 | RUN pip install optax==0.1.7
60 | # RUN pip install plum-dispatc
61 | RUN pip install scikit-learn
62 | # RUN pip install tqdm>=4.38
63 | RUN pip install emlp
64 |
--------------------------------------------------------------------------------
/config/sp4_inv/training_param.yaml:
--------------------------------------------------------------------------------
1 | train_data_path: "/home/justin/code/LieNeurons/data/sp4_inv_data/sp4_inv_10000_s_05_augmented_train_data.npz"
2 | test_data_path: "/home/justin/code/LieNeurons/data/sp4_inv_data/sp4_inv_10000_s_05_test_data.npz"
3 |
4 | # logger
5 | model_description: "SP4InvariantFunctionFitting"
6 | print_freq: 100
7 |
8 | # model_type: "MLP"
9 | # model_save_path: "/home/justin/code/LieNeurons/weights/rebuttal_inv_mlp_augmented_5"
10 | # log_writer_path: "/home/justin/code/LieNeurons/logs/rebuttal_inv_mlp_augmented_5"
11 |
12 | # model_type: "LN_relu_bracket"
13 | # model_save_path: "/home/justin/code/LieNeurons/weights/0404_sp4_inv_relu_bracket"
14 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0404_sp4_inv_relu_bracket"
15 |
16 | # model_type: "LN_relu"
17 | # model_save_path: "/home/justin/code/LieNeurons/weights/1120_one_relu_test"
18 | # log_writer_path: "/home/justin/code/LieNeurons/logs/1120_one_relu_test"
19 |
20 | # model_type: "LN_bracket"
21 | # model_save_path: "/home/justin/code/LieNeurons/weights/1120_one_bracket_3"
22 | # log_writer_path: "/home/justin/code/LieNeurons/logs/1120_one_bracket_3"
23 |
24 | # model_type: "MLP"
25 | # model_save_path: "/home/justin/code/LieNeurons/weights/0404_sp4_inv_MLP_augmented"
26 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0404_sp4_inv_MLP_augmented"
27 |
28 | model_type: "MLP512"
29 | model_save_path: "/home/justin/code/LieNeurons/weights/0404_sp4_inv_MLP_512_augmented"
30 | log_writer_path: "/home/justin/code/LieNeurons/logs/0404_sp4_inv_MLP_512_augmented"
31 |
32 | # model_type: "LN_bracket_no_residual"
33 | # model_save_path: "/home/justin/code/LieNeurons/weights/1120_one_bracket_no_residual"
34 | # log_writer_path: "/home/justin/code/LieNeurons/logs/1120_one_bracket_no_residual"
35 | resume_training: false
36 | resume_model_path: "/home/justin/code/LieNeurons/weights/0921_sl3_inv_first_try.pth.tar"
37 | num_workers: 1
38 | batch_size: 100
39 | update_every_batch: 1
40 | # initial_learning_rate: 0.00003
41 | initial_learning_rate: 0.0001
42 | learning_rate_decay_rate: 0.0
43 | weight_decay_rate: 0.00
44 | num_epochs: 600
45 | shuffle: True
46 |
--------------------------------------------------------------------------------
/config/sl3_inv/training_param.yaml:
--------------------------------------------------------------------------------
1 | train_data_path: "/home/justin/code/LieNeurons/data/sl3_inv_data/sl3_inv_10000_s_05_augmented_train_data.npz"
2 | test_data_path: "/home/justin/code/LieNeurons/data/sl3_inv_data/sl3_inv_10000_s_05_test_data.npz"
3 | # train_data_path: "/home/justin/code/LieNeurons/data/sl3_inv_data/sl3_inv_10000_s_05_train_data.npz"
4 | # test_data_path: "/home/justin/code/LieNeurons/data/sl3_inv_data/sl3_inv_10000_s_05_test_data.npz"
5 |
6 | # train_data_path: "/home/justin/code/LieNeurons/data/sl3_inv_data/sl3_inv_10000_s_05_augmented_train_data.npz"
7 | # test_data_path: "/home/justin/code/LieNeurons/data/sl3_inv_data/sl3_inv_10000_s_05_test_data.npz"
8 | # logger
9 | model_description: "SL3InvariantFunctionFitting"
10 | print_freq: 100
11 | model_type: "MLP"
12 | model_save_path: "/home/justin/code/LieNeurons/weights/rebuttal_inv_mlp_augmented_5"
13 | log_writer_path: "/home/justin/code/LieNeurons/logs/rebuttal_inv_mlp_augmented_5"
14 | # model_type: "LN_relu_bracket"
15 | # model_save_path: "/home/justin/code/LieNeurons/weights/1120_one_relu_one_bracket"
16 | # log_writer_path: "/home/justin/code/LieNeurons/logs/1120_one_relu_one_bracket"
17 |
18 | # model_type: "LN_relu"
19 | # model_save_path: "/home/justin/code/LieNeurons/weights/1120_one_relu_test"
20 | # log_writer_path: "/home/justin/code/LieNeurons/logs/1120_one_relu_test"
21 |
22 | # model_type: "LN_bracket"
23 | # model_save_path: "/home/justin/code/LieNeurons/weights/1120_one_bracket_3"
24 | # log_writer_path: "/home/justin/code/LieNeurons/logs/1120_one_bracket_3"
25 |
26 | # model_type: "MLP"
27 | # model_save_path: "/home/justin/code/LieNeurons/weights/1120_MLP"
28 | # log_writer_path: "/home/justin/code/LieNeurons/logs/1120_MLP"
29 |
30 | # model_type: "LN_bracket_no_residual"
31 | # model_save_path: "/home/justin/code/LieNeurons/weights/1120_one_bracket_no_residual"
32 | # log_writer_path: "/home/justin/code/LieNeurons/logs/1120_one_bracket_no_residual"
33 | resume_training: false
34 | resume_model_path: "/home/justin/code/LieNeurons/weights/0921_sl3_inv_first_try.pth.tar"
35 | num_workers: 1
36 | batch_size: 100
37 | update_every_batch: 1
38 | # initial_learning_rate: 0.00003
39 | initial_learning_rate: 0.0001
40 | learning_rate_decay_rate: 0.0
41 | weight_decay_rate: 0.00
42 | num_epochs: 600
43 | shuffle: True
44 |
--------------------------------------------------------------------------------
/config/sl3_equiv/training_param.yaml:
--------------------------------------------------------------------------------
1 | # train_data_path: "/home/justin/code/LieNeurons/data/sl3_equiv_lie_bracket_data/sl3_equiv_10000_lie_bracket_2inputs_augmented_train_data.npz"
2 | # test_data_path: "/home/justin/code/LieNeurons/data/sl3_equiv_lie_bracket_data/sl3_equiv_10000_lie_bracket_2inputs_test_data.npz"
3 | train_data_path: "/home/justin/code/LieNeurons/data/sl3_equiv_lie_bracket_data/sl3_equiv_10000_lie_bracket_2inputs_train_data.npz"
4 | test_data_path: "/home/justin/code/LieNeurons/data/sl3_equiv_lie_bracket_data/sl3_equiv_10000_lie_bracket_2inputs_test_data.npz"
5 | # train_data_path: "/home/justin/code/LieNeurons/data/sl3_equiv_lie_bracket_data/sl3_equiv_10000_lie_bracket_2inputs_augmented_train_data.npz"
6 | # test_data_path: "/home/justin/code/LieNeurons/data/sl3_equiv_lie_bracket_data/sl3_equiv_10000_lie_bracket_2inputs_test_data.npz"
7 | # logger
8 | model_description: "SL3EquivariantFunctionFitting"
9 | print_freq: 100
10 | # model_type: "LN_relu_bracket"
11 | # model_save_path: "/home/justin/code/LieNeurons/weights/1119_equiv_one_relu_one_bracket_final"
12 | # log_writer_path: "/home/justin/code/LieNeurons/logs/1119_equiv_one_relu_one_bracket_final"
13 |
14 | # model_type: "LN_relu"
15 | # model_save_path: "/home/justin/code/LieNeurons/weights/1119_equiv_one_relu"
16 | # log_writer_path: "/home/justin/code/LieNeurons/logs/1119_equiv_one_relu"
17 |
18 | # model_type: "LN_bracket"
19 | # model_save_path: "/home/justin/code/LieNeurons/weights/1119_equiv_one_bracket_3"
20 | # log_writer_path: "/home/justin/code/LieNeurons/logs/1119_equiv_one_bracket_3"
21 |
22 | # model_type: "MLP"
23 | # model_save_path: "/home/justin/code/LieNeurons/weights/rebuttal_equiv_mlp_augmented_5"
24 | # log_writer_path: "/home/justin/code/LieNeurons/logs/rebuttal_equiv_mlp_augmented_5"
25 |
26 | model_type: "LN_bracket_no_residual"
27 | model_save_path: "/home/justin/code/LieNeurons/weights/0404_equiv_one_bracket_no_residual_test"
28 | log_writer_path: "/home/justin/code/LieNeurons/logs/0404_equiv_one_bracket_no_residual_test"
29 |
30 | resume_training: false
31 | resume_model_path: "/home/justin/code/LieNeurons/weights/0921_sl3_equiv_first_try.pth.tar"
32 | num_workers: 1
33 | batch_size: 100
34 | update_every_batch: 1
35 | # initial_learning_rate: 0.0001
36 | # initial_learning_rate: 0.000003
37 | # initial_learning_rate: 0.000001
38 | initial_learning_rate: 0.00001
39 | learning_rate_decay_rate: 0.0
40 | weight_decay_rate: 0.00
41 | num_epochs: 1500
42 | shuffle: True
43 |
--------------------------------------------------------------------------------
/data_loader/sp4_inv_data_loader.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import Dataset
7 | from scipy.linalg import expm
8 | from tqdm import tqdm
9 |
10 |
11 | from einops import rearrange, repeat
12 | from einops.layers.torch import Rearrange
13 |
14 | from core.lie_neurons_layers import *
15 |
16 |
17 | class sp4InvDataSet(Dataset):
18 | def __init__(self, data_path, device='cuda'):
19 | data = np.load(data_path)
20 | _, num_points = data['x1'].shape
21 | _,_,num_conjugate = data['x1_conjugate'].shape
22 | self.x1 = rearrange(torch.from_numpy(data['x1']).type(
23 | 'torch.FloatTensor').to(device), 'k n -> n 1 k 1')
24 | self.x2 = rearrange(torch.from_numpy(data['x2']).type(
25 | 'torch.FloatTensor').to(device), 'k n -> n 1 k 1')
26 | self.x = torch.cat((self.x1, self.x2), dim=1)
27 |
28 | self.x1_conjugate = rearrange(torch.from_numpy(data['x1_conjugate']).type(
29 | 'torch.FloatTensor').to(device),'k n c -> c n 1 k 1')
30 | self.x2_conjugate = rearrange(torch.from_numpy(data['x2_conjugate']).type(
31 | 'torch.FloatTensor').to(device),'k n c -> c n 1 k 1')
32 | self.x_conjugate = torch.cat(
33 | (self.x1_conjugate, self.x2_conjugate), dim=2)
34 | self.y = torch.from_numpy(data['y']).type(
35 | 'torch.FloatTensor').to(device).reshape(num_points, 1)
36 |
37 | self.num_data = self.x1.shape[0]
38 |
39 | def __len__(self):
40 | return self.num_data
41 |
42 | def __getitem__(self, idx):
43 | if torch.is_tensor(idx):
44 | idx = idx.tolist()
45 |
46 | sample = {'x1': self.x1[idx, :, :, :], 'x2': self.x2[idx, :, :, :], 'x': self.x[idx, :, :, :],
47 | 'x1_conjugate': self.x1_conjugate[:,idx, :, :, :], 'x2_conjugate': self.x2_conjugate[:,idx, :, :, :],
48 | 'x_conjugate': self.x_conjugate[:,idx, :, :, :], 'y': self.y[idx, :]}
49 | return sample
50 |
51 |
52 | if __name__ == "__main__":
53 |
54 | DataLoader = sp4InvDataSet("data/sp4_inv_data/sp4_inv_10000_s_05_train_data.npz")
55 |
56 | print(DataLoader.x1.shape)
57 | print(DataLoader.x2.shape)
58 | print(DataLoader.x1_conjugate.shape)
59 | print(DataLoader.x2_conjugate.shape)
60 | print(DataLoader.x.shape)
61 | print(DataLoader.x_conjugate.shape)
62 | print(DataLoader.y.shape)
63 | for i, samples in tqdm(enumerate(DataLoader, start=0)):
64 | input_data = samples['x']
65 | y = samples['y']
66 | print(input_data.shape)
67 | print("x1: \n",input_data[0,:,0])
68 | print("x2: \n",input_data[1,:,0])
69 | print("y: ", y)
70 | print("------------------")
71 | # print(input_data.shape)
72 |
--------------------------------------------------------------------------------
/test/test_so3_bch_time.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import numpy as np
5 | import scipy.linalg
6 | import torch
7 | import math
8 | import time
9 |
10 | from core.lie_alg_util import *
11 | from core.lie_group_util import *
12 | from experiment.so3_bch_layers import *
13 |
14 | if __name__ == "__main__":
15 |
16 |
17 | num_test = 10000
18 |
19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20 | print('Using ', device)
21 |
22 | so3_hatlayer = HatLayer(algebra_type='so3')
23 |
24 | def gen_random_rotation_vector():
25 | v = torch.rand(1,3)
26 | v = v/torch.norm(v)
27 | phi = (math.pi-1e-6)*torch.rand(1)
28 | v = phi*v
29 | return v
30 | sum_t_first_order = 0
31 | sum_t_second_order = 0
32 | sum_t_third_order = 0
33 | sum_t_mlp = 0
34 | sum_t_LN = 0
35 |
36 | model = SO3EquivariantReluBracketLayers(2).to(device)
37 |
38 | # model = SO3EquivariantReluLayers(2).to(device)
39 | # model = SO3EquivariantBracketLayers(2).to(device)
40 | mlp_model = MLP(6).to(device)
41 |
42 | checkpoint = torch.load('/home/justin/code/LieNeurons/weights/0120_so3_bch_relu_bracket_best_test_loss_acc.pt')
43 | model.load_state_dict(checkpoint['model_state_dict'],strict=False)
44 |
45 | checkpoint = torch.load('/home/justin/code/LieNeurons/weights/0327_so3_bch_mlp_4_layers_1024_augmented_best_test_loss_acc.pt')
46 | mlp_model.load_state_dict(checkpoint['model_state_dict'],strict=False)
47 |
48 | # generate training data
49 | for i in range(num_test):
50 | # generate random v1, v2
51 | v1 = gen_random_rotation_vector()
52 | v2 = gen_random_rotation_vector()
53 |
54 | K1 = so3_hatlayer(v1).to(device)
55 | K2 = so3_hatlayer(v2).to(device)
56 |
57 |
58 | t1 = time.time()
59 | out = BCH_first_order_approx(K1,K2)
60 | t2 = time.time()
61 | sum_t_first_order += t2-t1
62 |
63 | t1 = time.time()
64 | out = BCH_second_order_approx(K1,K2)
65 | t2 = time.time()
66 | sum_t_second_order += t2-t1
67 |
68 | t1 = time.time()
69 | out = BCH_third_order_approx(K1,K2)
70 | t2 = time.time()
71 | sum_t_third_order += t2-t1
72 |
73 | v1 = v1.to(device)
74 | v2 = v2.to(device)
75 | v = torch.cat((v1,v2),0)
76 |
77 | v = rearrange(v, 'n k -> 1 n k 1')
78 |
79 | t1 = time.time()
80 | out = model(v)
81 | t2 = time.time()
82 | sum_t_LN += t2-t1
83 |
84 | t1 = time.time()
85 | out = mlp_model(v)
86 | t2 = time.time()
87 | sum_t_mlp += t2-t1
88 |
89 |
90 | print("First order time: ", sum_t_first_order/num_test)
91 | print("Second order time: ", sum_t_second_order/num_test)
92 | print("Third order time: ", sum_t_third_order/num_test)
93 | print("LN time: ", sum_t_LN/num_test)
94 | print("MLP time: ", sum_t_mlp/num_test)
95 |
--------------------------------------------------------------------------------
/config/so3_bch/training_param.yaml:
--------------------------------------------------------------------------------
1 | train_data_path: "/home/justin/code/LieNeurons/data/so3_bch_data/so3_bch_10000_train_data.npz"
2 | # train_data_path: "/home/justin/code/LieNeurons/data/so3_bch_data/so3_bch_10000_augmented_train_data.npz"
3 | test_data_path: "/home/justin/code/LieNeurons/data/so3_bch_data/so3_bch_10000_test_data.npz"
4 | # train_data_path: "/home/justin/code/LieNeurons/data/so3_bch_data/so3_bch_10000_approx_no_conj_train_data.npz"
5 | # test_data_path: "/home/justin/code/LieNeurons/data/so3_bch_data/so3_bch_10000_approx_no_conj_test_data.npz"
6 |
7 | # logger
8 | model_description: "SO3 BCH 10000"
9 | print_freq: 100
10 | # model_type: "LN_relu_bracket"
11 | # model_save_path: "/home/justin/code/LieNeurons/weights/0125_so3_bch_relu_bracket"
12 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0125_so3_bch_relu_bracket"
13 |
14 | # model_type: "LN_relu"
15 | # model_save_path: "/home/justin/code/LieNeurons/weights/0120_so3_bch_relu_4_layers_1024"
16 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0120_so3_bch_relu_4_layers_1024"
17 |
18 | # model_type: "LN_bracket"
19 | # model_save_path: "/home/justin/code/LieNeurons/weights/0124_so3_bch_bracket"
20 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0124_so3_bch_bracket"
21 |
22 | # model_type: "MLP"
23 | # model_save_path: "/home/justin/code/LieNeurons/weights/0327_so3_bch_mlp_4_layers_1024_augmented"
24 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0327_so3_bch_mlp_4_layers_1024_augmented"
25 |
26 | # model_type: "LN_bracket_no_residual"
27 | # model_save_path: "/home/justin/code/LieNeurons/weights/0928_equiv_one_bracket_no_residual"
28 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0928_equiv_one_bracket_no_residual"
29 |
30 | # model_type: "EMLP"
31 | # model_save_path: "/home/justin/code/LieNeurons/weights/0326_BCH_EMLP_512_test"
32 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0326_BCH_EMLP_512_test"
33 |
34 | # model_type: "e3nn_norm"
35 | # model_save_path: "/home/justin/code/LieNeurons/weights/0404_BCH_e3nn_norm_1024"
36 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0404_BCH_e3nn_norm_1024"
37 |
38 | model_type: "e3nn_s2grid"
39 | model_save_path: "/home/justin/code/LieNeurons/weights/0404_BCH_e3nn_s2grid_small_lr"
40 | log_writer_path: "/home/justin/code/LieNeurons/logs/0404_BCH_e3nn_s2grid_small_lr"
41 |
42 | # model_type: "VN_relu"
43 | # model_save_path: "/home/justin/code/LieNeurons/weights/0327_BCH_VN"
44 | # log_writer_path: "/home/justin/code/LieNeurons/logs/0327_BCH_VN"
45 |
46 | resume_training: False
47 | resume_model_path: "/home/justin/code/LieNeurons/weights/0326_BCH_EMLP_512_test_best_test_loss_acc.pt"
48 |
49 | num_workers: 1
50 | batch_size: 100
51 | update_every_batch: 1
52 | initial_learning_rate: 0.0001
53 | # initial_learning_rate: 0.000003
54 | # initial_learning_rate: 0.00001
55 | # initial_learning_rate: 0.0005 # EMLP
56 | learning_rate_decay_rate: 0.0
57 | weight_decay_rate: 0.00
58 | num_epochs: 10000
59 | shuffle: True
60 | full_eval_during_training: False
61 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Lie Neurons: Adjoint Equivariant Neural Networks for Semi-simple Lie Algebras
2 |
3 |
4 | Tzu-Yuan Lin*1
5 | Minghan Zhu*1
6 | Maani Ghaffari1
7 |
8 | *Eqaul Contributions 1University of Michigan, Ann Arbor
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 | ## About
23 | An MLP framework that takes Lie algebraic data as inputs and is equivariant to the adjoint representation of the group by construction.
24 |
25 | 
26 |
27 | ## Modules
28 | 
29 |
30 | ## Updates
31 | * [07/2024] The initial code is open-sourced. We are still re-organizing the code. We plan to release a cleaner version of the code soon. Feel free to reach out if you have any questions! :)
32 | * [07/2024] We presented our paper at ICML 24!
33 |
34 | ## Docker
35 | * We provide [docker](https://docs.docker.com/get-started/) files in [`docker/`](https://github.com/UMich-CURLY/LieNeurons/tree/main/docker).
36 | * Detailed tutorial on how to build the docker container can be found in the README in each docker folder.
37 |
38 | ## Training the Network
39 | * All the training codes for experiments are in [`experiment/`](https://github.com/UMich-CURLY/LieNeurons/tree/main/experiment).
40 | * Before training, you'll have to generate the data using Python scripts in [`data_gen`](https://github.com/UMich-CURLY/LieNeurons/tree/main/data_gen).
41 | * Empirically, we found out that using a lower learning rate (around `3e-5`) helps the convergence during training. This is likely due to the lack of normalization layers.
42 | * When working with $\mathfrak{so}(3)$, Lie Neurons specialize to [Vector Neurons](https://github.com/FlyingGiraffe/vnn) with an additional bracket nonlinearity and a channel mixing layer. Since the inner product is well-defined on $\mathfrak{so}(3)$, one can plug in the batch normalization layers proposed in Vector Neurons to improve stability during training.
43 |
44 | ## Citation
45 | If you find the work useful, please kindly cite our paper:
46 | ```
47 | @inproceedings{lin2024ln,
48 | title={{Lie Neurons}: {Adjoint}-Equivariant Neural Networks for Semisimple {Lie} Algebras},
49 | author={Lin, Tzu-Yuan and Zhu, Minghan and Ghaffari, Maani},
50 | booktitle={International Conference on Machine Learning},
51 | pages={30529--30545},
52 | year={2024},
53 | organization={PMLR}
54 | }
55 | ```
56 |
--------------------------------------------------------------------------------
/core/vn_layers.py:
--------------------------------------------------------------------------------
1 |
2 | import sys
3 | import torch
4 | import torch.nn as nn
5 |
6 | sys.path.append('.')
7 |
8 |
9 | EPS = 1e-6
10 |
11 | class VNLinear(nn.Module):
12 | def __init__(self, in_channels, out_channels):
13 | super(VNLinear, self).__init__()
14 | self.map_to_feat = nn.Linear(in_channels, out_channels, bias=False)
15 |
16 | def forward(self, x):
17 | '''
18 | x: point features of shape [B, N_feat, 3, N_samples, ...]
19 | '''
20 | x_out = self.map_to_feat(x.transpose(1,-1)).transpose(1,-1)
21 | return x_out
22 |
23 |
24 | class VNLeakyReLU(nn.Module):
25 | def __init__(self, in_channels, share_nonlinearity=False, leaky_relu=False, negative_slope=0.2):
26 | super(VNLeakyReLU, self).__init__()
27 | if share_nonlinearity == True:
28 | self.map_to_dir = nn.Linear(in_channels, 1, bias=False)
29 | else:
30 | self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False)
31 | self.negative_slope = negative_slope
32 | self.leaky_relu = leaky_relu
33 |
34 | def forward(self, x):
35 | '''
36 | x: point features of shape [B, N_feat, 3, N_samples, ...]
37 | '''
38 | d = self.map_to_dir(x.transpose(1,-1)).transpose(1,-1)
39 | dotprod = (x*d).sum(2, keepdim=True)
40 | mask = (dotprod >= 0).float()
41 | d_norm_sq = (d*d).sum(2, keepdim=True)
42 |
43 | if self.leaky_relu:
44 | x_out = self.negative_slope * x + (1-self.negative_slope) * (mask*x + (1-mask)*(x-(dotprod/(d_norm_sq+EPS))*d))
45 | else:
46 | x_out = torch.where(dotprod >= 0, x, x-(dotprod/(d_norm_sq+EPS))*d)
47 | return x_out
48 |
49 | class VNLinearAndLeakyReLU(nn.Module):
50 | def __init__(self, in_channels, out_channels, dim=5, share_nonlinearity=False, leaky_relu=False, use_batchnorm='norm', negative_slope=0.2):
51 | super(VNLinearAndLeakyReLU, self).__init__()
52 | self.dim = dim
53 | self.share_nonlinearity = share_nonlinearity
54 | self.use_batchnorm = use_batchnorm
55 | self.negative_slope = negative_slope
56 |
57 | self.linear = VNLinear(in_channels, out_channels)
58 | self.leaky_relu = VNLeakyReLU(out_channels, share_nonlinearity=share_nonlinearity, leaky_relu=leaky_relu, negative_slope=negative_slope)
59 |
60 | # BatchNorm
61 | self.use_batchnorm = use_batchnorm
62 | if use_batchnorm != 'none':
63 | self.batchnorm = VNBatchNorm(out_channels, dim=dim)
64 |
65 | def forward(self, x):
66 | '''
67 | x: point features of shape [B, N_feat, 3, N_samples, ...]
68 | '''
69 | # Conv
70 | x = self.linear(x)
71 | # InstanceNorm
72 | if self.use_batchnorm != 'none':
73 | x = self.batchnorm(x)
74 | # LeakyReLU
75 | x_out = self.leaky_relu(x)
76 | return x_out
77 |
78 |
79 | class VNBatchNorm(nn.Module):
80 | def __init__(self, num_features, dim):
81 | super(VNBatchNorm, self).__init__()
82 | self.dim = dim
83 | if dim == 3 or dim == 4:
84 | self.bn = nn.BatchNorm1d(num_features)
85 | elif dim == 5:
86 | self.bn = nn.BatchNorm2d(num_features)
87 |
88 | def forward(self, x):
89 | '''
90 | x: point features of shape [B, N_feat, 3, N_samples, ...]
91 | '''
92 | norm = torch.sqrt((x*x).sum(2))
93 | norm_bn = self.bn(norm)
94 | norm = norm.unsqueeze(2)
95 | norm_bn = norm_bn.unsqueeze(2)
96 | x = x / norm * norm_bn
97 |
98 | return x
99 |
--------------------------------------------------------------------------------
/matlab/sl3_equivariant_lift.asv:
--------------------------------------------------------------------------------
1 | clc;
2 | clear;
3 |
4 | syms v1 v2 v3 v4 v5 v6 v7 v8
5 | assumeAlso([v1 v2 v3 v4 v5 v6 v7 v8],'real')
6 |
7 | syms h1 h2 h3 h4 h5 h6 h7 h8 h9
8 | assumeAlso([h1 h2 h3 h4 h5 h6 h7 h8 h9],'real')
9 | %%
10 |
11 | E1 = [1, 0, 0; 0, -1, 0; 0, 0, 0];
12 | E2 = [0, 1, 0; 1, 0, 0; 0, 0, 0];
13 | E3 = [0, -1, 0; 1, 0, 0; 0, 0, 0];
14 | E4 = [1, 0, 0; 0, 1, 0; 0, 0, -2];
15 | E5 = [0, 0, 1; 0, 0, 0; 0, 0, 0];
16 | E6 = [0, 0, 0; 0, 0, 1; 0, 0, 0];
17 | E7 = [0, 0, 0; 0, 0, 0; 1, 0, 0];
18 | E8 = [0, 0, 0; 0, 0, 0; 0, 1, 0];
19 |
20 | Ex = [0, 0, 0;0, 0, -1;0, 1, 0];
21 | Ey = [0, 0, 1;0, 0, 0;-1, 0, 0];
22 | Ez = [0, -1, 0;1, 0, 0;0, 0, 0];
23 |
24 | Ea1 = [1,0,0;0,0,0;0,0,-1];
25 | Ea2 = [0,0,0;0,1,0;0,0,-1];
26 |
27 | Enx = [0,0,1;0,0,0;0,0,0];
28 | Eny = [0,0,0;0,0,1;0,0,0
29 |
30 | % E1 = [0, 0, 1; 0, 0, 0; 0, 0, 0];
31 | % E2 = [0, 0, 0; 0, 0, 1; 0, 0, 0];
32 | % E3 = [0, -1, 0; 1, 0, 0; 0, 0, 0];
33 | % E4 = [0, 0, 0; 0, 0, 0; 0, 0, -1];
34 | % E5 = [1, 0, 0; 0, -1, 0; 0, 0, 0];
35 | % E6 = [0, 1, 0; 0, 0, 0; 0, 0, 0];
36 | % E7 = [0, 0, 0; 0, 0, 0; 1, 0, 0];
37 | % E8 = [0, 0, 0; 0, 0, 0; 0, 1, 0];
38 |
39 | E = {E1,E2,E3,E4,E5,E6,E7,E8};
40 |
41 | E1_vec = reshape(E1,1,[])';
42 | E2_vec = reshape(E2,1,[])';
43 | E3_vec = reshape(E3,1,[])';
44 | E4_vec = reshape(E4,1,[])';
45 | E5_vec = reshape(E5,1,[])';
46 | E6_vec = reshape(E6,1,[])';
47 | E7_vec = reshape(E7,1,[])';
48 | E8_vec = reshape(E8,1,[])';
49 |
50 |
51 | E_vec = [E1_vec, E2_vec, E3_vec, E4_vec, E5_vec, E6_vec, E7_vec, E8_vec];
52 |
53 | x_hat = v1*E1+v2*E2+v3*E3+v4*E4+v5*E5+v6*E6+v7*E7+v8*E8;
54 |
55 |
56 | %% find h
57 | H = [h1,h2,h3;h4,h5,h6;h7,h8,h9];
58 | Ad_H_hat = H*x_hat*inv(H);
59 | Ad_H_hat_vec = reshape(Ad_H_hat,1,[])';
60 |
61 | % solve least square to obtain x
62 | x = inv(E_vec'*E_vec)*(E_vec')*Ad_H_hat_vec;
63 | var = [v1,v2,v3,v4,v5,v6,v7,v8];
64 | [Ad_H_sym,b]=equationsToMatrix(x,var);
65 |
66 | syms f(h1,h2,h3,h4,h5,h6,h7,h8,h9);
67 | f(h1,h2,h3,h4,h5,h6,h7,h8,h9) = Ad_H_sym;
68 |
69 | %% find Ad_Ei
70 | Ad_E = {};
71 | dAd_E = {};
72 | for i=1:8
73 | % Find Ad_E_i
74 | H = expm(E{i}); % Using numerical exponential for now
75 | Ad_E{i} = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3)));
76 | dAd_E{i} = logm(Ad_E{i});
77 | end
78 |
79 | %% construct dro_3
80 |
81 | dro_3 = {};
82 | C = [];
83 | for i =1:8
84 | dro_3{i} = kron(-E{i}',eye(8))+kron(eye(3),dAd_E{i});
85 | C = [C;dro_3{i}];
86 | end
87 |
88 | %% solve for the null space
89 | [U,S,V] = svd(C);
90 |
91 |
92 | Q = null(C);
93 |
94 | %%
95 | % rank(C,1e-10)
96 |
97 | %% test solving for one h
98 | a = E1+E2+E3+E4+E5+E6+E7+E8;
99 | % a = E4
100 | H = expm(a);
101 | % rnd = 100*rand(3);
102 | % x=rnd(1);
103 | % y=rnd(2);
104 | % z=rnd(3);
105 | % H = expm(x*Ex+y*Ey+z*Ez);
106 |
107 | % rnd = 100*randn(3);
108 | % n1=rnd(1);
109 | % n2=rnd(2);
110 | % n3=rnd(3);
111 | % H = [1 n1 n2; 0 1 n3; 0 0 1];
112 |
113 | % rnd = 2*randn(3);
114 | % a1=rnd(1);
115 | % a2=rnd(2);
116 | % H = [a1, 0,0; 0,a2,0;0,0,1/a1/a2];
117 |
118 | % H = [0.5, 0,0; 0,2,0;0,0,1/0.5/2];
119 | Ad_test = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3)));
120 |
121 | D = kron(inv(H'),Ad_test)-eye(24);
122 | % D = kron(Ad_test,inv(H'))-eye(24);
123 | % [DU,DS,DV] = svd(D);
124 |
125 | DQ = null(D);
126 |
127 | rank(D)
128 |
129 | wd = DQ*DQ'*ones(24,1);
130 |
131 | Wd = reshape(wd,[8,3]);
132 |
133 | v_test = [2,3,1]';
134 | H_v_test = H*v_test;
135 |
136 | x_test = Wd*v_test;
137 | x_H_test = Wd*H_v_test;
138 | x_ad_test = Ad_test*x_test;
139 | disp([x_H_test,x_ad_test])
140 | % disp(x_H_test - x_ad_test)
141 |
142 | % rank([E1;E2;E3;E4;E5;E6;E7;E8])
143 |
144 | %%
145 |
--------------------------------------------------------------------------------
/playground/neural_ode_nonhomogeneous.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | from scipy.signal import lti
7 | from scipy.signal import lsim2
8 | from scipy import interpolate
9 | from random import randint
10 | import matplotlib.pyplot as plt
11 | from torchdiffeq import odeint # modified version
12 |
13 | parser = argparse.ArgumentParser('Pendel')
14 | parser.add_argument('--data_size', type=int, default=1000)
15 | parser.add_argument('--batch_time', type=int, default=10)
16 | parser.add_argument('--batch_size', type=int, default=20)
17 | parser.add_argument('--niters', type=int, default=300)
18 | parser.add_argument('--test_freq', type=int, default=20)
19 | args = parser.parse_args()
20 |
21 | # initial conditions / system dynamics
22 | x0 = [0]
23 | # definition of the LTI-system
24 | A = np.array([0.05])
25 | B = np.array([0.05]) # x' = 0.05 x + 0.05 u
26 | C = np.array([1])
27 | D = np.array([0]) # y = x
28 | system = lti(A, B, C, D)
29 |
30 | t = torch.linspace(0., 25., args.data_size)
31 | u = torch.Tensor(np.ones_like(t)) # external input -- validation
32 | tout, y, x = lsim2(system, u, t, x0)
33 | true_y = torch.FloatTensor(y)
34 | ufunc_val = interpolate.interp1d(t, u.reshape(-1,1), kind='linear', axis=0, bounds_error=False, fill_value='extrapolate')
35 | y0 = torch.FloatTensor(x0)
36 |
37 | def get_batch():
38 | u_in = torch.zeros(args.batch_time,args.batch_size)
39 | true_x = torch.zeros(args.batch_time,args.batch_size)
40 | batch_t = t[:args.batch_time] # (T)
41 | for i in range(args.batch_size):
42 | s = torch.from_numpy(np.random.choice(np.arange(args.data_size - args.batch_time, dtype=np.int64), 1, replace=False))
43 | u = torch.ones(args.batch_time)+randint(0,10)/5 # external input -- training
44 | tout, y2, x = lsim2(system, u, batch_t, y[s])
45 | tout = torch.tensor(tout)
46 | x = torch.tensor(x)
47 | u_in[:,i] = u.reshape(1, args.batch_time)
48 | true_x[:,i] = x.reshape(1, args.batch_time)
49 | batch_x0 = true_x[0,:]
50 | return batch_x0, batch_t, true_x, u_in
51 |
52 |
53 | def visualize(true_x, pred_x):
54 | plt.title('Federpendel')
55 | plt.xlabel('t')
56 | plt.ylabel('x,v')
57 | plt.plot(t.numpy(), true_x.numpy())
58 | plt.plot(t.numpy(), pred_x.numpy(),'--')
59 | plt.xlim(t.min(), t.max())
60 | plt.tight_layout()
61 | plt.draw()
62 | plt.pause(0.001)
63 |
64 | class ODEFunc(nn.Module):
65 |
66 | def __init__(self):
67 | super(ODEFunc, self).__init__()
68 | self.linear = nn.Linear(2,1)
69 |
70 | def forward(self, t, x, args):
71 | ufun = args[0]
72 | unew = torch.FloatTensor(ufun(t.detach().numpy()))
73 | in1 = torch.stack((x, unew), dim=1)
74 | out = self.linear(in1).squeeze()
75 | return out
76 |
77 |
78 | if __name__ == '__main__':
79 |
80 | func = ODEFunc()
81 | optimizer = optim.RMSprop(func.parameters(), lr=0.02)
82 | lossb = nn.MSELoss()
83 | for itr in range(1, args.niters + 1):
84 | optimizer.zero_grad()
85 | batch_x0, batch_t, batch_x, batch_u = get_batch()
86 | ufunc = interpolate.interp1d(batch_t, batch_u, kind='linear', axis=0, bounds_error=False, fill_value='extrapolate')
87 | pred_x = odeint(func, batch_x0, batch_t, args=(ufunc,), method='dopri5')
88 | loss = lossb(pred_x, batch_x)
89 | loss.backward()
90 | optimizer.step()
91 | if itr % args.test_freq == 0:
92 | with torch.no_grad():
93 | pred_x = odeint(func, y0, t, args=(ufunc_val,), method='dopri5')
94 | loss = lossb(pred_x.squeeze(), true_y.squeeze())
95 | print('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item()))
96 | visualize(true_y.squeeze(), pred_x.squeeze())
--------------------------------------------------------------------------------
/core/lie_group_util.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import numpy as np
3 | import torch
4 | sys.path.append('.')
5 |
6 |
7 | from einops import rearrange, repeat
8 | from einops.layers.torch import Rearrange
9 |
10 | from core.lie_alg_util import lie_bracket, vee_so3
11 |
12 | def skew(v):
13 | # M = [[0 , -v[2], v[1]],
14 | # [v[2], 0, -v[0]],
15 | # [-v[1], v[0], 0]]
16 |
17 | v2 = v.clone()
18 | M = torch.zeros((3,3))
19 | M[0,1] = -v2[2]
20 | M[0,2] = v2[1]
21 | M[1,0] = v2[2]
22 | M[1,2] = -v2[0]
23 | M[2,0] = -v2[1]
24 | M[2,1] = v2[0]
25 | print("M", M)
26 | return M
27 |
28 |
29 | def exp_hat_and_so3(w):
30 | I = torch.eye(3)
31 | theta = torch.norm(w)
32 | A = skew(w)
33 | return I + (torch.sin(theta)/theta)*A + ((1-torch.cos(theta))/(theta*theta))*torch.matmul(A,A)
34 |
35 |
36 | # def bacth_exp_so3(A):
37 | # # A (B,3,3)
38 | # I = torch.eye(3).to(A.device)
39 | # theta = torch.sqrt(-torch.trace(torch.matmul(A,A))/2.0)
40 | # return I + (torch.sin(theta)/theta)*A + ((1-torch.cos(theta))/(theta*theta))*torch.matmul(A,A)
41 |
42 | def batch_trace(A):
43 | return A.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)
44 |
45 | def exp_so3(A):
46 | I = torch.eye(3).to(A.device)
47 | # print("A: ",A.shape)
48 | theta = torch.sqrt(-batch_trace(torch.matmul(A,A))/2.0).reshape(-1,1,1)
49 | # print("theta: ", theta.shape)
50 | return I + (torch.sin(theta)/theta)*A + ((1-torch.cos(theta))/(theta**2))*torch.matmul(A,A)
51 |
52 | def log_SO3(R):
53 |
54 | # print("trace: ", torch.trace(R))
55 | # print("R",R)
56 | cos_R = (batch_trace(R)-1)/2.0
57 | # print("cos_R", cos_R)
58 | theta = torch.acos(torch.clamp(cos_R, -1+1e-7, 1-1e-7))
59 | # theta2 = torch.asin(torch.sqrt((3-batch_trace(R))*(1+batch_trace(R)))/2.0)
60 | # print("theta: ", theta)
61 | # print("theta2: ", theta2)
62 |
63 | if torch.isnan(theta):
64 | raise ValueError("theta is nan")
65 | # return torch.zeros((3,3)).to(R.device)
66 | # if torch.isnan(theta):
67 | # return torch.zeros((3,3)).to(R.device)
68 | # print("theta: ", theta)
69 | # if theta - np.pi < 1e-6:
70 | # return theta/2/(np.pi-theta)*(R-R.T)
71 | # elif theta > np.pi:
72 | # theta = np.pi-theta
73 | # K = (theta/(2*torch.sin(theta)))[:,None,None]*(R-R.transpose(-1,-2))
74 |
75 | return (theta/(2*torch.sin(theta)))[:,None,None]*(R-R.transpose(-1,-2))
76 |
77 |
78 | def BCH_first_order_approx(X,Y):
79 | return X+Y
80 |
81 | def BCH_second_order_approx(X,Y):
82 | return X+Y+1/2*lie_bracket(X,Y)
83 |
84 | def BCH_third_order_approx(X,Y):
85 | return X+Y+1/2*lie_bracket(X,Y)+1/12*lie_bracket(X,lie_bracket(X,Y))-1/12*lie_bracket(Y,lie_bracket(X,Y))
86 |
87 | def BCH_approx(X,Y):
88 | return X+Y+1/2*lie_bracket(X,Y)+1/12*lie_bracket(X,lie_bracket(X,Y))-1/12*lie_bracket(Y,lie_bracket(X,Y))\
89 | -1/24*lie_bracket(Y,lie_bracket(X,lie_bracket(X,Y)))-1/720*lie_bracket(Y,lie_bracket(Y,lie_bracket(Y,lie_bracket(Y,X))))\
90 | -1/720*lie_bracket(X,lie_bracket(X,lie_bracket(X,lie_bracket(X,Y))))
91 |
92 | def BCH_so3(X,Y):
93 | x = vee_so3(X)
94 | y = vee_so3(Y)
95 | theta = torch.norm(x)
96 | phi = torch.norm(y)
97 | delta = torch.acos(x.transpose(-1,0)@y/torch.norm(x)/torch.norm(y)) # angles between x and y
98 | a = torch.sin(theta)*torch.cos(phi/2.0)*torch.cos(phi/2.0)-torch.sin(phi)*torch.sin(theta/2.0)*torch.sin(theta/2.0)*torch.cos(delta)
99 | b = torch.sin(phi)*torch.cos(theta/2.0)*torch.cos(theta/2.0)-torch.sin(theta)*torch.sin(phi/2.0)*torch.sin(phi/2.0)*torch.cos(delta)
100 | c = 1/2.0*torch.sin(theta)*torch.sin(phi)-2.0*torch.sin(theta/2.0)*torch.sin(theta/2.0)*torch.sin(phi/2.0)*torch.sin(phi/2.0)*torch.cos(delta)
101 | d = torch.sqrt(a*a+b*b+2.0*a*b*torch.cos(delta)+c*c*torch.sin(delta)*torch.sin(delta))
102 |
103 | alpha = torch.asin(d)*a/d/theta
104 | beta = torch.asin(d)*b/d/phi
105 | gamma = torch.asin(d)*c/d/theta/phi
106 |
107 | return alpha*X+beta*Y+gamma*lie_bracket(X,Y)
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 |
163 | # matlab
164 | .asv
165 | .fig
166 |
167 | weights/
168 | logs/
169 | data/
170 | logs_rebuttal/
171 | logs_archive/
172 | logs_good/
173 | png_ln_ode_demo/
174 | png_neural_ode_demo/
175 | figures/
176 | training_all.txt
177 | eval_results.txt
178 | eval_aug_mlp.txt
179 | training_mlp_aug_equiv.txt
180 |
--------------------------------------------------------------------------------
/matlab/sl3_equivariant_lift.m:
--------------------------------------------------------------------------------
1 | clc;
2 | clear;
3 |
4 | syms v1 v2 v3 v4 v5 v6 v7 v8
5 | assumeAlso([v1 v2 v3 v4 v5 v6 v7 v8],'real')
6 |
7 | syms h1 h2 h3 h4 h5 h6 h7 h8 h9
8 | assumeAlso([h1 h2 h3 h4 h5 h6 h7 h8 h9],'real')
9 | %%
10 |
11 | % E1 = [1, 0, 0; 0, -1, 0; 0, 0, 0];
12 | % E2 = [0, 1, 0; 1, 0, 0; 0, 0, 0];
13 | % E3 = [0, -1, 0; 1, 0, 0; 0, 0, 0];
14 | % E4 = [1, 0, 0; 0, 1, 0; 0, 0, -2];
15 | % E5 = [0, 0, 1; 0, 0, 0; 0, 0, 0];
16 | % E6 = [0, 0, 0; 0, 0, 1; 0, 0, 0];
17 | % E7 = [0, 0, 0; 0, 0, 0; 1, 0, 0];
18 | % E8 = [0, 0, 0; 0, 0, 0; 0, 1, 0];
19 |
20 | Ekx = [0, 0, 0;0, 0, -1;0, 1, 0];
21 | Eky = [0, 0, 1;0, 0, 0;-1, 0, 0];
22 | Ekz = [0, -1, 0;1, 0, 0;0, 0, 0];
23 |
24 | Ea1 = [1,0,0;0,0,0;0,0,-1];
25 | Ea2 = [0,0,0;0,1,0;0,0,-1];
26 |
27 | Ea3 = [1,0,0;0,-1,0;0,0,0];
28 |
29 | Enx = [0,0,1;0,0,0;0,0,0];
30 | Eny = [0,0,0;0,0,1;0,0,0];
31 | Enz = [0,1,0;0,0,0;0,0,0];
32 |
33 | E1 = Ekx;
34 | E2 = Eky;
35 | E3 = Ekz;
36 | E4 = Ea1;
37 | E5 = Ea2;
38 | E6 = Enx;
39 | E7 = Eny;
40 | E8 = Enz;
41 |
42 | % E1 = [0, 0, 1; 0, 0, 0; 0, 0, 0];
43 | % E2 = [0, 0, 0; 0, 0, 1; 0, 0, 0];
44 | % E3 = [0, -1, 0; 1, 0, 0; 0, 0, 0];
45 | % E4 = [0, 0, 0; 0, 0, 0; 0, 0, -1];
46 | % E5 = [1, 0, 0; 0, -1, 0; 0, 0, 0];
47 | % E6 = [0, 1, 0; 0, 0, 0; 0, 0, 0];
48 | % E7 = [0, 0, 0; 0, 0, 0; 1, 0, 0];
49 | % E8 = [0, 0, 0; 0, 0, 0; 0, 1, 0];
50 |
51 | % E = {E1,E2,E3,E4,E5,E6,E7,E8};
52 | E = {Ekx,Eky,Ekz,Ea1,Ea2,Enx,Eny,Enz};
53 | % E = {Ekx,Eky,Ekz};
54 |
55 | E1_vec = reshape(E1,1,[])';
56 | E2_vec = reshape(E2,1,[])';
57 | E3_vec = reshape(E3,1,[])';
58 | E4_vec = reshape(E4,1,[])';
59 | E5_vec = reshape(E5,1,[])';
60 | E6_vec = reshape(E6,1,[])';
61 | E7_vec = reshape(E7,1,[])';
62 | E8_vec = reshape(E8,1,[])';
63 |
64 |
65 | E_vec = [E1_vec, E2_vec, E3_vec, E4_vec, E5_vec, E6_vec, E7_vec, E8_vec];
66 |
67 | x_hat = v1*E1+v2*E2+v3*E3+v4*E4+v5*E5+v6*E6+v7*E7+v8*E8;
68 |
69 |
70 | %% find h
71 | H = [h1,h2,h3;h4,h5,h6;h7,h8,h9];
72 | Ad_H_hat = H*x_hat*inv(H);
73 | Ad_H_hat_vec = reshape(Ad_H_hat,1,[])';
74 |
75 | % solve least square to obtain x
76 | x = inv(E_vec'*E_vec)*(E_vec')*Ad_H_hat_vec;
77 | var = [v1,v2,v3,v4,v5,v6,v7,v8];
78 | [Ad_H_sym,b]=equationsToMatrix(x,var);
79 |
80 | syms f(h1,h2,h3,h4,h5,h6,h7,h8,h9);
81 | f(h1,h2,h3,h4,h5,h6,h7,h8,h9) = Ad_H_sym;
82 |
83 | %% find Ad_Ei
84 | Ad_E = {};
85 | dAd_E = {};
86 | for i=1:size(E,2)
87 | % Find Ad_E_i
88 | H = expm(E{i}); % Using numerical exponential for now
89 | Ad_E{i} = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3)));
90 | dAd_E{i} = logm(Ad_E{i});
91 | end
92 |
93 | %% construct dro_3
94 |
95 | dro_3 = {};
96 | C = [];
97 | for i =1:size(E,2)
98 | dro_3{i} = kron(-E{i}',eye(8))+kron(eye(3),dAd_E{i});
99 | C = [C;dro_3{i}];
100 | end
101 |
102 | %% solve for the null space
103 | [U,S,V] = svd(C);
104 |
105 |
106 | Q = null(C);
107 |
108 | %%
109 | H = expm(hat_sl3(rand(8,1)));
110 | Ad_test = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3)));
111 |
112 | % w = Q*Q'*rand(24,1)*10;
113 | w = Q*Q'*ones(24,1)*10;
114 |
115 | W = reshape(w,[8,3]);
116 |
117 | v_test = [2,3,1]';
118 | H_v_test = H*v_test;
119 |
120 | x_test = W*v_test;
121 | x_H_test = W*H_v_test;
122 | x_ad_test = Ad_test*x_test;
123 | disp([x_H_test,x_ad_test])
124 |
125 | %%
126 | % rank(C,1e-10)
127 |
128 | %% test solving for one h
129 | % % a = E1+E2+E3+E4+E5+E6+E7+E8;
130 | %
131 | % % a = Ekx+Eky+Ekz+Ea1+Ea2+Enx+Eny+Enz;
132 | %
133 | % a = Ekx+Eky+Ekz+Ea1+Ea2+Enx+Eny+Enz;
134 | % % a = Ea1+Ea2;
135 | % % a = E4
136 | % H = expm(a);
137 | % % rnd = 100*rand(3);
138 | % % x=rnd(1);
139 | % % y=rnd(2);
140 | % % z=rnd(3);
141 | % % H = expm(x*Ekx+y*Eky+z*Ekz);
142 | %
143 | % % rnd = 100*randn(3);
144 | % % n1=rnd(1);
145 | % % n2=rnd(2);
146 | % % n3=rnd(3);
147 | % % H = [1 n1 n2; 0 1 n3; 0 0 1];
148 | %
149 | % % rnd = 2*randn(3);
150 | % % a1=rnd(1);
151 | % % a2=rnd(2);
152 | % % H = [a1, 0,0; 0,a2,0;0,0,1/a1/a2];
153 | %
154 | % % H = [0.5, 0,0; 0,2,0;0,0,1/0.5/2];
155 | % Ad_test = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3)));
156 | %
157 | % D = kron(inv(H'),Ad_test)-eye(24);
158 | % % D = kron(Ad_test,inv(H'))-eye(24);
159 | % % [DU,DS,DV] = svd(D);
160 | %
161 | % DQ = null(D);
162 | %
163 | % rank(D)
164 | %
165 | % wd = DQ*DQ'*ones(24,1);
166 | %
167 | % Wd = reshape(wd,[8,3]);
168 | %
169 | % v_test = [2,3,1]';
170 | % H_v_test = H*v_test;
171 | %
172 | % x_test = Wd*v_test;
173 | % x_H_test = Wd*H_v_test;
174 | % x_ad_test = Ad_test*x_test;
175 | % disp([x_H_test,x_ad_test])
176 | % disp(x_H_test - x_ad_test)
177 |
178 | % rank([E1;E2;E3;E4;E5;E6;E7;E8])
179 |
180 | %%
181 |
--------------------------------------------------------------------------------
/experiment/platonic_solid_cls_test.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import argparse
5 | import os
6 | import time
7 | import yaml
8 | import numpy as np
9 | from tqdm import tqdm
10 | import matplotlib.pyplot as plt
11 | import time
12 | from scipy.spatial.transform import Rotation
13 |
14 | import torch
15 | from torch import nn
16 | import torch.optim as optim
17 | from torch.utils.tensorboard import SummaryWriter
18 | from torch.utils.data import Dataset, DataLoader, IterableDataset
19 |
20 | from core.lie_neurons_layers import *
21 | from experiment.platonic_solid_cls_layers import *
22 | from data_gen.gen_platonic_solids import *
23 |
24 |
25 | def test_perspective(model, test_loader, criterion, config, device):
26 | model.eval()
27 | hat_layer = HatLayer(algebra_type='sl3').to(device)
28 | rots = random_sample_rotations(config['num_rotations'], config['rotation_factor'],device)
29 | with torch.no_grad():
30 | loss_sum = 0.0
31 | num_correct = 0
32 | num_correct_non_conj = 0
33 | for iter, samples in tqdm(enumerate(test_loader, start=0)):
34 |
35 | x = samples[0].to(device)
36 | y = samples[1].to(device)
37 |
38 | x_hat = hat_layer(x)
39 | x = rearrange(x,'b n f k -> b f k n')
40 |
41 | output = model(x)
42 | _, prediction_non_conj = torch.max(output,1)
43 | num_correct_non_conj += (prediction_non_conj==y).sum().item()
44 | for r in range(config['num_rotations']):
45 | cur_rot = rots[r,:,:]
46 | x_rot_hat = torch.matmul(cur_rot, torch.matmul(x_hat, torch.inverse(cur_rot)))
47 | x_rot = rearrange(vee_sl3(x_rot_hat),'b n f k -> b f k n')
48 | output_rot = model(x_rot)
49 |
50 | _, prediction = torch.max(output_rot,1)
51 | num_correct += (prediction==y).sum().item()
52 |
53 | loss = criterion(output_rot, y)
54 | loss_sum += loss.item()
55 |
56 | loss_avg = loss_sum/config['num_test']*config['batch_size']/config['num_rotations']
57 | acc_avg = num_correct/config['num_test']/config['num_rotations']
58 | acc_avg_non_conj = num_correct_non_conj/config['num_test']/1.0
59 | return loss_avg, acc_avg, acc_avg_non_conj
60 |
61 | def random_sample_rotations(num_rotation, rotation_factor: float = 1.0, device='cpu') -> np.ndarray:
62 | r = np.zeros((num_rotation, 3, 3))
63 | for n in range(num_rotation):
64 | # angle_z, angle_y, angle_x
65 | euler = np.random.rand(3) * np.pi * 2 / rotation_factor # (0, 2 * pi / rotation_range)
66 | r[n,:,:] = Rotation.from_euler('zyx', euler).as_matrix()
67 | return torch.from_numpy(r).type('torch.FloatTensor').to(device)
68 |
69 |
70 |
71 |
72 | def main():
73 | # torch.autograd.set_detect_anomaly(True)
74 |
75 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76 | print('Using ', device)
77 |
78 | parser = argparse.ArgumentParser(description='Train the network')
79 | parser.add_argument('--training_config', type=str,
80 | default=os.path.dirname(os.path.abspath(__file__))+'/../config/platonic_solid_cls/testing_param.yaml')
81 | args = parser.parse_args()
82 |
83 | # load yaml file
84 | config = yaml.safe_load(open(args.training_config))
85 |
86 |
87 | test_set = PlatonicDataset(config['num_test'])
88 | test_loader = DataLoader(dataset=test_set, batch_size=config['batch_size'],
89 | shuffle=config['shuffle'])
90 |
91 | if config['model_type'] == "LN_relu_bracket":
92 | model = LNReluBracketPlatonicSolidClassifier(3).to(device)
93 | elif config['model_type'] == "LN_relu":
94 | model = LNReluPlatonicSolidClassifier(3).to(device)
95 | elif config['model_type'] == "LN_bracket":
96 | model = LNBracketPlatonicSolidClassifier(3).to(device)
97 | elif config['model_type'] == "MLP":
98 | model = MLP(288).to(device)
99 | elif config['model_type'] == "LN_bracket_no_residual":
100 | model = LNBracketNoResidualConnectPlatonicSolidClassifier(3).to(device)
101 |
102 | print("Using model: ", config['model_type'])
103 | print("total number of parameters: ", sum(p.numel() for p in model.parameters()))
104 |
105 | checkpoint = torch.load(config['model_path'])
106 | model.load_state_dict(checkpoint['model_state_dict'],strict=False)
107 |
108 | criterion = nn.CrossEntropyLoss().to(device)
109 | test_loss, test_acc, test_acc_non_conj = test_perspective(model, test_loader, criterion, config, device)
110 | print("test loss: ", test_loss)
111 | print("test with conjugate acc: ", test_acc)
112 | print("test without conjugate acc: ", "{:.4f}".format(test_acc_non_conj))
113 |
114 | if __name__ == "__main__":
115 | main()
116 |
--------------------------------------------------------------------------------
/data_gen/gen_sl3_inv_data.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import numpy as np
5 | import torch
6 | from scipy.linalg import expm
7 |
8 | from core.lie_neurons_layers import *
9 |
10 |
11 | def invariant_function(x1, x2):
12 | return torch.sin(torch.trace(x1@x1))+torch.cos(torch.trace(x2@x2))\
13 | -torch.pow(torch.trace(x2@x2), 3)/2.0+torch.det(x1@x2)+torch.exp(torch.trace(x1@x1))
14 |
15 |
16 | if __name__ == "__main__":
17 | data_saved_path = "data/sl3_inv_data/"
18 | data_name = "sl3_inv_10000_s_05_augmented"
19 | gen_augmented_training_data = True
20 | num_training = 5000
21 | num_testing = 10000
22 | num_conjugate = 1
23 | rnd_scale = 0.5
24 |
25 | train_data = {}
26 | test_data = {}
27 |
28 | hat_layer = HatLayer(algebra_type='sl3')
29 |
30 | x1 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_training)))\
31 | .reshape(1, 1, 8, num_training)
32 | x2 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_training)))\
33 | .reshape(1, 1, 8, num_training)
34 |
35 | # conjugate transformation
36 | h = torch.Tensor(np.random.uniform(-rnd_scale,
37 | rnd_scale, (num_conjugate, num_training, 8)))
38 | H = torch.linalg.matrix_exp(hat_layer(h))
39 | # conjugate x1
40 | x1_hat = hat_layer(x1.transpose(2, -1))
41 | conj_x1_hat = torch.matmul(H, torch.matmul(x1_hat, torch.inverse(H)))
42 | conj_x1 = rearrange(vee_sl3(conj_x1_hat), 'b c t l -> b l t c')
43 |
44 | # conjugate x2
45 | x2_hat = hat_layer(x2.transpose(2, -1))
46 | conj_x2_hat = torch.matmul(H, torch.matmul(x2_hat, torch.inverse(H)))
47 | conj_x2 = rearrange(vee_sl3(conj_x2_hat), 'b c t l -> b l t c')
48 |
49 | inv_output = torch.zeros((1, num_training, 1))
50 | # compute invariant function
51 | for n in range(num_training):
52 | inv_output[0, n, 0] = invariant_function(
53 | x1_hat[0, 0, n, :, :], x2_hat[0, 0, n, :, :])
54 |
55 | # print("--------------invariant function output: ------------------")
56 | # print(invariant_function(x1_hat[0, 0, n, :, :], x2_hat[0, 0, n, :, :]))
57 | # print(invariant_function(
58 | # conj_x1_hat[0, 0, n, :, :], conj_x2_hat[0, 0, n, :, :]))
59 |
60 | # print(x1.shape)
61 | # print(conj_x1.shape)
62 | if(gen_augmented_training_data):
63 | # this is for training data only
64 | train_data['x1'] = torch.cat((x1.reshape(8, num_training),conj_x1.reshape(8, num_training*num_conjugate)),dim=1).numpy()
65 | train_data['x2'] = torch.cat((x2.reshape(8, num_training),conj_x2.reshape(8, num_training*num_conjugate)),dim=1).numpy()
66 | train_data['x1_conjugate'] = conj_x1.reshape(8, num_training, num_conjugate).repeat(1,num_conjugate+1,1).numpy()
67 | train_data['x2_conjugate'] = conj_x2.reshape(8, num_training, num_conjugate).repeat(1,num_conjugate+1,1).numpy()
68 | train_data['y'] = inv_output.reshape(1, num_training).repeat(1,num_conjugate+1).numpy()
69 | else:
70 | train_data['x1'] = x1.numpy().reshape(8, num_training)
71 | train_data['x2'] = x2.numpy().reshape(8, num_training)
72 | train_data['x1_conjugate'] = conj_x1.numpy().reshape(8, num_training, num_conjugate)
73 | train_data['x2_conjugate'] = conj_x2.numpy().reshape(8, num_training, num_conjugate)
74 | train_data['y'] = inv_output.numpy().reshape(1, num_training)
75 |
76 | np.savez(data_saved_path + data_name + "_train_data.npz", **train_data)
77 |
78 |
79 | '''
80 | Generate testing data
81 | '''
82 | x1 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_testing)))\
83 | .reshape(1, 1, 8, num_testing)
84 | x2 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_testing)))\
85 | .reshape(1, 1, 8, num_testing)
86 |
87 | # conjugate transformation
88 | h = torch.Tensor(np.random.uniform(-rnd_scale,
89 | rnd_scale, (num_conjugate, num_testing, 8)))
90 | H = torch.linalg.matrix_exp(hat_layer(h))
91 | # conjugate x1
92 | x1_hat = hat_layer(x1.transpose(2, -1))
93 | conj_x1_hat = torch.matmul(H, torch.matmul(x1_hat, torch.inverse(H)))
94 | conj_x1 = rearrange(vee_sl3(conj_x1_hat), 'b c t l -> b l t c')
95 |
96 | # conjugate x2
97 | x2_hat = hat_layer(x2.transpose(2, -1))
98 | conj_x2_hat = torch.matmul(H, torch.matmul(x2_hat, torch.inverse(H)))
99 | conj_x2 = rearrange(vee_sl3(conj_x2_hat), 'b c t l -> b l t c')
100 | inv_output = torch.zeros((1, num_testing, 1))
101 | # compute invariant function
102 | for n in range(num_testing):
103 | inv_output[0, n, 0] = invariant_function(
104 | x1_hat[0, 0, n, :, :], x2_hat[0, 0, n, :, :])
105 |
106 | # for i in range(num_conjugate):
107 | # print(invariant_function(conj_x1_hat[0, i, n, :, :],conj_x2_hat[0, i, n, :, :]))
108 |
109 | test_data['x1'] = x1.numpy().reshape(8, num_testing)
110 | test_data['x2'] = x2.numpy().reshape(8, num_testing)
111 | test_data['x1_conjugate'] = conj_x1.numpy().reshape(8, num_testing, num_conjugate)
112 | test_data['x2_conjugate'] = conj_x2.numpy().reshape(8, num_testing, num_conjugate)
113 | test_data['y'] = inv_output.numpy().reshape(1, num_testing)
114 |
115 | np.savez(data_saved_path + data_name + "_test_data.npz", **test_data)
116 |
117 | print("Done! Data saved to: \n", data_saved_path +
118 | data_name + "_train_data.npz\n", data_saved_path + data_name + "_test_data.npz")
119 |
--------------------------------------------------------------------------------
/experiment/sl3_inv_test.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import argparse
5 | import os
6 | import time
7 | import yaml
8 | import numpy as np
9 | from tqdm import tqdm
10 | import matplotlib.pyplot as plt
11 | import time
12 |
13 | import torch
14 | from torch import nn
15 | import torch.optim as optim
16 | from torch.utils.tensorboard import SummaryWriter
17 | from torch.utils.data import Dataset, DataLoader
18 |
19 | from core.lie_neurons_layers import *
20 | from experiment.sl3_inv_layers import *
21 | from data_loader.sl3_inv_data_loader import *
22 |
23 |
24 | def test(model, test_loader, criterion, config, device):
25 | model.eval()
26 | with torch.no_grad():
27 | loss_sum = 0.0
28 | for i, sample in tqdm(enumerate(test_loader, start=0)):
29 | x = sample['x'].to(device)
30 | y = sample['y'].to(device)
31 |
32 | output = model(x)
33 |
34 | loss = criterion(output, y)
35 | loss_sum += loss.item()
36 |
37 | loss_avg = loss_sum/len(test_loader)
38 |
39 | return loss_avg
40 |
41 |
42 | def test_invariance(model, test_loader, criterion, config, device):
43 | model.eval()
44 | with torch.no_grad():
45 | loss_sum = 0.0
46 | loss_non_conj_sum = 0.0
47 | diff_output_sum = 0.0
48 | loss_all = []
49 | loss_non_conj_all = []
50 |
51 | for i, sample in tqdm(enumerate(test_loader, start=0)):
52 | x = sample['x'].to(device)
53 | x_conj = sample['x_conjugate'].to(device)
54 | y = sample['y'].to(device)
55 |
56 | output_x = model(x)
57 | loss_non_conj = criterion(output_x, y)
58 | loss_non_conj_sum += loss_non_conj.item()
59 | loss_non_conj_all.append(loss_non_conj.item())
60 | # print(output_x)
61 | # print(x_conj.shape)
62 | for j in range(x_conj.shape[1]):
63 | x_conj_j = x_conj[:, j, :, :, :]
64 | output_conj = model(x_conj_j)
65 | diff_output = output_x - output_conj
66 | loss = criterion(output_conj, y)
67 | loss_all.append(loss.item())
68 | loss_sum += loss.item()
69 | diff_output_sum += torch.sum(torch.abs(diff_output))
70 | # print(output_conj)
71 |
72 | # print("diff", diff_output[0,:])
73 | # print("conj_out", output_conj[0,:])
74 | # print("out",output_x[0,:])
75 | # print("y", y[0,:])
76 | # print(loss.item())
77 | # print("----------------------")
78 | # print(loss_all)
79 | # print(torch.Tensor(loss_all).shape)
80 | # loss_avg2 = torch.mean(torch.Tensor(loss_all))
81 | # print("loss_avg 2: ", loss_avg2)
82 | loss_std = torch.std(torch.Tensor(loss_all))
83 | loss_non_conj_std = torch.std(torch.Tensor(loss_non_conj_all))
84 | loss_avg = loss_sum/len(test_loader)/x_conj.shape[1]
85 | diff_output_avg = diff_output_sum/len(test_loader.dataset)/x_conj.shape[1]
86 | loss_non_conj_avg = loss_non_conj_sum/len(test_loader)
87 |
88 | return loss_avg, loss_std, loss_non_conj_avg, loss_non_conj_std, diff_output_avg
89 |
90 | def main():
91 | # torch.autograd.set_detect_anomaly(True)
92 |
93 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94 | print('Using ', device)
95 |
96 | parser = argparse.ArgumentParser(description='Train the network')
97 | parser.add_argument('--test_config', type=str,
98 | default=os.path.dirname(os.path.abspath(__file__))+'/../config/sl3_inv/testing_param.yaml')
99 | args = parser.parse_args()
100 |
101 | # load yaml file
102 | config = yaml.safe_load(open(args.test_config))
103 |
104 | test_set = sl3InvDataSet2Input(config['test_data_path'], device=device)
105 | test_loader = DataLoader(dataset=test_set, batch_size=config['batch_size'],
106 | shuffle=config['shuffle'])
107 |
108 |
109 | if config['model_type'] == "LN_relu_bracket":
110 | model = SL3InvariantReluBracketLayers(2).to(device)
111 | elif config['model_type'] == "LN_relu":
112 | model = SL3InvariantReluLayers(2).to(device)
113 | elif config['model_type'] == "LN_bracket":
114 | model = SL3InvariantBracketLayers(2).to(device)
115 | elif config['model_type'] == "MLP":
116 | model = MLP(16).to(device)
117 | elif config['model_type'] == "LN_bracket_no_residual":
118 | model = SL3InvariantBracketNoResidualConnectLayers(2).to(device)
119 |
120 | print("Using model: ", config['model_type'])
121 | print("total number of parameters: ", sum(p.numel() for p in model.parameters()))
122 | checkpoint = torch.load(config['model_path'])
123 | model.load_state_dict(checkpoint['model_state_dict'],strict=False)
124 |
125 | criterion = nn.MSELoss().to(device)
126 | # test_loss = test(model, test_loader, criterion, config, device)
127 | test_loss_inv, test_loss_inv_std, loss_non_conj_avg, loss_non_conj_std, diff_output_avg = test_invariance(model, test_loader, criterion, config, device)
128 |
129 | print("test_loss type:",type(test_loss_inv))
130 | print("test loss: ", test_loss_inv)
131 | print("test loss std", test_loss_inv_std)
132 | print("avg diff output: ", diff_output_avg)
133 | print("loss non conj: ", loss_non_conj_avg)
134 | print("loss non conj std: ", loss_non_conj_std)
135 |
136 | if __name__ == "__main__":
137 | main()
138 |
--------------------------------------------------------------------------------
/data_gen/gen_sp4_inv_data.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import numpy as np
5 | import torch
6 | from scipy.linalg import expm
7 |
8 | from core.lie_neurons_layers import *
9 |
10 |
11 | def invariant_function(x1, x2):
12 | return torch.sin(torch.trace(x1@x1))+torch.cos(torch.trace(x2@x2))\
13 | -torch.pow(torch.trace(x2@x2), 3)/2.0+torch.det(x1@x2)+torch.exp(torch.trace(x1@x1))
14 |
15 |
16 | if __name__ == "__main__":
17 | data_saved_path = "data/sp4_inv_data/"
18 | data_name = "sp4_inv_10000_s_05_augmented"
19 | gen_augmented_training_data = True
20 | num_training = 5000
21 | num_testing = 10000
22 | num_conjugate = 1
23 | rnd_scale = 0.5
24 |
25 | train_data = {}
26 | test_data = {}
27 |
28 | hat_layer = HatLayer(algebra_type='sp4')
29 |
30 | x1 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 10, num_training)))\
31 | .reshape(1, 1, 10, num_training)
32 | x2 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 10, num_training)))\
33 | .reshape(1, 1, 10, num_training)
34 |
35 | # conjugate transformation
36 | h = torch.Tensor(np.random.uniform(-rnd_scale,
37 | rnd_scale, (num_conjugate, num_training, 10)))
38 | H = torch.linalg.matrix_exp(hat_layer(h))
39 | # conjugate x1
40 | x1_hat = hat_layer(x1.transpose(2, -1))
41 | conj_x1_hat = torch.matmul(H, torch.matmul(x1_hat, torch.inverse(H)))
42 | conj_x1 = rearrange(vee_sp4(conj_x1_hat), 'b c t l -> b l t c')
43 |
44 | # conjugate x2
45 | x2_hat = hat_layer(x2.transpose(2, -1))
46 | conj_x2_hat = torch.matmul(H, torch.matmul(x2_hat, torch.inverse(H)))
47 | conj_x2 = rearrange(vee_sp4(conj_x2_hat), 'b c t l -> b l t c')
48 |
49 | inv_output = torch.zeros((1, num_training, 1))
50 | # compute invariant function
51 | for n in range(num_training):
52 | inv_output[0, n, 0] = invariant_function(
53 | x1_hat[0, 0, n, :, :], x2_hat[0, 0, n, :, :])
54 |
55 | # print("--------------invariant function output: ------------------")
56 | # print(invariant_function(x1_hat[0, 0, n, :, :], x2_hat[0, 0, n, :, :]))
57 | # print(invariant_function(
58 | # conj_x1_hat[0, 0, n, :, :], conj_x2_hat[0, 0, n, :, :]))
59 |
60 | # print(x1.shape)
61 | # print(conj_x1.shape)
62 | if(gen_augmented_training_data):
63 | # this is for training data only
64 | train_data['x1'] = torch.cat((x1.reshape(10, num_training),conj_x1.reshape(10, num_training*num_conjugate)),dim=1).numpy()
65 | train_data['x2'] = torch.cat((x2.reshape(10, num_training),conj_x2.reshape(10, num_training*num_conjugate)),dim=1).numpy()
66 | train_data['x1_conjugate'] = conj_x1.reshape(10, num_training, num_conjugate).repeat(1,num_conjugate+1,1).numpy()
67 | train_data['x2_conjugate'] = conj_x2.reshape(10, num_training, num_conjugate).repeat(1,num_conjugate+1,1).numpy()
68 | train_data['y'] = inv_output.reshape(1, num_training).repeat(1,num_conjugate+1).numpy()
69 | else:
70 | train_data['x1'] = x1.numpy().reshape(10, num_training)
71 | train_data['x2'] = x2.numpy().reshape(10, num_training)
72 | train_data['x1_conjugate'] = conj_x1.numpy().reshape(10, num_training, num_conjugate)
73 | train_data['x2_conjugate'] = conj_x2.numpy().reshape(10, num_training, num_conjugate)
74 | train_data['y'] = inv_output.numpy().reshape(1, num_training)
75 |
76 | np.savez(data_saved_path + data_name + "_train_data.npz", **train_data)
77 |
78 |
79 | '''
80 | Generate testing data
81 | '''
82 | x1 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 10, num_testing)))\
83 | .reshape(1, 1, 10, num_testing)
84 | x2 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 10, num_testing)))\
85 | .reshape(1, 1, 10, num_testing)
86 |
87 | # conjugate transformation
88 | h = torch.Tensor(np.random.uniform(-rnd_scale,
89 | rnd_scale, (num_conjugate, num_testing, 10)))
90 | H = torch.linalg.matrix_exp(hat_layer(h))
91 | # conjugate x1
92 | x1_hat = hat_layer(x1.transpose(2, -1))
93 | conj_x1_hat = torch.matmul(H, torch.matmul(x1_hat, torch.inverse(H)))
94 | conj_x1 = rearrange(vee_sp4(conj_x1_hat), 'b c t l -> b l t c')
95 |
96 | # conjugate x2
97 | x2_hat = hat_layer(x2.transpose(2, -1))
98 | conj_x2_hat = torch.matmul(H, torch.matmul(x2_hat, torch.inverse(H)))
99 | # print(conj_x2_hat.shape)
100 | conj_x2 = rearrange(vee_sp4(conj_x2_hat), 'b c t l -> b l t c')
101 | inv_output = torch.zeros((1, num_testing, 1))
102 | # compute invariant function
103 | for n in range(num_testing):
104 | inv_output[0, n, 0] = invariant_function(
105 | x1_hat[0, 0, n, :, :], x2_hat[0, 0, n, :, :])
106 |
107 | # for i in range(num_conjugate):
108 | # print(invariant_function(conj_x1_hat[0, i, n, :, :],conj_x2_hat[0, i, n, :, :]))
109 |
110 | test_data['x1'] = x1.numpy().reshape(10, num_testing)
111 | test_data['x2'] = x2.numpy().reshape(10, num_testing)
112 | test_data['x1_conjugate'] = conj_x1.numpy().reshape(10, num_testing, num_conjugate)
113 | test_data['x2_conjugate'] = conj_x2.numpy().reshape(10, num_testing, num_conjugate)
114 | test_data['y'] = inv_output.numpy().reshape(1, num_testing)
115 |
116 | np.savez(data_saved_path + data_name + "_test_data.npz", **test_data)
117 |
118 | print("Done! Data saved to: \n", data_saved_path +
119 | data_name + "_train_data.npz\n", data_saved_path + data_name + "_test_data.npz")
120 |
--------------------------------------------------------------------------------
/data_loader/sl3_inv_data_loader.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import Dataset
7 | from scipy.linalg import expm
8 | from tqdm import tqdm
9 |
10 |
11 | from einops import rearrange, repeat
12 | from einops.layers.torch import Rearrange
13 |
14 | from core.lie_neurons_layers import *
15 |
16 |
17 | class sl3InvDataSet2Input(Dataset):
18 | def __init__(self, data_path, device='cuda'):
19 | data = np.load(data_path)
20 | _, num_points = data['x1'].shape
21 | _,_,num_conjugate = data['x1_conjugate'].shape
22 | self.x1 = rearrange(torch.from_numpy(data['x1']).type(
23 | 'torch.FloatTensor').to(device), 'k n -> n 1 k 1')
24 | self.x2 = rearrange(torch.from_numpy(data['x2']).type(
25 | 'torch.FloatTensor').to(device), 'k n -> n 1 k 1')
26 | self.x = torch.cat((self.x1, self.x2), dim=1)
27 |
28 | self.x1_conjugate = rearrange(torch.from_numpy(data['x1_conjugate']).type(
29 | 'torch.FloatTensor').to(device),'k n c -> c n 1 k 1')
30 | self.x2_conjugate = rearrange(torch.from_numpy(data['x2_conjugate']).type(
31 | 'torch.FloatTensor').to(device),'k n c -> c n 1 k 1')
32 | self.x_conjugate = torch.cat(
33 | (self.x1_conjugate, self.x2_conjugate), dim=2)
34 | self.y = torch.from_numpy(data['y']).type(
35 | 'torch.FloatTensor').to(device).reshape(num_points, 1)
36 |
37 | self.num_data = self.x1.shape[0]
38 |
39 | def __len__(self):
40 | return self.num_data
41 |
42 | def __getitem__(self, idx):
43 | if torch.is_tensor(idx):
44 | idx = idx.tolist()
45 |
46 | sample = {'x1': self.x1[idx, :, :, :], 'x2': self.x2[idx, :, :, :], 'x': self.x[idx, :, :, :],
47 | 'x1_conjugate': self.x1_conjugate[:,idx, :, :, :], 'x2_conjugate': self.x2_conjugate[:,idx, :, :, :],
48 | 'x_conjugate': self.x_conjugate[:,idx, :, :, :], 'y': self.y[idx, :]}
49 | return sample
50 |
51 |
52 | class sl3InvDataSet5Input(Dataset):
53 | def __init__(self, data_path, device='cuda'):
54 | data = np.load(data_path)
55 | _, num_points = data['x1'].shape
56 | _,_,num_conjugate = data['x1_conjugate'].shape
57 | self.x1 = rearrange(torch.from_numpy(data['x1']).type(
58 | 'torch.FloatTensor').to(device), 'k n -> n 1 k 1')
59 | self.x2 = rearrange(torch.from_numpy(data['x2']).type(
60 | 'torch.FloatTensor').to(device), 'k n -> n 1 k 1')
61 | self.x3 = rearrange(torch.from_numpy(data['x3']).type(
62 | 'torch.FloatTensor').to(device), 'k n -> n 1 k 1')
63 | self.x4 = rearrange(torch.from_numpy(data['x4']).type(
64 | 'torch.FloatTensor').to(device), 'k n -> n 1 k 1')
65 | self.x5 = rearrange(torch.from_numpy(data['x5']).type(
66 | 'torch.FloatTensor').to(device), 'k n -> n 1 k 1')
67 | self.x = torch.cat((self.x1, self.x2, self.x3, self.x4, self.x5), dim=1)
68 |
69 | self.x1_conjugate = rearrange(torch.from_numpy(data['x1_conjugate']).type(
70 | 'torch.FloatTensor').to(device),'k n c -> c n 1 k 1')
71 | self.x2_conjugate = rearrange(torch.from_numpy(data['x2_conjugate']).type(
72 | 'torch.FloatTensor').to(device),'k n c -> c n 1 k 1')
73 | self.x3_conjugate = rearrange(torch.from_numpy(data['x3_conjugate']).type(
74 | 'torch.FloatTensor').to(device),'k n c -> c n 1 k 1')
75 | self.x4_conjugate = rearrange(torch.from_numpy(data['x4_conjugate']).type(
76 | 'torch.FloatTensor').to(device),'k n c -> c n 1 k 1')
77 | self.x5_conjugate = rearrange(torch.from_numpy(data['x5_conjugate']).type(
78 | 'torch.FloatTensor').to(device),'k n c -> c n 1 k 1')
79 | self.x_conjugate = torch.cat(
80 | (self.x1_conjugate, self.x2_conjugate,self.x3_conjugate,self.x4_conjugate,self.x5_conjugate), dim=2)
81 | self.y = torch.from_numpy(data['y']).type(
82 | 'torch.FloatTensor').to(device).reshape(num_points, 1)
83 |
84 | self.num_data = self.x1.shape[0]
85 |
86 | def __len__(self):
87 | return self.num_data
88 |
89 | def __getitem__(self, idx):
90 | if torch.is_tensor(idx):
91 | idx = idx.tolist()
92 |
93 | sample = {'x1': self.x1[idx, :, :, :], 'x2': self.x2[idx, :, :, :],'x3': self.x3[idx, :, :, :],\
94 | 'x4': self.x4[idx, :, :, :],'x5': self.x5[idx, :, :, :], 'x': self.x[idx, :, :, :],
95 | 'x1_conjugate': self.x1_conjugate[:,idx, :, :, :], 'x2_conjugate': self.x2_conjugate[:,idx, :, :, :],\
96 | 'x3_conjugate': self.x3_conjugate[:,idx, :, :, :],'x4_conjugate': self.x4_conjugate[:,idx, :, :, :],\
97 | 'x5_conjugate': self.x2_conjugate[:,idx, :, :, :], 'x_conjugate': self.x_conjugate[:,idx, :, :, :],\
98 | 'y': self.y[idx, :]}
99 | return sample
100 |
101 | if __name__ == "__main__":
102 |
103 | DataLoader = sl3InvDataSet5Input("data/sl3_inv_5_input_data/sl3_inv_1000_s_05_train_data.npz")
104 |
105 | print(DataLoader.x1.shape)
106 | print(DataLoader.x2.shape)
107 | print(DataLoader.x1_conjugate.shape)
108 | print(DataLoader.x2_conjugate.shape)
109 | print(DataLoader.x.shape)
110 | print(DataLoader.x_conjugate.shape)
111 | print(DataLoader.y.shape)
112 | for i, samples in tqdm(enumerate(DataLoader, start=0)):
113 | input_data = samples['x']
114 | y = samples['y']
115 | print(y)
116 | # print(input_data.shape)
117 |
--------------------------------------------------------------------------------
/experiment/sp4_inv_test.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import argparse
5 | import os
6 | import time
7 | import yaml
8 | import numpy as np
9 | from tqdm import tqdm
10 | import matplotlib.pyplot as plt
11 | import time
12 |
13 | import torch
14 | from torch import nn
15 | import torch.optim as optim
16 | from torch.utils.tensorboard import SummaryWriter
17 | from torch.utils.data import Dataset, DataLoader
18 |
19 | from core.lie_neurons_layers import *
20 | from experiment.sp4_inv_layers import *
21 | from data_loader.sp4_inv_data_loader import *
22 |
23 |
24 | def test(model, test_loader, criterion, config, device):
25 | model.eval()
26 | with torch.no_grad():
27 | loss_sum = 0.0
28 | for i, sample in tqdm(enumerate(test_loader, start=0)):
29 | x = sample['x'].to(device)
30 | y = sample['y'].to(device)
31 |
32 | output = model(x)
33 |
34 | loss = criterion(output, y)
35 | loss_sum += loss.item()
36 |
37 | loss_avg = loss_sum/len(test_loader)
38 |
39 | return loss_avg
40 |
41 |
42 | def test_invariance(model, test_loader, criterion, config, device):
43 | model.eval()
44 | with torch.no_grad():
45 | loss_sum = 0.0
46 | loss_non_conj_sum = 0.0
47 | diff_output_sum = 0.0
48 | loss_all = []
49 | loss_non_conj_all = []
50 |
51 | for i, sample in tqdm(enumerate(test_loader, start=0)):
52 | x = sample['x'].to(device)
53 | x_conj = sample['x_conjugate'].to(device)
54 | y = sample['y'].to(device)
55 |
56 | output_x = model(x)
57 | loss_non_conj = criterion(output_x, y)
58 | loss_non_conj_sum += loss_non_conj.item()
59 | loss_non_conj_all.append(loss_non_conj.item())
60 | # print(output_x)
61 | # print(x_conj.shape)
62 | for j in range(x_conj.shape[1]):
63 | x_conj_j = x_conj[:, j, :, :, :]
64 | output_conj = model(x_conj_j)
65 | diff_output = output_x - output_conj
66 | loss = criterion(output_conj, y)
67 | loss_all.append(loss.item())
68 | loss_sum += loss.item()
69 | diff_output_sum += torch.sum(torch.abs(diff_output))
70 | # print(output_conj)
71 |
72 | # print("diff", diff_output[0,:])
73 | # print("conj_out", output_conj[0,:])
74 | # print("out",output_x[0,:])
75 | # print("y", y[0,:])
76 | # print(loss.item())
77 | # print("----------------------")
78 | # print(loss_all)
79 | # print(torch.Tensor(loss_all).shape)
80 | # loss_avg2 = torch.mean(torch.Tensor(loss_all))
81 | # print("loss_avg 2: ", loss_avg2)
82 | loss_std = torch.std(torch.Tensor(loss_all))
83 | loss_non_conj_std = torch.std(torch.Tensor(loss_non_conj_all))
84 | loss_avg = loss_sum/len(test_loader)/x_conj.shape[1]
85 | diff_output_avg = diff_output_sum/len(test_loader.dataset)/x_conj.shape[1]
86 | loss_non_conj_avg = loss_non_conj_sum/len(test_loader)
87 |
88 | return loss_avg, loss_std, loss_non_conj_avg, loss_non_conj_std, diff_output_avg
89 |
90 | def main():
91 | # torch.autograd.set_detect_anomaly(True)
92 |
93 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94 | print('Using ', device)
95 |
96 | parser = argparse.ArgumentParser(description='Train the network')
97 | parser.add_argument('--test_config', type=str,
98 | default=os.path.dirname(os.path.abspath(__file__))+'/../config/sp4_inv/testing_param.yaml')
99 | args = parser.parse_args()
100 |
101 | # load yaml file
102 | config = yaml.safe_load(open(args.test_config))
103 |
104 | test_set = sp4InvDataSet(config['test_data_path'], device=device)
105 | test_loader = DataLoader(dataset=test_set, batch_size=config['batch_size'],
106 | shuffle=config['shuffle'])
107 |
108 |
109 | if config['model_type'] == "LN_relu_bracket":
110 | model = SP4InvariantReluBracketLayers(2).to(device)
111 | elif config['model_type'] == "LN_relu":
112 | model = SP4InvariantReluLayers(2).to(device)
113 | elif config['model_type'] == "LN_bracket":
114 | model = SP4InvariantBracketLayers(2).to(device)
115 | elif config['model_type'] == "MLP":
116 | model = MLP(20).to(device)
117 | elif config['model_type'] == "MLP512":
118 | model = MLP512(20).to(device)
119 | elif config['model_type'] == "LN_bracket_no_residual":
120 | model = SP4InvariantBracketNoResidualConnectLayers(2).to(device)
121 | else:
122 | raise ValueError("model type not supported")
123 |
124 | print("Using model: ", config['model_type'])
125 | print("total number of parameters: ", sum(p.numel() for p in model.parameters()))
126 | checkpoint = torch.load(config['model_path'])
127 | model.load_state_dict(checkpoint['model_state_dict'],strict=False)
128 |
129 | criterion = nn.MSELoss().to(device)
130 | # test_loss = test(model, test_loader, criterion, config, device)
131 | test_loss_inv, test_loss_inv_std, loss_non_conj_avg, loss_non_conj_std, diff_output_avg = test_invariance(model, test_loader, criterion, config, device)
132 |
133 | print("test_loss type:",type(test_loss_inv))
134 | print("test loss: ", test_loss_inv)
135 | print("test loss std", test_loss_inv_std)
136 | print("avg diff output: ", diff_output_avg)
137 | print("loss non conj: ", loss_non_conj_avg)
138 | print("loss non conj std: ", loss_non_conj_std)
139 |
140 | if __name__ == "__main__":
141 | main()
142 |
--------------------------------------------------------------------------------
/data_loader/so3_bch_data_loader.py:
--------------------------------------------------------------------------------
1 |
2 | import sys # nopep8
3 | sys.path.append('.') # nopep8
4 |
5 | import numpy as np
6 | import torch
7 | from torch.utils.data import Dataset
8 | from scipy.linalg import expm
9 | from tqdm import tqdm
10 | import matplotlib.pyplot as plt
11 |
12 |
13 | from einops import rearrange, repeat
14 | from einops.layers.torch import Rearrange
15 |
16 | from core.lie_neurons_layers import *
17 | from core.lie_group_util import *
18 |
19 | class so3BchDataSet(Dataset):
20 | def __init__(self, data_path, device='cuda'):
21 | data = np.load(data_path)
22 | num_points, _ = data['x1'].shape
23 |
24 | self.x1 = rearrange(torch.from_numpy(data['x1']).type(
25 | 'torch.FloatTensor').to(device), 'n k -> n 1 k 1')
26 | self.x2 = rearrange(torch.from_numpy(data['x2']).type(
27 | 'torch.FloatTensor').to(device), 'n k -> n 1 k 1')
28 | self.x = torch.cat((self.x1, self.x2), dim=1) # n 2 k 1
29 |
30 | self.y = torch.from_numpy(data['y']).type(
31 | 'torch.FloatTensor').to(device) # [N,3]
32 | # print(self.y.shape)
33 | self.num_data = self.x1.shape[0]
34 |
35 | def __len__(self):
36 | return self.num_data
37 |
38 | def __getitem__(self, idx):
39 | if torch.is_tensor(idx):
40 | idx = idx.tolist()
41 |
42 | sample = {'x1': self.x1[idx, :, :, :], 'x2': self.x2[idx, :, :, :],
43 | 'x': self.x[idx, :, :, :], 'y': self.y[idx, :]}
44 | return sample
45 |
46 | class so3BchTestDataSet(Dataset):
47 | '''
48 | Test data set contains augmented infomation
49 | '''
50 | def __init__(self, data_path, device='cuda'):
51 | data = np.load(data_path)
52 | num_points, _ = data['x1'].shape
53 |
54 | self.x1 = rearrange(torch.from_numpy(data['x1']).type(
55 | 'torch.FloatTensor').to(device), 'n k -> n 1 k 1')
56 | self.x2 = rearrange(torch.from_numpy(data['x2']).type(
57 | 'torch.FloatTensor').to(device), 'n k -> n 1 k 1')
58 | self.x = torch.cat((self.x1, self.x2), dim=1) # n 2 k 1
59 | self.x1_conj = rearrange(torch.from_numpy(data['x1_conjugate']).type(
60 | 'torch.FloatTensor').to(device), 'n c k -> c n 1 k 1')
61 | self.x2_conj = rearrange(torch.from_numpy(data['x2_conjugate']).type(
62 | 'torch.FloatTensor').to(device), 'n c k -> c n 1 k 1')
63 | self.x_conj = torch.cat((self.x1_conj, self.x2_conj), dim=2) # c n 2 k 1
64 | self.y = torch.from_numpy(data['y']).type(
65 | 'torch.FloatTensor').to(device) # [N,3]
66 | self.y_conj = rearrange(torch.from_numpy(data['y_conj']).type(
67 | 'torch.FloatTensor').to(device), 'n c k -> c n k') # [c, N,3]
68 | self.R = rearrange(torch.from_numpy(data['R_aug']).type(
69 | 'torch.FloatTensor').to(device), 'n c k1 k2 -> c n k1 k2') # [c, N,3,3]
70 | # print(self.y.shape)
71 | self.num_data = self.x1.shape[0]
72 |
73 | def __len__(self):
74 | return self.num_data
75 |
76 | def __getitem__(self, idx):
77 | if torch.is_tensor(idx):
78 | idx = idx.tolist()
79 |
80 | sample = {'x1': self.x1[idx, :, :, :], 'x2': self.x2[idx, :, :, :],\
81 | 'x': self.x[idx, :, :, :], 'y': self.y[idx, :],\
82 | 'x1_conj': self.x1_conj[:,idx, :, :, :], 'x2_conj': self.x2_conj[:,idx, :, :, :],\
83 | 'x_conj': self.x_conj[:,idx,:,:,:],'y_conj': self.y_conj[:,idx, :], \
84 | 'R': self.R[:,idx, :, :]
85 | }
86 | return sample
87 |
88 |
89 | if __name__ == "__main__":
90 |
91 | DataLoader = so3BchDataSet(
92 | "data/so3_bch_data/so3_bch_10000_4_train_data.npz")
93 |
94 | hat = HatLayer(algebra_type='so3').to('cuda')
95 | # print(DataLoader.x1.shape)
96 | # print(DataLoader.x2.shape)
97 | # print(DataLoader.x.shape)
98 | # print(DataLoader.y.shape)
99 |
100 | sum = 0
101 | sum_inf = 0
102 | sum_nan = 0
103 | sum_y_nan = 0
104 | sum_y_inf = 0
105 | v3_list = torch.zeros((DataLoader.num_data, 3))
106 | for i, samples in tqdm(enumerate(DataLoader, start=0)):
107 | input_data = samples['x'].to('cuda')
108 | R1 = exp_so3(hat(input_data[0, :, :].squeeze(-1)))
109 | R2 = exp_so3(hat(input_data[1, :, :].squeeze(-1)))
110 | v3 = vee(log_SO3(torch.matmul(R1,R2)),algebra_type='so3')
111 | v3_list[i,:] = v3
112 | # print("v1", input_data[0, :, :].squeeze(-1))
113 | # print("v2", input_data[1, :, :].squeeze(-1))
114 | # print("v3", v3)
115 | # print("norm v3", torch.norm(v3))
116 | y = samples['y'].to('cuda')
117 |
118 | # print(torch.trace(hat(y)))
119 | # print(float(torch.norm(v3)))
120 | if(torch.isinf(input_data).any()):
121 | sum_inf += 1
122 |
123 | if(torch.isnan(input_data).any()):
124 | sum_nan += 1
125 |
126 | if(torch.isinf(y).any()):
127 | sum_y_inf += 1
128 |
129 | if(torch.isnan(y).any()):
130 | sum_y_nan += 1
131 |
132 | # print("norm", torch.norm(v3-y))
133 |
134 | if(torch.norm(v3-y) > 1e-3):
135 | # print("\nv3", v3)
136 | # print("y", y)
137 | # print("diff", v3-y)
138 | # print("norm", torch.norm(v3-y))
139 | # print("-------------")
140 | sum += 1
141 |
142 | # diff_norm = torch.norm(v3-y)
143 |
144 | # print(y)
145 | # print(input_data.shape)
146 |
147 | # write a code to plot histogram of v1, v2, v3
148 | plt.hist(v3_list[:,0].cpu().numpy(), bins=100)
149 | plt.show()
150 |
151 | print("error > 1e-3", sum)
152 | print("input inf num", sum_inf)
153 | print("input nan num", sum_nan)
154 | print("y inf num", sum_y_inf)
155 | print("y nan num", sum_y_nan)
--------------------------------------------------------------------------------
/data_gen/gen_so3_bch.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import numpy as np
5 | import scipy.linalg
6 | import torch
7 | import math
8 |
9 | from core.lie_alg_util import *
10 | from core.lie_group_util import *
11 |
12 | if __name__ == "__main__":
13 |
14 | data_saved_path = "data/so3_bch_data/"
15 | data_name = "so3_bch_10000_augmented"
16 | gen_augmented_training_data = False
17 |
18 | num_training = 5000
19 | num_testing = 5000
20 | num_conjugate = 1
21 |
22 | train_data = {}
23 | test_data = {}
24 | so3_hatlayer = HatLayer(algebra_type='so3')
25 |
26 | def gen_random_rotation_vector():
27 | v = torch.rand(1,3)
28 | v = v/torch.norm(v)
29 | phi = (math.pi-1e-6)*torch.rand(1)
30 | v = phi*v
31 | return v
32 |
33 | x1 = torch.zeros((num_training,3))
34 | x2 = torch.zeros((num_training,3))
35 | y = torch.zeros((num_training,3))
36 |
37 | # if we want to generate adjoint augmented training data
38 | if(gen_augmented_training_data):
39 | x1_conj = torch.zeros((num_training, num_conjugate,3))
40 | x2_conj = torch.zeros((num_training, num_conjugate,3))
41 | y_conj = torch.zeros((num_training, num_conjugate,3))
42 | R_aug = torch.zeros((num_training, num_conjugate,3,3))
43 |
44 | # generate training data
45 | for i in range(num_training):
46 | # generate random v1, v2
47 | v1 = gen_random_rotation_vector()
48 | v2 = gen_random_rotation_vector()
49 |
50 | K1 = so3_hatlayer(v1)
51 | K2 = so3_hatlayer(v2)
52 |
53 | R1 = exp_so3(K1[0,:,:])
54 | R2 = exp_so3(K2[0,:,:])
55 |
56 | R3 = torch.matmul(R1,R2)
57 |
58 | v3 = vee(log_SO3(R3), algebra_type='so3')
59 |
60 | # v3 = vee(BCH_approx(K1[0,:,:], K2[0,:,:]), algebra_type='so3')
61 | if(torch.norm(v3) > math.pi):
62 | print("----------output bigger than pi---------")
63 | print("norm v1", torch.norm(v1))
64 | print("norm v2", torch.norm(v2))
65 | print("norm v3", torch.norm(v3))
66 | print("v1",v1)
67 | print("v2",v2)
68 | print("v3",v3)
69 | print("R1",R1)
70 | print("R2",R2)
71 | print("R3",R3)
72 |
73 | x1[i,:] = v1
74 | x2[i,:] = v2
75 | y[i,:] = v3
76 |
77 | if(gen_augmented_training_data):
78 | for j in range(num_conjugate):
79 | v4 = gen_random_rotation_vector()
80 | K4 = so3_hatlayer(v4)
81 | R4 = exp_so3(K4[0,:,:])
82 |
83 | R1_conj = R4@R1@R4.T
84 | R2_conj = R4@R2@R4.T
85 | R3_conj = R1_conj@R2_conj
86 |
87 | v1_conj = vee(log_SO3(R1_conj), algebra_type='so3')
88 | v2_conj = vee(log_SO3(R2_conj), algebra_type='so3')
89 | v3_conj = vee(log_SO3(R3_conj), algebra_type='so3')
90 |
91 | x1_conj[i,j,:] = v1_conj
92 | x2_conj[i,j,:] = v2_conj
93 | y_conj[i,j,:] = v3_conj
94 | R_aug[i,j,:,:] = R4
95 |
96 | if(gen_augmented_training_data):
97 | # concatenate the conjugate data
98 | train_data['x1'] = torch.cat((x1,x1_conj.reshape(num_training*num_conjugate,3)),dim=0).numpy()
99 | train_data['x2'] = torch.cat((x2,x2_conj.reshape(num_training*num_conjugate,3)),dim=0).numpy()
100 | train_data['y'] = torch.cat((y,y_conj.reshape(num_training*num_conjugate,3)),dim=0).numpy()
101 | train_data['R_aug'] = R_aug.numpy()
102 | else:
103 | train_data['x1'] = x1.numpy()
104 | train_data['x2'] = x2.numpy()
105 | train_data['y'] = y.numpy()
106 |
107 | np.savez(data_saved_path + data_name + "_train_data.npz", **train_data)
108 |
109 |
110 |
111 | # generate testing data
112 | x1 = torch.zeros((num_testing,3))
113 | x2 = torch.zeros((num_testing,3))
114 | y = torch.zeros((num_testing,3))
115 |
116 | x1_conj = torch.zeros((num_testing, num_conjugate,3))
117 | x2_conj = torch.zeros((num_testing, num_conjugate,3))
118 | y_conj = torch.zeros((num_testing, num_conjugate,3))
119 | R_aug = torch.zeros((num_testing, num_conjugate,3,3))
120 |
121 | for i in range(num_testing):
122 | # generate random v1, v2
123 | v1 = gen_random_rotation_vector()
124 | v2 = gen_random_rotation_vector()
125 |
126 | K1 = so3_hatlayer(v1)
127 | K2 = so3_hatlayer(v2)
128 |
129 | R1 = exp_so3(K1[0,:,:])
130 | R2 = exp_so3(K2[0,:,:])
131 |
132 | R3 = torch.matmul(R1,R2)
133 | v3 = vee(log_SO3(R3), algebra_type='so3')
134 |
135 | # v3 = vee(BCH_approx(K1[0,:,:], K2[0,:,:]), algebra_type='so3')
136 |
137 | x1[i,:] = v1
138 | x2[i,:] = v2
139 | y[i,:] = v3
140 |
141 | for j in range(num_conjugate):
142 | v4 = gen_random_rotation_vector()
143 | K4 = so3_hatlayer(v4)
144 | R4 = exp_so3(K4[0,:,:])
145 |
146 | R1_conj = R4@R1@R4.T
147 | R2_conj = R4@R2@R4.T
148 | R3_conj = R1_conj@R2_conj
149 |
150 | v1_conj = vee(log_SO3(R1_conj), algebra_type='so3')
151 | v2_conj = vee(log_SO3(R2_conj), algebra_type='so3')
152 | v3_conj = vee(log_SO3(R3_conj), algebra_type='so3')
153 |
154 | x1_conj[i,j,:] = v1_conj
155 | x2_conj[i,j,:] = v2_conj
156 | y_conj[i,j,:] = v3_conj
157 | R_aug[i,j,:,:] = R4
158 |
159 | test_data['x1'] = x1.numpy()
160 | test_data['x2'] = x2.numpy()
161 | test_data['x1_conjugate'] = x1_conj.numpy()
162 | test_data['x2_conjugate'] = x2_conj.numpy()
163 | test_data['y'] = y.numpy()
164 | test_data['y_conj'] = y_conj.numpy()
165 | test_data['R_aug'] = R_aug.numpy()
166 |
167 | np.savez(data_saved_path + data_name + "_test_data.npz", **test_data)
168 |
169 | print("Done! Data saved to: \n", data_saved_path +
170 | data_name + "_train_data.npz\n", data_saved_path + data_name + "_test_data.npz")
--------------------------------------------------------------------------------
/experiment/sl3_equiv_test.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import argparse
5 | import os
6 | import time
7 | import yaml
8 | import numpy as np
9 | from tqdm import tqdm
10 | import matplotlib.pyplot as plt
11 | import time
12 |
13 | import torch
14 | from torch import nn
15 | import torch.optim as optim
16 | from torch.utils.tensorboard import SummaryWriter
17 | from torch.utils.data import Dataset, DataLoader
18 |
19 | from core.lie_alg_util import *
20 | from core.lie_neurons_layers import *
21 | from experiment.sl3_equiv_layers import *
22 | from data_loader.sl3_equiv_data_loader import *
23 |
24 |
25 | def test(model, test_loader, criterion, config, device):
26 | model.eval()
27 | with torch.no_grad():
28 | loss_sum = 0.0
29 | for i, sample in tqdm(enumerate(test_loader, start=0)):
30 | x = sample['x'].to(device)
31 | y = sample['y'].to(device)
32 |
33 | output = model(x)
34 |
35 | loss = criterion(output, y)
36 | loss_sum += loss.item()
37 |
38 | loss_avg = loss_sum/len(test_loader)
39 |
40 | return loss_avg
41 |
42 |
43 | def test_equivariance(model, test_loader, criterion, config, device):
44 | model.eval()
45 | hat_layer = HatLayer(algebra_type='sl3').to(device)
46 | with torch.no_grad():
47 | loss_sum = 0.0
48 | loss_non_conj_sum = 0.0
49 | diff_output_sum = 0.0
50 | for i, sample in tqdm(enumerate(test_loader, start=0)):
51 | x = sample['x'].to(device) # [B, 5, 8, 1]
52 | x_conj = sample['x_conjugate'].to(device) # [B, C, 5, 8, 1]
53 | y = sample['y'].to(device) # [B, 8]
54 | H = sample['H'].to(device) # [B, C, 3, 3]
55 | y_conj = sample['y_conj']
56 |
57 | output_x = model(x) # [B, 8]
58 | output_x_hat = hat_layer(output_x) # [B, 3, 3]
59 | loss_non_conj_sum += criterion(output_x, y).item()
60 |
61 | for j in range(x_conj.shape[1]):
62 | x_conj_j = x_conj[:, j, :, :, :] # [B, 5, 8, 1]
63 | H_j = H[:, j, :, :] # [B, 3, 3]
64 | conj_output = model(x_conj_j) # [B, 8]
65 | # print('output hat', output_x_hat.shape)
66 | # print('H_j', H_j.shape)
67 | output_then_conj_hat = torch.matmul(H_j, torch.matmul(output_x_hat, torch.inverse(H_j)))
68 | # print('output then conj hat', output_then_conj_hat.shape)
69 | output_then_conj = vee_sl3(output_then_conj_hat)
70 |
71 | diff_output = output_then_conj - conj_output
72 | # print(output_then_conj)
73 |
74 | conj_output_hat = hat_layer(conj_output)
75 | conj_output_hat_conj_back = torch.matmul(torch.inverse(H_j), torch.matmul(conj_output_hat, H_j))
76 | conj_output_conj_back = vee_sl3(conj_output_hat_conj_back)
77 |
78 |
79 | # y_hat = hat_layer(y)
80 | # conj_y_hat = torch.matmul(H_j, torch.matmul(y_hat, torch.inverse(H_j)))
81 | # conj_y = vee_sl3(conj_y_hat)
82 | loss = criterion(conj_output_conj_back, y)
83 | # print("conj_out", conj_output[0,:])
84 | # print("y_conj", y_conj[0, j, :])
85 | # print("diff_out_out_conj", diff_output[0,:])
86 | # print("out",output_x[0,:])
87 | # print("y", y[0,:])
88 | # print("diff_out",output_x[0,:] - y[0,:])
89 | # print(loss.item())
90 | # print("----------------------")
91 | loss_sum += loss.item()
92 | diff_output_sum += torch.sum(torch.abs(diff_output))
93 |
94 | # print('diff_output: ', diff_output)
95 |
96 | loss_avg = loss_sum/len(test_loader)/x_conj.shape[1]
97 | diff_output_avg = diff_output_sum/len(test_loader.dataset)/x_conj.shape[1]/x_conj.shape[3]
98 | loss_non_conj_avg = loss_non_conj_sum/len(test_loader)
99 |
100 | return loss_avg, loss_non_conj_avg, diff_output_avg
101 |
102 | def main():
103 | # torch.autograd.set_detect_anomaly(True)
104 |
105 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
106 | print('Using ', device)
107 |
108 | parser = argparse.ArgumentParser(description='Train the network')
109 | parser.add_argument('--test_config', type=str,
110 | default=os.path.dirname(os.path.abspath(__file__))+'/../config/sl3_equiv/testing_param.yaml')
111 | args = parser.parse_args()
112 |
113 | # load yaml file
114 | config = yaml.safe_load(open(args.test_config))
115 |
116 | test_set = sl3EquivDataSetLieBracket2Input(config['test_data_path'], device=device)
117 | test_loader = DataLoader(dataset=test_set, batch_size=config['batch_size'],
118 | shuffle=config['shuffle'])
119 |
120 | if config['model_type'] == "LN_relu_bracket":
121 | model = SL3EquivariantReluBracketLayers(2).to(device)
122 | elif config['model_type'] == "LN_relu":
123 | model = SL3EquivariantReluLayers(2).to(device)
124 | elif config['model_type'] == "LN_bracket":
125 | model = SL3EquivariantBracketLayers(2).to(device)
126 | elif config['model_type'] == "MLP":
127 | model = MLP(16).to(device)
128 | elif config['model_type'] == "LN_bracket_no_residual":
129 | model = SL3EquivariantBracketNoResidualConnectLayers(2).to(device)
130 |
131 | print("Using model: ", config['model_type'])
132 | print("total number of parameters: ", sum(p.numel() for p in model.parameters()))
133 |
134 | # model = SL3InvariantLayersTest(2).to(device)
135 | checkpoint = torch.load(config['model_path'])
136 | model.load_state_dict(checkpoint['model_state_dict'],strict=False)
137 |
138 | criterion = nn.MSELoss().to(device)
139 | test_loss_equiv, loss_non_conj_avg, diff_output_avg = test_equivariance(model, test_loader, criterion, config, device)
140 | print("test_loss type:",type(test_loss_equiv))
141 | # print("avg diff output type: ", diff_output_avg.dtype)
142 | print("test loss: ", test_loss_equiv)
143 | print("avg diff output: ", diff_output_avg)
144 | print("loss non conj avg: ", loss_non_conj_avg)
145 |
146 | if __name__ == "__main__":
147 | main()
148 |
--------------------------------------------------------------------------------
/playground/neural_ode_demo.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import os
5 | import argparse
6 | import time
7 | import numpy as np
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.optim as optim
12 | from core.lie_neurons_layers import *
13 |
14 | parser = argparse.ArgumentParser('ODE demo')
15 | parser.add_argument('--method', type=str, choices=['dopri5', 'adams'], default='dopri5')
16 | parser.add_argument('--data_size', type=int, default=1000)
17 | parser.add_argument('--batch_time', type=int, default=10)
18 | parser.add_argument('--batch_size', type=int, default=20)
19 | parser.add_argument('--niters', type=int, default=2000)
20 | parser.add_argument('--test_freq', type=int, default=20)
21 | parser.add_argument('--viz', action='store_true')
22 | parser.add_argument('--gpu', type=int, default=0)
23 | parser.add_argument('--adjoint', action='store_true')
24 | args = parser.parse_args()
25 |
26 | if args.adjoint:
27 | from torchdiffeq import odeint_adjoint as odeint
28 | else:
29 | from torchdiffeq import odeint
30 |
31 | device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
32 |
33 | true_y0 = torch.tensor([[2., 0.]]).to(device)
34 | t = torch.linspace(0., 25., args.data_size).to(device)
35 | true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]]).to(device)
36 |
37 |
38 | class Lambda(nn.Module):
39 |
40 | def forward(self, t, y):
41 | return torch.mm(y**3, true_A)
42 |
43 |
44 | with torch.no_grad():
45 | true_y = odeint(Lambda(), true_y0, t, method='dopri5')
46 |
47 |
48 | def get_batch():
49 | s = torch.from_numpy(np.random.choice(np.arange(args.data_size - args.batch_time, dtype=np.int64), args.batch_size, replace=False))
50 | batch_y0 = true_y[s] # (M, D)
51 | batch_t = t[:args.batch_time] # (T)
52 | batch_y = torch.stack([true_y[s + i] for i in range(args.batch_time)], dim=0) # (T, M, D)
53 | return batch_y0.to(device), batch_t.to(device), batch_y.to(device)
54 |
55 |
56 | def makedirs(dirname):
57 | if not os.path.exists(dirname):
58 | os.makedirs(dirname)
59 |
60 |
61 | if args.viz:
62 | makedirs('png_neural_ode_demo')
63 | import matplotlib.pyplot as plt
64 | fig = plt.figure(figsize=(12, 4), facecolor='white')
65 | ax_traj = fig.add_subplot(131, frameon=False)
66 | ax_phase = fig.add_subplot(132, frameon=False)
67 | ax_vecfield = fig.add_subplot(133, frameon=False)
68 | plt.show(block=False)
69 |
70 |
71 | def visualize(true_y, pred_y, odefunc, itr):
72 |
73 | if args.viz:
74 |
75 | ax_traj.cla()
76 | ax_traj.set_title('Trajectories')
77 | ax_traj.set_xlabel('t')
78 | ax_traj.set_ylabel('x,y')
79 | ax_traj.plot(t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 0], t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 1], 'g-')
80 | ax_traj.plot(t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 0], '--', t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 1], 'b--')
81 | ax_traj.set_xlim(t.cpu().min(), t.cpu().max())
82 | ax_traj.set_ylim(-2, 2)
83 | ax_traj.legend()
84 |
85 | ax_phase.cla()
86 | ax_phase.set_title('Phase Portrait')
87 | ax_phase.set_xlabel('x')
88 | ax_phase.set_ylabel('y')
89 | ax_phase.plot(true_y.cpu().numpy()[:, 0, 0], true_y.cpu().numpy()[:, 0, 1], 'g-')
90 | ax_phase.plot(pred_y.cpu().numpy()[:, 0, 0], pred_y.cpu().numpy()[:, 0, 1], 'b--')
91 | ax_phase.set_xlim(-2, 2)
92 | ax_phase.set_ylim(-2, 2)
93 |
94 | ax_vecfield.cla()
95 | ax_vecfield.set_title('Learned Vector Field')
96 | ax_vecfield.set_xlabel('x')
97 | ax_vecfield.set_ylabel('y')
98 |
99 | y, x = np.mgrid[-2:2:21j, -2:2:21j]
100 | dydt = odefunc(0, torch.Tensor(np.stack([x, y], -1).reshape(21 * 21, 2)).to(device)).cpu().detach().numpy()
101 | mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1)
102 | dydt = (dydt / mag)
103 | dydt = dydt.reshape(21, 21, 2)
104 |
105 | ax_vecfield.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], color="black")
106 | ax_vecfield.set_xlim(-2, 2)
107 | ax_vecfield.set_ylim(-2, 2)
108 |
109 | fig.tight_layout()
110 | plt.savefig('png_neural_ode_demo/{:03d}'.format(itr))
111 | plt.draw()
112 | plt.pause(0.001)
113 |
114 |
115 | class ODEFunc(nn.Module):
116 |
117 | def __init__(self):
118 | super(ODEFunc, self).__init__()
119 |
120 | self.net = nn.Sequential(
121 | nn.Linear(2, 50),
122 | nn.Tanh(),
123 | nn.Linear(50, 2),
124 | )
125 |
126 | for m in self.net.modules():
127 | if isinstance(m, nn.Linear):
128 | nn.init.normal_(m.weight, mean=0, std=0.1)
129 | nn.init.constant_(m.bias, val=0)
130 |
131 | def forward(self, t, y):
132 | # print(y.shape)
133 | return self.net(y**3)
134 |
135 |
136 | class RunningAverageMeter(object):
137 | """Computes and stores the average and current value"""
138 |
139 | def __init__(self, momentum=0.99):
140 | self.momentum = momentum
141 | self.reset()
142 |
143 | def reset(self):
144 | self.val = None
145 | self.avg = 0
146 |
147 | def update(self, val):
148 | if self.val is None:
149 | self.avg = val
150 | else:
151 | self.avg = self.avg * self.momentum + val * (1 - self.momentum)
152 | self.val = val
153 |
154 |
155 | if __name__ == '__main__':
156 |
157 | ii = 0
158 |
159 | func = ODEFunc().to(device)
160 |
161 | optimizer = optim.RMSprop(func.parameters(), lr=1e-3)
162 | end = time.time()
163 |
164 | time_meter = RunningAverageMeter(0.97)
165 |
166 | loss_meter = RunningAverageMeter(0.97)
167 |
168 | for itr in range(1, args.niters + 1):
169 | optimizer.zero_grad()
170 | batch_y0, batch_t, batch_y = get_batch()
171 | pred_y = odeint(func, batch_y0, batch_t).to(device)
172 | loss = torch.mean(torch.abs(pred_y - batch_y))
173 | loss.backward()
174 | optimizer.step()
175 |
176 | time_meter.update(time.time() - end)
177 | loss_meter.update(loss.item())
178 |
179 | if itr % args.test_freq == 0:
180 | with torch.no_grad():
181 | pred_y = odeint(func, true_y0, t)
182 | loss = torch.mean(torch.abs(pred_y - true_y))
183 | print('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item()))
184 | visualize(true_y, pred_y, func, ii)
185 | ii += 1
186 |
187 | end = time.time()
188 |
--------------------------------------------------------------------------------
/experiment/sl3_inv_train.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import argparse
5 | import os
6 | import time
7 | import yaml
8 | import numpy as np
9 | from tqdm import tqdm
10 | import matplotlib.pyplot as plt
11 | import time
12 |
13 | import torch
14 | from torch import nn
15 | import torch.optim as optim
16 | from torch.utils.tensorboard import SummaryWriter
17 | from torch.utils.data import Dataset, DataLoader
18 |
19 | from core.lie_neurons_layers import *
20 | from experiment.sl3_inv_layers import *
21 | from data_loader.sl3_inv_data_loader import *
22 |
23 |
24 | def init_writer(config):
25 | writer = SummaryWriter(
26 | config['log_writer_path']+"_"+str(time.localtime()), comment=config['model_description'])
27 | writer.add_text("train_data_path: ", config['train_data_path'])
28 | writer.add_text("model_save_path: ", config['model_save_path'])
29 | writer.add_text("log_writer_path: ", config['log_writer_path'])
30 | writer.add_text("shuffle: ", str(config['shuffle']))
31 | writer.add_text("batch_size: ", str(config['batch_size']))
32 | writer.add_text("init_lr: ", str(config['initial_learning_rate']))
33 | writer.add_text("num_epochs: ", str(config['num_epochs']))
34 |
35 | return writer
36 |
37 |
38 | def test(model, test_loader, criterion, config, device):
39 | model.eval()
40 | with torch.no_grad():
41 | loss_sum = 0.0
42 | for i, sample in tqdm(enumerate(test_loader, start=0)):
43 | x = sample['x'].to(device)
44 | y = sample['y'].to(device)
45 |
46 | output = model(x)
47 |
48 | loss = criterion(output, y)
49 | loss_sum += loss.item()
50 |
51 | loss_avg = loss_sum/len(test_loader)
52 |
53 | return loss_avg
54 |
55 |
56 | def train(model, train_loader, test_loader, config, device='cpu'):
57 |
58 | writer = init_writer(config)
59 |
60 | # create criterion
61 | criterion = nn.MSELoss().to(device)
62 | # criterion = nn.L1Loss().to(device)
63 | optimizer = optim.Adam(model.parameters(
64 | ), lr=config['initial_learning_rate'], weight_decay=config['weight_decay_rate'])
65 | # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=config['learning_rate_decay_rate'])
66 | # scheduler = optim.lr_scheduler.LinearLR(optimizer,total_iters=config['num_epochs'])
67 |
68 | # if config['resume_training']:
69 | # checkpoint = torch.load(config['resume_model_path'])
70 | # model.load_state_dict(checkpoint['model_state_dict'])
71 | # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
72 | # start_epoch = checkpoint['epoch']
73 | # else:
74 | start_epoch = 0
75 |
76 | best_loss = float("inf")
77 | for epoch in range(start_epoch, config['num_epochs']):
78 | running_loss = 0.0
79 | loss_sum = 0.0
80 | model.train()
81 | optimizer.zero_grad()
82 | for i, sample in tqdm(enumerate(train_loader, start=0)):
83 | x = sample['x'].to(device)
84 | y = sample['y'].to(device)
85 |
86 | output = model(x)
87 |
88 | loss = criterion(output, y)
89 | loss.backward()
90 |
91 | # we only update the weights every config['update_every_batch'] iterations
92 | # This is to simulate a larger batch size
93 | # if (i+1) % config['update_every_batch'] == 0:
94 | optimizer.step()
95 | optimizer.zero_grad()
96 |
97 | # cur_training_loss_history.append(loss.item())
98 | running_loss += loss.item()
99 | loss_sum += loss.item()
100 |
101 | if i % config['print_freq'] == 0:
102 | print("epoch %d / %d, iteration %d / %d, loss: %.8f" %
103 | (epoch, config['num_epochs'], i, len(train_loader), running_loss/config['print_freq']))
104 | running_loss = 0.0
105 |
106 |
107 | # scheduler.step()
108 |
109 |
110 | train_loss = loss_sum/len(train_loader)
111 |
112 | test_loss = test(
113 | model, test_loader, criterion, config, device)
114 |
115 | # log down info in tensorboard
116 | writer.add_scalar('training loss', train_loss, epoch)
117 | writer.add_scalar('test loss', test_loss, epoch)
118 |
119 | # if we achieve best val loss, save the model
120 | if test_loss < best_loss:
121 | best_loss = test_loss
122 |
123 | state = {'epoch': epoch,
124 | 'model_state_dict': model.state_dict(),
125 | 'optimizer_state_dict': optimizer.state_dict(),
126 | 'loss': train_loss,
127 | 'test loss': test_loss}
128 |
129 | torch.save(state, config['model_save_path'] +
130 | '_best_test_loss_acc.pt')
131 | print("------------------------------")
132 | print("Finished epoch %d / %d, train loss: %.4f test loss: %.4f" %
133 | (epoch, config['num_epochs'], train_loss, test_loss))
134 |
135 | # save model
136 | state = {'epoch': epoch,
137 | 'model_state_dict': model.state_dict(),
138 | 'optimizer_state_dict': optimizer.state_dict(),
139 | 'loss': train_loss,
140 | 'test loss': test_loss}
141 |
142 | torch.save(state, config['model_save_path']+'_last_epo.pt')
143 |
144 | writer.close()
145 |
146 |
147 | def main():
148 | # torch.autograd.set_detect_anomaly(True)
149 |
150 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
151 | print('Using ', device)
152 |
153 | parser = argparse.ArgumentParser(description='Train the network')
154 | parser.add_argument('--training_config', type=str,
155 | default=os.path.dirname(os.path.abspath(__file__))+'/../config/sl3_inv/training_param.yaml')
156 | args = parser.parse_args()
157 |
158 | # load yaml file
159 | config = yaml.safe_load(open(args.training_config))
160 |
161 |
162 | training_set = sl3InvDataSet2Input(config['train_data_path'], device=device)
163 | train_loader = DataLoader(dataset=training_set, batch_size=config['batch_size'],
164 | shuffle=config['shuffle'])
165 |
166 | test_set = sl3InvDataSet2Input(config['test_data_path'], device=device)
167 | test_loader = DataLoader(dataset=test_set, batch_size=config['batch_size'],
168 | shuffle=config['shuffle'])
169 |
170 |
171 | if config['model_type'] == "LN_relu_bracket":
172 | model = SL3InvariantReluBracketLayers(2).to(device)
173 | elif config['model_type'] == "LN_relu":
174 | model = SL3InvariantReluLayers(2).to(device)
175 | elif config['model_type'] == "LN_bracket":
176 | model = SL3InvariantBracketLayers(2).to(device)
177 | elif config['model_type'] == "MLP":
178 | model = MLP(16).to(device)
179 | elif config['model_type'] == "LN_bracket_no_residual":
180 | model = SL3InvariantBracketNoResidualConnectLayers(2).to(device)
181 |
182 | train(model, train_loader, test_loader, config, device)
183 |
184 |
185 | if __name__ == "__main__":
186 | main()
187 |
--------------------------------------------------------------------------------
/experiment/sp4_inv_train.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import argparse
5 | import os
6 | import time
7 | import yaml
8 | import numpy as np
9 | from tqdm import tqdm
10 | import matplotlib.pyplot as plt
11 | import time
12 |
13 | import torch
14 | from torch import nn
15 | import torch.optim as optim
16 | from torch.utils.tensorboard import SummaryWriter
17 | from torch.utils.data import Dataset, DataLoader
18 |
19 | from core.lie_neurons_layers import *
20 | from experiment.sp4_inv_layers import *
21 | from data_loader.sp4_inv_data_loader import *
22 |
23 |
24 | def init_writer(config):
25 | writer = SummaryWriter(
26 | config['log_writer_path']+"_"+str(time.localtime()), comment=config['model_description'])
27 | writer.add_text("train_data_path: ", config['train_data_path'])
28 | writer.add_text("model_save_path: ", config['model_save_path'])
29 | writer.add_text("log_writer_path: ", config['log_writer_path'])
30 | writer.add_text("shuffle: ", str(config['shuffle']))
31 | writer.add_text("batch_size: ", str(config['batch_size']))
32 | writer.add_text("init_lr: ", str(config['initial_learning_rate']))
33 | writer.add_text("num_epochs: ", str(config['num_epochs']))
34 |
35 | return writer
36 |
37 |
38 | def test(model, test_loader, criterion, config, device):
39 | model.eval()
40 | with torch.no_grad():
41 | loss_sum = 0.0
42 | for i, sample in tqdm(enumerate(test_loader, start=0)):
43 | x = sample['x'].to(device)
44 | y = sample['y'].to(device)
45 |
46 | output = model(x)
47 |
48 | loss = criterion(output, y)
49 | loss_sum += loss.item()
50 |
51 | loss_avg = loss_sum/len(test_loader)
52 |
53 | return loss_avg
54 |
55 |
56 | def train(model, train_loader, test_loader, config, device='cpu'):
57 |
58 | writer = init_writer(config)
59 |
60 | # create criterion
61 | criterion = nn.MSELoss().to(device)
62 | # criterion = nn.L1Loss().to(device)
63 | optimizer = optim.Adam(model.parameters(
64 | ), lr=config['initial_learning_rate'], weight_decay=config['weight_decay_rate'])
65 | # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=config['learning_rate_decay_rate'])
66 | # scheduler = optim.lr_scheduler.LinearLR(optimizer,total_iters=config['num_epochs'])
67 |
68 | # if config['resume_training']:
69 | # checkpoint = torch.load(config['resume_model_path'])
70 | # model.load_state_dict(checkpoint['model_state_dict'])
71 | # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
72 | # start_epoch = checkpoint['epoch']
73 | # else:
74 | start_epoch = 0
75 |
76 | best_loss = float("inf")
77 | for epoch in range(start_epoch, config['num_epochs']):
78 | running_loss = 0.0
79 | loss_sum = 0.0
80 | model.train()
81 | optimizer.zero_grad()
82 | for i, sample in tqdm(enumerate(train_loader, start=0)):
83 | x = sample['x'].to(device)
84 | y = sample['y'].to(device)
85 |
86 | output = model(x)
87 |
88 | loss = criterion(output, y)
89 | loss.backward()
90 |
91 | # we only update the weights every config['update_every_batch'] iterations
92 | # This is to simulate a larger batch size
93 | # if (i+1) % config['update_every_batch'] == 0:
94 | optimizer.step()
95 | optimizer.zero_grad()
96 |
97 | # cur_training_loss_history.append(loss.item())
98 | running_loss += loss.item()
99 | loss_sum += loss.item()
100 |
101 | if i % config['print_freq'] == 0:
102 | print("epoch %d / %d, iteration %d / %d, loss: %.8f" %
103 | (epoch, config['num_epochs'], i, len(train_loader), running_loss/config['print_freq']))
104 | running_loss = 0.0
105 |
106 |
107 | # scheduler.step()
108 |
109 |
110 | train_loss = loss_sum/len(train_loader)
111 |
112 | test_loss = test(
113 | model, test_loader, criterion, config, device)
114 |
115 | # log down info in tensorboard
116 | writer.add_scalar('training loss', train_loss, epoch)
117 | writer.add_scalar('test loss', test_loss, epoch)
118 |
119 | # if we achieve best val loss, save the model
120 | if test_loss < best_loss:
121 | best_loss = test_loss
122 |
123 | state = {'epoch': epoch,
124 | 'model_state_dict': model.state_dict(),
125 | 'optimizer_state_dict': optimizer.state_dict(),
126 | 'loss': train_loss,
127 | 'test loss': test_loss}
128 |
129 | torch.save(state, config['model_save_path'] +
130 | '_best_test_loss_acc.pt')
131 | print("------------------------------")
132 | print("Finished epoch %d / %d, train loss: %.4f test loss: %.4f" %
133 | (epoch, config['num_epochs'], train_loss, test_loss))
134 |
135 | # save model
136 | state = {'epoch': epoch,
137 | 'model_state_dict': model.state_dict(),
138 | 'optimizer_state_dict': optimizer.state_dict(),
139 | 'loss': train_loss,
140 | 'test loss': test_loss}
141 |
142 | torch.save(state, config['model_save_path']+'_last_epo.pt')
143 |
144 | writer.close()
145 |
146 |
147 | def main():
148 | # torch.autograd.set_detect_anomaly(True)
149 |
150 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
151 | print('Using ', device)
152 |
153 | parser = argparse.ArgumentParser(description='Train the network')
154 | parser.add_argument('--training_config', type=str,
155 | default=os.path.dirname(os.path.abspath(__file__))+'/../config/sp4_inv/training_param.yaml')
156 | args = parser.parse_args()
157 |
158 | # load yaml file
159 | config = yaml.safe_load(open(args.training_config))
160 |
161 |
162 | training_set = sp4InvDataSet(config['train_data_path'], device=device)
163 | train_loader = DataLoader(dataset=training_set, batch_size=config['batch_size'],
164 | shuffle=config['shuffle'])
165 |
166 | test_set = sp4InvDataSet(config['test_data_path'], device=device)
167 | test_loader = DataLoader(dataset=test_set, batch_size=config['batch_size'],
168 | shuffle=config['shuffle'])
169 |
170 |
171 | if config['model_type'] == "LN_relu_bracket":
172 | model = SP4InvariantReluBracketLayers(2).to(device)
173 | elif config['model_type'] == "LN_relu":
174 | model = SP4InvariantReluLayers(2).to(device)
175 | elif config['model_type'] == "LN_bracket":
176 | model = SP4InvariantBracketLayers(2).to(device)
177 | elif config['model_type'] == "MLP":
178 | model = MLP(20).to(device)
179 | elif config['model_type'] == "MLP512":
180 | model = MLP512(20).to(device)
181 | elif config['model_type'] == "LN_bracket_no_residual":
182 | model = SP4InvariantBracketNoResidualConnectLayers(2).to(device)
183 |
184 | train(model, train_loader, test_loader, config, device)
185 |
186 |
187 | if __name__ == "__main__":
188 | main()
189 |
--------------------------------------------------------------------------------
/experiment/so3_bch_layers.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import os
5 | import copy
6 | import math
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from einops import rearrange, repeat
12 | from einops.layers.torch import Rearrange
13 |
14 |
15 | from core.lie_alg_util import *
16 | from core.lie_neurons_layers import *
17 | from core.vn_layers import *
18 |
19 |
20 | class SO3EquivariantVNReluLayers(nn.Module):
21 | def __init__(self, in_channels):
22 | super(SO3EquivariantVNReluLayers, self).__init__()
23 | feat_dim = 1024
24 | share_nonlinearity = False
25 | leaky_relu = True
26 | self.ln_fc = VNLinearAndLeakyReLU(
27 | in_channels, feat_dim, share_nonlinearity=share_nonlinearity, leaky_relu=leaky_relu, dim=4)
28 | self.ln_fc2 = VNLinearAndLeakyReLU(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity,leaky_relu=leaky_relu, dim=4)
29 | self.ln_fc3 = VNLinearAndLeakyReLU(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity,leaky_relu=leaky_relu, dim=4)
30 | self.ln_fc4 = VNLinearAndLeakyReLU(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity,leaky_relu=leaky_relu, dim=4)
31 |
32 | self.fc_final = nn.Linear(feat_dim, 1, bias=False)
33 |
34 | def forward(self, x):
35 | '''
36 | x input of shape [B, F, 3, 1]
37 | '''
38 | x = self.ln_fc(x)
39 | x = self.ln_fc2(x)
40 | x = self.ln_fc3(x)
41 | x = self.ln_fc4(x)
42 |
43 | x = torch.permute(x, (0, 3, 2, 1)) # [B, 1, 3, F]
44 | x_out = rearrange(self.fc_final(x), 'b 1 k 1 -> b k') # [B, 3]
45 | return x_out
46 |
47 | class SO3EquivariantReluLayers(nn.Module):
48 | def __init__(self, in_channels):
49 | super(SO3EquivariantReluLayers, self).__init__()
50 | feat_dim = 1024
51 | share_nonlinearity = False
52 | leaky_relu = True
53 | self.ln_fc = LNLinearAndKillingRelu(
54 | in_channels, feat_dim, share_nonlinearity=share_nonlinearity, leaky_relu=leaky_relu, algebra_type='so3')
55 | self.ln_fc2 = LNLinearAndKillingRelu(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity,leaky_relu=leaky_relu, algebra_type='so3')
56 | self.ln_fc3 = LNLinearAndKillingRelu(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity,leaky_relu=leaky_relu, algebra_type='so3')
57 | self.ln_fc4 = LNLinearAndKillingRelu(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity,leaky_relu=leaky_relu, algebra_type='so3')
58 |
59 | self.fc_final = nn.Linear(feat_dim, 1, bias=False)
60 |
61 | def forward(self, x):
62 | '''
63 | x input of shape [B, F, 3, 1]
64 | '''
65 | x = self.ln_fc(x) # [B, F, 3, 1]
66 | x = self.ln_fc2(x) # [B, F, 3, 1]
67 | x = self.ln_fc3(x)
68 | x = self.ln_fc4(x)
69 |
70 | x = torch.permute(x, (0, 3, 2, 1)) # [B, 1, 3, F]
71 | x_out = rearrange(self.fc_final(x), 'b 1 k 1 -> b k') # [B, 3]
72 | return x_out
73 |
74 |
75 | class SO3EquivariantBracketLayers(nn.Module):
76 | def __init__(self, in_channels):
77 | super(SO3EquivariantBracketLayers, self).__init__()
78 | feat_dim = 1024
79 | share_nonlinearity = False
80 | self.ln_fc = LNLinearAndLieBracket(in_channels, feat_dim,share_nonlinearity=share_nonlinearity, algebra_type='so3')
81 | self.ln_fc2 = LNLinearAndLieBracket(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity, algebra_type='so3')
82 | self.ln_fc3 = LNLinearAndLieBracket(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity, algebra_type='so3')
83 | self.ln_fc4 = LNLinearAndLieBracket(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity, algebra_type='so3')
84 | # self.ln_fc5 = LNLinearAndLieBracket(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity, algebra_type='so3')
85 | # self.ln_fc6 = LNLinearAndLieBracket(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity, algebra_type='so3')
86 | # self.ln_fc7 = LNLinearAndLieBracket(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity, algebra_type='so3')
87 |
88 | self.fc_final = nn.Linear(feat_dim, 1, bias=False)
89 |
90 | def forward(self, x):
91 | '''
92 | x input of shape [B, F, 3, 1]
93 | '''
94 | x = self.ln_fc(x) # [B, F, 3, 1]
95 | x = self.ln_fc2(x) # [B, F, 3, 1]
96 | x = self.ln_fc3(x)
97 | x = self.ln_fc4(x)
98 | # x = self.ln_fc5(x)
99 | # x = self.ln_fc6(x)
100 | # x = self.ln_fc7(x)
101 |
102 | x = torch.permute(x, (0, 3, 2, 1)) # [B, 1, 3, F]
103 | x_out = rearrange(self.fc_final(x), 'b 1 k 1 -> b k') # [B, 3]
104 | return x_out
105 |
106 | class SO3EquivariantReluBracketLayers(nn.Module):
107 | def __init__(self, in_channels):
108 | super(SO3EquivariantReluBracketLayers, self).__init__()
109 | feat_dim = 1024
110 | share_nonlinearity = False
111 | leaky_relu = True
112 | self.ln_fc = LNLinearAndKillingRelu(
113 | feat_dim, feat_dim, share_nonlinearity=share_nonlinearity, leaky_relu=leaky_relu, algebra_type='so3')
114 | self.ln_fc2 = LNLinearAndKillingRelu(
115 | feat_dim, feat_dim, share_nonlinearity=share_nonlinearity, leaky_relu=leaky_relu, algebra_type='so3')
116 | self.ln_fc_bracket = LNLinearAndLieBracket(in_channels, feat_dim,share_nonlinearity=share_nonlinearity, algebra_type='so3')
117 | self.ln_fc_bracket2 = LNLinearAndLieBracket(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity, algebra_type='so3')
118 |
119 | self.fc_final = nn.Linear(feat_dim, 1, bias=False)
120 |
121 | def forward(self, x):
122 | '''
123 | x input of shape [B, F, 3, 1]
124 | '''
125 |
126 | x = self.ln_fc_bracket(x) # [B, F, 3, 1]
127 | x = self.ln_fc(x) # [B, F, 3, 1]
128 | x = self.ln_fc_bracket2(x)
129 | x = self.ln_fc2(x)
130 |
131 | x = torch.permute(x, (0, 3, 2, 1)) # [B, 1, 3, F]
132 | x_out = rearrange(self.fc_final(x), 'b 1 k 1 -> b k') # [B, 3]
133 | # x_out = rearrange(self.ln_pooling(x), 'b 1 k 1 -> b k') # [B, F, 1, 1]
134 | return x_out
135 |
136 |
137 | class MLP(nn.Module):
138 | def __init__(self, in_channels):
139 | super(MLP, self).__init__()
140 | feat_dim = 1024
141 | self.fc = nn.Linear(in_channels, feat_dim)
142 | self.relu = nn.ReLU()
143 | self.fc2 = nn.Linear(feat_dim, feat_dim)
144 | self.fc3 = nn.Linear(feat_dim, feat_dim)
145 |
146 | self.fc4 = nn.Linear(feat_dim, feat_dim)
147 | self.fc_final = nn.Linear(feat_dim, 3)
148 |
149 | def forward(self, x):
150 | '''
151 | x input of shape [B, F, 3, 1]
152 | '''
153 | B, F, _, _ = x.shape
154 | x = torch.reshape(x, (B, -1))
155 | x = self.fc(x)
156 | x = self.relu(x)
157 | x = self.fc2(x)
158 | x = self.relu(x)
159 | x = self.fc3(x)
160 | x = self.relu(x)
161 | x = self.fc4(x)
162 | x = self.relu(x)
163 | x_out = torch.reshape(self.fc_final(x), (B, 3)) # [B, 3]
164 |
165 | return x_out
166 |
--------------------------------------------------------------------------------
/experiment/platonic_solid_cls_layers.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import os
5 | import copy
6 | import math
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from einops import rearrange, repeat
12 | from einops.layers.torch import Rearrange
13 |
14 | from core.lie_alg_util import *
15 | from core.lie_neurons_layers import *
16 |
17 |
18 | class LNPlatonicSolidClassifier(nn.Module):
19 | def __init__(self, in_channels):
20 | super(LNPlatonicSolidClassifier, self).__init__()
21 | feat_dim = 256
22 | share_nonlinearity = False
23 | self.ln_fc = LNLinearAndKillingRelu(
24 | in_channels, feat_dim, share_nonlinearity=share_nonlinearity)
25 | self.ln_fc2 = LNLinearAndKillingRelu(
26 | feat_dim, feat_dim, share_nonlinearity=share_nonlinearity)
27 | self.ln_fc3 = LNLinearAndKillingRelu(
28 | feat_dim, feat_dim, share_nonlinearity=share_nonlinearity)
29 | self.ln_pooling = LNMaxPool(
30 | feat_dim, abs_killing_form=False) # [B, F, 8, 1]
31 | self.ln_inv = LNInvariant(feat_dim, method='self_killing')
32 | self.fc_final = nn.Linear(feat_dim, 3, bias=False)
33 |
34 | def forward(self, x):
35 | '''
36 | x input of shape [B, F, 8, 1]
37 | '''
38 | x = self.ln_fc(x) # [B, F, 8, N]
39 | # x = self.ln_fc2(x) # [B, F, 8, N]
40 | # x = self.ln_fc3(x) # [B, F, 8, N]
41 | x = self.ln_pooling(x) # [B, F, 8, 1]
42 | x_inv = self.ln_inv(x).unsqueeze(-1) # [B, F, 1, 1]
43 | x_inv = torch.permute(x_inv, (0, 3, 2, 1)) # [B, 1, 1, F]
44 | x_out = rearrange(self.fc_final(x_inv),
45 | 'b 1 1 cls -> b cls') # [B, cls]
46 |
47 | return x_out
48 |
49 |
50 | class LNReluPlatonicSolidClassifier(nn.Module):
51 | def __init__(self, in_channels):
52 | super(LNReluPlatonicSolidClassifier, self).__init__()
53 | feat_dim = 256
54 | share_nonlinearity = False
55 | self.ln_fc = LNLinearAndKillingRelu(
56 | in_channels, feat_dim, share_nonlinearity=share_nonlinearity)
57 | self.ln_pooling = LNMaxPool(
58 | feat_dim, abs_killing_form=False) # [B, F, 8, 1]
59 | self.ln_inv = LNInvariant(feat_dim, method='self_killing')
60 | self.fc_final = nn.Linear(feat_dim, 3, bias=False)
61 |
62 | def forward(self, x):
63 | '''
64 | x input of shape [B, F, 8, 1]
65 | '''
66 | x = self.ln_fc(x) # [B, F, 8, N]
67 | x = self.ln_pooling(x) # [B, F, 8, 1]
68 | x_inv = self.ln_inv(x).unsqueeze(-1) # [B, F, 1, 1]
69 | x_inv = torch.permute(x_inv, (0, 3, 2, 1)) # [B, 1, 1, F]
70 | x_out = rearrange(self.fc_final(x_inv),
71 | 'b 1 1 cls -> b cls') # [B, cls]
72 |
73 | return x_out
74 |
75 |
76 | class LNBracketPlatonicSolidClassifier(nn.Module):
77 | def __init__(self, in_channels):
78 | super(LNBracketPlatonicSolidClassifier, self).__init__()
79 | feat_dim = 256
80 | share_nonlinearity = False
81 | self.ln_fc = LNLinearAndLieBracket(
82 | in_channels, feat_dim, share_nonlinearity=share_nonlinearity)
83 | # self.ln_fc2 = LNLinearAndLieBracket(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity)
84 | self.ln_pooling = LNMaxPool(
85 | feat_dim, abs_killing_form=False) # [B, F, 8, 1]
86 | self.ln_inv = LNInvariant(feat_dim, method='self_killing')
87 | self.fc_final = nn.Linear(feat_dim, 3, bias=False)
88 |
89 | def forward(self, x):
90 | '''
91 | x input of shape [B, F, 8, 1]
92 | '''
93 | x = self.ln_fc(x) # [B, F, 8, N]
94 | x = self.ln_pooling(x) # [B, F, 8, 1]
95 | x_inv = self.ln_inv(x).unsqueeze(-1) # [B, F, 1, 1]
96 | x_inv = torch.permute(x_inv, (0, 3, 2, 1)) # [B, 1, 1, F]
97 | x_out = rearrange(self.fc_final(x_inv),
98 | 'b 1 1 cls -> b cls') # [B, cls]
99 |
100 | return x_out
101 |
102 |
103 | class LNReluBracketPlatonicSolidClassifier(nn.Module):
104 | def __init__(self, in_channels):
105 | super(LNReluBracketPlatonicSolidClassifier, self).__init__()
106 | feat_dim = 256
107 | share_nonlinearity = False
108 | self.ln_fc = LNLinearAndKillingRelu(
109 | in_channels, feat_dim, share_nonlinearity=share_nonlinearity)
110 | self.ln_fc2 = LNLinearAndLieBracket(
111 | feat_dim, feat_dim, share_nonlinearity=share_nonlinearity)
112 | self.ln_pooling = LNMaxPool(
113 | feat_dim, abs_killing_form=False) # [B, F, 8, 1]
114 | self.ln_inv = LNInvariant(feat_dim, method='self_killing')
115 | self.fc_final = nn.Linear(feat_dim, 3, bias=False)
116 |
117 | def forward(self, x):
118 | '''
119 | x input of shape [B, F, 8, 1]
120 | '''
121 | x = self.ln_fc(x) # [B, F, 8, N]
122 | x = self.ln_fc2(x) # [B, F, 8, N]
123 | x = self.ln_pooling(x) # [B, F, 8, 1]
124 | x_inv = self.ln_inv(x).unsqueeze(-1) # [B, F, 1, 1]
125 | x_inv = torch.permute(x_inv, (0, 3, 2, 1)) # [B, 1, 1, F]
126 | x_out = rearrange(self.fc_final(x_inv),
127 | 'b 1 1 cls -> b cls') # [B, cls]
128 |
129 | return x_out
130 |
131 |
132 | class LNBracketNoResidualConnectPlatonicSolidClassifier(nn.Module):
133 | def __init__(self, in_channels):
134 | super(LNBracketNoResidualConnectPlatonicSolidClassifier, self).__init__()
135 | feat_dim = 256
136 | share_nonlinearity = False
137 | self.ln_fc = LNLinearAndLieBracketNoResidualConnect(
138 | in_channels, feat_dim, share_nonlinearity=share_nonlinearity)
139 | # self.ln_fc2 = LNLinearAndLieBracket(feat_dim, feat_dim,share_nonlinearity=share_nonlinearity)
140 | self.ln_pooling = LNMaxPool(
141 | feat_dim, abs_killing_form=False) # [B, F, 8, 1]
142 | self.ln_inv = LNInvariant(feat_dim, method='self_killing')
143 | self.fc_final = nn.Linear(feat_dim, 3, bias=False)
144 |
145 | def forward(self, x):
146 | '''
147 | x input of shape [B, F, 8, 1]
148 | '''
149 | x = self.ln_fc(x) # [B, F, 8, N]
150 | x = self.ln_pooling(x) # [B, F, 8, 1]
151 | x_inv = self.ln_inv(x).unsqueeze(-1) # [B, F, 1, 1]
152 | x_inv = torch.permute(x_inv, (0, 3, 2, 1)) # [B, 1, 1, F]
153 | x_out = rearrange(self.fc_final(x_inv),
154 | 'b 1 1 cls -> b cls') # [B, cls]
155 |
156 | return x_out
157 |
158 |
159 | class MLP(nn.Module):
160 | def __init__(self, in_channels):
161 | super(MLP, self).__init__()
162 | feat_dim = 256
163 | self.fc = nn.Linear(in_channels, feat_dim)
164 | self.fc2 = nn.Linear(feat_dim, feat_dim)
165 | self.fc3 = nn.Linear(feat_dim, feat_dim)
166 | self.relu = nn.ReLU()
167 | self.fc_final = nn.Linear(feat_dim, 3)
168 |
169 | def forward(self, x):
170 | '''
171 | x input of shape [B, F, 8, N]
172 | '''
173 | B, F, _, _ = x.shape
174 | x = torch.reshape(x, (B, -1))
175 | x = self.fc(x)
176 | x = self.relu(x)
177 | x = self.fc2(x)
178 | x = self.relu(x)
179 | x = self.fc3(x)
180 | x = self.relu(x)
181 | x_out = torch.reshape(self.fc_final(x), (B, 3)) # [B, cls]
182 |
183 | return x_out
184 |
--------------------------------------------------------------------------------
/experiment/sl3_equiv_train.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import argparse
5 | import os
6 | import time
7 | import yaml
8 | import numpy as np
9 | from tqdm import tqdm
10 | import matplotlib.pyplot as plt
11 | import time
12 |
13 | import torch
14 | from torch import nn
15 | import torch.optim as optim
16 | from torch.utils.tensorboard import SummaryWriter
17 | from torch.utils.data import Dataset, DataLoader
18 |
19 | from core.lie_neurons_layers import *
20 | from experiment.sl3_equiv_layers import *
21 | from data_loader.sl3_equiv_data_loader import *
22 |
23 |
24 | def init_writer(config):
25 | writer = SummaryWriter(
26 | config['log_writer_path']+"_"+str(time.localtime()), comment=config['model_description'])
27 | writer.add_text("train_data_path: ", config['train_data_path'])
28 | writer.add_text("model_save_path: ", config['model_save_path'])
29 | writer.add_text("log_writer_path: ", config['log_writer_path'])
30 | writer.add_text("shuffle: ", str(config['shuffle']))
31 | writer.add_text("batch_size: ", str(config['batch_size']))
32 | writer.add_text("init_lr: ", str(config['initial_learning_rate']))
33 | writer.add_text("num_epochs: ", str(config['num_epochs']))
34 |
35 | return writer
36 |
37 |
38 | def test(model, test_loader, criterion, config, device):
39 | model.eval()
40 | with torch.no_grad():
41 | loss_sum = 0.0
42 | for i, sample in tqdm(enumerate(test_loader, start=0)):
43 | x = sample['x'].to(device)
44 | y = sample['y'].to(device)
45 |
46 | output = model(x)
47 |
48 | loss = criterion(output, y)
49 | loss_sum += loss.item()
50 |
51 | loss_avg = loss_sum/len(test_loader)
52 |
53 | return loss_avg
54 |
55 |
56 | def train(model, train_loader, test_loader, config, device='cpu'):
57 |
58 | writer = init_writer(config)
59 |
60 | # create criterion
61 | criterion = nn.MSELoss().to(device)
62 | # criterion = nn.L1Loss().to(device)
63 | optimizer = optim.Adam(model.parameters(
64 | ), lr=config['initial_learning_rate'], weight_decay=config['weight_decay_rate'])
65 | # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=config['learning_rate_decay_rate'])
66 | # scheduler = optim.lr_scheduler.LinearLR(optimizer,total_iters=config['num_epochs'])
67 |
68 | # if config['resume_training']:
69 | # checkpoint = torch.load(config['resume_model_path'])
70 | # model.load_state_dict(checkpoint['model_state_dict'])
71 | # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
72 | # start_epoch = checkpoint['epoch']
73 | # else:
74 | start_epoch = 0
75 |
76 | best_loss = float("inf")
77 | for epoch in range(start_epoch, config['num_epochs']):
78 | running_loss = 0.0
79 | loss_sum = 0.0
80 | model.train()
81 | optimizer.zero_grad()
82 | for i, sample in tqdm(enumerate(train_loader, start=0)):
83 | x = sample['x'].to(device)
84 | y = sample['y'].to(device)
85 |
86 | output = model(x)
87 |
88 | loss = criterion(output, y)
89 | loss.backward()
90 |
91 | # we only update the weights every config['update_every_batch'] iterations
92 | # This is to simulate a larger batch size
93 | # if (i+1) % config['update_every_batch'] == 0:
94 | optimizer.step()
95 | optimizer.zero_grad()
96 |
97 | # cur_training_loss_history.append(loss.item())
98 | running_loss += loss.item()
99 | loss_sum += loss.item()
100 |
101 | if i % config['print_freq'] == 0:
102 | print("epoch %d / %d, iteration %d / %d, loss: %.8f" %
103 | (epoch, config['num_epochs'], i, len(train_loader), running_loss/config['print_freq']))
104 | running_loss = 0.0
105 |
106 |
107 | # scheduler.step()
108 |
109 | # train_top1, train_top5, _ = validate(train_loader, model, criterion, config, device)
110 |
111 | train_loss = loss_sum/len(train_loader)
112 |
113 | test_loss = test(
114 | model, test_loader, criterion, config, device)
115 |
116 | # log down info in tensorboard
117 | writer.add_scalar('training loss', train_loss, epoch)
118 | writer.add_scalar('test loss', test_loss, epoch)
119 |
120 | # if we achieve best val loss, save the model
121 | if test_loss < best_loss:
122 | best_loss = test_loss
123 |
124 | state = {'epoch': epoch,
125 | 'model_state_dict': model.state_dict(),
126 | 'optimizer_state_dict': optimizer.state_dict(),
127 | 'loss': train_loss,
128 | 'test loss': test_loss}
129 |
130 | torch.save(state, config['model_save_path'] +
131 | '_best_test_loss_acc.pt')
132 | print("------------------------------")
133 | # print("Finished epoch %d / %d, training top 1 acc: %.4f, training top 5 acc: %.4f, \
134 | # validation top1 acc: %.4f, validation top 5 acc: %.4f" %\
135 | # (epoch, config['num_epochs'], train_top1, train_top5, val_top1, val_top5))
136 | print("Finished epoch %d / %d, train loss: %.4f test loss: %.4f" %
137 | (epoch, config['num_epochs'], train_loss, test_loss))
138 |
139 | # save model
140 | state = {'epoch': epoch,
141 | 'model_state_dict': model.state_dict(),
142 | 'optimizer_state_dict': optimizer.state_dict(),
143 | 'loss': train_loss,
144 | 'test loss': test_loss}
145 |
146 | torch.save(state, config['model_save_path']+'_last_epo.pt')
147 |
148 | writer.close()
149 |
150 |
151 | def main():
152 | # torch.autograd.set_detect_anomaly(True)
153 |
154 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
155 | print('Using ', device)
156 |
157 | parser = argparse.ArgumentParser(description='Train the network')
158 | parser.add_argument('--training_config', type=str,
159 | default=os.path.dirname(os.path.abspath(__file__))+'/../config/sl3_equiv/training_param.yaml')
160 | args = parser.parse_args()
161 |
162 | # load yaml file
163 | config = yaml.safe_load(open(args.training_config))
164 |
165 | # 5 input
166 | training_set = sl3EquivDataSetLieBracket2Input(
167 | config['train_data_path'], device=device)
168 | train_loader = DataLoader(dataset=training_set, batch_size=config['batch_size'],
169 | shuffle=config['shuffle'])
170 |
171 | test_set = sl3EquivDataSetLieBracket2Input(
172 | config['test_data_path'], device=device)
173 | test_loader = DataLoader(dataset=test_set, batch_size=config['batch_size'],
174 | shuffle=config['shuffle'])
175 |
176 | if config['model_type'] == "LN_relu_bracket":
177 | model = SL3EquivariantReluBracketLayers(2).to(device)
178 | elif config['model_type'] == "LN_relu":
179 | model = SL3EquivariantReluLayers(2).to(device)
180 | elif config['model_type'] == "LN_bracket":
181 | model = SL3EquivariantBracketLayers(2).to(device)
182 | elif config['model_type'] == "MLP":
183 | model = MLP(16).to(device)
184 | elif config['model_type'] == "LN_bracket_no_residual":
185 | model = SL3EquivariantBracketNoResidualConnectLayers(2).to(device)
186 |
187 | train(model, train_loader, test_loader, config, device)
188 |
189 |
190 | if __name__ == "__main__":
191 | main()
192 |
--------------------------------------------------------------------------------
/matlab/sl3_equivariant_lift_find_bound.asv:
--------------------------------------------------------------------------------
1 | clc;
2 | clear;
3 |
4 | compute_K = false;
5 | num_K = 1000;
6 | rnd_scale_K = 100;
7 |
8 | compute_A = false;
9 | num_A = 1000;
10 | rnd_scale_A = 100;
11 |
12 |
13 | compute_N = false;
14 | num_N = 1000;
15 | rnd_scale_N = 100;
16 |
17 | compute_KN = true;
18 | num_KN = 1000;
19 | rnd_scale_KN = 100;
20 |
21 | syms v1 v2 v3 v4 v5 v6 v7 v8
22 | assumeAlso([v1 v2 v3 v4 v5 v6 v7 v8],'real')
23 |
24 | syms h1 h2 h3 h4 h5 h6 h7 h8 h9
25 | assumeAlso([h1 h2 h3 h4 h5 h6 h7 h8 h9],'real')
26 | %%
27 |
28 | E1 = [1, 0, 0; 0, -1, 0; 0, 0, 0];
29 | E2 = [0, 1, 0; 1, 0, 0; 0, 0, 0];
30 | E3 = [0, -1, 0; 1, 0, 0; 0, 0, 0];
31 | E4 = [1, 0, 0; 0, 1, 0; 0, 0, -2];
32 | E5 = [0, 0, 1; 0, 0, 0; 0, 0, 0];
33 | E6 = [0, 0, 0; 0, 0, 1; 0, 0, 0];
34 | E7 = [0, 0, 0; 0, 0, 0; 1, 0, 0];
35 | E8 = [0, 0, 0; 0, 0, 0; 0, 1, 0];
36 |
37 | Ex = [0, 0, 0;0, 0, -1;0, 1, 0];
38 | Ey = [0, 0, 1;0, 0, 0;-1, 0, 0];
39 | Ez = [0, -1, 0;1, 0, 0;0, 0, 0];
40 |
41 | % E1 = [0, 0, 1; 0, 0, 0; 0, 0, 0];
42 | % E2 = [0, 0, 0; 0, 0, 1; 0, 0, 0];
43 | % E3 = [0, -1, 0; 1, 0, 0; 0, 0, 0];
44 | % E4 = [0, 0, 0; 0, 0, 0; 0, 0, -1];
45 | % E5 = [1, 0, 0; 0, -1, 0; 0, 0, 0];
46 | % E6 = [0, 1, 0; 0, 0, 0; 0, 0, 0];
47 | % E7 = [0, 0, 0; 0, 0, 0; 1, 0, 0];
48 | % E8 = [0, 0, 0; 0, 0, 0; 0, 1, 0];
49 |
50 | E = {E1,E2,E3,E4,E5,E6,E7,E8};
51 |
52 | E1_vec = reshape(E1,1,[])';
53 | E2_vec = reshape(E2,1,[])';
54 | E3_vec = reshape(E3,1,[])';
55 | E4_vec = reshape(E4,1,[])';
56 | E5_vec = reshape(E5,1,[])';
57 | E6_vec = reshape(E6,1,[])';
58 | E7_vec = reshape(E7,1,[])';
59 | E8_vec = reshape(E8,1,[])';
60 |
61 |
62 | E_vec = [E1_vec, E2_vec, E3_vec, E4_vec, E5_vec, E6_vec, E7_vec, E8_vec];
63 |
64 | x_hat = v1*E1+v2*E2+v3*E3+v4*E4+v5*E5+v6*E6+v7*E7+v8*E8;
65 |
66 |
67 | %% find h
68 | H = [h1,h2,h3;h4,h5,h6;h7,h8,h9];
69 | Ad_H_hat = H*x_hat*inv(H);
70 | Ad_H_hat_vec = reshape(Ad_H_hat,1,[])';
71 |
72 | % solve least square to obtain x
73 | x = inv(E_vec'*E_vec)*(E_vec')*Ad_H_hat_vec;
74 | var = [v1,v2,v3,v4,v5,v6,v7,v8];
75 | [Ad_H_sym,b]=equationsToMatrix(x,var);
76 |
77 | syms f(h1,h2,h3,h4,h5,h6,h7,h8,h9);
78 | f(h1,h2,h3,h4,h5,h6,h7,h8,h9) = Ad_H_sym;
79 |
80 | %% K
81 | if compute_K
82 | rnd = rnd_scale_K*rand(num_K,3);
83 | tol = 1e-8;
84 | K_results = zeros(num_K,4);
85 | K_color = zeros(num_K,3);
86 | K_norm = zeros(num_K,1);
87 | for i=1:num_K
88 | H = expm(hat_so3(rnd(i,:)));
89 | Ad_test = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3)));
90 | D = kron(inv(H'),Ad_test)-eye(24);
91 |
92 | K_results(i,1) = rnd(i,1);
93 | K_results(i,2) = rnd(i,2);
94 | K_results(i,3) = rnd(i,3);
95 | K_results(i,4) = rank(D,tol);
96 | K_norm(i,1) = norm(K_results(i,1:3));
97 | if K_results(i,4)<24
98 | K_results(i,5) = 0;
99 | K_color(i,:) = [0, 0.4470, 0.7410];
100 | elseif K_results(i,4) == 24
101 | % H
102 | K_results(i,5) = 1;
103 | K_color(i,:) = [0.8500, 0.3250, 0.0980];
104 | end
105 | end
106 |
107 | K = figure(1);
108 | S = repmat(25,num_K,1);
109 | scatter3(K_results(:,1),K_results(:,2),K_results(:,3),S,K_color(:,:),'filled');
110 | xlabel("x")
111 | ylabel("y")
112 | zlabel("z")
113 |
114 | K2 = figure(2);
115 | scatter(K_norm,K_results(:,5),S,K_color,'filled');
116 | end
117 | %%
118 | % idx = K_results(:,5) == 1;
119 | % no_sol_K = K_results(idx,:);
120 | % no_sol_norm = K_norm(idx,:);
121 | % no_sol_norm_mod_pi = mod(no_sol_norm,pi);
122 | % K3 = figure(3);
123 | % histogram(no_sol_norm_mod_pi,314);
124 |
125 | %% A
126 | if compute_A
127 | rnd = rnd_scale_A*rand(num_A,2);
128 | tol = 1e-8;
129 | A_results = zeros(num_A,4);
130 | A_color = zeros(num_A,3);
131 | A_norm = zeros(num_A,1);
132 | for i = 1:num_A
133 | a1 = rnd(i,1);
134 | % a2 = rnd(i,2);
135 | a2 = 1/a1;
136 | H = [a1, 0,0; 0,a2,0;0,0,1/a1/a2];
137 | % H = [a1,0,0; 0,1/a1,0;0,0,1];
138 |
139 | Ad_test = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3)));
140 | D = kron(inv(H'),Ad_test)-eye(24);
141 |
142 | A_results(i,1) = a1;
143 | A_results(i,2) = a2;
144 | A_results(i,4) = rank(D,tol);
145 | A_norm(i,1) = norm(A_results(i,1:2));
146 | if A_results(i,4)<24
147 | A_results(i,5) = 0;
148 | A_color(i,:) = [0, 0.4470, 0.7410];
149 | elseif A_results(i,4) == 24
150 | % H
151 | A_results(i,5) = 1;
152 | A_color(i,:) = [0.8500, 0.3250, 0.0980];
153 | end
154 | end
155 |
156 | A_figure1 = figure(1);
157 | S = repmat(25,num_A,1);
158 | scatter(A_results(:,1),A_results(:,2),S,A_color(:,:),'filled');
159 | xlabel("x")
160 | ylabel("y")
161 | zlabel("z")
162 | title("Solution existence for A")
163 |
164 | A_figure2 = figure(2);
165 | scatter(A_norm,A_results(:,5),S,A_color,'filled');
166 |
167 |
168 | idx_A = A_results(:,5) == 1;
169 | no_sol_A = A_results(idx_A,:);
170 | end
171 |
172 | %% N
173 | if compute_N
174 | rnd = rnd_scale_N*rand(num_N,3);
175 | tol = 1e-8;
176 | N_results = zeros(num_N,4);
177 | N_color = zeros(num_N,3);
178 | N_norm = zeros(num_N,1);
179 | for i = 1:num_N
180 | n1 = rnd(i,1);
181 | n2 = rnd(i,2);
182 | n3 = rnd(i,3);
183 | H = [1 n1 n2; 0 1 n3; 0 0 1];
184 |
185 | Ad_test = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3)));
186 | D = kron(inv(H'),Ad_test)-eye(24);
187 |
188 | N_results(i,1) = n1;
189 | N_results(i,2) = n2;
190 | N_results(i,3) = n3;
191 | N_results(i,4) = rank(D,tol);
192 | N_norm(i,1) = norm(N_results(i,1:2));
193 | if N_results(i,4)<24
194 | N_results(i,5) = 0;
195 | N_color(i,:) = [0, 0.4470, 0.7410];
196 | elseif N_results(i,4) == 24
197 | % H
198 | N_results(i,5) = 1;
199 | N_color(i,:) = [0.8500, 0.3250, 0.0980];
200 | end
201 | end
202 |
203 | N_figure1 = figure(1);
204 | S = repmat(25,num_N,1);
205 | scatter3(N_results(:,1),N_results(:,2),N_results(:,3),S,N_color(:,:),'filled');
206 | xlabel("x")
207 | ylabel("y")
208 | zlabel("z")
209 | title("Solution existence for N")
210 |
211 | N_figure2 = figure(2);
212 | scatter(N_norm,N_results(:,5),S,N_color,'filled');
213 |
214 |
215 | idx_N = N_results(:,5) == 1;
216 | no_sol_N = N_results(idx_N,:);
217 | end
218 |
219 | %% KN
220 | if compute_KN
221 | rnd = rnd_scale_KN*rand(num_KN,3);
222 | rnd2 = rnd_scale_KN*rand(num_KN,3);
223 | tol = 1e-8;
224 | KN_results = zeros(num_KN,4);
225 | KN_color = zeros(num_KN,3);
226 | KN_norm = zeros(num_KN,1);
227 | for i = 1:num_KN
228 | n1 = rnd(i,1);
229 | n2 = rnd(i,2);
230 | n3 = rnd(i,3);
231 | N = [1 n1 n2; 0 1 n3; 0 0 1];
232 | K = expm(hat_so3(rnd2(i,:)));
233 |
234 | H = K*N;
235 |
236 | Ad_test = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3)));
237 | D = kron(inv(H'),Ad_test)-eye(24);
238 |
239 | KN_results(i,1) = n1;
240 | KN_results(i,2) = n2;
241 | KN_results(i,3) = n3;
242 | KN_results(i,4) = rank(D,tol);
243 | KN_results(i,6) = rnd2(i,1);
244 | KN_results(i,7) = rnd2(i,2);
245 | KN_results(i,8) = rnd2(i,3);
246 | KN_norm(i,1) = norm(N_results(i,1:2));
247 | if KN_results(i,4)<24
248 | KN_results(i,5) = 0;
249 | KN_color(i,:) = [0, 0.4470, 0.7410];
250 | elseif N_results(i,4) == 24
251 | % H
252 | KN_results(i,5) = 1;
253 | KN_color(i,:) = [0.8500, 0.3250, 0.0980];
254 | end
255 | end
256 |
257 | KN_figure1 = figure(1);
258 | S = repmat(25,num_N,1);
259 | scatter3(KN_results(:,1),KN_results(:,2),KN_results(:,3),S,KN_color(:,:),'filled');
260 | xlabel("x")
261 | ylabel("y")
262 | zlabel("z")
263 | title("Solution existence for N")
264 |
265 | N_figure2 = figure(2);
266 | scatter(N_norm,N_results(:,5),S,N_color,'filled');
267 |
268 |
269 | idx_N = N_results(:,5) == 1;
270 | no_sol_N = N_results(idx_N,:);
271 | end
--------------------------------------------------------------------------------
/matlab/sl3_equivariant_lift_find_bound.m:
--------------------------------------------------------------------------------
1 | clc;
2 | clear;
3 |
4 | compute_K = false;
5 | num_K = 1000;
6 | rnd_scale_K = 100;
7 |
8 | compute_A = false;
9 | num_A = 1000;
10 | rnd_scale_A = 100;
11 |
12 |
13 | compute_N = false;
14 | num_N = 1000;
15 | rnd_scale_N = 100;
16 |
17 | compute_KN = true;
18 | num_KN = 1000;
19 | rnd_scale_KN = 1;
20 |
21 | syms v1 v2 v3 v4 v5 v6 v7 v8
22 | assumeAlso([v1 v2 v3 v4 v5 v6 v7 v8],'real')
23 |
24 | syms h1 h2 h3 h4 h5 h6 h7 h8 h9
25 | assumeAlso([h1 h2 h3 h4 h5 h6 h7 h8 h9],'real')
26 | %%
27 |
28 | E1 = [1, 0, 0; 0, -1, 0; 0, 0, 0];
29 | E2 = [0, 1, 0; 1, 0, 0; 0, 0, 0];
30 | E3 = [0, -1, 0; 1, 0, 0; 0, 0, 0];
31 | E4 = [1, 0, 0; 0, 1, 0; 0, 0, -2];
32 | E5 = [0, 0, 1; 0, 0, 0; 0, 0, 0];
33 | E6 = [0, 0, 0; 0, 0, 1; 0, 0, 0];
34 | E7 = [0, 0, 0; 0, 0, 0; 1, 0, 0];
35 | E8 = [0, 0, 0; 0, 0, 0; 0, 1, 0];
36 |
37 | Ex = [0, 0, 0;0, 0, -1;0, 1, 0];
38 | Ey = [0, 0, 1;0, 0, 0;-1, 0, 0];
39 | Ez = [0, -1, 0;1, 0, 0;0, 0, 0];
40 |
41 | % E1 = [0, 0, 1; 0, 0, 0; 0, 0, 0];
42 | % E2 = [0, 0, 0; 0, 0, 1; 0, 0, 0];
43 | % E3 = [0, -1, 0; 1, 0, 0; 0, 0, 0];
44 | % E4 = [0, 0, 0; 0, 0, 0; 0, 0, -1];
45 | % E5 = [1, 0, 0; 0, -1, 0; 0, 0, 0];
46 | % E6 = [0, 1, 0; 0, 0, 0; 0, 0, 0];
47 | % E7 = [0, 0, 0; 0, 0, 0; 1, 0, 0];
48 | % E8 = [0, 0, 0; 0, 0, 0; 0, 1, 0];
49 |
50 | E = {E1,E2,E3,E4,E5,E6,E7,E8};
51 |
52 | E1_vec = reshape(E1,1,[])';
53 | E2_vec = reshape(E2,1,[])';
54 | E3_vec = reshape(E3,1,[])';
55 | E4_vec = reshape(E4,1,[])';
56 | E5_vec = reshape(E5,1,[])';
57 | E6_vec = reshape(E6,1,[])';
58 | E7_vec = reshape(E7,1,[])';
59 | E8_vec = reshape(E8,1,[])';
60 |
61 |
62 | E_vec = [E1_vec, E2_vec, E3_vec, E4_vec, E5_vec, E6_vec, E7_vec, E8_vec];
63 |
64 | x_hat = v1*E1+v2*E2+v3*E3+v4*E4+v5*E5+v6*E6+v7*E7+v8*E8;
65 |
66 |
67 | %% find h
68 | H = [h1,h2,h3;h4,h5,h6;h7,h8,h9];
69 | Ad_H_hat = H*x_hat*inv(H);
70 | Ad_H_hat_vec = reshape(Ad_H_hat,1,[])';
71 |
72 | % solve least square to obtain x
73 | x = inv(E_vec'*E_vec)*(E_vec')*Ad_H_hat_vec;
74 | var = [v1,v2,v3,v4,v5,v6,v7,v8];
75 | [Ad_H_sym,b]=equationsToMatrix(x,var);
76 |
77 | syms f(h1,h2,h3,h4,h5,h6,h7,h8,h9);
78 | f(h1,h2,h3,h4,h5,h6,h7,h8,h9) = Ad_H_sym;
79 |
80 | %% K
81 | if compute_K
82 | rnd = rnd_scale_K*rand(num_K,3);
83 | tol = 1e-8;
84 | K_results = zeros(num_K,4);
85 | K_color = zeros(num_K,3);
86 | K_norm = zeros(num_K,1);
87 | for i=1:num_K
88 | H = expm(hat_so3(rnd(i,:)));
89 | Ad_test = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3)));
90 | D = kron(inv(H'),Ad_test)-eye(24);
91 |
92 | K_results(i,1) = rnd(i,1);
93 | K_results(i,2) = rnd(i,2);
94 | K_results(i,3) = rnd(i,3);
95 | K_results(i,4) = rank(D,tol);
96 | K_norm(i,1) = norm(K_results(i,1:3));
97 | if K_results(i,4)<24
98 | K_results(i,5) = 0;
99 | K_color(i,:) = [0, 0.4470, 0.7410];
100 | elseif K_results(i,4) == 24
101 | % H
102 | K_results(i,5) = 1;
103 | K_color(i,:) = [0.8500, 0.3250, 0.0980];
104 | end
105 | end
106 |
107 | K = figure(1);
108 | S = repmat(25,num_K,1);
109 | scatter3(K_results(:,1),K_results(:,2),K_results(:,3),S,K_color(:,:),'filled');
110 | xlabel("x")
111 | ylabel("y")
112 | zlabel("z")
113 |
114 | K2 = figure(2);
115 | scatter(K_norm,K_results(:,5),S,K_color,'filled');
116 | end
117 | %%
118 | % idx = K_results(:,5) == 1;
119 | % no_sol_K = K_results(idx,:);
120 | % no_sol_norm = K_norm(idx,:);
121 | % no_sol_norm_mod_pi = mod(no_sol_norm,pi);
122 | % K3 = figure(3);
123 | % histogram(no_sol_norm_mod_pi,314);
124 |
125 | %% A
126 | if compute_A
127 | rnd = rnd_scale_A*rand(num_A,2);
128 | tol = 1e-8;
129 | A_results = zeros(num_A,4);
130 | A_color = zeros(num_A,3);
131 | A_norm = zeros(num_A,1);
132 | for i = 1:num_A
133 | a1 = rnd(i,1);
134 | % a2 = rnd(i,2);
135 | a2 = 1/a1;
136 | H = [a1, 0,0; 0,a2,0;0,0,1/a1/a2];
137 | % H = [a1,0,0; 0,1/a1,0;0,0,1];
138 |
139 | Ad_test = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3)));
140 | D = kron(inv(H'),Ad_test)-eye(24);
141 |
142 | A_results(i,1) = a1;
143 | A_results(i,2) = a2;
144 | A_results(i,4) = rank(D,tol);
145 | A_norm(i,1) = norm(A_results(i,1:2));
146 | if A_results(i,4)<24
147 | A_results(i,5) = 0;
148 | A_color(i,:) = [0, 0.4470, 0.7410];
149 | elseif A_results(i,4) == 24
150 | % H
151 | A_results(i,5) = 1;
152 | A_color(i,:) = [0.8500, 0.3250, 0.0980];
153 | end
154 | end
155 |
156 | A_figure1 = figure(1);
157 | S = repmat(25,num_A,1);
158 | scatter(A_results(:,1),A_results(:,2),S,A_color(:,:),'filled');
159 | xlabel("x")
160 | ylabel("y")
161 | zlabel("z")
162 | title("Solution existence for A")
163 |
164 | A_figure2 = figure(2);
165 | scatter(A_norm,A_results(:,5),S,A_color,'filled');
166 |
167 |
168 | idx_A = A_results(:,5) == 1;
169 | no_sol_A = A_results(idx_A,:);
170 | end
171 |
172 | %% N
173 | if compute_N
174 | rnd = rnd_scale_N*rand(num_N,3);
175 | tol = 1e-8;
176 | N_results = zeros(num_N,4);
177 | N_color = zeros(num_N,3);
178 | N_norm = zeros(num_N,1);
179 | for i = 1:num_N
180 | n1 = rnd(i,1);
181 | n2 = rnd(i,2);
182 | n3 = rnd(i,3);
183 | H = [1 n1 n2; 0 1 n3; 0 0 1];
184 |
185 | Ad_test = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3)));
186 | D = kron(inv(H'),Ad_test)-eye(24);
187 |
188 | N_results(i,1) = n1;
189 | N_results(i,2) = n2;
190 | N_results(i,3) = n3;
191 | N_results(i,4) = rank(D,tol);
192 | N_norm(i,1) = norm(N_results(i,1:3));
193 | if N_results(i,4)<24
194 | N_results(i,5) = 0;
195 | N_color(i,:) = [0, 0.4470, 0.7410];
196 | elseif N_results(i,4) == 24
197 | % H
198 | N_results(i,5) = 1;
199 | N_color(i,:) = [0.8500, 0.3250, 0.0980];
200 | end
201 | end
202 |
203 | N_figure1 = figure(1);
204 | S = repmat(25,num_N,1);
205 | scatter3(N_results(:,1),N_results(:,2),N_results(:,3),S,N_color(:,:),'filled');
206 | xlabel("x")
207 | ylabel("y")
208 | zlabel("z")
209 | title("Solution existence for N")
210 |
211 | N_figure2 = figure(2);
212 | scatter(N_norm,N_results(:,5),S,N_color,'filled');
213 |
214 |
215 | idx_N = N_results(:,5) == 1;
216 | no_sol_N = N_results(idx_N,:);
217 | end
218 |
219 | %% KN
220 | if compute_KN
221 | rnd = rnd_scale_KN*rand(num_KN,3);
222 | rnd2 = rnd_scale_KN*rand(num_KN,3);
223 | tol = 1e-8;
224 | KN_results = zeros(num_KN,4);
225 | KN_color = zeros(num_KN,3);
226 | KN_norm = zeros(num_KN,1);
227 | for i = 1:num_KN
228 | n1 = rnd(i,1);
229 | n2 = rnd(i,2);
230 | n3 = rnd(i,3);
231 | N = [1 n1 n2; 0 1 n3; 0 0 1];
232 | K = expm(hat_so3(rnd2(i,:)));
233 |
234 | H = K*N;
235 |
236 | Ad_test = double(f(H(1,1),H(1,2),H(1,3),H(2,1),H(2,2),H(2,3),H(3,1),H(3,2),H(3,3)));
237 | D = kron(inv(H'),Ad_test)-eye(24);
238 |
239 | KN_results(i,1) = n1;
240 | KN_results(i,2) = n2;
241 | KN_results(i,3) = n3;
242 | KN_results(i,4) = rank(D,tol);
243 | KN_results(i,6) = rnd2(i,1);
244 | KN_results(i,7) = rnd2(i,2);
245 | KN_results(i,8) = rnd2(i,3);
246 | KN_norm(i,1) = norm(KN_results(i,1:3));
247 | if KN_results(i,4)<24
248 | KN_results(i,5) = 0;
249 | KN_color(i,:) = [0, 0.4470, 0.7410];
250 | elseif KN_results(i,4) == 24
251 | % H
252 | KN_results(i,5) = 1;
253 | KN_color(i,:) = [0.8500, 0.3250, 0.0980];
254 | end
255 | end
256 |
257 | KN_figure1 = figure(1);
258 | S = repmat(25,num_N,1);
259 | scatter3(KN_results(:,1),KN_results(:,2),KN_results(:,3),S,KN_color(:,:),'filled');
260 | xlabel("x")
261 | ylabel("y")
262 | zlabel("z")
263 | title("Solution existence for KN")
264 |
265 | N_figure2 = figure(2);
266 | scatter(KN_norm,KN_results(:,5),S,KN_color,'filled');
267 |
268 |
269 | idx_KN = KN_results(:,5) == 1;
270 | no_sol_KN = KN_results(idx_KN,:);
271 | end
--------------------------------------------------------------------------------
/data_gen/gen_sl3_inv_5_input_data.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import numpy as np
5 | import torch
6 | from scipy.linalg import expm
7 |
8 |
9 | from einops import rearrange, repeat
10 | from einops.layers.torch import Rearrange
11 |
12 | from core.lie_neurons_layers import *
13 |
14 |
15 | def invariant_function(x1, x2, x3, x4, x5):
16 | return torch.sin(torch.trace(x1@x2@x3))+torch.cos(torch.trace(x3@x4@x5))\
17 | -torch.pow(torch.trace(x5@x1), 6)/2.0+torch.det(x3@x2)+torch.exp(torch.trace(x4@x1))\
18 | +torch.trace(x1@x2@x3@x4@x5)
19 |
20 |
21 | if __name__ == "__main__":
22 | data_saved_path = "data/sl3_inv_5_input_data/"
23 | data_name = "sl3_inv_1000_s_05_test"
24 | num_training = 1000
25 | num_testing = 1000
26 | num_conjugate = 500
27 | rnd_scale = 0.5
28 |
29 | train_data = {}
30 | test_data = {}
31 |
32 | hat_layer = HatLayer(algebra_type='sl3')
33 |
34 | x1 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_training)))\
35 | .reshape(1, 1, 8, num_training)
36 | x2 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_training)))\
37 | .reshape(1, 1, 8, num_training)
38 | x3 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_training)))\
39 | .reshape(1, 1, 8, num_training)
40 | x4 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_training)))\
41 | .reshape(1, 1, 8, num_training)
42 | x5 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_training)))\
43 | .reshape(1, 1, 8, num_training)
44 |
45 | # conjugate transformation
46 | h = torch.Tensor(np.random.uniform(-rnd_scale,
47 | rnd_scale, (num_conjugate, num_training, 8)))
48 | H = torch.linalg.matrix_exp(hat_layer(h))
49 | # conjugate x1
50 | x1_hat = hat_layer(x1.transpose(2, -1))
51 | conj_x1_hat = torch.matmul(H, torch.matmul(x1_hat, torch.inverse(H)))
52 | conj_x1 = rearrange(vee_sl3(conj_x1_hat), 'b c t l -> b l t c')
53 |
54 | # conjugate x2
55 | x2_hat = hat_layer(x2.transpose(2, -1))
56 | conj_x2_hat = torch.matmul(H, torch.matmul(x2_hat, torch.inverse(H)))
57 | conj_x2 = rearrange(vee_sl3(conj_x2_hat), 'b c t l -> b l t c')
58 |
59 | # conjugate x3
60 | x3_hat = hat_layer(x3.transpose(2, -1))
61 | conj_x3_hat = torch.matmul(H, torch.matmul(x3_hat, torch.inverse(H)))
62 | conj_x3 = rearrange(vee_sl3(conj_x3_hat), 'b c t l -> b l t c')
63 |
64 | # conjugate x4
65 | x4_hat = hat_layer(x4.transpose(2, -1))
66 | conj_x4_hat = torch.matmul(H, torch.matmul(x4_hat, torch.inverse(H)))
67 | conj_x4 = rearrange(vee_sl3(conj_x4_hat), 'b c t l -> b l t c')
68 |
69 | # conjugate x5
70 | x5_hat = hat_layer(x5.transpose(2, -1))
71 | conj_x5_hat = torch.matmul(H, torch.matmul(x5_hat, torch.inverse(H)))
72 | conj_x5 = rearrange(vee_sl3(conj_x5_hat), 'b c t l -> b l t c')
73 |
74 |
75 |
76 | inv_output = torch.zeros((1, num_training, 1))
77 | # compute invariant function
78 | for n in range(num_training):
79 | inv_output[0, n, 0] = invariant_function(
80 | x1_hat[0, 0, n, :, :], x2_hat[0, 0, n, :, :],x3_hat[0, 0, n, :, :],\
81 | x4_hat[0, 0, n, :, :],x5_hat[0, 0, n, :, :])
82 |
83 | # print("-------------------------------")
84 | # print(inv_output[0,n,0])
85 | # for i in range(num_conjugate):
86 | # print(invariant_function(conj_x1_hat[0, i, n, :, :],conj_x2_hat[0, i, n, :, :],\
87 | # conj_x3_hat[0, i, n, :, :],conj_x4_hat[0, i, n, :, :],\
88 | # conj_x5_hat[0, i, n, :, :]))
89 |
90 | train_data['x1'] = x1.numpy().reshape(8, num_training)
91 | train_data['x2'] = x2.numpy().reshape(8, num_training)
92 | train_data['x3'] = x3.numpy().reshape(8, num_training)
93 | train_data['x4'] = x4.numpy().reshape(8, num_training)
94 | train_data['x5'] = x5.numpy().reshape(8, num_training)
95 | train_data['x1_conjugate'] = conj_x1.numpy().reshape(8, num_training, num_conjugate)
96 | train_data['x2_conjugate'] = conj_x2.numpy().reshape(8, num_training, num_conjugate)
97 | train_data['x3_conjugate'] = conj_x3.numpy().reshape(8, num_training, num_conjugate)
98 | train_data['x4_conjugate'] = conj_x4.numpy().reshape(8, num_training, num_conjugate)
99 | train_data['x5_conjugate'] = conj_x5.numpy().reshape(8, num_training, num_conjugate)
100 | train_data['y'] = inv_output.numpy().reshape(1, num_training)
101 |
102 | np.savez(data_saved_path + data_name + "_train_data.npz", **train_data)
103 |
104 |
105 | '''
106 | Generate testing data
107 | '''
108 | x1 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_training)))\
109 | .reshape(1, 1, 8, num_testing)
110 | x2 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_testing)))\
111 | .reshape(1, 1, 8, num_testing)
112 | x3 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_testing)))\
113 | .reshape(1, 1, 8, num_testing)
114 | x4 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_testing)))\
115 | .reshape(1, 1, 8, num_testing)
116 | x5 = torch.Tensor(np.random.uniform(-rnd_scale, rnd_scale, (1, 8, num_testing)))\
117 | .reshape(1, 1, 8, num_testing)
118 |
119 | # conjugate transformation
120 | h = torch.Tensor(np.random.uniform(-rnd_scale,
121 | rnd_scale, (num_conjugate, num_testing, 8)))
122 | H = torch.linalg.matrix_exp(hat_layer(h))
123 | # conjugate x1
124 | x1_hat = hat_layer(x1.transpose(2, -1))
125 | conj_x1_hat = torch.matmul(H, torch.matmul(x1_hat, torch.inverse(H)))
126 | conj_x1 = rearrange(vee_sl3(conj_x1_hat), 'b c t l -> b l t c')
127 |
128 | # conjugate x2
129 | x2_hat = hat_layer(x2.transpose(2, -1))
130 | conj_x2_hat = torch.matmul(H, torch.matmul(x2_hat, torch.inverse(H)))
131 | conj_x2 = rearrange(vee_sl3(conj_x2_hat), 'b c t l -> b l t c')
132 |
133 | # conjugate x3
134 | x3_hat = hat_layer(x3.transpose(2, -1))
135 | conj_x3_hat = torch.matmul(H, torch.matmul(x3_hat, torch.inverse(H)))
136 | conj_x3 = rearrange(vee_sl3(conj_x3_hat), 'b c t l -> b l t c')
137 |
138 | # conjugate x4
139 | x4_hat = hat_layer(x4.transpose(2, -1))
140 | conj_x4_hat = torch.matmul(H, torch.matmul(x4_hat, torch.inverse(H)))
141 | conj_x4 = rearrange(vee_sl3(conj_x4_hat), 'b c t l -> b l t c')
142 |
143 | # conjugate x5
144 | x5_hat = hat_layer(x5.transpose(2, -1))
145 | conj_x5_hat = torch.matmul(H, torch.matmul(x5_hat, torch.inverse(H)))
146 | conj_x5 = rearrange(vee_sl3(conj_x5_hat), 'b c t l -> b l t c')
147 |
148 | inv_output = torch.zeros((1, num_testing, 1))
149 | # compute invariant function
150 | for n in range(num_testing):
151 | inv_output[0, n, 0] = invariant_function(
152 | x1_hat[0, 0, n, :, :], x2_hat[0, 0, n, :, :],x3_hat[0, 0, n, :, :],\
153 | x4_hat[0, 0, n, :, :],x5_hat[0, 0, n, :, :])
154 |
155 | # print("-------------------------------")
156 | # print(inv_output[0,n,0])
157 | # for i in range(num_conjugate):
158 | # print(invariant_function(conj_x1_hat[0, i, n, :, :],conj_x2_hat[0, i, n, :, :],\
159 | # conj_x3_hat[0, i, n, :, :],conj_x4_hat[0, i, n, :, :],\
160 | # conj_x5_hat[0, i, n, :, :]))
161 |
162 | test_data['x1'] = x1.numpy().reshape(8, num_testing)
163 | test_data['x2'] = x2.numpy().reshape(8, num_testing)
164 | test_data['x3'] = x3.numpy().reshape(8, num_testing)
165 | test_data['x4'] = x4.numpy().reshape(8, num_testing)
166 | test_data['x5'] = x5.numpy().reshape(8, num_testing)
167 | test_data['x1_conjugate'] = conj_x1.numpy().reshape(8, num_testing, num_conjugate)
168 | test_data['x2_conjugate'] = conj_x2.numpy().reshape(8, num_testing, num_conjugate)
169 | test_data['x3_conjugate'] = conj_x3.numpy().reshape(8, num_testing, num_conjugate)
170 | test_data['x4_conjugate'] = conj_x4.numpy().reshape(8, num_testing, num_conjugate)
171 | test_data['x5_conjugate'] = conj_x5.numpy().reshape(8, num_testing, num_conjugate)
172 | test_data['y'] = inv_output.numpy().reshape(1, num_testing)
173 |
174 | np.savez(data_saved_path + data_name + "_test_data.npz", **test_data)
175 |
176 | print("Done! Data saved to: \n", data_saved_path +
177 | data_name + "_train_data.npz\n", data_saved_path + data_name + "_test_data.npz")
178 |
--------------------------------------------------------------------------------
/script/evaluate_multiple.bash:
--------------------------------------------------------------------------------
1 | proj_dir="/home/justin/code/LieNeurons/"
2 | config_path="config/"
3 | num_experiment=4
4 | GREEN='\033[0;32m'
5 | ORANGE='\033[0;33m'
6 | NC='\033[0m' # No Color
7 | export proj_dir
8 |
9 | # ******************************************************************************
10 | # *******************************Evaluations ***********************************
11 | # ******************************************************************************
12 | # invariant tasks
13 | # mlp
14 | # for iter in $(seq 1 $num_experiment); do
15 | # echo -e "--------------------%% Running invariant task: mlp $iter %%--------------------"
16 | # export iter
17 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_inv_mlp_"+env(iter)+"_best_test_loss_acc.pt"' $config_path"sl3_inv/testing_param.yaml"
18 | # yq e -i '.model_type = "MLP"' $config_path"sl3_inv/testing_param.yaml"
19 | # python experiment/sl3_inv_test.py
20 | # done
21 | # # invariant tasks
22 | # # LN-LR
23 | # for iter in $(seq 1 $num_experiment); do
24 | # echo -e "--------------------%% Running invariant task: LN-LR $iter %%--------------------"
25 | # export iter
26 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_inv_relu_"+env(iter)+"_best_test_loss_acc.pt"' $config_path"sl3_inv/testing_param.yaml"
27 | # yq e -i '.model_type = "LN_relu"' $config_path"sl3_inv/testing_param.yaml"
28 | # python experiment/sl3_inv_test.py
29 | # done
30 | # # invariant tasks
31 | # # LN-LB
32 | # for iter in $(seq 1 $num_experiment); do
33 | # echo -e "--------------------%% Running invariant task: LN-LB $iter %%--------------------"
34 | # export iter
35 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_inv_bracket_"+env(iter)+"_best_test_loss_acc.pt"' $config_path"sl3_inv/testing_param.yaml"
36 | # yq e -i '.model_type = "LN_bracket"' $config_path"sl3_inv/testing_param.yaml"
37 | # python experiment/sl3_inv_test.py
38 | # done
39 |
40 | # # invariant tasks
41 | # # LN-LR+LB
42 | # for iter in $(seq 1 $num_experiment); do
43 | # echo -e "--------------------%% Running invariant task: LN-LR+LB $iter %%--------------------"
44 | # export iter
45 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_inv_rb_"+env(iter)+"_best_test_loss_acc.pt"' $config_path"sl3_inv/testing_param.yaml"
46 | # yq e -i '.model_type = "LN_relu_bracket"' $config_path"sl3_inv/testing_param.yaml"
47 | # python experiment/sl3_inv_test.py
48 | # done
49 |
50 | # # mlp augmented
51 | # for iter in $(seq 1 $num_experiment); do
52 | # echo -e "--------------------%% Running invariant task: mlp augmented $iter %%--------------------"
53 | # export iter
54 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_inv_mlp_augmented_"+env(iter)+"_best_test_loss_acc.pt"' $config_path"sl3_inv/testing_param.yaml"
55 | # yq e -i '.model_type = "MLP"' $config_path"sl3_inv/testing_param.yaml"
56 | # python experiment/sl3_inv_test.py
57 | # done
58 |
59 | # # ============================================================================================================================================================================================
60 | # # equivariant task
61 | # # mlp
62 | # for iter in $(seq 1 $num_experiment); do
63 | # echo -e "--------------------%% Running equivariant task: mlp $iter %%--------------------"
64 | # export iter
65 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_equiv_mlp_"+env(iter)+"_best_test_loss_acc.pt"' $config_path"sl3_equiv/testing_param.yaml"
66 | # yq e -i '.model_type = "MLP"' $config_path"sl3_equiv/testing_param.yaml"
67 | # python experiment/sl3_equiv_test.py
68 | # done
69 |
70 | # # equivariant task
71 | # # LN-LR
72 | # for iter in $(seq 1 $num_experiment); do
73 | # echo -e "--------------------%% Running equivariant task: LN-LR $iter %%--------------------"
74 | # export iter
75 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_equiv_relu_"+env(iter)+"_best_test_loss_acc.pt"' $config_path"sl3_equiv/testing_param.yaml"
76 | # yq e -i '.model_type = "LN_relu"' $config_path"sl3_equiv/testing_param.yaml"
77 | # python experiment/sl3_equiv_test.py
78 | # done
79 |
80 | # # equivariant task
81 | # # LN-LB
82 | # for iter in $(seq 1 $num_experiment); do
83 | # echo -e "--------------------%% Running equivariant task: LN-LB $iter %%--------------------"
84 | # export iter
85 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_equiv_bracket_"+env(iter)+"_best_test_loss_acc.pt"' $config_path"sl3_equiv/testing_param.yaml"
86 | # yq e -i '.model_type = "LN_bracket"' $config_path"sl3_equiv/testing_param.yaml"
87 | # python experiment/sl3_equiv_test.py
88 | # done
89 |
90 | # # equivariant task
91 | # # LN-LR+LB
92 | # for iter in $(seq 1 $num_experiment); do
93 | # echo -e "--------------------%% Running equivariant task: LN-LR+LB $iter %%--------------------"
94 | # export iter
95 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_equiv_rb_"+env(iter)+"_best_test_loss_acc.pt"' $config_path"sl3_equiv/testing_param.yaml"
96 | # yq e -i '.model_type = "LN_relu_bracket"' $config_path"sl3_equiv/testing_param.yaml"
97 | # python experiment/sl3_equiv_test.py
98 | # done
99 |
100 | # equivariant task
101 | # mlp augmented
102 | # for iter in $(seq 1 $num_experiment); do
103 | # echo -e "--------------------%% Running equivariant task: mlp augmented $iter %%--------------------"
104 | # export iter
105 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_equiv_mlp_augmented_"+env(iter)+"_best_test_loss_acc.pt"' $config_path"sl3_equiv/testing_param.yaml"
106 | # yq e -i '.model_type = "MLP"' $config_path"sl3_equiv/testing_param.yaml"
107 | # python experiment/sl3_equiv_test.py
108 | # done
109 |
110 | # ============================================================================================================================================================================================
111 | # classification task
112 |
113 | # mlp augmented
114 | for iter in $(seq 1 $num_experiment); do
115 | echo -e "--------------------%% Running classification task: mlp augmentation $iter %%--------------------"
116 | export iter
117 | yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_cls_mlp_aug_"+env(iter)+"_best_test_acc.pt"' $config_path"platonic_solid_cls/testing_param.yaml"
118 | yq e -i '.model_type = "MLP"' $config_path"platonic_solid_cls/testing_param.yaml"
119 | python experiment/platonic_solid_cls_test.py
120 | done
121 |
122 |
123 | # mlp
124 | # for iter in $(seq 1 $num_experiment); do
125 | # echo -e "--------------------%% Running classification task: mlp $iter %%--------------------"
126 | # export iter
127 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_cls_mlp_"+env(iter)+"_best_loss_acc.pt"' $config_path"platonic_solid_cls/testing_param.yaml"
128 | # yq e -i '.model_type = "MLP"' $config_path"platonic_solid_cls/testing_param.yaml"
129 | # python experiment/platonic_solid_cls_test.py
130 | # done
131 |
132 | # # classification task
133 | # # LN-LR
134 | # for iter in $(seq 1 $num_experiment); do
135 | # echo -e "--------------------%% Running classification task: LN-LR $iter %%--------------------"
136 | # export iter
137 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_cls_relu_"+env(iter)+"_best_test_acc.pt"' $config_path"platonic_solid_cls/testing_param.yaml"
138 | # yq e -i '.model_type = "LN_relu"' $config_path"platonic_solid_cls/testing_param.yaml"
139 | # python experiment/platonic_solid_cls_test.py
140 | # done
141 |
142 | # # classification task
143 | # # LN-LB
144 | # for iter in $(seq 1 $num_experiment); do
145 | # echo -e "--------------------%% Running classification task: LN-LB $iter %%--------------------"
146 | # export iter
147 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_cls_bracket_"+env(iter)+"_best_test_acc.pt"' $config_path"platonic_solid_cls/testing_param.yaml"
148 | # yq e -i '.model_type = "LN_bracket"' $config_path"platonic_solid_cls/testing_param.yaml"
149 | # python experiment/platonic_solid_cls_test.py
150 | # done
151 | # # classification task
152 | # # LN-LR+LB
153 | # for iter in $(seq 1 $num_experiment); do
154 | # echo -e "--------------------%% Running classification task: LN-LR+LB $iter %%--------------------"
155 | # export iter
156 | # yq e -i '.model_path = strenv(proj_dir)+"weights/rebuttal_cls_rb_"+env(iter)+"_best_test_acc.pt"' $config_path"platonic_solid_cls/testing_param.yaml"
157 | # yq e -i '.model_type = "LN_relu_bracket"' $config_path"platonic_solid_cls/testing_param.yaml"
158 | # python experiment/platonic_solid_cls_test.py
159 | # done
--------------------------------------------------------------------------------
/experiment/platonic_solid_cls_train.py:
--------------------------------------------------------------------------------
1 | import sys # nopep8
2 | sys.path.append('.') # nopep8
3 |
4 | import argparse
5 | import os
6 | import time
7 | import yaml
8 | import numpy as np
9 | from tqdm import tqdm
10 | import matplotlib.pyplot as plt
11 | import time
12 | from scipy.spatial.transform import Rotation
13 |
14 | import torch
15 | from torch import nn
16 | import torch.optim as optim
17 | from torch.utils.tensorboard import SummaryWriter
18 | from torch.utils.data import Dataset, DataLoader, IterableDataset
19 |
20 | from core.lie_neurons_layers import *
21 | from experiment.platonic_solid_cls_layers import *
22 | from data_gen.gen_platonic_solids import *
23 |
24 |
25 | def init_writer(config):
26 | writer = SummaryWriter(
27 | config['log_writer_path']+"_"+str(time.localtime()), comment=config['model_description'])
28 | # writer.add_text("train_data_path: ", config['train_data_path'])
29 | writer.add_text("model_save_path: ", config['model_save_path'])
30 | writer.add_text("log_writer_path: ", config['log_writer_path'])
31 | writer.add_text("shuffle: ", str(config['shuffle']))
32 | writer.add_text("batch_size: ", str(config['batch_size']))
33 | writer.add_text("init_lr: ", str(config['initial_learning_rate']))
34 | writer.add_text("num_train: ", str(config['num_train']))
35 |
36 | return writer
37 |
38 | def random_sample_rotations(num_rotation, rotation_factor: float = 1.0, device='cpu') -> np.ndarray:
39 | r = np.zeros((num_rotation, 3, 3))
40 | for n in range(num_rotation):
41 | # angle_z, angle_y, angle_x
42 | euler = np.random.rand(3) * np.pi * 2 / rotation_factor # (0, 2 * pi / rotation_range)
43 | r[n,:,:] = Rotation.from_euler('zyx', euler).as_matrix()
44 | return torch.from_numpy(r).type('torch.FloatTensor').to(device)
45 |
46 | def test(model, test_loader, criterion, config, device):
47 | model.eval()
48 | with torch.no_grad():
49 | loss_sum = 0.0
50 | num_correct = 0
51 | for iter, samples in tqdm(enumerate(test_loader, start=0)):
52 |
53 | x = samples[0].to(device)
54 | y = samples[1].to(device)
55 |
56 | x = rearrange(x,'b n f k -> b f k n')
57 | output = model(x)
58 |
59 | _, prediction = torch.max(output,1)
60 | num_correct += (prediction==y).sum().item()
61 |
62 | loss = criterion(output, y)
63 | loss_sum += loss.item()
64 |
65 | loss_avg = loss_sum/config['num_test']*config['batch_size']
66 | acc_avg = num_correct/config['num_test']
67 |
68 | return loss_avg, acc_avg
69 |
70 |
71 | def train(model, train_loader, test_loader, config, device='cpu'):
72 |
73 | writer = init_writer(config)
74 |
75 | # create criterion
76 | criterion = nn.CrossEntropyLoss().to(device)
77 | optimizer = optim.Adam(model.parameters(
78 | ), lr=config['initial_learning_rate'], weight_decay=config['weight_decay_rate'])
79 | # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=config['learning_rate_decay_rate'])
80 | # scheduler = optim.lr_scheduler.LinearLR(optimizer,total_iters=config['num_epochs'])
81 |
82 | if config['resume_training']:
83 | checkpoint = torch.load(config['resume_model_path'])
84 | model.load_state_dict(checkpoint['model_state_dict'])
85 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
86 | start_iter = checkpoint['epoch']
87 | else:
88 | start_iter = 0
89 |
90 | best_loss = float("inf")
91 | best_acc = 0.0
92 | running_loss = 0.0
93 | loss_sum = 0.0
94 | hat_layer = HatLayer(algebra_type='sl3').to(device)
95 |
96 | for iter, samples in tqdm(enumerate(train_loader, start=0)):
97 |
98 | model.train()
99 | optimizer.zero_grad()
100 |
101 | x = samples[0].to(device)
102 | cls = samples[1].to(device)
103 | if config['train_augmentation']:
104 | rot = random_sample_rotations(1, config['rotation_factor'],device)
105 | x_hat = hat_layer(x)
106 | x_rot_hat = torch.matmul(rot, torch.matmul(x_hat, torch.inverse(rot)))
107 | x = vee_sl3(x_rot_hat)
108 |
109 | x = rearrange(x,'b n f k -> b f k n')
110 |
111 | output = model(x)
112 | # print(output)
113 | loss = criterion(output, cls)
114 | loss.backward()
115 |
116 | # we only update the weights every config['update_every_batch'] iterations
117 | # This is to simulate a larger batch size
118 | # if (i+1) % config['update_every_batch'] == 0:
119 | optimizer.step()
120 | optimizer.zero_grad()
121 |
122 | # cur_training_loss_history.append(loss.item())
123 | running_loss += loss.item()
124 | loss_sum += loss.item()
125 |
126 | # if iter % config['print_freq'] == 0:
127 | # print("iteration %d / %d, loss: %.8f" %
128 | # (iter, config['num_train'], running_loss/config['print_freq']))
129 | # running_loss = 0.0
130 |
131 |
132 | # scheduler.step()
133 |
134 | # train_top1, train_top5, _ = validate(train_loader, model, criterion, config, device)
135 |
136 | train_loss = loss_sum/(iter+1)
137 |
138 | test_loss, test_acc = test(
139 | model, test_loader, criterion, config, device)
140 |
141 | # log down info in tensorboard
142 | writer.add_scalar('training loss', train_loss, iter)
143 | writer.add_scalar('test loss', test_loss, iter)
144 | writer.add_scalar('test acc', test_acc, iter)
145 |
146 | # if we achieve best val loss, save the model
147 | if test_loss < best_loss:
148 | best_loss = test_loss
149 |
150 | state = {'iter': iter,
151 | 'model_state_dict': model.state_dict(),
152 | 'optimizer_state_dict': optimizer.state_dict(),
153 | 'loss': train_loss,
154 | 'test loss': test_loss,
155 | 'test acc': test_acc}
156 |
157 | torch.save(state, config['model_save_path'] +
158 | '_best_test_loss.pt')
159 |
160 | if test_acc > best_acc:
161 | best_acc = test_acc
162 |
163 | state = {'iter': iter,
164 | 'model_state_dict': model.state_dict(),
165 | 'optimizer_state_dict': optimizer.state_dict(),
166 | 'loss': train_loss,
167 | 'test loss': test_loss,
168 | 'test acc': test_acc}
169 |
170 | torch.save(state, config['model_save_path'] +
171 | '_best_test_acc.pt')
172 | print("------------------------------")
173 | # print("Finished epoch %d / %d, training top 1 acc: %.4f, training top 5 acc: %.4f, \
174 | # validation top1 acc: %.4f, validation top 5 acc: %.4f" %\
175 | # (epoch, config['num_epochs'], train_top1, train_top5, val_top1, val_top5))
176 | print("Finished iteration %d / %d, train loss: %.4f test loss: %.4f test acc: %.4f" %
177 | (iter, config['num_train']/config['batch_size'], train_loss, test_loss, test_acc))
178 |
179 | # save model
180 | state = {'iter': iter,
181 | 'model_state_dict': model.state_dict(),
182 | 'optimizer_state_dict': optimizer.state_dict(),
183 | 'loss': train_loss,
184 | 'test loss': test_loss}
185 |
186 | torch.save(state, config['model_save_path']+'_last_iter.pt')
187 |
188 | writer.close()
189 |
190 |
191 | def main():
192 | # torch.autograd.set_detect_anomaly(True)
193 |
194 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
195 | print('Using ', device)
196 |
197 | parser = argparse.ArgumentParser(description='Train the network')
198 | parser.add_argument('--training_config', type=str,
199 | default=os.path.dirname(os.path.abspath(__file__))+'/../config/platonic_solid_cls/training_param.yaml')
200 | args = parser.parse_args()
201 |
202 | # load yaml file
203 | config = yaml.safe_load(open(args.training_config))
204 |
205 | # create dataset and dataloader
206 | training_set = PlatonicDataset(config['num_train'])
207 | train_loader = DataLoader(dataset=training_set, batch_size=config['batch_size'],
208 | shuffle=config['shuffle'])
209 |
210 | test_set = PlatonicDataset(config['num_test'])
211 | test_loader = DataLoader(dataset=test_set, batch_size=config['batch_size'],
212 | shuffle=config['shuffle'])
213 |
214 | if config['model_type'] == "LN_relu_bracket":
215 | model = LNReluBracketPlatonicSolidClassifier(3).to(device)
216 | elif config['model_type'] == "LN_relu":
217 | model = LNReluPlatonicSolidClassifier(3).to(device)
218 | elif config['model_type'] == "LN_bracket":
219 | model = LNBracketPlatonicSolidClassifier(3).to(device)
220 | elif config['model_type'] == "MLP":
221 | model = MLP(288).to(device)
222 | elif config['model_type'] == "LN_bracket_no_residual":
223 | model = LNBracketNoResidualConnectPlatonicSolidClassifier(3).to(device)
224 |
225 | train(model, train_loader, test_loader, config, device)
226 |
227 |
228 | if __name__ == "__main__":
229 | main()
230 |
--------------------------------------------------------------------------------