├── experiments ├── include │ ├── vae3d │ │ ├── fid.yaml │ │ ├── adversarial.yaml │ │ ├── mse.yaml │ │ └── README.md │ ├── vae2d │ │ ├── fid.yaml │ │ ├── adversarial.yaml │ │ ├── images │ │ │ └── MNIST_VAE_ResNetVAE2d_FID_49.png │ │ ├── mse.yaml │ │ └── README.md │ ├── vae4d │ │ ├── images │ │ │ └── example.gif │ │ ├── mse.yaml │ │ └── README.md │ └── vae1d │ │ ├── hparams.yaml │ │ ├── mse.yaml │ │ └── README.md ├── la5c │ ├── include │ │ └── hparams.yaml │ ├── classification │ │ ├── bilingual_hparams.yaml │ │ └── bilingual.yaml │ └── README.md ├── mnist │ ├── include │ │ ├── pg_schedule.yaml │ │ ├── dataset.yaml │ │ ├── common_vae.yaml │ │ └── classification_common.yaml │ ├── vae │ │ ├── fid.yaml │ │ ├── mse.yaml │ │ ├── pg_fid.yaml │ │ └── pg_mse.yaml │ └── classification │ │ ├── basic.yaml │ │ ├── embed_fid.yaml │ │ ├── embed_mse.yaml │ │ ├── sandwich_fid.yaml │ │ ├── sandwich_mse.yaml │ │ └── comparison.yaml ├── deeplesion │ ├── include │ │ └── hparams.yaml │ ├── images │ │ ├── initial_localization.png │ │ └── DeepLesion_Basic_localize_lesions_0.png │ ├── vae │ │ ├── halfres_mse_hparams.yaml │ │ ├── fid.yaml │ │ ├── mse_hparams.yaml │ │ ├── halfres_mse.yaml │ │ └── mse.yaml │ └── localization │ │ ├── pg.yaml │ │ ├── gaussian_hparams.yaml │ │ ├── halfres_hparams.yaml │ │ ├── halfres.yaml │ │ ├── basic_hparams.yaml │ │ ├── gaussian.yaml │ │ └── basic.yaml ├── grasp_and_lift_eeg │ ├── images │ │ ├── data_example.png │ │ ├── training_acc.jpg │ │ ├── validation_acc.jpg │ │ ├── balanced_train_loss.png │ │ ├── balanced_val_accuracy_small.png │ │ └── balanced_train_accuracy_small.png │ ├── include │ │ ├── hparams.yaml │ │ └── common_vae.yaml │ ├── vae │ │ ├── mse.yaml │ │ └── hparams.yaml │ ├── classification │ │ ├── halfres_hparams.yaml │ │ ├── balanced_hparams.yaml │ │ ├── kernelsize=9_hparams.yaml │ │ ├── kernelsize=9.yaml │ │ ├── basic_hparams.yaml │ │ ├── halfres.yaml │ │ ├── balanced.yaml │ │ └── basic.yaml │ ├── rl │ │ └── ppo.yaml │ └── augment │ │ └── augment.yaml ├── rsna-intracranial │ ├── images │ │ ├── overfitting.jpg │ │ ├── training-dynamics.jpg │ │ └── RSNA_HalfRes_classifier2d_20000.jpg │ ├── vae │ │ ├── fid.yaml │ │ └── mse.yaml │ ├── include │ │ ├── hparams.yaml │ │ ├── dataset.yaml │ │ └── vae_common.yaml │ └── classification │ │ ├── halfres_hparams.yaml │ │ ├── vae_classifier.yaml │ │ ├── basic_hparams.yaml │ │ ├── halfres.yaml │ │ └── basic.yaml ├── doom │ ├── vae │ │ └── mse.yaml │ ├── include │ │ └── common_vae.yaml │ └── README.md ├── graphics │ ├── images │ │ └── NeuralGBuffer_plot2d_55000.png │ ├── gbuffer │ │ └── basic.yaml │ └── README.md ├── trends-fmri │ ├── vae │ │ └── mse.yaml │ ├── README.md │ ├── include │ │ └── common.yaml │ └── regression │ │ └── basic.yaml ├── forrestgump │ ├── classification │ │ ├── linear_hparams.yaml │ │ ├── nonlinear_hparams.yaml │ │ ├── linear.yaml │ │ ├── nonlinear.yaml │ │ ├── unaligned_hparams.yaml │ │ └── unaligned.yaml │ └── include │ │ └── hparams.yaml ├── ffhq │ ├── vae │ │ ├── fid.yaml │ │ └── mse.yaml │ └── include │ │ └── common.yaml └── cq500 │ ├── vae │ ├── fid.yaml │ └── mse.yaml │ ├── classification │ └── basic.yaml.bak │ ├── include │ └── common.yaml │ └── README.md ├── data ├── .gitignore ├── cow_texture.png ├── fonts │ └── arial.ttf └── cow.mtl ├── images └── banner.png ├── docker_entrypoint.sh ├── .gitignore ├── scripts ├── build_docker.sh └── tensorboard.sh ├── .dockerignore ├── src ├── merge_strategy.py ├── copy_weights.py ├── models │ ├── maxpool4d.py │ ├── batchnorm4d.py │ ├── encoder_wrapper.py │ ├── upscale2d.py │ ├── resnet_gaussian_localizer2d.py │ ├── augmenter.py │ ├── renderer.py │ ├── resnet_augmenter1d.py │ ├── resnet_embed2d.py │ ├── resnet_regressor4d.py │ ├── resnet4d.py │ ├── __init__.py │ ├── resnet_classifier3d.py │ ├── resnet_sandwich2d.py │ ├── util.py │ ├── classifier.py │ ├── resnet_pg_localizer2d.py │ ├── resnet_rl2d.py │ ├── resnet_rl1d.py │ ├── resnet_renderer2d.py │ ├── resnet_localizer2d.py │ ├── resnet_classifier2d.py │ ├── resnet_classifier1d.py │ ├── resnet2d.py │ ├── resnet1d.py │ ├── resnet3d.py │ ├── resnet_vae1d.py │ ├── base.py │ ├── localizer.py │ ├── resnet_vae3d.py │ ├── resnet_vae4d.py │ └── iou.py ├── tune_test.py ├── dataset │ ├── reference.py │ ├── video.py │ ├── dicom_util.py │ ├── batch_video.py │ ├── forrestgump_converter.py │ ├── cq500.py │ ├── la5c.py │ ├── toy_neural_graphics.py │ └── trends_fmri.py ├── linear_warmup.py ├── niitest.py ├── decord_test.py ├── vae_classifier.py ├── plot_stat_map.py ├── remove-corrupt-dcm.py ├── plot_test.py ├── draw_eeg.py ├── load_config.py ├── env.py ├── video_grid.py ├── verify_dataset.py ├── resize-video.py ├── augmentation.py ├── draw_boxes.py ├── compiler.py ├── localization.py ├── render_teapot.py ├── classification.py └── neural_gbuffer.py ├── pyproject.toml ├── requirements.txt ├── .github └── workflows │ └── lint.yml ├── LICENSE-MIT └── Dockerfile /experiments/include/vae3d/fid.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | rsna-intracranial/ 2 | crx8/ -------------------------------------------------------------------------------- /experiments/include/vae3d/adversarial.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thavlik/machine-learning-portfolio/HEAD/images/banner.png -------------------------------------------------------------------------------- /data/cow_texture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thavlik/machine-learning-portfolio/HEAD/data/cow_texture.png -------------------------------------------------------------------------------- /data/fonts/arial.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thavlik/machine-learning-portfolio/HEAD/data/fonts/arial.ttf -------------------------------------------------------------------------------- /docker_entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euo pipefail 3 | cd "$(dirname "$0")" 4 | python src/main.py $@ 5 | -------------------------------------------------------------------------------- /experiments/include/vae2d/fid.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/include/vae2d/mse.yaml 2 | 3 | exp_params: 4 | fid_weight: 1.0 5 | -------------------------------------------------------------------------------- /experiments/la5c/include/hparams.yaml: -------------------------------------------------------------------------------- 1 | exp_params: 2 | optimizer: 3 | lr: 4 | grid_search: [0.0005, 0.005, 0.02, 0.05, 0.1] 5 | -------------------------------------------------------------------------------- /experiments/include/vae2d/adversarial.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/include/vae2d/mse.yaml 2 | 3 | exp_params: 4 | adversarial_weight: 1.0 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .DS_Store 3 | *.mp4 4 | *.part 5 | logs/ 6 | logs2/ 7 | logs_remote/ 8 | ray_results/ 9 | saves/ 10 | .venv/ -------------------------------------------------------------------------------- /experiments/mnist/include/pg_schedule.yaml: -------------------------------------------------------------------------------- 1 | exp_params: 2 | progressive_growing: 3 | - 0 # 7x7 4 | - 256 # 14x14 5 | - 1024 # 28x18 6 | -------------------------------------------------------------------------------- /experiments/include/vae4d/images/example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thavlik/machine-learning-portfolio/HEAD/experiments/include/vae4d/images/example.gif -------------------------------------------------------------------------------- /experiments/deeplesion/include/hparams.yaml: -------------------------------------------------------------------------------- 1 | exp_params: 2 | optimizer: 3 | lr: 4 | uniform: 5 | lower: 0.0000001 6 | upper: 0.0003 7 | -------------------------------------------------------------------------------- /experiments/deeplesion/images/initial_localization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thavlik/machine-learning-portfolio/HEAD/experiments/deeplesion/images/initial_localization.png -------------------------------------------------------------------------------- /experiments/grasp_and_lift_eeg/images/data_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thavlik/machine-learning-portfolio/HEAD/experiments/grasp_and_lift_eeg/images/data_example.png -------------------------------------------------------------------------------- /experiments/grasp_and_lift_eeg/images/training_acc.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thavlik/machine-learning-portfolio/HEAD/experiments/grasp_and_lift_eeg/images/training_acc.jpg -------------------------------------------------------------------------------- /experiments/rsna-intracranial/images/overfitting.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thavlik/machine-learning-portfolio/HEAD/experiments/rsna-intracranial/images/overfitting.jpg -------------------------------------------------------------------------------- /experiments/grasp_and_lift_eeg/images/validation_acc.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thavlik/machine-learning-portfolio/HEAD/experiments/grasp_and_lift_eeg/images/validation_acc.jpg -------------------------------------------------------------------------------- /experiments/doom/vae/mse.yaml: -------------------------------------------------------------------------------- 1 | include: 2 | - experiments/include/vae2d/mse.yaml 3 | - experiments/doom/include/common_vae.yaml 4 | 5 | logging_params: 6 | name: "Doom_VAE_MSE" 7 | -------------------------------------------------------------------------------- /experiments/graphics/images/NeuralGBuffer_plot2d_55000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thavlik/machine-learning-portfolio/HEAD/experiments/graphics/images/NeuralGBuffer_plot2d_55000.png -------------------------------------------------------------------------------- /experiments/rsna-intracranial/images/training-dynamics.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thavlik/machine-learning-portfolio/HEAD/experiments/rsna-intracranial/images/training-dynamics.jpg -------------------------------------------------------------------------------- /scripts/build_docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euo pipefail 3 | cd "$(dirname "$0")"/.. 4 | IMAGE=thavlik/machine-learning-portfolio:latest 5 | docker build -t $IMAGE . 6 | docker push $IMAGE -------------------------------------------------------------------------------- /experiments/grasp_and_lift_eeg/images/balanced_train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thavlik/machine-learning-portfolio/HEAD/experiments/grasp_and_lift_eeg/images/balanced_train_loss.png -------------------------------------------------------------------------------- /experiments/mnist/vae/fid.yaml: -------------------------------------------------------------------------------- 1 | include: 2 | - experiments/include/vae2d/fid.yaml 3 | - experiments/mnist/include/common_vae.yaml 4 | 5 | logging_params: 6 | name: "MNIST_VAE_FID" 7 | -------------------------------------------------------------------------------- /experiments/mnist/vae/mse.yaml: -------------------------------------------------------------------------------- 1 | include: 2 | - experiments/include/vae2d/mse.yaml 3 | - experiments/mnist/include/common_vae.yaml 4 | 5 | logging_params: 6 | name: "MNIST_VAE_MSE" 7 | -------------------------------------------------------------------------------- /experiments/mnist/vae/pg_fid.yaml: -------------------------------------------------------------------------------- 1 | include: 2 | - experiments/mnist/vae/fid.yaml 3 | - experiments/mnist/include/pg_schedule.yaml 4 | 5 | logging_params: 6 | name: "MNIST_VAE_PG_FID" 7 | -------------------------------------------------------------------------------- /experiments/mnist/vae/pg_mse.yaml: -------------------------------------------------------------------------------- 1 | include: 2 | - experiments/mnist/vae/mse.yaml 3 | - experiments/mnist/include/pg_schedule.yaml 4 | 5 | logging_params: 6 | name: "MNIST_VAE_PG_MSE" 7 | -------------------------------------------------------------------------------- /experiments/trends-fmri/vae/mse.yaml: -------------------------------------------------------------------------------- 1 | include: 2 | - experiments/include/vae4d/mse.yaml 3 | - experiments/trends-fmri/include/common.yaml 4 | 5 | logging_params: 6 | name: "TReNDS_VAE_MSE" 7 | -------------------------------------------------------------------------------- /experiments/include/vae2d/images/MNIST_VAE_ResNetVAE2d_FID_49.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thavlik/machine-learning-portfolio/HEAD/experiments/include/vae2d/images/MNIST_VAE_ResNetVAE2d_FID_49.png -------------------------------------------------------------------------------- /experiments/deeplesion/images/DeepLesion_Basic_localize_lesions_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thavlik/machine-learning-portfolio/HEAD/experiments/deeplesion/images/DeepLesion_Basic_localize_lesions_0.png -------------------------------------------------------------------------------- /experiments/grasp_and_lift_eeg/images/balanced_val_accuracy_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thavlik/machine-learning-portfolio/HEAD/experiments/grasp_and_lift_eeg/images/balanced_val_accuracy_small.png -------------------------------------------------------------------------------- /experiments/mnist/classification/basic.yaml: -------------------------------------------------------------------------------- 1 | include: 2 | - experiments/mnist/include/classification_common.yaml 3 | 4 | entrypoint: classification 5 | 6 | logging_params: 7 | name: "MNIST_Basic" 8 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .git/ 2 | .gitignore 3 | .vscode/ 4 | __pycache__/ 5 | data/ 6 | README.md 7 | logs/ 8 | logs2/ 9 | logs_remote/ 10 | ray_results/ 11 | saves/ 12 | *.mp4 13 | *.part 14 | *.gif 15 | .venv/ -------------------------------------------------------------------------------- /data/cow.mtl: -------------------------------------------------------------------------------- 1 | newmtl material_1 2 | map_Kd cow_texture.png 3 | 4 | # Test colors 5 | 6 | Ka 1.000 1.000 1.000 # white 7 | Kd 1.000 1.000 1.000 # white 8 | Ks 0.000 0.000 0.000 # black 9 | Ns 10.0 10 | -------------------------------------------------------------------------------- /experiments/grasp_and_lift_eeg/images/balanced_train_accuracy_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thavlik/machine-learning-portfolio/HEAD/experiments/grasp_and_lift_eeg/images/balanced_train_accuracy_small.png -------------------------------------------------------------------------------- /experiments/grasp_and_lift_eeg/include/hparams.yaml: -------------------------------------------------------------------------------- 1 | exp_params: 2 | save_weights: 3 | every_n_steps: 0 4 | optimizer: 5 | lr: 6 | grid_search: [0.0000001, 0.000001, 0.00001, 0.0001, 0.001] 7 | -------------------------------------------------------------------------------- /experiments/rsna-intracranial/images/RSNA_HalfRes_classifier2d_20000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thavlik/machine-learning-portfolio/HEAD/experiments/rsna-intracranial/images/RSNA_HalfRes_classifier2d_20000.jpg -------------------------------------------------------------------------------- /experiments/rsna-intracranial/vae/fid.yaml: -------------------------------------------------------------------------------- 1 | include: 2 | - experiments/include/vae2d/fid.yaml 3 | - experiments/rsna-intracranial/include/vae_common.yaml 4 | 5 | logging_params: 6 | name: "RSNA_VAE_FID" 7 | -------------------------------------------------------------------------------- /experiments/rsna-intracranial/vae/mse.yaml: -------------------------------------------------------------------------------- 1 | include: 2 | - experiments/include/vae2d/mse.yaml 3 | - experiments/rsna-intracranial/include/common_vae.yaml 4 | 5 | logging_params: 6 | name: "RSNA_VAE_MSE" 7 | -------------------------------------------------------------------------------- /experiments/grasp_and_lift_eeg/vae/mse.yaml: -------------------------------------------------------------------------------- 1 | include: 2 | - experiments/include/vae1d/mse.yaml 3 | - experiments/grasp_and_lift_eeg/include/common_vae.yaml 4 | 5 | logging_params: 6 | name: 'GraspLiftEEG_VAE_MSE' 7 | -------------------------------------------------------------------------------- /experiments/grasp_and_lift_eeg/vae/hparams.yaml: -------------------------------------------------------------------------------- 1 | entrypoint: hparam_search 2 | 3 | num_samples: 8 4 | 5 | experiment: 6 | - experiments/grasp_and_lift_eeg/vae/mse.yaml 7 | - experiments/include/vae1d/hparams.yaml 8 | 9 | -------------------------------------------------------------------------------- /experiments/rsna-intracranial/include/hparams.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | dropout: 3 | grid_search: [0.15, 0.2, 0.25] 4 | 5 | exp_params: 6 | optimizer: 7 | lr: 8 | grid_search: [0.0005, 0.001, 0.01, 0.015, 0.02] 9 | -------------------------------------------------------------------------------- /experiments/deeplesion/vae/halfres_mse_hparams.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/deeplesion/vae/mse_hparams.yaml 2 | 3 | experiment: 4 | - experiments/deeplesion/vae/halfres_mse.yaml 5 | - experiments/deeplesion/include/hparams.yaml 6 | 7 | -------------------------------------------------------------------------------- /experiments/deeplesion/localization/pg.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/deeplesion/localization/basic.yaml 2 | 3 | model_params: 4 | arch: ResNetPGLocalizer2d 5 | name: ResNetPGLocalizer2d 6 | 7 | logging_params: 8 | name: "DeepLesion_PG" 9 | -------------------------------------------------------------------------------- /experiments/deeplesion/localization/gaussian_hparams.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/deeplesion/localization/basic_hparams.yaml 2 | 3 | experiment: 4 | - experiments/deeplesion/localization/gaussian.yaml 5 | - experiments/deeplesion/include/hparams.yaml 6 | -------------------------------------------------------------------------------- /experiments/deeplesion/localization/halfres_hparams.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/deeplesion/localization/basic_hparams.yaml 2 | 3 | experiment: 4 | - experiments/deeplesion/localization/halfres.yaml 5 | - experiments/deeplesion/include/hparams.yaml 6 | 7 | -------------------------------------------------------------------------------- /experiments/forrestgump/classification/linear_hparams.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/forrestgump/classification/unaligned_hparams.yaml 2 | 3 | experiment: 4 | - experiments/forrestgump/classification/linear.yaml 5 | - experiments/forrestgump/include/hparams.yaml 6 | -------------------------------------------------------------------------------- /experiments/forrestgump/classification/nonlinear_hparams.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/forrestgump/classification/unaligned_hparams.yaml 2 | 3 | experiment: 4 | - experiments/forrestgump/classification/nonlinear.yaml 5 | - experiments/forrestgump/include/hparams.yaml 6 | -------------------------------------------------------------------------------- /experiments/forrestgump/classification/linear.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/forrestgump/classification/unaligned.yaml 2 | 3 | exp_params: 4 | data: 5 | training: 6 | alignment: linear 7 | 8 | logging_params: 9 | name: "ForrestGump_Conv3d_Linear" 10 | 11 | -------------------------------------------------------------------------------- /src/merge_strategy.py: -------------------------------------------------------------------------------- 1 | from deepmerge import Merger 2 | 3 | strategy = Merger([(list, "override"), (dict, "merge")], ["override"], 4 | ["override"]) 5 | 6 | 7 | def deep_merge(base: dict, next_: dict) -> dict: 8 | return strategy.merge(base, next_) 9 | -------------------------------------------------------------------------------- /src/copy_weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | new = torch.load('../new.pt') 4 | old = torch.load('../old.pt') 5 | 6 | for k in new: 7 | if k in old: 8 | new[k] = old[k].view(new[k].shape) 9 | print(f'Copied {k}') 10 | 11 | torch.save(new, '../new.pt') 12 | -------------------------------------------------------------------------------- /experiments/forrestgump/classification/nonlinear.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/forrestgump/classification/unaligned.yaml 2 | 3 | exp_params: 4 | data: 5 | training: 6 | alignment: nonlinear 7 | 8 | logging_params: 9 | name: "ForrestGump_Conv3d_Nonlinear" 10 | 11 | -------------------------------------------------------------------------------- /experiments/include/vae1d/hparams.yaml: -------------------------------------------------------------------------------- 1 | exp_params: 2 | optimizer: 3 | lr: 4 | grid_search: 5 | - 0.00001 6 | - 0.00005 7 | - 0.0001 8 | - 0.0002 9 | batch_size: 10 | grid_search: 11 | - 8 12 | - 16 13 | - 32 14 | -------------------------------------------------------------------------------- /experiments/rsna-intracranial/classification/halfres_hparams.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/rsna-intracranial/classification/basic_hparams.yaml 2 | 3 | experiment: 4 | - experiments/rsna-intracranial/classification/halfres.yaml 5 | - experiments/rsna-intracranial/include/hparams.yaml 6 | 7 | -------------------------------------------------------------------------------- /experiments/ffhq/vae/fid.yaml: -------------------------------------------------------------------------------- 1 | include: 2 | - experiments/include/vae2d/fid.yaml 3 | - experiments/ffhq/include/common.yaml 4 | 5 | exp_params: 6 | plot: 7 | params: 8 | title: "${model}, FFHQ, FID loss, Epoch ${epoch}" 9 | 10 | logging_params: 11 | name: "FFHQ_VAE_FID" -------------------------------------------------------------------------------- /experiments/grasp_and_lift_eeg/classification/halfres_hparams.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/grasp_and_lift_eeg/classification/basic_hparams.yaml 2 | 3 | experiment: 4 | - experiments/grasp_and_lift_eeg/classification/halfres.yaml 5 | - experiments/grasp_and_lift_eeg/include/hparams.yaml 6 | 7 | -------------------------------------------------------------------------------- /experiments/cq500/vae/fid.yaml: -------------------------------------------------------------------------------- 1 | include: 2 | - experiments/include/vae2d/fid.yaml 3 | - experiments/cq500/include/common.yaml 4 | 5 | exp_params: 6 | plot: 7 | params: 8 | title: "${model}, CQ500, FID loss, Epoch ${epoch}" 9 | 10 | logging_params: 11 | name: "CQ500_VAE_FID" -------------------------------------------------------------------------------- /experiments/cq500/vae/mse.yaml: -------------------------------------------------------------------------------- 1 | include: 2 | - experiments/include/vae2d/mse.yaml 3 | - experiments/cq500/include/common.yaml 4 | 5 | exp_params: 6 | plot: 7 | params: 8 | title: "${model}, CQ500, MSE loss, Epoch ${epoch}" 9 | 10 | logging_params: 11 | name: "CQ500_VAE_MSE" -------------------------------------------------------------------------------- /experiments/ffhq/vae/mse.yaml: -------------------------------------------------------------------------------- 1 | include: 2 | - experiments/include/vae2d/mse.yaml 3 | - experiments/ffhq/include/common.yaml 4 | 5 | exp_params: 6 | plot: 7 | params: 8 | title: "${model}, FFHQ, MSE loss, Epoch ${epoch}" 9 | 10 | logging_params: 11 | name: "FFHQ_VAE_MSE" -------------------------------------------------------------------------------- /experiments/grasp_and_lift_eeg/classification/balanced_hparams.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/grasp_and_lift_eeg/classification/basic_hparams.yaml 2 | 3 | experiment: 4 | - experiments/grasp_and_lift_eeg/classification/balanced.yaml 5 | - experiments/grasp_and_lift_eeg/include/hparams.yaml 6 | 7 | -------------------------------------------------------------------------------- /experiments/grasp_and_lift_eeg/classification/kernelsize=9_hparams.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/grasp_and_lift_eeg/classification/basic_hparams.yaml 2 | 3 | experiment: 4 | - experiments/grasp_and_lift_eeg/classification/kernelsize=9.yaml 5 | - experiments/grasp_and_lift_eeg/include/hparams.yaml 6 | 7 | -------------------------------------------------------------------------------- /experiments/rsna-intracranial/classification/vae_classifier.yaml: -------------------------------------------------------------------------------- 1 | include: 2 | - experiments/rsna-intracranial/include/vae_common.yaml 3 | 4 | entrypoint: vae_classifier2d 5 | 6 | classifier_params: 7 | hidden_dim: [128, 64, 32] 8 | 9 | logging_params: 10 | name: "RSNA_Basic" 11 | 12 | -------------------------------------------------------------------------------- /experiments/grasp_and_lift_eeg/classification/kernelsize=9.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/grasp_and_lift_eeg/classification/basic.yaml 2 | 3 | model_params: 4 | name: 'ResNetClassifier1d_KernelSize=9' 5 | kernel_size: 9 6 | padding: 4 7 | 8 | logging_params: 9 | name: "GraspLift_KernelSize=9" 10 | -------------------------------------------------------------------------------- /src/models/maxpool4d.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import functional as F 3 | 4 | 5 | class MaxPool4d(nn.MaxPool3d): 6 | 7 | def __init__(self) -> None: 8 | super(MaxPool4d, self).__init__() 9 | 10 | def forward(self, input): 11 | raise NotImplementedError 12 | -------------------------------------------------------------------------------- /experiments/deeplesion/vae/fid.yaml: -------------------------------------------------------------------------------- 1 | include: 2 | - experiments/include/vae2d/fid.yaml 3 | - experiments/deeplesion/include/common.yaml 4 | 5 | exp_params: 6 | plot: 7 | params: 8 | title: "${model}, DeepLesion, FID loss, Epoch ${epoch}" 9 | 10 | logging_params: 11 | name: "DeepLesionVAE_FID" -------------------------------------------------------------------------------- /experiments/deeplesion/vae/mse_hparams.yaml: -------------------------------------------------------------------------------- 1 | entrypoint: hparam_search 2 | 3 | num_samples: 4 4 | 5 | num_train_steps: 50_000 6 | 7 | num_val_steps: 2_000 8 | 9 | randomize_seed: true 10 | 11 | experiment: 12 | - experiments/deeplesion/vae/mse.yaml 13 | - experiments/deeplesion/include/hparams.yaml 14 | -------------------------------------------------------------------------------- /src/models/batchnorm4d.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.batchnorm import _BatchNorm 2 | 3 | 4 | class BatchNorm4d(_BatchNorm): 5 | 6 | def _check_input_dim(self, input): 7 | if input.dim() != 6: 8 | raise ValueError('expected 6D input (got {}D input)'.format( 9 | input.dim())) 10 | -------------------------------------------------------------------------------- /experiments/deeplesion/localization/halfres.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/deeplesion/localization/basic.yaml 2 | 3 | exp_params: 4 | batch_size: 16 5 | optimizer: 6 | lr: 5.0e-06 7 | weight_decay: 1.0e-05 8 | data: 9 | training: 10 | lod: 1 11 | 12 | logging_params: 13 | name: "DeepLesion_HalfRes" 14 | -------------------------------------------------------------------------------- /experiments/la5c/classification/bilingual_hparams.yaml: -------------------------------------------------------------------------------- 1 | entrypoint: hparam_search 2 | 3 | num_samples: 3 4 | 5 | num_train_steps: 10_000 6 | 7 | num_val_steps: 50 8 | 9 | randomize_seed: true 10 | 11 | experiment: 12 | - experiments/la5c/classification/bilingual.yaml 13 | - experiments/la5c/include/hparams.yaml 14 | -------------------------------------------------------------------------------- /experiments/mnist/include/dataset.yaml: -------------------------------------------------------------------------------- 1 | exp_params: 2 | data: 3 | name: reference 4 | training: 5 | name: MNIST 6 | params: 7 | root: E:/ 8 | download: true 9 | train: true 10 | validation: 11 | params: 12 | train: false 13 | batch_size: 64 14 | warmup_steps: 256 -------------------------------------------------------------------------------- /experiments/deeplesion/localization/basic_hparams.yaml: -------------------------------------------------------------------------------- 1 | entrypoint: hparam_search 2 | 3 | num_samples: 32 4 | 5 | num_train_steps: 4_000 6 | 7 | num_val_steps: 200 8 | 9 | randomize_seed: true 10 | 11 | experiment: 12 | - experiments/deeplesion/localization/basic.yaml 13 | - experiments/deeplesion/include/hparams.yaml 14 | -------------------------------------------------------------------------------- /experiments/deeplesion/localization/gaussian.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/deeplesion/localization/basic.yaml 2 | 3 | exp_params: 4 | optimizer: 5 | lr: 0.01 6 | 7 | model_params: 8 | arch: ResNetGaussianLocalizer2d 9 | name: ResNetGaussianLocalizer2d 10 | dropout: 0.2 11 | 12 | logging_params: 13 | name: "DeepLesion_Gaussian" 14 | -------------------------------------------------------------------------------- /experiments/forrestgump/classification/unaligned_hparams.yaml: -------------------------------------------------------------------------------- 1 | entrypoint: hparam_search 2 | 3 | num_samples: 3 4 | 5 | num_train_steps: 2_500 6 | 7 | num_val_steps: 300 8 | 9 | randomize_seed: true 10 | 11 | experiment: 12 | - experiments/forrestgump/classification/unaligned.yaml 13 | - experiments/forrestgump/include/hparams.yaml 14 | -------------------------------------------------------------------------------- /experiments/rsna-intracranial/include/dataset.yaml: -------------------------------------------------------------------------------- 1 | exp_params: 2 | data: 3 | name: rsna-intracranial 4 | split: 0.7 5 | loader: 6 | pin_memory: false 7 | num_workers: 0 8 | drop_last: true 9 | training: 10 | root: /data/rsna-ich 11 | download: false 12 | use_gzip: true 13 | warmup_steps: 256 14 | -------------------------------------------------------------------------------- /experiments/rsna-intracranial/classification/basic_hparams.yaml: -------------------------------------------------------------------------------- 1 | entrypoint: hparam_search 2 | 3 | num_samples: 4 4 | 5 | num_train_steps: 20_000 6 | 7 | num_val_steps: 1_000 8 | 9 | randomize_seed: true 10 | 11 | experiment: 12 | - experiments/rsna-intracranial/classification/basic.yaml 13 | - experiments/rsna-intracranial/include/hparams.yaml 14 | -------------------------------------------------------------------------------- /experiments/trends-fmri/README.md: -------------------------------------------------------------------------------- 1 | These experiments utilize the [TReNDS Neuroimaging](https://www.kaggle.com/c/trends-assessment-prediction/) dataset from kaggle. 2 | 3 | ## Results 4 | (TODO: insert picture of validation results) 5 | 6 | ## TODO 7 | - Download data from spaces 8 | - Write dataset class for 3D MRI data 9 | - Visualize fMRI data in validation step -------------------------------------------------------------------------------- /experiments/mnist/include/common_vae.yaml: -------------------------------------------------------------------------------- 1 | include: 2 | - experiments/mnist/include/dataset.yaml 3 | 4 | model_params: 5 | hidden_dims: [64, 64] 6 | 7 | # Only ten classes of digits, with minimal variation 8 | # between within each class. 9 | latent_dim: 10 10 | 11 | trainer_params: 12 | max_epochs: 100 13 | log_every_n_steps: 20 14 | 15 | -------------------------------------------------------------------------------- /experiments/mnist/classification/embed_fid.yaml: -------------------------------------------------------------------------------- 1 | include: 2 | - experiments/mnist/include/classification_common.yaml 3 | 4 | entrypoint: classification_embed2d 5 | 6 | base_experiment: experiments/mnist/vae/fid.yaml 7 | 8 | model_params: 9 | arch: 'ResNetEmbed2d' 10 | name: 'ResNetEmbed2d' 11 | 12 | logging_params: 13 | name: "MNIST_EmbedFID" 14 | 15 | -------------------------------------------------------------------------------- /experiments/mnist/classification/embed_mse.yaml: -------------------------------------------------------------------------------- 1 | include: 2 | - experiments/mnist/include/classification_common.yaml 3 | 4 | entrypoint: classification_embed2d 5 | 6 | base_experiment: experiments/mnist/vae/mse.yaml 7 | 8 | model_params: 9 | arch: 'ResNetEmbed2d' 10 | name: 'ResNetEmbed2d' 11 | 12 | logging_params: 13 | name: "MNIST_EmbedMSE" 14 | 15 | -------------------------------------------------------------------------------- /experiments/deeplesion/vae/halfres_mse.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/deeplesion/vae/mse.yaml 2 | 3 | model_params: 4 | dropout: 0.1 5 | hidden_dims: [128, 256, 512, 256, 128] 6 | 7 | exp_params: 8 | batch_size: 12 9 | optimizer: 10 | lr: 0.02 11 | data: 12 | training: 13 | lod: 1 14 | 15 | logging_params: 16 | name: "DeepLesionVAE_HalfRes" 17 | -------------------------------------------------------------------------------- /experiments/mnist/classification/sandwich_fid.yaml: -------------------------------------------------------------------------------- 1 | include: 2 | - experiments/mnist/include/classification_common.yaml 3 | 4 | entrypoint: classification_sandwich2d 5 | 6 | base_experiment: experiments/mnist/vae/fid.yaml 7 | 8 | model_params: 9 | arch: 'ResNetSandwich2d' 10 | name: 'ResNetSandwich2d' 11 | 12 | logging_params: 13 | name: "MNIST_SandwichFID" 14 | 15 | -------------------------------------------------------------------------------- /experiments/mnist/classification/sandwich_mse.yaml: -------------------------------------------------------------------------------- 1 | include: 2 | - experiments/mnist/include/classification_common.yaml 3 | 4 | entrypoint: classification_sandwich2d 5 | 6 | base_experiment: experiments/mnist/vae/mse.yaml 7 | 8 | model_params: 9 | arch: 'ResNetSandwich2d' 10 | name: 'ResNetSandwich2d' 11 | 12 | logging_params: 13 | name: "MNIST_SandwichMSE" 14 | 15 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # We must configure isort to recognize torch as a standard library 2 | # so that it moves torch imports to the top of the import list, before 3 | # any third-party imports. This is necessary because decord is imported 4 | # in the entrypoint (as to configure the bridge) but it doesn't import 5 | # torch properly. 6 | 7 | [tool.isort] 8 | extra_standard_library = ["torch"] -------------------------------------------------------------------------------- /experiments/grasp_and_lift_eeg/classification/basic_hparams.yaml: -------------------------------------------------------------------------------- 1 | entrypoint: hparam_search 2 | 3 | metric: train/loss 4 | 5 | num_samples: 3 6 | 7 | num_train_steps: 100_000 8 | 9 | num_val_steps: 5_000 10 | 11 | randomize_seed: true 12 | 13 | experiment: 14 | - experiments/grasp_and_lift_eeg/classification/basic.yaml 15 | - experiments/grasp_and_lift_eeg/include/hparams.yaml 16 | -------------------------------------------------------------------------------- /experiments/grasp_and_lift_eeg/classification/halfres.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/grasp_and_lift_eeg/classification/basic.yaml 2 | 3 | model_params: 4 | hidden_dims: [512, 1024, 1024, 512] 5 | dropout: 0.3 6 | 7 | exp_params: 8 | batch_size: 64 9 | optimizer: 10 | lr: 0.001 11 | data: 12 | training: 13 | lod: 1 14 | 15 | logging_params: 16 | name: "GraspLift_HalfRes" 17 | -------------------------------------------------------------------------------- /experiments/forrestgump/include/hparams.yaml: -------------------------------------------------------------------------------- 1 | exp_params: 2 | optimizer: 3 | lr: 4 | grid_search: [0.02, 0.05, 0.08] 5 | data: 6 | training: 7 | offset: 8 | # Offset all labels this many seconds into the future 9 | # A previous hyperparameter search determined offset=0.0 10 | # performed worse than any non-zero offset. 11 | grid_search: [2, 4, 6, 8] 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | opencv-python 3 | torchvision 4 | pytorch-lightning>=1.0.7 5 | test-tube 6 | youtube-dl 7 | boto3 8 | h5py 9 | nilearn 10 | pylibjpeg 11 | pylibjpeg-libjpeg 12 | matplotlib 13 | lz4 14 | ray[tune] 15 | gymnasium 16 | dm-tree 17 | opencv-python 18 | pydicom 19 | scikit-image 20 | plotly==4.12.0 21 | deepmerge 22 | decord 23 | requests 24 | visdom 25 | seaborn 26 | grpcio 27 | grpcio-tools -------------------------------------------------------------------------------- /experiments/cq500/classification/basic.yaml.bak: -------------------------------------------------------------------------------- 1 | entrypoint: basic_prediction 2 | 3 | model_params: 4 | name: 'CQ500BasicPrediction' 5 | hidden_dims: [256, 512] 6 | 7 | exp_params: 8 | learning_rate: 0.0002 9 | batch_size: 128 10 | 11 | trainer_params: 12 | max_epochs: 10000000 13 | check_val_every_n_epoch: 200 14 | 15 | logging_params: 16 | name: "CQ500BasicPrediction" 17 | 18 | manual_seed: 1498 19 | -------------------------------------------------------------------------------- /experiments/include/vae3d/mse.yaml: -------------------------------------------------------------------------------- 1 | entrypoint: vae3d 2 | 3 | model_params: 4 | arch: 'ResNetVAE3d' 5 | name: 'ResNetVAE3d' 6 | latent_dim: 512 7 | hidden_dims: [32, 64, 32] 8 | 9 | exp_params: 10 | optimizer: 11 | lr: 0.0002 12 | weight_decay: 0.00001 13 | batch_size: 128 14 | 15 | trainer_params: 16 | max_epochs: 10000000 17 | log_every_n_steps: 200 18 | check_val_every_n_epoch: 200 19 | 20 | manual_seed: 1498 21 | -------------------------------------------------------------------------------- /experiments/include/vae4d/mse.yaml: -------------------------------------------------------------------------------- 1 | entrypoint: vae4d 2 | 3 | model_params: 4 | arch: 'ResNetVAE4d' 5 | name: 'ResNetVAE4d' 6 | latent_dim: 64 7 | hidden_dims: [32, 64, 32] 8 | 9 | exp_params: 10 | optimizer: 11 | lr: 0.0002 12 | weight_decay: 0.00001 13 | batch_size: 128 14 | 15 | trainer_params: 16 | max_epochs: 10000000 17 | log_every_n_steps: 200 18 | check_val_every_n_epoch: 200 19 | 20 | manual_seed: 1498 21 | -------------------------------------------------------------------------------- /experiments/ffhq/include/common.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | pooling: max 3 | 4 | exp_params: 5 | batch_size: 32 6 | data: 7 | name: reference 8 | training: 9 | name: ImageFolder 10 | params: 11 | root: E:/thumbnails128x128 12 | plot: 13 | params: 14 | scaling: 2.0 15 | suptitle: 16 | y: 0.78 17 | 18 | trainer_params: 19 | max_epochs: 10 20 | log_every_n_steps: 20 21 | check_val_every_n_epoch: 1 22 | -------------------------------------------------------------------------------- /experiments/include/vae1d/mse.yaml: -------------------------------------------------------------------------------- 1 | entrypoint: vae1d 2 | 3 | model_params: 4 | arch: 'ResNetVAE1d' 5 | name: 'ResNetVAE1d' 6 | latent_dim: 16 7 | hidden_dims: [256, 512, 512, 256] 8 | 9 | exp_params: 10 | optimizer: 11 | lr: 0.0002 12 | weight_decay: 0.00001 13 | batch_size: 32 14 | warmup_steps: 512 15 | 16 | trainer_params: 17 | max_epochs: 10000000 18 | log_every_n_steps: 200 19 | check_val_every_n_epoch: 1 20 | 21 | manual_seed: 1498 22 | -------------------------------------------------------------------------------- /experiments/rsna-intracranial/classification/halfres.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/rsna-intracranial/classification/basic.yaml 2 | 3 | model_params: 4 | hidden_dims: [256, 192, 128, 64] 5 | load_weights: logs/RSNA_HalfRes/version_0/checkpoints/epoch=1-step=37639.ckpt 6 | dropout: 0.15 7 | 8 | exp_params: 9 | batch_size: 12 10 | optimizer: 11 | lr: 0.015 12 | data: 13 | training: 14 | lod: 1 15 | 16 | logging_params: 17 | name: "RSNA_HalfRes" 18 | -------------------------------------------------------------------------------- /experiments/mnist/include/classification_common.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/mnist/include/dataset.yaml 2 | 3 | entrypoint: classification 4 | 5 | model_params: 6 | arch: 'ResNetClassifier2d' 7 | name: 'ResNetClassifier2d' 8 | num_classes: 10 9 | hidden_dims: [64, 64] 10 | 11 | exp_params: 12 | optimizer: 13 | lr: 0.0002 14 | weight_decay: 0.00001 15 | 16 | trainer_params: 17 | max_epochs: 8 18 | log_every_n_steps: 20 19 | check_val_every_n_epoch: 1 20 | 21 | manual_seed: 1498 22 | -------------------------------------------------------------------------------- /experiments/doom/include/common_vae.yaml: -------------------------------------------------------------------------------- 1 | include: 2 | - experiments/include/vae2d/mse.yaml 3 | 4 | model_params: 5 | hidden_dims: [64, 128, 64] 6 | pooling: max 7 | 8 | exp_params: 9 | data: 10 | name: video 11 | training: 12 | dir: E:/doom 13 | width: 320 14 | height: 240 15 | limit: 3 16 | validation: 17 | dir: E:/doom 18 | batch_size: 32 19 | warmup_steps: 256 20 | 21 | trainer_params: 22 | max_epochs: 100 23 | log_every_n_steps: 20 24 | check_val_every_n_epoch: 1 25 | -------------------------------------------------------------------------------- /experiments/rsna-intracranial/include/vae_common.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/rsna-intracranial/include/dataset.yaml 2 | 3 | model_params: 4 | pooling: max 5 | hidden_dims: [64, 128, 128, 128, 64] 6 | 7 | exp_params: 8 | batch_size: 8 9 | plot: 10 | fn: dcm 11 | params: 12 | title: "${model}, RSNA Intracranial Hemorrhage CT Slices, Epoch ${epoch}" 13 | scaling: 2.0 14 | suptitle: 15 | y: 0.78 16 | 17 | trainer_params: 18 | max_epochs: 5 19 | log_every_n_steps: 20 20 | check_val_every_n_epoch: 1 -------------------------------------------------------------------------------- /experiments/grasp_and_lift_eeg/classification/balanced.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/grasp_and_lift_eeg/classification/basic.yaml 2 | 3 | exp_params: 4 | optimizer: 5 | lr: 0.00005 6 | batch_size: 16 7 | data: 8 | balanced: 9 | labels: 10 | - [0, 0, 0, 0, 0, 0] 11 | - [1, 0, 0, 0, 0, 0] 12 | - [0, 1, 0, 0, 0, 0] 13 | - [0, 0, 1, 0, 0, 0] 14 | - [0, 0, 0, 1, 0, 0] 15 | - [0, 0, 0, 0, 1, 0] 16 | - [0, 0, 0, 0, 0, 1] 17 | 18 | logging_params: 19 | name: "GraspLift_Balanced" 20 | -------------------------------------------------------------------------------- /src/tune_test.py: -------------------------------------------------------------------------------- 1 | import ray 2 | from ray import tune 3 | from ray.tune.integration.torch import DistributedTrainableCreator 4 | from ray.util.sgd.torch import is_distributed_trainable 5 | 6 | ray.init() 7 | 8 | 9 | def my_trainable(config, checkpoint_dir=None): 10 | if is_distributed_trainable(): 11 | pass 12 | 13 | 14 | trainable = DistributedTrainableCreator( 15 | my_trainable, 16 | use_gpu=True, 17 | num_workers=4, 18 | num_cpus_per_worker=1, 19 | ) 20 | 21 | config = {} 22 | 23 | tune.run( 24 | trainable, 25 | resources_per_trial=None, 26 | config=config, 27 | ) 28 | -------------------------------------------------------------------------------- /src/dataset/reference.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from torchvision import datasets 4 | from torchvision.transforms import ToTensor 5 | 6 | 7 | class ReferenceDataset(data.Dataset): 8 | """ Reference dataset that proxies a torchvision stock dataset 9 | """ 10 | 11 | def __init__(self, name: str, params: dict): 12 | super(ReferenceDataset, self).__init__() 13 | self.ds = getattr(datasets, name)(**params, transform=ToTensor()) 14 | 15 | def __getitem__(self, index): 16 | return self.ds[index] 17 | 18 | def __len__(self): 19 | return len(self.ds) 20 | -------------------------------------------------------------------------------- /src/models/encoder_wrapper.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor, nn 2 | from typing import List 3 | 4 | 5 | class EncoderWrapper(nn.Module): 6 | 7 | def __init__(self, latent_dim: int, layers: nn.Module, mu: nn.Module, 8 | var: nn.Module): 9 | super(EncoderWrapper, self).__init__() 10 | self.latent_dim = latent_dim 11 | self.layers = layers 12 | self.mu = mu 13 | self.var = var 14 | 15 | def forward(self, input: Tensor) -> List[Tensor]: 16 | x = self.layers(input) 17 | mu = self.mu(x) 18 | var = self.var(x) 19 | return [mu, var] 20 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint with yapf & isort 2 | run-name: Enforce consistent formatting with yapf & isort 3 | on: [push] 4 | jobs: 5 | Check-Formatting: 6 | runs-on: ubuntu-latest 7 | steps: 8 | - name: Checkout repository 9 | uses: actions/checkout@v4 10 | - name: Install dependencies 11 | run: sudo apt-get update && sudo apt-get install -y python3-yapf python3-isort 12 | - name: Check formatting with yapf 13 | run: python3 -m yapf -rd src/ 14 | - name: Check formatting with isort 15 | run: python3 -m isort --only-modified --diff --check-only src/ 16 | -------------------------------------------------------------------------------- /experiments/cq500/include/common.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | pooling: max 3 | hidden_dims: [64, 128, 128, 128, 64] 4 | 5 | exp_params: 6 | batch_size: 8 7 | data: 8 | name: cq500 9 | loader: 10 | num_workers: 0 11 | drop_last: true 12 | training: 13 | dir: E:/cq500 14 | validation: 15 | #FIXME it will look like it's crazy effective this way lol 16 | dir: E:/cq500 17 | plot: 18 | fn: dcm 19 | params: 20 | scaling: 2.0 21 | suptitle: 22 | y: 0.78 23 | 24 | trainer_params: 25 | max_epochs: 20 26 | log_every_n_steps: 20 27 | check_val_every_n_epoch: 1 -------------------------------------------------------------------------------- /src/linear_warmup.py: -------------------------------------------------------------------------------- 1 | from torch.optim import Optimizer 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | 4 | 5 | class LinearWarmup(_LRScheduler): 6 | 7 | def __init__(self, optimizer: Optimizer, lr: float, num_steps: int, *args, 8 | **kwargs): 9 | self._lr = lr 10 | self._num_steps = num_steps 11 | super().__init__(optimizer, *args, **kwargs) 12 | 13 | def get_lr(self): 14 | lr_scale = min(1.0, 15 | float(self._step_count + 1) / float(self._num_steps)) 16 | lr = lr_scale * self._lr 17 | return [lr] * len(self.optimizer.param_groups) 18 | -------------------------------------------------------------------------------- /experiments/cq500/README.md: -------------------------------------------------------------------------------- 1 | These experiments utilize the [CQ500 dataset from qure.ai](http://headctstudy.qure.ai/dataset). 2 | 3 | ## Results 4 | (TODO: insert picture of validation results) 5 | 6 | ## TODO 7 | - Download data from spaces 8 | - Write dataset class 9 | - Visualize results 10 | 11 | ## License 12 | ### Data 13 | CQ500 is licensed under [Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)](https://creativecommons.org/licenses/by-nc-sa/4.0/). 14 | 15 | ### Code 16 | Apache 2.0 / MIT dual-license. Please contact me if this is somehow not permissive enough and we'll add whatever free license is necessary for your project. 17 | -------------------------------------------------------------------------------- /scripts/tensorboard.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euo pipefail 3 | cd "$(dirname "$0")"/.. 4 | logdir() { 5 | if [ -f "/etc/wsl.conf" ]; then 6 | echo "$(wslpath -a logs)" 7 | else 8 | echo "$(pwd)/logs" 9 | fi 10 | } 11 | name=hometb 12 | docker kill $name || true 2>/dev/null 13 | docker container stop $name || true 2>/dev/null 14 | docker container rm $name || true 2>/dev/null 15 | docker run \ 16 | -d \ 17 | --rm \ 18 | --name $name \ 19 | -p 6006:6006 \ 20 | -v $(logdir):/logs \ 21 | tensorflow/tensorflow:latest-jupyter \ 22 | tensorboard \ 23 | --logdir /logs \ 24 | --host 0.0.0.0 \ 25 | --port 6006 26 | -------------------------------------------------------------------------------- /experiments/include/vae2d/mse.yaml: -------------------------------------------------------------------------------- 1 | entrypoint: vae2d 2 | 3 | model_params: 4 | arch: 'ResNetVAE2d' 5 | name: 'ResNetVAE2d' 6 | latent_dim: 32 7 | hidden_dims: [64, 128, 128, 64] 8 | 9 | exp_params: 10 | optimizer: 11 | lr: 0.0002 12 | weight_decay: 0.00001 13 | batch_size: 32 14 | warmup_steps: 512 15 | plot: 16 | fn: 'plot2d' 17 | batch_size: 16 18 | sample_every_n_steps: 10_000 19 | params: 20 | title: "Epoch ${epoch}" 21 | scaling: 1.0 22 | rows: 4 23 | cols: 4 24 | suptitle: 25 | y: 0.82 26 | 27 | trainer_params: 28 | max_epochs: 100 29 | log_every_n_steps: 200 30 | check_val_every_n_epoch: 1 31 | 32 | manual_seed: 1498 33 | -------------------------------------------------------------------------------- /src/niitest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import nibabel as nib 5 | import nilearn as nl 6 | import nilearn.plotting as nlplt 7 | import numpy as np 8 | 9 | img = nl.image.load_img( 10 | 'E:/openneuro/ds000030-download/sub-10159/anat/sub-10159_T1w.nii.gz') 11 | img = nl.image.load_img( 12 | 'E:/openneuro/ds000113-download/sub-01/ses-forrestgump/func/sub-01_ses-forrestgump_task-forrestgump_acq-dico_run-02_bold.nii.gz' 13 | ) 14 | #img = nl.image.load_img('E:/openneuro/ds003151-download/sub-173/ses-hormoneabsent/func/sub-173_ses-hormoneabsent_task-nback_run-2_sbref.nii.gz') 15 | #img = nl.image.load_img('E:/sub-16_ses-mri_task-facerecognition_run-01_bold.nii') 16 | print(img) 17 | -------------------------------------------------------------------------------- /experiments/grasp_and_lift_eeg/include/common_vae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | output_activation: relu 3 | hidden_dims: [128, 256, 128] 4 | 5 | exp_params: 6 | batch_size: 64 7 | data: 8 | name: grasp-and-lift-eeg 9 | loader: 10 | num_workers: 0 11 | drop_last: true 12 | training: 13 | root: /data/grasp-and-lift-eeg-detection 14 | train: true 15 | num_samples: 8192 16 | validation: 17 | train: false 18 | plot: 19 | fn: eeg 20 | sample_every_n_steps: 1000 21 | batch_size: 4 22 | params: 23 | width: 6000 24 | height: 4000 25 | 26 | trainer_params: 27 | max_epochs: 10 28 | log_every_n_steps: 200 29 | check_val_every_n_epoch: 1 30 | -------------------------------------------------------------------------------- /experiments/trends-fmri/include/common.yaml: -------------------------------------------------------------------------------- 1 | exp_params: 2 | data: 3 | name: trends-fmri 4 | loader: 5 | num_workers: 0 6 | drop_last: true 7 | training: 8 | path: E:/trends-fmri/train 9 | validation: 10 | path: E:/trends-fmri/test 11 | plot: 12 | - fn: fmri_prob_atlas 13 | batch_size: 16 14 | params: 15 | bg_img: 'E:/trends-fmri/ch2better.nii' 16 | mask_path: 'E:/trends-fmri/fMRI_mask.nii' 17 | rows: 4 18 | cols: 4 19 | scaling: 3.0 20 | dpi: 330 21 | suptitle: 22 | y: 0.91 23 | title: '${model}, fMRI Original (top) vs. Reconstruction (bottom), Epoch ${epoch}' 24 | - fn: fmri_stat_map_video 25 | params: {} 26 | -------------------------------------------------------------------------------- /experiments/include/vae3d/README.md: -------------------------------------------------------------------------------- 1 | # 3D Variational Autoencoders 2 | Adding another dimension to the classic experiment, 3D VAE is capable of handling voxel and video data. 3 | 4 | ## Flavors 5 | Several derivatives VAE experiments are available in the separate yaml files. 6 | - **mse.yaml** is the original experiment with Mean Square Error and [KLD](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) loss 7 | - **adversarial.yaml** leverages adversarial regularization inspired by [Goodfellow 2014](https://arxiv.org/abs/1406.2661) 8 | 9 | ## Compatible Datasets 10 | This experiment can be applied to the following datasets: 11 | - video data 12 | - any kind of sMRI 13 | 14 | ## Results 15 | (TODO: insert picture of validation results) -------------------------------------------------------------------------------- /experiments/mnist/classification/comparison.yaml: -------------------------------------------------------------------------------- 1 | entrypoint: comparison 2 | 3 | name: MNIST_Comparison 4 | 5 | # A list of experiments to compare 6 | series: 7 | - experiments/mnist/classification/basic.yaml 8 | - experiments/mnist/classification/embed_mse.yaml 9 | - experiments/mnist/classification/embed_fid.yaml 10 | - experiments/mnist/classification/sandwich_mse.yaml 11 | - experiments/mnist/classification/sandwich_fid.yaml 12 | 13 | # Run each experiment N times and average the results 14 | num_samples: 3 15 | 16 | plot: 17 | # A list of columns from metrics.csv, each one getting 18 | # a separate output image. 19 | metrics: 20 | - loss 21 | - accuracy 22 | - avg_val_loss 23 | - avg_val_acc 24 | 25 | width: 1920 26 | height: 1080 27 | 28 | -------------------------------------------------------------------------------- /experiments/trends-fmri/regression/basic.yaml: -------------------------------------------------------------------------------- 1 | entrypoint: regression4d 2 | 3 | model_params: 4 | arch: 'ResNetRegressor4d' 5 | name: 'ResNetRegressor4d' 6 | hidden_dims: [64, 64] 7 | dropout: 0.1 8 | 9 | exp_params: 10 | manual_seed: 6602 11 | data: 12 | name: trends-fmri 13 | training: 14 | root: /data/trends-fmri 15 | loader: 16 | pin_memory: false 17 | num_workers: 0 18 | batch_size: 2 19 | warmup_steps: 512 20 | optimizer: 21 | lr: 0.00005 22 | weight_decay: 0.000005 23 | 24 | trainer_params: 25 | max_epochs: 1000000 26 | log_every_n_steps: 500 27 | val_check_interval: 10000 28 | limit_val_batches: 2048 29 | 30 | logging_params: 31 | save_dir: "logs/" 32 | name: "TReNDS_Basic" 33 | 34 | manual_seed: 6602 35 | -------------------------------------------------------------------------------- /experiments/forrestgump/classification/unaligned.yaml: -------------------------------------------------------------------------------- 1 | entrypoint: classification 2 | 3 | model_params: 4 | arch: 'ResNetClassifier3d' 5 | name: 'ResNetClassifier3d' 6 | num_classes: 2 7 | hidden_dims: [32, 32, 32, 32] 8 | dropout: 0.2 9 | 10 | exp_params: 11 | loss_params: 12 | objective: mse 13 | baseline_accuracy: 0.7084 14 | batch_size: 2 15 | optimizer: 16 | lr: 0.025 17 | weight_decay: 0.00001 18 | data: 19 | name: forrestgump 20 | split: 0.6 21 | loader: 22 | pin_memory: false 23 | num_workers: 0 24 | drop_last: true 25 | training: 26 | root: /data/openneuro/ds000113-download 27 | offset: 4.0 28 | warmup_steps: 32 29 | 30 | trainer_params: 31 | log_every_n_steps: 50 32 | check_val_every_n_epoch: 1 33 | 34 | manual_seed: 1498 35 | 36 | logging_params: 37 | name: "ForrestGump_Conv3d_Unaligned" 38 | 39 | -------------------------------------------------------------------------------- /experiments/include/vae4d/README.md: -------------------------------------------------------------------------------- 1 | # 4D Variational Autoencoders 2 | Functional magnetic resonance imaging (fMRI) data is nontrivial to handle due to its very high dimensionsality. Additionally, there are only few ways to interpret model performance. It is customary to visualize 4D data as animated 3D data. 3 | 4 | ![fMRI example](images/example.gif) 5 | 6 | ## Flavors 7 | Several derivatives VAE experiments are available in the separate yaml files: 8 | - **mse.yaml** is the original experiment with Mean Square Error and [KLD](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) loss 9 | - **adversarial.yaml** leverages adversarial regularization inspired by [Goodfellow 2014](https://arxiv.org/abs/1406.2661) 10 | 11 | ## Compatible Datasets 12 | This experiment can be applied to the following datasets: 13 | - [TReNDS Neuroimaging](https://www.kaggle.com/c/trends-assessment-prediction/) 14 | -------------------------------------------------------------------------------- /src/models/upscale2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Upscale2d(nn.Module): 6 | 7 | def __init__(self, factor=2, gain=1): 8 | super(Upscale2d, self).__init__() 9 | self.gain = gain 10 | self.factor = factor 11 | 12 | def forward(self, x): 13 | if self.gain != 1: 14 | x = x * self.gain 15 | if self.factor > 1: 16 | shape = x.shape 17 | x = x.view(shape[0], shape[1], shape[2], 1, shape[3], 18 | 1).expand(-1, -1, -1, self.factor, -1, self.factor) 19 | x = x.contiguous().view(shape[0], shape[1], self.factor * shape[2], 20 | self.factor * shape[3]) 21 | return x 22 | 23 | 24 | if __name__ == '__main__': 25 | assert Upscale2d()(torch.randn(32, 1, 28, 26 | 28)).shape == torch.Size([32, 1, 56, 56]) 27 | -------------------------------------------------------------------------------- /experiments/include/vae1d/README.md: -------------------------------------------------------------------------------- 1 | # 1D Variational Autoencoders 2 | Just like image data, time series can be transformed by convolutional networks. 3 | 4 | ## Flavors 5 | Several derivatives VAE experiments are available in the separate yaml files. 6 | - **mse.yaml** is the original experiment with Mean Square Error and [KLD](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) loss 7 | - **fid.yaml** is identical to mse.yaml, but [FID(original, reconstruction)](https://en.wikipedia.org/wiki/Fr%C3%A9chet_inception_distance) is included in loss 8 | - **adversarial.yaml** leverages adversarial regularization inspired by [Goodfellow 2014](https://arxiv.org/abs/1406.2661) 9 | 10 | ## Compatible Datasets 11 | This experiment can be applied to the following datasets: 12 | - [Grasp-and-Lift EEG Detection](https://www.kaggle.com/c/grasp-and-lift-eeg-detection) 13 | 14 | ## Results 15 | (TODO: insert picture of validation results) -------------------------------------------------------------------------------- /experiments/la5c/classification/bilingual.yaml: -------------------------------------------------------------------------------- 1 | entrypoint: classification 2 | 3 | model_params: 4 | arch: 'ResNetClassifier3d' 5 | name: 'ResNetClassifier3d' 6 | num_classes: 1 7 | pooling: max 8 | hidden_dims: [10, 18, 18, 10] 9 | dropout: 0.2 10 | 11 | exp_params: 12 | loss_params: 13 | objective: bce 14 | baseline_accuracy: 0.6212121212121212 15 | batch_size: 2 16 | optimizer: 17 | lr: 0.03 18 | weight_decay: 0.00001 19 | data: 20 | name: la5c 21 | split: 0.6 22 | loader: 23 | pin_memory: false 24 | num_workers: 7 25 | drop_last: true 26 | training: 27 | root: /data/openneuro/ds000030-download 28 | phenotypes: 29 | - language/bilingual 30 | exclude_na: true 31 | warmup_steps: 32 32 | 33 | trainer_params: 34 | log_every_n_steps: 20 35 | check_val_every_n_epoch: 1 36 | gpus: 1 37 | 38 | manual_seed: 1498 39 | 40 | logging_params: 41 | name: "LA5c_Bilingual" 42 | 43 | -------------------------------------------------------------------------------- /experiments/graphics/gbuffer/basic.yaml: -------------------------------------------------------------------------------- 1 | entrypoint: neural_gbuffer 2 | 3 | model_params: 4 | arch: ResNetRenderer2d 5 | name: ResNetRenderer2d 6 | hidden_dims: [512, 1024, 2048, 1024, 512] 7 | output_layer: fc 8 | 9 | exp_params: 10 | loss_params: 11 | fid_weight: 1.0 12 | optimizer: 13 | lr: 0.00001 14 | weight_decay: 0.00001 15 | data: 16 | name: toy-neural-graphics 17 | training: 18 | dir: data/ 19 | rasterization_settings: 20 | image_size: 128 21 | blur_radius: 0.0 22 | faces_per_pixel: 1 23 | loader: 24 | num_workers: 0 25 | drop_last: true 26 | batch_size: 4 27 | warmup_steps: 256 28 | plot: 29 | fn: 'plot2d' 30 | batch_size: 16 31 | sample_every_n_steps: 5_000 32 | params: 33 | rows: 4 34 | cols: 4 35 | 36 | trainer_params: 37 | max_epochs: 100_000 38 | log_every_n_steps: 20 39 | check_val_every_n_epoch: 100_000 40 | 41 | logging_params: 42 | name: "BasicNeuralRenderer" 43 | 44 | -------------------------------------------------------------------------------- /src/models/resnet_gaussian_localizer2d.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor, nn 2 | from torch.distributions import Normal 3 | 4 | from .resnet_localizer2d import ResNetLocalizer2d 5 | 6 | 7 | class ResNetGaussianLocalizer2d(ResNetLocalizer2d): 8 | 9 | def __init__(self, kappa: float = 0.05, **kwargs) -> None: 10 | super().__init__(**kwargs) 11 | self.kappa = kappa 12 | std_dev = [nn.Linear(self.prediction.in_features, 4)] 13 | if self.batch_normalize: 14 | std_dev.append(nn.BatchNorm1d(4)) 15 | std_dev.append(self.activation) 16 | self.std_dev = nn.Sequential(*std_dev) 17 | 18 | def forward(self, x: Tensor) -> Tensor: 19 | y = self.layers(x) 20 | mu = self.output(self.prediction(y)) 21 | std_dev = self.std_dev(y) 22 | pred = reparameterize_normal(mu, std_dev * self.kappa) 23 | return pred 24 | 25 | 26 | def reparameterize_normal(mu: Tensor, std_dev: Tensor) -> Tensor: 27 | return Normal(mu, std_dev).rsample() 28 | -------------------------------------------------------------------------------- /experiments/deeplesion/vae/mse.yaml: -------------------------------------------------------------------------------- 1 | include: 2 | - experiments/include/vae2d/mse.yaml 3 | 4 | model_params: 5 | pooling: max 6 | dropout: 0.1 7 | hidden_dims: [64, 128, 256, 128, 64] 8 | 9 | exp_params: 10 | warmup_steps: 256 11 | save_weights: 12 | every_n_steps: 10_000 13 | local: {} 14 | batch_size: 4 15 | optimizer: 16 | lr: 0.01 17 | weight_decay: 0.00001 18 | data: 19 | name: deeplesion 20 | split: 0.7 21 | flatten_labels: true 22 | loader: 23 | num_workers: 0 24 | drop_last: true 25 | training: 26 | root: /data/deeplesion 27 | only_positives: true 28 | include_label: false 29 | components: 30 | - bounding_boxes 31 | visdom: 32 | host: https://visdom.foldy.dev 33 | port: 443 34 | env: deeplesion 35 | plot: 36 | params: 37 | img_filter: apply_softwindow 38 | 39 | trainer_params: 40 | max_epochs: 10_000_000 41 | log_every_n_steps: 20 42 | check_val_every_n_epoch: 100 43 | 44 | logging_params: 45 | name: "DeepLesionVAE_MSE" 46 | -------------------------------------------------------------------------------- /src/models/augmenter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import abstractmethod 3 | from torch import Tensor, nn 4 | from typing import List 5 | 6 | 7 | class Augmenter(nn.Module): 8 | 9 | def __init__(self, name: str) -> None: 10 | super(Augmenter, self).__init__() 11 | self.name = name 12 | 13 | @abstractmethod 14 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 15 | raise NotImplementedError 16 | 17 | def loss_function(self, 18 | x: Tensor, 19 | constraint: nn.Module, 20 | alpha: float = 1.0) -> dict: 21 | t = self.forward(x) 22 | co = constraint(x) 23 | ct = constraint(t) 24 | td = torch.pow(ct - co, 2).mean() # lower is better 25 | ud = torch.pow(x - t, 2).mean() # higher is better 26 | loss = td - ud * alpha 27 | return { 28 | 'loss': loss, 29 | 'TransformedDelta': td.detach().cpu(), 30 | 'UntransformedDelta': ud.detach().cpu() 31 | } 32 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Thomas Havlik 2024 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /experiments/grasp_and_lift_eeg/rl/ppo.yaml: -------------------------------------------------------------------------------- 1 | entrypoint: rl2d 2 | 3 | algorithm: PPO 4 | 5 | run_params: 6 | local_dir: "ray_results/" 7 | 8 | checkpoint_at_end: true 9 | checkpoint_freq: 5000 10 | 11 | config: 12 | # https://docs.ray.io/en/master/rllib-training.html#common-parameters 13 | rollout_fragment_length: 128 14 | sgd_minibatch_size: 128 15 | train_batch_size: 128 16 | env_config: 17 | name: TimeSeriesDetector 18 | low: -20000.0 19 | high: 20000.0 20 | observation_length: 8192 21 | channels: 32 22 | num_event_classes: 6 23 | data: 24 | name: grasp-and-lift-eeg 25 | params: 26 | dir: E:/grasp-and-lift-eeg-detection/train 27 | num_gpus: 1.0 28 | num_gpus_per_worker: 0 29 | num_workers: 0 # TODO: fix https://github.com/ray-project/ray/issues/7583 30 | num_envs_per_worker: 4 31 | model: 32 | custom_model: ResNetRL1d 33 | custom_model_config: 34 | num_samples: 8192 35 | channels: 32 36 | hidden_dims: [128, 256, 256, 128] 37 | pooling: max 38 | lr: 0.00001 39 | 40 | manual_seed: 1498 41 | -------------------------------------------------------------------------------- /experiments/deeplesion/localization/basic.yaml: -------------------------------------------------------------------------------- 1 | entrypoint: localization2d 2 | 3 | model_params: 4 | arch: ResNetLocalizer2d 5 | name: ResNetLocalizer2d 6 | pooling: max 7 | batch_normalize: true 8 | hidden_dims: [64, 128, 256, 512, 512, 512] 9 | output_activation: sigmoid 10 | 11 | exp_params: 12 | scheduler: 13 | warmup_steps: 256 14 | reduce_lr_on_plateau: 15 | factor: 0.1 16 | patience: 1 17 | threshold: 1.0e-04 18 | save_weights: 19 | every_n_steps: 10_000 20 | local: {} 21 | loss_params: 22 | objective: cbiou+dbiou+gbiou 23 | batch_size: 10 24 | optimizer: 25 | lr: 0.00005 26 | weight_decay: 1.0e-07 27 | data: 28 | name: deeplesion 29 | split: 0.3 30 | flatten_labels: true 31 | loader: 32 | num_workers: 4 33 | drop_last: true 34 | training: 35 | root: /opt/data/deeplesion 36 | only_positives: true 37 | components: 38 | - bounding_boxes 39 | plot: 40 | fn: localize_lesions 41 | sample_every_n_steps: 5_000 42 | batch_size: 9 43 | params: {} 44 | 45 | trainer_params: 46 | max_epochs: 10_000_000 47 | log_every_n_steps: 10 48 | check_val_every_n_epoch: 1 49 | 50 | logging_params: 51 | name: "DeepLesion_Basic" 52 | -------------------------------------------------------------------------------- /experiments/doom/README.md: -------------------------------------------------------------------------------- 1 | Variational autoencoders (VAE) and other experiments trained on DOOM 1/2 gameplay videos 2 | 3 | ## Motivation 4 | Latent representations and unsupervised pretraining boost data efficiency on more challenging supervised [1] and reinforcement learning tasks [2]. The goal of this project is to provide both the Doom and machine learning communities with: 5 | - High quality datasets comprised of Doom gameplay 6 | - Various ready-to-run experiments 7 | - Suitable boilerplate for derivative projects 8 | 9 | ## The Data 10 | Gameplay videos are sourced from YouTube with permission. Special thanks to the following creators for their contributions to the community and this dataset - these individuals are the lifeblood of the Doom community: 11 | - [Timothy Brown](https://www.youtube.com/user/mArt1And00m3r11339) 12 | - [decino](https://www.youtube.com/c/decino) 13 | - [Zero Master](https://www.youtube.com/channel/UCiVZWY9LmrJFOg3hWGjyBbw) 14 | 15 | The dataset is compiled by taking the original videos (~167 Gb worth) and re-encoding them as 320x240 @ 15fps videos (18 Gb, reduction of 80-90%). This allows data loaders to retrieve random frames at training speeds (128-1024 frames/sec) whereas the raw 1080p video can only be sampled at ~3 frames/sec (i7 7800X @ 3.5 GHz). 16 | 17 | ## TODO 18 | - Dataset compiler 19 | - ~~Doom gameplay video links~~ -------------------------------------------------------------------------------- /experiments/grasp_and_lift_eeg/augment/augment.yaml: -------------------------------------------------------------------------------- 1 | entrypoint: augment 2 | 3 | model_params: 4 | arch: 'ResNetAugmenter1d' 5 | name: 'ResNetAugmenter1d' 6 | hidden_dims: [512, 1024, 512] 7 | 8 | 9 | constraint_params: 10 | arch: 'ResNetClassifier1d' 11 | name: 'ResNetClassifier1d' 12 | hidden_dims: [512, 1024, 1024, 512] 13 | num_classes: 6 14 | kernel_size: 9 15 | padding: 4 16 | logits_only: true 17 | load_weights: /data/eeg/version_7/checkpoints/last.ckpt 18 | 19 | exp_params: 20 | save_weights: 21 | every_n_steps: 10_000 22 | local: {} 23 | data: 24 | name: grasp-and-lift-eeg 25 | training: 26 | root: /data/grasp-and-lift-eeg-detection 27 | num_samples: 2048 28 | last_label_only: true 29 | subjects: [1, 2, 3, 4, 5, 6, 7, 8, 9] 30 | validation: 31 | subjects: [10, 11, 12] 32 | loader: 33 | pin_memory: false 34 | num_workers: 0 35 | loss_params: 36 | alpha: 0.000001 37 | batch_size: 16 38 | warmup_steps: 512 39 | optimizer: 40 | lr: 0.00005 41 | weight_decay: 0.000005 42 | 43 | trainer_params: 44 | max_epochs: 1000000 45 | log_every_n_steps: 500 46 | val_check_interval: 20000 47 | limit_val_batches: 5000 48 | 49 | logging_params: 50 | save_dir: "logs/" 51 | name: "GraspLift_Augment" 52 | 53 | -------------------------------------------------------------------------------- /src/decord_test.py: -------------------------------------------------------------------------------- 1 | import decord 2 | from decord import VideoLoader, VideoReader, cpu 3 | 4 | decord.bridge.set_bridge('torch') 5 | 6 | path = 'E:/doom/_9zaLSmRgGc.mp4' 7 | 8 | vl = VideoLoader( 9 | [ 10 | path, 11 | #'E:/doom/_BHunyDleDQ.mp4', 12 | ], 13 | ctx=[cpu(0)], 14 | shape=(10, 320, 240, 3), 15 | interval=0, 16 | skip=5, 17 | shuffle=1) 18 | ex = vl.next() 19 | vr = VideoReader(path, ctx=cpu(0)) 20 | # a file like object works as well, for in-memory decoding 21 | with open(path, 'rb') as f: 22 | vr = VideoReader(f, ctx=cpu(0)) 23 | print('video frames:', len(vr)) 24 | # 1. the simplest way is to directly access frames 25 | for i in range(len(vr)): 26 | # the video reader will handle seeking and skipping in the most efficient manner 27 | frame = vr[i] 28 | print(frame.shape) 29 | 30 | # To get multiple frames at once, use get_batch 31 | # this is the efficient way to obtain a long list of frames 32 | frames = vr.get_batch([1, 3, 5, 7, 9]) 33 | print(frames.shape) 34 | # (5, 240, 320, 3) 35 | # duplicate frame indices will be accepted and handled internally to avoid duplicate decoding 36 | frames2 = vr.get_batch([1, 2, 3, 2, 3, 4, 3, 4, 5]) 37 | print(frames2.shape) 38 | # (9, 240, 320, 3) 39 | 40 | # 2. you can do cv2 style reading as well 41 | # skip 100 frames 42 | vr.skip_frames(100) 43 | # seek to start 44 | vr.seek(0) 45 | batch = vr.next() 46 | print('frame shape:', batch.shape) 47 | -------------------------------------------------------------------------------- /experiments/include/vae2d/README.md: -------------------------------------------------------------------------------- 1 | # 2D Variational Autoencoders 2 | My favorite unsupervised modeling task. This experiment embeds 2D imagery in a compact latent space by modeling its principle components as a multivariate gaussian, a la [Kingma & Welling 2013](https://arxiv.org/abs/1312.6114). 3 | 4 | ![Vae2D Banner](images/vae2d_banner.png) 5 | 6 | ## Flavors 7 | Several derivatives VAE experiments are available in the separate yaml files. 8 | - **mse.yaml** is the original experiment with [Mean Squared Error](https://en.wikipedia.org/wiki/Mean_squared_error) and [KLD](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) loss, from which all other experiments are derived 9 | - **fid.yaml** includes [FID(original, reconstruction)](https://en.wikipedia.org/wiki/Fr%C3%A9chet_inception_distance) as part of the loss function 10 | - **adversarial.yaml** leverages adversarial regularization inspired by [Goodfellow 2014](https://arxiv.org/abs/1406.2661) 11 | 12 | ## Compatible Datasets 13 | This experiment can be applied to the following datasets: 14 | - [RSNA Intracranial Hemorrhage Prediction](https://www.kaggle.com/c/rsna-intracranial-hemorrhage-detection) 15 | - [CQ500](http://headctstudy.qure.ai/dataset) 16 | - [DeepLesion](https://www.nih.gov/news-events/news-releases/) 17 | - [torchvision datasets](https://pytorch.org/docs/stable/torchvision/datasets.html) 18 | 19 | ## Results 20 | Reconstruction is on the left, original is on the right. 21 | ### MNIST 22 | ![MNIST_VAE_ResNetVAE2d_FID_49.png](images/MNIST_VAE_ResNetVAE2d_FID_49.png) -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM rayproject/ray-ml:latest-gpu 2 | USER root 3 | RUN apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub \ 4 | && apt-get update \ 5 | && apt-get install -y \ 6 | chromium-browser \ 7 | fonts-liberation \ 8 | xvfb \ 9 | poppler-utils \ 10 | libxss1 \ 11 | libnss3 \ 12 | libnss3-dev \ 13 | libgdk-pixbuf2.0-dev \ 14 | libgtk-3-dev \ 15 | libxss-dev \ 16 | libasound2 \ 17 | libgtk2.0-dev \ 18 | zlib1g-dev \ 19 | libgl1-mesa-dev \ 20 | nodejs \ 21 | npm \ 22 | nano \ 23 | htop \ 24 | && rm -rf /var/lib/apt/lists/* \ 25 | && apt-get clean \ 26 | && npm install -g \ 27 | orca \ 28 | vtop 29 | RUN echo 'alias watchsmi="watch -n 0.5 nvidia-smi"' >> /root/.bashrc 30 | WORKDIR /app 31 | COPY requirements.txt . 32 | #RUN conda install cudatoolkit=10.1 33 | RUN pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html 34 | #RUN pip install https://download.pytorch.org/whl/cu111/torch-1.7.1%2Bcu101-cp37-cp37m-linux_x86_64.whl 35 | #RUN pip install https://download.pytorch.org/whl/cu111/torchvision-0.8.1%2Bcu101-cp37-cp37m-linux_x86_64.whl 36 | RUN pip install awscli --force-reinstall --upgrade --ignore-installed 37 | RUN pip install 'git+https://github.com/thavlik/nonechucks.git' 38 | RUN pip install -r requirements.txt 39 | WORKDIR /machine-learning-portfolio 40 | COPY . . 41 | CMD ["./docker_entrypoint.sh"] 42 | -------------------------------------------------------------------------------- /src/vae_classifier.py: -------------------------------------------------------------------------------- 1 | from models import Classifier 2 | from vae import VAEExperiment 3 | 4 | 5 | class VAEClassifierExperiment(VAEExperiment): 6 | 7 | def __init__(self, classifier: Classifier, **kwargs) -> None: 8 | super().__init__(**kwargs) 9 | self.classifier = classifier 10 | 11 | def training_step(self, batch, batch_idx): 12 | # First move to device so it doesn't happen twice 13 | real_img, labels = batch 14 | real_img = real_img.to(self.curr_device) 15 | batch = (real_img, labels) 16 | loss, results = super().training_step_raw(batch, batch_idx, 17 | optimizer_idx) 18 | z = results[4] # Check BaseVAE.loss_function() 19 | prediction = self.classifier(z) 20 | classifier_loss = self.classifier.loss_function(prediction, labels) 21 | loss['loss'] += classifier_loss 22 | loss['Classifier_Loss'] = classifier_loss 23 | return loss 24 | 25 | def validation_step(self, batch, batch_idx): 26 | real_img, labels = batch 27 | real_img = real_img.to(self.curr_device) 28 | batch = (real_img, labels) 29 | loss, results = super().validation_step_raw(batch, batch_idx, 30 | optimizer_idx) 31 | z = results[4] # Check BaseVAE.loss_function() 32 | prediction = self.classifier(z) 33 | classifier_loss = self.classifier.loss_function(prediction, labels) 34 | loss['loss'] += classifier_loss 35 | loss['Classifier_Loss'] = classifier_loss 36 | return loss 37 | -------------------------------------------------------------------------------- /src/plot_stat_map.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import nibabel as nib 5 | import nilearn as nl 6 | import nilearn.plotting as nlplt 7 | import numpy as np 8 | 9 | from dataset.trends_fmri import TReNDSfMRIDataset, load_subject 10 | 11 | base_path = '/data/trends-fmri' 12 | smri_filename = os.path.join(base_path, 'ch2better.nii') 13 | subject_filename = os.path.join(base_path, 'fMRI_test/10228.mat') 14 | 15 | ds = TReNDSfMRIDataset(base_path) 16 | 17 | subject_niimg = load_subject(subject_filename, ds.mask) 18 | grid_size = int(np.ceil(np.sqrt(subject_niimg.shape[0]))) 19 | fig, axes = plt.subplots(grid_size, 20 | grid_size, 21 | figsize=(grid_size * 10, grid_size * 10)) 22 | [axi.set_axis_off() for axi in axes.ravel()] 23 | row = -1 24 | for i, cur_img in enumerate(nl.image.iter_img(subject_niimg)): 25 | col = i % grid_size 26 | if col == 0: 27 | row += 1 28 | nlplt.plot_stat_map(cur_img, 29 | bg_img=smri_filename, 30 | title="IC %d" % i, 31 | axes=axes[row, col], 32 | threshold=3, 33 | colorbar=False) 34 | plt.show() 35 | 36 | img = nl.image.new_img_like(ds.mask, 37 | ds[0][0].numpy(), 38 | affine=ds.mask.affine, 39 | copy_header=True) 40 | nlplt.plot_prob_atlas(img, 41 | bg_img=smri_filename, 42 | view_type='filled_contours', 43 | draw_cross=False, 44 | threshold='auto') 45 | plt.show() 46 | -------------------------------------------------------------------------------- /experiments/rsna-intracranial/classification/basic.yaml: -------------------------------------------------------------------------------- 1 | include: experiments/rsna-intracranial/include/dataset.yaml 2 | 3 | entrypoint: classification 4 | 5 | model_params: 6 | arch: 'ResNetClassifier2d' 7 | name: 'ResNetClassifier2d' 8 | num_classes: 6 9 | hidden_dims: [128, 128, 128, 128] 10 | pooling: max 11 | 12 | exp_params: 13 | loss_params: 14 | objective: mse 15 | batch_size: 4 16 | optimizer: 17 | lr: 0.0002 18 | weight_decay: 0.00001 19 | #visdom: 20 | # host: https://visdom.foldy.dev 21 | # port: 443 22 | # env: rsna-ich 23 | plot: 24 | fn: classifier2d 25 | sample_every_n_steps: 10_000 26 | examples_per_class: 6 27 | classes: 28 | - name: Control 29 | labels: [0, 0, 0, 0, 0, 0] 30 | all: true 31 | baseline: 0.2 32 | - name: Epidural 33 | labels: [1, 0, 0, 0, 0, 0] 34 | all: false 35 | baseline: 0.2 36 | - name: Intraparenchymal 37 | labels: [0, 1, 0, 0, 0, 0] 38 | all: false 39 | baseline: 0.2 40 | - name: Intraventricular 41 | labels: [0, 0, 1, 0, 0, 0] 42 | all: false 43 | baseline: 0.2 44 | - name: Subarachnoid 45 | labels: [0, 0, 0, 1, 0, 0] 46 | all: false 47 | baseline: 0.2 48 | - name: Subdural 49 | labels: [0, 0, 0, 0, 1, 0] 50 | all: false 51 | baseline: 0.2 52 | - name: Any 53 | labels: [0, 0, 0, 0, 0, 1] 54 | all: false 55 | baseline: 0.2 56 | params: 57 | img_filter: ct 58 | 59 | trainer_params: 60 | max_epochs: 8 61 | log_every_n_steps: 20 62 | check_val_every_n_epoch: 1 63 | 64 | manual_seed: 1498 65 | 66 | logging_params: 67 | name: "RSNA_Basic" 68 | 69 | -------------------------------------------------------------------------------- /src/remove-corrupt-dcm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import time 5 | 6 | import pydicom 7 | 8 | from dataset import normalized_dicom_pixels 9 | 10 | parser = argparse.ArgumentParser(description='Remove corrupted DICOM') 11 | parser.add_argument('--dir', 12 | '-d', 13 | dest="dir", 14 | metavar='DIR', 15 | help='path to directory containing dcm files', 16 | default='E:/cq500') 17 | parser.add_argument('--log_interval', 18 | dest="log_interval", 19 | metavar='LOG_INTERVAL', 20 | help='how often to print updates to stdout', 21 | default=1000) 22 | args = parser.parse_args() 23 | 24 | print(f'Removing corrupted DICOM files at {args.dir}') 25 | 26 | start = time.time() 27 | 28 | files = [ 29 | os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(args.dir)) 30 | for f in fn if f.endswith('.dcm') 31 | ] 32 | 33 | print(f'Processing {len(files)} files...') 34 | 35 | num_removed = 0 36 | for i, file in enumerate(files): 37 | try: 38 | x = pydicom.dcmread(file, stop_before_pixels=False) 39 | x = normalized_dicom_pixels(x) 40 | if x.shape != (1, 512, 512): 41 | raise ValueError('wrong shape') 42 | except: 43 | path = os.path.join(args.dir, file) 44 | os.remove(path) 45 | num_removed += 1 46 | print(f'Removed corrupted {path}: {sys.exc_info()}') 47 | if i > 0 and i % args.log_interval == 0: 48 | print(f'Processed {i}/{len(files)}') 49 | 50 | elapsed = time.time() - start 51 | print(f'Removed {num_removed} corrupt examples in {elapsed} seconds') 52 | -------------------------------------------------------------------------------- /src/models/renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import abstractmethod 3 | from torch import Tensor, nn 4 | from torch.nn import functional as F 5 | from typing import List 6 | 7 | from .inception import InceptionV3 8 | 9 | 10 | class BaseRenderer(nn.Module): 11 | 12 | def __init__(self, 13 | name: str, 14 | enable_fid: bool, 15 | fid_blocks: List[int] = [3]) -> None: 16 | super(BaseRenderer, self).__init__() 17 | self.name = name 18 | self.enable_fid = enable_fid 19 | if enable_fid: 20 | self.inception = InceptionV3(fid_blocks, use_fid_inception=True) 21 | 22 | @abstractmethod 23 | def decode(self, world_matrix: Tensor, **kwargs) -> Tensor: 24 | raise NotImplementedError 25 | 26 | def forward(self, world_matrix: Tensor, **kwargs) -> List[Tensor]: 27 | return self.decode(torch.flatten(world_matrix, start_dim=1)) 28 | 29 | def loss_function(self, 30 | recons: Tensor, 31 | orig: Tensor, 32 | fid_weight: float = 1.0) -> dict: 33 | recons_loss = F.mse_loss(recons, orig) 34 | 35 | loss = recons_loss 36 | 37 | result = {'loss': loss, 'Reconstruction_Loss': recons_loss} 38 | 39 | if self.enable_fid: 40 | fid_loss = self.fid(orig, recons).sum() 41 | result['FID_Loss'] = fid_loss 42 | result['loss'] += fid_loss * fid_weight 43 | 44 | return result 45 | 46 | def fid(self, a: Tensor, b: Tensor) -> Tensor: 47 | a = self.inception(a) 48 | b = self.inception(b) 49 | fid = [torch.mean((x - y)**2).unsqueeze(0) for x, y in zip(a, b)] 50 | fid = torch.cat(fid, dim=0) 51 | return fid 52 | -------------------------------------------------------------------------------- /src/plot_test.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import pandas as pd 7 | import plotly.express as px 8 | import plotly.graph_objects as go 9 | from PIL import Image, ImageDraw, ImageFont 10 | from plotly.subplots import make_subplots 11 | 12 | from plot import plot_comparison 13 | 14 | img = Image.open('experiments/rsna-intracranial/images/img.png') 15 | font_path = os.path.join('data', 'fonts', 'arial.ttf') 16 | font = ImageFont.truetype(font_path, 24) 17 | draw = ImageDraw.Draw(img) 18 | draw.text((0, 0), "Sample Text", (255, 255, 255), font=font) 19 | img.show() 20 | 21 | fig = plt.figure(figsize=(1, 3), dpi=80) 22 | ax = fig.add_subplot(111) 23 | ax.imshow(img) 24 | ax.text(0, 25 | 0, 26 | str("Label"), 27 | horizontalalignment="left", 28 | verticalalignment="top") 29 | buf = io.BytesIO() 30 | data = fig.savefig(buf, format="png") 31 | buf.seek(0) 32 | img = Image.open(buf) 33 | plt.show() 34 | 35 | d1 = ImageDraw.Draw(img) 36 | f = d1.getfont() 37 | 38 | d1.text((0, 0), "Hello, TutorialsPoint!", fill=(255, 0, 0)) 39 | img.show() 40 | """ 41 | steps = 128 42 | items = [ 43 | ['sin(x)', [np.sin(2*np.pi*x/steps) * 0.5 + 0.5 44 | for x in range(steps)]], 45 | ['cos(x)', [np.cos(2*np.pi*x/steps) * 0.5 + 0.5 46 | for x in range(steps)]] 47 | ] 48 | fig = go.Figure() 49 | for name, y in items: 50 | x = np.arange(len(y)) 51 | fig.add_trace(go.Scatter(x=x, y=y, 52 | mode='lines', 53 | name=name)) 54 | fig.update_layout( 55 | title='test'.title(), 56 | xaxis_title="Epoch", 57 | yaxis_title="loss".title(), 58 | font=dict( 59 | size=18, 60 | ) 61 | ) 62 | fig.write_image('out.png') 63 | """ 64 | -------------------------------------------------------------------------------- /src/draw_eeg.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | import numpy as np 4 | import plotly.graph_objects as go 5 | from plotly.subplots import make_subplots 6 | 7 | from dataset import GraspAndLiftEEGDataset 8 | 9 | 10 | def eeg(orig: Tensor, 11 | out_path: str, 12 | width: int, 13 | height: int, 14 | line_opacity: float = 0.7, 15 | layout_params: dict = {}): 16 | batch_size, num_channels, num_samples = orig.shape 17 | cols = batch_size 18 | fig = make_subplots(rows=num_channels, cols=cols) 19 | #if layout_params is not None: 20 | # fig.update_layout(**layout_params) 21 | i = 0 22 | n = min(cols, batch_size) 23 | x = np.arange(num_samples) / num_samples 24 | for col in range(cols): 25 | if i >= n: 26 | break 27 | for channel in range(num_channels): 28 | yo = orig[i, channel, :] 29 | fig.add_trace(go.Scatter( 30 | x=x, 31 | y=yo, 32 | mode='lines', 33 | opacity=line_opacity, 34 | line=dict( 35 | color='red', 36 | width=2, 37 | ), 38 | ), 39 | row=channel + 1, 40 | col=col + 1) 41 | i += 1 42 | fig.write_image(out_path + '.png', width=width, height=height) 43 | 44 | 45 | ds = GraspAndLiftEEGDataset('/data/grasp-and-lift-eeg-detection') 46 | eeg(ds[0][0][:8].unsqueeze(0), 47 | 'out', 48 | 2048, 49 | 1024, 50 | layout_params=dict( 51 | xaxis=go.layout.XAxis( 52 | visible=False, 53 | showticklabels=False, 54 | ), 55 | yaxis=go.layout.YAxis( 56 | visible=False, 57 | showticklabels=False, 58 | ), 59 | )) 60 | -------------------------------------------------------------------------------- /src/models/resnet_augmenter1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Size, Tensor, nn 3 | from typing import List 4 | 5 | from .augmenter import Augmenter 6 | from .resnet1d import BasicBlock1d 7 | 8 | 9 | class ResNetAugmenter1d(Augmenter): 10 | 11 | def __init__(self, 12 | name: str, 13 | hidden_dims: List[int], 14 | input_shape: Size, 15 | load_weights: str = None, 16 | kernel_size: int = 3, 17 | padding: int = 1) -> None: 18 | super().__init__(name=name) 19 | self.num_samples = input_shape[1] 20 | self.channels = input_shape[0] 21 | self.hidden_dims = hidden_dims.copy() 22 | modules = [] 23 | in_features = self.channels 24 | for h_dim in hidden_dims: 25 | modules.append( 26 | BasicBlock1d(in_features, 27 | h_dim, 28 | kernel_size=kernel_size, 29 | padding=padding)) 30 | in_features = h_dim 31 | modules.append( 32 | BasicBlock1d(in_features, 33 | self.channels, 34 | kernel_size=kernel_size, 35 | padding=padding)) 36 | self.layers = nn.Sequential(*modules) 37 | if load_weights is not None: 38 | new = self.state_dict() 39 | old = torch.load(load_weights)['state_dict'] 40 | for k, v in new.items(): 41 | ok = f'model.{k}' 42 | if ok in old: 43 | new[k] = old[ok].cpu() 44 | print(f'Loaded weights for layer {k}') 45 | self.load_state_dict(new) 46 | 47 | def forward(self, x: Tensor) -> Tensor: 48 | return self.layers(x) 49 | -------------------------------------------------------------------------------- /src/models/resnet_embed2d.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor, nn 2 | from typing import List 3 | 4 | from .base import reparameterize 5 | from .classifier import Classifier 6 | from .encoder_wrapper import EncoderWrapper 7 | from .resnet2d import BasicBlock2d 8 | 9 | 10 | class ResNetEmbed2d(Classifier): 11 | 12 | def __init__(self, 13 | name: str, 14 | hidden_dims: List[int], 15 | width: int, 16 | height: int, 17 | channels: int, 18 | num_classes: int, 19 | encoder: EncoderWrapper, 20 | dropout: float = 0.4, 21 | pooling: str = None) -> None: 22 | super(ResNetEmbed2d, self).__init__(name=name) 23 | self.width = width 24 | self.height = height 25 | self.channels = channels 26 | self.hidden_dims = hidden_dims.copy() 27 | self.encoder = encoder 28 | 29 | self.decoder = nn.Linear(encoder.latent_dim, hidden_dims[0] * 4) 30 | modules = [] 31 | in_features = hidden_dims[0] 32 | for h_dim in hidden_dims: 33 | modules.append(BasicBlock2d(in_features, h_dim)) 34 | in_features = h_dim 35 | self.hidden_layers = nn.Sequential(*modules) 36 | self.output_layer = nn.Sequential( 37 | nn.Dropout(dropout), 38 | nn.Linear(in_features * 4, num_classes), 39 | nn.BatchNorm1d(num_classes), 40 | nn.Sigmoid(), 41 | ) 42 | 43 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 44 | mu, log_var = self.encoder(input) 45 | z = reparameterize(mu, log_var) 46 | y = self.decoder(z) 47 | y = y.view(y.shape[0], self.hidden_dims[-1], 2, 2) 48 | y = self.hidden_layers(y) 49 | y = y.reshape(y.shape[0], -1) 50 | y = self.output_layer(y) 51 | return y 52 | -------------------------------------------------------------------------------- /src/models/resnet_regressor4d.py: -------------------------------------------------------------------------------- 1 | from torch import Size, Tensor, nn 2 | from typing import List 3 | 4 | from .resnet4d import BasicBlock4d 5 | from .util import get_activation 6 | 7 | 8 | class ResNetRegressor4d(nn.Module): 9 | 10 | def __init__(self, 11 | name: str, 12 | hidden_dims: List[int], 13 | input_shape: Size, 14 | output_features: int, 15 | dropout: float = 0.3, 16 | batch_normalize: bool = False, 17 | output_activation: str = 'sigmoid') -> None: 18 | super().__init__(name=name) 19 | self.width = input_shape[4] 20 | self.height = input_shape[3] 21 | self.depth = input_shape[2] 22 | self.frames = input_shape[1] 23 | self.channels = input_shape[0] 24 | self.batch_normalize = batch_normalize 25 | self.hidden_dims = hidden_dims.copy() 26 | modules = [] 27 | in_features = self.channels 28 | for h_dim in hidden_dims: 29 | modules.append(BasicBlock4d(in_features, h_dim)) 30 | in_features = h_dim 31 | self.layers = nn.Sequential( 32 | *modules, 33 | nn.Flatten(), 34 | nn.Dropout(p=dropout), 35 | ) 36 | in_features = hidden_dims[ 37 | -1] * self.width * self.height * self.depth * self.frames 38 | self.activation = get_activation(output_activation) 39 | self.prediction = nn.Linear(in_features, output_features) 40 | if batch_normalize: 41 | self.output = nn.Sequential( 42 | nn.BatchNorm1d(output_features), 43 | self.activation, 44 | ) 45 | else: 46 | self.output = self.activation 47 | 48 | def forward(self, x: Tensor) -> Tensor: 49 | x = self.layers(x) 50 | x = self.prediction(x) 51 | x = self.output(x) 52 | return x 53 | -------------------------------------------------------------------------------- /src/load_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Union 3 | 4 | import yaml 5 | 6 | from merge_strategy import deep_merge 7 | 8 | 9 | def load_config(path: Union[str, List[str]]) -> Union[dict, List[dict]]: 10 | paths = path if type(path) == list else [path] 11 | result = {} 12 | for path in paths: 13 | if os.path.isdir(path): 14 | # If a directory is passed, it's the same as having 15 | # a yaml with a 'series' list of all the files in 16 | # the directory. 17 | configs = [] 18 | for f in os.listdir(path): 19 | if os.path.basename(f) == 'include': 20 | # Include folders do not contain any files 21 | # that can be executed directly. 22 | continue 23 | fp = os.path.join(path, f) 24 | if not os.path.isdir(fp) and not f.endswith('.yaml'): 25 | # Ignore non-yaml files like README.md 26 | continue 27 | fc = load_config(fp) 28 | if type(fc) == list and len(fc) == 0: 29 | # Exclude empty directories 30 | continue 31 | configs.append(fc) 32 | return configs 33 | with open(path, 'r') as f: 34 | config = yaml.safe_load(f) 35 | if 'include' in config: 36 | # Recursively deep merge all the includes 37 | includes = config['include'] 38 | if type(includes) is not list: 39 | includes = [includes] 40 | merged = {} 41 | for include in includes: 42 | merged = deep_merge(merged, load_config(include)) 43 | # Merge this config file in last 44 | config = deep_merge(merged, config) 45 | # Remove include directive now that merge has occured 46 | del config['include'] 47 | result = deep_merge(result, config) 48 | return result 49 | -------------------------------------------------------------------------------- /src/dataset/video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | 5 | from decord import VideoReader, cpu 6 | 7 | 8 | class VideoDataset(data.Dataset): 9 | 10 | def __init__(self, dir: str, width: int, height: int, limit: int = None): 11 | super(VideoDataset, self).__init__() 12 | videos = [] 13 | for f in os.listdir(dir): 14 | if limit is not None and len(videos) >= limit: 15 | break 16 | if f.endswith('.mp4'): 17 | videos.append(f) 18 | self.vr = [ 19 | VideoReader(os.path.join(dir, f), 20 | ctx=cpu(0), 21 | width=width, 22 | height=height) for f in videos 23 | ] 24 | n = 0 25 | for vr in self.vr: 26 | n += len(vr) 27 | self.n = n 28 | 29 | def __getitem__(self, index): 30 | cur = 0 31 | for vr in self.vr: 32 | end = cur + len(vr) 33 | if index >= end: 34 | cur = end 35 | continue 36 | index -= cur 37 | x = vr[index] 38 | x = torch.transpose(x, 0, -1) 39 | x = torch.transpose(x, 1, -1) 40 | x = x.float() 41 | return (x, []) 42 | raise ValueError(f'failed to seek index {index}') 43 | 44 | def __len__(self): 45 | return self.n 46 | 47 | 48 | if __name__ == '__main__': 49 | import time 50 | batch_size = 16 51 | num_frames = 1 52 | width = 320 53 | height = 240 54 | ds = VideoDataset(dir='E:/doom', width=width, height=height, limit=10) 55 | indices = torch.randint(low=0, high=len(ds), 56 | size=(batch_size, 1)).squeeze() 57 | start = time.time() 58 | batch = [ds[int(i)] for i in indices] 59 | delta = time.time() - start 60 | print( 61 | f'Loaded {batch_size} examples in {delta} seconds ({delta/batch_size}) seconds avg)' 62 | ) 63 | print('ok') 64 | -------------------------------------------------------------------------------- /experiments/graphics/README.md: -------------------------------------------------------------------------------- 1 | # Neural Rendering (work in progress) 2 | Graphics rendering pipelines are becoming exponentially more complicated. Generative adversarial networks (GANs) are able to produce realistic imagery ([Goodfellow 2014](https://papers.nips.cc/paper/2014/file/5ca3e9b122f61f8f06494c97b1afccf3-Paper.pdf), [Karras et al 2019](https://arxiv.org/abs/1912.04958)), providing an alternative means of achieving computer graphics indistinguishable from reality. 3 | 4 | My interest in AI graphics is motivated by the goal of seeing the technology put to use in surgical training tools. While the applications are innumerable, surgery simulators with differentiable patient models - allowing the educator to gradually increase the difficulty of a case - are particularly interesting to me. 5 | 6 | ## The Basics 7 | The most basic neural rendering experiment attempts to reproduce the pixels drawn by a standard rasterization- based renderer according to a bounded transform. 8 | 9 | In this first experiment, the model is given a 4x4 transformation matrix as input, and is tasked with rendering the mesh with no further variation. For simplicity's sake, [mean squared error (MSE)](https://en.wikipedia.org/wiki/Mean_squared_error) was used to calculate reconstruction loss. 10 | 11 | > ![](images/NeuralGBuffer_plot2d_55000.png) 12 | ***Target image is on the left, model output is on the right.*** 13 | 14 | Unsurprisingly, use of MSE without a progressive growing strategy yields only blurry messes. Further work is required to simplify image synthesis early in training. 15 | 16 | ## TODO 17 | - implement progressive growing 18 | - implement FID loss 19 | 20 | ## License 21 | ### Data 22 | Cow model is part of [PyTorch3D](https://github.com/facebookresearch/pytorch3d) and is [licensed under BSD](https://raw.githubusercontent.com/facebookresearch/pytorch3d/main/LICENSE). 23 | 24 | ### Code 25 | [Apache 2.0](../../LICENSE-Apache) / [MIT](../../LICENSE-MIT) dual-license. Please contact me if this is somehow not permissive enough and we'll add whatever free license is necessary for your project. 26 | -------------------------------------------------------------------------------- /src/dataset/dicom_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import numpy as np 4 | from torchvision.transforms import Resize, ToPILImage, ToTensor 5 | 6 | 7 | def raw_dicom_pixels(ds): 8 | signed = ds.PixelRepresentation == 1 9 | slope = ds.RescaleSlope 10 | intercept = ds.RescaleIntercept 11 | x = ds.pixel_array 12 | x = np.frombuffer(x, dtype='int16' if signed else 'uint16') 13 | x = np.array(x, dtype='float32') 14 | x = x * slope + intercept 15 | x = torch.Tensor(x) 16 | #x = x.clamp(0.0, 1.0) 17 | # TODO: fix normalization 18 | x = x.view(1, 512, 512) 19 | return x 20 | 21 | 22 | def normalized_dicom_pixels(ds): 23 | signed = ds.PixelRepresentation == 1 24 | slope = float(ds.RescaleSlope) 25 | intercept = float(ds.RescaleIntercept) 26 | x = ds.pixel_array 27 | if ds.BitsStored == 12 and not signed and int(intercept) > -100: 28 | # see: https://www.kaggle.com/jhoward/cleaning-the-data-for-rapid-prototyping-fastai 29 | x += 1000 30 | px_mode = 4096 31 | x[x >= px_mode] = x[x >= px_mode] - px_mode 32 | intercept -= 1000 33 | x = np.frombuffer(x, dtype='int16' if signed else 'uint16') 34 | x = np.array(x, dtype='float32') 35 | x = x * slope + intercept 36 | x = torch.Tensor(x) 37 | if x.numel() != 512 * 512: 38 | #dim = torch.sqrt(torch.Tensor([x.numel()])) 39 | #if dim.floor() != dim.ceil(): 40 | # raise ValueError('Non-square number of input elements ' 41 | # f'got {x.numel()} (dcm header reports {ds.Rows}x{ds.Columns})') 42 | #dim = dim.int().item() 43 | if ds.Columns * ds.Rows != x.numel(): 44 | raise ValueError( 45 | f'dimensions {ds.Rows}x{ds.Columns} does not match numel {x.numel()}' 46 | ) 47 | x = x.view(1, ds.Columns, ds.Rows) 48 | x = ToPILImage()(x) 49 | x = Resize((512, 512))(x) 50 | x = ToTensor()(x) 51 | #print(f'Successfully resized from {ds.Rows}x{ds.Columns}') 52 | x = x.view(1, 512, 512) 53 | return x 54 | -------------------------------------------------------------------------------- /src/models/resnet4d.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from .batchnorm4d import BatchNorm4d 5 | from .conv4d import Conv4d 6 | 7 | 8 | class BasicBlock4d(nn.Module): 9 | expansion = 1 10 | 11 | def __init__(self, in_planes, planes, stride=1): 12 | super(BasicBlock4d, self).__init__() 13 | self.conv1 = Conv4d(in_planes, 14 | planes, 15 | kernel_size=3, 16 | stride=stride, 17 | padding=1, 18 | bias=False) 19 | self.bn1 = BatchNorm4d(planes) 20 | self.conv2 = Conv4d(planes, 21 | planes, 22 | kernel_size=3, 23 | stride=1, 24 | padding=1, 25 | bias=False) 26 | self.bn2 = BatchNorm4d(planes) 27 | 28 | if stride != 1 or in_planes != self.expansion * planes: 29 | self.shortcut = nn.Sequential( 30 | Conv4d(in_planes, 31 | self.expansion * planes, 32 | kernel_size=1, 33 | stride=stride, 34 | bias=False), BatchNorm4d(self.expansion * planes)) 35 | else: 36 | self.shortcut = None 37 | 38 | def forward(self, x): 39 | out = F.relu(self.bn1(self.conv1(x))) 40 | out = self.bn2(self.conv2(out)) 41 | if self.shortcut is not None: 42 | out += self.shortcut(x) 43 | out = F.relu(out) 44 | return out 45 | 46 | 47 | class TransposeBasicBlock4d(nn.Module): 48 | expansion = 1 49 | 50 | def __init__(self, in_planes, planes, stride=1): 51 | super(TransposeBasicBlock4d, self).__init__() 52 | raise NotImplementedError 53 | 54 | def forward(self, x): 55 | out = F.relu(self.bn1(self.conv1(x))) 56 | out = self.bn2(self.conv2(out)) 57 | if self.shortcut is not None: 58 | out += self.shortcut(x) 59 | out = F.relu(out) 60 | return out 61 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from ray.rllib.models import ModelCatalog 2 | 3 | from .base import BaseVAE 4 | from .classifier import Classifier 5 | from .localizer import Localizer 6 | from .renderer import BaseRenderer 7 | from .resnet_augmenter1d import ResNetAugmenter1d 8 | from .resnet_classifier1d import ResNetClassifier1d 9 | from .resnet_classifier2d import ResNetClassifier2d 10 | from .resnet_classifier3d import ResNetClassifier3d 11 | from .resnet_embed2d import ResNetEmbed2d 12 | from .resnet_gaussian_localizer2d import ResNetGaussianLocalizer2d 13 | from .resnet_localizer2d import ResNetLocalizer2d 14 | from .resnet_renderer2d import ResNetRenderer2d 15 | from .resnet_rl1d import ResNetRL1d 16 | from .resnet_rl2d import ResNetRL2d 17 | from .resnet_sandwich2d import ResNetSandwich2d 18 | from .resnet_vae1d import ResNetVAE1d 19 | from .resnet_vae2d import ResNetVAE2d 20 | from .resnet_vae3d import ResNetVAE3d 21 | from .resnet_vae4d import ResNetVAE4d 22 | 23 | # ModelCatalog.register_custom_model("ResNetRL1d", ResNetRL1d) 24 | # ModelCatalog.register_custom_model("ResNetRL2d", ResNetRL2d) 25 | 26 | models = { 27 | 'ResNetAugmenter1d': ResNetAugmenter1d, 28 | 'ResNetClassifier1d': ResNetClassifier1d, 29 | 'ResNetClassifier2d': ResNetClassifier2d, 30 | 'ResNetClassifier3d': ResNetClassifier3d, 31 | 'ResNetEmbed2d': ResNetEmbed2d, 32 | 'ResNetSandwich2d': ResNetSandwich2d, 33 | 'ResNetGaussianLocalizer2d': ResNetGaussianLocalizer2d, 34 | 'ResNetLocalizer2d': ResNetLocalizer2d, 35 | 'ResNetRenderer2d': ResNetRenderer2d, 36 | 'ResNetRL1d': ResNetRL1d, 37 | 'ResNetRL2d': ResNetRL2d, 38 | 'ResNetVAE1d': ResNetVAE1d, 39 | 'ResNetVAE2d': ResNetVAE2d, 40 | 'ResNetVAE3d': ResNetVAE3d, 41 | 'ResNetVAE4d': ResNetVAE4d, 42 | } 43 | 44 | 45 | def create_model(arch: str, **kwargs): 46 | if arch not in models: 47 | raise ValueError(f'unknown model architecture "{arch}" ' 48 | f'valid options are {models}') 49 | try: 50 | model = models[arch](**kwargs) 51 | except: 52 | print(f'failed to create model "{arch}"') 53 | raise 54 | return model 55 | -------------------------------------------------------------------------------- /src/models/resnet_classifier3d.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | from torch import Size, Tensor, nn 3 | from torch.nn import functional as F 4 | from typing import List 5 | 6 | from .classifier import Classifier 7 | from .resnet3d import BasicBlock3d 8 | from .util import get_pooling3d 9 | 10 | 11 | class ResNetClassifier3d(Classifier): 12 | 13 | def __init__(self, 14 | name: str, 15 | hidden_dims: List[int], 16 | input_shape: Size, 17 | num_classes: int, 18 | dropout: float = 0.4, 19 | pooling: str = None) -> None: 20 | super().__init__(name=name, num_classes=num_classes) 21 | self.width = input_shape[3] 22 | self.height = input_shape[2] 23 | self.depth = input_shape[1] 24 | self.channels = input_shape[0] 25 | self.dropout = dropout 26 | self.hidden_dims = hidden_dims.copy() 27 | if pooling is not None: 28 | pool_fn = get_pooling3d(pooling) 29 | modules = [] 30 | in_features = self.channels 31 | for h_dim in hidden_dims: 32 | modules.append(BasicBlock3d(in_features, h_dim)) 33 | if pooling is not None: 34 | modules.append(pool_fn(2)) 35 | in_features = h_dim 36 | self.layers = nn.Sequential(*modules) 37 | in_features = hidden_dims[-1] * self.width * self.height * self.depth 38 | if pooling is not None: 39 | in_features /= 8**len(hidden_dims) 40 | if abs(in_features - ceil(in_features)) > 0: 41 | raise ValueError( 42 | 'noninteger number of features - perhaps there is too much pooling?' 43 | ) 44 | in_features = int(in_features) 45 | self.output = nn.Sequential( 46 | nn.Linear(in_features, num_classes), 47 | nn.BatchNorm1d(num_classes), 48 | nn.Sigmoid(), 49 | ) 50 | 51 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 52 | y = self.layers(input) 53 | y = y.reshape((input.shape[0], -1)) 54 | y = F.dropout(y, p=self.dropout) 55 | y = self.output(y) 56 | return y 57 | -------------------------------------------------------------------------------- /experiments/grasp_and_lift_eeg/classification/basic.yaml: -------------------------------------------------------------------------------- 1 | entrypoint: classification 2 | 3 | model_params: 4 | arch: 'ResNetClassifier1d' 5 | name: 'ResNetClassifier1d' 6 | hidden_dims: [128, 256, 128] 7 | num_classes: 6 8 | 9 | exp_params: 10 | manual_seed: 6602 11 | warmup_steps: 512 12 | reduce_lr_on_plateau: 13 | factor: 0.1 14 | patience: 3 15 | threshold: 1.0e-04 16 | save_weights: 17 | every_n_steps: 10_000 18 | local: {} 19 | data: 20 | name: grasp-and-lift-eeg 21 | training: 22 | root: "C:/data/grasp-and-lift-eeg" #/data/grasp-and-lift-eeg-detection 23 | num_samples: 1024 24 | last_label_only: true 25 | subjects: [1, 2, 3, 4, 5, 6, 7, 8, 9] 26 | validation: 27 | subjects: [10, 11, 12] 28 | loader: 29 | pin_memory: false 30 | num_workers: 0 31 | batch_size: 32 32 | optimizer: 33 | lr: 0.00005 34 | weight_decay: 0.0000001 35 | plot: 36 | fn: classifier1d_multicolumn 37 | sample_every_n_steps: 20_000 38 | examples_per_class: 6 39 | classes: 40 | - name: Control 41 | labels: [0, 0, 0, 0, 0, 0] 42 | all: true 43 | - name: HandStart 44 | labels: [1, 0, 0, 0, 0, 0] 45 | all: true 46 | - name: FirstDigitTouch 47 | labels: [0, 1, 0, 0, 0, 0] 48 | all: true 49 | - name: BothStartLoadPhase 50 | labels: [0, 0, 1, 0, 0, 0] 51 | all: true 52 | - name: LiftOff 53 | labels: [0, 0, 0, 1, 0, 0] 54 | all: true 55 | - name: Replace 56 | labels: [0, 0, 0, 0, 1, 0] 57 | all: true 58 | - name: BothReleased 59 | labels: [0, 0, 0, 0, 0, 1] 60 | all: true 61 | params: 62 | width: 1024 63 | height: 256 64 | indicator_thickness: 12 65 | line_opacity: 0.3 66 | layout_params: 67 | showlegend: false 68 | margin: 69 | l: 0 70 | r: 0 71 | b: 0 72 | t: 0 73 | layout_params: 74 | showlegend: false 75 | 76 | trainer_params: 77 | max_epochs: 1_000_000 78 | log_every_n_steps: 100 79 | val_check_interval: 20_000 80 | limit_val_batches: 2_000 81 | 82 | logging_params: 83 | save_dir: "logs/" 84 | name: "GraspLift_Basic" 85 | 86 | manual_seed: 6602 87 | -------------------------------------------------------------------------------- /src/models/resnet_sandwich2d.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor, nn 2 | from typing import List 3 | 4 | from .base import reparameterize 5 | from .classifier import Classifier 6 | from .encoder_wrapper import EncoderWrapper 7 | from .resnet2d import BasicBlock2d 8 | 9 | 10 | class ResNetSandwich2d(Classifier): 11 | 12 | def __init__(self, 13 | name: str, 14 | hidden_dims: List[int], 15 | width: int, 16 | height: int, 17 | channels: int, 18 | num_classes: int, 19 | encoder: EncoderWrapper, 20 | sandwich_layers: List[nn.Module], 21 | dropout: float = 0.4, 22 | pooling: str = None) -> None: 23 | super(ResNetSandwich2d, self).__init__(name=name) 24 | self.width = width 25 | self.height = height 26 | self.channels = channels 27 | self.hidden_dims = hidden_dims.copy() 28 | self.sandwich_layers = sandwich_layers 29 | self.encoder = encoder 30 | 31 | self.decoder = nn.Linear(encoder.latent_dim, hidden_dims[0] * 4) 32 | modules = [] 33 | for h_dim, (layer, features) in zip(hidden_dims, sandwich_layers): 34 | modules.append(layer) 35 | in_features = features 36 | modules.append(BasicBlock2d(in_features, h_dim)) 37 | in_features = h_dim 38 | self.hidden_layers = nn.Sequential(*modules) 39 | self.output_layer = nn.Sequential( 40 | nn.Dropout(dropout), 41 | nn.Linear(in_features * 4, num_classes), 42 | nn.BatchNorm1d(num_classes), 43 | nn.Sigmoid(), 44 | ) 45 | 46 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 47 | mu, log_var = self.encoder(input) 48 | z = reparameterize(mu, log_var) 49 | y = self.decoder(z) 50 | y = y.view(y.shape[0], self.hidden_dims[-1], 2, 2) 51 | y = self.hidden_layers(y) 52 | y = y.reshape(y.shape[0], -1) 53 | y = self.output_layer(y) 54 | return y 55 | 56 | def set_sandwich_frozen(self, frozen: bool): 57 | self.encoder.requires_grad = frozen 58 | for layer, _ in self.sandwich_layers: 59 | layer.requires_grad = frozen 60 | -------------------------------------------------------------------------------- /src/dataset/batch_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from decord import VideoLoader, cpu 5 | 6 | 7 | class BatchVideoDataLoader(object): 8 | 9 | def __init__(self, 10 | dir: str, 11 | batch_size: int, 12 | num_frames: int, 13 | width: int, 14 | height: int, 15 | interval: int = 0, 16 | skip: int = 0, 17 | shuffle: int = 1, 18 | limit: int = None): 19 | super(BatchVideoDataLoader, self).__init__() 20 | self.batch_size = batch_size 21 | videos = [ 22 | os.path.join(dir, f) for f in os.listdir(dir) if f.endswith('.mp4') 23 | ] 24 | if limit is not None: 25 | videos = videos[:limit] 26 | self.vl = VideoLoader(videos, 27 | ctx=[cpu(0)], 28 | shape=(num_frames, width, height, 3), 29 | interval=interval, 30 | skip=skip, 31 | shuffle=shuffle) 32 | 33 | def __next__(self): 34 | x, y = [], [] 35 | for _ in range(self.batch_size): 36 | a, b = self.vl.next() 37 | # [F, W, H, C] -> [F, C, H, W] 38 | a = torch.transpose(a, 1, 3) 39 | x.append(a.unsqueeze(0)) 40 | y.append(b.unsqueeze(0)) 41 | x = torch.cat(x, dim=0).squeeze().float() 42 | y = torch.cat(y, dim=0).squeeze() 43 | return x, y 44 | 45 | def __len__(self): 46 | return len(self.vl) 47 | 48 | def __iter__(self): 49 | self.vl.__iter__() 50 | return self 51 | 52 | 53 | if __name__ == '__main__': 54 | import time 55 | num_frames = 1 56 | width = 320 57 | height = 240 58 | start = time.time() 59 | ds = BatchVideoDataLoader(dir='E:/doom', 60 | batch_size=8, 61 | num_frames=num_frames, 62 | width=width, 63 | height=height, 64 | limit=10) 65 | for x, y in ds: 66 | pass 67 | delta = time.time() - start 68 | print(f'Loaded in {delta} seconds') 69 | print('ok') 70 | -------------------------------------------------------------------------------- /src/models/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | 4 | from .maxpool4d import MaxPool4d 5 | 6 | pool1d = { 7 | 'max': nn.MaxPool1d, 8 | 'min': nn.MaxPool1d, 9 | 'avg': nn.AvgPool1d, 10 | } 11 | 12 | pool2d = { 13 | 'max': nn.MaxPool2d, 14 | 'min': nn.MaxPool2d, 15 | 'avg': nn.AvgPool2d, 16 | } 17 | 18 | pool3d = { 19 | 'max': nn.MaxPool3d, 20 | 'min': nn.MaxPool3d, 21 | 'avg': nn.AvgPool3d, 22 | } 23 | 24 | pool4d = { 25 | 'max': MaxPool4d, 26 | } 27 | 28 | 29 | def get_pooling1d(name: str) -> nn.Module: 30 | if name not in pool1d: 31 | raise ValueError(f'Unknown 1D pool function "{name}", ' 32 | f'valid options are {pool1d}') 33 | return pool1d[name] 34 | 35 | 36 | def get_pooling2d(name: str) -> nn.Module: 37 | if name not in pool2d: 38 | raise ValueError(f'Unknown 2D pooling function "{name}", ' 39 | f'valid options are {pool2d}') 40 | return pool2d[name] 41 | 42 | 43 | def get_pooling3d(name: str) -> nn.Module: 44 | if name not in pool3d: 45 | raise ValueError(f'Unknown 3D pooling function "{name}", ' 46 | f'valid options are {pool3d}') 47 | return pool3d[name] 48 | 49 | 50 | def get_pooling4d(name: str) -> nn.Module: 51 | if name not in pool4d: 52 | raise ValueError(f'Unknown 4D pooling function "{name}", ' 53 | f'valid options are {pool4d}') 54 | return pool4d[name] 55 | 56 | 57 | act_options = { 58 | 'sigmoid': nn.Sigmoid, 59 | 'tanh': nn.Tanh, 60 | 'leaky-relu': nn.LeakyReLU, 61 | 'relu': nn.ReLU, 62 | } 63 | 64 | 65 | def get_activation(name: str) -> nn.Module: 66 | if name not in act_options: 67 | raise ValueError(f'Unknown activation function "{name}"') 68 | return act_options[name]() 69 | 70 | 71 | def reparameterize(mu: Tensor, logvar: Tensor) -> Tensor: 72 | """ 73 | Reparameterization trick to sample from N(mu, var) from 74 | N(0,1). 75 | :param mu: (Tensor) Mean of the latent Gaussian [B x D] 76 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] 77 | :return: (Tensor) [B x D] 78 | """ 79 | std = torch.exp(0.5 * logvar) 80 | eps = torch.randn_like(std) 81 | return eps * std + mu 82 | -------------------------------------------------------------------------------- /src/models/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import abstractmethod 3 | from torch import Tensor, nn 4 | from torch.nn import functional as F 5 | from typing import List 6 | 7 | 8 | class Classifier(nn.Module): 9 | 10 | def __init__(self, name: str, num_classes: int) -> None: 11 | super(Classifier, self).__init__() 12 | self.name = name 13 | self.num_classes = num_classes 14 | 15 | @abstractmethod 16 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 17 | raise NotImplementedError 18 | 19 | def loss_function(self, 20 | prediction: Tensor, 21 | target: Tensor, 22 | objective: str = 'bce', 23 | baseline_accuracy: float = None) -> dict: 24 | result = {} 25 | if objective == 'nll': 26 | result['loss'] = F.nll_loss(prediction, target) 27 | elif objective == 'mse': 28 | result['loss'] = F.mse_loss(prediction, target) 29 | elif objective == 'bce': 30 | result['loss'] = F.binary_cross_entropy(prediction, target) 31 | else: 32 | raise ValueError(f'Objective "{objective}" not implemented') 33 | 34 | any_acc = torch.sum(torch.round(prediction), dim=1).clamp( 35 | 0, 1).int() == torch.sum(target, dim=1).clamp(0, 1).int() 36 | any_acc = any_acc.float().mean() 37 | result['accuracy/any'] = any_acc 38 | 39 | avg_acc = torch.round(prediction).int() == target 40 | avg_acc = avg_acc.float().mean() 41 | result['accuracy/avg'] = avg_acc 42 | 43 | if baseline_accuracy is not None: 44 | result['rel_acc/avg'] = (avg_acc - baseline_accuracy) / ( 45 | 1.0 - baseline_accuracy) 46 | result['rel_acc/any'] = (any_acc - baseline_accuracy) / ( 47 | 1.0 - baseline_accuracy) 48 | 49 | num_classes = target.shape[1] 50 | 51 | for i in range(num_classes): 52 | acc = torch.round(prediction[:, i]).int() == target[:, i] 53 | acc = acc.float().mean() 54 | result[f'accuracy/class_{i}'] = acc 55 | if baseline_accuracy is not None: 56 | result[f'rel_acc/class_{i}'] = (acc - baseline_accuracy) / ( 57 | 1.0 - baseline_accuracy) 58 | 59 | return result 60 | -------------------------------------------------------------------------------- /src/models/resnet_pg_localizer2d.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | from torch import Size, Tensor, nn 3 | from typing import List 4 | 5 | from .localizer import Localizer 6 | from .resnet2d import BasicBlock2d 7 | from .util import get_activation, get_pooling2d 8 | 9 | 10 | class ResNetPGLocalizer2d(Localizer): 11 | 12 | def __init__(self, 13 | name: str, 14 | hidden_dims: List[int], 15 | input_shape: Size, 16 | dropout: float = 0.4, 17 | pooling: str = None, 18 | batch_normalize: bool = False, 19 | output_activation: str = 'sigmoid') -> None: 20 | super().__init__(name=name) 21 | self.width = input_shape[2] 22 | self.height = input_shape[1] 23 | self.channels = input_shape[0] 24 | self.batch_normalize = batch_normalize 25 | self.hidden_dims = hidden_dims.copy() 26 | if pooling is not None: 27 | pool_fn = get_pooling2d(pooling) 28 | modules = [] 29 | in_features = self.channels 30 | for h_dim in hidden_dims: 31 | modules.append(BasicBlock2d(in_features, h_dim)) 32 | if pooling is not None: 33 | modules.append(pool_fn(2)) 34 | in_features = h_dim 35 | self.layers = nn.Sequential( 36 | *modules, 37 | nn.Flatten(), 38 | nn.Dropout(p=dropout), 39 | ) 40 | in_features = hidden_dims[-1] * self.width * self.height 41 | if pooling is not None: 42 | in_features /= 4**len(hidden_dims) 43 | if abs(in_features - ceil(in_features)) > 0: 44 | raise ValueError( 45 | 'noninteger number of features - perhaps there is too much pooling?' 46 | ) 47 | in_features = int(in_features) 48 | self.activation = get_activation(output_activation) 49 | self.prediction = nn.Linear(in_features, 4) 50 | if batch_normalize: 51 | self.output = nn.Sequential( 52 | nn.BatchNorm1d(4), 53 | self.activation, 54 | ) 55 | else: 56 | self.output = self.activation 57 | 58 | def forward(self, x: Tensor, lod: int = 0) -> Tensor: 59 | x = self.layers(x) 60 | x = self.prediction(x) 61 | x = self.output(x) 62 | return x 63 | -------------------------------------------------------------------------------- /src/models/resnet_rl2d.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | from torch import nn 3 | from typing import List 4 | 5 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 6 | 7 | from .resnet2d import BasicBlock2d 8 | from .util import get_pooling2d 9 | 10 | 11 | class ResNetRL2d(TorchModelV2, nn.Module): 12 | 13 | def __init__(self, 14 | obs_space, 15 | action_space, 16 | num_outputs: int, 17 | model_config: dict, 18 | name: str, 19 | width: int, 20 | height: int, 21 | channels: int, 22 | hidden_dims: List[int], 23 | pooling: str = None, 24 | dropout: float = 0.4): 25 | nn.Module.__init__(self) 26 | TorchModelV2.__init__(self, obs_space, action_space, num_outputs, 27 | model_config, name) 28 | 29 | if pooling is not None: 30 | pool_fn = get_pooling2d(pooling) 31 | 32 | modules = [] 33 | in_features = channels 34 | for h_dim in hidden_dims: 35 | modules.append(BasicBlock2d(in_features, h_dim)) 36 | if pooling is not None: 37 | modules.append(pool_fn(2)) 38 | in_features = h_dim 39 | self.layers = nn.Sequential( 40 | *modules, 41 | nn.Flatten(), 42 | nn.Dropout(p=dropout), 43 | ) 44 | if pooling is not None: 45 | in_features /= 4**len(hidden_dims) 46 | if abs(in_features - ceil(in_features)) > 0: 47 | raise ValueError( 48 | 'noninteger number of features - perhaps there is too much pooling?' 49 | ) 50 | in_features = int(in_features) 51 | self.output = nn.Linear( 52 | nn.Linear(in_features, action_space.num_actions), 53 | nn.Sigmoid(), 54 | ) 55 | self.value_out = nn.Linear( 56 | nn.Linear(in_features, 1), 57 | nn.Sigmoid(), 58 | ) 59 | 60 | def forward(self, input_dict, state, seq_lens): 61 | obs = input_dict['obs'].to(self.device) 62 | x = self.layers(obs) 63 | model_out = self.output(x) 64 | self._value_out = self.value_out(x) 65 | return model_out, state 66 | 67 | def value_function(self): 68 | return self._value_out.flatten() 69 | -------------------------------------------------------------------------------- /src/env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import Env, spaces 3 | 4 | from dataset import get_dataset 5 | 6 | 7 | class TimeSeriesDetector(Env): 8 | 9 | def __init__(self, config: dict): 10 | super(TimeSeriesDetector, self).__init__() 11 | self.observation_length = config['observation_length'] 12 | self.num_channels = config['channels'] 13 | self.action_stride = config.get('action_stride', 0) 14 | num_actions = config['num_event_classes'] + 1 # abstain 15 | self.action_space = spaces.Discrete(num_actions) 16 | self.observation_space = spaces.Box( 17 | low=config['low'], 18 | high=config['high'], 19 | shape=(self.num_channels, self.observation_length), 20 | dtype=np.float32, 21 | ) 22 | # Load the data 23 | self.ds = get_dataset(**config['data']) 24 | 25 | def step(self, action: int): 26 | reward = 0.0 27 | y = self.y[:, self.current_step:self.current_step + 28 | self.observation_length] 29 | #loss = F.nll_loss(prediction, target) 30 | 31 | self.current_step += 1 32 | if action != 0: 33 | self.current_step += self.action_stride 34 | obs = self.get_observation() 35 | done = self.current_step >= self.x.shape[1] - self.observation_length 36 | info = self.get_info_dict() 37 | return obs, reward, done, info 38 | 39 | def reset(self): 40 | self.current_step = 0 41 | i = np.random.randint(0, len(self.ds)) 42 | self.x, self.y = self.ds[i] 43 | if self.x.shape[1] < self.observation_length: 44 | raise ValueError( 45 | f'Example {i} is shorter ({self.x.shape[1]}) ' 46 | f'than the observation length ({self.observation_length})') 47 | return self.get_observation() 48 | 49 | def get_observation(self): 50 | return self.x[:, self.current_step:self.current_step + 51 | self.observation_length].numpy() 52 | 53 | def get_info_dict(self) -> dict: 54 | return dict(current_step=self.current_step) 55 | 56 | 57 | envs = { 58 | 'TimeSeriesDetector': TimeSeriesDetector, 59 | } 60 | 61 | 62 | def get_env(name: str): 63 | if name not in envs: 64 | raise ValueError(f'Environment "{name}" not found ' 65 | f'valid options are {envs}') 66 | return envs[name] 67 | -------------------------------------------------------------------------------- /src/models/resnet_rl1d.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | from torch import nn 3 | from typing import List 4 | 5 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 6 | 7 | from .resnet1d import BasicBlock1d 8 | from .util import get_pooling1d 9 | 10 | 11 | class ResNetRL1d(TorchModelV2, nn.Module): 12 | 13 | def __init__(self, 14 | obs_space, 15 | action_space, 16 | num_outputs: int, 17 | model_config: dict, 18 | name: str, 19 | num_samples: int, 20 | channels: int, 21 | hidden_dims: List[int], 22 | pooling: str = None, 23 | dropout: float = 0.4): 24 | nn.Module.__init__(self) 25 | TorchModelV2.__init__(self, obs_space, action_space, num_outputs, 26 | model_config, name) 27 | if pooling is not None: 28 | pool_fn = get_pooling1d(pooling) 29 | modules = [] 30 | in_features = channels 31 | for h_dim in hidden_dims: 32 | modules.append(BasicBlock1d(in_features, h_dim)) 33 | if pooling is not None: 34 | modules.append(pool_fn(2)) 35 | in_features = h_dim 36 | self.layers = nn.Sequential( 37 | *modules, 38 | nn.Dropout(p=dropout), 39 | nn.Flatten(), 40 | ) 41 | in_features = hidden_dims[-1] * num_samples 42 | if pooling is not None: 43 | in_features /= 2**len(hidden_dims) 44 | if abs(in_features - ceil(in_features)) > 0: 45 | raise ValueError( 46 | 'noninteger number of features - perhaps there is too much pooling?' 47 | ) 48 | in_features = int(in_features) 49 | self.output = nn.Sequential( 50 | nn.Linear(in_features, action_space.n), 51 | nn.Sigmoid(), 52 | ) 53 | self.value_out = nn.Sequential( 54 | nn.Linear(in_features, 1), 55 | nn.ReLU(), 56 | ) 57 | 58 | def forward(self, input_dict, state, seq_lens): 59 | obs = input_dict['obs'] #.to(self.layers.device) 60 | x = self.layers(obs) 61 | model_out = self.output(x) 62 | self._value_out = self.value_out(x) 63 | return model_out, state 64 | 65 | def value_function(self): 66 | return self._value_out.flatten() 67 | -------------------------------------------------------------------------------- /experiments/la5c/README.md: -------------------------------------------------------------------------------- 1 | # LA5c Study (work in progress) 2 | These experiments utilize the [LA5c Study](https://openneuro.org/datasets/ds000030/versions/1.0.0) from the [Preprocessed Consortium for Neuropsychiatric Phenomics dataset](https://f1000research.com/articles/6-1262/v2). 265 participants completed extensive psychometric and neuroimaging examinations. A relatively simple modeling task was chosen: predict participants' questionnaire answers given their T1w MRI. 3 | 4 | ## Results 5 | TODO: pick interesting questions we want to model with T1w 6 | 7 | TODO: show a case where overtly visible differences in brain structure were exploited to infer correct answers 8 | 9 | ## Materials & Methods 10 | ### Choice of Questions 11 | The dataset comes with hundreds of self-report answers for an amalgam of questionnaires. Participants' answers are modeled by the structural T1-weighted MRI of each participant. TODO: explain process behind selecting questions 12 | 13 | TODO: show which questions were selected 14 | 15 | ### Experiment Files 16 | | File | Input Size (CxDxHxW) | Notes 17 | | -------------------------------------------------------------------------------- | --------------------- | ------ 18 | | [classification/bilingual.yaml](classification/bilingual.yaml) | 1x176x256x256 | Model bilingual Y/N in terms of T1w 19 | | [classification/bilingual_hparams.yaml](classification/bilingual_hparams.yaml) | 1x176x256x256 | Hyperparameter search for `bilingual.yaml` 20 | 21 | ### Source Files 22 | | File | Notes 23 | | ------------------------------------------------------------------------ | ------------------------------ 24 | | [src/classification.py](/src/classification.py) | Base classification experiment 25 | | [src/dataset/la5c.py](/src/dataset/la5c.py) | LA5c dataset class 26 | | [src/models/resnet_classifier3d.py](/src/models/resnet_classifier3d.py) | 3D ResNet classifier model 27 | 28 | ## License 29 | ### Data 30 | LA5c is released under the [Creative Commons Zero (CC0)](https://creativecommons.org/choose/zero/) license. 31 | 32 | ### Code 33 | [Apache 2.0](../../LICENSE-Apache) / [MIT](../../LICENSE-MIT) dual-license. Please contact me if this is somehow not permissive enough and we'll add whatever free license is necessary for your project. 34 | -------------------------------------------------------------------------------- /src/video_grid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | import numpy as np 5 | from decord import VideoReader, bridge 6 | from PIL import Image 7 | 8 | bridge.set_bridge('torch') 9 | 10 | input = '/mnt/e/320x240' 11 | rows = 4 12 | cols = 4 13 | fps = 12 14 | num_seconds = 8 15 | num_frames = fps * num_seconds 16 | limit = rows * cols 17 | files = sorted(f for f in os.listdir(input) if f.endswith('1.mp4')) 18 | indices = np.random.randint(0, len(files), size=(1, limit)).squeeze() 19 | files = [files[idx] for idx in indices] 20 | print(f'Using videos {files}') 21 | videos = [VideoReader(os.path.join(input, f)) for f in files] 22 | frames = np.array([[vr.next().numpy() for _ in range(num_frames)] 23 | for vr in videos]) 24 | for frame in range(num_frames): 25 | i = 0 26 | complete_frame = [] 27 | for _ in range(rows): 28 | items = [] 29 | for _ in range(cols): 30 | frame_data = frames[i][frame] 31 | items.append(frame_data) 32 | i += 1 33 | items = np.concatenate(items, axis=1) 34 | complete_frame.append(items) 35 | complete_frame = np.concatenate(complete_frame, axis=0) 36 | img = Image.fromarray(complete_frame) 37 | img = img.resize((img.width // 2, img.height // 2)) 38 | out_path = 'frame_{0:04}.png'.format(frame) 39 | img.save(out_path) 40 | print(f'Wrote {out_path}') 41 | out_path = 'output' 42 | print(f'Encoding video to {out_path}.mp4') 43 | cmd = f"ffmpeg -r {fps} -s {img.width}x{img.height} -i frame_%04d.png -crf 25 -pix_fmt yuv420p {out_path}.mp4" 44 | print(f'Running {cmd}') 45 | proc = subprocess.run(['bash', '-c', cmd], capture_output=True) 46 | if proc.returncode != 0: 47 | msg = 'expected exit code 0 from ffmpeg, got exit code {}: {}'.format( 48 | proc.returncode, proc.stdout.decode('unicode_escape')) 49 | if proc.stderr: 50 | msg += ' ' + proc.stderr.decode('unicode_escape') 51 | raise ValueError(msg) 52 | cmd = f"ffmpeg -i {out_path}.mp4 {out_path}.gif" 53 | print(f'Running {cmd}') 54 | proc = subprocess.run(['bash', '-c', cmd], capture_output=True) 55 | if proc.returncode != 0: 56 | msg = 'expected exit code 0 from ffmpeg, got exit code {}: {}'.format( 57 | proc.returncode, proc.stdout.decode('unicode_escape')) 58 | if proc.stderr: 59 | msg += ' ' + proc.stderr.decode('unicode_escape') 60 | raise ValueError(msg) 61 | print(f'Wrote gif to {out_path}.gif') 62 | [os.remove(f'frame_{i}.png') for i in range(num_frames)] 63 | -------------------------------------------------------------------------------- /src/models/resnet_renderer2d.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor, nn 2 | from typing import List 3 | 4 | import numpy as np 5 | 6 | from .renderer import BaseRenderer 7 | from .resnet2d import TransposeBasicBlock2d 8 | from .util import get_activation 9 | 10 | 11 | class ResNetRenderer2d(BaseRenderer): 12 | 13 | def __init__(self, 14 | name: str, 15 | hidden_dims: List[int], 16 | width: int, 17 | height: int, 18 | channels: int, 19 | enable_fid: bool = True, 20 | output_activation: str = 'sigmoid') -> None: 21 | super().__init__(name=name, enable_fid=enable_fid) 22 | self.width = width 23 | self.height = height 24 | self.channels = channels 25 | self.hidden_dims = hidden_dims 26 | 27 | # Decoder 28 | self.decoder_input = nn.Linear(16, hidden_dims[0] * 4) 29 | modules = [] 30 | in_features = hidden_dims[0] 31 | for h_dim in hidden_dims: 32 | layer = TransposeBasicBlock2d(in_features, h_dim) 33 | modules.append(layer) 34 | in_features = h_dim 35 | self.decoder = nn.Sequential(*modules) 36 | num_lods = np.min([np.log(width), np.log(height)]) / np.log(2) - 2 37 | activation = get_activation(output_activation) 38 | self.initial_output = nn.Sequential( 39 | TransposeBasicBlock2d(in_features, 4 * 4 * 3 // 4), 40 | activation, 41 | ) 42 | output_layers = [] 43 | for _ in range(num_lods): 44 | output_layers.append( 45 | nn.Sequential( 46 | TransposeBasicBlock2d(3, 128), 47 | TransposeBasicBlock2d(128, 3), 48 | activation, 49 | )) 50 | self.output_layers = output_layers 51 | 52 | def decode(self, 53 | world_matrix: Tensor, 54 | lod: int = 0, 55 | alpha: float = 0.0, 56 | **kwargs) -> Tensor: 57 | x = self.decoder_input(world_matrix) 58 | x = x.view(x.shape[0], self.hidden_dims[-1], 2, 2) 59 | x = self.decoder(x) 60 | x = self.initial_output(x) 61 | x = x.view(x.shape[0], 3, 4, 4) 62 | for i, layer in enumerate(self.output_layers[:lod]): 63 | a = nn.Upsample(x.shape[2:] * 2)(x) 64 | b = layer(a) 65 | x = b if i < lod - 1 else a.lerp(b, alpha) 66 | return x 67 | -------------------------------------------------------------------------------- /src/models/resnet_localizer2d.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | from torch import Size, Tensor, nn 3 | from typing import List 4 | 5 | from .localizer import Localizer 6 | from .resnet2d import BasicBlock2d 7 | from .util import get_activation, get_pooling2d 8 | 9 | 10 | class ResNetLocalizer2d(Localizer): 11 | 12 | def __init__(self, 13 | name: str, 14 | hidden_dims: List[int], 15 | input_shape: Size, 16 | pooling: str = None, 17 | batch_normalize: bool = False, 18 | output_activation: str = 'sigmoid') -> None: 19 | super().__init__(name=name) 20 | self.width = input_shape[2] 21 | self.height = input_shape[1] 22 | self.channels = input_shape[0] 23 | self.batch_normalize = batch_normalize 24 | self.hidden_dims = hidden_dims.copy() 25 | if pooling is not None: 26 | pool_fn = get_pooling2d(pooling) 27 | modules = [] 28 | in_features = self.channels 29 | for h_dim in hidden_dims: 30 | modules.append(BasicBlock2d(in_features, h_dim)) 31 | if pooling is not None: 32 | modules.append(pool_fn(2)) 33 | in_features = h_dim 34 | self.layers = nn.Sequential( 35 | *modules, 36 | nn.Flatten(), 37 | ) 38 | in_features = hidden_dims[-1] * self.width * self.height 39 | if pooling is not None: 40 | in_features /= 4**len(hidden_dims) 41 | if abs(in_features - ceil(in_features)) > 0: 42 | raise ValueError( 43 | 'noninteger number of features - perhaps there is too much pooling?' 44 | ) 45 | in_features = int(in_features) 46 | self.activation = get_activation(output_activation) 47 | self.prediction = nn.Linear(in_features, 4) 48 | self.output = nn.Sequential( 49 | nn.BatchNorm1d(4), 50 | self.activation, 51 | ) if batch_normalize else self.activation 52 | 53 | def forward(self, x: Tensor) -> Tensor: 54 | x = self.layers(x) 55 | x = self.prediction(x) 56 | x = self.output(x) 57 | # We are going to enforce the invariant x1 <= x2 && y1 <= y2 by 58 | # predicting the width and height instead of x2 and y2 directly. 59 | # Here we change it back to the original format. 60 | ax = x.clone() 61 | ax[:, 2] = x[:, 0] + x[:, 2] 62 | ax[:, 3] = x[:, 1] + x[:, 3] 63 | return ax 64 | -------------------------------------------------------------------------------- /src/verify_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | from multiprocessing import cpu_count 4 | from torch.utils.data import DataLoader 5 | 6 | from tqdm import tqdm 7 | 8 | from dataset import get_dataset 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser(description='test dataset') 12 | parser.add_argument('--dataset', 13 | dest="dataset", 14 | metavar='DATASET', 15 | help='dataset name', 16 | default='deeplesion') 17 | parser.add_argument('--root', 18 | dest="root", 19 | metavar='ROOT', 20 | help='dataset root dir', 21 | default=None) 22 | parser.add_argument('--num-workers', 23 | dest="num_workers", 24 | metavar='NUM_WORKERS', 25 | help='number of worker processes for the data loader', 26 | default=cpu_count()) 27 | args = parser.parse_args() 28 | 29 | opts = { 30 | # 'rsna-intracranial': { 31 | # 'root': args.root or 'E:/rsna-intracranial', 32 | # 'download': False, 33 | # 'use_gzip': False, 34 | # }, 35 | 'rsna-intracranial': { 36 | 'root': '/data/rsna-ich', 37 | 'download': True, 38 | 'use_gzip': True, 39 | }, 40 | 'deeplesion': { 41 | 'root': args.root or '/data/deeplesion', 42 | 'download': True, 43 | }, 44 | } 45 | 46 | print(f'Verifying {args.dataset} with {args.num_workers} workers') 47 | 48 | for train in [False, True]: 49 | ds = get_dataset(args.dataset, 50 | opts[args.dataset], 51 | train=train, 52 | safe=False) 53 | loader = DataLoader(ds, 54 | batch_size=1, 55 | num_workers=args.num_workers, 56 | pin_memory=False) 57 | n = len(loader) 58 | it = iter(loader) 59 | bad_indices = [] 60 | for i in tqdm(range(n)): 61 | try: 62 | batch = next(it) 63 | except KeyboardInterrupt: 64 | raise 65 | except: 66 | bad_indices.append(i) 67 | print(f'Encountered bad index ({i}):') 68 | print(sys.exc_info()) 69 | print('Bad indices:') 70 | print(bad_indices) 71 | -------------------------------------------------------------------------------- /src/models/resnet_classifier2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from math import ceil 3 | from torch import Size, Tensor, nn 4 | from typing import List 5 | 6 | from .classifier import Classifier 7 | from .resnet2d import BasicBlock2d 8 | from .util import get_pooling2d 9 | 10 | 11 | class ResNetClassifier2d(Classifier): 12 | 13 | def __init__(self, 14 | name: str, 15 | hidden_dims: List[int], 16 | input_shape: Size, 17 | num_classes: int, 18 | load_weights: str = None, 19 | dropout: float = 0.4, 20 | pooling: str = None) -> None: 21 | super().__init__(name=name, num_classes=num_classes) 22 | self.width = input_shape[2] 23 | self.height = input_shape[1] 24 | self.channels = input_shape[0] 25 | self.hidden_dims = hidden_dims.copy() 26 | if pooling is not None: 27 | pool_fn = get_pooling2d(pooling) 28 | modules = [] 29 | in_features = self.channels 30 | for h_dim in hidden_dims: 31 | modules.append(BasicBlock2d(in_features, h_dim)) 32 | if pooling is not None: 33 | modules.append(pool_fn(2)) 34 | in_features = h_dim 35 | self.layers = nn.Sequential( 36 | *modules, 37 | nn.Flatten(), 38 | nn.Dropout(p=dropout), 39 | ) 40 | in_features = hidden_dims[-1] * self.width * self.height 41 | if pooling is not None: 42 | in_features /= 4**len(hidden_dims) 43 | if abs(in_features - ceil(in_features)) > 0: 44 | raise ValueError( 45 | 'noninteger number of features - perhaps there is too much pooling?' 46 | ) 47 | in_features = int(in_features) 48 | self.output = nn.Sequential( 49 | nn.Linear(in_features, num_classes), 50 | nn.BatchNorm1d(num_classes), 51 | nn.Sigmoid(), 52 | ) 53 | if load_weights is not None: 54 | new = self.state_dict() 55 | old = torch.load(load_weights)['state_dict'] 56 | for k, v in new.items(): 57 | ok = f'classifier.{k}' 58 | if ok in old: 59 | new[k] = old[ok].cpu() 60 | print(f'Loaded weights for layer {k}') 61 | self.load_state_dict(new) 62 | 63 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 64 | y = self.layers(input) 65 | y = self.output(y) 66 | return y 67 | -------------------------------------------------------------------------------- /src/models/resnet_classifier1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from math import ceil 3 | from torch import Size, Tensor, nn 4 | from typing import List 5 | 6 | from .classifier import Classifier 7 | from .resnet1d import BasicBlock1d 8 | from .util import get_pooling1d 9 | 10 | 11 | class ResNetClassifier1d(Classifier): 12 | 13 | def __init__(self, 14 | name: str, 15 | hidden_dims: List[int], 16 | input_shape: Size, 17 | num_classes: int, 18 | load_weights: str = None, 19 | dropout: float = 0.0, 20 | pooling: str = None, 21 | kernel_size: int = 3, 22 | padding: int = 1, 23 | logits_only: bool = False) -> None: 24 | super().__init__(name=name, num_classes=num_classes) 25 | self.num_samples = input_shape[1] 26 | self.channels = input_shape[0] 27 | self.dropout = nn.Dropout(dropout) 28 | self.hidden_dims = hidden_dims.copy() 29 | self.logits_only = logits_only 30 | if pooling is not None: 31 | pool_fn = get_pooling1d(pooling) 32 | modules = [] 33 | in_features = self.channels 34 | for h_dim in hidden_dims: 35 | modules.append( 36 | BasicBlock1d(in_features, 37 | h_dim, 38 | kernel_size=kernel_size, 39 | padding=padding)) 40 | if pooling is not None: 41 | modules.append(pool_fn(2)) 42 | in_features = h_dim 43 | self.layers = nn.Sequential(*modules) 44 | in_features = hidden_dims[-1] * self.num_samples 45 | if pooling is not None: 46 | in_features /= 2**len(hidden_dims) 47 | if abs(in_features - ceil(in_features)) > 0: 48 | raise ValueError( 49 | 'noninteger number of features - perhaps there is too much pooling?' 50 | ) 51 | in_features = int(in_features) 52 | self.output = nn.Linear(in_features, num_classes) 53 | if load_weights is not None: 54 | new = self.state_dict() 55 | old = torch.load(load_weights)['state_dict'] 56 | for k, v in new.items(): 57 | ok = f'model.{k}' 58 | if ok in old: 59 | new[k] = old[ok].cpu() 60 | print(f'Loaded weights for layer {k}') 61 | self.load_state_dict(new) 62 | 63 | def forward(self, x: Tensor) -> Tensor: 64 | x = self.layers(x) 65 | if self.logits_only: 66 | return x 67 | x = x.reshape(x.shape[0], -1) 68 | x = self.dropout(x) 69 | x = self.output(x) 70 | x = torch.sigmoid(x) 71 | return x 72 | -------------------------------------------------------------------------------- /src/resize-video.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | import time 5 | 6 | parser = argparse.ArgumentParser(description='resize videos') 7 | parser.add_argument('--input', 8 | '-i', 9 | dest="input", 10 | metavar='INPUT', 11 | help='path to dir with mp4 files', 12 | default='E:/doom') 13 | parser.add_argument('--output', 14 | '-o', 15 | dest="output", 16 | metavar='OUTPUT', 17 | help='output dir', 18 | default='E:/doom-processed') 19 | parser.add_argument('--width', 20 | dest="width", 21 | metavar='WIDTH', 22 | help='output resolution x', 23 | default=320) 24 | parser.add_argument('--height', 25 | dest="height", 26 | metavar='HEIGHT', 27 | help='output resolution y', 28 | default=240) 29 | parser.add_argument('--skip-frames', 30 | dest="skip_frames", 31 | metavar='SKIP_FRAMES', 32 | help='number of frames to skip', 33 | default=1) 34 | args = parser.parse_args() 35 | denom = args.skip_frames + 1 36 | 37 | if not os.path.exists(args.output): 38 | os.makedirs(args.output) 39 | 40 | files = sorted([f for f in os.listdir(args.input) if f.endswith('.mp4')]) 41 | print(f'Processing {len(files)} files') 42 | total_in = 0 43 | total_out = 0 44 | for i, file in enumerate(files): 45 | start = time.time() 46 | input = os.path.join(args.input, file).replace('\\', '/') 47 | output = os.path.join(args.output, file).replace('\\', '/') 48 | in_size = os.path.getsize(input) 49 | total_in += in_size 50 | in_size //= 1000 * 1000 51 | print(f'[{i+1}/{len(files)}] Processing {input} ({in_size} MiB)') 52 | cmd = f"ffmpeg -i $(wslpath {input}) -s {args.width}x{args.height} -y -c:a copy -an -vf select='not(mod(n\\,{denom})), setpts={1.0/denom}*PTS' $(wslpath {output})" 53 | proc = subprocess.run(['bash', '-c', cmd], capture_output=True) 54 | if proc.returncode != 0: 55 | msg = 'expected exit code 0 from ffmpeg, got exit code {}: {}'.format( 56 | proc.returncode, proc.stdout.decode('unicode_escape')) 57 | if proc.stderr: 58 | msg += ' ' + proc.stderr.decode('unicode_escape') 59 | raise ValueError(msg) 60 | delta = time.time() - start 61 | out_size = os.path.getsize(output) 62 | total_out += out_size 63 | out_size //= 1000 * 1000 64 | pct = (1.0 - out_size / in_size) * 100 65 | print( 66 | f'[{i+1}/{len(files)}] Wrote {output} in {delta} seconds ({out_size} MiB, {int(pct)}% reduction)' 67 | ) 68 | reduction = int((1.0 - total_out / total_in) * 100) 69 | print(f'Success, total reduction of {reduction}%') 70 | -------------------------------------------------------------------------------- /src/models/resnet2d.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | return nn.Conv2d(in_planes, 7 | out_planes, 8 | kernel_size=3, 9 | stride=stride, 10 | padding=1, 11 | bias=False) 12 | 13 | 14 | class BasicBlock2d(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock2d, self).__init__() 19 | self.conv1 = conv3x3(in_planes, planes, stride) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | 24 | if stride != 1 or in_planes != self.expansion * planes: 25 | self.shortcut = nn.Sequential( 26 | nn.Conv2d(in_planes, 27 | self.expansion * planes, 28 | kernel_size=1, 29 | stride=stride, 30 | bias=False), nn.BatchNorm2d(self.expansion * planes)) 31 | else: 32 | self.shortcut = None 33 | 34 | def forward(self, x): 35 | out = F.relu(self.bn1(self.conv1(x))) 36 | out = self.bn2(self.conv2(out)) 37 | if self.shortcut is not None: 38 | out += self.shortcut(x) 39 | out = F.relu(out) 40 | return out 41 | 42 | 43 | class TransposeBasicBlock2d(nn.Module): 44 | expansion = 1 45 | 46 | def __init__(self, in_planes, planes, stride=1): 47 | super(TransposeBasicBlock2d, self).__init__() 48 | self.conv1 = nn.ConvTranspose2d(in_planes, 49 | planes, 50 | kernel_size=3, 51 | stride=stride, 52 | padding=1, 53 | bias=False) 54 | self.bn1 = nn.BatchNorm2d(planes) 55 | self.conv2 = nn.ConvTranspose2d(planes, 56 | planes, 57 | kernel_size=3, 58 | stride=1, 59 | padding=1, 60 | bias=False) 61 | self.bn2 = nn.BatchNorm2d(planes) 62 | 63 | if stride != 1 or in_planes != self.expansion * planes: 64 | self.shortcut = nn.Sequential( 65 | nn.ConvTranspose2d(in_planes, 66 | self.expansion * planes, 67 | kernel_size=1, 68 | stride=stride, 69 | bias=False), nn.BatchNorm2d(planes)) 70 | else: 71 | self.shortcut = None 72 | 73 | def forward(self, x): 74 | out = F.relu(self.bn1(self.conv1(x))) 75 | out = self.bn2(self.conv2(out)) 76 | if self.shortcut is not None: 77 | out += self.shortcut(x) 78 | out = F.relu(out) 79 | return out 80 | -------------------------------------------------------------------------------- /src/models/resnet1d.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class BasicBlock1d(nn.Module): 6 | 7 | def __init__(self, in_planes, planes, stride=1, kernel_size=3, padding=1): 8 | super(BasicBlock1d, self).__init__() 9 | self.conv1 = nn.Conv1d(in_planes, 10 | planes, 11 | kernel_size=kernel_size, 12 | stride=stride, 13 | padding=padding, 14 | bias=False) 15 | self.bn1 = nn.BatchNorm1d(planes) 16 | self.conv2 = nn.Conv1d(planes, 17 | planes, 18 | kernel_size=kernel_size, 19 | stride=1, 20 | padding=padding, 21 | bias=False) 22 | self.bn2 = nn.BatchNorm1d(planes) 23 | 24 | if stride != 1 or in_planes != planes: 25 | self.shortcut = nn.Sequential( 26 | nn.Conv1d(in_planes, 27 | planes, 28 | kernel_size=1, 29 | stride=stride, 30 | bias=False), nn.BatchNorm1d(planes)) 31 | else: 32 | self.shortcut = None 33 | 34 | def forward(self, x): 35 | out = F.relu(self.bn1(self.conv1(x))) 36 | out = self.bn2(self.conv2(out)) 37 | if self.shortcut is not None: 38 | out += self.shortcut(x) 39 | out = F.relu(out) 40 | return out 41 | 42 | 43 | class TransposeBasicBlock1d(nn.Module): 44 | 45 | def __init__(self, in_planes, planes, stride=1): 46 | super(TransposeBasicBlock1d, self).__init__() 47 | self.conv1 = nn.ConvTranspose1d(in_planes, 48 | planes, 49 | kernel_size=3, 50 | stride=stride, 51 | padding=1, 52 | bias=False) 53 | self.bn1 = nn.BatchNorm1d(planes) 54 | self.conv2 = nn.ConvTranspose1d(planes, 55 | planes, 56 | kernel_size=3, 57 | stride=1, 58 | padding=1, 59 | bias=False) 60 | self.bn2 = nn.BatchNorm1d(planes) 61 | 62 | if stride != 1 or in_planes != planes: 63 | self.shortcut = nn.Sequential( 64 | nn.ConvTranspose1d(in_planes, 65 | planes, 66 | kernel_size=1, 67 | stride=stride, 68 | bias=False), nn.BatchNorm1d(planes)) 69 | else: 70 | self.shortcut = None 71 | 72 | def forward(self, x): 73 | out = F.relu(self.bn1(self.conv1(x))) 74 | out = self.bn2(self.conv2(out)) 75 | if self.shortcut is not None: 76 | out += self.shortcut(x) 77 | out = F.relu(out) 78 | return out 79 | -------------------------------------------------------------------------------- /src/augmentation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import Tensor 4 | from torch.nn.parameter import Parameter 5 | from torch.utils.data import Dataset 6 | from typing import Iterator 7 | 8 | from base_experiment import BaseExperiment 9 | from dataset import get_example_shape 10 | from models import create_model 11 | from plot import get_plot_fn 12 | 13 | 14 | class AugmentationExperiment(BaseExperiment): 15 | 16 | def __init__(self, config: dict, enable_tune: bool = False, **kwargs): 17 | super().__init__(config=config, enable_tune=enable_tune, **kwargs) 18 | input_shape = get_example_shape(config['exp_params']['data']) 19 | self.constraint = create_model(**config['constraint_params'], 20 | input_shape=input_shape) 21 | self.constraint.requires_grad = False 22 | self.constraint.eval() 23 | self.model = create_model(**config['model_params'], 24 | input_shape=input_shape) 25 | 26 | def sample_images(self, plot: dict, batch: Tensor): 27 | transformed = self.model(batch) 28 | out_path = os.path.join( 29 | self.logger.save_dir, self.logger.name, 30 | f"version_{self.logger.version}", 31 | f"{self.logger.name}_{plot['fn']}_{self.global_step}") 32 | fn = get_plot_fn(plot['fn']) 33 | fn(x=batch, 34 | y=transformed, 35 | out_path=out_path, 36 | vis=self.visdom(), 37 | **plot['params']) 38 | 39 | def get_val_batches(self, dataset: Dataset) -> list: 40 | val_batches = [] 41 | n = len(dataset) 42 | for plot in self.plots: 43 | indices = torch.randint(low=0, 44 | high=n, 45 | size=(plot['batch_size'], 1)).squeeze() 46 | batch = [dataset[i][0] for i in indices] 47 | val_batches.append(batch) 48 | return val_batches 49 | 50 | def training_step(self, batch, batch_idx): 51 | real_img, _ = batch 52 | self.curr_device = self.device 53 | real_img = real_img.to(self.curr_device) 54 | train_loss = self.model.loss_function(real_img, 55 | constraint=self.constraint, 56 | **self.params.get( 57 | 'loss_params', {})) 58 | self.log_train_step(train_loss) 59 | return train_loss 60 | 61 | def validation_step(self, batch, batch_idx): 62 | real_img, _ = batch 63 | self.curr_device = self.device 64 | real_img = real_img.to(self.curr_device) 65 | val_loss = self.model.loss_function(real_img, 66 | constraint=self.constraint, 67 | **self.params.get( 68 | 'loss_params', {})) 69 | self.log_val_step(val_loss) 70 | return val_loss 71 | 72 | def trainable_parameters(self) -> Iterator[Parameter]: 73 | return self.model.parameters() 74 | -------------------------------------------------------------------------------- /src/draw_boxes.py: -------------------------------------------------------------------------------- 1 | # Source: https://www.kaggle.com/kmader/deeplesion-overview 2 | import os 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import pandas as pd 7 | from matplotlib.patches import Rectangle 8 | from skimage.io import imread 9 | from skimage.transform import resize 10 | 11 | 12 | def read_hu(x): 13 | return resize(imread(x).astype(np.float32) - 32768, (512, 512)) 14 | 15 | 16 | def create_boxes(in_row): 17 | box_list = [] 18 | for (start_x, start_y, end_x, end_y) in in_row['bbox']: 19 | box_list += [ 20 | Rectangle((start_x, start_y), np.abs(end_x - start_x), 21 | np.abs(end_y - start_y)) 22 | ] 23 | return box_list 24 | 25 | 26 | def create_segmentation(in_img, in_row): 27 | yy, xx = np.meshgrid(range(in_img.shape[0]), 28 | range(in_img.shape[1]), 29 | indexing='ij') 30 | out_seg = np.zeros_like(in_img) 31 | for (start_x, start_y, end_x, end_y) in in_row['bbox']: 32 | c_seg = (xx < end_x) & (xx > start_x) & (yy < end_y) & (yy > start_y) 33 | out_seg += c_seg 34 | return np.clip(out_seg, 0, 1).astype(np.float32) 35 | 36 | 37 | root = 'E:\\deeplesion' 38 | base_img_dir = os.path.join(root, 'Images_png') 39 | patient_df = pd.read_csv(os.path.join(root, 'DL_info.csv')) 40 | patient_df = patient_df[:99] 41 | patient_df['bbox'] = patient_df['Bounding_boxes'].map( 42 | lambda x: np.reshape([float(y) for y in x.split(',')], (-1, 4))) 43 | patient_df['kaggle_path'] = patient_df.apply( 44 | lambda c_row: os.path.join( 45 | base_img_dir, '{Patient_index:06d}_{Study_index:02d}_{Series_ID:02d}'. 46 | format(**c_row), '{Key_slice_index:03d}.png'.format(**c_row)), 1) 47 | """ 48 | _, test_row = next(patient_df.sample(1, random_state=0).iterrows()) 49 | fig, ax1 = plt.subplots(1, 1, figsize=(10, 10)) 50 | c_img = read_hu(test_row['kaggle_path']) 51 | ax1.imshow(c_img, vmin=-1200, vmax=600, cmap='gray') 52 | ax1.add_collection(PatchCollection( 53 | create_boxes(test_row), alpha=0.25, facecolor='red')) 54 | ax1.set_title('{Patient_age}-{Patient_gender}'.format(**test_row)) 55 | """ 56 | 57 | 58 | def apply_softwindow(x): 59 | return (255 * plt.cm.gray(0.5 * np.clip((x - 50) / 350, -1, 1) + 60 | 0.5)[:, :, :3]).astype(np.uint8) 61 | 62 | 63 | fig, m_axs = plt.subplots(3, 1, figsize=(10, 15)) 64 | 65 | for ax1, (_, c_row) in zip(m_axs, 66 | patient_df.sample(50, random_state=0).iterrows()): 67 | c_img = read_hu(c_row['kaggle_path']) 68 | #ax1.imshow(c_img, vmin=-1200, vmax=600, cmap='gray') 69 | #ax1.add_collection(PatchCollection( 70 | # create_boxes(c_row), alpha=0.25, facecolor='red')) 71 | #ax1.set_title('{Patient_age}-{Patient_gender}'.format(**c_row)) 72 | #ax1.axis('off') 73 | c_segs = create_segmentation(c_img, c_row).astype(int) 74 | #ax1.imshow(mark_boundaries(image=apply_softwindow(c_img), 75 | # label_img=c_segs, 76 | # color=(0, 1, 0), 77 | # mode='thick')) 78 | ax1.imshow(apply_softwindow(c_img)) 79 | ax1.set_title('Segmentation Map') 80 | 81 | plt.show() 82 | -------------------------------------------------------------------------------- /src/models/resnet3d.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class BasicBlock3d(nn.Module): 6 | expansion = 1 7 | 8 | def __init__(self, in_planes, planes, stride=1): 9 | super(BasicBlock3d, self).__init__() 10 | self.conv1 = nn.Conv3d(in_planes, 11 | planes, 12 | kernel_size=3, 13 | stride=stride, 14 | padding=1, 15 | bias=False) 16 | self.bn1 = nn.BatchNorm3d(planes) 17 | self.conv2 = nn.Conv3d(planes, 18 | planes, 19 | kernel_size=3, 20 | stride=1, 21 | padding=1, 22 | bias=False) 23 | self.bn2 = nn.BatchNorm3d(planes) 24 | 25 | if stride != 1 or in_planes != self.expansion * planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv3d(in_planes, 28 | self.expansion * planes, 29 | kernel_size=1, 30 | stride=stride, 31 | bias=False), nn.BatchNorm3d(self.expansion * planes)) 32 | else: 33 | self.shortcut = None 34 | 35 | def forward(self, x): 36 | out = F.relu(self.bn1(self.conv1(x))) 37 | out = self.bn2(self.conv2(out)) 38 | if self.shortcut is not None: 39 | out += self.shortcut(x) 40 | out = F.relu(out) 41 | return out 42 | 43 | 44 | class TransposeBasicBlock3d(nn.Module): 45 | expansion = 1 46 | 47 | def __init__(self, in_planes, planes, stride=1): 48 | super(TransposeBasicBlock3d, self).__init__() 49 | self.conv1 = nn.ConvTranspose3d(in_planes, 50 | planes, 51 | kernel_size=3, 52 | stride=stride, 53 | padding=1, 54 | bias=False) 55 | self.bn1 = nn.BatchNorm3d(planes) 56 | self.conv2 = nn.ConvTranspose3d(planes, 57 | planes, 58 | kernel_size=3, 59 | stride=1, 60 | padding=1, 61 | bias=False) 62 | self.bn2 = nn.BatchNorm3d(planes) 63 | 64 | if stride != 1 or in_planes != self.expansion * planes: 65 | self.shortcut = nn.Sequential( 66 | nn.ConvTranspose3d(in_planes, 67 | self.expansion * planes, 68 | kernel_size=1, 69 | stride=stride, 70 | bias=False), nn.BatchNorm3d(planes)) 71 | else: 72 | self.shortcut = None 73 | 74 | def forward(self, x): 75 | out = F.relu(self.bn1(self.conv1(x))) 76 | out = self.bn2(self.conv2(out)) 77 | if self.shortcut is not None: 78 | out += self.shortcut(x) 79 | out = F.relu(out) 80 | return out 81 | -------------------------------------------------------------------------------- /src/dataset/forrestgump_converter.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from math import ceil 4 | 5 | import nilearn as nl 6 | import nilearn.plotting 7 | import numpy as np 8 | 9 | 10 | def convert_forrest_gump(root: str, 11 | alignment: str = 'raw', 12 | max_chunk_samples: int = 16): 13 | if alignment == 'raw': 14 | data_dir = root 15 | identifier = 'acq-raw' 16 | elif alignment == 'linear': 17 | data_dir = os.path.join(root, 'derivatives', 18 | 'linear_anatomical_alignment') 19 | identifier = 'rec-dico7Tad2grpbold7Tad' 20 | elif alignment == 'nonlinear': 21 | data_dir = os.path.join(root, 'derivatives', 22 | 'non-linear_anatomical_alignment') 23 | identifier = 'rec-dico7Tad2grpbold7TadNL' 24 | else: 25 | raise ValueError(f"unknown alignment value '{alignment}'") 26 | subjects = [ 27 | f for f in os.listdir(data_dir) 28 | if f.startswith('sub-') and len(f) == 6 and int(f[len('sub-'):]) <= 20 29 | ] 30 | num_frames = 3599 31 | out_dir = os.path.join(root, 'converted', alignment) 32 | try: 33 | os.makedirs(out_dir) 34 | except: 35 | pass 36 | metadata = { 37 | 'alignment': alignment, 38 | 'max_chunk_samples': max_chunk_samples, 39 | 'subjects': {}, 40 | } 41 | for subject in subjects: 42 | print(f'Converting {alignment}/{subject}') 43 | subj_no = int(subject[4:]) - 1 44 | frame_no = 0 45 | frames = None 46 | for run in range(8): 47 | filename = f'{subject}_ses-forrestgump_task-forrestgump_{identifier}_run-0{run+1}_bold.nii.gz' 48 | filename = os.path.join(data_dir, subject, 'ses-forrestgump', 49 | 'func', filename) 50 | img = nl.image.load_img(filename) 51 | img = img.get_data() 52 | img = np.transpose(img, (3, 2, 0, 1)) 53 | frames = img if frames is None else np.concatenate( 54 | (frames, img), axis=0) 55 | if frames.shape[0] != num_frames: 56 | print( 57 | f'WARNING: {subject} has {len(frames)} frames, expected {num_frames}' 58 | ) 59 | try: 60 | os.makedirs(os.path.join(out_dir, subject)) 61 | except: 62 | pass 63 | num_chunks = ceil(frames.shape[0] / max_chunk_samples) 64 | chunks = [] 65 | for i in range(num_chunks): 66 | a = max_chunk_samples * i 67 | b = min(a + max_chunk_samples, frames.shape[0]) 68 | out_path = os.path.join(out_dir, subject, subject + f'_{i}') 69 | chunk = frames[a:b, ...] 70 | chunks.append(chunk.shape[0]) 71 | np.save(out_path, chunk) 72 | metadata['subjects'][subject] = { 73 | 'num_frames': frames.shape[0], 74 | 'chunks': chunks, 75 | } 76 | with open(os.path.join(out_dir, 'metadata.json'), 'w') as f: 77 | f.write(json.dumps(metadata)) 78 | print('Finished converting') 79 | 80 | 81 | if __name__ == '__main__': 82 | convert_forrest_gump('/data/openneuro/ds000113-download', 83 | alignment='linear') 84 | -------------------------------------------------------------------------------- /src/dataset/cq500.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | 5 | import boto3 6 | import pydicom 7 | 8 | CQ500_HEADER = 'name,Category,R1:ICH,R1:IPH,R1:IVH,R1:SDH,R1:EDH,R1:SAH,R1:BleedLocation-Left,R1:BleedLocation-Right,R1:ChronicBleed,R1:Fracture,R1:CalvarialFracture,R1:OtherFracture,R1:MassEffect,R1:MidlineShift,R2:ICH,R2:IPH,R2:IVH,R2:SDH,R2:EDH,R2:SAH,R2:BleedLocation-Left,R2:BleedLocation-Right,R2:ChronicBleed,R2:Fracture,R2:CalvarialFracture,R2:OtherFracture,R2:MassEffect,R2:MidlineShift,R3:ICH,R3:IPH,R3:IVH,R3:SDH,R3:EDH,R3:SAH,R3:BleedLocation-Left,R3:BleedLocation-Right,R3:ChronicBleed,R3:Fracture,R3:CalvarialFracture,R3:OtherFracture,R3:MassEffect,R3:MidlineShift\n' 9 | 10 | 11 | def load_labels(path: str) -> dict: 12 | labels = {} 13 | with open(path, 'r') as f: 14 | hdr = f.readline() 15 | if hdr != CQ500_HEADER: 16 | raise ValueError('bad header') 17 | for line in f: 18 | parts = line.strip().split(',') 19 | idx = parts[0][9:] 20 | labels[idx] = torch.Tensor([int(b) for b in parts[2:]]) 21 | return labels 22 | 23 | 24 | class CQ500Dataset(data.Dataset): 25 | 26 | def __init__(self, 27 | root: str, 28 | download: bool = True, 29 | use_gzip: bool = True, 30 | s3_bucket: str = 'cq500', 31 | s3_endpoint: str = 'https://nyc3.digitaloceanspaces.com', 32 | delete_after_use: bool = False): 33 | super().__init__() 34 | self.root = root 35 | self.download = download 36 | self.use_gzip = use_gzip 37 | self.s3_bucket = s3_bucket 38 | self.s3_endpoint = s3_endpoint 39 | labels_csv_path = os.path.join(root, 'reads.csv') 40 | if download: 41 | if not os.path.exists(root): 42 | os.makedirs(root) 43 | if not os.path.exists(labels_csv_path) or os.path.getsize( 44 | labels_csv_path) == 0: 45 | s3 = boto3.resource('s3', endpoint_url=s3_endpoint) 46 | bucket = s3.Bucket(s3_bucket) 47 | with open(labels_csv_path, 'w') as f: 48 | obj = bucket.Object('reads.csv') 49 | obj.download_fileobj(f) 50 | elif not os.path.exists(labels_csv_path): 51 | raise ValueError( 52 | f'with download == False, {labels_csv_path} does not exist') 53 | self.labels = load_labels(labels_csv_path) 54 | 55 | def __getitem__(self, index): 56 | path = '' 57 | if self.use_gzip: 58 | path += '.gz' 59 | if not os.path.exists(path) or os.path.getsize(path) == 0: 60 | key = '' 61 | s3 = boto3.resource('s3', endpoint_url=self.s3_endpoint) 62 | bucket = s3.Bucket(self.s3_bucket) 63 | with open(path, 'wb') as f: 64 | obj = bucket.Object(key) 65 | obj.download_fileobj(f) 66 | 67 | ds = pydicom.dcmread(self.files[index], stop_before_pixels=False) 68 | data = raw_dicom_pixels(ds) 69 | return (data, []) 70 | 71 | def __len__(self): 72 | return len(self.labels) 73 | 74 | 75 | if __name__ == '__main__': 76 | ds = CQ500Dataset('E:/cq500') 77 | print(ds[0].shape) 78 | print(ds[1].shape) 79 | -------------------------------------------------------------------------------- /src/models/resnet_vae1d.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | from torch import Tensor, nn 3 | from typing import List 4 | 5 | from .base import BaseVAE 6 | from .resnet1d import BasicBlock1d, TransposeBasicBlock1d 7 | from .util import get_activation, get_pooling1d 8 | 9 | 10 | class ResNetVAE1d(BaseVAE): 11 | 12 | def __init__(self, 13 | name: str, 14 | latent_dim: int, 15 | hidden_dims: List[int], 16 | num_samples: int, 17 | channels: int, 18 | dropout: float = 0.4, 19 | pooling: str = None, 20 | output_activation: str = 'tanh') -> None: 21 | super(ResNetVAE1d, self).__init__(name=name, latent_dim=latent_dim) 22 | self.num_samples = num_samples 23 | self.channels = channels 24 | self.hidden_dims = hidden_dims.copy() 25 | 26 | if pooling is not None: 27 | pool_fn = get_pooling1d(pooling) 28 | 29 | # Encoder 30 | modules = [] 31 | in_features = channels 32 | for h_dim in hidden_dims: 33 | modules.append(BasicBlock1d(in_features, h_dim)) 34 | if pooling is not None: 35 | modules.append(pool_fn(2)) 36 | in_features = h_dim 37 | self.encoder = nn.Sequential( 38 | *modules, 39 | nn.Flatten(), 40 | nn.Dropout(p=dropout), 41 | ) 42 | in_features = hidden_dims[-1] * num_samples 43 | if pooling is not None: 44 | in_features /= 2**len(hidden_dims) 45 | if abs(in_features - ceil(in_features)) > 0: 46 | raise ValueError( 47 | 'noninteger number of features - perhaps there is too much pooling?' 48 | ) 49 | in_features = int(in_features) 50 | self.mu = nn.Sequential( 51 | nn.Linear(in_features, latent_dim), 52 | nn.BatchNorm1d(latent_dim), 53 | nn.ReLU(), 54 | ) 55 | self.var = nn.Sequential( 56 | nn.Linear(in_features, latent_dim), 57 | nn.BatchNorm1d(latent_dim), 58 | nn.ReLU(), 59 | ) 60 | 61 | hidden_dims.reverse() 62 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[0]) 63 | modules = [] 64 | in_features = hidden_dims[0] 65 | for h_dim in hidden_dims: 66 | modules.append(TransposeBasicBlock1d(in_features, h_dim)) 67 | in_features = h_dim 68 | self.decoder = nn.Sequential( 69 | *modules, 70 | nn.Conv1d(hidden_dims[-1], 71 | num_samples * channels, 72 | kernel_size=3, 73 | padding=1), 74 | get_activation(output_activation), 75 | ) 76 | 77 | def encode(self, input: Tensor) -> List[Tensor]: 78 | if input.shape[-2:] != (self.channels, self.num_samples): 79 | raise ValueError('wrong input shape') 80 | x = self.encoder(input) 81 | mu = self.mu(x) 82 | var = self.var(x) 83 | return [mu, var] 84 | 85 | def decode(self, z: Tensor, **kwargs) -> Tensor: 86 | x = self.decoder_input(z) 87 | x = x.view(x.shape[0], self.hidden_dims[-1], -1) 88 | x = self.decoder(x) 89 | x = x.view(x.shape[0], self.channels, self.num_samples) 90 | return x 91 | -------------------------------------------------------------------------------- /src/models/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import abstractmethod 3 | from torch import Tensor, nn 4 | from torch.nn import functional as F 5 | from typing import Any, List 6 | 7 | from .util import reparameterize 8 | 9 | 10 | class BaseVAE(nn.Module): 11 | 12 | def __init__(self, name: str, latent_dim: int) -> None: 13 | super(BaseVAE, self).__init__() 14 | self.name = name 15 | self.latent_dim = latent_dim 16 | 17 | @abstractmethod 18 | def encode(self, input: Tensor) -> List[Tensor]: 19 | raise NotImplementedError 20 | 21 | @abstractmethod 22 | def decode(self, input: Tensor, **kwargs) -> Any: 23 | raise NotImplementedError 24 | 25 | def get_sandwich_layers(self) -> List[nn.Module]: 26 | raise NotImplementedError 27 | 28 | @abstractmethod 29 | def get_encoder(self) -> List[nn.Module]: 30 | raise NotImplementedError 31 | 32 | def forward(self, x: Tensor, **kwargs) -> List[Tensor]: 33 | mu, log_var = self.encode(x) 34 | z = reparameterize(mu, log_var) 35 | y = self.decode(z, **kwargs) 36 | return [y, x, mu, log_var, z] 37 | 38 | def sample(self, num_samples: int, current_device: int, 39 | **kwargs) -> Tensor: 40 | """ 41 | Samples from the latent space and return the corresponding 42 | image space map. 43 | :param num_samples: (Int) Number of samples 44 | :param current_device: (Int) Device to run the model 45 | :return: (Tensor) 46 | """ 47 | z = torch.randn(num_samples, self.latent_dim) 48 | z = z.to(current_device) 49 | samples = self.decode(z) 50 | return samples 51 | 52 | def generate(self, x: Tensor, **kwargs) -> Tensor: 53 | """ 54 | Given an input image x, returns the reconstructed image 55 | :param x: (Tensor) [B x C x H x W] 56 | :return: (Tensor) [B x C x H x W] 57 | """ 58 | 59 | return self.forward(x)[0] 60 | 61 | def loss_function(self, 62 | recons: Tensor, 63 | input: Tensor, 64 | mu: Tensor, 65 | log_var: Tensor, 66 | z: Tensor, 67 | objective: str = 'default', 68 | beta: float = 1.0, 69 | gamma: float = 1.0, 70 | target_capacity: float = 25.0, 71 | **kwargs) -> dict: 72 | recons_loss = F.mse_loss(recons, input) 73 | 74 | result = {'loss': recons_loss, 'Reconstruction_Loss': recons_loss} 75 | 76 | kld_loss = torch.mean( 77 | -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1), 78 | dim=0) 79 | result['KLD_Loss'] = kld_loss 80 | 81 | if objective == 'default': 82 | # O.G. beta loss term applied directly to KLD 83 | result['loss'] += beta * kld_loss 84 | elif objective == 'controlled_capacity': 85 | # Use controlled capacity increase from 86 | # https://arxiv.org/pdf/1804.03599.pdf 87 | capacity_loss = torch.abs(kld_loss - target_capacity) 88 | result['Capacity_Loss'] = capacity_loss 89 | result['loss'] += gamma * capacity_loss 90 | else: 91 | raise ValueError(f'unknown objective "{objective}"') 92 | 93 | return result 94 | -------------------------------------------------------------------------------- /src/models/localizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import abstractmethod 3 | from torch import Tensor, nn 4 | from torch.nn import functional as F 5 | from typing import Optional 6 | 7 | from torchvision.ops import (complete_box_iou_loss, distance_box_iou_loss, 8 | generalized_box_iou_loss) 9 | 10 | 11 | class Localizer(nn.Module): 12 | """ Base class for a model that carries out localization. 13 | """ 14 | 15 | def __init__(self, name: str) -> None: 16 | super().__init__() 17 | self.name = name 18 | 19 | @abstractmethod 20 | def forward(self, x: Tensor, **kwargs) -> Tensor: 21 | raise NotImplementedError 22 | 23 | def loss_function(self, 24 | pred_params: Tensor, 25 | targ_params: Tensor, 26 | objective: Optional[str] = 'cbiou+dbiou') -> dict: 27 | # Sanity check to ensure that the parameters are valid BBs. 28 | assert (pred_params[:, 0] <= pred_params[:, 2]).all(), \ 29 | "Predicted BBs are invalid." 30 | assert (pred_params[:, 1] <= pred_params[:, 3]).all(), \ 31 | "Predicted BBs are invalid." 32 | assert (targ_params[:, 0] <= targ_params[:, 2]).all(), \ 33 | "Target BBs are invalid." 34 | assert (targ_params[:, 1] <= targ_params[:, 3]).all(), \ 35 | "Target BBs are invalid." 36 | # Calculate various loss metrics. Some may go unused. 37 | losses = dict(mse=F.mse_loss(pred_params, targ_params), 38 | dbiou=distance_box_iou_loss(pred_params, 39 | targ_params, 40 | reduction='sum'), 41 | cbiou=complete_box_iou_loss(pred_params, 42 | targ_params, 43 | reduction='sum'), 44 | gbiou=generalized_box_iou_loss(pred_params, 45 | targ_params, 46 | reduction='sum')) 47 | # Calculate the total loss based on the objectives. 48 | loss = torch.zeros(1) 49 | for obj in objective.split('+'): 50 | assert obj in losses, f"Objective '{obj}' not recognized." 51 | loss += losses[obj] 52 | return {'loss': loss, **{f'{k}_Loss': v for k, v in losses.items()}} 53 | 54 | 55 | def bb_intersection_over_union(boxA, boxB): 56 | # Source: https://www.pyimagesearch.com/2016/11/07/intersection-over-union-iou-for-object-detection/ 57 | # determine the (x, y)-coordinates of the intersection rectangle 58 | xA = torch.max(boxA[:, 0], boxB[:, 0]) 59 | yA = torch.max(boxA[:, 1], boxB[:, 1]) 60 | xB = torch.min(boxA[:, 2], boxB[:, 2]) 61 | yB = torch.min(boxA[:, 3], boxB[:, 3]) 62 | # compute the area of intersection rectangle 63 | zeros = torch.zeros(xA.shape) 64 | interArea = torch.max(zeros, xB - xA + 1) * torch.max(zeros, yB - yA + 1) 65 | # compute the area of both the prediction and ground-truth 66 | # rectangles 67 | boxAArea = (boxA[:, 2] - boxA[:, 0] + 1) * (boxA[:, 3] - boxA[:, 1] + 1) 68 | boxBArea = (boxB[:, 2] - boxB[:, 0] + 1) * (boxB[:, 3] - boxB[:, 1] + 1) 69 | # compute the intersection over union by taking the intersection 70 | # area and dividing it by the sum of prediction + ground-truth 71 | # areas - the interesection area 72 | iou = interArea / (boxAArea + boxBArea - interArea).float() 73 | # return the intersection over union value 74 | return iou 75 | -------------------------------------------------------------------------------- /src/models/resnet_vae3d.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | from torch import Tensor, nn 3 | from typing import List 4 | 5 | from .base import BaseVAE 6 | from .resnet3d import BasicBlock3d, TransposeBasicBlock3d 7 | from .util import get_activation, get_pooling3d 8 | 9 | 10 | class ResNetVAE3d(BaseVAE): 11 | 12 | def __init__( 13 | self, 14 | name: str, 15 | latent_dim: int, 16 | hidden_dims: List[int], 17 | width: int, 18 | height: int, 19 | depth: int, 20 | channels: int, 21 | dropout: float = 0.4, 22 | enable_fid: bool = False, # per-frame FID, for video 23 | pooling: str = None, 24 | output_activation: str = 'sigmoid') -> None: 25 | super(ResNetVAE3d, self).__init__(name=name, latent_dim=latent_dim) 26 | self.width = width 27 | self.height = height 28 | self.depth = depth 29 | self.channels = channels 30 | self.hidden_dims = hidden_dims.copy() 31 | 32 | if pooling is not None: 33 | pool_fn = get_pooling3d(pooling) 34 | 35 | # Encoder 36 | modules = [] 37 | in_features = channels 38 | for h_dim in hidden_dims: 39 | modules.append(BasicBlock3d(in_features, h_dim)) 40 | if pooling is not None: 41 | modules.append(pool_fn(2)) 42 | in_features = h_dim 43 | self.encoder = nn.Sequential( 44 | *modules, 45 | nn.Flatten(), 46 | nn.Dropout(p=dropout), 47 | ) 48 | in_features = hidden_dims[-1] * width * height * depth 49 | if pooling is not None: 50 | in_features /= 8**len(hidden_dims) 51 | if abs(in_features - ceil(in_features)) > 0: 52 | raise ValueError( 53 | 'noninteger number of features - perhaps there is too much pooling?' 54 | ) 55 | in_features = int(in_features) 56 | self.mu = nn.Sequential( 57 | nn.Linear(in_features, latent_dim), 58 | nn.BatchNorm1d(latent_dim), 59 | nn.ReLU(), 60 | ) 61 | self.var = nn.Sequential( 62 | nn.Linear(in_features, latent_dim), 63 | nn.BatchNorm1d(latent_dim), 64 | nn.ReLU(), 65 | ) 66 | 67 | # Decoder 68 | hidden_dims.reverse() 69 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[0] * 8) 70 | modules = [] 71 | in_features = hidden_dims[0] 72 | for h_dim in hidden_dims: 73 | modules.append(TransposeBasicBlock3d(in_features, h_dim)) 74 | in_features = h_dim 75 | self.decoder = nn.Sequential( 76 | *modules, 77 | nn.Conv3d(hidden_dims[-1], 78 | width * height * depth * channels // 8, 79 | kernel_size=3, 80 | padding=1), 81 | get_activation(output_activation), 82 | ) 83 | 84 | def encode(self, input: Tensor) -> List[Tensor]: 85 | if input.shape[-4:] != (self.channels, self.depth, self.height, 86 | self.width): 87 | raise ValueError('wrong input shape') 88 | x = self.encoder(input) 89 | mu = self.mu(x) 90 | var = self.var(x) 91 | return [mu, var] 92 | 93 | def decode(self, z: Tensor, **kwargs) -> Tensor: 94 | x = self.decoder_input(z) 95 | x = x.view(x.shape[0], self.hidden_dims[0], 2, 2, 2) 96 | x = self.decoder(x) 97 | x = x.view(x.shape[0], self.channels, self.depth, self.height, 98 | self.width) 99 | return x 100 | -------------------------------------------------------------------------------- /src/models/resnet_vae4d.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | from torch import Tensor, nn 3 | from typing import List 4 | 5 | from .base import BaseVAE 6 | from .conv4d import Conv4d 7 | from .resnet4d import BasicBlock4d 8 | from .util import get_activation, get_pooling4d 9 | 10 | 11 | class ResNetVAE4d(BaseVAE): 12 | 13 | def __init__(self, 14 | name: str, 15 | latent_dim: int, 16 | hidden_dims: List[int], 17 | width: int, 18 | height: int, 19 | depth: int, 20 | frames: int, 21 | channels: int, 22 | dropout: float = 0.4, 23 | pooling: str = None, 24 | output_activation: str = 'sigmoid') -> None: 25 | super(ResNetVAE4d, self).__init__(name=name, latent_dim=latent_dim) 26 | self.width = width 27 | self.height = height 28 | self.depth = depth 29 | self.frames = frames 30 | self.channels = channels 31 | self.hidden_dims = hidden_dims.copy() 32 | 33 | if pooling is not None: 34 | pool_fn = get_pooling4d(pooling) 35 | 36 | # Encoder 37 | modules = [] 38 | in_features = channels 39 | for h_dim in hidden_dims: 40 | modules.append(BasicBlock4d(in_features, h_dim)) 41 | if pooling is not None: 42 | modules.append(pool_fn(2)) 43 | in_features = h_dim 44 | self.encoder = nn.Sequential( 45 | *modules, 46 | nn.Flatten(), 47 | nn.Dropout(p=dropout), 48 | ) 49 | in_features = hidden_dims[-1] * width * height * depth * frames 50 | if pooling is not None: 51 | in_features /= 16**len(hidden_dims) 52 | if abs(in_features - ceil(in_features)) > 0: 53 | raise ValueError( 54 | 'noninteger number of features - perhaps there is too much pooling?' 55 | ) 56 | in_features = int(in_features) 57 | self.mu = nn.Sequential( 58 | nn.Linear(in_features, latent_dim), 59 | nn.BatchNorm1d(latent_dim), 60 | nn.ReLU(), 61 | ) 62 | self.var = nn.Sequential( 63 | nn.Linear(in_features, latent_dim), 64 | nn.BatchNorm1d(latent_dim), 65 | nn.ReLU(), 66 | ) 67 | 68 | # Decode 69 | hidden_dims.reverse() 70 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[0] * 16) 71 | modules = [] 72 | in_features = hidden_dims[0] 73 | for h_dim in hidden_dims: 74 | modules.append(BasicBlock4d(in_features, h_dim)) 75 | in_features = h_dim 76 | self.decoder = nn.Sequential( 77 | *modules, 78 | Conv4d(hidden_dims[-1], 79 | width * height * depth * frames * channels // 16, 80 | kernel_size=3, 81 | padding=1), 82 | get_activation(output_activation), 83 | ) 84 | 85 | def encode(self, input: Tensor) -> List[Tensor]: 86 | if input.shape[-5:] != (self.channels, self.frames, self.depth, 87 | self.height, self.width): 88 | raise ValueError('wrong input shape') 89 | x = self.encoder(input) 90 | mu = self.mu(x) 91 | var = self.var(x) 92 | return [mu, var] 93 | 94 | def decode(self, z: Tensor, **kwargs) -> Tensor: 95 | x = self.decoder_input(z) 96 | x = x.view(x.shape[0], self.hidden_dims[0], 2, 2, 2, 2) 97 | x = self.decoder(x) 98 | x = x.view(x.shape[0], self.channels, self.frames, self.depth, 99 | self.height, self.width) 100 | return x 101 | -------------------------------------------------------------------------------- /src/compiler.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import sys 5 | 6 | import youtube_dl 7 | 8 | parser = argparse.ArgumentParser(description='Youtube dataset compiler') 9 | parser.add_argument( 10 | '--input', 11 | '-i', 12 | dest="input", 13 | metavar='INPUT', 14 | help='path to text file containing youtube video or playlist links', 15 | default='../dataset/doom/full.txt') 16 | parser.add_argument('--output', 17 | '-o', 18 | dest="output", 19 | metavar='OUTPUT', 20 | help='output file path', 21 | default='../dataset/doom/compiled.json') 22 | parser.add_argument('--download', 23 | dest="download", 24 | metavar='DOWNLOAD', 25 | help='download videos if true', 26 | default=False) 27 | parser.add_argument('--cache_dir', 28 | dest="cache_dir", 29 | metavar='CACHE_DIR', 30 | help='video download path', 31 | default='E:/cache') 32 | parser.add_argument( 33 | '--clean', 34 | dest="clean", 35 | metavar='CLEAN', 36 | help='remove missing videos from output records (do not download)', 37 | default=False) 38 | args = parser.parse_args() 39 | 40 | completed = [] 41 | videos = [] 42 | 43 | completed_path = os.path.join(os.path.dirname(args.output), '.completed.txt') 44 | 45 | try: 46 | with open(completed_path) as f: 47 | completed = json.loads(f.read()) 48 | except: 49 | pass 50 | 51 | if os.path.exists(args.output): 52 | with open(args.output) as f: 53 | videos = json.loads(f.read()) 54 | 55 | 56 | def write_completed(): 57 | with open(completed_path, 'w') as f: 58 | f.write(json.dumps(completed)) 59 | 60 | 61 | def write_videos(): 62 | with open(args.output, 'w') as f: 63 | f.write(json.dumps(videos)) 64 | 65 | 66 | def process_video(video, ydl, download): 67 | videos.append({ 68 | k: video[k] 69 | for k in [ 70 | 'id', 'ext', 'vcodec', 'uploader_id', 'channel_id', 'duration', 71 | 'width', 'height', 'fps' 72 | ] 73 | }) 74 | id = video['id'] 75 | path = os.path.join(args.cache_dir, id + '.mp4') 76 | if os.path.exists(path): 77 | print(f'{id} already downloaded') 78 | elif download: 79 | try: 80 | ydl.extract_info( 81 | f'https://youtube.com/watch?v={id}', 82 | download=True, 83 | ) 84 | except: 85 | print(f'Failed to download {id}: {sys.exc_info()}') 86 | 87 | 88 | if args.clean: 89 | new_videos = [ 90 | video for video in videos 91 | if os.path.exists(os.path.join(args.cache_dir, video['id'] + '.mp4')) 92 | ] 93 | print(f'Removed {len(videos)-len(new_videos)} videos') 94 | videos = new_videos 95 | write_videos() 96 | sys.exit(0) 97 | 98 | with open(args.input, "r") as f: 99 | lines = [line.strip() for line in f] 100 | 101 | with youtube_dl.YoutubeDL({ 102 | 'verbose': True, 103 | 'outtmpl': args.cache_dir + '/%(id)s.%(ext)s', 104 | }) as ydl: 105 | for line in lines: 106 | if line in completed: 107 | print(f'Skipping {line}') 108 | continue 109 | result = ydl.extract_info( 110 | line, 111 | download=False, 112 | ) 113 | if 'entries' in result: 114 | # It is a playlist 115 | for video in result['entries']: 116 | process_video(video, ydl, args.download) 117 | else: 118 | # Just a single video 119 | process_video(result, ydl, args.download) 120 | write_videos() 121 | completed.append(line) 122 | write_completed() 123 | 124 | # TODO: resample videos 125 | -------------------------------------------------------------------------------- /src/localization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import Tensor 4 | from torch.nn.parameter import Parameter 5 | from torch.utils.data import Dataset 6 | from typing import Iterator 7 | 8 | from base_experiment import BaseExperiment 9 | from dataset import get_example_shape 10 | from models import create_model 11 | from plot import get_plot_fn 12 | 13 | 14 | class LocalizationExperiment(BaseExperiment): 15 | 16 | def __init__(self, config: dict, enable_tune: bool = False, **kwargs): 17 | super().__init__(config=config, enable_tune=enable_tune, **kwargs) 18 | exp_params = config['exp_params'] 19 | input_shape = get_example_shape(exp_params['data']) 20 | localizer = create_model(**config['model_params'], 21 | input_shape=input_shape) 22 | self.localizer = localizer 23 | 24 | def trainable_parameters(self) -> Iterator[Parameter]: 25 | return self.localizer.parameters() 26 | 27 | def sample_images(self, plot: dict, batch: list): 28 | test_input = [] 29 | pred_params = [] 30 | target_params = [] 31 | for item in batch: 32 | x, target_label, target_param = item 33 | x = x.unsqueeze(0) 34 | test_input.append(x) 35 | pred_param = self.localizer(x.to(self.curr_device)) 36 | pred_params.append(pred_param.detach().cpu()) 37 | target_params.append(target_param.unsqueeze(0)) 38 | test_input = torch.cat(test_input, dim=0).cpu() 39 | pred_params = torch.cat(pred_params, dim=0) 40 | target_params = torch.cat(target_params, dim=0) 41 | 42 | # Extensionless output path (let plotting function choose extension) 43 | out_path = os.path.join( 44 | self.logger.save_dir, self.logger.name, 45 | f"version_{self.logger.version}", 46 | f"{self.logger.name}_{plot['fn']}_{self.global_step}") 47 | fn = get_plot_fn(plot['fn']) 48 | image = fn(test_input=test_input, 49 | pred_params=pred_params, 50 | target_params=target_params, 51 | out_path=out_path, 52 | **plot['params']) 53 | self.logger.experiment.add_image(plot['fn'], image, self.global_step) 54 | vis = self.visdom() 55 | if vis is not None: 56 | vis.image(image, win=plot['fn']) 57 | 58 | def training_step(self, batch, batch_idx): 59 | real_img, targ_labels, targ_params = batch 60 | self.curr_device = self.device 61 | real_img = real_img.to(self.curr_device) 62 | pred_params = self.localizer(real_img).cpu() 63 | train_loss = self.localizer.loss_function( 64 | pred_params, targ_params.cpu(), 65 | **self.params.get('loss_params', {})) 66 | self.log_train_step(train_loss) 67 | return train_loss 68 | 69 | def validation_step(self, batch, batch_idx): 70 | real_img, targ_labels, targ_params = batch 71 | self.curr_device = self.device 72 | real_img = real_img.to(self.curr_device) 73 | pred_params = self.localizer(real_img).cpu() 74 | val_loss = self.localizer.loss_function( 75 | pred_params, targ_params.cpu(), 76 | **self.params.get('loss_params', {})) 77 | self.log_val_step(val_loss) 78 | return val_loss 79 | 80 | def get_val_batches(self, dataset: Dataset) -> list: 81 | val_batches = [] 82 | for plot in self.plots: 83 | batch = [ 84 | get_positive_example(dataset) 85 | for _ in range(plot['batch_size']) 86 | ] 87 | for _, label, _ in batch: 88 | assert torch.is_nonzero(label) 89 | val_batches.append(batch) 90 | return val_batches 91 | 92 | 93 | def get_positive_example(ds): 94 | try: 95 | return ds.get_positive_example() 96 | except: 97 | return get_positive_example(ds.dataset) 98 | -------------------------------------------------------------------------------- /src/render_teapot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | 5 | import matplotlib.pyplot as plt 6 | import pytorch3d 7 | from pytorch3d.io import load_obj, load_objs_as_meshes 8 | from pytorch3d.renderer import (DirectionalLights, FoVPerspectiveCameras, 9 | Materials, MeshRasterizer, MeshRenderer, 10 | PointLights, RasterizationSettings, 11 | SoftPhongShader, TexturesUV, TexturesVertex, 12 | look_at_view_transform) 13 | from pytorch3d.structures import Meshes 14 | from pytorch3d.transforms import (Rotate, Transform3d, Translate, 15 | random_rotations) 16 | from pytorch3d.vis.plotly_vis import (AxisArgs, plot_batch_individually, 17 | plot_scene) 18 | from pytorch3d.vis.texture_vis import texturesuv_image_matplotlib 19 | 20 | device = torch.device("cuda") 21 | 22 | mesh = load_objs_as_meshes(['data/cow.obj'], device=device) 23 | 24 | # Define the settings for rasterization and shading. Here we set the output image to be of size 25 | # 512x512. As we are rendering images for visualization purposes only we will set faces_per_pixel=1 26 | # and blur_radius=0.0. We also set bin_size and max_faces_per_bin to None which ensure that 27 | # the faster coarse-to-fine rasterization method is used. Refer to rasterize_meshes.py for 28 | # explanations of these parameters. Refer to docs/notes/renderer.md for an explanation of 29 | # the difference between naive and coarse-to-fine rasterization. 30 | raster_settings = RasterizationSettings( 31 | image_size=512, 32 | blur_radius=0.0, 33 | faces_per_pixel=1, 34 | ) 35 | 36 | # Place a point light in front of the object. As mentioned above, the front of the cow is facing the 37 | # -z direction. 38 | lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]]) 39 | 40 | n = 3 41 | 42 | zfar = 100.0 43 | 44 | for i in range(n): 45 | t = i / n 46 | # Initialize a camera. 47 | # With world coordinates +Y up, +X left and +Z in, the front of the cow is facing the -Z direction. 48 | # So we move the camera by 180 in the azimuth direction so it is facing the front of the cow. 49 | R, T = look_at_view_transform(0, 0, 0) 50 | cameras = FoVPerspectiveCameras(device=device, R=R, T=T, zfar=zfar) 51 | smin = 0.1 52 | smax = 2.0 53 | srange = smax - smin 54 | scale = (torch.rand(1).squeeze() * srange + smin).item() 55 | 56 | # Generate a random NDC coordinate https://pytorch3d.org/docs/cameras 57 | x, y, d = torch.rand(3) 58 | x = x * 2.0 - 1.0 59 | y = y * 2.0 - 1.0 60 | trans = torch.Tensor([x, y, d]).to(device) 61 | trans = cameras.unproject_points(trans.unsqueeze(0), 62 | world_coordinates=False, 63 | scaled_depth_input=True)[0] 64 | rot = random_rotations(1)[0].to(device) 65 | 66 | transform = Transform3d() \ 67 | .scale(scale) \ 68 | .compose(Rotate(rot)) \ 69 | .translate(*trans) 70 | 71 | # TODO: transform mesh 72 | # Create a phong renderer by composing a rasterizer and a shader. The textured phong shader will 73 | # interpolate the texture uv coordinates for each vertex, sample from a texture image and 74 | # apply the Phong lighting model 75 | renderer = MeshRenderer(rasterizer=MeshRasterizer( 76 | cameras=cameras, raster_settings=raster_settings), 77 | shader=SoftPhongShader( 78 | device=device, 79 | cameras=cameras, 80 | lights=lights, 81 | )) 82 | images = renderer(mesh.scale_verts(scale), 83 | R=rot.unsqueeze(0), 84 | T=trans.unsqueeze(0)) 85 | plt.figure(figsize=(10, 10)) 86 | plt.imshow(images[0, ..., :3].cpu().numpy()) 87 | plt.grid("off") 88 | plt.axis("off") 89 | plt.show() 90 | plt.close('all') 91 | -------------------------------------------------------------------------------- /src/dataset/la5c.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data as data 3 | from torch import Tensor 4 | from typing import List 5 | 6 | import nilearn as nl 7 | import numpy as np 8 | 9 | 10 | class LA5cDataset(data.Dataset): 11 | """ UCLA Consortium for Neuropsychiatric Phenomics LA5c Study 12 | 13 | https://openneuro.org/datasets/ds000030/ 14 | 15 | Args: 16 | root: Path to download directory, e.g. /data/ds000030-download 17 | 18 | phenotypes: List of phenotype paths for labels. The file composes the 19 | first part of the path, and the column name from the tsv file is the 20 | second part. e.g. 'saps/saps17' loads the first data column from 21 | phenotype/saps.tsv 22 | 23 | exclude_na: Exclude subjects with any answers listed as N/A 24 | 25 | Reference: 26 | Gorgolewski KJ, Durnez J and Poldrack RA. Preprocessed Consortium for 27 | Neuropsychiatric Phenomics dataset [version 2; peer review: 2 approved]. 28 | F1000Research 2017, 6:1262 (https://doi.org/10.12688/f1000research.11964.2) 29 | """ 30 | 31 | def __init__(self, 32 | root: str, 33 | phenotypes: List[str] = ['language/bilingual'], 34 | exclude_na: bool = True): 35 | super(LA5cDataset, self).__init__() 36 | self.root = root 37 | self.subjects = [f for f in os.listdir(root) if f.startswith('sub-')] 38 | labels = {} 39 | for phenotype in phenotypes: 40 | parts = phenotype.split('/') 41 | tsv_path = os.path.join(root, 'phenotype', parts[0] + '.tsv') 42 | with open(tsv_path, 'r') as f: 43 | columns = f.readline().strip().split('\t') 44 | col_no = None 45 | for i, column in enumerate(columns): 46 | if column == parts[1]: 47 | col_no = i 48 | break 49 | if col_no is None: 50 | raise ValueError( 51 | f'unable to find metric {parts[1]} in {tsv_path}') 52 | cols = [line.strip().split('\t') for line in f] 53 | for col in cols: 54 | sub = col[0] 55 | if sub not in self.subjects: 56 | continue 57 | label = col[col_no] 58 | if label == 'n/a' and exclude_na: 59 | self.subjects.remove(sub) 60 | continue 61 | elif label == 'N': 62 | label = 0 63 | elif label == 'Y': 64 | label = 1 65 | else: 66 | try: 67 | label = float(label) 68 | except: 69 | pass 70 | if sub in labels: 71 | labels[sub].append(label) 72 | else: 73 | labels[sub] = [label] 74 | self.labels = labels 75 | 76 | # Sanity check, make sure all examples have labels 77 | for sub in self.subjects: 78 | if sub not in self.labels: 79 | raise ValueError(f'subject {sub} has data but no labels') 80 | 81 | def __getitem__(self, index): 82 | sub = self.subjects[index] 83 | labels = Tensor(self.labels[sub]) 84 | path = os.path.join(self.root, sub, 'anat', f'{sub}_T1w.nii.gz') 85 | img = nl.image.load_img(path) 86 | img = Tensor(np.asanyarray(img.dataobj)) 87 | if img.shape != (176, 256, 256): 88 | raise ValueError(f'invalid shape {img.shape}') 89 | img = img.unsqueeze(0) 90 | return (img, labels) 91 | 92 | def __len__(self): 93 | return len(self.subjects) 94 | 95 | 96 | if __name__ == '__main__': 97 | ds = LA5cDataset(root='/data/openneuro/ds000030-download') 98 | print(ds[0][1]) 99 | print(ds[len(ds) - 1][0].shape) 100 | -------------------------------------------------------------------------------- /src/dataset/toy_neural_graphics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | 5 | import numpy as np 6 | from pytorch3d.io import load_obj, load_objs_as_meshes 7 | from pytorch3d.renderer import (DirectionalLights, FoVPerspectiveCameras, 8 | HardFlatShader, Materials, MeshRasterizer, 9 | MeshRenderer, PointLights, 10 | RasterizationSettings, SoftPhongShader, 11 | TexturesUV, TexturesVertex, 12 | look_at_view_transform) 13 | from pytorch3d.transforms import (Rotate, Transform3d, Translate, 14 | random_rotation) 15 | 16 | 17 | class ToyNeuralGraphicsDataset(data.Dataset): 18 | 19 | def __init__(self, 20 | dir: str, 21 | rasterization_settings: dict, 22 | znear: float = 1.0, 23 | zfar: float = 1000.0, 24 | scale_min: float = 0.5, 25 | scale_max: float = 2.0, 26 | device: str = 'cuda'): 27 | super(ToyNeuralGraphicsDataset, self).__init__() 28 | device = torch.device(device) 29 | self.device = device 30 | self.scale_min = scale_min 31 | self.scale_max = scale_max 32 | self.scale_range = scale_max - scale_min 33 | objs = [ 34 | os.path.join(dir, f) for f in os.listdir(dir) if f.endswith('.obj') 35 | ] 36 | self.meshes = load_objs_as_meshes(objs, device=device) 37 | R, T = look_at_view_transform(0, 0, 0) 38 | self.cameras = FoVPerspectiveCameras(R=R, 39 | T=T, 40 | znear=znear, 41 | zfar=zfar, 42 | device=device) 43 | self.renderer = MeshRenderer(rasterizer=MeshRasterizer( 44 | cameras=self.cameras, 45 | raster_settings=RasterizationSettings(**rasterization_settings), 46 | ), 47 | shader=HardFlatShader( 48 | device=device, 49 | cameras=self.cameras, 50 | )) 51 | 52 | def get_random_transform(self): 53 | scale = (torch.rand(1).squeeze() * self.scale_range + 54 | self.scale_min).item() 55 | 56 | rot = random_rotation() 57 | 58 | x, y, d = torch.rand(3) 59 | x = x * 2.0 - 1.0 60 | y = y * 2.0 - 1.0 61 | trans = torch.Tensor([x, y, d]) 62 | trans = self.cameras.unproject_points( 63 | trans.unsqueeze(0).to(self.device), 64 | world_coordinates=False, 65 | scaled_depth_input=True)[0].cpu() 66 | return scale, rot, trans 67 | 68 | def __getitem__(self, index): 69 | index %= len(self.meshes) 70 | scale, rot, trans = self.get_random_transform() 71 | transform = Transform3d() \ 72 | .scale(scale) \ 73 | .compose(Rotate(rot)) \ 74 | .translate(*trans) \ 75 | .get_matrix() \ 76 | .squeeze() 77 | mesh = self.meshes[index].scale_verts(scale) 78 | pixels = self.renderer(mesh, 79 | R=rot.unsqueeze(0).to(self.device), 80 | T=trans.unsqueeze(0).to(self.device)) 81 | pixels = pixels[0, ..., :3].transpose(0, -1) 82 | return (pixels, [transform.to(self.device)]) 83 | 84 | def __len__(self): 85 | return len(self.meshes) * 1024 86 | 87 | 88 | if __name__ == '__main__': 89 | import matplotlib.pyplot as plt 90 | ds = ToyNeuralGraphicsDataset('data/', 91 | rasterization_settings=dict( 92 | image_size=256, 93 | blur_radius=0.0, 94 | faces_per_pixel=1)) 95 | image, labels = ds[0] 96 | plt.figure(figsize=(10, 10)) 97 | plt.imshow(image.cpu().numpy()) 98 | plt.grid("off") 99 | plt.axis("off") 100 | plt.show() 101 | plt.close('all') 102 | -------------------------------------------------------------------------------- /src/dataset/trends_fmri.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | 5 | import h5py 6 | import nilearn as nl 7 | import numpy as np 8 | 9 | 10 | def load_subject(filename: str, mask_niimg): 11 | """ 12 | Load a subject saved in .mat format with 13 | the version 7.3 flag. Return the subject 14 | niimg, using a mask niimg as a template 15 | for nifti headers. 16 | 17 | Args: 18 | filename the .mat filename for the subject data 19 | mask_niimg niimg object the mask niimg object used for nifti headers 20 | """ 21 | subject_data = None 22 | with h5py.File(filename, 'r') as f: 23 | subject_data = f['SM_feature'][()] 24 | # It's necessary to reorient the axes, since h5py flips axis order 25 | subject_data = np.moveaxis(subject_data, [0, 1, 2, 3], [3, 2, 1, 0]) 26 | subject_niimg = nl.image.new_img_like(mask_niimg, 27 | subject_data, 28 | affine=mask_niimg.affine, 29 | copy_header=True) 30 | return subject_niimg 31 | 32 | 33 | TRENDS_HEADER = "Id,age,domain1_var1,domain1_var2,domain2_var1,domain2_var2\n" 34 | 35 | 36 | def load_scores(path: str): 37 | scores = {} 38 | with open(path, 'r') as f: 39 | hdr = f.readline() 40 | if hdr != TRENDS_HEADER: 41 | raise ValueError("bad header") 42 | for line in f: 43 | parts = line.strip().split(',') 44 | if len(parts) == 0: 45 | continue 46 | scores[parts[0]] = [float(p) for p in parts[1:] if len(p) > 0] 47 | return scores 48 | 49 | 50 | class TReNDSfMRIDataset(data.Dataset): 51 | 52 | def __init__(self, root: str, train: bool = True): 53 | super(TReNDSfMRIDataset, self).__init__() 54 | self.root = root 55 | self.mat_dir = os.path.join(root, 56 | 'fMRI_train' if train else 'fMRI_test') 57 | self.files = os.listdir(self.mat_dir) 58 | self.mask = nl.image.load_img(os.path.join(root, 'fMRI_mask.nii')) 59 | self.scores = load_scores(os.path.join(root, 'train_scores.csv')) 60 | 61 | def __getitem__(self, index): 62 | file = self.files[index] 63 | path = os.path.join(self.mat_dir, file) 64 | data = load_subject(path, self.mask).get_data() 65 | data = torch.Tensor(data) 66 | scores = self.scores[file[:-4]] 67 | scores = torch.Tensor(scores) 68 | return data, scores 69 | 70 | def __len__(self): 71 | return len(self.files) 72 | 73 | 74 | if __name__ == '__main__': 75 | import matplotlib.pyplot as plt 76 | import nilearn.plotting as nlplt 77 | base_path = 'E:\\trends-fmri' 78 | ds = TReNDSfMRIDataset(base_path) 79 | print(ds[0][0].shape) 80 | smri_filename = os.path.join(base_path, 'ch2better.nii') 81 | subject_filename = os.path.join(base_path, 'fMRI_test/10228.mat') 82 | mask_niimg = ds.mask 83 | 84 | subject_niimg = load_subject(subject_filename, mask_niimg) 85 | 86 | grid_size = int(np.ceil(np.sqrt(subject_niimg.shape[0]))) 87 | fig, axes = plt.subplots(grid_size, 88 | grid_size, 89 | figsize=(grid_size * 10, grid_size * 10)) 90 | [axi.set_axis_off() for axi in axes.ravel()] 91 | row = -1 92 | for i, cur_img in enumerate(nl.image.iter_img(subject_niimg)): 93 | col = i % grid_size 94 | if col == 0: 95 | row += 1 96 | nlplt.plot_stat_map(cur_img, 97 | bg_img=smri_filename, 98 | title="IC %d" % i, 99 | axes=axes[row, col], 100 | threshold=3, 101 | colorbar=False) 102 | plt.show() 103 | 104 | print("Image shape is %s" % (str(subject_niimg.shape))) 105 | num_components = subject_niimg.shape[-1] 106 | print("Detected {num_components} spatial maps".format( 107 | num_components=num_components)) 108 | nlplt.plot_prob_atlas(subject_niimg, 109 | bg_img=smri_filename, 110 | view_type='filled_contours', 111 | draw_cross=False, 112 | title='All %d spatial maps' % num_components, 113 | threshold='auto') 114 | nlplt.show() 115 | -------------------------------------------------------------------------------- /src/models/iou.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def bbox_overlaps(bboxes1, bboxes2, mode='iou', allow_neg=False): 5 | """Calculate the ious between each bbox of bboxes1 and bboxes2. 6 | 7 | Args: 8 | bboxes1(ndarray): shape (n, 4) 9 | bboxes2(ndarray): shape (k, 4) 10 | mode(str): iou (intersection over union) or iof (intersection 11 | over foreground) 12 | 13 | Returns: 14 | ious(ndarray): shape (n, k) 15 | """ 16 | 17 | assert mode in ['iou', 'iof'] 18 | 19 | bboxes1 = bboxes1.astype(np.float32) 20 | bboxes2 = bboxes2.astype(np.float32) 21 | rows = bboxes1.shape[0] 22 | cols = bboxes2.shape[0] 23 | ious = np.zeros((rows, cols), dtype=np.float32) 24 | if rows * cols == 0: 25 | return ious 26 | exchange = False 27 | if bboxes1.shape[0] > bboxes2.shape[0]: 28 | bboxes1, bboxes2 = bboxes2, bboxes1 29 | ious = np.zeros((cols, rows), dtype=np.float32) 30 | exchange = True 31 | area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * (bboxes1[:, 3] - 32 | bboxes1[:, 1] + 1) 33 | area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * (bboxes2[:, 3] - 34 | bboxes2[:, 1] + 1) 35 | for i in range(bboxes1.shape[0]): 36 | x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0]) 37 | y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1]) 38 | x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2]) 39 | y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3]) 40 | if not allow_neg: 41 | overlap = np.maximum(x_end - x_start + 1, 0) * np.maximum( 42 | y_end - y_start + 1, 0) 43 | else: 44 | overlap = (x_end - x_start + 1) * (y_end - y_start + 1) 45 | flag = np.ones(overlap.shape) 46 | flag[x_end - x_start + 1 < 0] = -1. 47 | flag[y_end - y_start + 1 < 0] = -1. 48 | overlap = flag * np.abs(overlap) 49 | 50 | if mode == 'iou': 51 | union = area1[i] + area2 - overlap 52 | else: 53 | union = area1[i] if not exchange else area2 54 | ious[i, :] = overlap / union 55 | if exchange: 56 | ious = ious.T 57 | return ious 58 | 59 | 60 | def giou(bboxes1, bboxes2): 61 | """Calculate the gious between each bbox of bboxes1 and bboxes2. 62 | 63 | Args: 64 | bboxes1(ndarray): shape (n, 4) 65 | bboxes2(ndarray): shape (k, 4) 66 | 67 | Returns: 68 | gious(ndarray): shape (n, k) 69 | """ 70 | 71 | bboxes1 = bboxes1.astype(np.float32) 72 | bboxes2 = bboxes2.astype(np.float32) 73 | rows = bboxes1.shape[0] 74 | cols = bboxes2.shape[0] 75 | ious = np.zeros((rows, cols), dtype=np.float32) 76 | if rows * cols == 0: 77 | return ious 78 | exchange = False 79 | if bboxes1.shape[0] > bboxes2.shape[0]: 80 | bboxes1, bboxes2 = bboxes2, bboxes1 81 | ious = np.zeros((cols, rows), dtype=np.float32) 82 | exchange = True 83 | area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * (bboxes1[:, 3] - 84 | bboxes1[:, 1] + 1) 85 | area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * (bboxes2[:, 3] - 86 | bboxes2[:, 1] + 1) 87 | for i in range(bboxes1.shape[0]): 88 | x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0]) 89 | x_min = np.minimum(bboxes1[i, 0], bboxes2[:, 0]) 90 | y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1]) 91 | y_min = np.minimum(bboxes1[i, 1], bboxes2[:, 1]) 92 | x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2]) 93 | x_max = np.maximum(bboxes1[i, 2], bboxes2[:, 2]) 94 | y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3]) 95 | y_max = np.maximum(bboxes1[i, 3], bboxes2[:, 3]) 96 | 97 | overlap = np.maximum(x_end - x_start + 1, 0) * np.maximum( 98 | y_end - y_start + 1, 0) 99 | closure = np.maximum(x_max - x_min + 1, 0) * np.maximum( 100 | y_max - y_min + 1, 0) 101 | 102 | union = area1[i] + area2 - overlap 103 | closure 104 | ious[i, :] = overlap / union - (closure - union) / closure 105 | if exchange: 106 | ious = ious.T 107 | return ious 108 | 109 | 110 | if __name__ == '__main__': 111 | bbox_overlaps(np.array([[0, 0, 100, 100]] * 3), 112 | np.array([[0, 200, 100, 300]] * 3), 113 | allow_neg=True) 114 | -------------------------------------------------------------------------------- /src/classification.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import Tensor 4 | from torch.nn.parameter import Parameter 5 | from torch.utils.data import Dataset 6 | from typing import Iterator, List 7 | 8 | from base_experiment import BaseExperiment 9 | from dataset import get_example_shape 10 | from models import create_model 11 | from plot import get_plot_fn, get_random_example_with_label 12 | 13 | 14 | class ClassificationExperiment(BaseExperiment): 15 | 16 | def __init__(self, config: dict, enable_tune: bool = False, **kwargs): 17 | super().__init__(config=config, enable_tune=enable_tune, **kwargs) 18 | self.classifier = create_model(**config['model_params'], 19 | input_shape=get_example_shape( 20 | config['exp_params']['data'])) 21 | 22 | def sample_images(self, plot: dict, batches: List[Tensor]): 23 | test_input = [] 24 | predictions = [] 25 | targets = [] 26 | for class_batch in batches: 27 | class_input = [] 28 | class_predictions = [] 29 | class_targets = [] 30 | for x, y in class_batch: 31 | x = x.unsqueeze(0) 32 | class_input.append(x) 33 | x = self.classifier(x.to(self.curr_device)).detach().cpu() 34 | class_predictions.append(x) 35 | class_targets.append(y.unsqueeze(0)) 36 | class_input = torch.cat(class_input, dim=0) 37 | test_input.append(class_input.unsqueeze(0)) 38 | predictions.append( 39 | torch.cat(class_predictions, dim=0).unsqueeze(0)) 40 | targets.append(torch.cat(class_targets, dim=0).unsqueeze(0)) 41 | 42 | test_input = torch.cat(test_input, dim=0) 43 | targets = torch.cat(targets, dim=0) 44 | predictions = torch.cat(predictions, dim=0) 45 | 46 | # Extensionless output path (let plotting function choose extension) 47 | out_path = os.path.join( 48 | self.logger.save_dir, self.logger.name, 49 | f"version_{self.logger.version}", 50 | f"{self.logger.name}_{plot['fn']}_{self.global_step}") 51 | fn = get_plot_fn(plot['fn']) 52 | image = fn(test_input=test_input, 53 | targets=targets, 54 | predictions=predictions, 55 | classes=plot['classes'], 56 | out_path=out_path, 57 | **plot['params']) 58 | self.logger.experiment.add_image(plot['fn'], image, self.global_step) 59 | vis = self.visdom() 60 | if vis is not None: 61 | vis.image(image, win=plot['fn']) 62 | 63 | def training_step(self, batch, batch_idx): 64 | real_img, labels = batch 65 | self.curr_device = self.device 66 | real_img = real_img.to(self.curr_device) 67 | y = self.classifier(real_img) 68 | train_loss = self.classifier.loss_function( 69 | y.cpu(), labels.cpu(), **self.params.get('loss_params', {})) 70 | self.log_train_step(train_loss) 71 | return train_loss 72 | 73 | def validation_step(self, batch, batch_idx): 74 | real_img, labels = batch 75 | self.curr_device = self.device 76 | real_img = real_img.to(self.curr_device) 77 | y = self.classifier(real_img) 78 | val_loss = self.classifier.loss_function( 79 | y.cpu(), labels.cpu(), **self.params.get('loss_params', {})) 80 | self.log_val_step(val_loss) 81 | return val_loss 82 | 83 | def trainable_parameters(self) -> Iterator[Parameter]: 84 | return self.classifier.parameters() 85 | 86 | def get_val_batches(self, dataset: Dataset) -> list: 87 | val_batches = [] 88 | for plot in self.plots: 89 | classes = plot['classes'] 90 | examples_per_class = plot['examples_per_class'] 91 | class_batches = [] 92 | for obj in classes: 93 | batch = [] 94 | class_indices = [] 95 | for _ in range(examples_per_class): 96 | idx = get_random_example_with_label(dataset, 97 | torch.Tensor( 98 | obj['labels']), 99 | all_=obj['all'], 100 | exclude=class_indices) 101 | batch.append(dataset[idx]) 102 | class_indices.append(idx) 103 | class_batches.append(batch) 104 | val_batches.append(class_batches) 105 | return val_batches 106 | -------------------------------------------------------------------------------- /src/neural_gbuffer.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import torch 4 | from torch import Tensor 5 | from torch.nn.parameter import Parameter 6 | from torch.utils.data import DataLoader 7 | from typing import Iterator 8 | 9 | import pytorch_lightning as pl 10 | 11 | from dataset import get_dataset 12 | from merge_strategy import deep_merge 13 | from models import BaseRenderer, create_model 14 | from plot import get_plot_fn 15 | 16 | 17 | class NeuralGBufferExperiment(pl.LightningModule): 18 | 19 | def __init__(self, model: BaseRenderer, params: dict) -> None: 20 | super().__init__() 21 | self.model = model 22 | self.params = params 23 | self.curr_device = None 24 | plots = self.params['plot'] 25 | if type(plots) is not list: 26 | plots = [plots] 27 | self.plots = plots 28 | 29 | def training_step(self, batch, batch_idx): 30 | self.curr_device = self.device 31 | orig, labels = batch 32 | recons = self.model(*labels) 33 | train_loss = self.model.loss_function( 34 | recons, orig, **self.params.get('loss_params', {})) 35 | self.logger.experiment.log( 36 | {key: val.item() 37 | for key, val in train_loss.items()}) 38 | if self.global_step > 0: 39 | for plot, val_indices in zip(self.plots, self.val_indices): 40 | if self.global_step % plot['sample_every_n_steps'] == 0: 41 | self.sample_images(plot, val_indices) 42 | return train_loss 43 | 44 | def validation_step(self, batch, batch_idx): 45 | orig, labels = batch 46 | recons = self.model(*labels) 47 | val_loss = self.model.loss_function( 48 | recons, orig, **self.params.get('loss_params', {})) 49 | return val_loss 50 | 51 | def trainable_parameters(self) -> Iterator[Parameter]: 52 | return self.model.parameters() 53 | 54 | def train_dataloader(self): 55 | dataset = get_dataset(self.params['data']['name'], 56 | self.params['data'].get('training', {})) 57 | self.num_train_imgs = len(dataset) 58 | return DataLoader(dataset, 59 | batch_size=self.params['batch_size'], 60 | shuffle=True, 61 | **self.params['data'].get('loader', {})) 62 | 63 | def val_dataloader(self): 64 | ds_params = deep_merge(self.params['data'].get('training', {}).copy(), 65 | self.params['data'].get('validation', {})) 66 | dataset = get_dataset(self.params['data']['name'], ds_params) 67 | 68 | self.sample_dataloader = DataLoader( 69 | dataset, 70 | batch_size=self.params['batch_size'], 71 | shuffle=False, 72 | **self.params['data'].get('loader', {})) 73 | self.num_val_imgs = len(self.sample_dataloader) 74 | n = len(dataset) 75 | self.val_indices = [ 76 | torch.randint(low=0, high=n, 77 | size=(plot['batch_size'], 1)).squeeze() 78 | for plot in self.plots 79 | ] 80 | return self.sample_dataloader 81 | 82 | def sample_images(self, plot: dict, val_indices: Tensor): 83 | revert = self.training 84 | if revert: 85 | self.eval() 86 | test_input = [] 87 | recons = [] 88 | 89 | batch = [self.sample_dataloader.dataset[int(i)] for i in val_indices] 90 | for x, transform in batch: 91 | test_input.append(x.unsqueeze(0)) 92 | out = self.model(*[a.unsqueeze(0) for a in transform]) 93 | recons.append(out) 94 | test_input = torch.cat(test_input, dim=0) 95 | recons = torch.cat(recons, dim=0) 96 | # Extensionless output path (let plotting function choose extension) 97 | out_path = os.path.join( 98 | self.logger.save_dir, self.logger.name, 99 | f"version_{self.logger.version}", 100 | f"{self.logger.name}_{plot['fn']}_{self.global_step}") 101 | orig = test_input.data.cpu() 102 | recons = recons.data.cpu() 103 | fn = get_plot_fn(plot['fn']) 104 | fn(orig=orig, 105 | recons=recons, 106 | model_name=self.model.name, 107 | epoch=self.current_epoch, 108 | out_path=out_path, 109 | **plot['params']) 110 | gc.collect() 111 | if revert: 112 | self.train() 113 | 114 | 115 | def neural_gbuffer(config: dict, run_args: dict) -> pl.LightningModule: 116 | exp_params = config['exp_params'] 117 | image_size = exp_params['data']['training']['rasterization_settings'][ 118 | 'image_size'] 119 | model = create_model(**config['model_params'], 120 | width=image_size, 121 | height=image_size, 122 | channels=3, 123 | enable_fid=True) 124 | return NeuralGBufferExperiment(model, exp_params) 125 | --------------------------------------------------------------------------------