├── .gitignore ├── LICENSE ├── README.md ├── configs ├── config_affectnet_base.yaml ├── config_affectnet_large.yaml ├── config_vitasd_base.yaml ├── config_vitasd_base_attonly.yaml ├── config_vitasd_large.yaml └── config_vitasd_small.yaml ├── datasets ├── __init__.py ├── affectnet.py └── autism_dataset.py ├── eval.py ├── lib ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── augment.cpython-38.pyc │ ├── sngp.cpython-38.pyc │ └── utils.cpython-38.pyc ├── augment.py ├── ood_evaluator.py ├── pos_embed.py ├── sngp.py └── utils.py ├── lightning_logs ├── ViTASD-B │ └── version_1 │ │ ├── config.yaml │ │ └── hparams.yaml ├── ViTASD-L │ └── version_1 │ │ ├── config.yaml │ │ └── hparams.yaml └── ViTASD-S │ └── version_1 │ ├── config.yaml │ └── hparams.yaml ├── models ├── __init__.py └── vitasd.py ├── requirement.txt ├── runs ├── figures │ ├── attention.png │ └── framework.png └── vis │ ├── cf_matrix.png │ └── histogram.png ├── tools ├── __init__.py └── ls_timm_models.py ├── train.py ├── train_affectnet.py └── visualization_attention.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Xu(Iroh) Cao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ICASSP 2023] ViTASD: Robust Vision Transformer Baselines for Autism Spectrum Disorder Facial Diagnosis 2 | ### Official PyTorch Implementation 3 | Shenzhen Children's Hospital 4 | New York University 5 | 6 | 7 | ### Abstract 8 | Autism spectrum disorder (ASD) is a lifelong neurodevelopmental disorder with very high prevalence around the world. Research progress in the field of ASD facial analysis in pediatric patients has been hindered due to a lack of well-established baselines. In this paper, we propose the use of the Vision Transformer (ViT) for the computational analysis of pediatric ASD. The presented model, known as ViTASD, distills knowledge from large facial expression datasets and offers model structure transferability. Specifically, ViTASD employs a vanilla ViT to extract features from patients' face images and adopts a lightweight decoder with a Gaussian Process layer to enhance the robustness for ASD analysis. Extensive experiments conducted on standard ASD facial analysis benchmarks show that our method outperforms all of the representative approaches in ASD facial analysis, while the ViTASD-L achieves a new state-of-the-art. 9 | 10 | 11 | ![Attention for ASD Children](./runs/figures/attention.png) 12 | 13 | ## Dataset 14 | 15 | Publicly available datasets were analyzed in this study. The original data page can be found at: [Kaggle](https://www.kaggle.com/cihan063/autism-image-data). The author update the dataset to a new [Google Drive](https://drive.google.com/drive/folders/1XQU0pluL0m3TIlXqntano12d68peMb8A) 16 | 17 | 18 | Other useful dataset for computer vision in Autism Spectrum Disorder detection: 19 | 20 | [DE-ENIGMA Dataset](https://de-enigma.eu/database/) 21 | [Saliency4ASD dataset](https://saliency4asd.ls2n.fr/datasets/) 22 | 23 | We will expand the research for these datasets in the future. And we are also trying to build a new benchmark for ASD facial diagnosis using many new datasets in Shenzhen's children. Any news for this benchmark will be updated to this Github repo until we publish the competition. This project will create a completely non-profit platform for ASD early intervention around the world. 24 | 25 | ## Model 26 | 27 | ![NetWork_Architecture](./runs/figures/framework.png) 28 | 29 | 30 | ## Pre-trained in AffectNet Dataset 31 | 32 | ``` 33 | python train_affectnet.py fit -c ./configs/config_affectnet_base.yaml 34 | ``` 35 | 36 | ``` 37 | python train_affectnet.py fit -c ./configs/config_affectnet_large.yaml 38 | ``` 39 | 40 | 41 | ## Training 42 | 43 | ``` 44 | python train.py fit -c ./configs/config_vitasd_small.yaml 45 | ``` 46 | 47 | ``` 48 | python train.py fit -c ./configs/config_vitasd_base.yaml 49 | ``` 50 | 51 | ``` 52 | python train.py fit -c ./configs/config_vitasd_large.yaml 53 | ``` 54 | 55 | ### Monitoring the training ('X' is S, B, or L) 56 | 57 | ``` 58 | tensorboard --logdir=./lightning_logs/ViTASD-'X' 59 | ``` 60 | 61 | 62 | ## Evaluation 63 | pending 64 | -------------------------------------------------------------------------------- /configs/config_affectnet_base.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.7.5 2 | seed_everything: 42 3 | trainer: 4 | logger: 5 | class_path: pytorch_lightning.loggers.TensorBoardLogger 6 | init_args: 7 | save_dir: ./lightning_logs 8 | name: AffectNet-B 9 | version: 1 10 | default_hp_metric: false 11 | enable_checkpointing: true 12 | callbacks: 13 | class_path: pytorch_lightning.callbacks.ModelCheckpoint 14 | init_args: 15 | dirpath: ./lightning_logs/AffectNet-B 16 | save_top_k: 2 17 | monitor: Accuracy/val 18 | mode: max 19 | save_weights_only: true 20 | auto_insert_metric_name: true 21 | default_root_dir: null 22 | gradient_clip_val: null 23 | gradient_clip_algorithm: null 24 | num_nodes: 1 25 | num_processes: null 26 | devices: 4 27 | gpus: null 28 | auto_select_gpus: false 29 | tpu_cores: null 30 | ipus: null 31 | enable_progress_bar: true 32 | overfit_batches: 0.0 33 | track_grad_norm: -1 34 | check_val_every_n_epoch: 1 35 | fast_dev_run: false 36 | accumulate_grad_batches: null 37 | max_epochs: &epochs 50 38 | min_epochs: null 39 | max_steps: -1 40 | min_steps: null 41 | max_time: null 42 | limit_train_batches: null 43 | limit_val_batches: null 44 | limit_test_batches: null 45 | limit_predict_batches: null 46 | val_check_interval: null 47 | log_every_n_steps: 50 48 | accelerator: gpu 49 | strategy: null 50 | sync_batchnorm: false 51 | precision: 32 52 | enable_model_summary: true 53 | weights_save_path: null 54 | num_sanity_val_steps: 2 55 | resume_from_checkpoint: null 56 | profiler: null 57 | benchmark: null 58 | deterministic: null 59 | reload_dataloaders_every_n_epochs: 0 60 | auto_lr_find: false 61 | replace_sampler_ddp: true 62 | detect_anomaly: false 63 | auto_scale_batch_size: false 64 | plugins: null 65 | amp_backend: native 66 | amp_level: null 67 | move_metrics_to_cpu: false 68 | multiple_trainloader_mode: max_size_cycle 69 | model: 70 | num_classes: 8 71 | attn_only: false 72 | opt: adamw 73 | weight_decay: 0.05 74 | sched: cosine 75 | lr: 4.0e-05 76 | warmup_lr: 1.0e-06 77 | min_lr: 1.0e-06 78 | warmup_epochs: 5 79 | cooldown_epochs: 0 80 | smoothing: 0.0 81 | batch_size: &bs 64 82 | model: deit3_base_patch16_224 83 | mixup: 0.8 84 | cutmix: 1.0 85 | data: 86 | batch_size: *bs 87 | num_workers: 4 88 | three_augment: true 89 | ckpt_path: null 90 | -------------------------------------------------------------------------------- /configs/config_affectnet_large.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.7.5 2 | seed_everything: 42 3 | trainer: 4 | logger: 5 | class_path: pytorch_lightning.loggers.TensorBoardLogger 6 | init_args: 7 | save_dir: ./lightning_logs 8 | name: AffectNet-L 9 | version: 1 10 | default_hp_metric: false 11 | enable_checkpointing: true 12 | callbacks: 13 | class_path: pytorch_lightning.callbacks.ModelCheckpoint 14 | init_args: 15 | dirpath: ./lightning_logs/AffectNet-L 16 | save_top_k: 2 17 | monitor: Accuracy/val 18 | mode: max 19 | save_weights_only: true 20 | auto_insert_metric_name: true 21 | default_root_dir: null 22 | gradient_clip_val: null 23 | gradient_clip_algorithm: null 24 | num_nodes: 1 25 | num_processes: null 26 | devices: 4 27 | gpus: null 28 | auto_select_gpus: false 29 | tpu_cores: null 30 | ipus: null 31 | enable_progress_bar: true 32 | overfit_batches: 0.0 33 | track_grad_norm: -1 34 | check_val_every_n_epoch: 1 35 | fast_dev_run: false 36 | accumulate_grad_batches: null 37 | max_epochs: &epochs 100 38 | min_epochs: null 39 | max_steps: -1 40 | min_steps: null 41 | max_time: null 42 | limit_train_batches: null 43 | limit_val_batches: null 44 | limit_test_batches: null 45 | limit_predict_batches: null 46 | val_check_interval: null 47 | log_every_n_steps: 50 48 | accelerator: gpu 49 | strategy: null 50 | sync_batchnorm: false 51 | precision: 32 52 | enable_model_summary: true 53 | weights_save_path: null 54 | num_sanity_val_steps: 2 55 | resume_from_checkpoint: null 56 | profiler: null 57 | benchmark: null 58 | deterministic: null 59 | reload_dataloaders_every_n_epochs: 0 60 | auto_lr_find: false 61 | replace_sampler_ddp: true 62 | detect_anomaly: false 63 | auto_scale_batch_size: false 64 | plugins: null 65 | amp_backend: native 66 | amp_level: null 67 | move_metrics_to_cpu: false 68 | multiple_trainloader_mode: max_size_cycle 69 | model: 70 | num_classes: 8 71 | attn_only: false 72 | opt: adamw 73 | weight_decay: 0.05 74 | sched: cosine 75 | lr: 4.0e-05 76 | warmup_lr: 1.0e-06 77 | min_lr: 1.0e-06 78 | warmup_epochs: 5 79 | cooldown_epochs: 0 80 | smoothing: 0.0 81 | batch_size: &bs 32 82 | model: deit3_large_patch16_224 83 | mixup: 0.8 84 | cutmix: 1.0 85 | data: 86 | batch_size: *bs 87 | num_workers: 4 88 | three_augment: true 89 | ckpt_path: null 90 | -------------------------------------------------------------------------------- /configs/config_vitasd_base.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.7.5 2 | seed_everything: 42 3 | trainer: 4 | logger: 5 | class_path: pytorch_lightning.loggers.TensorBoardLogger 6 | init_args: 7 | save_dir: ./lightning_logs 8 | name: ViTASD-B 9 | version: 3 10 | default_hp_metric: false 11 | enable_checkpointing: true 12 | callbacks: 13 | class_path: pytorch_lightning.callbacks.ModelCheckpoint 14 | init_args: 15 | dirpath: ./lightning_logs/ViTASD-B 16 | save_top_k: 3 17 | monitor: Accuracy/val 18 | mode: max 19 | save_weights_only: true 20 | auto_insert_metric_name: true 21 | default_root_dir: null 22 | gradient_clip_val: null 23 | gradient_clip_algorithm: null 24 | num_nodes: 1 25 | num_processes: null 26 | devices: 4 27 | gpus: null 28 | auto_select_gpus: false 29 | tpu_cores: null 30 | ipus: null 31 | enable_progress_bar: true 32 | overfit_batches: 0.0 33 | track_grad_norm: -1 34 | check_val_every_n_epoch: 1 35 | fast_dev_run: false 36 | accumulate_grad_batches: null 37 | max_epochs: &epochs 300 38 | min_epochs: null 39 | max_steps: -1 40 | min_steps: null 41 | max_time: null 42 | limit_train_batches: null 43 | limit_val_batches: null 44 | limit_test_batches: null 45 | limit_predict_batches: null 46 | val_check_interval: null 47 | log_every_n_steps: 50 48 | accelerator: gpu 49 | strategy: null 50 | sync_batchnorm: false 51 | precision: 32 52 | enable_model_summary: true 53 | weights_save_path: null 54 | num_sanity_val_steps: 2 55 | resume_from_checkpoint: null 56 | profiler: null 57 | benchmark: null 58 | deterministic: null 59 | reload_dataloaders_every_n_epochs: 0 60 | auto_lr_find: false 61 | replace_sampler_ddp: true 62 | detect_anomaly: false 63 | auto_scale_batch_size: false 64 | plugins: null 65 | amp_backend: native 66 | amp_level: null 67 | move_metrics_to_cpu: false 68 | multiple_trainloader_mode: max_size_cycle 69 | model: 70 | num_classes: 2 71 | attn_only: false 72 | opt: adamw 73 | weight_decay: 0.05 74 | sched: cosine 75 | lr: 1.0e-05 76 | warmup_lr: 1.0e-06 77 | min_lr: 1.0e-06 78 | warmup_epochs: 5 79 | cooldown_epochs: 0 80 | smoothing: 0.0 81 | batch_size: &bs 16 82 | model: deit3_base_patch16_224 83 | mixup: 0.8 84 | cutmix: 1.0 85 | data: 86 | batch_size: *bs 87 | num_workers: 4 88 | three_augment: false 89 | ckpt_path: null 90 | -------------------------------------------------------------------------------- /configs/config_vitasd_base_attonly.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.7.5 2 | seed_everything: 42 3 | trainer: 4 | logger: 5 | class_path: pytorch_lightning.loggers.TensorBoardLogger 6 | init_args: 7 | save_dir: ./lightning_logs 8 | name: ViTASD-B 9 | version: 1 10 | default_hp_metric: false 11 | enable_checkpointing: true 12 | callbacks: 13 | class_path: pytorch_lightning.callbacks.ModelCheckpoint 14 | init_args: 15 | dirpath: ./lightning_logs/ViTASD-B 16 | save_top_k: 3 17 | monitor: Accuracy/val 18 | mode: max 19 | save_weights_only: true 20 | auto_insert_metric_name: true 21 | default_root_dir: null 22 | gradient_clip_val: null 23 | gradient_clip_algorithm: null 24 | num_nodes: 1 25 | num_processes: null 26 | devices: 4 27 | gpus: null 28 | auto_select_gpus: false 29 | tpu_cores: null 30 | ipus: null 31 | enable_progress_bar: true 32 | overfit_batches: 0.0 33 | track_grad_norm: -1 34 | check_val_every_n_epoch: 1 35 | fast_dev_run: false 36 | accumulate_grad_batches: null 37 | max_epochs: &epochs 300 38 | min_epochs: null 39 | max_steps: -1 40 | min_steps: null 41 | max_time: null 42 | limit_train_batches: null 43 | limit_val_batches: null 44 | limit_test_batches: null 45 | limit_predict_batches: null 46 | val_check_interval: null 47 | log_every_n_steps: 50 48 | accelerator: gpu 49 | strategy: null 50 | sync_batchnorm: false 51 | precision: 32 52 | enable_model_summary: true 53 | weights_save_path: null 54 | num_sanity_val_steps: 2 55 | resume_from_checkpoint: null 56 | profiler: null 57 | benchmark: null 58 | deterministic: null 59 | reload_dataloaders_every_n_epochs: 0 60 | auto_lr_find: false 61 | replace_sampler_ddp: true 62 | detect_anomaly: false 63 | auto_scale_batch_size: false 64 | plugins: null 65 | amp_backend: native 66 | amp_level: null 67 | move_metrics_to_cpu: false 68 | multiple_trainloader_mode: max_size_cycle 69 | model: 70 | num_classes: 2 71 | attn_only: true 72 | opt: adamw 73 | weight_decay: 0.05 74 | sched: cosine 75 | lr: 1.0e-05 76 | warmup_lr: 1.0e-06 77 | min_lr: 1.0e-06 78 | warmup_epochs: 5 79 | cooldown_epochs: 0 80 | smoothing: 0.0 81 | batch_size: &bs 16 82 | model: deit3_base_patch16_224 83 | mixup: 0.8 84 | cutmix: 1.0 85 | data: 86 | batch_size: *bs 87 | num_workers: 4 88 | three_augment: false 89 | ckpt_path: null 90 | -------------------------------------------------------------------------------- /configs/config_vitasd_large.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.7.5 2 | seed_everything: 42 3 | trainer: 4 | logger: 5 | class_path: pytorch_lightning.loggers.TensorBoardLogger 6 | init_args: 7 | save_dir: ./lightning_logs 8 | name: ViTASD-L 9 | version: 3 10 | default_hp_metric: false 11 | enable_checkpointing: true 12 | callbacks: 13 | class_path: pytorch_lightning.callbacks.ModelCheckpoint 14 | init_args: 15 | dirpath: ./lightning_logs/ViTASD-L 16 | save_top_k: 3 17 | monitor: Accuracy/val 18 | mode: max 19 | save_weights_only: true 20 | auto_insert_metric_name: true 21 | default_root_dir: null 22 | gradient_clip_val: null 23 | gradient_clip_algorithm: null 24 | num_nodes: 1 25 | num_processes: null 26 | devices: 4 27 | gpus: null 28 | auto_select_gpus: false 29 | tpu_cores: null 30 | ipus: null 31 | enable_progress_bar: true 32 | overfit_batches: 0.0 33 | track_grad_norm: -1 34 | check_val_every_n_epoch: 1 35 | fast_dev_run: false 36 | accumulate_grad_batches: null 37 | max_epochs: &epochs 300 38 | min_epochs: null 39 | max_steps: -1 40 | min_steps: null 41 | max_time: null 42 | limit_train_batches: null 43 | limit_val_batches: null 44 | limit_test_batches: null 45 | limit_predict_batches: null 46 | val_check_interval: null 47 | log_every_n_steps: 50 48 | accelerator: gpu 49 | strategy: null 50 | sync_batchnorm: false 51 | precision: 32 52 | enable_model_summary: true 53 | weights_save_path: null 54 | num_sanity_val_steps: 2 55 | resume_from_checkpoint: null 56 | profiler: null 57 | benchmark: null 58 | deterministic: null 59 | reload_dataloaders_every_n_epochs: 0 60 | auto_lr_find: false 61 | replace_sampler_ddp: true 62 | detect_anomaly: false 63 | auto_scale_batch_size: false 64 | plugins: null 65 | amp_backend: native 66 | amp_level: null 67 | move_metrics_to_cpu: false 68 | multiple_trainloader_mode: max_size_cycle 69 | model: 70 | opt: adamw 71 | weight_decay: 0.05 72 | sched: cosine 73 | lr: 1.0e-05 74 | warmup_lr: 1.0e-06 75 | min_lr: 1.0e-06 76 | warmup_epochs: 5 77 | cooldown_epochs: 0 78 | batch_size: &bs 16 79 | model: deit3_large_patch16_224 80 | mixup: 0.8 81 | cutmix: 1.0 82 | data: 83 | batch_size: *bs 84 | num_workers: 4 85 | three_augment: false 86 | ckpt_path: null 87 | -------------------------------------------------------------------------------- /configs/config_vitasd_small.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.7.5 2 | seed_everything: 42 3 | trainer: 4 | logger: 5 | class_path: pytorch_lightning.loggers.TensorBoardLogger 6 | init_args: 7 | save_dir: ./lightning_logs 8 | name: ViTASD-S 9 | version: 1 10 | default_hp_metric: false 11 | enable_checkpointing: true 12 | callbacks: 13 | class_path: pytorch_lightning.callbacks.ModelCheckpoint 14 | init_args: 15 | dirpath: ./lightning_logs/ViTASD-S 16 | save_top_k: 1 17 | monitor: Accuracy/val 18 | mode: max 19 | save_weights_only: true 20 | auto_insert_metric_name: true 21 | default_root_dir: null 22 | gradient_clip_val: null 23 | gradient_clip_algorithm: null 24 | num_nodes: 1 25 | num_processes: null 26 | devices: 4 27 | gpus: null 28 | auto_select_gpus: false 29 | tpu_cores: null 30 | ipus: null 31 | enable_progress_bar: true 32 | overfit_batches: 0.0 33 | track_grad_norm: -1 34 | check_val_every_n_epoch: 1 35 | fast_dev_run: false 36 | accumulate_grad_batches: null 37 | max_epochs: &epochs 300 38 | min_epochs: null 39 | max_steps: -1 40 | min_steps: null 41 | max_time: null 42 | limit_train_batches: null 43 | limit_val_batches: null 44 | limit_test_batches: null 45 | limit_predict_batches: null 46 | val_check_interval: null 47 | log_every_n_steps: 50 48 | accelerator: gpu 49 | strategy: null 50 | sync_batchnorm: false 51 | precision: 32 52 | enable_model_summary: true 53 | weights_save_path: null 54 | num_sanity_val_steps: 2 55 | resume_from_checkpoint: null 56 | profiler: null 57 | benchmark: null 58 | deterministic: null 59 | reload_dataloaders_every_n_epochs: 0 60 | auto_lr_find: false 61 | replace_sampler_ddp: true 62 | detect_anomaly: false 63 | auto_scale_batch_size: false 64 | plugins: null 65 | amp_backend: native 66 | amp_level: null 67 | move_metrics_to_cpu: false 68 | multiple_trainloader_mode: max_size_cycle 69 | model: 70 | opt: adamw 71 | weight_decay: 0.05 72 | sched: cosine 73 | lr: 1.0e-05 74 | warmup_lr: 1.0e-06 75 | min_lr: 1.0e-06 76 | warmup_epochs: 5 77 | cooldown_epochs: 0 78 | batch_size: &bs 16 79 | model: deit3_small_patch16_224 80 | mixup: 0.8 81 | cutmix: 1.0 82 | data: 83 | batch_size: *bs 84 | num_workers: 4 85 | three_augment: false 86 | ckpt_path: null 87 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .autism_dataset import AutismDatasetModule 2 | from .affectnet import AffectNetDataModule 3 | 4 | __all__ = ["AutismDatasetModule", "AffectNetDataModule"] 5 | -------------------------------------------------------------------------------- /datasets/affectnet.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import LightningDataModule 2 | from pytorch_lightning.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS 3 | from timm.data import ImageDataset, create_transform 4 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 5 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 6 | from torchvision import transforms 7 | from torchvision.transforms.functional import InterpolationMode 8 | 9 | from pathlib import Path 10 | from lib.augment import new_data_aug_generator 11 | 12 | 13 | class AffectNetDataModule(LightningDataModule): 14 | def __init__(self, 15 | batch_size: int = 256, 16 | num_workers: int = 4, 17 | data_root: str = "", 18 | input_size: int = 224, 19 | color_jitter: float = 0.3, 20 | three_augment: bool = True, 21 | src: bool = False, # simple random crop 22 | ): 23 | super(AffectNetDataModule, self).__init__() 24 | self.save_hyperparameters() 25 | self.data_path = Path(self.hparams.data_root) / "imgs" 26 | self.train_transforms = self.build_transform(is_train=True) 27 | self.eval_transforms = self.build_transform(is_train=False) 28 | if self.hparams.three_augment: 29 | self.train_transforms = new_data_aug_generator(self.hparams) 30 | self.class_map = {str(i): i for i in range(8)} 31 | 32 | def build_transform(self, is_train): 33 | resize_im = self.hparams.input_size > 32 34 | if is_train: 35 | transform = create_transform(input_size=self.hparams.input_size, 36 | is_training=True, 37 | color_jitter=self.hparams.color_jitter) 38 | if not resize_im: # replace RandomResizedCropAndInterpolation with RandomCrop 39 | transform.transforms[0] = transforms.RandomCrop(self.hparams.input_size, padding=4) 40 | return transform 41 | t = [] 42 | if resize_im: 43 | # int((256 / 224) * args.input_size) (deit crop ratio (256 / 224), deit III crop ratio 1.0) 44 | size = int((1.0) * self.hparams.input_size) 45 | t.append(transforms.Resize(size, 46 | interpolation=InterpolationMode.BICUBIC, )) # to maintain same ratio w.r.t. 224 images 47 | t.append(transforms.CenterCrop(self.hparams.input_size)) 48 | t.append(transforms.ToTensor()) 49 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 50 | return transforms.Compose(t) 51 | 52 | def train_dataloader(self) -> TRAIN_DATALOADERS: 53 | train_dataset = ImageDataset(str(self.data_path / "train"), transform=self.train_transforms, 54 | class_map=self.class_map) 55 | train_sampler = RandomSampler(train_dataset) 56 | return DataLoader(train_dataset, 57 | sampler=train_sampler, 58 | batch_size=self.hparams.batch_size, 59 | num_workers=self.hparams.num_workers, 60 | pin_memory=True, 61 | persistent_workers=True, 62 | drop_last=True) 63 | 64 | def val_dataloader(self) -> EVAL_DATALOADERS: 65 | val_dataset = ImageDataset(str(self.data_path / "val"), transform=self.eval_transforms, 66 | class_map=self.class_map) 67 | val_sampler = SequentialSampler(val_dataset) 68 | return DataLoader(val_dataset, 69 | sampler=val_sampler, 70 | batch_size=int(1.5 * self.hparams.batch_size), 71 | num_workers=self.hparams.num_workers, 72 | pin_memory=True, 73 | persistent_workers=True, 74 | drop_last=False) -------------------------------------------------------------------------------- /datasets/autism_dataset.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import LightningDataModule 2 | from pytorch_lightning.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS 3 | from timm.data import ImageDataset, create_transform 4 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 5 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 6 | from torchvision import transforms 7 | from torchvision.transforms.functional import InterpolationMode 8 | 9 | from pathlib import Path 10 | from lib.augment import new_data_aug_generator 11 | 12 | 13 | class AutismDatasetModule(LightningDataModule): 14 | def __init__(self, 15 | batch_size: int = 256, 16 | num_workers: int = 4, 17 | data_root: str = "./datasets/Kaggle_AutismDataset2", 18 | input_size: int = 224, 19 | color_jitter: float = 0.3, 20 | three_augment: bool = True, 21 | src: bool = False, # simple random crop 22 | ): 23 | super(AutismDatasetModule, self).__init__() 24 | self.save_hyperparameters() 25 | self.data_path = Path(self.hparams.data_root) 26 | self.train_transforms = self.build_transform(is_train=True) 27 | self.eval_transforms = self.build_transform(is_train=False) 28 | if self.hparams.three_augment: 29 | self.train_transforms = new_data_aug_generator(self.hparams) 30 | self.class_map = {'Non_Autistic': 0, 'Autistic': 1} 31 | 32 | def build_transform(self, is_train): 33 | resize_im = self.hparams.input_size > 32 34 | if is_train: 35 | transform = create_transform(input_size=self.hparams.input_size, 36 | is_training=True, 37 | color_jitter=self.hparams.color_jitter) 38 | if not resize_im: # replace RandomResizedCropAndInterpolation with RandomCrop 39 | transform.transforms[0] = transforms.RandomCrop(self.hparams.input_size, padding=4) 40 | return transform 41 | t = [] 42 | if resize_im: 43 | # int((256 / 224) * args.input_size) (deit crop ratio (256 / 224), deit III crop ratio 1.0) 44 | size = int((1.0) * self.hparams.input_size) 45 | t.append(transforms.Resize(size, 46 | interpolation=InterpolationMode.BICUBIC, )) # to maintain same ratio w.r.t. 224 images 47 | t.append(transforms.CenterCrop(self.hparams.input_size)) 48 | t.append(transforms.ToTensor()) 49 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 50 | return transforms.Compose(t) 51 | 52 | def train_dataloader(self) -> TRAIN_DATALOADERS: 53 | train_dataset = ImageDataset(str(self.data_path / "training"), transform=self.train_transforms, 54 | class_map=self.class_map) 55 | train_sampler = RandomSampler(train_dataset) 56 | return DataLoader(train_dataset, 57 | sampler=train_sampler, 58 | batch_size=self.hparams.batch_size, 59 | num_workers=self.hparams.num_workers, 60 | pin_memory=True, 61 | persistent_workers=True, 62 | drop_last=True) 63 | 64 | def val_dataloader(self) -> EVAL_DATALOADERS: 65 | val_dataset = ImageDataset(str(self.data_path / "testing"), transform=self.eval_transforms, 66 | class_map=self.class_map) 67 | val_sampler = SequentialSampler(val_dataset) 68 | return DataLoader(val_dataset, 69 | sampler=val_sampler, 70 | batch_size=4, 71 | num_workers=self.hparams.num_workers, 72 | pin_memory=True, 73 | persistent_workers=True, 74 | drop_last=False) 75 | 76 | def test_dataloader(self) -> EVAL_DATALOADERS: 77 | test_dataset = ImageDataset(str(self.data_path / "testing"), transform=self.eval_transforms, 78 | class_map=self.class_map) 79 | test_sampler = SequentialSampler(test_dataset) 80 | return DataLoader(test_dataset, 81 | sampler=test_sampler, 82 | batch_size=4, 83 | num_workers=self.hparams.num_workers, 84 | pin_memory=True, 85 | persistent_workers=True, 86 | drop_last=False) 87 | 88 | def predict_dataloader(self) -> EVAL_DATALOADERS: 89 | test_dataset = ImageDataset(str(self.data_path), transform=self.eval_transforms, 90 | class_map=self.class_map) 91 | test_sampler = SequentialSampler(test_dataset) 92 | return DataLoader(test_dataset, 93 | sampler=test_sampler, 94 | batch_size=1, 95 | num_workers=self.hparams.num_workers, 96 | pin_memory=True, 97 | persistent_workers=True, 98 | drop_last=False) -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import matplotlib.pyplot as plt 4 | import seaborn as sns 5 | import numpy as np 6 | 7 | from datasets import AutismDatasetModule 8 | from train import ViTASDLM 9 | 10 | from pytorch_lightning import LightningModule 11 | from pytorch_lightning.cli import LightningCLI 12 | from pytorch_lightning.utilities.types import STEP_OUTPUT, EPOCH_OUTPUT 13 | 14 | from torchmetrics import Accuracy, ConfusionMatrix, AUROC 15 | from torch.optim import Optimizer 16 | 17 | from timm.data import Mixup 18 | from timm.models import create_model 19 | from timm.optim import create_optimizer_v2 20 | from timm.scheduler import create_scheduler 21 | from timm.scheduler.scheduler import Scheduler 22 | 23 | from pathlib import Path 24 | from typing import Optional 25 | 26 | 27 | auroc = AUROC(num_classes=2) 28 | accuracy = Accuracy() 29 | 30 | 31 | model = ViTASDLM.load_from_checkpoint( 32 | checkpoint_path="", 33 | hparams_file="", 34 | map_location=None, 35 | ) 36 | 37 | 38 | def get_predictions(model): 39 | 40 | softmax = nn.Softmax(dim=1) 41 | dataset_module = AutismDatasetModule() 42 | model.eval() 43 | 44 | predictions = [] 45 | labels = [] 46 | 47 | for data, label in iter(dataset_module.test_dataloader()): 48 | 49 | prediction = model(data) 50 | predictions.append(softmax(prediction)) 51 | labels.append(label) 52 | 53 | predictions = torch.cat(predictions) 54 | labels = torch.cat(labels) 55 | true_predictions = [max(a,b) for a,b in predictions.tolist()] 56 | return predictions, labels, true_predictions 57 | 58 | in_preds, in_labels, in_true_preds = get_predictions(model) 59 | 60 | print("in auroc:", auroc(in_preds, in_labels)) 61 | print("in accuracy:", accuracy(in_preds, in_labels)) 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PediaMedAI/ViTASD/f96561962dbae63fc1e168c081946819e6262e9b/lib/__init__.py -------------------------------------------------------------------------------- /lib/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PediaMedAI/ViTASD/f96561962dbae63fc1e168c081946819e6262e9b/lib/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/augment.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PediaMedAI/ViTASD/f96561962dbae63fc1e168c081946819e6262e9b/lib/__pycache__/augment.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/sngp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PediaMedAI/ViTASD/f96561962dbae63fc1e168c081946819e6262e9b/lib/__pycache__/sngp.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PediaMedAI/ViTASD/f96561962dbae63fc1e168c081946819e6262e9b/lib/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /lib/augment.py: -------------------------------------------------------------------------------- 1 | # Origin: https://github.com/facebookresearch/deit/blob/main/augment.py 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | 6 | """ 7 | 3Augment implementation 8 | Data-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino) 9 | and timm DA(https://github.com/rwightman/pytorch-image-models) 10 | """ 11 | import torch 12 | from torchvision import transforms 13 | 14 | from timm.data.transforms import RandomResizedCropAndInterpolation, ToNumpy, ToTensor 15 | 16 | import numpy as np 17 | from torchvision import datasets, transforms 18 | import random 19 | 20 | from PIL import ImageFilter, ImageOps 21 | import torchvision.transforms.functional as TF 22 | 23 | 24 | class GaussianBlur(object): 25 | """ 26 | Apply Gaussian Blur to the PIL image. 27 | """ 28 | 29 | def __init__(self, p=0.1, radius_min=0.1, radius_max=2.): 30 | self.prob = p 31 | self.radius_min = radius_min 32 | self.radius_max = radius_max 33 | 34 | def __call__(self, img): 35 | do_it = random.random() <= self.prob 36 | if not do_it: 37 | return img 38 | 39 | img = img.filter( 40 | ImageFilter.GaussianBlur( 41 | radius=random.uniform(self.radius_min, self.radius_max) 42 | ) 43 | ) 44 | return img 45 | 46 | 47 | class Solarization(object): 48 | """ 49 | Apply Solarization to the PIL image. 50 | """ 51 | 52 | def __init__(self, p=0.2): 53 | self.p = p 54 | 55 | def __call__(self, img): 56 | if random.random() < self.p: 57 | return ImageOps.solarize(img) 58 | else: 59 | return img 60 | 61 | 62 | class gray_scale(object): 63 | """ 64 | Apply Solarization to the PIL image. 65 | """ 66 | 67 | def __init__(self, p=0.2): 68 | self.p = p 69 | self.transf = transforms.Grayscale(3) 70 | 71 | def __call__(self, img): 72 | if random.random() < self.p: 73 | return self.transf(img) 74 | else: 75 | return img 76 | 77 | 78 | class horizontal_flip(object): 79 | """ 80 | Apply Solarization to the PIL image. 81 | """ 82 | 83 | def __init__(self, p=0.2, activate_pred=False): 84 | self.p = p 85 | self.transf = transforms.RandomHorizontalFlip(p=1.0) 86 | 87 | def __call__(self, img): 88 | if random.random() < self.p: 89 | return self.transf(img) 90 | else: 91 | return img 92 | 93 | 94 | def new_data_aug_generator(args=None): 95 | img_size = args.input_size 96 | remove_random_resized_crop = args.src 97 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 98 | primary_tfl = [] 99 | scale = (0.08, 1.0) 100 | interpolation = 'bicubic' 101 | if remove_random_resized_crop: 102 | primary_tfl = [ 103 | transforms.Resize(img_size, interpolation=3), 104 | transforms.RandomCrop(img_size, padding=4, padding_mode='reflect'), 105 | transforms.RandomHorizontalFlip() 106 | ] 107 | else: 108 | primary_tfl = [ 109 | RandomResizedCropAndInterpolation(img_size, scale=scale, interpolation=interpolation), 110 | transforms.RandomHorizontalFlip() 111 | ] 112 | 113 | secondary_tfl = [transforms.RandomChoice([gray_scale(p=0.8), 114 | Solarization(p=0.8), 115 | GaussianBlur(p=0.8)])] 116 | 117 | if args.color_jitter is not None and not args.color_jitter == 0: 118 | secondary_tfl.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter)) 119 | 120 | final_tfl = [ 121 | transforms.ToTensor(), 122 | transforms.Normalize( 123 | mean=torch.tensor(mean), 124 | std=torch.tensor(std)) 125 | ] 126 | 127 | return transforms.Compose(primary_tfl + secondary_tfl + final_tfl) 128 | -------------------------------------------------------------------------------- /lib/ood_evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from sklearn.metrics import roc_auc_score, precision_recall_curve, auc 7 | 8 | from datasets import get_dataset 9 | 10 | 11 | class OODEvaluator: 12 | def __init__(self, in_dataset, out_dataset, model, data_root, num_workers, batch_size): 13 | self.in_ds_name = in_dataset 14 | self.out_ds_name = out_dataset 15 | in_dataset = get_dataset(in_dataset + "/test", root=data_root) 16 | out_dataset = get_dataset(out_dataset + "/test", root=data_root) 17 | out_dataset.transform = in_dataset.transform 18 | datasets = [in_dataset, out_dataset] 19 | self.targets = torch.cat([torch.zeros(len(in_dataset)), torch.ones(len(out_dataset))]) 20 | self.concat_dataset = torch.utils.data.ConcatDataset(datasets) 21 | self.dataloader = torch.utils.data.DataLoader( 22 | self.concat_dataset, batch_size, shuffle=False, num_workers=num_workers, pin_memory=True 23 | ) 24 | self.model: torch.nn.Module = model 25 | 26 | def loop_over_datasets(self): 27 | self.model.eval() 28 | with torch.no_grad(): 29 | scores = [] 30 | for x, _ in self.dataloader: 31 | x = x.cuda() 32 | y_pred = F.softmax(self.model(x), dim=1) 33 | uncertainty = -(y_pred * y_pred.log()).sum(1) 34 | scores.append(uncertainty.detach().cpu().numpy()) 35 | 36 | scores = np.concatenate(scores) 37 | return scores 38 | 39 | def get_ood_metrics(self): 40 | scores = self.loop_over_datasets() 41 | auroc = roc_auc_score(y_true=self.targets, y_score=scores) 42 | precision, recall, _ = precision_recall_curve(y_true=self.targets, probas_pred=scores) 43 | aupr = auc(recall, precision) 44 | return auroc, aupr -------------------------------------------------------------------------------- /lib/pos_embed.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | # -------------------------------------------------------- 6 | # 2D sine-cosine position embedding 7 | # References: 8 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 9 | # MoCo v3: https://github.com/facebookresearch/moco-v3 10 | # -------------------------------------------------------- 11 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 12 | """ 13 | grid_size: int of the grid height and width 14 | return: 15 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 16 | """ 17 | grid_h = np.arange(grid_size, dtype=np.float32) 18 | grid_w = np.arange(grid_size, dtype=np.float32) 19 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 20 | grid = np.stack(grid, axis=0) 21 | 22 | grid = grid.reshape([2, 1, grid_size, grid_size]) 23 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 24 | if cls_token: 25 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 26 | return pos_embed 27 | 28 | 29 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 30 | assert embed_dim % 2 == 0 31 | 32 | # use half of dimensions to encode grid_h 33 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 34 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 35 | 36 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 37 | return emb 38 | 39 | 40 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 41 | """ 42 | embed_dim: output dimension for each position 43 | pos: a list of positions to be encoded: size (M,) 44 | out: (M, D) 45 | """ 46 | assert embed_dim % 2 == 0 47 | omega = np.arange(embed_dim // 2, dtype=np.float) 48 | omega /= embed_dim / 2. 49 | omega = 1. / 10000**omega # (D/2,) 50 | 51 | pos = pos.reshape(-1) # (M,) 52 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 53 | 54 | emb_sin = np.sin(out) # (M, D/2) 55 | emb_cos = np.cos(out) # (M, D/2) 56 | 57 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 58 | return emb 59 | 60 | 61 | # -------------------------------------------------------- 62 | # Interpolate position embeddings for high-resolution 63 | # References: 64 | # DeiT: https://github.com/facebookresearch/deit 65 | # -------------------------------------------------------- 66 | def interpolate_pos_embed(model, checkpoint_model): 67 | if 'pos_embed' in checkpoint_model: 68 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 69 | embedding_size = pos_embed_checkpoint.shape[-1] 70 | num_patches = model.patch_embed.num_patches 71 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 72 | # height (== width) for the checkpoint position embedding 73 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 74 | # height (== width) for the new position embedding 75 | new_size = int(num_patches ** 0.5) 76 | # class_token and dist_token are kept unchanged 77 | if orig_size != new_size: 78 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 79 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 80 | # only the position tokens are interpolated 81 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 82 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 83 | pos_tokens = torch.nn.functional.interpolate( 84 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 85 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 86 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 87 | checkpoint_model['pos_embed'] = new_pos_embed 88 | -------------------------------------------------------------------------------- /lib/sngp.py: -------------------------------------------------------------------------------- 1 | # This implementation is based on: https://arxiv.org/abs/2006.10108 2 | # and implementation: https://github.com/google/edward2/blob/main/edward2/tensorflow/layers/random_feature.py 3 | # In particular the full data inverse that avoids momentum hyper-parameters. 4 | 5 | import math 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | def random_ortho(n, m): 12 | q, _ = torch.linalg.qr(torch.randn(n, m)) 13 | return q 14 | 15 | 16 | class RandomFourierFeatures(nn.Module): 17 | def __init__(self, in_dim, num_random_features, lengthscale=None): 18 | super().__init__() 19 | if lengthscale is None: 20 | lengthscale = math.sqrt(num_random_features / 2) 21 | 22 | self.register_buffer("lengthscale", torch.tensor(lengthscale)) 23 | 24 | if num_random_features <= in_dim: 25 | W = random_ortho(in_dim, num_random_features) 26 | else: 27 | # generate blocks of orthonormal rows which are not neccesarily orthonormal 28 | # to each other. 29 | dim_left = num_random_features 30 | ws = [] 31 | while dim_left > in_dim: 32 | ws.append(random_ortho(in_dim, in_dim)) 33 | dim_left -= in_dim 34 | ws.append(random_ortho(in_dim, dim_left)) 35 | W = torch.cat(ws, 1) 36 | 37 | # From: https://github.com/google/edward2/blob/d672c93b179bfcc99dd52228492c53d38cf074ba/edward2/tensorflow/initializers.py#L807-L817 38 | feature_norm = torch.randn(W.shape) ** 2 39 | W = W * feature_norm.sum(0).sqrt() 40 | self.register_buffer("W", W) 41 | 42 | b = torch.empty(num_random_features).uniform_(0, 2 * math.pi) 43 | self.register_buffer("b", b) 44 | 45 | def forward(self, x): 46 | k = torch.cos(x @ self.W + self.b) 47 | k = k / self.lengthscale 48 | 49 | return k 50 | 51 | 52 | class Laplace(nn.Module): 53 | def __init__( 54 | self, 55 | feature_extractor, 56 | num_deep_features, 57 | num_gp_features, 58 | normalize_gp_features, 59 | num_random_features, 60 | num_outputs, 61 | num_data, 62 | train_batch_size, 63 | mean_field_factor=None, # required for classification problems 64 | ridge_penalty=1.0, 65 | lengthscale=None, 66 | ): 67 | super().__init__() 68 | self.feature_extractor = feature_extractor 69 | self.mean_field_factor = mean_field_factor 70 | 71 | if num_gp_features > 0: 72 | self.num_gp_features = num_gp_features 73 | self.register_buffer( 74 | "random_matrix", 75 | torch.normal(0, 0.05, (num_gp_features, num_deep_features)), 76 | ) 77 | self.jl = lambda x: nn.functional.linear(x, self.random_matrix) 78 | else: 79 | self.num_gp_features = num_deep_features 80 | self.jl = nn.Identity() 81 | 82 | self.normalize_gp_features = normalize_gp_features 83 | if normalize_gp_features: 84 | self.normalize = nn.LayerNorm(self.num_gp_features) 85 | 86 | self.rff = RandomFourierFeatures( 87 | self.num_gp_features, num_random_features, lengthscale 88 | ) 89 | self.beta = nn.Linear(num_random_features, num_outputs) 90 | 91 | self.ridge_penalty = ridge_penalty 92 | 93 | self.train_batch_size = train_batch_size 94 | self.num_data = num_data 95 | self.register_buffer("seen_data", torch.tensor(0)) 96 | 97 | precision_matrix = torch.eye(num_random_features) * self.ridge_penalty 98 | self.register_buffer("precision_matrix", precision_matrix) 99 | 100 | def reset_precision_matrix(self): 101 | identity = torch.eye( 102 | self.precision_matrix.shape[0], device=self.precision_matrix.device 103 | ) 104 | self.precision_matrix = identity * self.ridge_penalty 105 | self.seen_data = torch.tensor(0) 106 | 107 | def mean_field_logits(self, logits, pred_cov): 108 | # Mean-Field approximation as alternative to MC integration of Gaussian-Softmax 109 | # Based on: https://arxiv.org/abs/2006.07584 110 | 111 | logits_scale = torch.sqrt(1.0 + torch.diag(pred_cov) * self.mean_field_factor) 112 | if self.mean_field_factor > 0: 113 | logits = logits / logits_scale.unsqueeze(-1) 114 | 115 | return logits 116 | 117 | def forward(self, x): 118 | f = self.feature_extractor(x) 119 | f_reduc = self.jl(f) 120 | if self.normalize_gp_features: 121 | f_reduc = self.normalize(f_reduc) 122 | 123 | k = self.rff(f_reduc) 124 | pred = self.beta(k) 125 | self.training = True 126 | 127 | if self.training: 128 | precision_matrix_minibatch: torch.Tensor = k.t() @ k 129 | self.precision_matrix += precision_matrix_minibatch.detach() 130 | self.seen_data += x.shape[0] 131 | else: 132 | # TODO: this is annoying for loading the model later 133 | assert self.seen_data > (self.num_data - self.train_batch_size), "not seen sufficient data" 134 | 135 | # TODO: cache this for efficiency 136 | cov = torch.inverse(self.precision_matrix) 137 | # cov = torch.linalg.inv(self.precision_matrix) 138 | pred_cov = k @ ((cov @ k.t()) * self.ridge_penalty) 139 | if self.mean_field_factor is None: 140 | return pred, pred_cov 141 | else: 142 | pred = self.mean_field_logits(pred, pred_cov) 143 | 144 | return pred -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 4 | from torchvision import transforms 5 | 6 | _mean = torch.tensor(list(IMAGENET_DEFAULT_MEAN)) 7 | _std = torch.tensor(list(IMAGENET_DEFAULT_STD)) 8 | unnormalize = transforms.Normalize((-_mean / _std).tolist(), (1.0 / _std).tolist()) 9 | -------------------------------------------------------------------------------- /lightning_logs/ViTASD-B/version_1/config.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.7.7 2 | seed_everything: 42 3 | trainer: 4 | logger: 5 | class_path: pytorch_lightning.loggers.TensorBoardLogger 6 | init_args: 7 | save_dir: ./lightning_logs 8 | name: ViTASD-B 9 | version: 1 10 | log_graph: false 11 | default_hp_metric: false 12 | prefix: '' 13 | sub_dir: null 14 | agg_key_funcs: null 15 | agg_default_func: null 16 | comment: '' 17 | purge_step: null 18 | max_queue: 10 19 | flush_secs: 120 20 | filename_suffix: '' 21 | enable_checkpointing: true 22 | callbacks: 23 | class_path: pytorch_lightning.callbacks.ModelCheckpoint 24 | init_args: 25 | dirpath: ./lightning_logs/ViTASD-B 26 | filename: null 27 | monitor: Accuracy/val 28 | verbose: false 29 | save_last: null 30 | save_top_k: 3 31 | save_weights_only: true 32 | mode: max 33 | auto_insert_metric_name: true 34 | every_n_train_steps: null 35 | train_time_interval: null 36 | every_n_epochs: null 37 | save_on_train_epoch_end: null 38 | default_root_dir: null 39 | gradient_clip_val: null 40 | gradient_clip_algorithm: null 41 | num_nodes: 1 42 | num_processes: null 43 | devices: 4 44 | gpus: null 45 | auto_select_gpus: false 46 | tpu_cores: null 47 | ipus: null 48 | enable_progress_bar: true 49 | overfit_batches: 0.0 50 | track_grad_norm: -1 51 | check_val_every_n_epoch: 1 52 | fast_dev_run: false 53 | accumulate_grad_batches: null 54 | max_epochs: 300 55 | min_epochs: null 56 | max_steps: -1 57 | min_steps: null 58 | max_time: null 59 | limit_train_batches: null 60 | limit_val_batches: null 61 | limit_test_batches: null 62 | limit_predict_batches: null 63 | val_check_interval: null 64 | log_every_n_steps: 50 65 | accelerator: gpu 66 | strategy: null 67 | sync_batchnorm: false 68 | precision: 32 69 | enable_model_summary: true 70 | weights_save_path: null 71 | num_sanity_val_steps: 2 72 | resume_from_checkpoint: null 73 | profiler: null 74 | benchmark: null 75 | deterministic: null 76 | reload_dataloaders_every_n_epochs: 0 77 | auto_lr_find: false 78 | replace_sampler_ddp: true 79 | detect_anomaly: false 80 | auto_scale_batch_size: false 81 | plugins: null 82 | amp_backend: native 83 | amp_level: null 84 | move_metrics_to_cpu: false 85 | multiple_trainloader_mode: max_size_cycle 86 | model: 87 | batch_size: 16 88 | num_classes: 2 89 | epochs: 300 90 | attn_only: true 91 | smoothing: 0.0 92 | vis_path: ./runs/vis 93 | model: deit3_base_patch16_224 94 | input_size: 224 95 | drop: 0.0 96 | drop_path: 0.05 97 | opt: adamw 98 | weight_decay: 0.05 99 | sched: cosine 100 | lr: 1.0e-05 101 | warmup_lr: 1.0e-06 102 | min_lr: 1.0e-06 103 | warmup_epochs: 5 104 | cooldown_epochs: 0 105 | mixup: 0.8 106 | cutmix: 1.0 107 | mixup_prob: 1.0 108 | mixup_switch_prob: 0.5 109 | mixup_mode: batch 110 | data: 111 | batch_size: 16 112 | num_workers: 4 113 | data_root: ./datasets/Kaggle_AutismDataset2 114 | input_size: 224 115 | color_jitter: 0.3 116 | three_augment: false 117 | src: false 118 | ckpt_path: null 119 | -------------------------------------------------------------------------------- /lightning_logs/ViTASD-B/version_1/hparams.yaml: -------------------------------------------------------------------------------- 1 | attn_only: true 2 | batch_size: 16 3 | color_jitter: 0.3 4 | cooldown_epochs: 0 5 | cutmix: 1.0 6 | data_root: ./ASD/datasets/Kaggle_AutismDataset2 7 | drop: 0.0 8 | drop_path: 0.05 9 | epochs: 300 10 | input_size: 224 11 | lr: 1.0e-05 12 | min_lr: 1.0e-06 13 | mixup: 0.8 14 | mixup_mode: batch 15 | mixup_prob: 1.0 16 | mixup_switch_prob: 0.5 17 | model: deit3_base_patch16_224 18 | num_classes: 2 19 | num_workers: 4 20 | opt: adamw 21 | sched: cosine 22 | smoothing: 0.0 23 | src: false 24 | three_augment: false 25 | vis_path: ./runs/vis 26 | warmup_epochs: 5 27 | warmup_lr: 1.0e-06 28 | weight_decay: 0.05 29 | -------------------------------------------------------------------------------- /lightning_logs/ViTASD-L/version_1/config.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.7.7 2 | seed_everything: 42 3 | trainer: 4 | logger: 5 | class_path: pytorch_lightning.loggers.TensorBoardLogger 6 | init_args: 7 | save_dir: ./lightning_logs 8 | name: ViTASD-L 9 | version: 1 10 | log_graph: false 11 | default_hp_metric: false 12 | prefix: '' 13 | sub_dir: null 14 | agg_key_funcs: null 15 | agg_default_func: null 16 | comment: '' 17 | purge_step: null 18 | max_queue: 10 19 | flush_secs: 120 20 | filename_suffix: '' 21 | enable_checkpointing: true 22 | callbacks: 23 | class_path: pytorch_lightning.callbacks.ModelCheckpoint 24 | init_args: 25 | dirpath: ./lightning_logs/ViTASD-L 26 | filename: null 27 | monitor: Accuracy/val 28 | verbose: false 29 | save_last: null 30 | save_top_k: 1 31 | save_weights_only: true 32 | mode: max 33 | auto_insert_metric_name: true 34 | every_n_train_steps: null 35 | train_time_interval: null 36 | every_n_epochs: null 37 | save_on_train_epoch_end: null 38 | default_root_dir: null 39 | gradient_clip_val: null 40 | gradient_clip_algorithm: null 41 | num_nodes: 1 42 | num_processes: null 43 | devices: 2 44 | gpus: null 45 | auto_select_gpus: false 46 | tpu_cores: null 47 | ipus: null 48 | enable_progress_bar: true 49 | overfit_batches: 0.0 50 | track_grad_norm: -1 51 | check_val_every_n_epoch: 1 52 | fast_dev_run: false 53 | accumulate_grad_batches: null 54 | max_epochs: 300 55 | min_epochs: null 56 | max_steps: -1 57 | min_steps: null 58 | max_time: null 59 | limit_train_batches: null 60 | limit_val_batches: null 61 | limit_test_batches: null 62 | limit_predict_batches: null 63 | val_check_interval: null 64 | log_every_n_steps: 50 65 | accelerator: gpu 66 | strategy: null 67 | sync_batchnorm: false 68 | precision: 32 69 | enable_model_summary: true 70 | weights_save_path: null 71 | num_sanity_val_steps: 2 72 | resume_from_checkpoint: null 73 | profiler: null 74 | benchmark: null 75 | deterministic: null 76 | reload_dataloaders_every_n_epochs: 0 77 | auto_lr_find: false 78 | replace_sampler_ddp: true 79 | detect_anomaly: false 80 | auto_scale_batch_size: false 81 | plugins: null 82 | amp_backend: native 83 | amp_level: null 84 | move_metrics_to_cpu: false 85 | multiple_trainloader_mode: max_size_cycle 86 | model: 87 | batch_size: 16 88 | num_classes: 2 89 | epochs: 300 90 | attn_only: false 91 | smoothing: 0.0 92 | vis_path: ./runs/vis 93 | model: deit3_large_patch16_224 94 | input_size: 224 95 | drop: 0.0 96 | drop_path: 0.05 97 | opt: adamw 98 | weight_decay: 0.05 99 | sched: cosine 100 | lr: 1.0e-05 101 | warmup_lr: 1.0e-06 102 | min_lr: 1.0e-06 103 | warmup_epochs: 5 104 | cooldown_epochs: 0 105 | mixup: 0.8 106 | cutmix: 1.0 107 | mixup_prob: 1.0 108 | mixup_switch_prob: 0.5 109 | mixup_mode: batch 110 | data: 111 | batch_size: 16 112 | num_workers: 4 113 | data_root: ./ASD/datasets/Kaggle_AutismDataset2 114 | input_size: 224 115 | color_jitter: 0.3 116 | three_augment: false 117 | src: false 118 | ckpt_path: null 119 | -------------------------------------------------------------------------------- /lightning_logs/ViTASD-L/version_1/hparams.yaml: -------------------------------------------------------------------------------- 1 | attn_only: false 2 | batch_size: 16 3 | color_jitter: 0.3 4 | cooldown_epochs: 0 5 | cutmix: 1.0 6 | data_root: ./ASD/datasets/Kaggle_AutismDataset2 7 | drop: 0.0 8 | drop_path: 0.05 9 | epochs: 300 10 | input_size: 224 11 | lr: 1.0e-05 12 | min_lr: 1.0e-06 13 | mixup: 0.8 14 | mixup_mode: batch 15 | mixup_prob: 1.0 16 | mixup_switch_prob: 0.5 17 | model: deit3_large_patch16_224 18 | num_classes: 2 19 | num_workers: 4 20 | opt: adamw 21 | sched: cosine 22 | smoothing: 0.0 23 | src: false 24 | three_augment: false 25 | vis_path: ./runs/vis 26 | warmup_epochs: 5 27 | warmup_lr: 1.0e-06 28 | weight_decay: 0.05 29 | -------------------------------------------------------------------------------- /lightning_logs/ViTASD-S/version_1/config.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.7.7 2 | seed_everything: 42 3 | trainer: 4 | logger: 5 | class_path: pytorch_lightning.loggers.TensorBoardLogger 6 | init_args: 7 | save_dir: ./lightning_logs 8 | name: ViTASD-S 9 | version: 1 10 | log_graph: false 11 | default_hp_metric: false 12 | prefix: '' 13 | sub_dir: null 14 | agg_key_funcs: null 15 | agg_default_func: null 16 | comment: '' 17 | purge_step: null 18 | max_queue: 10 19 | flush_secs: 120 20 | filename_suffix: '' 21 | enable_checkpointing: true 22 | callbacks: 23 | class_path: pytorch_lightning.callbacks.ModelCheckpoint 24 | init_args: 25 | dirpath: ./lightning_logs/ViTASD-S 26 | filename: null 27 | monitor: Accuracy/val 28 | verbose: false 29 | save_last: null 30 | save_top_k: 1 31 | save_weights_only: true 32 | mode: max 33 | auto_insert_metric_name: true 34 | every_n_train_steps: null 35 | train_time_interval: null 36 | every_n_epochs: null 37 | save_on_train_epoch_end: null 38 | default_root_dir: null 39 | gradient_clip_val: null 40 | gradient_clip_algorithm: null 41 | num_nodes: 1 42 | num_processes: null 43 | devices: 4 44 | gpus: null 45 | auto_select_gpus: false 46 | tpu_cores: null 47 | ipus: null 48 | enable_progress_bar: true 49 | overfit_batches: 0.0 50 | track_grad_norm: -1 51 | check_val_every_n_epoch: 1 52 | fast_dev_run: false 53 | accumulate_grad_batches: null 54 | max_epochs: 300 55 | min_epochs: null 56 | max_steps: -1 57 | min_steps: null 58 | max_time: null 59 | limit_train_batches: null 60 | limit_val_batches: null 61 | limit_test_batches: null 62 | limit_predict_batches: null 63 | val_check_interval: null 64 | log_every_n_steps: 50 65 | accelerator: gpu 66 | strategy: null 67 | sync_batchnorm: false 68 | precision: 32 69 | enable_model_summary: true 70 | weights_save_path: null 71 | num_sanity_val_steps: 2 72 | resume_from_checkpoint: null 73 | profiler: null 74 | benchmark: null 75 | deterministic: null 76 | reload_dataloaders_every_n_epochs: 0 77 | auto_lr_find: false 78 | replace_sampler_ddp: true 79 | detect_anomaly: false 80 | auto_scale_batch_size: false 81 | plugins: null 82 | amp_backend: native 83 | amp_level: null 84 | move_metrics_to_cpu: false 85 | multiple_trainloader_mode: max_size_cycle 86 | model: 87 | batch_size: 16 88 | num_classes: 2 89 | epochs: 300 90 | attn_only: false 91 | smoothing: 0.0 92 | vis_path: ./runs/vis 93 | model: deit3_small_patch16_224 94 | input_size: 224 95 | drop: 0.0 96 | drop_path: 0.05 97 | opt: adamw 98 | weight_decay: 0.05 99 | sched: cosine 100 | lr: 1.0e-05 101 | warmup_lr: 1.0e-06 102 | min_lr: 1.0e-06 103 | warmup_epochs: 5 104 | cooldown_epochs: 0 105 | mixup: 0.8 106 | cutmix: 1.0 107 | mixup_prob: 1.0 108 | mixup_switch_prob: 0.5 109 | mixup_mode: batch 110 | data: 111 | batch_size: 16 112 | num_workers: 4 113 | data_root: ./ASD/datasets/Kaggle_AutismDataset2 114 | input_size: 224 115 | color_jitter: 0.3 116 | three_augment: false 117 | src: false 118 | ckpt_path: null 119 | -------------------------------------------------------------------------------- /lightning_logs/ViTASD-S/version_1/hparams.yaml: -------------------------------------------------------------------------------- 1 | attn_only: false 2 | batch_size: 16 3 | color_jitter: 0.3 4 | cooldown_epochs: 0 5 | cutmix: 1.0 6 | data_root: ./ASD/datasets/Kaggle_AutismDataset2 7 | drop: 0.0 8 | drop_path: 0.05 9 | epochs: 300 10 | input_size: 224 11 | lr: 1.0e-05 12 | min_lr: 1.0e-06 13 | mixup: 0.8 14 | mixup_mode: batch 15 | mixup_prob: 1.0 16 | mixup_switch_prob: 0.5 17 | model: deit3_small_patch16_224 18 | num_classes: 2 19 | num_workers: 4 20 | opt: adamw 21 | sched: cosine 22 | smoothing: 0.0 23 | src: false 24 | three_augment: false 25 | vis_path: ./runs/vis 26 | warmup_epochs: 5 27 | warmup_lr: 1.0e-06 28 | weight_decay: 0.05 29 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.vitasd import ViTASD 2 | 3 | __all__ = ['ViTASD'] 4 | -------------------------------------------------------------------------------- /models/vitasd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from timm.models import create_model 5 | from timm.models.vision_transformer import VisionTransformer 6 | from timm.models.layers import PatchEmbed 7 | 8 | from lib.pos_embed import interpolate_pos_embed 9 | 10 | 11 | class ViTASD(nn.Module): 12 | def __init__(self, backbone: str, num_classes, drop_rate, drop_path_rate, input_size): 13 | super(ViTASD, self).__init__() 14 | self.num_classes = num_classes 15 | self.input_size = input_size 16 | 17 | self.backbone: VisionTransformer = create_model( 18 | backbone, 19 | pretrained=True, 20 | num_classes=num_classes, 21 | drop_rate=drop_rate, 22 | drop_path_rate=drop_path_rate, 23 | drop_block_rate=None, 24 | img_size=input_size 25 | ) 26 | self.embed_dim = self.backbone.embed_dim 27 | 28 | def forward(self, x): 29 | x = self.backbone.patch_embed(x) 30 | x = x + self.backbone.pos_embed 31 | x = torch.cat([self.backbone.cls_token.expand(x.shape[0], -1, -1), x], dim=1) 32 | x = self.backbone.blocks(x) 33 | x = self.backbone.norm(x) 34 | x = x[:, 0] 35 | x = self.backbone.head(x) 36 | return x 37 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | torch 4 | pytorch_lightning 5 | timm 6 | seaborn -------------------------------------------------------------------------------- /runs/figures/attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PediaMedAI/ViTASD/f96561962dbae63fc1e168c081946819e6262e9b/runs/figures/attention.png -------------------------------------------------------------------------------- /runs/figures/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PediaMedAI/ViTASD/f96561962dbae63fc1e168c081946819e6262e9b/runs/figures/framework.png -------------------------------------------------------------------------------- /runs/vis/cf_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PediaMedAI/ViTASD/f96561962dbae63fc1e168c081946819e6262e9b/runs/vis/cf_matrix.png -------------------------------------------------------------------------------- /runs/vis/histogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PediaMedAI/ViTASD/f96561962dbae63fc1e168c081946819e6262e9b/runs/vis/histogram.png -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PediaMedAI/ViTASD/f96561962dbae63fc1e168c081946819e6262e9b/tools/__init__.py -------------------------------------------------------------------------------- /tools/ls_timm_models.py: -------------------------------------------------------------------------------- 1 | import timm 2 | 3 | if __name__ == "__main__": 4 | print(timm.list_models('*deit3*', pretrained=True)) 5 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | 5 | from datasets import AutismDatasetModule 6 | from models import ViTASD 7 | from lib.pos_embed import interpolate_pos_embed 8 | 9 | from pytorch_lightning import LightningModule 10 | from pytorch_lightning.cli import LightningCLI 11 | from pytorch_lightning.utilities.types import STEP_OUTPUT, EPOCH_OUTPUT 12 | from pytorch_lightning.callbacks import Callback 13 | 14 | from torchmetrics import Accuracy, ConfusionMatrix, AUROC 15 | from torch.optim import Optimizer 16 | 17 | from timm.data import Mixup 18 | from timm.models import create_model 19 | from timm.optim import create_optimizer_v2 20 | from timm.scheduler import create_scheduler 21 | from timm.scheduler.scheduler import Scheduler 22 | 23 | from pathlib import Path 24 | from typing import Optional 25 | 26 | 27 | class ViTASDLM(LightningModule): 28 | def __init__(self, 29 | batch_size: int = 256, 30 | num_classes: int = 2, 31 | epochs: int = 300, 32 | attn_only: bool = False, 33 | smoothing: float = 0.0, # Label smoothing 34 | vis_path: str = "./runs/vis", 35 | 36 | # Model parameters 37 | model: str = "deit3_base_patch16_224", # Name of model to train 38 | input_size: int = 224, # images input size 39 | drop: float = 0.0, # Dropout rate 40 | drop_path: float = 0.05, # Drop path rate 41 | pretrain_path: str = "" 42 | 43 | # Optimizer parameters 44 | opt: str = "adamw", 45 | weight_decay: float = 0.05, 46 | 47 | # Learning rate schedule parameters 48 | sched: str = "cosine", 49 | lr: float = 1e-4, 50 | warmup_lr: float = 1e-6, 51 | min_lr: float = 1e-6, 52 | warmup_epochs: int = 5, # epochs to warmup LR, if scheduler supports 53 | cooldown_epochs: int = 0, # epochs to cooldown LR at min_lr, after cyclic schedule ends 54 | 55 | # Mixup parameters 56 | mixup: float = 0.8, # mixup alpha, mixup enabled if > 0 57 | cutmix: float = 1.0, # cutmix alpha, cutmix enabled if > 0. 58 | mixup_prob: float = 1.0, # Prob of performing mixup or cutmix when either/both is enabled 59 | mixup_switch_prob: float = 0.5, # Prob of switching to cutmix when both mixup and cutmix enabled 60 | mixup_mode: str = "batch", # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 61 | ): 62 | 63 | super(ViTASDLM, self).__init__() 64 | self.save_hyperparameters() 65 | 66 | self.model: torch.nn.Module = ViTASD( 67 | self.hparams.model, 68 | num_classes=self.hparams.num_classes, 69 | drop_rate=self.hparams.drop, 70 | drop_path_rate=self.hparams.drop_path, 71 | input_size=self.hparams.input_size 72 | ) 73 | 74 | if os.path.exists(pretrain_path): 75 | self._load_pretrained(pretrain_path) 76 | 77 | self._init_mixup() 78 | self._init_frozen_params() 79 | self.train_criterion = torch.nn.CrossEntropyLoss() 80 | self.valid_criterion = torch.nn.CrossEntropyLoss() 81 | self.valid_acc = Accuracy() 82 | self.auroc = AUROC(num_classes=2) 83 | self.confusion_matrix = ConfusionMatrix(num_classes=self.hparams.num_classes, normalize='true') 84 | 85 | def _init_mixup(self): 86 | self.mixup_fn = None 87 | mixup_active = self.hparams.mixup > 0 or self.hparams.cutmix > 0. 88 | if mixup_active: 89 | self.mixup_fn = Mixup( 90 | mixup_alpha=self.hparams.mixup, 91 | cutmix_alpha=self.hparams.cutmix, 92 | cutmix_minmax=None, 93 | prob=self.hparams.mixup_prob, 94 | switch_prob=self.hparams.mixup_switch_prob, 95 | mode=self.hparams.mixup_mode, 96 | label_smoothing=self.hparams.smoothing, 97 | num_classes=self.hparams.num_classes 98 | ) 99 | 100 | def _load_pretrained(self, pretrain_path): 101 | checkpoint = torch.load(pretrain_path) 102 | print("Load pre-trained checkpoint from: %s" % pretrain_path) 103 | checkpoint_model = checkpoint['state_dict'] 104 | state_dict = self.model.state_dict() 105 | for k in ['backbone.head.weight', 'backbone.head.bias']: 106 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 107 | print(f"Removing key {k} from pretrained checkpoint") 108 | del checkpoint_model[k] 109 | # interpolate position embedding 110 | interpolate_pos_embed(self.model, checkpoint_model) 111 | self.model.load_state_dict(checkpoint_model, strict=False) 112 | 113 | 114 | def _init_frozen_params(self): 115 | if self.hparams.attn_only: 116 | for name_p, p in self.model.named_parameters(): 117 | if '.attn.' in name_p: 118 | p.requires_grad = True 119 | else: 120 | p.requires_grad = False 121 | 122 | self.model.backbone.head.weight.requires_grad = True 123 | self.model.backbone.head.bias.requires_grad = True 124 | self.model.backbone.pos_embed.requires_grad = True 125 | for p in self.model.backbone.patch_embed.parameters(): 126 | p.requires_grad = True 127 | 128 | def forward(self, x): 129 | return self.model(x) 130 | 131 | def training_step(self, batch, batch_idx) -> STEP_OUTPUT: 132 | samples, targets = batch 133 | if self.mixup_fn is not None: 134 | samples, targets = self.mixup_fn(samples, targets) 135 | outputs = self.forward(samples) 136 | loss = self.train_criterion(outputs, targets) 137 | loss_value = loss.item() 138 | self.log('Loss/train', loss_value, sync_dist=True) 139 | return loss 140 | 141 | def validation_step(self, batch, batch_idx) -> STEP_OUTPUT: 142 | samples, targets = batch 143 | outputs = self.forward(samples) 144 | loss = self.valid_criterion(outputs, targets) 145 | loss_value = loss.item() 146 | self.valid_acc.update(outputs, targets) 147 | self.log("Accuracy/val", self.valid_acc, on_step=True, on_epoch=True, sync_dist=True) 148 | # self.log("AUROC/val", self.auroc(outputs, targets), on_epoch=True, sync_dist=True) 149 | self.log("Loss/val", loss_value, sync_dist=True) 150 | return self.valid_acc 151 | 152 | 153 | def test_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]: 154 | samples, targets = batch 155 | outputs = self.forward(samples) 156 | self.confusion_matrix.update(outputs, targets) 157 | 158 | def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: 159 | opt: Optimizer = self.optimizers() 160 | self.log("LR", opt.param_groups[0]["lr"], on_epoch=True, sync_dist=True) 161 | 162 | def on_test_end(self) -> None: 163 | self.visualize_confusion_matrix() 164 | 165 | def configure_optimizers(self): 166 | optimizer = create_optimizer_v2( 167 | self.model, 168 | opt=self.hparams.opt, 169 | lr=self.hparams.lr, 170 | weight_decay=self.hparams.weight_decay, 171 | ) 172 | scheduler, _ = create_scheduler(self.hparams, optimizer) 173 | return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}] 174 | 175 | def lr_scheduler_step(self, scheduler: Scheduler, optimizer_idx, metric) -> None: 176 | scheduler.step(epoch=self.current_epoch) # timm's scheduler need the epoch value 177 | 178 | def visualize_confusion_matrix(self): 179 | cf_matrix = self.confusion_matrix.compute().cpu() 180 | categories = [f'C{i}' for i in range(self.hparams.num_classes)] 181 | fig, ax = plt.subplots(1) 182 | sns.heatmap(cf_matrix, annot=True, cmap='Blues', fmt='.2f', xticklabels=categories, yticklabels=categories) 183 | ax.set_xlabel('Predicted') 184 | ax.set_ylabel('True Label') 185 | vis_path = Path(self.hparams.vis_path) 186 | fig.savefig(str(vis_path / f"cf_matrix.png"), dpi=200) 187 | 188 | 189 | def cli_main(): 190 | cli = LightningCLI(ViTASDLM, 191 | AutismDatasetModule, 192 | seed_everything_default=42, 193 | trainer_defaults=dict(accelerator='gpu', devices=1), 194 | save_config_overwrite=True, 195 | ) 196 | 197 | 198 | if __name__ == "__main__": 199 | cli_main() 200 | -------------------------------------------------------------------------------- /train_affectnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | 5 | from datasets import AffectNetDataModule 6 | from models import ViTASD, ResNetASD 7 | 8 | from pytorch_lightning import LightningModule 9 | from pytorch_lightning.cli import LightningCLI 10 | from pytorch_lightning.utilities.types import STEP_OUTPUT, EPOCH_OUTPUT 11 | from pytorch_lightning.callbacks import Callback 12 | 13 | from torchmetrics import Accuracy, ConfusionMatrix 14 | from torch.optim import Optimizer 15 | 16 | from timm.data import Mixup 17 | from timm.models import create_model 18 | from timm.optim import create_optimizer_v2 19 | from timm.scheduler import create_scheduler 20 | from timm.scheduler.scheduler import Scheduler 21 | 22 | from pathlib import Path 23 | from typing import Optional 24 | 25 | 26 | class ViTASDLM(LightningModule): 27 | def __init__(self, 28 | batch_size: int = 256, 29 | num_classes: int = 8, 30 | epochs: int = 100, 31 | attn_only: bool = False, 32 | smoothing: float = 0.0, # Label smoothing 33 | vis_path: str = "./runs/vis", 34 | 35 | # Model parameters 36 | model: str = "deit3_base_patch16_224", # Name of model to train 37 | input_size: int = 224, # images input size 38 | drop: float = 0.0, # Dropout rate 39 | drop_path: float = 0.05, # Drop path rate 40 | pretrain_path: str = "" 41 | 42 | # Optimizer parameters 43 | opt: str = "adamw", 44 | weight_decay: float = 0.05, 45 | 46 | # Learning rate schedule parameters 47 | sched: str = "cosine", 48 | lr: float = 1e-4, 49 | warmup_lr: float = 1e-6, 50 | min_lr: float = 1e-6, 51 | warmup_epochs: int = 5, # epochs to warmup LR, if scheduler supports 52 | cooldown_epochs: int = 0, # epochs to cooldown LR at min_lr, after cyclic schedule ends 53 | 54 | # Mixup parameters 55 | mixup: float = 0.8, # mixup alpha, mixup enabled if > 0 56 | cutmix: float = 1.0, # cutmix alpha, cutmix enabled if > 0. 57 | mixup_prob: float = 1.0, # Prob of performing mixup or cutmix when either/both is enabled 58 | mixup_switch_prob: float = 0.5, # Prob of switching to cutmix when both mixup and cutmix enabled 59 | mixup_mode: str = "batch", # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 60 | ): 61 | 62 | super(ViTASDLM, self).__init__() 63 | self.save_hyperparameters() 64 | 65 | # self.model: torch.nn.Module = ViTASD( 66 | # self.hparams.model, 67 | # num_classes=self.hparams.num_classes, 68 | # drop_rate=self.hparams.drop, 69 | # drop_path_rate=self.hparams.drop_path, 70 | # input_size=self.hparams.input_size 71 | # ) 72 | self.model: torch.nn.Module = ResNetASD( 73 | self.hparams.model, 74 | num_classes=self.hparams.num_classes, 75 | drop_rate=self.hparams.drop, 76 | drop_path_rate=self.hparams.drop_path, 77 | input_size=self.hparams.input_size 78 | ) 79 | 80 | 81 | self._init_mixup() 82 | self._init_frozen_params() 83 | self.train_criterion = torch.nn.CrossEntropyLoss() 84 | self.valid_criterion = torch.nn.CrossEntropyLoss() 85 | self.valid_acc = Accuracy() 86 | self.confusion_matrix = ConfusionMatrix(num_classes=self.hparams.num_classes, normalize='true') 87 | 88 | def _init_mixup(self): 89 | self.mixup_fn = None 90 | mixup_active = self.hparams.mixup > 0 or self.hparams.cutmix > 0. 91 | if mixup_active: 92 | self.mixup_fn = Mixup( 93 | mixup_alpha=self.hparams.mixup, 94 | cutmix_alpha=self.hparams.cutmix, 95 | cutmix_minmax=None, 96 | prob=self.hparams.mixup_prob, 97 | switch_prob=self.hparams.mixup_switch_prob, 98 | mode=self.hparams.mixup_mode, 99 | label_smoothing=self.hparams.smoothing, 100 | num_classes=self.hparams.num_classes 101 | ) 102 | 103 | def _init_frozen_params(self): 104 | if self.hparams.attn_only: 105 | for name_p, p in self.model.named_parameters(): 106 | if '.attn.' in name_p: 107 | p.requires_grad = True 108 | else: 109 | p.requires_grad = False 110 | 111 | self.model.backbone.head.weight.requires_grad = True 112 | self.model.backbone.head.bias.requires_grad = True 113 | self.model.backbone.pos_embed.requires_grad = True 114 | for p in self.model.backbone.patch_embed.parameters(): 115 | p.requires_grad = True 116 | 117 | def forward(self, x): 118 | return self.model(x) 119 | 120 | def training_step(self, batch, batch_idx) -> STEP_OUTPUT: 121 | samples, targets = batch 122 | if self.mixup_fn is not None: 123 | samples, targets = self.mixup_fn(samples, targets) 124 | outputs = self.forward(samples) 125 | loss = self.train_criterion(outputs, targets) 126 | loss_value = loss.item() 127 | self.log('Loss/train', loss_value, sync_dist=True) 128 | 129 | return loss 130 | 131 | def validation_step(self, batch, batch_idx) -> STEP_OUTPUT: 132 | samples, targets = batch 133 | outputs = self.forward(samples) 134 | loss = self.valid_criterion(outputs, targets) 135 | loss_value = loss.item() 136 | self.valid_acc.update(outputs, targets) 137 | self.log("Accuracy/val", self.valid_acc, on_step=True, on_epoch=True, sync_dist=True) 138 | self.log("Loss/val", loss_value, sync_dist=True) 139 | 140 | return self.valid_acc 141 | 142 | 143 | def test_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]: 144 | samples, targets = batch 145 | outputs = self.forward(samples) 146 | self.confusion_matrix.update(outputs, targets) 147 | 148 | def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: 149 | opt: Optimizer = self.optimizers() 150 | self.log("LR", opt.param_groups[0]["lr"], on_epoch=True, sync_dist=True) 151 | 152 | def on_test_end(self) -> None: 153 | self.visualize_confusion_matrix() 154 | 155 | def configure_optimizers(self): 156 | optimizer = create_optimizer_v2( 157 | self.model, 158 | opt=self.hparams.opt, 159 | lr=self.hparams.lr, 160 | weight_decay=self.hparams.weight_decay, 161 | ) 162 | scheduler, _ = create_scheduler(self.hparams, optimizer) 163 | return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}] 164 | 165 | def lr_scheduler_step(self, scheduler: Scheduler, optimizer_idx, metric) -> None: 166 | scheduler.step(epoch=self.current_epoch) # timm's scheduler need the epoch value 167 | 168 | def visualize_confusion_matrix(self): 169 | cf_matrix = self.confusion_matrix.compute().cpu() 170 | categories = [f'C{i}' for i in range(self.hparams.num_classes)] 171 | fig, ax = plt.subplots(1) 172 | sns.heatmap(cf_matrix, annot=True, cmap='Blues', fmt='.2f', xticklabels=categories, yticklabels=categories) 173 | ax.set_xlabel('Predicted') 174 | ax.set_ylabel('True Label') 175 | vis_path = Path(self.hparams.vis_path) 176 | fig.savefig(str(vis_path / f"cf_matrix.png"), dpi=200) 177 | 178 | 179 | def cli_main(): 180 | cli = LightningCLI(ViTASDLM, 181 | AffectNetDataModule, 182 | seed_everything_default=42, 183 | trainer_defaults=dict(accelerator='gpu', devices=1), 184 | save_config_overwrite=True, 185 | ) 186 | 187 | 188 | if __name__ == "__main__": 189 | cli_main() 190 | -------------------------------------------------------------------------------- /visualization_attention.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import functools 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import seaborn as sns 6 | import torch 7 | import torch.nn.functional as F 8 | import torchshow 9 | import types 10 | 11 | from pathlib import Path 12 | from tqdm import tqdm 13 | 14 | from datasets import AutismDatasetModule 15 | from lib.utils import unnormalize 16 | from train import ViTASDLM 17 | 18 | 19 | def forward(self, x, attn_maps: list): 20 | B, N, C = x.shape 21 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 22 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 23 | 24 | attn = (q @ k.transpose(-2, -1)) * self.scale 25 | attn = attn.softmax(dim=-1) 26 | attn = self.attn_drop(attn) 27 | 28 | # Save attention maps here 29 | tensor_attn = attn[0, 0, :, :] 30 | attn_maps.append(torch.clone(tensor_attn).cpu().numpy()) 31 | ######## 32 | 33 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 34 | x = self.proj(x) 35 | x = self.proj_drop(x) 36 | return x 37 | 38 | 39 | def vis_attn(mat: np.ndarray, output_path: Path, img=None, heatmap_kwargs=None, savefig_kwargs=None): 40 | if heatmap_kwargs is None: 41 | heatmap_kwargs = {} 42 | if savefig_kwargs is None: 43 | savefig_kwargs = {} 44 | # if not output_path.exists(): 45 | # output_path.mkdir(parents=True) 46 | 47 | fig, ax = plt.subplots(1) 48 | heatmap = sns.heatmap(mat, ax=ax, **heatmap_kwargs) 49 | if img is not None: 50 | ax.imshow(img, aspect=heatmap.get_aspect(), extent=heatmap.get_xlim() + heatmap.get_ylim(), zorder=1) 51 | fig.savefig(output_path, dpi=200, **savefig_kwargs) 52 | ax.clear() 53 | 54 | 55 | def main(hparams): 56 | hparams['output_root'] = Path(hparams['output_root']) 57 | data_module = AutismDatasetModule( 58 | batch_size=1, 59 | data_root=hparams['data_root'], 60 | color_jitter=0, 61 | input_size=224, 62 | three_augment=False 63 | ) 64 | model: ViTASDLM = ViTASDLM.load_from_checkpoint(hparams['ckpt_path']) 65 | model.eval() 66 | loader = data_module.predict_dataloader() 67 | dataset: ImageDataset = loader.dataset 68 | attn_maps = [] 69 | for i in range(hparams['num_layers']): 70 | attn_layer: torch.nn.Module = model.model.backbone.blocks[i].attn 71 | forward_fn = functools.partial(forward, attn_maps=attn_maps) 72 | attn_layer.forward = types.MethodType(forward_fn, attn_layer) 73 | 74 | image_path = 'imgs/Non_Autistic/011.jpg' 75 | 76 | it = iter(loader) 77 | for image_idx in tqdm(range(len(dataset))): 78 | data_item = next(it) 79 | cur_path = str(dataset.filename(image_idx, absolute=False)).strip() 80 | if cur_path == image_path: 81 | _, cls, number = cur_path.split('/') 82 | number = number.split('.')[0] 83 | output_path = hparams['output_root'] / image_path.split('.')[0] 84 | # if not output_path.exists(): 85 | # output_path.mkdir(parents=True) 86 | 87 | img, target = data_item 88 | torchshow.save(img, str(output_path / f'input.jpg')) 89 | with torch.no_grad(): 90 | pred = model(img) 91 | pred = F.softmax(pred[0], dim=0) 92 | print(pred) 93 | img = unnormalize(img[0]).permute(1, 2, 0).cpu().numpy() 94 | 95 | i = 1 96 | for mp in attn_maps: 97 | # Attentions between the distraction token and visual tokens 98 | mat = mp[0, :] 99 | heatmap_kwargs = dict(cmap="jet", zorder=2, cbar=False, xticklabels=False, yticklabels=False, 100 | alpha=0.5) 101 | savefig_kwargs = dict(bbox_inches='tight', pad_inches=0.01) 102 | mat = mat[1: 197].reshape([14, 14]) 103 | vis_attn(mat, output_path / f'attn_{i}.png', 104 | img, heatmap_kwargs, savefig_kwargs) 105 | i += 1 106 | 107 | attn_maps.clear() 108 | 109 | 110 | if __name__ == "__main__": 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument('--output_root', type=str, default=r"") 113 | parser.add_argument('--ckpt_path', type=str, default=r"") 114 | parser.add_argument('--data_root', type=str, default=r"") 115 | parser.add_argument('--num_layers', type=int, default=12) 116 | args = vars(parser.parse_args()) 117 | main(args) 118 | --------------------------------------------------------------------------------