├── avae ├── __init__.py ├── decoders │ ├── __init__.py │ ├── base.py │ ├── spatial.py │ └── differentiable.py ├── encoders │ ├── __init__.py │ └── base.py ├── settings.py ├── base.py ├── models.py ├── utils_learning.py ├── cyc_annealing.py ├── loss.py └── evaluate.py ├── tests ├── __init__.py ├── testdata_mrc │ ├── __init__.py │ ├── classes.csv │ ├── 1b23_12-78-36.mrcs │ ├── 1b23_18-6-203.mrcs │ ├── 1dkg_71-5-304.mrcs │ ├── 1b23_12-321-215.mrcs │ ├── 1b23_12-78-36-1.mrcs │ ├── 1b23_12-78-36-2.mrcs │ ├── 1b23_12-78-36-3.mrcs │ ├── 1b23_17-12-216.mrcs │ ├── 1b23_17-122-253.mrcs │ ├── 1b23_18-6-203-1.mrcs │ ├── 1b23_18-6-203-2.mrcs │ ├── 1b23_18-6-203-3.mrcs │ ├── 1b23_19-63-199.mrcs │ ├── 1b23_3-143-280.mrcs │ ├── 1dfo_10-119-219.mrcs │ ├── 1dfo_11-314-163.mrcs │ ├── 1dfo_14-53-194.mrcs │ ├── 1dfo_39-131-304.mrcs │ ├── 1dfo_50-101-13.mrcs │ ├── 1dfo_50-188-93.mrcs │ ├── 1dfo_50-98-117.mrcs │ ├── 1dkg_39-193-315.mrcs │ ├── 1dkg_48-309-129.mrcs │ ├── 1dkg_54-251-337.mrcs │ ├── 1dkg_61-295-21.mrcs │ ├── 1dkg_66-55-313.mrcs │ ├── 1dkg_70-345-201.mrcs │ ├── 1e3p_154-148-97.mrcs │ ├── 1e3p_158-30-224.mrcs │ ├── 1e3p_165-45-181.mrcs │ ├── 1e3p_175-257-94.mrcs │ ├── 1e3p_180-68-53.mrcs │ ├── 1e3p_194-99-344.mrcs │ ├── 1e3p_200-14-187.mrcs │ ├── 1b23_12-321-215-1.mrcs │ ├── 1b23_12-321-215-2.mrcs │ ├── 1b23_12-321-215-3.mrcs │ ├── 1b23_17-12-216-1.mrcs │ ├── 1b23_17-12-216-2.mrcs │ ├── 1b23_17-12-216-3.mrcs │ ├── 1b23_17-122-253-1.mrcs │ ├── 1b23_17-122-253-2.mrcs │ ├── 1b23_17-122-253-3.mrcs │ ├── 1b23_19-63-199-1.mrcs │ ├── 1b23_19-63-199-2.mrcs │ ├── 1b23_19-63-199-3.mrcs │ ├── 1b23_3-143-280-1.mrcs │ ├── 1b23_3-143-280-2.mrcs │ ├── 1b23_3-143-280-3.mrcs │ ├── 1dfo_10-119-219-1.mrcs │ ├── 1dfo_10-119-219-2.mrcs │ ├── 1dfo_10-119-219-3.mrcs │ ├── 1dfo_10-119-219-4.mrcs │ ├── 1dkg_48-309-129-1.mrcs │ ├── 1dkg_48-309-129-2.mrcs │ ├── 1dkg_48-309-129-3.mrcs │ ├── 1e3p_154-148-97-1.mrcs │ ├── 1e3p_154-148-97-2.mrcs │ ├── 1e3p_154-148-97-3.mrcs │ └── test │ │ ├── 1u22_19-28-86.mrcs │ │ ├── 1u22_2-252-51.mrcs │ │ ├── 1u22_8-0-98.mrcs │ │ ├── 1dkg_39-193-315.mrcs │ │ ├── 1dkg_66-55-313.mrcs │ │ ├── 1dkg_70-345-201.mrcs │ │ ├── 1fcj_12-315-152.mrcs │ │ ├── 1fcj_14-179-232.mrcs │ │ ├── 1fcj_3-312-192.mrcs │ │ ├── 1fcj_36-356-203.mrcs │ │ ├── 1fcj_44-141-202.mrcs │ │ ├── 1fcj_47-208-183.mrcs │ │ ├── 1fcj_61-226-58.mrcs │ │ ├── 1u22_18-34-195.mrcs │ │ ├── 1u22_20-234-11.mrcs │ │ ├── 1u22_3-333-222.mrcs │ │ ├── 1u22_7-151-203.mrcs │ │ ├── 1b23_12-321-215-1.mrcs │ │ ├── 1b23_12-321-215-2.mrcs │ │ └── 1b23_12-321-215-3.mrcs ├── testdata_npy │ ├── __init__.py │ ├── classes.csv │ ├── 2_0_img.npy │ ├── 2_1_img.npy │ ├── 2_2_img.npy │ ├── 2_3_img.npy │ ├── 2_4_img.npy │ ├── 2_5_img.npy │ ├── 2_6_img.npy │ ├── 2_8_img.npy │ ├── 2_9_img.npy │ ├── 5_0_img.npy │ ├── 5_10_img.npy │ ├── 5_1_img.npy │ ├── 5_2_img.npy │ ├── 5_3_img.npy │ ├── 5_4_img.npy │ ├── 5_5_img.npy │ ├── 5_6_img.npy │ ├── 5_7_img.npy │ ├── 5_9_img.npy │ ├── a_0_img.npy │ ├── a_10_img.npy │ ├── a_1_img.npy │ ├── a_2_img.npy │ ├── a_3_img.npy │ ├── a_4_img.npy │ ├── a_5_img.npy │ ├── a_6_img.npy │ ├── a_7_img.npy │ ├── a_8_img.npy │ ├── d_0_img.npy │ ├── d_1_img.npy │ ├── d_2_img.npy │ ├── d_3_img.npy │ ├── d_4_img.npy │ ├── d_5_img.npy │ ├── d_6_img.npy │ ├── d_7_img.npy │ ├── d_8_img.npy │ ├── d_9_img.npy │ ├── e_0_img.npy │ ├── e_10_img.npy │ ├── e_1_img.npy │ ├── e_2_img.npy │ ├── e_3_img.npy │ ├── e_4_img.npy │ ├── e_5_img.npy │ ├── e_6_img.npy │ ├── e_7_img.npy │ ├── e_8_img.npy │ ├── e_9_img.npy │ ├── i_10_img.npy │ ├── i_1_img.npy │ ├── i_2_img.npy │ ├── i_3_img.npy │ ├── i_4_img.npy │ ├── i_5_img.npy │ ├── i_6_img.npy │ ├── i_7_img.npy │ ├── i_8_img.npy │ ├── i_9_img.npy │ ├── j_0_img.npy │ ├── j_10_img.npy │ ├── j_1_img.npy │ ├── j_2_img.npy │ ├── j_3_img.npy │ ├── j_4_img.npy │ ├── j_5_img.npy │ ├── j_6_img.npy │ ├── j_7_img.npy │ ├── j_9_img.npy │ ├── l_0_img.npy │ ├── l_10_img.npy │ ├── l_1_img.npy │ ├── l_2_img.npy │ ├── l_3_img.npy │ ├── l_4_img.npy │ ├── l_5_img.npy │ ├── l_6_img.npy │ ├── l_7_img.npy │ ├── l_8_img.npy │ ├── l_9_img.npy │ ├── s_0_img.npy │ ├── s_10_img.npy │ ├── s_1_img.npy │ ├── s_2_img.npy │ ├── s_3_img.npy │ ├── s_4_img.npy │ ├── s_5_img.npy │ ├── s_6_img.npy │ ├── s_8_img.npy │ ├── s_9_img.npy │ ├── u_0_img.npy │ ├── u_10_img.npy │ ├── u_1_img.npy │ ├── u_2_img.npy │ ├── u_3_img.npy │ ├── u_4_img.npy │ ├── u_5_img.npy │ ├── u_7_img.npy │ ├── u_8_img.npy │ ├── u_9_img.npy │ ├── v_0_img.npy │ ├── v_10_img.npy │ ├── v_1_img.npy │ ├── v_2_img.npy │ ├── v_3_img.npy │ ├── v_5_img.npy │ ├── v_6_img.npy │ ├── v_7_img.npy │ ├── v_8_img.npy │ ├── v_9_img.npy │ ├── x_0_img.npy │ ├── x_10_img.npy │ ├── x_2_img.npy │ ├── x_3_img.npy │ ├── x_4_img.npy │ ├── x_5_img.npy │ ├── x_6_img.npy │ ├── x_7_img.npy │ ├── x_8_img.npy │ ├── x_9_img.npy │ └── test │ │ ├── 2_10_img.npy │ │ ├── 2_7_img.npy │ │ ├── 5_8_img.npy │ │ ├── a_9_img.npy │ │ ├── d_10_img.npy │ │ ├── i_0_img.npy │ │ ├── j_8_img.npy │ │ ├── s_7_img.npy │ │ ├── u_6_img.npy │ │ ├── v_4_img.npy │ │ └── x_1_img.npy ├── test_utils.py ├── test_loss.py ├── test_data.py ├── test_config.py ├── test_vis.py └── test_train_eval_pipeline.py ├── tools ├── __init__.py ├── gamma_array.npy ├── slurm_run.sh ├── augment_mrcs.py └── create_subtomo.py ├── configs ├── __init__.py ├── avae-create_subtomo_test_config.yml └── avae-test-config.yml ├── tutorials ├── mnist_data │ ├── classes_mnist.csv │ ├── affinity_mnist.csv │ └── mnist_config.yml ├── README.md └── mnist_saver.py ├── images ├── Napari1.png └── Napari2.png ├── requirements.txt ├── .github └── workflows │ ├── tests.yml │ └── checks.yml ├── LICENSE.md ├── .pre-commit-config.yaml ├── .gitignore ├── scripts ├── README.md ├── run_napari_model_view.py └── run_create_subtomo.py └── pyproject.toml /avae/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /avae/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /avae/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/testdata_mrc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/testdata_npy/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/testdata_npy/classes.csv: -------------------------------------------------------------------------------- 1 | a,e,i,j,s,u,v,d,2 2 | -------------------------------------------------------------------------------- /tests/testdata_mrc/classes.csv: -------------------------------------------------------------------------------- 1 | 1b23,1dfo,1dkg,1e3p 2 | -------------------------------------------------------------------------------- /tutorials/mnist_data/classes_mnist.csv: -------------------------------------------------------------------------------- 1 | 0,1,2,3,4,5,6,7,8 2 | -------------------------------------------------------------------------------- /images/Napari1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/images/Napari1.png -------------------------------------------------------------------------------- /images/Napari2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/images/Napari2.png -------------------------------------------------------------------------------- /tools/gamma_array.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tools/gamma_array.npy -------------------------------------------------------------------------------- /tests/testdata_npy/2_0_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/2_0_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/2_1_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/2_1_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/2_2_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/2_2_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/2_3_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/2_3_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/2_4_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/2_4_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/2_5_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/2_5_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/2_6_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/2_6_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/2_8_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/2_8_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/2_9_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/2_9_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/5_0_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/5_0_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/5_10_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/5_10_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/5_1_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/5_1_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/5_2_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/5_2_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/5_3_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/5_3_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/5_4_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/5_4_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/5_5_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/5_5_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/5_6_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/5_6_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/5_7_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/5_7_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/5_9_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/5_9_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/a_0_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/a_0_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/a_10_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/a_10_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/a_1_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/a_1_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/a_2_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/a_2_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/a_3_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/a_3_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/a_4_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/a_4_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/a_5_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/a_5_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/a_6_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/a_6_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/a_7_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/a_7_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/a_8_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/a_8_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/d_0_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/d_0_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/d_1_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/d_1_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/d_2_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/d_2_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/d_3_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/d_3_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/d_4_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/d_4_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/d_5_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/d_5_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/d_6_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/d_6_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/d_7_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/d_7_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/d_8_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/d_8_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/d_9_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/d_9_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/e_0_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/e_0_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/e_10_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/e_10_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/e_1_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/e_1_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/e_2_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/e_2_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/e_3_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/e_3_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/e_4_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/e_4_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/e_5_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/e_5_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/e_6_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/e_6_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/e_7_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/e_7_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/e_8_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/e_8_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/e_9_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/e_9_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/i_10_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/i_10_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/i_1_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/i_1_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/i_2_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/i_2_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/i_3_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/i_3_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/i_4_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/i_4_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/i_5_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/i_5_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/i_6_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/i_6_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/i_7_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/i_7_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/i_8_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/i_8_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/i_9_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/i_9_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/j_0_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/j_0_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/j_10_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/j_10_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/j_1_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/j_1_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/j_2_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/j_2_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/j_3_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/j_3_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/j_4_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/j_4_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/j_5_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/j_5_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/j_6_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/j_6_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/j_7_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/j_7_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/j_9_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/j_9_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/l_0_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/l_0_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/l_10_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/l_10_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/l_1_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/l_1_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/l_2_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/l_2_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/l_3_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/l_3_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/l_4_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/l_4_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/l_5_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/l_5_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/l_6_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/l_6_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/l_7_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/l_7_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/l_8_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/l_8_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/l_9_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/l_9_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/s_0_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/s_0_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/s_10_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/s_10_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/s_1_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/s_1_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/s_2_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/s_2_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/s_3_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/s_3_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/s_4_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/s_4_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/s_5_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/s_5_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/s_6_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/s_6_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/s_8_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/s_8_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/s_9_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/s_9_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/u_0_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/u_0_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/u_10_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/u_10_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/u_1_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/u_1_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/u_2_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/u_2_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/u_3_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/u_3_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/u_4_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/u_4_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/u_5_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/u_5_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/u_7_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/u_7_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/u_8_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/u_8_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/u_9_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/u_9_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/v_0_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/v_0_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/v_10_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/v_10_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/v_1_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/v_1_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/v_2_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/v_2_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/v_3_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/v_3_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/v_5_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/v_5_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/v_6_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/v_6_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/v_7_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/v_7_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/v_8_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/v_8_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/v_9_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/v_9_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/x_0_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/x_0_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/x_10_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/x_10_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/x_2_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/x_2_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/x_3_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/x_3_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/x_4_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/x_4_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/x_5_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/x_5_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/x_6_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/x_6_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/x_7_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/x_7_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/x_8_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/x_8_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/x_9_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/x_9_img.npy -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_12-78-36.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_12-78-36.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_18-6-203.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_18-6-203.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1dkg_71-5-304.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1dkg_71-5-304.mrcs -------------------------------------------------------------------------------- /tests/testdata_npy/test/2_10_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/test/2_10_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/test/2_7_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/test/2_7_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/test/5_8_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/test/5_8_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/test/a_9_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/test/a_9_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/test/d_10_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/test/d_10_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/test/i_0_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/test/i_0_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/test/j_8_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/test/j_8_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/test/s_7_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/test/s_7_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/test/u_6_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/test/u_6_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/test/v_4_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/test/v_4_img.npy -------------------------------------------------------------------------------- /tests/testdata_npy/test/x_1_img.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_npy/test/x_1_img.npy -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_12-321-215.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_12-321-215.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_12-78-36-1.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_12-78-36-1.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_12-78-36-2.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_12-78-36-2.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_12-78-36-3.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_12-78-36-3.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_17-12-216.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_17-12-216.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_17-122-253.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_17-122-253.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_18-6-203-1.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_18-6-203-1.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_18-6-203-2.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_18-6-203-2.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_18-6-203-3.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_18-6-203-3.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_19-63-199.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_19-63-199.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_3-143-280.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_3-143-280.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1dfo_10-119-219.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1dfo_10-119-219.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1dfo_11-314-163.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1dfo_11-314-163.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1dfo_14-53-194.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1dfo_14-53-194.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1dfo_39-131-304.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1dfo_39-131-304.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1dfo_50-101-13.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1dfo_50-101-13.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1dfo_50-188-93.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1dfo_50-188-93.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1dfo_50-98-117.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1dfo_50-98-117.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1dkg_39-193-315.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1dkg_39-193-315.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1dkg_48-309-129.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1dkg_48-309-129.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1dkg_54-251-337.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1dkg_54-251-337.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1dkg_61-295-21.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1dkg_61-295-21.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1dkg_66-55-313.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1dkg_66-55-313.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1dkg_70-345-201.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1dkg_70-345-201.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1e3p_154-148-97.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1e3p_154-148-97.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1e3p_158-30-224.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1e3p_158-30-224.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1e3p_165-45-181.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1e3p_165-45-181.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1e3p_175-257-94.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1e3p_175-257-94.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1e3p_180-68-53.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1e3p_180-68-53.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1e3p_194-99-344.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1e3p_194-99-344.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1e3p_200-14-187.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1e3p_200-14-187.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_12-321-215-1.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_12-321-215-1.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_12-321-215-2.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_12-321-215-2.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_12-321-215-3.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_12-321-215-3.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_17-12-216-1.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_17-12-216-1.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_17-12-216-2.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_17-12-216-2.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_17-12-216-3.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_17-12-216-3.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_17-122-253-1.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_17-122-253-1.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_17-122-253-2.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_17-122-253-2.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_17-122-253-3.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_17-122-253-3.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_19-63-199-1.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_19-63-199-1.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_19-63-199-2.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_19-63-199-2.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_19-63-199-3.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_19-63-199-3.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_3-143-280-1.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_3-143-280-1.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_3-143-280-2.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_3-143-280-2.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1b23_3-143-280-3.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1b23_3-143-280-3.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1dfo_10-119-219-1.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1dfo_10-119-219-1.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1dfo_10-119-219-2.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1dfo_10-119-219-2.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1dfo_10-119-219-3.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1dfo_10-119-219-3.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1dfo_10-119-219-4.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1dfo_10-119-219-4.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1dkg_48-309-129-1.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1dkg_48-309-129-1.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1dkg_48-309-129-2.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1dkg_48-309-129-2.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1dkg_48-309-129-3.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1dkg_48-309-129-3.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1e3p_154-148-97-1.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1e3p_154-148-97-1.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1e3p_154-148-97-2.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1e3p_154-148-97-2.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/1e3p_154-148-97-3.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/1e3p_154-148-97-3.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/test/1u22_19-28-86.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/test/1u22_19-28-86.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/test/1u22_2-252-51.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/test/1u22_2-252-51.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/test/1u22_8-0-98.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/test/1u22_8-0-98.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/test/1dkg_39-193-315.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/test/1dkg_39-193-315.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/test/1dkg_66-55-313.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/test/1dkg_66-55-313.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/test/1dkg_70-345-201.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/test/1dkg_70-345-201.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/test/1fcj_12-315-152.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/test/1fcj_12-315-152.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/test/1fcj_14-179-232.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/test/1fcj_14-179-232.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/test/1fcj_3-312-192.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/test/1fcj_3-312-192.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/test/1fcj_36-356-203.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/test/1fcj_36-356-203.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/test/1fcj_44-141-202.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/test/1fcj_44-141-202.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/test/1fcj_47-208-183.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/test/1fcj_47-208-183.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/test/1fcj_61-226-58.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/test/1fcj_61-226-58.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/test/1u22_18-34-195.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/test/1u22_18-34-195.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/test/1u22_20-234-11.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/test/1u22_20-234-11.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/test/1u22_3-333-222.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/test/1u22_3-333-222.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/test/1u22_7-151-203.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/test/1u22_7-151-203.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/test/1b23_12-321-215-1.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/test/1b23_12-321-215-1.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/test/1b23_12-321-215-2.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/test/1b23_12-321-215-2.mrcs -------------------------------------------------------------------------------- /tests/testdata_mrc/test/1b23_12-321-215-3.mrcs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccpem/affinity-vae/HEAD/tests/testdata_mrc/test/1b23_12-321-215-3.mrcs -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | altair 2 | click 3 | matplotlib 4 | mrcfile 5 | numpy 6 | pandas 7 | pillow 8 | pyyaml 9 | requests 10 | scikit-image 11 | scikit-learn 12 | scipy 13 | tensorboard 14 | torch 15 | torchvision 16 | umap-learn 17 | -------------------------------------------------------------------------------- /avae/encoders/base.py: -------------------------------------------------------------------------------- 1 | import enum 2 | from abc import ABC, abstractmethod 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | # Abstract Encoder 9 | class AbstractEncoder(nn.Module, ABC): 10 | @abstractmethod 11 | def forward(self, x): 12 | pass 13 | -------------------------------------------------------------------------------- /avae/decoders/base.py: -------------------------------------------------------------------------------- 1 | import enum 2 | from abc import ABC, abstractmethod 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | # Abstract Decoder 9 | class AbstractDecoder(nn.Module, ABC): 10 | @abstractmethod 11 | def forward(self, x: torch.Tensor, x_pose: torch.Tensor) -> torch.Tensor: 12 | pass 13 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from avae.utils_learning import dims_after_pooling 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "n_pools, expected", [(0, 64), (1, 32), (2, 16), (3, 8)] 8 | ) 9 | def test_dims_after_pooling_ndim(n_pools, expected): 10 | """Test dimension calculation after 2x2 pooling op.""" 11 | start = 64 12 | after_pool = dims_after_pooling(start, n_pools) 13 | assert after_pool == expected 14 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | on: [push, pull_request] 3 | 4 | jobs: 5 | build: 6 | runs-on: ubuntu-latest 7 | steps: 8 | - uses: actions/checkout@v4 9 | - uses: actions/setup-python@v5 10 | with: 11 | python-version: '3.10.4' 12 | - name: Install dependencies 13 | run: | 14 | python -m pip install --upgrade pip 15 | - name: Test with pytest 16 | run: | 17 | python -m pip install -e ".[test]" 18 | python -m pytest -s -W ignore 19 | -------------------------------------------------------------------------------- /avae/settings.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from datetime import datetime 3 | 4 | ## visualisation configurations 5 | 6 | VIS_LOS = False 7 | VIS_ACC = False 8 | VIS_REC = False 9 | VIS_EMB = False 10 | VIS_INT = False 11 | VIS_DIS = False 12 | VIS_POS = False 13 | VIS_CYC = False 14 | VIS_AFF = False 15 | VIS_HIS = False 16 | VIS_SIM = False 17 | VIS_DYN = False 18 | VIS_POSE_CLASS = None 19 | VIS_Z_N_INT = None 20 | VIS_FORMAT = "png" 21 | 22 | FREQ_ACC = 0 23 | FREQ_REC = 0 24 | FREQ_EMB = 0 25 | FREQ_INT = 0 26 | FREQ_DIS = 0 27 | FREQ_POS = 0 28 | FREQ_SIM = 0 29 | FREQ_EVAL = 0 30 | FREQ_STA = 0 31 | 32 | # logging settings 33 | 34 | 35 | date_time_run = datetime.now().strftime("%H_%M_%d_%m_%Y") 36 | logging.basicConfig( 37 | level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s" 38 | ) 39 | -------------------------------------------------------------------------------- /tutorials/mnist_data/affinity_mnist.csv: -------------------------------------------------------------------------------- 1 | 0,1,2,3,4,5,6,7,8,9 2 | 1.00E+00,-5.00E-01,-2.25E-01,-3.50E-01,-5.38E-02,-4.37E-03,-1.97E-01,1.29E-01,-1.46E-01,-3.96E-02 3 | -5.00E-01,1.00E+00,-5.00E-01,-1.00E-01,-1.00E-01,-1.00E-01,-1.00E-01,0.00E+00,-1.00E-01,-1.00E-01 4 | -2.25E-01,-5.00E-01,1.00E+00,-4.43E-01,-2.07E-01,1.50E-01,-1.36E-01,-2.13E-01,-2.09E-01,-2.48E-01 5 | -3.50E-01,-1.00E-01,-4.43E-01,1.00E+00,-5.00E-01,8.14E-02,-2.42E-03,1.35E-01,6.53E-01,1.31E-01 6 | -5.38E-02,-1.00E-01,-2.07E-01,-5.00E-01,1.00E+00,-5.00E-01,-2.00E-01,2.66E-01,0.00E+00,5.48E-01 7 | -4.37E-03,-1.00E-01,1.50E-01,8.14E-02,-5.00E-01,1.00E+00,-5.00E-01,-2.00E-01,5.00E-02,1.00E-01 8 | -1.97E-01,-1.00E-01,-1.36E-01,-2.42E-03,-2.00E-01,-5.00E-01,1.00E+00,-5.00E-01,3.30E-01,8.00E-01 9 | 1.29E-01,0.00E+00,-2.13E-01,1.35E-01,2.66E-01,-2.00E-01,-5.00E-01,1.00E+00,-5.00E-01,2.11E-01 10 | -1.46E-01,-1.00E-01,-2.09E-01,6.53E-01,0.00E+00,5.00E-02,3.30E-01,-5.00E-01,1.00E+00,2.00E-01 11 | -3.96E-02,-1.00E-01,-2.48E-01,1.31E-01,5.48E-01,1.00E-01,8.00E-01,2.11E-01,2.00E-01,1.00E+00 12 | -------------------------------------------------------------------------------- /tools/slurm_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -l 2 | #SBATCH --account=vjgo8416-ms-img-pc # this is the account that will be charged for the job, change to your own account 3 | #SBATCH --qos=turing # use the turing partition, change to your preferred partition 4 | #SBATCH --gres=gpu:1 # use 1 GPU 5 | #SBATCH --time 8:00:00 # maximum walltime 6 | #SBATCH --cpus-per-gpu=36 # use 36 CPU cores per GPU 7 | #SBATCH --ntasks-per-node=1 # use 1 task per node 8 | #SBATCH --mem=0 # use all available memory on the node 9 | #SBATCH --nodes=1 # use 1 node 10 | 11 | # Load baskerville environment, created as described in main README 12 | source /path/to/environment/pyenv_affinity/bin/activate 13 | 14 | # Load the required modules 15 | module purge 16 | module load baskerville 17 | module load bask-apps/live 18 | 19 | module load NCCL/2.12.12-GCCcore-11.3.0-CUDA-11.7.0 20 | module load PyTorch/2.0.1-foss-2022a-CUDA-11.7.0 21 | module load torchvision/0.15.2-foss-2022a-CUDA-11.7.0 22 | 23 | 24 | # debugging flags (optional) 25 | export NCCL_DEBUG=INFO 26 | export PYTHONFAULTHANDLER=1 27 | 28 | 29 | srun python /path/to/affinity-vae/run.py --config_file path/to/avae-config_file --new_out 30 | -------------------------------------------------------------------------------- /configs/avae-create_subtomo_test_config.yml: -------------------------------------------------------------------------------- 1 | ################################## 2 | # This input file contains the parameters required for running 3 | # run_create_subtomo.py 4 | # This run script can preprocess a full tomogram by applying various filters such as : 5 | # Normalisation 6 | # add gaussian noise for benchmarking purposes 7 | # apply a bandpass filter fiven a low and high frequency threshold 8 | # create augmentations of the voxel where the voxels are rotated between a minimum and maximum input angle 9 | # can apply padding 10 | # Not implemented yet : can apply crop after rotation 11 | # Not Implemented yet: gaussian blur 12 | 13 | #### Data parameters 14 | input_path: /Users/mfamili/work/test_affinity_merge/test_create_subtomo_avae 15 | annot_path: /Users/mfamili/work/test_affinity_merge/test_create_subtomo_avae 16 | output_path: /Users/mfamili/work/test_affinity_merge/test_create_subtomo_avae/subtomos 17 | 18 | #### 19 | datatype: mrc 20 | vox_size: [32, 32, 32] 21 | 22 | ### Filters 23 | bandpass: False 24 | low_freq: 0 25 | high_freq: 15 26 | gaussian_blur: True 27 | normalise: True 28 | add_noise: True 29 | noise_int: 3 30 | 31 | #### augmentation 32 | augment: 3 33 | aug_th_min: -20 34 | aug_th_max: 20 35 | -------------------------------------------------------------------------------- /.github/workflows/checks.yml: -------------------------------------------------------------------------------- 1 | name: checks 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: checkout source 10 | uses: actions/checkout@v3 11 | 12 | - name: set up python 13 | uses: actions/setup-python@v4 14 | with: 15 | python-version: '3.10.4' 16 | 17 | - name: set PY 18 | run: echo "PY=$(python -VV | sha256sum | cut -d' ' -f1)" >> $GITHUB_ENV 19 | 20 | - name: cache stuff 21 | uses: actions/cache@v2 22 | with: 23 | path: | 24 | ${{ env.pythonLocation }} 25 | ~/.cache/pre-commit 26 | key: | 27 | pre-commit-${{ env.PY }}-${{ hashFiles('.pre-commit-config.yaml') }} 28 | - name: install dependencies 29 | run: pip install pre-commit 30 | 31 | - name: Executable access to .pre-commit-config.yaml 32 | run: chmod -x .pre-commit-config.yaml 33 | 34 | - name: Install pre-commit hooks 35 | run: pre-commit install 36 | 37 | # This will run on all files in the repo not just those that have been 38 | # committed. Once formatting has been applied once globally, from then on 39 | # the files being changed by pre-commit should be just those that are 40 | # being committed - provided that people are using the pre-commit hook to 41 | # format their code. 42 | - name: make executable 43 | run: chmod -x .pre-commit-config.yaml 44 | - name: run pre-commit 45 | run: pre-commit run --all-files --color always 46 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, Alan Turing Institute All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, 6 | are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its contributors 16 | may be used to endorse or promote products derived from this software without 17 | specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 20 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 21 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 23 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 24 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 25 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 26 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 28 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autoupdate_commit_msg: "chore: update pre-commit hooks" 3 | autofix_commit_msg: "style: pre-commit fixes" 4 | 5 | repos: 6 | - repo: https://github.com/pycqa/isort 7 | rev: 5.12.0 8 | hooks: 9 | - id: isort 10 | args: ["--profile", "black", "--filter-files"] 11 | 12 | - repo: https://github.com/pre-commit/pre-commit-hooks 13 | rev: "v4.4.0" 14 | hooks: 15 | - id: check-docstring-first 16 | - id: check-executables-have-shebangs 17 | - id: check-merge-conflict 18 | - id: check-toml 19 | - id: end-of-file-fixer 20 | - id: mixed-line-ending 21 | args: [--fix=lf] 22 | - id: requirements-txt-fixer 23 | - id: trailing-whitespace 24 | args: [--markdown-linebreak-ext=md] 25 | - id: check-case-conflict 26 | - id: check-merge-conflict 27 | - id: check-symlinks 28 | - id: check-yaml 29 | - id: debug-statements 30 | - id: end-of-file-fixer 31 | - id: name-tests-test 32 | args: ["--pytest-test-first"] 33 | - id: requirements-txt-fixer 34 | 35 | - repo: https://github.com/psf/black 36 | rev: 22.3.0 37 | hooks: 38 | - id: black 39 | #- repo: https://github.com/pre-commit/mirrors-prettier 40 | #rev: "v3.0.3" 41 | # hooks: 42 | # - id: prettier 43 | # types_or: [yaml, markdown, html, css, scss, javascript, json] 44 | # args: [--prose-wrap=always] 45 | 46 | - repo: https://github.com/pre-commit/mirrors-mypy 47 | rev: 'v1.7.1' # Use the sha / tag you want to point at 48 | hooks: 49 | - id: mypy 50 | language_version: python3.10 51 | args: [--ignore-missing-imports] 52 | additional_dependencies: ['types-PyYAML'] 53 | -------------------------------------------------------------------------------- /avae/base.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import logging 3 | 4 | import torch 5 | 6 | from avae.decoders.base import AbstractDecoder 7 | from avae.encoders.base import AbstractEncoder 8 | 9 | 10 | class SpatialDims(enum.IntEnum): 11 | TWO = 2 12 | THREE = 3 13 | 14 | 15 | # Abstract AffinityVAE 16 | class AbstractAffinityVAE(torch.nn.Module): 17 | def __init__( 18 | self, encoder: AbstractEncoder, decoder: AbstractDecoder, **kwargs 19 | ) -> None: 20 | super().__init__() 21 | self.encoder = encoder 22 | self.decoder = decoder 23 | 24 | def forward( 25 | self, x: torch.Tensor 26 | ) -> tuple[ 27 | torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor 28 | ]: 29 | mu, log_var, pose = self.encoder(x) 30 | z = self.reparameterize(mu, log_var) 31 | x = self.decoder(z, pose) 32 | return x, mu, log_var, z, pose 33 | 34 | def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor): 35 | raise NotImplementedError( 36 | "Reparameterize method must be implemented in child class." 37 | ) 38 | 39 | 40 | def set_layer_dim( 41 | ndim: SpatialDims | int, 42 | ) -> tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module]: 43 | if ndim == SpatialDims.TWO: 44 | return torch.nn.Conv2d, torch.nn.ConvTranspose2d, torch.nn.BatchNorm2d 45 | elif ndim == SpatialDims.THREE: 46 | return torch.nn.Conv3d, torch.nn.ConvTranspose3d, torch.nn.BatchNorm3d 47 | else: 48 | logging.error("Data must be 2D or 3D.") 49 | exit(1) 50 | 51 | 52 | def dims_after_pooling(start: int, n_pools: int) -> int: 53 | """Calculate the size of a layer after n pooling ops. 54 | 55 | Parameters 56 | ---------- 57 | start: int 58 | The size of the layer before pooling. 59 | n_pools: int 60 | The number of pooling operations. 61 | 62 | Returns 63 | ------- 64 | int 65 | The size of the layer after pooling. 66 | 67 | 68 | """ 69 | return start // (2**n_pools) 70 | -------------------------------------------------------------------------------- /tutorials/mnist_data/mnist_config.yml: -------------------------------------------------------------------------------- 1 | #### Data parameters 2 | 3 | datapath: mnist_data/images_train 4 | affinity: mnist_data/affinity_mnist.csv 5 | classes: mnist_data/classes_mnist.csv 6 | restart: False 7 | limit: 1000 8 | split: 20 9 | datatype: npy 10 | 11 | #### Data processing parameters 12 | shift_mean: False 13 | gaussian_blur: False 14 | normalise: False 15 | rescale: False 16 | 17 | #### Network parameters 18 | 19 | depth: 4 20 | channels: 8 21 | latent_dims: 16 22 | pose_dims: 1 23 | 24 | #### Training parameters 25 | 26 | no_val_drop: true 27 | epochs: 1000 28 | batch: 128 29 | learning: 0.0001 30 | gpu: true 31 | freq_sta: 10 # frequency of state saves 32 | 33 | #### Loss parameters 34 | 35 | beta: 0.0001 36 | gamma: 0.001 37 | loss_fn: MSE 38 | 39 | #### Cyclical annealing parameters for beta 40 | beta_min: 0 41 | beta_cycle: 5 42 | beta_ratio: 0.5 43 | cyc_method_beta: flat 44 | 45 | #### Cyclical annealing parameters for gamma 46 | gamma_min: 0 47 | gamma_cycle: 5 48 | gamma_ratio: 0.5 49 | cyc_method_gamma: flat 50 | 51 | #### Evaluation parameters 52 | eval: false 53 | freq_eval: 10 # frequency of test evaluation (if test present) 54 | 55 | #### Visualization parameters 56 | vis_los: false # loss on/off (every epoch from epoch 2) 57 | vis_acc: false # confusion matrices and F1 scores on/off (frequency controlled) 58 | vis_rec: false # reconstructions on/off (frequency controlled) 59 | vis_emb: false # latent embeddings on/off (frequency controlled) 60 | vis_int: false # latent interpolations on/off (frequency controlled) 61 | vis_dis: false # latent disentanglement on/off (frequency controlled) 62 | vis_pos: false # pose disentanglement on/off (frequency controlled) 63 | vis_pose_class: 1,2 # comma separated string eg. 1,2 64 | vis_cyc: false # beta for cyclic anealing on/off (once per run) 65 | vis_aff: false # affinity matrix on/off (once per run) 66 | vis_his: false # class distribution histogram on/off (once per run) 67 | vis_sim: false # latent space similarity matrix on/off 68 | vis_all: true # sets all above to on/off 69 | dynamic: true # dynamic embedding visualisations on/off (freq same as freq_emb) 70 | 71 | freq_acc: 10 # frequency of accuracy visualisation 72 | freq_rec: 10 # frequency of reconstruction visualisation 73 | freq_emb: 10 # frequency of latent embedding visualisation 74 | freq_int: 10 # frequency of latent interpolation visualisation 75 | freq_dis: 10 # frequency of latent disentanglement visualisation 76 | freq_pos: 10 # frequency of pose disentanglement visualisation 77 | freq_sim: 10 # frequency of similarity matrix visualisation 78 | freq_all: 10 # sets all above to the same value (inc freq_sta and freq_eval) 79 | 80 | #### debug parameters 81 | debug: false 82 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # not change default config 132 | configs/avae-test-config.yml 133 | 134 | # remove mac specific files 135 | .DS_Store 136 | 137 | # remove pycharm specific files 138 | .idea/ 139 | affinity-vae.code-workspace 140 | .vscode/launch.json 141 | .vscode/settings.json 142 | .gitignore 143 | -------------------------------------------------------------------------------- /configs/avae-test-config.yml: -------------------------------------------------------------------------------- 1 | #### Data parameters 2 | 3 | datapath: tests/testdata_npy 4 | affinity: tests/testdata_npy/affinity_an.csv 5 | classes: tests/testdata_npy/classes.csv 6 | restart: False 7 | limit: 1000 8 | split: 20 9 | datatype: npy 10 | 11 | #### Data processing parameters 12 | shift_mean: False 13 | gaussian_blur: False 14 | normalise: False 15 | rescale: False 16 | 17 | #### Network parameters 18 | 19 | depth: 3 20 | channels: 64 21 | filters: [8, 16, 32, 64] 22 | latent_dims: 8 23 | pose_dims: 1 24 | bnorm_encoder: true 25 | bnorm_decoder: true 26 | 27 | model: a 28 | 29 | #### Training parameters 30 | 31 | no_val_drop: true 32 | epochs: 1000 33 | batch: 128 34 | learning: 0.001 35 | gpu: true 36 | freq_sta: 10 # frequency of state saves 37 | 38 | #### Loss parameters 39 | 40 | beta: 1 41 | gamma: 2 42 | loss_fn: MSE 43 | klreduction: mean 44 | 45 | #### Cyclical annealing parameters for beta 46 | beta_min: 0 47 | beta_cycle: 5 48 | beta_ratio: 0.5 49 | cyc_method_beta : cycle_sigmoid 50 | 51 | #### Cyclical annealing parameters for gamma 52 | gamma_min: 0 53 | gamma_cycle: 5 54 | gamma_ratio: 0.5 55 | cyc_method_gamma : cycle_sigmoid 56 | 57 | #### Evaluation parameters 58 | eval: false 59 | freq_eval: 10 # frequency of test evaluation (if test present) 60 | 61 | #### Visualization parameters 62 | vis_los: false # loss on/off (every epoch from epoch 2) 63 | vis_acc: false # confusion matrices and F1 scores on/off (frequency controlled) 64 | vis_rec: false # reconstructions on/off (frequency controlled) 65 | vis_emb: false # latent embeddings on/off (frequency controlled) 66 | vis_int: false # latent interpolations on/off (frequency controlled) 67 | vis_dis: false # latent disentanglement on/off (frequency controlled) 68 | vis_pos: false # pose disentanglement on/off (frequency controlled) 69 | vis_pose_class: 1,2 # comma separated string eg. 1,2 70 | vis_cyc: false # beta for cyclic anealing on/off (once per run) 71 | vis_aff: false # affinity matrix on/off (once per run) 72 | vis_his: false # class distribution histogram on/off (once per run) 73 | vis_sim: false # latent space similarity matrix on/off 74 | vis_all: true # sets all above to on/off 75 | dynamic: true # dynamic embedding visualisations on/off (freq same as freq_emb) 76 | 77 | freq_acc: 10 # frequency of accuracy visualisation 78 | freq_rec: 10 # frequency of reconstruction visualisation 79 | freq_emb: 10 # frequency of latent embedding visualisation 80 | freq_int: 10 # frequency of latent interpolation visualisation 81 | freq_dis: 10 # frequency of latent disentanglement visualisation 82 | freq_pos: 10 # frequency of pose disentanglement visualisation 83 | freq_sim: 10 # frequency of similarity matrix visualisation 84 | freq_all: 10 # sets all above to the same value (inc freq_sta and freq_eval) 85 | 86 | #### debug parameters 87 | debug: false 88 | -------------------------------------------------------------------------------- /tests/test_loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | 8 | from avae.decoders.decoders import Decoder 9 | from avae.encoders.encoders import Encoder 10 | from avae.loss import AVAELoss 11 | from avae.models import AffinityVAE as avae 12 | from avae.utils_learning import set_device 13 | from tests import testdata_mrc 14 | 15 | torch.manual_seed(0) 16 | 17 | 18 | class LossTest(unittest.TestCase): 19 | def setUp(self) -> None: 20 | """Test instantiation of the loss.""" 21 | 22 | self._orig_dir = os.getcwd() 23 | self.test_data = os.path.dirname(testdata_mrc.__file__) 24 | os.chdir(self.test_data) 25 | 26 | self.affinity = pd.read_csv("affinity_fsc_10.csv").to_numpy( 27 | dtype=np.float32 28 | ) 29 | device = set_device(False) 30 | self.loss = AVAELoss( 31 | device=device, 32 | beta=[1], 33 | gamma=[1], 34 | lookup_aff=self.affinity, 35 | recon_fn="MSE", 36 | ) 37 | self.encoder_3d = Encoder( 38 | capacity=8, 39 | depth=4, 40 | input_shape=(64, 64, 64), 41 | latent_dims=16, 42 | pose_dims=3, 43 | bnorm=True, 44 | ) 45 | 46 | self.decoder_3d = Decoder( 47 | capacity=8, 48 | depth=4, 49 | input_shape=(64, 64, 64), 50 | latent_dims=16, 51 | pose_dims=3, 52 | ) 53 | 54 | self.vae = avae(self.encoder_3d, self.decoder_3d) 55 | 56 | def tearDown(self): 57 | os.chdir(self._orig_dir) 58 | 59 | def test_loss_instatiation(self): 60 | """Test instantiation of the loss.""" 61 | 62 | assert isinstance(self.loss, AVAELoss) 63 | 64 | def test_loss(self): 65 | 66 | x = torch.randn(14, 1, 64, 64, 64) 67 | 68 | x_hat, lat_mu, lat_logvar, lat, lat_pose = self.vae(x) 69 | total_loss, recon_loss, kldivergence, affin_loss = self.loss( 70 | x, 71 | x_hat, 72 | lat_mu, 73 | lat_logvar, 74 | 0, 75 | batch_aff=torch.ones(14, dtype=torch.int), 76 | ) 77 | 78 | self.assertGreaterEqual(total_loss.detach().numpy().item(0), 1.1171) 79 | self.assertGreater(recon_loss.detach().numpy().item(0), 1) 80 | self.assertGreater(recon_loss, kldivergence) 81 | self.assertGreater(total_loss, affin_loss) 82 | 83 | def test_loss_bvae(self): 84 | 85 | x = torch.randn(14, 1, 64, 64, 64) 86 | 87 | self.loss = AVAELoss( 88 | torch.device("cpu"), 89 | [1], 90 | [0], 91 | lookup_aff=self.affinity, 92 | recon_fn="MSE", 93 | ) 94 | 95 | x_hat, lat_mu, lat_logvar, _, _ = self.vae(x) 96 | _, _, _, affin_loss = self.loss( 97 | x, 98 | x_hat, 99 | lat_mu, 100 | lat_logvar, 101 | 0, 102 | batch_aff=torch.ones(14, dtype=torch.int), 103 | ) 104 | self.assertEqual(affin_loss, 0) 105 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # Useful scripts for AffinityVAE 2 | 3 | ## Run Napari plugin on AffinityVAE trained models 4 | 5 | We can inspect the latent space of the AffinityVAE models and the trained 6 | decoder using a Napari pluggin. To do so, we need to install Napari and an older 7 | version of pydantic (less than 2.0.0) to avoid conflicts with the napari 8 | library. 9 | 10 | For this you need to be at the root of the repository and run the following 11 | commands: 12 | 13 | ```bash 14 | python -m venv env_napari 15 | source env_napari/bin/activate 16 | python -m pip install --upgrade pip 17 | python -m pip install -e ."[napari]" 18 | ``` 19 | 20 | > **Important note**: This installation should only be used for running the 21 | > Napari plugin. If you want to train/run the AffinityVAE model, you'll need to 22 | > reinstall the package using the "all" option (e.g ```python -m pip install -e ."[all]"```). The reason for 23 | > this is that the napari library requires an older version of pydantic (1.10.0) 24 | > and the AffinityVAE model requires pydantic > 2.0 or higher. 25 | 26 | After setting up the environment, you can use the `--help` flag in the command 27 | line to see the different options: 28 | 29 | ```bash 30 | 31 | python scripts/run_napari_model_view.py --help 32 | 33 | Output: 34 | 35 | usage: run_napari_model_view.py [-h] --model_file MODEL_FILE --meta_file META_FILE [--manifold MANIFOLD] [--pose_dims POSE_DIMS] [--latent_dims LATENT_DIMS] 36 | 37 | options: 38 | -h, --help show this help message and exit 39 | --model_file MODEL_FILE 40 | Path to model state file. 41 | --meta_file META_FILE 42 | Path to the meta file. 43 | --manifold MANIFOLD Manifold to use for latent space. This can be either `umap` or `load`. 44 | --pose_dims POSE_DIMS 45 | Number of pose dimensions. This will overwrite the internal model value. 46 | --latent_dims LATENT_DIMS 47 | Number of latent dimensions. This will overwrite the internal model value. 48 | ``` 49 | 50 | For a minimal usage you can run the Napari plugin just providing the model state 51 | and meta files from a run: 52 | 53 | ```bash 54 | python scripts/run_napari_model_view.py --model_file --meta_file 55 | ``` 56 | 57 | in this case the pose and latent dimensions will be read from the model file and 58 | the manifold will be created using the `umap` library. 59 | 60 | **Using "umap" and "load" manifold options**: 61 | 62 | Using the `umap` option can be slow, as the manifold is created on the fly and 63 | reversed everytime you click on a point on the embedding map. If you want to use 64 | the precomputed manifold from the AffinityVAE run, you can use the `load` option 65 | (adding the flag `--manifold "load"` to the command line. This will use the 66 | embedding variables from the meta file to create the manifold, and it will 67 | reverse to the latent space by finding the closest distance point in the data. 68 | This is much faster than using the `umap` option and a good option for quick 69 | debugging/exploring. However, this option will not allow you to explore unseen 70 | regions of the latent space, only the available data. 71 | 72 | **Example of usage** 73 | 74 | Here you can see an example of the interface for the Napari plugin running on a 75 | model trained on Shrec protein dataset. You can interact with the interface by 76 | clicking on the different points on the embedding map and see the corresponding 77 | reconstruction and latent space values. You can also scroll through the 78 | different poses and the different latent dimensions and looking at the resulting 79 | reconstructions using the available sliders on the top of the interface. 80 | ![NapariUMAP](../images/Napari1.png) ![NapariLatents](../images/Napari2.png) 81 | -------------------------------------------------------------------------------- /scripts/run_napari_model_view.py: -------------------------------------------------------------------------------- 1 | import napari 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | 6 | from avae.napari import GenerativeAffinityVAEWidget 7 | 8 | 9 | def setup_napari(): 10 | """Setup the napari viewer.""" 11 | BOX_SIZE = 32 12 | viewer = napari.Viewer() 13 | viewer.dims.ndisplay = 2 14 | viewer.axes.visible = True 15 | viewer.add_image( 16 | np.zeros((BOX_SIZE, BOX_SIZE), dtype=np.float32), 17 | colormap="inferno", 18 | rendering="iso", 19 | name="Reconstruction", 20 | depiction="volume", 21 | ) 22 | return viewer 23 | 24 | 25 | def load_model(model_fn, meta_fn, device="cpu"): 26 | """Load the model and meta data.""" 27 | print("Loading model") 28 | checkpoint = torch.load(model_fn, map_location=torch.device(device)) 29 | model = checkpoint["model_class_object"] 30 | 31 | model.load_state_dict(checkpoint["model_state_dict"]) 32 | 33 | meta_df = pd.read_pickle(meta_fn) 34 | 35 | model.to(device) 36 | model.eval() 37 | return model, meta_df 38 | 39 | 40 | def run_napari(model_fn, meta_fn, ldim=None, pdim=None, manifold="umap"): 41 | """Run the napari viewer for model.""" 42 | 43 | viewer = setup_napari() 44 | 45 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 46 | model, meta_df = load_model(model_fn, meta_fn, device=device) 47 | 48 | if pdim is not None: 49 | pose_dims = pdim 50 | else: 51 | try: 52 | pose_dims = model.encoder.pose_fc.out_features 53 | except AttributeError: 54 | raise AttributeError( 55 | "Model does not have pose attributes, please specify manually." 56 | ) 57 | 58 | if ldim is not None: 59 | lat_dims = ldim 60 | else: 61 | try: 62 | lat_dims = model.encoder.mu.out_features 63 | except AttributeError: 64 | raise AttributeError( 65 | "Model does not have latent attributes, please specify manually." 66 | ) 67 | 68 | widget = GenerativeAffinityVAEWidget( 69 | viewer, 70 | model, 71 | device=device, 72 | meta_df=meta_df, 73 | pose_dims=pose_dims, 74 | latent_dims=lat_dims, 75 | manifold=manifold, 76 | ) 77 | viewer.window.add_dock_widget(widget, name="AffinityVAE") 78 | napari.run() 79 | 80 | 81 | if __name__ == "__main__": 82 | import argparse 83 | 84 | parser = argparse.ArgumentParser() 85 | 86 | parser.add_argument( 87 | "--model_file", help="Path to model state file.", required=True 88 | ) 89 | parser.add_argument( 90 | "--meta_file", help="Path to the meta file.", required=True 91 | ) 92 | parser.add_argument( 93 | "--manifold", 94 | help="Manifold to use for latent space. This can be either `umap` or `load`.", 95 | required=False, 96 | default="umap", 97 | ) 98 | parser.add_argument( 99 | "--pose_dims", 100 | help="Number of pose dimensions. This will overwrite the internal model value.", 101 | default=None, 102 | required=False, 103 | ) 104 | parser.add_argument( 105 | "--latent_dims", 106 | help="Number of latent dimensions. This will overwrite the internal model value.", 107 | default=None, 108 | required=False, 109 | ) 110 | 111 | args = parser.parse_args() 112 | 113 | model_fn = args.model_file 114 | meta_fn = args.meta_file 115 | pdim = args.pose_dims 116 | ldim = args.latent_dims 117 | manifold = args.manifold 118 | 119 | if manifold != "umap" and manifold != "load": 120 | raise ValueError("Manifold must be either \"umap\" or \"load\".") 121 | 122 | print("Running napari viewer with model: {}".format(model_fn)) 123 | run_napari(model_fn, meta_fn, ldim, pdim, manifold) 124 | -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tempfile 4 | import unittest 5 | 6 | import lightning as lt 7 | import numpy as np 8 | 9 | from avae.data import load_data 10 | from tests import testdata_mrc 11 | 12 | 13 | class DataTest(unittest.TestCase): 14 | def setUp(self) -> None: 15 | """Setup data and output directories.""" 16 | self._orig_dir = os.getcwd() 17 | self.test_data = os.path.dirname(testdata_mrc.__file__) 18 | self.test_dir = tempfile.mkdtemp(prefix="avae_") 19 | self.fabric = lt.Fabric() 20 | self.fabric.launch() 21 | 22 | # Change to test directory 23 | os.chdir(self.test_dir) 24 | 25 | def tearDown(self): 26 | os.chdir(self._orig_dir) 27 | if os.path.exists(self.test_dir): 28 | shutil.rmtree(self.test_dir) 29 | 30 | def test_load_eval_data(self): 31 | """Test loading evaluation data.""" 32 | 33 | sh = 32 34 | 35 | shutil.copytree( 36 | os.path.join(self.test_data, "test"), 37 | os.path.join(self.test_dir, "eval"), 38 | ) 39 | 40 | out, data_dim = load_data( 41 | "./eval", 42 | datatype="mrc", 43 | lim=None, 44 | batch_s=32, 45 | eval=True, 46 | gaussian_blur=True, 47 | normalise=True, 48 | shift_min=True, 49 | rescale=sh, 50 | fabric=self.fabric, 51 | ) 52 | print(os.getcwd()) 53 | 54 | # test load_data 55 | assert len(out) == 1 56 | eval_data = out 57 | # assert isinstance(eval_data, lt.fast.DataLoader) 58 | 59 | # test ProteinDataset 60 | eval_batch = list(eval_data)[0] 61 | xs, ys, aff, meta = eval_batch 62 | assert len(xs) == len(ys) == len(aff) 63 | assert ( 64 | np.all(aff.numpy()) == 0 65 | ) # this is expected only for eval without affinity 66 | 67 | assert xs[0].shape[-1] == sh 68 | 69 | def test_load_train_data(self): 70 | """Test loading training data.""" 71 | shutil.copytree(self.test_data, os.path.join(self.test_dir, "train")) 72 | 73 | sh = 32 74 | 75 | out = load_data( 76 | "./train", 77 | datatype="mrc", 78 | lim=None, 79 | splt=30, 80 | batch_s=16, 81 | no_val_drop=True, 82 | eval=False, 83 | affinity_path="./train/affinity_fsc_10.csv", 84 | gaussian_blur=True, 85 | normalise=True, 86 | shift_min=True, 87 | rescale=sh, 88 | fabric=self.fabric, 89 | ) 90 | 91 | # test load_data 92 | assert len(out) == 5 93 | train_data, val_data, test_data, lookup, data_dim = out 94 | assert len(train_data) >= len(val_data) 95 | # assert isinstance(train_data, lt.fabric.DataLoader) 96 | 97 | # test ProtenDataset 98 | train_batch = list(train_data)[0] 99 | xs, ys, aff, meta = train_batch 100 | assert len(xs) == len(ys) == len(aff) 101 | assert len(np.unique(aff.numpy())) == 4 102 | assert xs[0].shape[-1] == sh 103 | 104 | # test affinity matrix 105 | assert isinstance(lookup, np.ndarray) 106 | assert len(lookup.shape) == 2 107 | assert lookup.shape[0] == lookup.shape[1] 108 | assert lookup[0][0] == 1 109 | 110 | data_0 = list(out[0])[0][0][0][0][0].detach().numpy() 111 | data_1 = list(out[0])[0][0][0][0][0].detach().numpy() 112 | 113 | # test reproducibility 114 | out_2, _, _, _, _ = load_data( 115 | "./train", 116 | datatype="mrc", 117 | lim=None, 118 | splt=30, 119 | batch_s=16, 120 | no_val_drop=True, 121 | eval=False, 122 | affinity_path="./train/affinity_fsc_10.csv", 123 | gaussian_blur=True, 124 | normalise=True, 125 | shift_min=True, 126 | rescale=sh, 127 | fabric=self.fabric, 128 | ) 129 | data_0_1 = list(out_2)[0][0][0][0][0].detach().numpy() 130 | data_1_1 = list(out_2)[0][0][0][0][0].detach().numpy() 131 | 132 | assert data_0.all() == data_0_1.all() 133 | assert data_1.all() == data_1_1.all() 134 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.2", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "affinivae" 7 | version = "0.0.1" 8 | description = "Affinity-VAE" 9 | authors = [{name = "Alan R. Lowe", email = "alowe@turing.ac.uk"}] 10 | license = {text = "BSD 3-Clause"} 11 | classifiers = [ 12 | "Programming Language :: Python", 13 | "Programming Language :: Python :: 3 :: Only", 14 | "Programming Language :: Python :: 3.10", 15 | "Topic :: Scientific/Engineering", 16 | "Topic :: Scientific/Engineering :: Bio-Informatics", 17 | "Topic :: Scientific/Engineering :: Image Recognition", 18 | "Operating System :: POSIX", 19 | "Operating System :: Unix", 20 | "Operating System :: MacOS", 21 | ] 22 | requires-python = ">=3.10" 23 | 24 | dependencies = [ 25 | "altair", 26 | "click", 27 | "mrcfile", 28 | "matplotlib", 29 | "numpy", 30 | "scipy", 31 | "pillow", 32 | "pandas", 33 | "pyyaml", 34 | "requests", 35 | "scikit-image", 36 | "scikit-learn", 37 | "tensorboard", 38 | "umap-learn", 39 | "caked@git+https://github.com/alan-turing-institute/caked/" #to be changed when caked is in pypi. 40 | ] 41 | 42 | [project.readme] 43 | file = "README.md" 44 | content-type = "text/markdown" 45 | 46 | [project.urls] 47 | Homepage = "https://github.com/alan-turing-institute/affinity-vae" 48 | 49 | [tool.hatch] 50 | version.path = "avae/__init__.py" 51 | envs.default.dependencies = [ 52 | "pytest", 53 | "pytest-cov", 54 | ] 55 | 56 | 57 | 58 | [project.optional-dependencies] 59 | all = ["torch", 60 | "torchvision", 61 | "pydantic>2", 62 | "lightning", 63 | ] 64 | 65 | test = [ 66 | "affinivae[all]", 67 | "pytest >=6", 68 | "pytest-cov >=3", 69 | "lightning", 70 | ] 71 | 72 | baskerville = [ 73 | "pydantic>2", 74 | "lightning", 75 | 76 | ] 77 | napari = [ 78 | "napari[all]", 79 | "torch", 80 | "torchvision", 81 | "pydantic<2", # latest napari version is incompatible with napari, napari demo doesnt use pydantic. 82 | ] 83 | 84 | [tool.pytest.ini_options] 85 | minversion = "6.0" 86 | addopts = ["-ra", "--showlocals"]#, "--strict-markers", "--strict-config"] 87 | xfail_strict = true 88 | filterwarnings = [ # NEED TO FIX THESE WARNINGS ISSUE X 89 | "error", 90 | "ignore:jsonschema.RefResolver is deprecated as of v4.18.0, in favor of the:DeprecationWarning", 91 | "ignore:The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator.", 92 | "ignore:Tensorflow not installed; ParametricUMAP will be unavailable", 93 | "ignore:pkg_resources is deprecated as an API", 94 | "ignore:Deprecated call to `pkg_resources.declare_namespace('mpl_toolkits')`", 95 | "ignore:The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later.", 96 | "ignore:Maximum iterations (500) reached and the optimization hasn't converged yet", # FIX THESE WARNINGS ISSUE x 97 | "ignore:Could not initialize NNPACK!.", 98 | "ignore:Mean of empty slice.", 99 | "ignore:invalid value encountered in divide", 100 | "ignore:Degrees of freedom <= 0 for slice", 101 | "ignore:Scoring failed. The score on this train-test partition for these parameters will be set to nan.", 102 | "ignore:One or more of the test scores are non-finite:", 103 | ] 104 | 105 | log_cli_level = "INFO" 106 | testpaths = [ 107 | "tests", 108 | ] 109 | [tool.setuptools] 110 | include-package-data = true 111 | license-files = ["LICENSE.md"] 112 | 113 | [tool.setuptools.packages] 114 | find = {namespaces = false} 115 | 116 | [tool.tox] 117 | isolated_build = "true" 118 | 119 | [tool.tox.envlist] 120 | extend-ignore = [ 121 | "PLR", # Design related pylint codes 122 | "E501", # Line too long 123 | ] 124 | 125 | src = ["avae"] 126 | unfixable = [ 127 | "F841", # Would remove unused variables 128 | ] 129 | exclude = [] 130 | flake8-unused-arguments.ignore-variadic-names = true 131 | 132 | [tool.black] 133 | target-version = ['py38', 'py39', 'py310'] 134 | line-length = 79 135 | skip-string-normalization = true 136 | include = '\.pyi?$' 137 | exclude = ''' 138 | ( 139 | /( 140 | \.eggs 141 | | \.git 142 | | \.hg 143 | | \.mypy_cache 144 | | \.tox 145 | | \.venv 146 | | _build 147 | | build 148 | | dist 149 | | examples 150 | )/ 151 | ) 152 | ''' 153 | 154 | [tool.isort] 155 | profile = "black" 156 | line_length = 79 157 | -------------------------------------------------------------------------------- /tests/test_config.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import tempfile 4 | import unittest 5 | 6 | import avae.settings as settings 7 | import configs 8 | from avae.config import AffinityConfig, load_config_params, write_config_file 9 | from tests import testdata_mrc, testdata_npy 10 | 11 | 12 | class ConfigTest(unittest.TestCase): 13 | def setUp(self) -> None: 14 | """Setup data and output directories.""" 15 | self._orig_dir = os.getcwd() 16 | self.temp_dir = tempfile.TemporaryDirectory() 17 | self.config = os.path.join( 18 | os.path.dirname(configs.__file__), "avae-test-config.yml" 19 | ) 20 | self.data_local = { 21 | "datapath": os.path.dirname(testdata_mrc.__file__), 22 | "affinity": os.path.join( 23 | os.path.dirname(testdata_mrc.__file__), "affinity_fsc_10.csv" 24 | ), 25 | "classes": os.path.join( 26 | os.path.dirname(testdata_npy.__file__), "classes.csv" 27 | ), 28 | "split": 5, 29 | "epochs": 100, 30 | "vis_sim": True, 31 | } 32 | 33 | self.datapath_local_missing = { 34 | "affinity": os.path.join( 35 | os.path.dirname(testdata_mrc.__file__), "affinity_fsc_10.csv" 36 | ), 37 | "classes": os.path.join( 38 | os.path.dirname(testdata_npy.__file__), "classes.csv" 39 | ), 40 | "split": 5, 41 | "epochs": 100, 42 | } 43 | 44 | self.data_local_fail_wrong_label = { 45 | "datapath": os.path.dirname(testdata_mrc.__file__), 46 | "affinity": os.path.join( 47 | os.path.dirname(testdata_mrc.__file__), "affinity_fsc_10.csv" 48 | ), 49 | "classes": os.path.join( 50 | os.path.dirname(testdata_npy.__file__), "classes.csv" 51 | ), 52 | "collect_meta": True, 53 | "batch": "25", 54 | } 55 | self.data_local_fail_wrong_type = { 56 | "batch": "25", 57 | "split": "test", 58 | } 59 | 60 | self.default_model = AffinityConfig() 61 | 62 | def tearDown(self): 63 | os.chdir(self._orig_dir) 64 | self.temp_dir.cleanup() 65 | 66 | def test_validate_config(self): 67 | 68 | data = load_config_params( 69 | config_file=self.config, local_vars=self.data_local 70 | ) 71 | 72 | self.assertEqual( 73 | len(data.items()), len(self.default_model.model_dump().items()) 74 | ) 75 | self.assertEqual(data["batch"], 128) 76 | self.assertEqual(data["affinity"], self.data_local['affinity']) 77 | self.assertEqual(data["new_out"], self.default_model.new_out) 78 | self.assertEqual(data["split"], self.data_local['split']) 79 | self.assertEqual(data["epochs"], self.data_local['epochs']) 80 | self.assertEqual(data["vis_sim"], self.data_local['vis_sim']) 81 | 82 | data_config_only = load_config_params(config_file=self.config) 83 | self.assertEqual(data_config_only["epochs"], 1000) 84 | 85 | data_local_data_only = load_config_params(local_vars=self.data_local) 86 | self.assertEqual(data_local_data_only["epochs"], 100) 87 | self.assertEqual( 88 | data_local_data_only["vis_all"], self.default_model.vis_all 89 | ) 90 | 91 | def test_validate_config_fail(self): 92 | 93 | with self.assertRaises(ValueError): 94 | load_config_params(local_vars=self.data_local_fail_wrong_label) 95 | 96 | with self.assertRaises(TypeError): 97 | load_config_params(local_vars=self.data_local_fail_wrong_type) 98 | 99 | with self.assertRaises(ValueError): 100 | load_config_params(local_vars=self.datapath_local_missing) 101 | 102 | # wrong input for classifier 103 | self.data_local['classifier'] = 'LS' 104 | with self.assertRaises(TypeError): 105 | load_config_params(local_vars=self.data_local) 106 | 107 | def test_write_config_file(self): 108 | 109 | data = load_config_params( 110 | config_file=self.config, local_vars=self.data_local 111 | ) 112 | os.chdir(self.temp_dir.name) 113 | 114 | write_config_file(settings.date_time_run, data) 115 | 116 | files = glob.glob( 117 | self.temp_dir.name 118 | + "/configs/*" 119 | + settings.date_time_run 120 | + "*.yaml" 121 | ) 122 | 123 | data_from_output = load_config_params(config_file=files[0]) 124 | 125 | self.assertEqual(len(files), 1) 126 | 127 | data['datapath'] = data_from_output['datapath'] 128 | data['affinity'] = data_from_output['affinity'] 129 | data['classes'] = data_from_output['classes'] 130 | 131 | self.assertEqual(data_from_output, data) 132 | -------------------------------------------------------------------------------- /tutorials/README.md: -------------------------------------------------------------------------------- 1 | # AffinityVAE Tutorial 2 | 3 | ## Running AffinityVAE on the MNIST dataset 4 | 5 | ### 1. Install AffinityVAE 6 | To start, clone the AffinityVAE repository and switch to the `affinity-vae` directory 7 | As described in the [README](../README.md), AffinityVAE can be installed using 8 | pip: 9 | 10 | ```bash 11 | python -m venv env 12 | source env/bin/activate 13 | python -m pip install --upgrade pip 14 | python -m pip install -e ."[all]" 15 | ``` 16 | 17 | ### 2. Download the MNIST dataset 18 | 19 | The MNIST dataset can be downloaded from: 20 | https://figshare.com/articles/dataset/mnist_pkl_gz/13303457 21 | 22 | In this tutorial, we assume that the dataset is downloaded to this `tutorials` 23 | directory. Make sure that the downloaded file is indeed `mnist.pkl.gz` and that 24 | it hasn't been decompressed by your browser (this could happen in Safari). 25 | 26 | ### Create the MNSIT dataset for affinityVAE 27 | 28 | We need an MNIST dataset with augmentation and rotations, and saved with a 29 | structure that affinityVAE can read (a subdirectory for training data, and 30 | another for testing). This can be done by running the following command: 31 | 32 | ```bash 33 | python mnist_saver.py --mnist_file mnist.pkl.gz --output_path . 34 | ``` 35 | 36 | here the first argument is the path to the downloaded MNIST dataset, and the 37 | second argument is the path to the directory where the processed dataset will be 38 | saved. In this example, we are saving the dataset in the current directory under 39 | [mnist_data](mnist_data). There you can also find a visualization of a few 40 | random samples of the dataset in the `mnist_examples.png` file for validation. 41 | 42 | Affinity-VAE is a highly configurable model (to see the full list of 43 | configurables run `python absolute/path/to/affinity-vae/run.py --help`). In this 44 | tutorial we will use a yaml file [mnist_config.yml](mnist_data/mnist_config.yml) 45 | to configure our run. 46 | 47 | Now we are ready to run AffinityVAE on the MNIST dataset. For this we need the 48 | data divided into training and test sets, we need a configuration file 49 | (`mnist_config.yml`), and we need the files `classes_mnist.csv` (list of labels 50 | to use in training) and `affinity_mnist.csv` (affinity matrix for the classes 51 | defined above), which are provided in his tutorial. All these files can be found 52 | in the [mnist_data](mnist_data) directory. 53 | 54 | First thing you need to do is to modify the `mnist_config.yml` file to point to 55 | the absolute paths of the data, classes, and affinity files. Currently, the 56 | paths in the config file are relative to this directory but this can cause issues 57 | when running the code from a different directory. To modify the paths, open the 58 | `mnist_config.yml` file and change the following lines: 59 | 60 | ```yaml 61 | datapath: /absolute/path/to/mnist_data/images_train/ 62 | classes: /absolute/path/to/mnist_data/classes_mnist.csv 63 | affinity: /absolute/path/to/mnist_data/affinity_mnist.csv 64 | ``` 65 | 66 | In general, we recommend to work with absolute paths to avoid issues. 67 | 68 | To train AffinityVAE on our rotated MNIST dataset, run the following command: 69 | 70 | ```bash 71 | python /absolute/path/to/affinity-vae/run.py --config_file /absolute/path/to/mnist_data/mnist_config.yml --new_out 72 | ``` 73 | 74 | This will train AffinityVAE on the MNIST dataset and save the results in a new 75 | timestamped directory created by the `--new_out` flag. In this case the run is 76 | configured by the `mnist_config.yml` file. 77 | 78 | You can also configure the run by passing arguments to the `run.py` script as 79 | shown in the main README file and in the following example: 80 | 81 | ```bash 82 | python path/to/affinity-vae/run.py --config_file /absolute/path/to/mnist_data/mnist_config.yml --beta 0.1 --gamma 0.01 --lr 0.001 --epochs 200 --new_out 83 | ``` 84 | 85 | Here the command line arguments override the values in the config file. 86 | 87 | The config file provided here has optimal parameters for the MNIST dataset, so 88 | we recommend to use it as it is. 89 | 90 | Once the training finishes you can evaluate the model on unseen data using the 91 | test set by stepping into the new directory and running the following. The outputs from the test run will be saved in the existing plots directory created during the training run, and will have the suffix ```evl_``` in the filename. 92 | 93 | ```bash 94 | cd path/to/new_out 95 | python path/to/affinity-vae/run.py --config_file /absolute/path/to/mnist_data/mnist_config.yml --datapath /absolute/path/to/mnist_data/images_test/ --eval 96 | ``` 97 | 98 | _Note_: During training we've left the class `9` out, so we can use it for 99 | evaluation and see where it fits in the affinity organised latent space. 100 | 101 | You can also restart training from a checkpoint by running 102 | 103 | ```bash 104 | cd path/to/new_out 105 | python path/to/affinity-vae/run.py --config_file /absolute/path/to/mnist_data/mnist_config.yml --restart --epochs 2000 --data_path /absolute/path/to/mnist_data/images_train/ 106 | ``` 107 | 108 | here epochs are set to 2000 to continue training for 1000 extra epochs (assuming 109 | initial training ran for the 1000 epochs defined in the config file). 110 | 111 | # Outputs of the training and evaluation runs 112 | 113 | The training run will create a directory with the following structure: 114 | 115 | ``` 116 | new_out 117 | ├── configs # copy of the config file used for the run for reproducibility 118 | ├── logs # run logs 119 | ├── plots # plots and data of the training and evaluation metrics 120 | ├── latents # html files latent space of the training and test sets, these files can be very large, so we recomend them to only runnin the at evaluation time (using the --dynamic flag) 121 | ├── states #saving checkpoints of the models and the training latent space to be use for evaluation or restart training 122 | 123 | 124 | ``` 125 | -------------------------------------------------------------------------------- /tests/test_vis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import unittest 4 | 5 | import torch 6 | 7 | from avae import config, settings 8 | from tests import testdata_mrc 9 | from tests.test_train_eval_pipeline import helper_train_eval 10 | 11 | torch.random.manual_seed(10) 12 | random.seed(10) 13 | 14 | 15 | class VisPipelineTest(unittest.TestCase): 16 | """Test pipeline with isolated visualisations.""" 17 | 18 | def setUp(self) -> None: 19 | """Setup data and output directories.""" 20 | self.test_data = os.path.dirname(testdata_mrc.__file__) 21 | 22 | self.data_params = { # only specify if differs from default 23 | # data 24 | "datapath": self.test_data, 25 | "affinity": os.path.join(self.test_data, "affinity_fsc_10.csv"), 26 | "classes": os.path.join(self.test_data, "classes.csv"), 27 | "split": 10, 28 | # preprocess 29 | "rescale": 32, 30 | "gaussian_blur": True, 31 | "normalise": True, 32 | "shift_min": True, 33 | # model 34 | "epochs": 1, 35 | "batch": 25, 36 | "model": "u", 37 | "channels": 3, 38 | "depth": 4, 39 | "latent_dims": 8, 40 | "pose_dims": 3, 41 | "learning": 0.03, 42 | "beta": 1, 43 | "gamma": 1, 44 | # vis 45 | "vis_all": False, 46 | "freq_all": 1, 47 | "vis_format": "png", 48 | } 49 | 50 | self.data = config.load_config_params(local_vars=self.data_params) 51 | config.setup_visualisation_config(self.data) 52 | 53 | def test_accuracy(self): 54 | settings.VIS_ACC = True 55 | 56 | _, n_plots, _, _ = helper_train_eval( 57 | self.data, eval=False, nolat=True, nostate=True 58 | ) 59 | 60 | self.assertEqual(n_plots, 11) 61 | # confusion val and train + val and train norm, f1, f1 val and f1 train 62 | 63 | def test_loss(self): 64 | self.data["epochs"] = 2 65 | settings.VIS_LOS = True 66 | 67 | __, n_plots, _, _ = helper_train_eval( 68 | self.data, eval=False, nolat=True, nostate=True 69 | ) 70 | 71 | self.assertEqual(n_plots, 3) # loss, total loss, train loss 72 | 73 | def test_recon(self): 74 | settings.VIS_REC = True 75 | 76 | _, n_plots, _, _ = helper_train_eval( 77 | self.data, eval=False, nolat=True, nostate=True 78 | ) 79 | 80 | self.assertEqual(n_plots, 7) 81 | # recon in and out for train and val + 3D + reconstructions dir 82 | 83 | def test_similarity(self): 84 | settings.VIS_SIM = True 85 | 86 | _, n_plots, _, _ = helper_train_eval( 87 | self.data, eval=False, nolat=True, nostate=True 88 | ) 89 | 90 | self.assertEqual(n_plots, 2) # recon and val 91 | 92 | def test_embedding(self): 93 | settings.VIS_EMB = True 94 | 95 | _, n_plots, _, _ = helper_train_eval( 96 | self.data, eval=False, nolat=True, nostate=True 97 | ) 98 | 99 | self.assertEqual(n_plots, 4) # tsne and umap for lat+pose 100 | 101 | def test_dyn_embedding(self): 102 | settings.VIS_EMB = True 103 | settings.VIS_DYN = True 104 | 105 | _, n_plots, n_latent, _ = helper_train_eval( 106 | self.data, eval=False, nostate=True 107 | ) 108 | 109 | self.assertEqual(n_plots, 4) 110 | self.assertEqual(n_latent, 2) # tsne and umap 111 | 112 | def test_latent_disentanglement(self): 113 | settings.VIS_DIS = True 114 | lat_pose = [(0, 0), (3, 0), (0, 3), (3, 0)] 115 | 116 | for l, p in lat_pose: 117 | self.data["latent_dim"] = l 118 | self.data["pose_dim"] = p 119 | 120 | _, n_plots, _, _ = helper_train_eval( 121 | self.data, eval=False, nolat=True, nostate=True 122 | ) 123 | 124 | self.assertEqual(n_plots, 1) # in future 3 + per class 125 | 126 | def test_pose_disentanglement(self): 127 | settings.VIS_POS = True 128 | self.data["pose_dims"] = 0 129 | 130 | _, n_plots, _, _ = helper_train_eval( 131 | self.data, eval=False, nolat=True, nostate=True 132 | ) 133 | 134 | self.assertEqual(n_plots, 0) 135 | 136 | class_out = [(None, 1), ("1b23", 2), ("1b23,1dkg", 3)] 137 | 138 | for c, out in class_out: 139 | self.data["pose_dims"] = 3 140 | settings.VIS_POSE_CLASS = c 141 | 142 | _, n_plots, _, _ = helper_train_eval( 143 | self.data, eval=False, nolat=True, nostate=True 144 | ) 145 | 146 | self.assertEqual(n_plots, out) # in future 3 + per class 147 | 148 | def test_interpolation(self): 149 | settings.VIS_INT = True 150 | lat_pose = [(0, 0), (3, 0), (0, 3), (3, 0)] 151 | 152 | for l, p in lat_pose: 153 | self.data["latent_dim"] = l 154 | self.data["pose_dim"] = p 155 | 156 | _, n_plots, _, _ = helper_train_eval( 157 | self.data, eval=False, nolat=True, nostate=True 158 | ) 159 | 160 | self.assertEqual(n_plots, 1) # in future 3 161 | 162 | def test_affinity(self): 163 | settings.VIS_AFF = True 164 | 165 | _, n_plots, _, _ = helper_train_eval( 166 | self.data, eval=False, nolat=True, nostate=True 167 | ) 168 | 169 | self.assertEqual(n_plots, 1) 170 | 171 | def test_distribution(self): 172 | settings.VIS_HIS = True 173 | 174 | _, n_plots, _, _ = helper_train_eval( 175 | self.data, eval=False, nolat=True, nostate=True 176 | ) 177 | 178 | self.assertEqual(n_plots, 2) # train and val 179 | 180 | def test_cyc_variables(self): 181 | settings.VIS_CYC = True 182 | 183 | _, n_plots, _, _ = helper_train_eval( 184 | self.data, eval=False, nolat=True, nostate=True 185 | ) 186 | 187 | self.assertEqual(n_plots, 2) # beta and gamma 188 | -------------------------------------------------------------------------------- /tutorials/mnist_saver.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import os 3 | import random 4 | import sys 5 | 6 | import _pickle as cPickle 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import pandas as pd 10 | from PIL import Image 11 | from scipy.ndimage import rotate 12 | 13 | 14 | def load_mnist(path): 15 | 16 | f = gzip.open(path, 'rb') 17 | if sys.version_info < (3,): 18 | data = cPickle.load(f) 19 | else: 20 | data = cPickle.load(f, encoding='bytes') 21 | f.close() 22 | return data 23 | 24 | 25 | def augmentation(mol, a, aug_th_min, aug_th_max): 26 | mol = np.array(mol) 27 | deg_per_rot = 5 28 | angle = np.random.randint(aug_th_min, aug_th_max, size=(len(mol.shape),)) 29 | for ax in range(angle.size): 30 | theta = angle[ax] * deg_per_rot * a 31 | axes = (ax, (ax + 1) % angle.size) 32 | mol = rotate(mol, theta, axes=axes, order=0, reshape=False) 33 | return mol 34 | 35 | 36 | def make_dirs(path_list): 37 | for path in path_list: 38 | if not os.path.exists(path): 39 | os.makedirs(path) 40 | 41 | 42 | def make_containing_dirs(path_list): 43 | for path in path_list: 44 | dir_name = os.path.dirname(path) 45 | if not os.path.exists(dir_name): 46 | os.makedirs(dir_name) 47 | 48 | 49 | def padding(array, xx, yy, zz=None): 50 | """ 51 | :param array: numpy array 52 | :param xx: desired height 53 | :param yy: desirex width 54 | :return: padded array 55 | """ 56 | 57 | h = array.shape[0] 58 | w = array.shape[1] 59 | if zz is not None: 60 | z = array.shape[2] 61 | 62 | a = (xx - h) // 2 63 | aa = xx - a - h 64 | 65 | b = (yy - w) // 2 66 | bb = yy - b - w 67 | 68 | if zz is not None: 69 | c = (zz - z) // 2 70 | cc = zz - c - z 71 | 72 | return np.pad( 73 | array, pad_width=((a, aa), (b, bb), (c, cc)), mode='constant' 74 | ) 75 | 76 | else: 77 | return np.pad(array, pad_width=((a, aa), (b, bb)), mode='constant') 78 | 79 | 80 | class SaverMNIST: 81 | def __init__( 82 | self, 83 | image_train_path, 84 | image_test_path, 85 | csv_train_path, 86 | csv_test_path, 87 | image_shape=(32, 32), 88 | rotation_angle=180, 89 | data=None, 90 | ): 91 | 92 | self._image_format = '.npy' 93 | 94 | self.store_image_paths = [image_train_path, image_test_path] 95 | self.store_csv_paths = [csv_train_path, csv_test_path] 96 | self.image_shape = image_shape 97 | self.rotation_angle = rotation_angle 98 | 99 | make_dirs(self.store_image_paths) 100 | make_containing_dirs(self.store_csv_paths) 101 | 102 | # Load MNIST dataset 103 | self.data = data 104 | 105 | def run(self): 106 | for collection, store_image_path, store_csv_path in zip( 107 | self.data, self.store_image_paths, self.store_csv_paths 108 | ): 109 | labels_list = [] 110 | paths_list = [] 111 | 112 | for index, (image, label) in enumerate( 113 | zip(collection[0], collection[1]) 114 | ): 115 | im = Image.fromarray(image) 116 | width, height = im.size 117 | image_name = str(label) + '_' + str(index) + self._image_format 118 | image = np.array(image) 119 | 120 | angle = np.random.randint( 121 | -self.rotation_angle, 122 | +self.rotation_angle, 123 | size=(len(image.shape),), 124 | ) 125 | for ax in range(angle.size): 126 | theta = angle[ax] 127 | axes = (ax, (ax + 1) % angle.size) 128 | image = rotate( 129 | image, theta, axes=axes, order=0, reshape=False 130 | ) 131 | 132 | image = padding( 133 | image, self.image_shape[0], self.image_shape[1] 134 | ) 135 | 136 | # Build save path 137 | save_path = os.path.join(store_image_path, image_name) 138 | # im.save(save_path, dpi=(300, 300)) 139 | np.save(save_path, image) 140 | 141 | labels_list.append(label) 142 | paths_list.append(save_path) 143 | 144 | df = pd.DataFrame( 145 | {'image_paths': paths_list, 'labels': labels_list} 146 | ) 147 | 148 | df.to_csv(store_csv_path) 149 | 150 | 151 | if __name__ == '__main__': 152 | 153 | import argparse 154 | 155 | parser = argparse.ArgumentParser() 156 | 157 | # -db DATABSE -u USERNAME -p PASSWORD -size 20 158 | parser.add_argument("--mnist_file", help="Path to mnist.pkl.gz file") 159 | parser.add_argument("--output_path", help="Output path to save data") 160 | 161 | args = parser.parse_args() 162 | 163 | mnist_path = args.mnist_file 164 | output_path = args.output_path 165 | 166 | data_path = os.path.join(output_path, 'mnist_data') 167 | 168 | data = load_mnist(mnist_path) 169 | 170 | mnist_saver = SaverMNIST( 171 | data=data, 172 | image_train_path=data_path + '/images_train', 173 | image_test_path=data_path + '/images_test', 174 | csv_train_path=data_path + '/train.csv', 175 | csv_test_path=data_path + '/test.csv', 176 | ) 177 | 178 | # Write files into disk 179 | mnist_saver.run() 180 | 181 | examples_train = pd.read_csv(data_path + "/train.csv") 182 | examples_test = pd.read_csv(data_path + "/test.csv") 183 | 184 | array_train_plots = [] 185 | array_test_plots = [] 186 | 187 | for i in range(3): 188 | array_train_plots.append( 189 | np.load( 190 | examples_train['image_paths'][ 191 | random.randint(0, examples_train.shape[0]) 192 | ] 193 | ) 194 | ) 195 | array_test_plots.append( 196 | np.load( 197 | examples_train['image_paths'][ 198 | random.randint(0, examples_test.shape[0]) 199 | ] 200 | ) 201 | ) 202 | 203 | # create figure 204 | fig, axis = plt.subplots(2, 3) 205 | axis[0, 0].imshow(array_train_plots[0]) 206 | axis[0, 1].imshow(array_train_plots[1]) 207 | axis[0, 2].imshow(array_train_plots[2]) 208 | axis[1, 0].imshow(array_test_plots[0]) 209 | axis[1, 1].imshow(array_test_plots[1]) 210 | axis[1, 2].imshow(array_test_plots[2]) 211 | plt.savefig(data_path + "/mnist_examples.png") 212 | -------------------------------------------------------------------------------- /avae/models.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABC, abstractmethod 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | from avae.base import AbstractAffinityVAE 9 | from avae.decoders.decoders import Decoder, DecoderA, DecoderB 10 | from avae.decoders.differentiable import GaussianSplatDecoder 11 | from avae.encoders.encoders import Encoder, EncoderA, EncoderB 12 | 13 | from .base import SpatialDims 14 | 15 | 16 | def set_layer_dim( 17 | ndim: SpatialDims | int, 18 | ) -> tuple[nn.Module, nn.Module, nn.Module]: 19 | if ndim == SpatialDims.TWO: 20 | return nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d 21 | elif ndim == SpatialDims.THREE: 22 | return nn.Conv3d, nn.ConvTranspose3d, nn.BatchNorm3d 23 | else: 24 | logging.error("Data must be 2D or 3D.") 25 | exit(1) 26 | 27 | 28 | def dims_after_pooling(start: int, n_pools: int) -> int: 29 | """Calculate the size of a layer after n pooling ops. 30 | 31 | Parameters 32 | ---------- 33 | start: int 34 | The size of the layer before pooling. 35 | n_pools: int 36 | The number of pooling operations. 37 | 38 | Returns 39 | ------- 40 | int 41 | The size of the layer after pooling. 42 | 43 | 44 | """ 45 | return start // (2**n_pools) 46 | 47 | 48 | def set_device(gpu: bool) -> torch.device: 49 | """Set the torch device to use for training and inference. 50 | 51 | Parameters 52 | ---------- 53 | gpu: bool 54 | If True, the model will be trained on GPU. 55 | 56 | Returns 57 | ------- 58 | device: torch.device 59 | 60 | """ 61 | device = torch.device( 62 | "cuda:0" if gpu and torch.cuda.is_available() else "cpu" 63 | ) 64 | if gpu and device == "cpu": 65 | logging.warning( 66 | "\n\nWARNING: no GPU available, running on CPU instead.\n" 67 | ) 68 | return device 69 | 70 | 71 | def build_model( 72 | model_type: str, 73 | input_shape: tuple, 74 | channels: int, 75 | depth: int, 76 | lat_dims: int, 77 | pose_dims: int, 78 | bnorm_encoder: bool, 79 | bnorm_decoder: bool, 80 | n_splats: int, 81 | gsd_conv_layers: int, 82 | device: torch.device, 83 | filters: list | None = None, 84 | ): 85 | """Create the AffinityVAE model. 86 | 87 | Parameters 88 | ---------- 89 | model_type : str 90 | The type of model to create. Must be one of : a, b, u or gsd. 91 | input_shape : tuple 92 | The size of the input. 93 | channels : int 94 | The number of channels in the model. 95 | depth : int 96 | The depth of the model. 97 | lat_dims : int 98 | The number of latent dimensions. 99 | pose_dims : int 100 | The number of pose dimensions. 101 | bnorm_encoder : bool 102 | Whether to use batch normalisation in the encoder. 103 | bnorm_decoder : bool 104 | Whether to use batch normalisation in the decoder. 105 | n_splats : int 106 | The number of splats in the Gaussian Splat Decoder. 107 | gsd_conv_layers : int 108 | The number of convolutional layers in the Gaussian Splat Decoder. 109 | device : torch.device 110 | The device to use for training and inference. 111 | filters : list or None 112 | The filters to use in the model. 113 | 114 | """ 115 | 116 | if filters is not None: 117 | filters = np.array( 118 | np.array(filters).replace(" ", "").split(","), dtype=np.int64 119 | ) 120 | 121 | if model_type == "a": 122 | encoder = EncoderA( 123 | input_shape, 124 | channels, 125 | depth, 126 | lat_dims, 127 | pose_dims, 128 | bnorm=bnorm_encoder, 129 | ) 130 | decoder = DecoderA( 131 | input_shape, 132 | channels, 133 | depth, 134 | lat_dims, 135 | pose_dims, 136 | bnorm=bnorm_decoder, 137 | ) 138 | elif model_type == "b": 139 | encoder = EncoderB(input_shape, channels, depth, lat_dims, pose_dims) 140 | decoder = DecoderB(input_shape, channels, depth, lat_dims, pose_dims) 141 | elif model_type == "u": 142 | encoder = Encoder( 143 | input_shape=input_shape, 144 | capacity=channels, 145 | filters=filters, 146 | depth=depth, 147 | latent_dims=lat_dims, 148 | pose_dims=pose_dims, 149 | bnorm=bnorm_encoder, 150 | ) 151 | decoder = Decoder( 152 | input_shape=input_shape, 153 | capacity=channels, 154 | filters=filters, 155 | depth=depth, 156 | latent_dims=lat_dims, 157 | pose_dims=pose_dims, 158 | bnorm=bnorm_decoder, 159 | ) 160 | elif model_type == "gsd": 161 | encoder = EncoderA( 162 | input_shape, 163 | channels, 164 | depth, 165 | lat_dims, 166 | pose_dims, 167 | bnorm=bnorm_encoder, 168 | ) 169 | decoder = GaussianSplatDecoder( 170 | input_shape, 171 | n_splats=n_splats, 172 | latent_dims=lat_dims, 173 | output_channels=gsd_conv_layers, 174 | device=device, 175 | pose_dims=pose_dims, 176 | ) 177 | else: 178 | raise ValueError( 179 | "Invalid model type", 180 | model_type, 181 | "must be one of : a, b, u or gsd", 182 | ) 183 | 184 | vae = AffinityVAE(encoder, decoder) 185 | 186 | return vae 187 | 188 | 189 | # 190 | # Concrete implementation of the AffinityVAE 191 | class AffinityVAE(AbstractAffinityVAE): 192 | def __init__(self, encoder, decoder): 193 | super(AffinityVAE, self).__init__(encoder, decoder) 194 | self.encoder = encoder 195 | self.decoder = decoder 196 | 197 | if self.encoder.pose != self.decoder.pose: 198 | logging.error("Encoder and decoder pose must be the same.") 199 | raise RuntimeError("Encoder and decoder pose must be the same.") 200 | 201 | self.pose = self.encoder.pose 202 | 203 | def forward(self, x): 204 | # encode 205 | if self.pose: 206 | latent_mu, latent_logvar, latent_pose = self.encoder(x) 207 | else: 208 | latent_mu, latent_logvar = self.encoder(x) 209 | latent_pose = None 210 | # reparametrise 211 | latent = self.reparametrise(latent_mu, latent_logvar) 212 | # decode 213 | x_recon = self.decoder(latent, latent_pose) # pose set to None if pd=0 214 | 215 | return x_recon, latent_mu, latent_logvar, latent, latent_pose 216 | 217 | def reparametrise(self, mu, log_var): 218 | if self.training: 219 | std = torch.exp(0.5 * log_var) 220 | eps = torch.randn_like(std) 221 | return eps * std + mu 222 | else: 223 | return mu 224 | -------------------------------------------------------------------------------- /scripts/run_create_subtomo.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import warnings 4 | from datetime import datetime 5 | 6 | import click 7 | import yaml 8 | 9 | from tools.create_subtomo import create_subtomo 10 | 11 | if not os.path.exists("../logs"): 12 | os.mkdir("../logs") 13 | dt_name = datetime.now().strftime("%H_%M_%d_%m_%Y") 14 | logging.basicConfig( 15 | level=logging.INFO, 16 | format="%(asctime)s [%(levelname)s] %(message)s", 17 | handlers=[ 18 | logging.FileHandler("logs/avae_run_log_" + dt_name + ".log"), 19 | logging.StreamHandler(), 20 | ], 21 | ) 22 | 23 | 24 | class ConvertStrToList(click.Option): 25 | def type_cast_value(self, ctx, value) -> list: 26 | try: 27 | value = str(value) 28 | assert value.count("[") == 1 and value.count("]") == 1 29 | list_as_str = value.replace('"', "'").split("[")[1].split("]")[0] 30 | list_of_items = [ 31 | item.strip().strip("'") for item in list_as_str.split() 32 | ] 33 | return list_of_items 34 | except Exception: 35 | raise click.BadParameter(value) 36 | 37 | 38 | @click.command(name="Subtomogram creator") 39 | @click.option("--config_file", type=click.Path(exists=True)) 40 | @click.option( 41 | "--input_path", 42 | "-ip", 43 | type=str, 44 | default=None, 45 | help="Path to the folder containing the full tomogram/image.", 46 | ) 47 | @click.option( 48 | "--output_path", 49 | "-op", 50 | type=str, 51 | default=None, 52 | help="Path to the folder for output subtomograms", 53 | ) 54 | @click.option( 55 | "--annot_path", 56 | "-op", 57 | type=str, 58 | default=None, 59 | help="path to the folder containing the name of the particles and their x,y,z coordinates", 60 | ) 61 | @click.option( 62 | "--datatype", 63 | "-dtype", 64 | type=str, 65 | default=None, 66 | help="Type of the data: mrc, npy", 67 | ) 68 | @click.option( 69 | "--vox_size", 70 | "-vs", 71 | cls=ConvertStrToList, 72 | default=[], 73 | help="size of each subtomogram voxel given as a list where vox_size: [x,y,x]", 74 | ) 75 | @click.option( 76 | "--bandpass", 77 | "-bp", 78 | type=bool, 79 | default=None, 80 | is_flag=True, 81 | help="Apply band pass", 82 | ) 83 | @click.option( 84 | "--low_freq", 85 | "-lf", 86 | type=float, 87 | default=None, 88 | help="Lower frequency threshold for the band pass filter", 89 | ) 90 | @click.option( 91 | "--high_freq", 92 | "-hf", 93 | type=float, 94 | default=None, 95 | help="higher frequency threshold for the band pass filter", 96 | ) 97 | @click.option( 98 | "--gaussian_blur", 99 | "-gb", 100 | type=bool, 101 | default=None, 102 | is_flag=True, 103 | help="Applying gaussian bluring to the image data which should help removing noise. The minimum and maximum for this is hardcoded.", 104 | ) 105 | @click.option( 106 | "--normalise", 107 | "-nrm", 108 | type=bool, 109 | default=None, 110 | is_flag=True, 111 | help="Normalise data", 112 | ) 113 | @click.option( 114 | "--add_noise", 115 | "-n", 116 | type=bool, 117 | default=None, 118 | is_flag=True, 119 | help="Add noise to images, this can be used for benchmarking", 120 | ) 121 | @click.option( 122 | "--noise_int", 123 | "-ni", 124 | type=int, 125 | default=None, 126 | help="noise intensity", 127 | ) 128 | @click.option( 129 | "--padding", 130 | "-ni", 131 | type=list, 132 | default=None, 133 | help="size of padding boxes", 134 | ) 135 | @click.option( 136 | "--augment", 137 | "-aug", 138 | type=int, 139 | default=None, 140 | help="perform the given number of aumentation on each voxel before saving", 141 | ) 142 | @click.option( 143 | "--aug_th_min", 144 | "-atmin", 145 | type=int, 146 | default=None, 147 | help="The minimum value of the augmentation range in degrees", 148 | ) 149 | @click.option( 150 | "--aug_th_max", 151 | "-atmax", 152 | type=int, 153 | default=None, 154 | help="The maximum value of the augmentation range in degrees", 155 | ) 156 | def run( 157 | config_file, 158 | input_path, 159 | output_path, 160 | datatype, 161 | annot_path, 162 | vox_size, 163 | bandpass, 164 | low_freq=0, 165 | high_freq=15, 166 | gaussian_blur=False, 167 | normalise=False, 168 | add_noise=False, 169 | noise_int=0, 170 | padding=None, 171 | augment=False, 172 | aug_num=5, 173 | aug_th_min=-45, 174 | aug_th_max=45, 175 | ): 176 | 177 | warnings.simplefilter("ignore", FutureWarning) 178 | # read config file and command line arguments and assign to local variables that are used in the rest of the code 179 | local_vars = locals().copy() 180 | print(local_vars) 181 | 182 | if config_file is not None: 183 | with open(config_file, "r") as f: 184 | logging.info("Reading submission configuration file" + config_file) 185 | data = yaml.load(f, Loader=yaml.FullLoader) 186 | # returns JSON object as 187 | 188 | for key, val in local_vars.items(): 189 | if ( 190 | val is not None 191 | and isinstance(val, (int, float, bool, str)) 192 | or data.get(key) is None 193 | ): 194 | logging.warning( 195 | "Command line argument " 196 | + key 197 | + " is overwriting config file value to: " 198 | + str(val) 199 | ) 200 | data[key] = val 201 | else: 202 | logging.info( 203 | "Setting " 204 | + key 205 | + " to config file value: " 206 | + str(data[key]) 207 | ) 208 | else: 209 | # if no config file is provided, use command line arguments 210 | data = local_vars 211 | 212 | logging.info( 213 | "Saving final submission config file to: " 214 | + "avae_final_config" 215 | + dt_name 216 | + ".yaml" 217 | ) 218 | file = open("avae_final_config" + dt_name + ".yaml", "w") 219 | yaml.dump(data, file) 220 | file.close() 221 | logging.info("YAML File saved!") 222 | 223 | create_subtomo( 224 | input_path=data["input_path"], 225 | output_path=data["output_path"], 226 | datatype=data["datatype"], 227 | annot_path=data["annot_path"], 228 | vox_size=data["vox_size"], 229 | bandpass=data["bandpass"], 230 | low_freq=data["low_freq"], 231 | high_freq=data["high_freq"], 232 | gaussian_blur=data["gaussian_blur"], 233 | normalise=data["normalise"], 234 | add_noise=data["add_noise"], 235 | noise_int=data["noise_int"], 236 | padding=data["padding"], 237 | augment=data["augment"], 238 | aug_th_min=data["aug_th_min"], 239 | aug_th_max=data["aug_th_max"], 240 | ) 241 | 242 | 243 | if __name__ == "__main__": 244 | run() 245 | -------------------------------------------------------------------------------- /avae/decoders/spatial.py: -------------------------------------------------------------------------------- 1 | import enum 2 | from typing import Tuple 3 | 4 | import torch 5 | 6 | 7 | class SpatialDims(enum.IntEnum): 8 | TWO = 2 9 | THREE = 3 10 | 11 | 12 | class CartesianAxes(enum.Enum): 13 | """Set of Cartesian axes as 3D vectors.""" 14 | 15 | X = (1, 0, 0) 16 | Y = (0, 1, 0) 17 | Z = (0, 0, 1) 18 | 19 | def as_tensor(self) -> torch.Tensor: 20 | return torch.tensor(self.value, dtype=torch.float32) 21 | 22 | 23 | def axis_angle_to_quaternion( 24 | axis_angles: torch.Tensor, *, normalize: bool = True 25 | ) -> torch.Tensor: 26 | """Convert an axis angle to a rotation quaternion representation. 27 | 28 | Parameters 29 | ---------- 30 | axis_angles : tensor 31 | An (N, 4) tensor specifying the axis angle representations for a batch. 32 | The order is (theta, x_hat, y_hat, z_hat). 33 | normalize : bool, default = True 34 | Whether to normalize the axes to a unit vector (recommended) 35 | 36 | Returns 37 | ------- 38 | quaternions : tensor 39 | A (N, 4) tensor specifying the quaternion representations of the axis 40 | angles. The order is (q0, q1, q2, q3) where q0 is the real part, and 41 | (q1, q2, a3) are the imaginary parts. 42 | 43 | Notes 44 | ----- 45 | 46 | """ 47 | theta = axis_angles[:, 0].unsqueeze(-1) 48 | axis = axis_angles[:, 1:] 49 | 50 | if axis.shape[-1] not in (SpatialDims.THREE,): 51 | raise ValueError("Axis must be specified in three dimensions.") 52 | 53 | axis = torch.nn.functional.normalize(axis, dim=1) if normalize else axis 54 | 55 | real = torch.cos(theta / 2) 56 | imag = axis * torch.sin(theta / 2) 57 | 58 | quaternion = torch.concat([real, imag], axis=-1) 59 | 60 | return quaternion 61 | 62 | 63 | def quaternion_to_rotation_matrix(quaternions: torch.Tensor) -> torch.Tensor: 64 | """Convert quaternion forms to rotation matrices. 65 | 66 | Parameters 67 | ---------- 68 | quaternions : tensor 69 | A (N, 4) tensor specifying the quaternion representations of the axis 70 | angles. The order is (q0, q1, q2, q3). 71 | 72 | Returns 73 | ------- 74 | rotation_matrices : tensor 75 | An (N, 3, 3) tensor specifying rotation matrices for each quaternion. 76 | """ 77 | # extract the real and imaginary parts of the quaternions 78 | q0, q1, q2, q3 = torch.unbind(quaternions, dim=-1) 79 | 80 | # calculate the components of the rotation matrix 81 | R00 = q0**2 + q1**2 - q2**2 - q3**2 82 | R01 = 2 * (q1 * q2 - q0 * q3) 83 | R02 = 2 * (q1 * q3 + q0 * q2) 84 | R10 = 2 * (q1 * q2 + q0 * q3) 85 | R11 = q0**2 - q1**2 + q2**2 - q3**2 86 | R12 = 2 * (q2 * q3 - q0 * q1) 87 | R20 = 2 * (q1 * q3 - q0 * q2) 88 | R21 = 2 * (q2 * q3 + q0 * q1) 89 | R22 = q0**2 - q1**2 - q2**2 + q3**2 90 | 91 | # stack the components into the rotation matrix 92 | rotation_matrices = torch.stack( 93 | [ 94 | torch.stack([R00, R01, R02], axis=-1), 95 | torch.stack([R10, R11, R12], axis=-1), 96 | torch.stack([R20, R21, R22], axis=-1), 97 | ], 98 | axis=-1, 99 | ) 100 | 101 | return rotation_matrices 102 | 103 | 104 | class RotatedCoordinates(torch.nn.Module): 105 | """Creates a homogeneous grid of rotated coordinates. 106 | 107 | Parameters 108 | ---------- 109 | shape : tuple 110 | A tuple describing the output shape of the image data. Can be 2- or 3- 111 | dimensional. 112 | default_axis : CartesianAxes 113 | A default cartesian axis to use for rotation if the pose is provided by 114 | a rotation only. Default is Z, equivalent to a typical image rotation 115 | about the central axis. 116 | 117 | Notes 118 | ----- 119 | Uses a quaternion representation: 120 | https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation 121 | """ 122 | 123 | def __init__( 124 | self, 125 | shape: Tuple[int], 126 | *, 127 | default_axis: CartesianAxes = CartesianAxes.Z, 128 | device: torch.device = torch.device("cpu"), 129 | ): 130 | super().__init__() 131 | 132 | if len(shape) not in (SpatialDims.TWO, SpatialDims.THREE): 133 | raise ValueError("Only 2D or 3D rotations are currently supported") 134 | 135 | grids = torch.meshgrid( 136 | *[torch.linspace(-1, 1, sz) for sz in shape], 137 | indexing="xy", 138 | ) 139 | 140 | # add all zeros for z- if we have a 2d grid 141 | if len(shape) == SpatialDims.TWO: 142 | grids += ( 143 | torch.zeros_like( 144 | grids[0], 145 | ), 146 | ) 147 | 148 | self.coords = ( 149 | torch.stack([torch.ravel(grid) for grid in grids], axis=0) 150 | .unsqueeze(0) 151 | .to(device) 152 | ) 153 | 154 | self.grids = torch.stack(grids, axis=0).unsqueeze(0).to(device) 155 | self._shape = shape 156 | self._ndim = len(shape) 157 | # self._default_axis = torch.tensor(default_axis.value, dtype=torch.float32) 158 | self._default_axis = default_axis.as_tensor() 159 | 160 | def forward(self, pose: torch.Tensor) -> torch.Tensor: 161 | """Forward pass of spatial coordinate rotation. 162 | 163 | Parameters 164 | ---------- 165 | pose : tensor 166 | An (N, D) tensor describing the angle and axis to rotate each grid 167 | by. D can either be 1 (i.e. just angle, assuming z-axis) or 4 (i.e. 168 | the angle and an axis to rotate around). 169 | 170 | Returns 171 | ------- 172 | rotated_grids : tensor 173 | An (N, D, H, W) tensor, equivalent to a rotated meshgrid operation, 174 | where D is the axis dimension. 175 | """ 176 | batch_size = pose.shape[0] 177 | 178 | # in the case where the encoded pose only has one dimension, we need to 179 | # use the pose as a rotation about the z-axis 180 | if pose.shape[-1] == 1: 181 | pose = torch.concat( 182 | [ 183 | pose, 184 | torch.tile(self._default_axis, (batch_size, 1)), 185 | ], 186 | axis=-1, 187 | ) 188 | 189 | # convert axis angles to quaternions 190 | assert pose.shape[-1] == 4, pose.shape 191 | quaternions = axis_angle_to_quaternion(pose, normalize=True) 192 | 193 | # convert the quaternions to rotation matrices 194 | # NOTE(arl): we should probably use rotation matrices OR quaternions 195 | # converting between them is not necessary 196 | rotation_matrices = quaternion_to_rotation_matrix(quaternions) 197 | 198 | # rotate the 3D points using the rotation matrices 199 | rotated_coords = torch.matmul( 200 | rotation_matrices, 201 | self.coords, 202 | ) 203 | # use only the required spatial dimensions 204 | rotated_coords = rotated_coords[:, : self._ndim, :] 205 | 206 | # now create the equivalent of the rotated xy grid 207 | return rotated_coords.reshape((batch_size, self._ndim, *self._shape)) 208 | -------------------------------------------------------------------------------- /avae/utils_learning.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import typing 3 | 4 | import lightning 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | 9 | from avae.loss import AVAELoss 10 | from avae.vis import format 11 | 12 | 13 | def set_device(gpu: bool) -> torch.device: 14 | """Set the torch device to use for training and inference. 15 | 16 | Parameters 17 | ---------- 18 | gpu: bool 19 | If True, the model will be trained on GPU. 20 | 21 | Returns 22 | ------- 23 | device: torch.device 24 | 25 | """ 26 | device = torch.device( 27 | "cuda" if gpu and torch.cuda.is_available() else "cpu" 28 | ) 29 | if gpu and device == "cpu": 30 | logging.warning( 31 | "\n\nWARNING: no GPU available, running on CPU instead.\n" 32 | ) 33 | return device 34 | 35 | 36 | def dims_after_pooling(start: int, n_pools: int) -> int: 37 | """Calculate the size of a layer after n pooling ops. 38 | 39 | Parameters 40 | ---------- 41 | start: int 42 | The size of the layer before pooling. 43 | n_pools: int 44 | The number of pooling operations. 45 | 46 | Returns 47 | ------- 48 | int 49 | The size of the layer after pooling. 50 | 51 | 52 | """ 53 | return start // (2**n_pools) 54 | 55 | 56 | def pass_batch( 57 | fabric: lightning.Fabric, 58 | vae: torch.nn.Module, 59 | batch: list, 60 | b: int, 61 | batches: int, 62 | e: int = 1, 63 | epochs: int = 1, 64 | history: list = [], 65 | loss: AVAELoss | None = None, 66 | optimizer: typing.Any = None, 67 | beta: list[float] | None = None, 68 | ) -> tuple[ 69 | torch.Tensor, 70 | torch.Tensor, 71 | torch.Tensor, 72 | torch.Tensor, 73 | torch.Tensor, 74 | torch.Tensor, 75 | list, 76 | ]: 77 | """Passes a batch through the affinity VAE model epoch and computes the loss. 78 | 79 | Parameters 80 | ---------- 81 | device: torch.device 82 | Device to use for training. 83 | vae: torch.nn.Module 84 | Affinity VAE model class. 85 | batch: list 86 | List of batches with data and labels. 87 | b: int 88 | Batch number. 89 | batches: int 90 | Total number of batches. 91 | e: int 92 | Epoch number. 93 | epochs: int 94 | Total number of epochs. 95 | history: list 96 | List of training losses. 97 | loss: avae.loss.AVAELoss 98 | Loss function class. 99 | optimizer: torch.optim 100 | Optimizer. 101 | beta: float 102 | Beta parameter for affinity-VAE. 103 | 104 | Returns 105 | ------- 106 | x: torch.Tensor 107 | Input data. 108 | x_hat: torch.Tensor 109 | Reconstructed data. 110 | lat_mu: torch.Tensor 111 | Latent mean. 112 | lat_logvar: torch.Tensor 113 | Latent log variance. 114 | lat: torch.Tensor 115 | Latent representation. 116 | lat_pose: torch.Tensor 117 | Latent pose. 118 | history: list 119 | List of training losses. 120 | 121 | 122 | """ 123 | if bool(history == []) ^ bool(loss is None): 124 | raise RuntimeError( 125 | "When validating, both 'loss' and 'history' parameters must be " 126 | "present in 'pass_batch' function." 127 | ) 128 | if bool(e is None) ^ bool(epochs is None): 129 | raise RuntimeError( 130 | "Function 'pass_batch' expects both 'e' and 'epoch' parameters." 131 | ) 132 | if e is None and epochs is None: 133 | e = 1 134 | epochs = 1 135 | 136 | # to device 137 | x = batch[0] 138 | x = x.to(fabric.device) 139 | aff = batch[2] 140 | aff = aff.to(fabric.device) 141 | 142 | # forward 143 | x = x.to(torch.float32) 144 | x_hat, lat_mu, lat_logvar, lat, lat_pose = vae(x) 145 | if loss is not None: 146 | history_loss = loss(x, x_hat, lat_mu, lat_logvar, e, batch_aff=aff) 147 | 148 | if beta is None: 149 | raise RuntimeError( 150 | "Please pass beta value to pass_batch function." 151 | ) 152 | 153 | # record loss 154 | for i in range(len(history[-1])): 155 | history[-1][i] += history_loss[i].item() 156 | logging.debug( 157 | "Epoch: [%d/%d] | Batch: [%d/%d] | Loss: %f | Recon: %f | " 158 | "KLdiv: %f | Affin: %f | Beta: %f" 159 | % (e + 1, epochs, b + 1, batches, *history_loss, beta[e]) 160 | ) 161 | 162 | # backwards 163 | if optimizer is not None: 164 | fabric.backward(history_loss[0]) 165 | optimizer.step() 166 | optimizer.zero_grad() 167 | 168 | return x, x_hat, lat_mu, lat_logvar, lat, lat_pose, history 169 | 170 | 171 | def add_meta( 172 | data_dim: int, 173 | meta_df: pd.DataFrame, 174 | batch_meta: dict, 175 | x_hat: torch.Tensor, 176 | latent_mu: torch.Tensor, 177 | lat_pose: torch.Tensor, 178 | latent_logvar: torch.Tensor, 179 | mode: str = "trn", 180 | ) -> pd.DataFrame: 181 | """ 182 | Created meta data about data and training. 183 | 184 | Parameters 185 | ---------- 186 | data_dim: int 187 | Dimensions of the data. 188 | meta_df: pd.DataFrame 189 | Dataframe containing meta data, to which new data is added. 190 | batch_meta: dict 191 | Meta data about the batch. 192 | x_hat: torch.Tensor 193 | Reconstructed data. 194 | latent_mu: torch.Tensor 195 | Latent mean. 196 | lat_pose: torch.Tensor 197 | Latent pose. 198 | lat_logvar: torch.Tensor 199 | Latent logvar. 200 | mode: str 201 | Data category on training (either 'trn', 'val' or 'test'). 202 | 203 | Returns 204 | ------- 205 | meta_df: pd.DataFrame 206 | Dataframe containing meta data. 207 | 208 | """ 209 | batch_meta = { 210 | k: v.to(device='cpu', non_blocking=True) if hasattr(v, 'to') else v 211 | for k, v in batch_meta.items() 212 | } 213 | 214 | meta = pd.DataFrame(batch_meta) 215 | 216 | meta["mode"] = mode 217 | meta["image"] += format(x_hat, data_dim) 218 | for d in range(latent_mu.shape[-1]): 219 | meta[f"lat{d}"] = np.array(latent_mu[:, d].cpu().detach().numpy()) 220 | for d in range(latent_logvar.shape[-1]): 221 | lat_std = np.exp(0.5 * latent_logvar[:, d].cpu().detach().numpy()) 222 | meta[f"std-{d}"] = np.array(lat_std) 223 | if lat_pose is not None: 224 | for d in range(lat_pose.shape[-1]): 225 | meta[f"pos{d}"] = np.array(lat_pose[:, d].cpu().detach().numpy()) 226 | meta_df = pd.concat( 227 | [meta_df, meta], ignore_index=False 228 | ) # ignore index doesn't overwrite 229 | return meta_df 230 | 231 | 232 | def configure_optimiser( 233 | opt_method: str, model: torch.nn.Module, learning_rate: float 234 | ): 235 | """ 236 | Configure the optimiser for the training. 237 | 238 | Parameters 239 | ---------- 240 | opt_method : str 241 | Optimisation method. 242 | model : torch.nn.Module 243 | Model to be trained. 244 | learning_rate : float 245 | Learning rate for the optimiser. 246 | 247 | Returns 248 | ------- 249 | optimizer : torch.optim 250 | Optimiser for the training. 251 | """ 252 | if opt_method == "adam": 253 | optimizer = torch.optim.Adam( 254 | params=model.parameters(), lr=learning_rate # , weight_decay=1e-5 255 | ) 256 | elif opt_method == "sgd": 257 | optimizer = torch.optim.SGD( 258 | params=model.parameters(), lr=learning_rate # , weight_decay=1e-5 259 | ) 260 | elif opt_method == "asgd": 261 | optimizer = torch.optim.aSGD( 262 | params=model.parameters(), lr=learning_rate # , weight_decay=1e-5 263 | ) 264 | else: 265 | raise ValueError( 266 | "Invalid optimisation method", 267 | opt_method, 268 | "must be adam or sgd if you have other methods in mind, this can be easily added to the train.py", 269 | ) 270 | 271 | return optimizer 272 | -------------------------------------------------------------------------------- /avae/cyc_annealing.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | 5 | 6 | def configure_annealing( 7 | epochs: int, 8 | value_max: float, 9 | value_min: float, 10 | cyc_method: str, 11 | n_cycle: int, 12 | ratio: float, 13 | cycle_load: str | None = None, 14 | ): 15 | """ 16 | This function is used to configure the annealing of the beta and gamma values. 17 | It creates an array of values that oscillate between a maximum and minimum value 18 | for a defined number of cycles. This is used for gamma and beta in the loss term. 19 | The function also allows for the loading of a pre-existing array of beta or gamma values. 20 | 21 | Parameters 22 | ---------- 23 | epochs: int 24 | Number of epochs in training 25 | value_max: float 26 | Maximum value of the beta or gamma 27 | value_min: float 28 | Minimum value of the beta or gamma 29 | cyc_method : str 30 | The method for constructing the cyclical mixing weight 31 | - Flat : regular beta-vae 32 | - Linear 33 | - Sigmoid 34 | - Cosine 35 | - ramp 36 | - delta 37 | n_cycle: int 38 | Number of cycles of the variable to oscillate between min and max 39 | during the epochs 40 | ratio: float 41 | Ratio of increase during ramping 42 | cycle_load: str | None 43 | Path to a file containing the beta or gamma array 44 | 45 | Returns 46 | ------- 47 | cycle_arr: np.ndarray 48 | Array of beta or gamma values 49 | 50 | """ 51 | if value_max == 0 and cyc_method != "flat" and cycle_load is not None: 52 | raise RuntimeError( 53 | "The maximum value for beta is set to 0, it is not possible to" 54 | "oscillate between a maximum and minimum. Please choose the flat method for" 55 | "cyc_method_beta" 56 | ) 57 | 58 | if cycle_load is None: 59 | # If a path for loading the beta array is not provided, 60 | # create it given the input 61 | cycle_arr = ( 62 | cyc_annealing( 63 | epochs, 64 | cyc_method, 65 | n_cycle=n_cycle, 66 | ratio=ratio, 67 | ).var 68 | * (value_max - value_min) 69 | + value_min 70 | ) 71 | else: 72 | cycle_arr = np.load(cycle_load) 73 | if len(cycle_arr) != epochs: 74 | raise RuntimeError( 75 | f"The length of the beta array loaded from file is {len(cycle_arr)} but the number of Epochs specified in the input are {epochs}.\n" 76 | "These two values should be the same." 77 | ) 78 | 79 | return cycle_arr 80 | 81 | 82 | class cyc_annealing: 83 | 84 | """ 85 | This class presents an array which will have a value changing between a minimum 86 | and maximum for a defined number of cycles. This is used for gamma and beta in loss term. 87 | 88 | Parameters 89 | ---------- 90 | n_epochs : number of epochs in training 91 | cyc_method : The method for constructing the cyclical mixing weight 92 | - Flat : regular beta-vae 93 | - Linear 94 | - Sigmoid 95 | - Cosine 96 | - ramp 97 | - delta 98 | start : The starting point (min) 99 | stop : The starting point (min) 100 | n_cycle : Number of cycles of the variable to oscillate between min and max 101 | during the epochs 102 | ratio : ratio of increase during ramping 103 | """ 104 | 105 | def __init__( 106 | self, 107 | n_epoch: int, 108 | cyc_method: str = "flat", 109 | n_cycle: int = 4, 110 | ratio: float = 0.5, 111 | ): 112 | 113 | self.n_epoch = n_epoch 114 | 115 | # The start and stop control where in each the cycle, the sigmoid function starts and stops 116 | # This is for the moment hard coded in, as the feature does not contribute to the outcome of the model for our purpose 117 | self.start = 0 118 | self.stop = 1 119 | self.n_cycle = n_cycle 120 | self.ratio = ratio 121 | 122 | if cyc_method == "flat": 123 | self.var = self._frange_flat() 124 | 125 | elif cyc_method == "cycle_linear": 126 | self.var = self._frange_cycle_linear() 127 | 128 | elif cyc_method == "cycle_sigmoid": 129 | self.var = self._frange_cycle_sigmoid() 130 | 131 | elif cyc_method == "cycle_cosine": 132 | self.var = self._frange_cycle_cosine() 133 | 134 | elif cyc_method == "ramp": 135 | self.var = self._frange_ramp() 136 | 137 | elif cyc_method == "delta": 138 | self.var = self._frange_delta() 139 | 140 | elif cyc_method == "mixed": 141 | self.var = self._frange_mixed() 142 | 143 | else: 144 | raise RuntimeError( 145 | "Select a valid method for cyclical method for your variable. Available options are : " 146 | "flat, cycle_linear, cycle_sigmoid, cycle_cosine, ramp, delta, mixed" 147 | ) 148 | 149 | def _frange_flat(self): 150 | L = np.ones(self.n_epoch) 151 | return L 152 | 153 | def _frange_cycle_linear(self): 154 | L = np.ones(self.n_epoch) 155 | period = self.n_epoch / self.n_cycle 156 | step = (self.stop - self.start) / ( 157 | period * self.ratio 158 | ) # linear schedule 159 | 160 | for c in range(self.n_cycle): 161 | 162 | v, i = self.start, 0 163 | while v <= self.stop and (int(i + c * period) < self.n_epoch): 164 | L[int(i + c * period)] = v 165 | v += step 166 | i += 1 167 | return L 168 | 169 | def _frange_cycle_sigmoid(self): 170 | L = np.ones(self.n_epoch) 171 | period = self.n_epoch / self.n_cycle 172 | step = (self.stop - self.start) / ( 173 | period * self.ratio 174 | ) # step is in [0,1] 175 | 176 | # transform into [-6, 6] for plots: v*12.-6. 177 | 178 | for c in range(self.n_cycle): 179 | 180 | v, i = self.start, 0 181 | while v <= self.stop: 182 | L[int(i + c * period)] = 1.0 / ( 183 | 1.0 + np.exp(-(v * 12.0 - 6.0)) 184 | ) 185 | v += step 186 | i += 1 187 | return L 188 | 189 | # function = 1 − cos(a), where a scans from 0 to pi/2 190 | 191 | def _frange_cycle_cosine(self): 192 | L = np.ones(self.n_epoch) 193 | period = self.n_epoch / self.n_cycle 194 | step = (self.stop - self.start) / ( 195 | period * self.ratio 196 | ) # step is in [0,1] 197 | 198 | # transform into [0, pi] for plots: 199 | 200 | for c in range(self.n_cycle): 201 | 202 | v, i = self.start, 0 203 | while v <= self.stop: 204 | L[int(i + c * period)] = 0.5 - 0.5 * math.cos(v * math.pi) 205 | v += step 206 | i += 1 207 | return L 208 | 209 | def _frange_ramp(self): 210 | L = np.ones(self.n_epoch) 211 | v, i = self.start, 0 212 | period = self.n_epoch / self.n_cycle 213 | 214 | step = (self.stop - self.start) / (period * self.ratio) 215 | while v <= self.stop: 216 | L[i] = v 217 | v += step 218 | i += 1 219 | return L 220 | 221 | def _frange_delta(self): 222 | L = np.zeros(self.n_epoch) 223 | period = self.n_epoch / (self.n_cycle + 1) 224 | 225 | for n in range(self.n_cycle + 1): 226 | if n % 2 == 1: 227 | L[int(period) * n : int(period) * (n + 1)] = self.stop 228 | else: 229 | L[int(period) * n : int(period) * (n + 1)] = self.start 230 | return L 231 | 232 | def _frange_mixed(self): 233 | 234 | L = np.ones(self.n_epoch) 235 | on = 300 236 | off = 600 237 | L[0:on] = 0 238 | 239 | period = (off - on) / self.n_cycle 240 | step = (self.stop - self.start) / ( 241 | period * self.ratio 242 | ) # step is in [0,1] 243 | 244 | # transform into [-6, 6] for plots: v*12.-6. 245 | 246 | for c in range(self.n_cycle): 247 | v, i = self.start, on 248 | while v <= self.stop: 249 | L[int(i + c * period)] = 1.0 / ( 250 | 1.0 + np.exp(-(v * 12.0 - 6.0)) 251 | ) 252 | v += step 253 | i += 1 254 | return L 255 | -------------------------------------------------------------------------------- /tests/test_train_eval_pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | import tempfile 5 | import unittest 6 | 7 | import torch 8 | 9 | from avae import config, settings 10 | from run import run_pipeline 11 | from tests import testdata_mrc, testdata_npy 12 | 13 | # fixing random seeds so we dont get fail on mrc tests 14 | torch.random.manual_seed(10) 15 | random.seed(10) 16 | 17 | 18 | class TrainEvalTest(unittest.TestCase): 19 | def setUp(self) -> None: 20 | """Test instantiation of the pipeline.""" 21 | self.testdata_mrc = os.path.dirname(testdata_mrc.__file__) 22 | self.testdata_npy = os.path.dirname(testdata_npy.__file__) 23 | 24 | self.data_params = { 25 | "datapath": self.testdata_mrc, 26 | "datatype": "mrc", 27 | "split": 10, 28 | "batch": 25, 29 | "no_val_drop": True, 30 | "affinity": os.path.join(self.testdata_mrc, "affinity_fsc_10.csv"), 31 | "classes": os.path.join(self.testdata_mrc, "classes.csv"), 32 | "dynamic": True, 33 | "epochs": 5, 34 | "channels": 3, 35 | "depth": 4, 36 | "latent_dims": 8, 37 | "pose_dims": 3, 38 | "learning": 0.03, 39 | "beta_min": 0, 40 | "beta": 1, 41 | "beta_cycle": 1, 42 | "cyc_method_beta": "flat", 43 | "gamma_min": 0, 44 | "gamma": 1, 45 | "cyc_method_gamma": "flat", 46 | "loss_fn": "MSE", 47 | "gaussian_blur": True, 48 | "normalise": True, 49 | "shift_min": True, 50 | "rescale": 32, 51 | "tensorboard": True, 52 | "classifier": "NN", 53 | "opt_method": "adam", 54 | "freq_all": 5, 55 | "vis_all": True, 56 | "vis_format": "png", 57 | } 58 | 59 | self.data = config.load_config_params(local_vars=self.data_params) 60 | config.setup_visualisation_config(self.data) 61 | 62 | def test_model_a_mrc(self): 63 | self.data["model"] = "a" 64 | settings.VIS_POSE_CLASS = "1b23,1dkg" 65 | 66 | ( 67 | n_dir_train, 68 | n_plots_train, 69 | n_latent_train, 70 | n_states_train, 71 | n_plots_eval, 72 | n_latent_eval, 73 | n_states_eval, 74 | ) = helper_train_eval(self.data) 75 | 76 | self.assertEqual(n_dir_train, 4) 77 | self.assertEqual(n_plots_train, 37) 78 | self.assertEqual(n_latent_train, 2) 79 | self.assertEqual(n_states_train, 2) 80 | 81 | self.assertEqual(n_plots_eval, 60) 82 | self.assertEqual(n_latent_eval, 4) 83 | self.assertEqual(n_states_eval, 3) 84 | 85 | def test_model_b_mrc(self): 86 | self.data["model"] = "b" 87 | settings.VIS_POSE_CLASS = "1b23,1dkg" 88 | 89 | ( 90 | n_dir_train, 91 | n_plots_train, 92 | n_latent_train, 93 | n_states_train, 94 | n_plots_eval, 95 | n_latent_eval, 96 | n_states_eval, 97 | ) = helper_train_eval(self.data) 98 | 99 | self.assertEqual(n_dir_train, 4) 100 | self.assertEqual(n_plots_train, 37) 101 | self.assertEqual(n_latent_train, 2) 102 | self.assertEqual(n_states_train, 2) 103 | self.assertEqual(n_plots_eval, 60) 104 | self.assertEqual(n_latent_eval, 4) 105 | self.assertEqual(n_states_eval, 3) 106 | 107 | def test_model_a_npy(self): 108 | self.data["model"] = "a" 109 | self.data["datatype"] = "npy" 110 | self.data["datapath"] = self.testdata_npy 111 | settings.VIS_POSE_CLASS = "2,i" 112 | 113 | self.data["affinity"] = os.path.join( 114 | self.testdata_npy, "affinity_an.csv" 115 | ) 116 | self.data["classes"] = os.path.join(self.testdata_npy, "classes.csv") 117 | ( 118 | n_dir_train, 119 | n_plots_train, 120 | n_latent_train, 121 | n_states_train, 122 | n_plots_eval, 123 | n_latent_eval, 124 | n_states_eval, 125 | ) = helper_train_eval(self.data) 126 | 127 | self.assertEqual(n_dir_train, 4) 128 | self.assertEqual(n_plots_train, 35) 129 | self.assertEqual(n_latent_train, 2) 130 | self.assertEqual(n_states_train, 2) 131 | self.assertEqual(n_plots_eval, 57) 132 | self.assertEqual(n_latent_eval, 4) 133 | self.assertEqual(n_states_eval, 3) 134 | 135 | def test_model_b_npy(self): 136 | self.data["model"] = "b" 137 | self.data["datatype"] = "npy" 138 | self.data["datapath"] = self.testdata_npy 139 | settings.VIS_POSE_CLASS = "2,i" 140 | 141 | self.data["affinity"] = os.path.join( 142 | self.testdata_npy, "affinity_an.csv" 143 | ) 144 | self.data["classes"] = os.path.join(self.testdata_npy, "classes.csv") 145 | ( 146 | n_dir_train, 147 | n_plots_train, 148 | n_latent_train, 149 | n_states_train, 150 | n_plots_eval, 151 | n_latent_eval, 152 | n_states_eval, 153 | ) = helper_train_eval(self.data) 154 | 155 | self.assertEqual(n_dir_train, 4) 156 | self.assertEqual(n_plots_train, 35) 157 | self.assertEqual(n_latent_train, 2) 158 | self.assertEqual(n_states_train, 2) 159 | self.assertEqual(n_plots_eval, 57) 160 | self.assertEqual(n_latent_eval, 4) 161 | self.assertEqual(n_states_eval, 3) 162 | 163 | def test_model_nopose(self): 164 | self.data["model"] = "u" 165 | self.data["pose_dims"] = 0 166 | 167 | ( 168 | n_dir_train, 169 | n_plots_train, 170 | n_latent_train, 171 | n_states_train, 172 | n_plots_eval, 173 | n_latent_eval, 174 | n_states_eval, 175 | ) = helper_train_eval(self.data) 176 | 177 | self.assertEqual(n_dir_train, 4) 178 | self.assertEqual(n_plots_train, 32) 179 | self.assertEqual(n_latent_train, 2) 180 | self.assertEqual(n_states_train, 2) 181 | 182 | self.assertEqual(n_plots_eval, 52) 183 | self.assertEqual(n_latent_eval, 4) 184 | self.assertEqual(n_states_eval, 3) 185 | 186 | def test_model_nogamma(self): 187 | self.data["model"] = "u" 188 | self.data["gamma"] = 0 189 | 190 | ( 191 | n_dir_train, 192 | n_plots_train, 193 | n_latent_train, 194 | n_states_train, 195 | n_plots_eval, 196 | n_latent_eval, 197 | n_states_eval, 198 | ) = helper_train_eval(self.data) 199 | 200 | self.assertEqual(n_dir_train, 4) 201 | self.assertEqual(n_plots_train, 35) 202 | self.assertEqual(n_latent_train, 2) 203 | self.assertEqual(n_states_train, 2) 204 | self.assertEqual(n_plots_eval, 56) 205 | self.assertEqual(n_latent_eval, 4) 206 | self.assertEqual(n_states_eval, 3) 207 | 208 | 209 | def helper_train_eval( 210 | data, eval=True, noplot=False, nolat=False, nostate=False 211 | ): 212 | temp_dir = tempfile.TemporaryDirectory(prefix='avae-') 213 | os.chdir(temp_dir.name) 214 | 215 | if eval: 216 | eval = [not eval, eval] 217 | else: 218 | eval = [eval] 219 | ret = [] 220 | 221 | # run training 222 | for e in eval: 223 | data["eval"] = e 224 | if data["eval"]: 225 | data["datapath"] = os.path.join(data["datapath"], "test") 226 | 227 | run_pipeline(data) 228 | 229 | n_plots, n_latent, n_states = (0, 0, 0) 230 | n_dir = len(next(os.walk(temp_dir.name))[1]) 231 | if os.path.exists(os.path.join(temp_dir.name, "plots")): 232 | n_plots = ( 233 | len(os.listdir(os.path.join(temp_dir.name, "plots"))) 234 | if not noplot 235 | else None 236 | ) 237 | if os.path.exists(os.path.join(temp_dir.name, "latents")): 238 | n_latent = ( 239 | len(os.listdir(os.path.join(temp_dir.name, "latents"))) 240 | if not nolat 241 | else None 242 | ) 243 | if os.path.exists(os.path.join(temp_dir.name, "states")): 244 | n_states = ( 245 | len(os.listdir(os.path.join(temp_dir.name, "states"))) 246 | if not nostate 247 | else None 248 | ) 249 | 250 | ret.extend([n_plots, n_latent, n_states]) 251 | ret.insert(0, n_dir) 252 | 253 | shutil.rmtree(temp_dir.name) 254 | 255 | return tuple(ret) 256 | 257 | 258 | if __name__ == "__main__": 259 | unittest.main() 260 | -------------------------------------------------------------------------------- /avae/loss.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | 5 | 6 | class AffinityLoss: 7 | """Affinity loss based on pre-calculated shape similarity. 8 | 9 | Parameters 10 | ---------- 11 | lookup : np.ndarray (M, M) 12 | A square symmetric matrix where each column and row is the index of an 13 | object from the training set, consisting of M different objects. The 14 | value at (i, j) is a scalar value encoding the shape similarity between 15 | objects i and j, pre-calculated using some shape (or other) metric. The 16 | identity of the matrix should be 1 since these objects are the same 17 | shape. The affinity similarity should be normalized to the range 18 | (-1, 1). 19 | 20 | Notes 21 | ----- 22 | The final loss is calculated using L1-norm. This could be changed, e.g. 23 | L2-norm. Not sure what the best one is yet. 24 | """ 25 | 26 | def __init__(self, lookup: torch.Tensor, device: torch.device): 27 | self.device = device 28 | self.lookup = torch.tensor(lookup).to(device) 29 | self.cos = torch.nn.CosineSimilarity(dim=1, eps=1e-8) 30 | self.l1loss = torch.nn.L1Loss() 31 | 32 | def __call__( 33 | self, y_true: torch.Tensor, y_pred: torch.Tensor 34 | ) -> torch.Tensor: 35 | """Return the affinity loss. 36 | 37 | Parameters 38 | ---------- 39 | y_true : torch.Tensor (N, ) 40 | A vector of N objects in the mini-batch of the indices representing 41 | the identity of the object as an index. These indices should 42 | correspond to the rows and columns of the `lookup` table. 43 | y_pred : torch.Tensor (N, latent_dims) 44 | An array of latent encodings of the N objects. 45 | 46 | Returns 47 | ------- 48 | loss : torch.Tensor 49 | The affinity loss. 50 | """ 51 | # first calculate the affinity, for the real classes 52 | c = ( 53 | torch.combinations(y_true, r=2, with_replacement=False) 54 | .to(self.device) 55 | .long() 56 | ) 57 | affinity = self.lookup[c[:, 0], c[:, 1]].to(self.device) 58 | 59 | # now calculate the latent similarity 60 | z_id = torch.tensor(list(range(y_pred.shape[0]))) 61 | c = torch.combinations(z_id, r=2, with_replacement=False) 62 | latent_similarity = self.cos(y_pred[c[:, 0], :], y_pred[c[:, 1], :]) 63 | loss = self.l1loss(latent_similarity, affinity) 64 | return loss 65 | 66 | 67 | class AVAELoss: 68 | """AffinityVAE loss consisting of reconstruction loss, beta 69 | parametrised latent regularisation loss and 70 | gamma parametrised affinity regularisation loss. Reconstruction loss 71 | should be Mean Squared Error for real 72 | valued data and Binary Cross-Entropy for binary data. Latent 73 | regularisation loss is KL Divergence. Affinity 74 | regularisation is defined in AffinityLoss class. 75 | 76 | Parameters 77 | ---------- 78 | device : torch.device 79 | Device used to calculate the loss. 80 | beta: list 81 | Beta parameter list defining weight on latent regularisation term. 82 | gamma : list 83 | Gamma parameter list defining weight on affinity regularisation term, 84 | default = 1. Only used if lookup_aff is present. 85 | lookup_aff : np.ndarray [M, M] 86 | A square symmetric matrix where each column and row is the index of an 87 | object class from the training set, consisting of M different classes. 88 | recon_fn : 'MSE' or 'BCE' 89 | Function used for reconstruction loss. BCE uses Binary 90 | Cross-Entropy for binary data and MSE uses Mean 91 | Squared Error for real-valued data. 92 | 93 | """ 94 | 95 | def __init__( 96 | self, 97 | device: torch.device, 98 | beta: list[float], 99 | gamma: list[float], 100 | lookup_aff: torch.Tensor | None = None, 101 | recon_fn: str = "MSE", 102 | klred: str = "mean", 103 | ): 104 | self.device = device 105 | self.recon_fn = recon_fn 106 | self.klred = klred 107 | self.beta = beta 108 | 109 | self.affinity_loss = None 110 | self.gamma = gamma 111 | 112 | if lookup_aff is not None and max(gamma) != 0: 113 | self.affinity_loss = AffinityLoss(lookup_aff, device) 114 | 115 | elif lookup_aff is None and max(gamma) != 0: 116 | raise RuntimeError( 117 | "Affinity matrix is needed to compute Affinity loss" 118 | ". Although you've set gamma, you have not provided --af/" 119 | "--affinity parameter." 120 | ) 121 | elif lookup_aff is not None and max(gamma) == 0: 122 | logging.warning( 123 | "\n\nWARNING: You provided affinity matrix but no gamma. Unless " 124 | "you provide gamma, affinity will be ignored and you're " 125 | "running a vanilla beta-VAE.\n" 126 | ) 127 | self.affinity_loss = None 128 | 129 | def __call__( 130 | self, 131 | x: torch.Tensor, 132 | recon_x: torch.Tensor, 133 | mu: torch.Tensor, 134 | logvar: torch.Tensor, 135 | epoch: int, 136 | batch_aff: torch.Tensor | None = None, 137 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 138 | """Return the aVAE loss. 139 | 140 | Parameters 141 | ---------- 142 | x : torch.Tensor (N, CH, Z, Y, X) 143 | Mini-batch of inputs, where N stands for the number of samples in 144 | the mini-batch, CH stands for number of 145 | channels and X, Y, Z define input dimensions. 146 | recon_x : torch.Tensor (N, CH, Z, Y, X) 147 | Mini-batch of outputs, where N stands for the number of samples in 148 | the mini-batch, CH stands for number of 149 | channels and X, Y, Z define input dimensions. 150 | mu : torch.Tensor (N, latent_dims) 151 | Mini-batch of encoder outputs representing latent means, where N 152 | stands for the number of samples in the 153 | mini-batch and 'latent_dims' defines the number of latent 154 | dimensions. 155 | logvar : torch.Tensor (N, latent_dims) 156 | Mini-batch of encoder outputs representing latent log of the 157 | variance, where N stands for the number of 158 | samples in the mini-batch and 'latent_dims' defines the number of 159 | latent dimensions. 160 | batch_aff : torch.Tensor (N, ) 161 | Optional, must be present if AVAELoss was 162 | initialised with 'lookup_aff' parameter containing affinity 163 | lookup matrix. A vector of N objects in the mini-batch of the 164 | indices representing the identity of the object's class as 165 | an index. These indices should correspond to the rows and columns 166 | of the `lookup_aff` table. 167 | 168 | Returns 169 | ------- 170 | total_loss : torch.Tensor 171 | Combined reconstruction, latent regularisaton and affinity loss. 172 | recon_loss : torch.Tensor 173 | Reconstruction loss. 174 | kldivergence : torch.Tensor 175 | Non-weighted KL Divergence loss. 176 | affin_loss : torch.Tensor 177 | Non-weighted affinity loss. 178 | """ 179 | if self.affinity_loss is not None and batch_aff is None: 180 | raise RuntimeError( 181 | "aVAE loss function requires affinity ids for the batch." 182 | ) 183 | 184 | # recon loss 185 | if self.recon_fn == "BCE": 186 | recon_loss = torch.nn.functional.binary_cross_entropy( 187 | x, recon_x, reduction="mean" 188 | ) 189 | elif self.recon_fn == "MSE": 190 | recon_loss = torch.nn.functional.mse_loss( 191 | x, recon_x, reduction="mean" 192 | ) 193 | else: 194 | raise RuntimeError( 195 | "AffinityVAE loss requires 'BCE' or 'MSE' for 'loss_fn' " 196 | "parameter." 197 | ) 198 | 199 | # kldiv loss 200 | if self.klred == "mean": 201 | kldivergence = -0.5 * torch.mean( 202 | 1 + logvar - mu.pow(2) - logvar.exp() 203 | ) 204 | elif self.klred == "sum": 205 | kldivergence = torch.mean( 206 | -0.5 207 | * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), axis=1), 208 | axis=0, 209 | ) 210 | else: 211 | raise RuntimeError( 212 | "AffinityVAE loss requires 'mean' or 'sum' for 'klreduction' " 213 | "parameter." 214 | ) 215 | 216 | # affinity loss 217 | affin_loss = torch.Tensor([0]).to(self.device) 218 | if self.affinity_loss is not None: 219 | affin_loss = self.affinity_loss(batch_aff, mu) 220 | 221 | # total loss 222 | total_loss = ( 223 | recon_loss 224 | + self.beta[epoch] * kldivergence 225 | + self.gamma[epoch] * affin_loss 226 | ) 227 | 228 | return total_loss, recon_loss, kldivergence, affin_loss 229 | -------------------------------------------------------------------------------- /tools/augment_mrcs.py: -------------------------------------------------------------------------------- 1 | """ 2 | For each class representative, augument the data with 3 | rotation and translation. 4 | """ 5 | 6 | import argparse 7 | import os 8 | from pathlib import Path 9 | 10 | import mrcfile 11 | import numpy as np 12 | from scipy import ndimage 13 | from scipy.ndimage import zoom 14 | from tqdm import tqdm 15 | 16 | 17 | def rescale(path_core, path_rescaled): 18 | path_rescaled = os.path.join( 19 | path_core.parent, path_core.stem + "_rescaled" 20 | ) 21 | if not os.path.exists(path_rescaled): 22 | os.mkdir(path_rescaled) 23 | 24 | for x in list(path_core.iterdir()): 25 | print("Rescaling", x.stem) 26 | 27 | try: 28 | with mrcfile.open(Path(x)) as mrc: 29 | nx, ny, nz = mrc.header.nx, mrc.header.ny, mrc.header.nz 30 | new_x = zoom(mrc.data, (64 / nz, 64 / ny, 64 / nx)) 31 | new_x = (new_x - np.min(new_x)) / np.ptp(new_x) 32 | 33 | x_range, y_range, z_range = [ 34 | sum( 35 | [ 36 | np.any(z_slice > 0.01) 37 | for z_slice in np.rollaxis(new_x, dim) 38 | ] 39 | ) 40 | for dim in [2, 1, 0] 41 | ] 42 | 43 | new_x = np.where(new_x < 0.01, 0, new_x) 44 | 45 | with mrcfile.new( 46 | os.path.join(path_rescaled, x.stem + ".mrc"), 47 | overwrite=True, 48 | ) as mrc: 49 | mrc.set_data(new_x) 50 | mrc.header.label[1] = f"ratio_z_x={str(z_range/x_range)}" 51 | mrc.header.label[2] = f"ratio_z_y={str(z_range/y_range)}" 52 | mrc.header.nlabl = 3 53 | 54 | except ValueError: 55 | print(f"Problem with file: {x}") 56 | 57 | except IsADirectoryError: 58 | continue 59 | 60 | 61 | def rotate_the_pokemino_1_axis( 62 | array, axes=(0, 1), theta=None, order=1, s=None 63 | ): 64 | """Rotates the Pokemino using scipy.ndimage.rotate""" 65 | 66 | assert ( 67 | type(axes) == tuple 68 | and len(axes) == 2 69 | and all([type(i) == int for i in axes]) 70 | and all([i in range(0, 3) for i in axes]) 71 | ), "Incorrect axes parameter: pass a tuple of 2 axes." 72 | 73 | if not theta: 74 | np.random.seed(s) 75 | theta = np.random.choice([i for i in range(0, 360)]) 76 | else: 77 | assert (isinstance(theta, int)) and theta in range( 78 | 0, 360 79 | ), "Error: Pokemino3D.rotate_the_brick requires the value for theta in range <0, 360>." 80 | array = ndimage.rotate( 81 | array, angle=theta, axes=axes, order=order, reshape=False 82 | ) 83 | 84 | return array, theta 85 | 86 | 87 | def rotate_the_pokemino_3_axes( 88 | array, theta_x=None, theta_y=None, theta_z=None 89 | ): 90 | 91 | array, z_rot = rotate_the_pokemino_1_axis( 92 | array, axes=(1, 0), theta=theta_x 93 | ) 94 | array, x_rot = rotate_the_pokemino_1_axis( 95 | array, axes=(2, 1), theta=theta_y 96 | ) 97 | array, y_rot = rotate_the_pokemino_1_axis( 98 | array, axes=(0, 2), theta=theta_z 99 | ) 100 | 101 | return array, x_rot, y_rot, z_rot 102 | 103 | 104 | def shift_block(array, sh, axis=1): 105 | 106 | assert len(array.shape) == 3, "Pass a 3D array" 107 | assert type(sh) is int, "Pass a valid shift" 108 | assert sh < array.shape[axis], "You're asking for too much of a shift" 109 | 110 | if sh == 0: 111 | return array 112 | 113 | if sh < 0: 114 | if axis == 0: 115 | return np.concatenate( 116 | (array[-sh:, :, :], np.zeros_like(array[:-sh, :, :])), axis=0 117 | ) 118 | if axis == 1: 119 | return np.concatenate( 120 | (array[:, -sh:, :], np.zeros_like(array[:, :-sh, :])), axis=1 121 | ) 122 | if axis == 2: 123 | return np.concatenate( 124 | (array[:, :, -sh:], np.zeros_like(array[:, :, :-sh])), axis=2 125 | ) 126 | 127 | if sh > 0: 128 | if axis == 0: 129 | return np.concatenate( 130 | (np.zeros_like(array[-sh:, :, :]), array[:-sh, :, :]), axis=0 131 | ) 132 | if axis == 1: 133 | return np.concatenate( 134 | (np.zeros_like(array[:, -sh:, :]), array[:, :-sh, :]), axis=1 135 | ) 136 | if axis == 2: 137 | return np.concatenate( 138 | (np.zeros_like(array[:, :, -sh:]), array[:, :, :-sh]), axis=2 139 | ) 140 | 141 | 142 | def add_label(mrc, param_name, param_value): 143 | 144 | label_count = mrc.header.nlabl 145 | mrc.header.label[label_count] = f"{param_name}={str(param_value)}" 146 | mrc.header.nlabl += 1 147 | 148 | 149 | def read_rotate_translate_save_mrc( 150 | src_path, output_path, mrcs_list, n_pokeminos, nrot, ntrans 151 | ): 152 | 153 | if not os.path.exists(output_path): 154 | os.mkdir(output_path) 155 | 156 | for i in tqdm(range(n_pokeminos)): 157 | 158 | protein = np.random.choice(mrcs) 159 | meta = [] 160 | new_mrc = mrcfile.open(Path(src_path, f"{protein}.mrc")).data 161 | 162 | if nrot == 1: 163 | new_mrc, theta_x = rotate_the_pokemino_1_axis(new_mrc) 164 | meta.append(str(theta_x)) 165 | 166 | elif nrot == 3: 167 | new_mrc, theta_x, theta_y, theta_z = rotate_the_pokemino_3_axes( 168 | new_mrc 169 | ) 170 | meta.append(str(theta_x)) 171 | meta.append(str(theta_y)) 172 | meta.append(str(theta_z)) 173 | 174 | x_range, y_range, z_range = [ 175 | [np.any(slice > 0.01) for slice in np.rollaxis(new_mrc, dim)] 176 | for dim in [0, 1, 2] 177 | ] 178 | 179 | if ntrans >= 1: 180 | # Shifting in x 181 | shift_x = int( 182 | np.random.choice( 183 | range( 184 | -x_range.index(True) + 1, x_range[::-1].index(True) - 1 185 | ) 186 | ) 187 | ) 188 | new_mrc = shift_block(new_mrc, shift_x, axis=0) 189 | meta.append(str(shift_x)) 190 | 191 | if ntrans >= 2: 192 | # Shifting in y 193 | shift_y = int( 194 | np.random.choice( 195 | range( 196 | -y_range.index(True) + 1, y_range[::-1].index(True) - 1 197 | ) 198 | ) 199 | ) 200 | new_mrc = shift_block(new_mrc, shift_y, axis=1) 201 | meta.append(str(shift_y)) 202 | 203 | if ntrans >= 3: 204 | # Shifting in z 205 | shift_z = int( 206 | np.random.choice( 207 | range( 208 | -z_range.index(True) + 1, z_range[::-1].index(True) - 1 209 | ) 210 | ) 211 | ) 212 | new_mrc = shift_block(new_mrc, shift_z, axis=2) 213 | meta.append(str(shift_z)) 214 | 215 | meta = "-".join(meta) 216 | with mrcfile.new( 217 | Path(output_path / f"{protein}_{meta}.mrc"), overwrite=True 218 | ) as mrc: 219 | 220 | mrc.set_data(new_mrc) 221 | 222 | """if i == 0: 223 | mrc.print_header()""" 224 | 225 | 226 | parser = argparse.ArgumentParser() 227 | parser.add_argument("--data") 228 | args = parser.parse_args() 229 | src_path = Path(args.data) 230 | 231 | # rescale 232 | # output_path = Path(os.path.join(src_path.parent, src_path.stem+"_rescaled/")) 233 | # rescale(src_path, output_path) 234 | 235 | # random.seed(0) 236 | # mrcs = random.sample([x.stem for x in list(src_path.iterdir())], k=20i) 237 | mrcs = [x.stem for x in list(src_path.iterdir())] 238 | n_pokeminos = 10000 239 | 240 | # 1rot 241 | # output_path = Path(os.path.join(src_path.parent, src_path.stem+"src_path.core+"_1rot/")) 242 | # read_rotate_translate_save_mrc(src_path = src_path, output_path = output_path, mrcs_list = mrcs, n_pokeminos = n_pokeminos, nrot = 1, ntrans = 0) 243 | 244 | # 3rot 245 | output_path = Path( 246 | os.path.join(src_path.parent, src_path.stem + "_3rot_10000/") 247 | ) 248 | read_rotate_translate_save_mrc( 249 | src_path=src_path, 250 | output_path=output_path, 251 | mrcs_list=mrcs, 252 | n_pokeminos=n_pokeminos, 253 | nrot=3, 254 | ntrans=0, 255 | ) 256 | 257 | # 1trans 258 | # output_path = Path(os.path.join(src_path.parent, src_path.stem+"_1trans/")) 259 | # read_rotate_translate_save_mrc(src_path = src_path, output_path = output_path, mrcs_list = mrcs, n_pokeminos = n_pokeminos, nrot = 0, ntrans = 1) 260 | 261 | # 3trans 262 | # output_path = Path(os.path.join(src_path.parent, src_path.stem+"_3trans/")) 263 | # read_rotate_translate_save_mrc(src_path = src_path, output_path = output_path, mrcs_list = mrcs, n_pokeminos = n_pokeminos, nrot = 0, ntrans = 3) 264 | 265 | # 1rot + 1trans 266 | # output_path = Path(os.path.join(src_path.parent, src_path.stem+"_1rot_1trans/")) 267 | # read_rotate_translate_save_mrc(src_path = src_path, output_path = output_path, mrcs_list = mrcs, n_pokeminos = n_pokeminos, nrot = 1, ntrans = 1) 268 | 269 | # 3rot + 3trans 270 | # output_path = Path(os.path.join(src_path.parent, src_path.stem+"_3rot_3trans/")) 271 | # read_rotate_translate_save_mrc(src_path = src_path, output_path = output_path, mrcs_list = mrcs, n_pokeminos = n_pokeminos, nrot = 3, ntrans = 3) 272 | -------------------------------------------------------------------------------- /avae/evaluate.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import lightning as lt 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | 9 | from . import settings, vis 10 | from .data import load_data 11 | from .utils import accuracy, latest_file 12 | from .utils_learning import add_meta 13 | 14 | 15 | def evaluate( 16 | datapath: str, 17 | datatype: str, 18 | state: str | None, 19 | meta: str | None, 20 | lim: int | None, 21 | splt: int, 22 | batch_s: int, 23 | classes: str | None, 24 | use_gpu: bool, 25 | gaussian_blur: bool, 26 | normalise: bool, 27 | shift_min: bool, 28 | rescale: bool, 29 | classifier: str, 30 | ): 31 | """Function for evaluating the model. Loads the data, model and runs the evaluation. Saves the results of the 32 | evaluation in the plot and latents directories. 33 | 34 | Parameters 35 | ---------- 36 | datapath: str 37 | Path to the data directory. 38 | datatype: str 39 | data file formats : mrc, npy 40 | state: str 41 | Path to the model state file to be used for evaluation/resume. 42 | meta: str 43 | Path to the meta file to be used for evaluation/resume. 44 | lim: int 45 | Limit the number of samples to load. 46 | splt: int 47 | Percentage of data to be used for validation. 48 | batch_s: int 49 | Batch size. 50 | classes: list 51 | List of classes to be selected from the data for the training and validation set. 52 | use_gpu: bool 53 | If True, the model will be trained on GPU. 54 | gaussian_blur: bool 55 | if True, Gaussian bluring is applied to the input before being passed to the model. 56 | This is added as a way to remove noise from the input data. 57 | normalise: 58 | In True, the input data is normalised before being passed to the model. 59 | shift_min: bool 60 | If True, the input data is shifted to have a minimum value of 0 and max 1. 61 | classifier: str 62 | The method to use on the latent space classification. Can be neural network (NN), k nearest neighbourgs (KNN) or logistic regression (LR). 63 | 64 | 65 | """ 66 | fabric = lt.Fabric() 67 | fabric.launch() 68 | # ############################### DATA ############################### 69 | tests, data_dim = load_data( 70 | datapath=datapath, 71 | datatype=datatype, 72 | lim=lim, 73 | splt=splt, 74 | batch_s=batch_s, 75 | eval=True, 76 | gaussian_blur=gaussian_blur, 77 | normalise=normalise, 78 | shift_min=shift_min, 79 | rescale=rescale, 80 | fabric=fabric, 81 | ) 82 | 83 | # ############################### MODEL ############################### 84 | device = fabric.device 85 | 86 | if state is None: 87 | if not os.path.exists("states"): 88 | raise RuntimeError( 89 | "There are no existing model states saved or provided via the state flag in config unable to evaluate." 90 | ) 91 | else: 92 | state = latest_file("states", ".pt") 93 | state = os.path.join("states", state) 94 | 95 | s = os.path.basename(state) 96 | fname = s.split(".")[0].split("_") 97 | dshape = list(tests)[0][0].shape[2:] 98 | pose_dims = int(fname[-1]) 99 | 100 | logging.info("Loading model from: {}".format(state)) 101 | checkpoint = torch.load(state) 102 | vae = checkpoint["model_class_object"] 103 | vae.load_state_dict(checkpoint["model_state_dict"]) 104 | vae = fabric.setup(vae) 105 | 106 | # ########################## EVALUATE ################################ 107 | 108 | if meta is None: 109 | metas = latest_file("states", ".pkl") 110 | meta = os.path.join("states", metas) 111 | 112 | logging.info("Loading model from: {}".format(meta)) 113 | meta_df = pd.read_pickle(meta) 114 | 115 | # create holders for latent spaces and labels 116 | x_test, y_test, c_test = [], [], [] 117 | p_test = None 118 | 119 | if pose_dims != 0: 120 | p_test = [] 121 | 122 | logging.debug("Batch: [0/%d]" % (len(tests))) 123 | 124 | vae.eval() 125 | for batch_number, (t, label, aff, meta_data) in enumerate(tests): 126 | 127 | # get data in the right device 128 | t, aff = t.to(device), aff.to(device) 129 | t = t.to(torch.float32) 130 | 131 | # forward 132 | t_hat, t_mu, t_logvar, tlat, tlat_pose = vae(t) 133 | 134 | x_test.extend(t_mu.cpu().detach().numpy()) # store latents 135 | c_test.extend(t_logvar.cpu().detach().numpy()) 136 | # if labels are present save them otherwise save test 137 | try: 138 | y_test.extend(label) 139 | except IndexError: 140 | np.full(shape=len(t), fill_value="test") 141 | if tlat_pose is not None: 142 | p_test.extend(tlat_pose.cpu().detach().numpy()) 143 | 144 | meta_df = add_meta( 145 | data_dim, 146 | meta_df, 147 | meta_data, 148 | t_hat, 149 | t_mu, 150 | tlat_pose, 151 | tlat, 152 | mode="evl", 153 | ) 154 | 155 | logging.debug("Batch: [%d/%d]" % (batch_number + 1, len(tests))) 156 | logging.info("Batch: [%d/%d]" % (batch_number + 1, len(tests))) 157 | 158 | # ########################## VISUALISE ################################ 159 | if classes is not None: 160 | classes_list = pd.read_csv(classes).columns.tolist() 161 | else: 162 | classes_list = [] 163 | # visualise reconstructions - last batch 164 | if settings.VIS_REC: 165 | vis.recon_plot(t, t_hat, y_test, data_dim, mode="evl") 166 | 167 | # visualise latent disentanglement 168 | if settings.VIS_DIS: 169 | vis.latent_disentamglement_plot( 170 | dshape, 171 | x_test, 172 | vae, 173 | device, 174 | poses=p_test, 175 | mode="_eval", 176 | ) 177 | 178 | # visualise pose disentanglement 179 | if pose_dims != 0 and settings.VIS_POS: 180 | vis.pose_disentanglement_plot( 181 | dshape, 182 | x_test, 183 | p_test, 184 | vae, 185 | device, 186 | mode="_eval", 187 | ) 188 | 189 | if pose_dims != 0 and settings.VIS_POSE_CLASS: 190 | vis.pose_class_disentanglement_plot( 191 | dshape, 192 | x_test, 193 | y_test, 194 | settings.VIS_POSE_CLASS, 195 | p_test, 196 | vae, 197 | device, 198 | mode="_eval", 199 | ) 200 | # visualise interpolations 201 | if settings.VIS_INT: 202 | vis.interpolations_plot( 203 | dshape, 204 | x_test, 205 | np.ones(len(x_test)), 206 | vae, 207 | device, 208 | poses=p_test, 209 | mode="_eval", 210 | ) 211 | 212 | # visualise embeddings 213 | if settings.VIS_EMB: 214 | vis.latent_embed_plot_umap( 215 | x_test, np.array(y_test), classes_list, "_eval" 216 | ) 217 | vis.latent_embed_plot_tsne( 218 | x_test, np.array(y_test), classes_list, "_eval" 219 | ) 220 | 221 | if settings.VIS_SIM: 222 | vis.latent_space_similarity_plot( 223 | x_test, np.array(y_test), mode="_eval", classes_order=classes_list 224 | ) 225 | 226 | # ############################# Predict ############################# 227 | # get training latent space from metadata for comparison and accuracy estimation 228 | latents_training = meta_df[meta_df["mode"] == "trn"][ 229 | [col for col in meta_df if col.startswith("lat")] 230 | ].to_numpy() 231 | latents_training_id = meta_df[meta_df["mode"] == "trn"]["id"] 232 | 233 | if settings.VIS_DYN: 234 | # merge img and rec into one image for display in altair 235 | meta_df["image"] = meta_df["image"].apply(vis.merge) 236 | vis.dyn_latentembed_plot(meta_df, 0, embedding="umap", mode="_eval") 237 | vis.dyn_latentembed_plot(meta_df, 0, embedding="tsne", mode="_eval") 238 | 239 | # visualise embeddings 240 | if settings.VIS_EMB: 241 | vis.latent_embed_plot_umap( 242 | np.concatenate([x_test, latents_training]), 243 | np.concatenate([np.array(y_test), np.array(latents_training_id)]), 244 | classes_list, 245 | "_train_eval_comparison", 246 | ) 247 | vis.latent_embed_plot_tsne( 248 | np.concatenate([x_test, latents_training]), 249 | np.concatenate([np.array(y_test), np.array(latents_training_id)]), 250 | classes_list, 251 | "_train_eval_comparison", 252 | ) 253 | 254 | # visualise accuracy 255 | (train_acc, val_acc, val_acc_selected, ypred_train, ypred_val,) = accuracy( 256 | latents_training, 257 | np.array(latents_training_id), 258 | x_test, 259 | np.array(y_test), 260 | classifier=classifier, 261 | ) 262 | logging.info( 263 | "------------------->>> Accuracy: Train: %f | Val : %f | Val with unseen labels: %f\n" 264 | % (train_acc, val_acc_selected, val_acc) 265 | ) 266 | vis.accuracy_plot( 267 | np.array(latents_training_id), 268 | ypred_train, 269 | y_test, 270 | ypred_val, 271 | classes, 272 | mode="_eval", 273 | ) 274 | vis.f1_plot( 275 | np.array(latents_training_id), 276 | ypred_train, 277 | y_test, 278 | ypred_val, 279 | classes, 280 | mode="_eval", 281 | ) 282 | logging.info("Saving meta files with evaluation data.") 283 | 284 | metas = os.path.basename(meta) 285 | # save metadata with evaluation data 286 | meta_df.to_pickle( 287 | os.path.join("states", metas.split(".")[0] + "_eval.pkl") 288 | ) 289 | -------------------------------------------------------------------------------- /tools/create_subtomo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | 4 | import mrcfile 5 | import numpy as np 6 | import pandas as pd 7 | from scipy.ndimage import rotate 8 | 9 | 10 | def create_subtomo( 11 | input_path, 12 | output_path, 13 | datatype, 14 | annot_path, 15 | vox_size=[32, 32, 32], 16 | bandpass=False, 17 | low_freq=0, 18 | high_freq=15, 19 | gaussian_blur=False, 20 | normalise=False, 21 | add_noise=False, 22 | noise_int=0, 23 | padding=None, 24 | augment=None, 25 | aug_th_min=-45, 26 | aug_th_max=45, 27 | ): 28 | """Function to train an AffinityVAE model. The inputs are training configuration parameters. In this function the 29 | data is loaded, selected and split into training, validation and test sets, the model is initialised and trained 30 | over epochs, the results are evaluated visualised and saved and the epoch level with a frequency configured with 31 | input parameters. 32 | 33 | Parameters 34 | ---------- 35 | input_path: str 36 | Path to the folder containing the full tomogram/image. 37 | output_path: str 38 | Path to the folder for output subtomograms 39 | datatype : str 40 | data file formats : mrc, npy 41 | annot_path : str 42 | path to the folder containing the name of the particles and their x,y,z coordinates 43 | file names should be identical to that of the full tomograms with the extension of txt 44 | for mat of the txt files is , 'class','z','y','x','cx','cy','cz 45 | vox_size: list 46 | size of each subtomogram voxel given as a list where vox_size: [x,y,x] 47 | bandpass: bool 48 | Apply band pass filter to the full tomogram before extracting the subtomograms 49 | low_freq: float 50 | Lower frequency threshold for the band pass filter 51 | high_freq: float 52 | higher frequency threshold for the band pass filter 53 | gaussian_blur: bool 54 | if True, Gaussian bluring is applied to the input before being passed to the model. 55 | This is added as a way to remove noise from the input data. 56 | normalise: bool 57 | In True, the input data is normalised before being passed to the model. 58 | add_noise: bool 59 | Add noise to images, this can be used for benchmarking 60 | noise_int: int 61 | noise intensity 62 | padding: list 63 | size of padding boxes 64 | augment: int 65 | perform the given number of aumentation on each voxel before saving 66 | aug_th_max: int 67 | The minimum value of the augmentation range in degrees 68 | aug_th_min: int 69 | The maximum value of the augmentation range in degrees 70 | """ 71 | 72 | a = 0 73 | # the name of all full tomograms 74 | file_list = [f for f in os.listdir(input_path) if "." + datatype in f] 75 | print("List of tomograms provided:", file_list) 76 | # the name of all annotations for tomograms 77 | annot_list = [f for f in os.listdir(input_path) if "." + "txt" in f] 78 | print("List of annotation files provided:", annot_list) 79 | 80 | if len(annot_list) != len(file_list): 81 | raise RuntimeError( 82 | "number of full tomograms do not match the number of annotation files" 83 | ) 84 | 85 | for t, tomo in enumerate(file_list): 86 | print("#####################################") 87 | print("#####################################") 88 | print(f"{t}: segmenting tomogram {tomo}") 89 | print("#####################################") 90 | print("#####################################") 91 | 92 | data = read_mrc(os.path.join(input_path, tomo)) 93 | if bandpass: 94 | print("creating band pass filter") 95 | bandpass_filter_image = bandpass_filter( 96 | data.shape, low_freq, high_freq 97 | ) 98 | print("Applying band pass filter") 99 | data = apply_bandpass_filter(data, bandpass_filter_image) 100 | 101 | if add_noise: 102 | data = add_g_noise(data, noise_int) 103 | 104 | if gaussian_blur: 105 | print("gaussian blur is not implemented yet") 106 | data = data 107 | 108 | if normalise: 109 | data = normalisiation(data) 110 | 111 | particle_df = particles_GT( 112 | os.path.join(annot_path, f"{tomo[:-4]}.txt"), output_path 113 | ) 114 | 115 | particle_df = delete_detection_in_edges( 116 | particle_df, data.shape, vox_size 117 | ) 118 | 119 | for index, protein in particle_df.iterrows(): 120 | name = protein["class"] 121 | s = [int(n / 2) for n in vox_size] 122 | mol = data[ 123 | protein["x"] - s[0] : protein["x"] + s[0], 124 | protein["y"] - s[1] : protein["y"] + s[1], 125 | protein["z"] - s[2] : protein["z"] + s[2], 126 | ] 127 | for a in range(augment): 128 | mol = augmentation(mol, a, aug_th_min, aug_th_max) 129 | if padding is not None: 130 | mol = padding_mol(mol, padding) 131 | 132 | mol_file_name = ( 133 | f"{name}_{tomo[:-4]}_{index}_Th{a}_n{noise_int}.mrc" 134 | ) 135 | print("saving:", mol_file_name) 136 | mrcfile.write( 137 | os.path.join(output_path, mol_file_name), 138 | mol, 139 | overwrite=True, 140 | ) 141 | 142 | 143 | def normalisiation(x): 144 | x = (x - x.min()) / (x.max() - x.min()) 145 | return x 146 | 147 | 148 | def augmentation(mol, a, aug_th_min, aug_th_max): 149 | mol = np.array(mol) 150 | deg_per_rot = 5 151 | angle = np.random.randint(aug_th_min, aug_th_max, size=(len(mol.shape),)) 152 | for ax in range(angle.size): 153 | theta = angle[ax] * deg_per_rot * a 154 | axes = (ax, (ax + 1) % angle.size) 155 | mol = rotate(mol, theta, axes=axes, order=0, reshape=False) 156 | return mol 157 | 158 | 159 | def bandpass_filter(image_size, bp_low, bp_high): 160 | 161 | bandpass_filter = np.zeros((image_size), dtype=np.float32) 162 | if len(image_size) == 2: 163 | for u in range(image_size[0]): 164 | for v in range(image_size[1]): 165 | D = np.sqrt( 166 | (u - image_size[0] / 2) ** 2 + (v - image_size[1] / 2) ** 2 167 | ) 168 | if D <= bp_low: 169 | bandpass_filter[u, v] = 1 170 | elif D >= bp_high: 171 | bandpass_filter[u, v] = 1 172 | elif len(image_size) == 3: 173 | for u in range(image_size[0]): 174 | for v in range(image_size[1]): 175 | for w in range(image_size[2]): 176 | D = np.sqrt( 177 | (u - image_size[0] / 2) ** 2 178 | + (v - image_size[1] / 2) ** 2 179 | + (w - image_size[2] / 2) ** 2 180 | ) 181 | 182 | if D <= bp_low: 183 | bandpass_filter[u, v, w] = 1 184 | elif D >= bp_high: 185 | bandpass_filter[u, v, w] = 1 186 | return bandpass_filter 187 | 188 | 189 | def apply_bandpass_filter(image, bandpass_filter): 190 | F = np.fft.fftn(image) 191 | Fshift = np.fft.fftshift(F) 192 | Gshift = Fshift * bandpass_filter 193 | G = np.fft.ifftshift(Gshift) 194 | filtered_image = abs(np.fft.ifftn(G)) 195 | return filtered_image.astype("float32") 196 | 197 | 198 | def add_g_noise(input_a, scale): 199 | # Gaussian distribution parameters 200 | mean = 0 201 | var = 0.1 202 | sigma = var**0.5 203 | 204 | gaussian = np.random.normal( 205 | mean, sigma, (input_a.shape[0], input_a.shape[1], input_a.shape[1]) 206 | ) 207 | 208 | output_a = input_a + gaussian * scale 209 | return output_a.astype(np.float32) 210 | 211 | 212 | def read_mrc(path): 213 | """ 214 | Takes a path and read a mrc file convert the data to np array 215 | """ 216 | warnings.simplefilter( 217 | "ignore" 218 | ) # to mute some warnings produced when opening the tomos 219 | with mrcfile.open(path, mode="r+", permissive=True) as mrc: 220 | mrc.update_header_from_data() 221 | 222 | mrc.header.map = mrcfile.constants.MAP_ID 223 | mrc = mrc.data 224 | 225 | with mrcfile.open(path) as mrc: 226 | data = np.array(mrc.data) 227 | return data 228 | 229 | 230 | def particles_GT(annot, output_path): 231 | 232 | labels = np.loadtxt(annot, dtype="str") 233 | 234 | labels = np.reshape(labels, (-1, 7)) 235 | df = pd.DataFrame(labels) 236 | df.columns = ["class", "z", "y", "x", "cx", "cy", "cz"] 237 | 238 | df = df.astype({"x": "int"}) 239 | df = df.astype({"y": "int"}) 240 | df = df.astype({"z": "int"}) 241 | proteins = df["class"].unique() 242 | 243 | df.drop(df[df["class"] == "vesicle"].index, inplace=True) 244 | 245 | with open(f"{output_path}/classes.csv", "w") as f: 246 | f.write(",".join(proteins)) 247 | return df 248 | 249 | 250 | def delete_detection_in_edges(df, data_shape, vox_size): 251 | # drop the particles that are appearing in the edges 252 | df.drop(df[df["x"] > data_shape[0] - vox_size[0]].index, inplace=True) 253 | df.drop(df[df["y"] > data_shape[1] - vox_size[1]].index, inplace=True) 254 | df.drop(df[df["z"] > data_shape[2] - vox_size[2]].index, inplace=True) 255 | df.drop(df[df["x"] < vox_size[0]].index, inplace=True) 256 | df.drop(df[df["y"] < vox_size[1]].index, inplace=True) 257 | df.drop(df[df["z"] < vox_size[2]].index, inplace=True) 258 | return df 259 | 260 | 261 | def padding_mol(array, padding_size): 262 | """ 263 | :param array: numpy array 264 | :param xx: desired height 265 | :param yy: desirex width 266 | :return: padded array 267 | """ 268 | 269 | h = array.shape[0] 270 | w = array.shape[1] 271 | z = array.shape[2] 272 | 273 | xx = padding_size[0] 274 | yy = padding_size[1] 275 | zz = padding_size[2] 276 | 277 | a = (xx - h) // 2 278 | aa = xx - a - h 279 | 280 | b = (yy - w) // 2 281 | bb = yy - b - w 282 | 283 | c = (zz - z) // 2 284 | cc = zz - c - z 285 | 286 | return np.pad( 287 | array, pad_width=((a, aa), (b, bb), (c, cc)), mode="constant" 288 | ) 289 | -------------------------------------------------------------------------------- /avae/decoders/differentiable.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | 5 | from avae.decoders.base import AbstractDecoder 6 | from avae.decoders.spatial import ( 7 | CartesianAxes, 8 | SpatialDims, 9 | axis_angle_to_quaternion, 10 | quaternion_to_rotation_matrix, 11 | ) 12 | 13 | 14 | class GaussianSplatRenderer(torch.nn.Module): 15 | """Perform gaussian splatting.""" 16 | 17 | def __init__( 18 | self, 19 | shape: Tuple[int], 20 | *, 21 | device: torch.device = torch.device("cpu"), 22 | ): 23 | super().__init__() 24 | 25 | self._shape = shape 26 | self._ndim = len(shape) 27 | self.device = device 28 | 29 | if len(shape) not in (SpatialDims.TWO, SpatialDims.THREE): 30 | raise ValueError("Only 2D or 3D rotations are currently supported") 31 | 32 | grids = torch.meshgrid( 33 | *[torch.linspace(-1, 1, sz) for sz in shape], 34 | indexing="xy", 35 | ) 36 | 37 | # add all zeros for z- if we have a 2d grid 38 | if len(shape) == SpatialDims.TWO: 39 | grids += ( 40 | torch.zeros_like( 41 | grids[0], 42 | ), 43 | ) 44 | 45 | self.coords = ( 46 | torch.stack([torch.ravel(grid) for grid in grids], axis=0) 47 | .transpose(0, 1) 48 | .unsqueeze(0) 49 | .to(self.device) 50 | ) 51 | 52 | def forward( 53 | self, 54 | splats: torch.Tensor, 55 | weights: torch.Tensor, 56 | sigmas: torch.Tensor, 57 | *, 58 | splat_sigma_range: Tuple[float, float] = (0.0, 1.0), 59 | ) -> torch.Tensor: 60 | """Render the Gaussian splats in an image volume. 61 | 62 | Parameters 63 | ---------- 64 | splats : tensor 65 | An (N, D, 3) tensor specifying the X,Y,Z coordinates of the D 66 | gaussians for the minibatch of N images. 67 | weights : tensor 68 | An (N, D, 1) tensor specifying the weights of the D gaussians for 69 | the minibatch of N images.In the range of 0 to 1. 70 | sigmas : tensor 71 | An (N, D, 1) tensor specifying the standard deviations of the D 72 | gaussians for the minibatch of N images. In the range of 0 to 1. 73 | splat_sigma_range : tuple 74 | The minimum and maximum values for sigma. Final sigma is calculated 75 | as sigmas * (max_sigma - min_sigma) + min_sigma. 76 | 77 | Returns 78 | ------- 79 | x : tensor 80 | The rendered image volume. 81 | 82 | Notes 83 | ----- 84 | This isn't very memory efficient since the GMM is evaluated for every 85 | voxel in the output. This means an (N, M*M*M, D) matrix, where M is the 86 | dimensions of the image volume (e.g. 32x32x32) and D is the number of 87 | gaussians (e.g. 1024). This leads to a matrix of 32 x 32768 x 1024 88 | for a minibatch of 32 volumes. 89 | """ 90 | if torch.cuda.device_count() == 0 and self.device.type == 'cuda': 91 | self.device = torch.device('cpu') 92 | 93 | # scale the sigma values 94 | min_sigma, max_sigma = splat_sigma_range 95 | sigmas = sigmas * (max_sigma - min_sigma) + min_sigma 96 | 97 | # transpose keeping batch intact 98 | # coords_t = torch.swapaxes(self.coords, 1, 2) 99 | splats = splats.to(self.device) 100 | splats_t = torch.swapaxes(splats, 1, 2) 101 | splats_t = splats_t.to(self.device) 102 | self.coords = self.coords.to( 103 | self.device 104 | ) # calculate D^2 for all combinations of voxel and gaussian 105 | D_squared = torch.sum( 106 | self.coords[:, :, None, :] ** 2 + splats_t[:, None, :, :] ** 2, 107 | axis=-1, 108 | ) - 2 * torch.matmul(self.coords, splats) 109 | # scale the gaussians 110 | sigmas = 2.0 * sigmas[:, None, :] ** 2 111 | sigmas = sigmas.to(self.device) 112 | weights = weights.to(self.device) 113 | 114 | # now splat the gaussians 115 | x = torch.sum( 116 | weights[:, None, :] * torch.exp(-D_squared / sigmas), 117 | axis=-1, 118 | ) 119 | 120 | return x.reshape((-1, *self._shape)).unsqueeze(1) 121 | 122 | 123 | class SoftStep(torch.nn.Module): 124 | """Soft (differentiable) step function in the range of 0-1.""" 125 | 126 | def __init__(self, *, k: float = 1.0): 127 | super().__init__() 128 | self.k = k 129 | 130 | def forward(self, x: torch.Tensor) -> torch.Tensor: 131 | return 1.0 / (1.0 + torch.exp(-self.k * x)) 132 | 133 | 134 | class GaussianSplatDecoder(AbstractDecoder): 135 | """Differentiable Gaussian splat decoder. 136 | 137 | Parameters 138 | ---------- 139 | shape : tuple 140 | A tuple describing the output shape of the image data. Can be 2- or 3- 141 | dimensional. For example: (32, 32, 32) 142 | n_splats : int 143 | The number of Gaussians in the mixture model. 144 | latent_dims : int 145 | The dimensions of the latent representation. 146 | output_channels : int, optional 147 | The number of output channels in the final image volume. If not 148 | supplied, this will default to 1. If it is supplied, additional 149 | convolutions are applied to the GMM model. 150 | splat_sigma_range : tuple[float, float] 151 | The minimum and maximum sigma values for each splat. Useful to control 152 | the resolution of the final render. 153 | default_axis : CartesianAxes 154 | A default cartesian axis to use for rotation if the pose is provided by 155 | a rotation only. Default is Z, equivalent to a typical image rotation 156 | about the central axis. 157 | 158 | Notes 159 | ----- 160 | Takes the latent code and pose estimate to generate a planar or volumetric 161 | image. The code is used to position N symmetric gaussians in the image 162 | volume which are then rotated by an explicit rotation transform. These are 163 | rendered as an image by evaluating the list of gaussians as a GMM. 164 | 165 | The renderer is differentiable and can therefore be used during training. 166 | """ 167 | 168 | def __init__( 169 | self, 170 | shape: Tuple[int], 171 | *, 172 | n_splats: int = 128, 173 | latent_dims: int = 8, 174 | output_channels: int = 0, 175 | splat_sigma_range: Tuple[float, float] = (0.02, 0.1), 176 | default_axis: CartesianAxes = CartesianAxes.Z, 177 | device: torch.device = torch.device("cpu"), 178 | pose_dims: int, 179 | ): 180 | super().__init__() 181 | 182 | self._device = device 183 | # centroids should be in the range of (-1, 1) 184 | self.centroids = torch.nn.Sequential( 185 | torch.nn.Linear(latent_dims, n_splats * 3), 186 | torch.nn.Tanh(), 187 | ) 188 | # weights are effectively whether a splat is used or not 189 | # use a soft step function to make this `binary` (but differentiable) 190 | # NOTE(arl): not sure if this really makes any difference 191 | self.weights = torch.nn.Sequential( 192 | torch.nn.Linear(latent_dims, n_splats), 193 | torch.nn.Tanh(), 194 | SoftStep(k=10.0), 195 | ) 196 | # sigma ends up being scaled by `splat_sigma_range` 197 | self.sigmas = torch.nn.Sequential( 198 | torch.nn.Linear(latent_dims, n_splats), 199 | torch.nn.Sigmoid(), 200 | ) 201 | # now set up the differentiable renderer 202 | self.configure_renderer( 203 | shape, 204 | splat_sigma_range=splat_sigma_range, 205 | default_axis=default_axis, 206 | device=device, 207 | ) 208 | 209 | self._device = device 210 | self._ndim = len(shape) 211 | self._output_channels = output_channels 212 | """Decode the splats to retrieve the coordinates, weights and sigmas.""" 213 | if pose_dims not in (1, 4): 214 | raise ValueError( 215 | "Pose needs to be either a single angle rotation about the " 216 | "`default_axis` or a full angle-axis representation in 3D. " 217 | ) 218 | self.pose = not (pose_dims == 0) 219 | 220 | # add a final convolutional decoder to generate an image if the number 221 | # of output channels has been provided 222 | if output_channels != 0: 223 | conv = ( 224 | torch.nn.Conv2d 225 | if self._ndim == SpatialDims.TWO 226 | else torch.nn.Conv3d 227 | ) 228 | self._decoder = torch.nn.Sequential( 229 | conv(1, 32, 3, padding="same"), 230 | torch.nn.ReLU(), 231 | conv(32, 32, 3, padding="same"), 232 | torch.nn.ReLU(), 233 | conv(32, output_channels, 3, padding="same"), 234 | ) 235 | 236 | def configure_renderer( 237 | self, 238 | shape: Tuple[int], 239 | *, 240 | splat_sigma_range: Tuple[float, float] = (0.02, 0.1), 241 | default_axis: CartesianAxes = CartesianAxes.Z, 242 | device: torch.device = torch.device("cpu"), 243 | ) -> None: 244 | """Reconfigure the renderer. 245 | 246 | Notes 247 | ----- 248 | This might be useful to do once a model is trained. For example, one 249 | could change the resolution of the rendered image by changing the 250 | `shape` of the output. 251 | """ 252 | self._shape = shape 253 | self._default_axis = default_axis.as_tensor() 254 | self._splatter = GaussianSplatRenderer( 255 | shape, 256 | device=device, 257 | ) 258 | self._splat_sigma_range = splat_sigma_range 259 | 260 | def decode_splats( 261 | self, z: torch.Tensor, pose: torch.Tensor 262 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 263 | 264 | # predict the centroids for the splats 265 | splats = self.centroids(z).view(z.shape[0], 3, -1) 266 | weights = self.weights(z) 267 | sigmas = self.sigmas(z) 268 | 269 | # get the batch size 270 | batch_size = z.shape[0] 271 | 272 | # in the case where the encoded pose only has one dimension, we need to 273 | # use the pose as a rotation about the z-axis 274 | if pose.shape[-1] == 1: 275 | pose = torch.concat( 276 | [ 277 | pose, 278 | torch.tile(self._default_axis, (batch_size, 1)), 279 | ], 280 | axis=-1, 281 | ) 282 | 283 | # convert axis angles to quaternions 284 | assert pose.shape[-1] == 4, pose.shape 285 | quaternions = axis_angle_to_quaternion(pose, normalize=True) 286 | 287 | # convert the quaternions to rotation matrices 288 | rotation_matrices = quaternion_to_rotation_matrix(quaternions) 289 | 290 | # rotate the 3D points using the rotation matrices 291 | rotated_splats = torch.matmul( 292 | rotation_matrices, 293 | splats, 294 | ) 295 | 296 | # use only the required spatial dimensions (batch, ndim, samples) 297 | # rotated_splats = rotated_splats[:, : self._ndim, :] 298 | 299 | return rotated_splats, weights, sigmas 300 | 301 | def forward( 302 | self, 303 | z: torch.Tensor, 304 | pose: torch.Tensor, 305 | *, 306 | use_final_convolution: bool = True, 307 | ) -> torch.Tensor: 308 | """Decode the latents to an image volume given an explicit transform. 309 | 310 | Parameters 311 | ---------- 312 | z : tensor 313 | An (N, D) tensor specifying the D dimensional latent encodings for 314 | the minibatch of N images. 315 | pose : tensor 316 | An (N, 1 | 4) tensor specifying the pose in terms of a single 317 | rotation (assumed around the z-axis) or a full axis-angle rotation. 318 | use_final_convolution: bool 319 | Whether to apply the final convolutional layers to recover the image. 320 | This can be useful to inspect the underlying structure in a trained 321 | model. 322 | 323 | Returns 324 | ------- 325 | x : tensor 326 | The decoded image from the latents and pose. 327 | """ 328 | 329 | # decode the splats from the latents and pose 330 | splats, weights, sigmas = self.decode_splats(z, pose) 331 | 332 | x = self._splatter( 333 | splats, weights, sigmas, splat_sigma_range=self._splat_sigma_range 334 | ) 335 | # if we're doing a final convolution, do it here 336 | if ( 337 | self._output_channels is not None 338 | and self._output_channels != 0 339 | and use_final_convolution 340 | ): 341 | x = self._decoder(x) 342 | return x 343 | --------------------------------------------------------------------------------