├── assets ├── diagram.png └── diagram-larger.png ├── config ├── checkpoint │ ├── scratch.yaml │ └── pretrained.yaml ├── experiment │ ├── mae_cifar10.yaml │ ├── mae_path.yaml │ ├── mae_derma.yaml │ ├── mae_blood.yaml │ ├── mae_imagenet.yaml │ ├── mae_tiny.yaml │ ├── mae_sampling_cifar10.yaml │ ├── mae_sampling_derma.yaml │ ├── mae_sampling_path.yaml │ ├── mae_sampling_blood.yaml │ ├── mae_sampling_tiny.yaml │ ├── pcmae_cifar10_pc.yaml │ ├── pcmae_derma_pc.yaml │ ├── pcmae_imagenet_pc.yaml │ ├── pcmae_path_pc.yaml │ ├── pcmae_tiny_pc.yaml │ ├── pcmae_blood_pc.yaml │ ├── pcmae_blood_pcsampling.yaml │ ├── pcmae_cifar10_pcsampling.yaml │ ├── pcmae_derma_pcsampling.yaml │ ├── pcmae_path_pcsampling.yaml │ └── pcmae_tiny_pcsampling.yaml ├── mode │ └── local.yaml ├── masking │ ├── pc_pc.yaml │ ├── mae.yaml │ ├── mae_sampling.yaml │ └── pc_pcsampling.yaml ├── train_defaults.yaml ├── trainer │ ├── train_eval.yaml │ ├── eval_knn.yaml │ ├── eval_lin.yaml │ ├── eval_mlp.yaml │ ├── eval_transfert.yaml │ ├── eval_fine.yaml │ └── eval_transfert_fine.yaml ├── model │ ├── vit-s.yaml │ ├── vit-t.yaml │ └── vit-b.yaml ├── base │ ├── offline.yaml │ └── online.yaml ├── transformations │ └── mae.yaml ├── dataset │ ├── imagenet.yaml │ ├── tinyimagenet.yaml │ ├── cifar10.yaml │ ├── pathmnist.yaml │ ├── dermamnist.yaml │ └── bloodmnist.yaml └── user │ └── abizeul_euler.yaml ├── src ├── __pycache__ │ ├── utils.cpython-310.pyc │ ├── plotting.cpython-310.pyc │ └── dataloader.cpython-310.pyc ├── model │ ├── __pycache__ │ │ ├── module.cpython-310.pyc │ │ ├── vit_mae.cpython-310.pyc │ │ ├── module_knn.cpython-310.pyc │ │ └── module_lin.cpython-310.pyc │ ├── module_knn.py │ ├── vit_mae.py │ ├── module.py │ ├── module_fine.py │ ├── module_lin.py │ ├── module_mlp.py │ └── module_transfert.py ├── plotting.py ├── dataloader.py └── utils.py ├── scripts ├── train.sh └── eval.sh ├── tools ├── imagenet_val.py └── pca.py ├── requirements.txt ├── main.py └── README.md /assets/diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alicebizeul/pmae/HEAD/assets/diagram.png -------------------------------------------------------------------------------- /assets/diagram-larger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alicebizeul/pmae/HEAD/assets/diagram-larger.png -------------------------------------------------------------------------------- /config/checkpoint/scratch.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | checkpoint: 4 | checkpoint_fn: 5 | _target_: torch.load 6 | f: -------------------------------------------------------------------------------- /src/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alicebizeul/pmae/HEAD/src/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/plotting.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alicebizeul/pmae/HEAD/src/__pycache__/plotting.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/dataloader.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alicebizeul/pmae/HEAD/src/__pycache__/dataloader.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/module.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alicebizeul/pmae/HEAD/src/model/__pycache__/module.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/vit_mae.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alicebizeul/pmae/HEAD/src/model/__pycache__/vit_mae.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/module_knn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alicebizeul/pmae/HEAD/src/model/__pycache__/module_knn.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/module_lin.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alicebizeul/pmae/HEAD/src/model/__pycache__/module_lin.cpython-310.pyc -------------------------------------------------------------------------------- /config/experiment/mae_cifar10.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: cifar10 5 | - override /model: vit-t 6 | - override /masking: mae -------------------------------------------------------------------------------- /config/experiment/mae_path.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: pathmnist 5 | - override /model: vit-t 6 | - override /masking: mae 7 | -------------------------------------------------------------------------------- /config/experiment/mae_derma.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: dermamnist 5 | - override /model: vit-t 6 | - override /masking: mae 7 | -------------------------------------------------------------------------------- /config/mode/local.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | project: pcmae 4 | 5 | trainer: 6 | num_nodes: 1 7 | devices: 1 8 | accelerator: gpu 9 | strategy: auto 10 | -------------------------------------------------------------------------------- /config/experiment/mae_blood.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: bloodmnist 5 | - override /model: vit-t 6 | - override /masking: mae 7 | 8 | -------------------------------------------------------------------------------- /config/experiment/mae_imagenet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: imagenet 5 | - override /model: vit-b 6 | - override /masking: mae 7 | 8 | -------------------------------------------------------------------------------- /config/experiment/mae_tiny.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: tinyimagenet 5 | - override /model: vit-t 6 | - override /masking: mae 7 | 8 | -------------------------------------------------------------------------------- /config/experiment/mae_sampling_cifar10.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: cifar10 5 | - override /model: vit-t 6 | - override /masking: mae_sampling -------------------------------------------------------------------------------- /config/experiment/mae_sampling_derma.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: dermamnist 5 | - override /model: vit-t 6 | - override /masking: mae_sampling 7 | -------------------------------------------------------------------------------- /config/experiment/mae_sampling_path.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: pathmnist 5 | - override /model: vit-t 6 | - override /masking: mae_sampling 7 | -------------------------------------------------------------------------------- /config/experiment/mae_sampling_blood.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: bloodmnist 5 | - override /model: vit-t 6 | - override /masking: mae_sampling 7 | 8 | -------------------------------------------------------------------------------- /config/experiment/mae_sampling_tiny.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: tinyimagenet 5 | - override /model: vit-t 6 | - override /masking: mae_sampling 7 | 8 | -------------------------------------------------------------------------------- /config/experiment/pcmae_cifar10_pc.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: cifar10 5 | - override /model: vit-t 6 | - override /masking: pc_pc 7 | - override /transformations: mae 8 | 9 | -------------------------------------------------------------------------------- /config/experiment/pcmae_derma_pc.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: dermamnist 5 | - override /model: vit-t 6 | - override /masking: pc_pc 7 | - override /transformations: mae 8 | 9 | -------------------------------------------------------------------------------- /config/experiment/pcmae_imagenet_pc.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: imagenet 5 | - override /model: vit-b 6 | - override /masking: pc_pc 7 | - override /transformations: mae 8 | 9 | -------------------------------------------------------------------------------- /config/experiment/pcmae_path_pc.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: pathmnist 5 | - override /model: vit-t 6 | - override /masking: pc_pc 7 | - override /transformations: mae 8 | 9 | -------------------------------------------------------------------------------- /config/experiment/pcmae_tiny_pc.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: tinyimagenet 5 | - override /model: vit-t 6 | - override /masking: pc_pc 7 | - override /transformations: mae 8 | 9 | -------------------------------------------------------------------------------- /config/experiment/pcmae_blood_pc.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: bloodmnist 5 | - override /model: vit-t 6 | - override /masking: pc_pc 7 | - override /transformations: mae 8 | 9 | 10 | -------------------------------------------------------------------------------- /config/experiment/pcmae_blood_pcsampling.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: bloodmnist 5 | - override /model: vit-t 6 | - override /masking: pc_pcsampling 7 | - override /transformations: mae 8 | 9 | -------------------------------------------------------------------------------- /config/experiment/pcmae_cifar10_pcsampling.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: cifar10 5 | - override /model: vit-t 6 | - override /masking: pc_pcsampling 7 | - override /transformations: mae 8 | 9 | -------------------------------------------------------------------------------- /config/experiment/pcmae_derma_pcsampling.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: dermamnist 5 | - override /model: vit-t 6 | - override /masking: pc_pcsampling 7 | - override /transformations: mae 8 | 9 | -------------------------------------------------------------------------------- /config/experiment/pcmae_path_pcsampling.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: pathmnist 5 | - override /model: vit-t 6 | - override /masking: pc_pcsampling 7 | - override /transformations: mae 8 | 9 | -------------------------------------------------------------------------------- /config/experiment/pcmae_tiny_pcsampling.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataset: tinyimagenet 5 | - override /model: vit-t 6 | - override /masking: pc_pcsampling 7 | - override /transformations: mae 8 | 9 | -------------------------------------------------------------------------------- /config/masking/pc_pc.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | masking: 4 | type: pc 5 | strategy: pc 6 | pc_ratio: 0.8 7 | pixel_ratio: 0.0 8 | ratio: ${decimal_2_percent:${masking.pc_ratio}} 9 | str_ratio: ${convert_str:${masking.ratio}} -------------------------------------------------------------------------------- /config/masking/mae.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | masking: 4 | type: pixel 5 | strategy: pixel 6 | pixel_ratio: 0.75 7 | pc_ratio: 0.0 8 | ratio: ${decimal_2_percent:${masking.pixel_ratio}} 9 | str_ratio: ${convert_str:${masking.ratio}} -------------------------------------------------------------------------------- /config/masking/mae_sampling.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | masking: 4 | type: pixel 5 | strategy: sampling 6 | pixel_ratio: null 7 | pc_ratio: 0.0 8 | ratio: ${decimal_2_percent:${masking.pixel_ratio}} 9 | str_ratio: ${convert_str:${masking.ratio}} -------------------------------------------------------------------------------- /config/masking/pc_pcsampling.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | masking: 4 | type: pc 5 | strategy: sampling_pc 6 | pc_ratio: null 7 | pixel_ratio: 0.0 8 | ratio: ${decimal_2_percent:${masking.pc_ratio}} 9 | str_ratio: ${convert_str:${masking.ratio}} -------------------------------------------------------------------------------- /config/train_defaults.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - base: offline 5 | - user: ${oc.env:USER}_personal 6 | - dataset: tinyimagenet 7 | - masking: mae 8 | - model: vit-t 9 | - mode: local 10 | - transformations: mae 11 | - experiment: mae_tiny 12 | - checkpoint: scratch 13 | - trainer: train_eval -------------------------------------------------------------------------------- /config/checkpoint/pretrained.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | checkpoint: 4 | epoch: 800 5 | epoch_name: ${substract_one:${checkpoint.epoch}} 6 | run_path: ${checkpoint_folders.${model_name}.${data.name}.${masking.type}.${masking.strategy}.${masking.str_ratio}.${data.str_patch_size}} 7 | path: ${base_logs_dir}/${checkpoint.run_path}/checkpoints/epoch=${checkpoint.epoch_name}-train_loss=0.00.ckpt 8 | 9 | checkpoint_fn: 10 | _target_: torch.load 11 | f: ${checkpoint.path} -------------------------------------------------------------------------------- /config/trainer/train_eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | trainer: 4 | max_epochs: 800 5 | 6 | evaluator: 7 | max_epochs: 100 8 | 9 | evaluated_epoch: ${trainer.max_epochs} 10 | 11 | pl_module_eval: 12 | _target_: src.model.module_lin.ViTMAE_lin 13 | base_learning_rate: 0.1 14 | betas: 0.9 15 | weight_decay: 0 16 | optimizer_name: sgd 17 | warmup: 10 18 | learning_rate: ${compute_lr:${pl_module_eval.base_learning_rate},${datamodule_eval.batch_size}} 19 | eval_type: ${data.eval_type} 20 | eval_fn: ${data.eval_fn} 21 | evaluated_epoch: ${trainer.max_epochs} 22 | 23 | -------------------------------------------------------------------------------- /config/trainer/eval_knn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /checkpoint: pretrained 5 | - override /transformations: mae_adapted 6 | 7 | trainer: 8 | max_epochs: 0 9 | 10 | evaluator: 11 | max_epochs: 1 12 | 13 | pl_module_eval: 14 | _target_: src.model.module_knn.ViTMAE_knn 15 | evaluated_epoch: ${evaluated_epoch} 16 | k: 17 | _target_: numpy.arange 18 | start: 2 19 | stop: 20 20 | 21 | # overriding so that we get where we were if only just evaluating a run 22 | logs_dir: ${base_logs_dir}/${checkpoint.run_path} 23 | local_dir: ${base_outputs_dir}/${checkpoint.run_path} 24 | evaluated_epoch: ${checkpoint.epoch} 25 | -------------------------------------------------------------------------------- /config/trainer/eval_lin.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /checkpoint: pretrained 5 | - override /transformations: mae 6 | 7 | trainer: 8 | max_epochs: 0 9 | 10 | evaluator: 11 | max_epochs: 100 12 | 13 | pl_module_eval: 14 | _target_: src.model.module_lin.ViTMAE_lin 15 | base_learning_rate: 0.1 16 | betas: 0.9 17 | weight_decay: 0 18 | optimizer_name: sgd 19 | warmup: 10 20 | learning_rate: ${compute_lr:${pl_module_eval.base_learning_rate},${datamodule_eval.batch_size}} 21 | eval_type: ${data.eval_type} 22 | eval_fn: ${data.eval_fn} 23 | evaluated_epoch: ${evaluated_epoch} 24 | 25 | # overriding so that we get where we were if only just evaluating a run 26 | logs_dir: ${base_logs_dir}/${checkpoint.run_path} 27 | local_dir: ${base_outputs_dir}/${checkpoint.run_path} 28 | evaluated_epoch: ${checkpoint.epoch} 29 | -------------------------------------------------------------------------------- /config/trainer/eval_mlp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /checkpoint: pretrained 5 | - override /transformations: mae 6 | 7 | trainer: 8 | max_epochs: 0 9 | 10 | evaluator: 11 | max_epochs: 100 12 | 13 | pl_module_eval: 14 | _target_: src.model.module_mlp.ViTMAE_mlp 15 | base_learning_rate: 0.1 16 | betas: 0.9 17 | weight_decay: 0 18 | optimizer_name: sgd 19 | warmup: 10 20 | learning_rate: ${compute_lr:${pl_module_eval.base_learning_rate},${datamodule_eval.batch_size}} 21 | eval_type: ${data.eval_type} 22 | eval_fn: ${data.eval_fn} 23 | evaluated_epoch: ${evaluated_epoch} 24 | 25 | # overriding so that we get where we were if only just evaluating a run 26 | logs_dir: ${base_logs_dir}/${checkpoint.run_path} 27 | local_dir: ${base_outputs_dir}/${checkpoint.run_path} 28 | evaluated_epoch: ${checkpoint.epoch} 29 | -------------------------------------------------------------------------------- /config/trainer/eval_transfert.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /checkpoint: pretrained 5 | - override /transformations: mae 6 | 7 | trainer: 8 | max_epochs: 0 9 | 10 | evaluator: 11 | max_epochs: 100 12 | 13 | pl_module_eval: 14 | _target_: src.model.module_transfert.ViTMAE_transfert 15 | base_learning_rate: 0.1 16 | betas: 0.9 17 | weight_decay: 0 18 | optimizer_name: sgd 19 | warmup: 10 20 | learning_rate: ${compute_lr:${pl_module_eval.base_learning_rate},${datamodule_eval.batch_size}} 21 | eval_type: ${data.eval_type} 22 | eval_fn: ${data.eval_fn} 23 | evaluated_epoch: ${evaluated_epoch} 24 | 25 | # overriding so that we get where we were if only just evaluating a run 26 | logs_dir: ${base_logs_dir}/${checkpoint.run_path} 27 | local_dir: ${base_outputs_dir}/${checkpoint.run_path} 28 | evaluated_epoch: ${checkpoint.epoch} 29 | -------------------------------------------------------------------------------- /config/trainer/eval_fine.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /checkpoint: pretrained 5 | - override /transformations: mae 6 | 7 | trainer: 8 | max_epochs: 0 9 | 10 | evaluator: 11 | max_epochs: 100 12 | 13 | pl_module_eval: 14 | _target_: src.model.module_fine.ViTMAE_fine 15 | base_learning_rate: 0.001 16 | betas: [0.9,0.999] 17 | weight_decay: 0.05 18 | optimizer_name: adamw_warmup 19 | warmup: 5 20 | learning_rate: ${compute_lr:${pl_module_eval.base_learning_rate},${datamodule_eval.batch_size}} 21 | eval_type: ${data.eval_type} 22 | eval_fn: ${data.eval_fn} 23 | evaluated_epoch: ${evaluated_epoch} 24 | 25 | # overriding so that we get where we were if only just evaluating a run 26 | logs_dir: ${base_logs_dir}/${checkpoint.run_path} 27 | local_dir: ${base_outputs_dir}/${checkpoint.run_path} 28 | evaluated_epoch: ${checkpoint.epoch} 29 | -------------------------------------------------------------------------------- /config/model/vit-s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | module: 4 | _target_: src.model.vit_mae.ViTMAEForPreTraining 5 | 6 | model_name: "vit-s" 7 | 8 | pl_module: 9 | _target_: src.model.module.ViTMAE 10 | base_learning_rate: 1.5e-4 11 | betas: [0.9, 0.95] 12 | weight_decay: 0.05 13 | optimizer_name: adamw_warmup 14 | warmup: 40 15 | eval_freq: 100 16 | eval_type: ${data.eval_type} 17 | eval_fn: ${data.eval_fn} 18 | eval_logit_fn: ${data.eval_logit_fn} 19 | learning_rate: ${compute_lr:${pl_module.base_learning_rate},${datamodule.batch_size}} 20 | 21 | module_config: 22 | _target_: transformers.ViTMAEConfig 23 | hidden_size: 384 24 | num_attention_head: 6 25 | intermediate_size: 1536 26 | norm_pix_loss: False 27 | attn_implementation: "eager" 28 | mask_ratio: ${masking.pixel_ratio} 29 | patch_size: ${data.patch_size} 30 | image_size: ${data.resolution} 31 | -------------------------------------------------------------------------------- /config/trainer/eval_transfert_fine.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /checkpoint: pretrained 5 | - override /transformations: mae 6 | 7 | trainer: 8 | max_epochs: 0 9 | 10 | evaluator: 11 | max_epochs: 100 12 | 13 | pl_module_eval: 14 | _target_: src.model.module_transfert_fine.ViTMAE_transfert_fine 15 | base_learning_rate: 0.1 16 | betas: 0.9 17 | weight_decay: 0 18 | optimizer_name: sgd 19 | warmup: 10 20 | learning_rate: ${compute_lr:${pl_module_eval.base_learning_rate},${datamodule_eval.batch_size}} 21 | eval_type: ${data.eval_type} 22 | eval_fn: ${data.eval_fn} 23 | evaluated_epoch: ${evaluated_epoch} 24 | 25 | # overriding so that we get where we were if only just evaluating a run 26 | logs_dir: ${base_logs_dir}/${checkpoint.run_path} 27 | local_dir: ${base_outputs_dir}/${checkpoint.run_path} 28 | evaluated_epoch: ${checkpoint.epoch} 29 | -------------------------------------------------------------------------------- /config/model/vit-t.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | module: 4 | _target_: src.model.vit_mae.ViTMAEForPreTraining 5 | 6 | model_name: "vit-t" 7 | 8 | pl_module: 9 | _target_: src.model.module.ViTMAE 10 | base_learning_rate: 1.5e-4 11 | betas: [0.9, 0.95] 12 | weight_decay: 0.05 13 | optimizer_name: adamw_warmup 14 | warmup: 40 15 | eval_freq: 100 16 | eval_type: ${data.eval_type} 17 | eval_fn: ${data.eval_fn} 18 | eval_logit_fn: ${data.eval_logit_fn} 19 | learning_rate: ${compute_lr:${pl_module.base_learning_rate},${datamodule.batch_size}} 20 | 21 | module_config: 22 | _target_: transformers.ViTMAEConfig 23 | hidden_size: 192 24 | num_attention_head: 3 25 | intermediate_size: 768 26 | norm_pix_loss: False 27 | attn_implementation: "eager" 28 | mask_ratio: ${masking.pixel_ratio} 29 | patch_size: ${data.patch_size} 30 | image_size: ${data.resolution} 31 | 32 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # setup environment 4 | conda activate mae 5 | 6 | # ressource specs 7 | NUM_WORKERS=8 8 | TIME=120:00:00 9 | MEM_PER_CPU=12G 10 | MEM_PER_GPU=24G 11 | 12 | # user name 13 | USER_MACHINE=abizeul_euler 14 | 15 | # which experiment and masking ratio to run, see config/experiment 16 | EXPERIMENT=pcmae_cifar10_pc 17 | MASK=0.2 18 | 19 | # for imagenet, ensure to request a GPU with 80GB of RAM 20 | if [ "$EXPERIMENT" == "pcmae_imagenet_pc" ]; then 21 | MEM_PER_GPU=80G 22 | fi 23 | 24 | RUN_TAG=""$EXPERIMENT"_mask_"$MASK"" 25 | NAME="../$RUN_TAG" 26 | JOB="python main.py user=$USER_MACHINE experiment=$EXPERIMENT masking.pc_ratio=$MASK run_tag=$RUN_TAG" 27 | sbatch -o "$NAME" -n 1 --cpus-per-task "$NUM_WORKERS" --mem-per-cpu="$MEM_PER_CPU" --time="$TIME" -p gpu --gpus=1 --gres=gpumem:"$MEM_PER_GPU" --wrap="nvidia-smi;$JOB" 28 | 29 | #module_config.norm_pix_loss=True 30 | -------------------------------------------------------------------------------- /config/model/vit-b.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | datamodule: 4 | batch_size: ${divide:${data.batch_size},2} 5 | datamodule_eval: 6 | batch_size: ${divide:${data.batch_size},2} 7 | 8 | model_name: "vit-b" 9 | 10 | module: 11 | _target_: src.model.vit_mae.ViTMAEForPreTraining 12 | 13 | pl_module: 14 | _target_: src.model.module.ViTMAE 15 | base_learning_rate: 1.5e-4 16 | betas: [0.9, 0.95] 17 | weight_decay: 0.05 18 | optimizer_name: adamw_warmup 19 | warmup: 40 20 | eval_freq: 100 21 | eval_type: ${data.eval_type} 22 | eval_fn: ${data.eval_fn} 23 | eval_logit_fn: ${data.eval_logit_fn} 24 | learning_rate: ${compute_lr:${pl_module.base_learning_rate},${datamodule.batch_size}} 25 | 26 | module_config: 27 | _target_: transformers.ViTMAEConfig 28 | hidden_size: 768 29 | num_attention_head: 12 30 | intermediate_size: 1536 31 | norm_pix_loss: False 32 | attn_implementation: "eager" 33 | mask_ratio: ${masking.pixel_ratio} 34 | patch_size: ${data.patch_size} 35 | image_size: ${data.resolution} 36 | 37 | -------------------------------------------------------------------------------- /config/base/offline.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | seed: 0 3 | download: False 4 | 5 | hydra/job_logging: colorlog 6 | hydra/hydra_logging: colorlog 7 | 8 | base_logs_dir: ${base_logs_home}/logs/${oc.env:USER} 9 | base_outputs_dir: ${base_logs_home}/outputs/${oc.env:USER} 10 | run_tag: debug 11 | logs_dir: ${base_logs_dir}/${run_tag}/${now:%Y-%m-%d_%H-%M-%S} 12 | local_dir: ${base_outputs_dir}/${run_tag}/${now:%Y-%m-%d_%H-%M-%S} 13 | checkpoint_dir: ${logs_dir}/checkpoints 14 | 15 | wandb_dir: ${base_logs_home}/wandb 16 | wandb_datadir : ${base_logs_home}/wandb/artifacts 17 | wandb_cachedir: ${base_logs_home}/.cache/wandb 18 | wandb_configdir: ${base_logs_home}/.config/wandb 19 | 20 | wandb: 21 | project: ${project} 22 | notes: null 23 | tags: ${run_tag} 24 | log_model: False 25 | save_code: True 26 | reinit: True 27 | offline: True 28 | # group multi-node runs 29 | group: ${trainer.strategy} 30 | 31 | hydra: 32 | job: 33 | # when exceuting a job change to the logs_dir 34 | chdir: True 35 | run: 36 | dir: ${logs_dir} 37 | sweep: 38 | dir: ${logs_dir} 39 | 40 | -------------------------------------------------------------------------------- /config/base/online.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | seed: 0 3 | download: True 4 | 5 | hydra/job_logging: colorlog 6 | hydra/hydra_logging: colorlog 7 | 8 | base_logs_dir: ${base_logs_home}/logs/${oc.env:USER} 9 | base_outputs_dir: ${base_logs_home}/outputs/${oc.env:USER} 10 | run_tag: debug 11 | logs_dir: ${base_logs_dir}/${project}/${run_tag}/${now:%Y-%m-%d_%H-%M-%S} 12 | local_dir: ${base_outputs_dir}/${run_tag}/${now:%Y-%m-%d_%H-%M-%S} 13 | checkpoint_dir: ${logs_dir}/checkpoints 14 | 15 | wandb_dir: ${base_logs_home}/wandb 16 | wandb_datadir : ${base_logs_home}/wandb/artifacts 17 | wandb_cachedir: ${base_logs_home}/.cache/wandb 18 | wandb_configdir: ${base_logs_home}/.config/wandb 19 | 20 | wandb: 21 | project: ${project} 22 | notes: null 23 | tags: ${run_tag} 24 | log_model: True 25 | save_code: True 26 | reinit: True 27 | offline: False 28 | # group multi-node runs 29 | group: ${trainer.strategy} 30 | 31 | hydra: 32 | job: 33 | # when exceuting a job change to the logs_dir 34 | chdir: True 35 | run: 36 | dir: ${logs_dir} 37 | sweep: 38 | dir: ${logs_dir} 39 | 40 | -------------------------------------------------------------------------------- /tools/imagenet_val.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module Name: imagenet_val.py 3 | Author: Alice Bizeul 4 | Ownership: ETH Zürich - ETH AI Center 5 | Description: Order imagenet validation images into class folders 6 | """ 7 | 8 | import os 9 | import shutil 10 | 11 | # Path to the folder where all the images are unzipped 12 | source_folder = '~/ILSVRC2012_img/val' 13 | 14 | # Destination folder where you want to organize the images by class 15 | destination_folder = '~/ILSVRC2012_img/val_fixed' 16 | 17 | # Create the destination folder if it doesn't exist 18 | if not os.path.exists(destination_folder): 19 | os.makedirs(destination_folder) 20 | 21 | # Loop through all files in the source folder 22 | for filename in os.listdir(source_folder): 23 | if filename.endswith(".JPEG"): 24 | # Extract class number from the filename (assuming the format "classNumber_imageID.JPEG") 25 | class_number = filename.split('_')[0] 26 | 27 | # Define the path for the class folder 28 | class_folder = os.path.join(destination_folder, class_number) 29 | 30 | # Create the class folder if it doesn't exist 31 | if not os.path.exists(class_folder): 32 | os.makedirs(class_folder) 33 | 34 | # Move the image to the class folder 35 | src_path = os.path.join(source_folder, filename) 36 | dest_path = os.path.join(class_folder, filename) 37 | shutil.move(src_path, dest_path) 38 | -------------------------------------------------------------------------------- /config/transformations/mae.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | transformation: 4 | train: 5 | _target_: torchvision.transforms.Compose 6 | transforms: ${compose_train_transforms} 7 | val: 8 | _target_: torchvision.transforms.Compose 9 | transforms: ${compose_val_transforms} 10 | 11 | transformation_train_list: 12 | resize: 13 | _target_: torchvision.transforms.RandomResizedCrop 14 | size: 15 | - ${data.height} 16 | - ${data.width} 17 | scale: [0.2,1.0] 18 | interpolation: 3 19 | flip: 20 | _target_: torchvision.transforms.RandomHorizontalFlip 21 | tensor: 22 | _target_: torchvision.transforms.ToTensor 23 | normalize: 24 | _target_: src.utils.Normalize 25 | mean: ${data.mean} 26 | std: ${data.std} 27 | 28 | transformation_val_list: 29 | resize: 30 | _target_: torchvision.transforms.Resize 31 | size: 32 | - ${data.height} 33 | - ${data.width} 34 | interpolation: 3 35 | tensor: 36 | _target_: torchvision.transforms.ToTensor 37 | normalize: 38 | _target_: src.utils.Normalize 39 | mean: ${data.mean} 40 | std: ${data.std} 41 | 42 | compose_train_transforms: 43 | - ${transformation_train_list.resize} 44 | - ${transformation_train_list.flip} 45 | - ${transformation_train_list.tensor} 46 | - ${transformation_train_list.normalize} 47 | 48 | compose_val_transforms: 49 | - ${transformation_val_list.resize} 50 | - ${transformation_val_list.tensor} 51 | - ${transformation_val_list.normalize} -------------------------------------------------------------------------------- /tools/pca.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module Name: pca.py 3 | Author: Alice Bizeul 4 | Ownership: ETH Zürich - ETH AI Center 5 | Description: Compute pca and store eigenvalues and principal components 6 | """ 7 | 8 | import os 9 | import sys 10 | import glob 11 | import random 12 | 13 | sys.path.append("../") 14 | 15 | import matplotlib.pyplot as plt 16 | import numpy as np 17 | from PIL import Image 18 | import plotly.graph_objects as go 19 | from plotly.subplots import make_subplots 20 | import sklearn 21 | from sklearn.decomposition import PCA, IncrementalPCA 22 | import torch 23 | from torch.utils.data import DataLoader, Dataset, Subset 24 | import torchvision 25 | import torchvision.transforms as transforms 26 | 27 | random.seed(42) 28 | 29 | resolution=224 30 | name="imagenet" 31 | data_fn = torchvision.datasets.ImageFolder 32 | folder = "~/ILSVRC2012_img/train" 33 | 34 | transform = transforms.Compose([ 35 | transforms.Resize(resolution), 36 | transforms.ToTensor(), 37 | ]) 38 | 39 | trainset = data_fn(root=folder, transform=transform) 40 | 41 | # Create a DataLoader for the subset 42 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=len(trainset), shuffle=False) 43 | data_iter = iter(trainloader) 44 | images_np, _ = next(data_iter) 45 | 46 | images_np = images_np.numpy() 47 | pca = PCA() # You can adjust the number of components 48 | 49 | # Reshape the images to (num_samples, height * width * channels) 50 | num_samples = images_np.shape[0] 51 | original_shape = images_np.shape 52 | images_np = images_np.reshape(num_samples, -1) 53 | 54 | # Standardize 55 | mean, std = np.mean(images_flat, axis=0), np.std(images_flat, axis=0) 56 | images_flat = (images_flat - mean) / std 57 | 58 | # Step 4: Perform PCA 59 | pca.fit(images_np) 60 | 61 | np.save(f'~/pc_matrix_ipca.npy',pca.components_) 62 | np.save(f'~/eigenvalues_ipca.npy',pca.explained_variance_) 63 | np.save(f'~/eigenvalues_ratio_ipca.npy',pca.explained_variance_ratio_) -------------------------------------------------------------------------------- /config/dataset/imagenet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | data: 4 | name: imagenet 5 | resolution: 224 6 | channels: 3 7 | height: ${data.resolution} 8 | width: ${data.resolution} 9 | patch_size: 16 10 | batch_size: 512 # 256 for 40G, 512 for 80G 11 | eval_fn: 12 | _target_: torch.nn.CrossEntropyLoss 13 | eval_logit_fn: 14 | _target_: torch.nn.Softmax 15 | dim: -1 16 | eval_type: multiclass 17 | str_patch_size: ${convert_str:${data.patch_size}} 18 | mean: 19 | _target_: torch.tensor 20 | data: 21 | _target_: numpy.load 22 | file: /cluster/scratch/abizeul/imagenet_mean_reshaped.npy #[0.485, 0.456, 0.406] 23 | std: 24 | _target_: torch.tensor 25 | data: 26 | _target_: numpy.load 27 | file: /cluster/scratch/abizeul/imagenet_std_reshaped.npy #[0.229, 0.224, 0.225] 28 | classes: 1000 29 | 30 | datasets: 31 | train: 32 | _target_: torchvision.datasets.ImageFolder 33 | root: ${base_data_home}/ILSVRC2012_img/train 34 | transform: ${transformation.train} 35 | 36 | val: 37 | _target_: torchvision.datasets.ImageFolder 38 | root: ${base_data_home}/ILSVRC2012_img/val/ILSVRC2012 39 | transform: ${transformation.val} 40 | 41 | test: 42 | _target_: torchvision.datasets.ImageFolder 43 | root: ${base_data_home}/ILSVRC2012_img/val/ILSVRC2012 44 | transform: ${transformation.val} 45 | 46 | datamodule: 47 | _target_: src.dataloader.DataModule 48 | batch_size: ${data.batch_size} 49 | num_workers: 8 50 | classes: ${data.classes} 51 | channels: ${data.channels} 52 | resolution: ${data.resolution} 53 | 54 | datamodule_eval: 55 | _target_: src.dataloader.DataModule 56 | batch_size: ${data.batch_size} 57 | num_workers: ${datamodule.num_workers} 58 | classes: ${data.classes} 59 | channels: ${data.channels} 60 | resolution: ${data.resolution} 61 | 62 | 63 | extradata: 64 | pcamodule: 65 | _target_ : numpy.load 66 | file: ${base_data_home}/pc_matrix_pca.npy 67 | eigenratiomodule: 68 | _target_: numpy.load 69 | file: ${base_data_home}/eigenvalues_ratio_pca.npy 70 | -------------------------------------------------------------------------------- /config/dataset/tinyimagenet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | data: 4 | name: tinyimagenet 5 | resolution: 64 6 | channels: 3 7 | height: ${data.resolution} 8 | width: ${data.resolution} 9 | patch_size: 8 10 | batch_size: 512 11 | eval_fn: 12 | _target_: torch.nn.CrossEntropyLoss 13 | reduction: mean 14 | eval_logit_fn: 15 | _target_: torch.nn.Softmax 16 | dim: -1 17 | eval_type: multiclass 18 | str_patch_size: ${convert_str:${data.patch_size}} 19 | mean: 20 | _target_: torch.tensor 21 | data: 22 | _target_: numpy.load 23 | file: ${datasets.train.root}/mean_reshaped.npy 24 | std: 25 | _target_: torch.tensor 26 | data: 27 | _target_: numpy.load 28 | file: ${datasets.train.root}/std_reshaped.npy 29 | classes: 200 30 | task: 1 31 | 32 | datasets: 33 | train: 34 | _target_: torchvision.datasets.ImageFolder 35 | root: ${base_data_home}/tiny-imagenet-200/train 36 | transform: ${transformation.train} 37 | 38 | val: 39 | _target_: torchvision.datasets.ImageFolder 40 | root: ${base_data_home}/tiny-imagenet-200/val 41 | transform: ${transformation.val} 42 | 43 | test: 44 | _target_: torchvision.datasets.ImageFolder 45 | root: ${base_data_home}/tiny-imagenet-200/val 46 | transform: ${transformation.val} 47 | 48 | datamodule: 49 | _target_: src.dataloader.DataModule 50 | batch_size: ${data.batch_size} 51 | num_workers: 8 52 | classes: ${data.classes} 53 | channels: ${data.channels} 54 | resolution: ${data.resolution} 55 | name: ${data.name} 56 | 57 | datamodule_eval: 58 | _target_: src.dataloader.DataModule 59 | batch_size: ${data.batch_size} 60 | num_workers: ${datamodule.num_workers} 61 | classes: ${data.classes} 62 | channels: ${data.channels} 63 | resolution: ${data.resolution} 64 | name: ${data.name} 65 | 66 | 67 | extradata: 68 | pcamodule: 69 | _target_ : numpy.load 70 | file: ${base_data_home}/tiny-imagenet-200/train/pc_matrix.npy 71 | 72 | eigenratiomodule: 73 | _target_: numpy.load 74 | file: ${base_data_home}/tiny-imagenet-200/train/eigenvalues_ratio.npy 75 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.2.1 2 | aiohappyeyeballs==2.4.3 3 | aiohttp==3.11.6 4 | aiosignal==1.3.1 5 | antlr4-python3-runtime==4.9.3 6 | async-timeout==5.0.1 7 | attrs==24.2.0 8 | certifi==2024.8.30 9 | charset-normalizer==3.4.0 10 | click==8.1.7 11 | cloudpickle==3.1.0 12 | contourpy==1.3.1 13 | cycler==0.12.1 14 | docker-pycreds==0.4.0 15 | filelock==3.16.1 16 | fire==0.7.0 17 | fonttools==4.55.0 18 | frozenlist==1.5.0 19 | fsspec==2024.10.0 20 | gitdb==4.0.11 21 | GitPython==3.1.43 22 | huggingface-hub==0.26.2 23 | hydra-core==1.3.2 24 | idna==3.10 25 | imageio==2.36.0 26 | Jinja2==3.1.4 27 | joblib==1.4.2 28 | kaleido==0.2.1 29 | kiwisolver==1.4.7 30 | lazy_loader==0.4 31 | lightning-utilities==0.11.9 32 | MarkupSafe==3.0.2 33 | matplotlib==3.9.2 34 | medmnist==3.0.2 35 | mpmath==1.3.0 36 | multidict==6.1.0 37 | networkx==3.4.2 38 | numpy==2.1.3 39 | nvidia-cublas-cu12==12.1.3.1 40 | nvidia-cuda-cupti-cu12==12.1.105 41 | nvidia-cuda-nvrtc-cu12==12.1.105 42 | nvidia-cuda-runtime-cu12==12.1.105 43 | nvidia-cudnn-cu12==9.1.0.70 44 | nvidia-cufft-cu12==11.0.2.54 45 | nvidia-curand-cu12==10.3.2.106 46 | nvidia-cusolver-cu12==11.4.5.107 47 | nvidia-cusparse-cu12==12.1.0.106 48 | nvidia-nccl-cu12==2.21.5 49 | nvidia-nvjitlink-cu12==12.6.85 50 | nvidia-nvtx-cu12==12.1.105 51 | omegaconf==2.3.0 52 | packaging==24.2 53 | pandas==2.2.3 54 | pillow==11.0.0 55 | platformdirs==4.3.6 56 | plotly==5.24.1 57 | propcache==0.2.0 58 | protobuf==5.28.3 59 | psutil==6.1.0 60 | pyparsing==3.2.0 61 | python-dateutil==2.9.0.post0 62 | pytorch-lightning==2.4.0 63 | pytz==2024.2 64 | PyYAML==6.0.2 65 | regex==2024.11.6 66 | requests==2.32.3 67 | safetensors==0.4.5 68 | scikit-image==0.24.0 69 | scikit-learn==1.5.2 70 | scipy==1.14.1 71 | sentry-sdk==2.18.0 72 | setproctitle==1.3.4 73 | six==1.16.0 74 | smmap==5.0.1 75 | submitit==1.5.2 76 | sympy==1.13.1 77 | tenacity==9.0.0 78 | termcolor==2.5.0 79 | threadpoolctl==3.5.0 80 | tifffile==2024.9.20 81 | tokenizers==0.20.3 82 | torch==2.5.1+cu121 83 | torchmetrics==1.6.0 84 | torchvision==0.20.1+cu121 85 | tqdm==4.67.0 86 | transformers==4.46.3 87 | triton==3.1.0 88 | typing_extensions==4.12.2 89 | tzdata==2024.2 90 | urllib3==2.2.3 91 | wandb==0.18.7 92 | yarl==1.17.2 93 | -------------------------------------------------------------------------------- /config/dataset/cifar10.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | data: 4 | name: cifar10 5 | resolution: 32 6 | height: ${data.resolution} 7 | width: ${data.resolution} 8 | channels: 3 9 | batch_size: 512 10 | patch_size: 8 11 | eval_fn: 12 | _target_: torch.nn.CrossEntropyLoss 13 | reduction: mean 14 | eval_logit_fn: 15 | _target_: torch.nn.Softmax 16 | dim: -1 17 | eval_type: multiclass 18 | str_patch_size: ${convert_str:${data.patch_size}} 19 | mean: 20 | _target_: torch.tensor 21 | data: 22 | _target_: numpy.load 23 | file: ${base_data_home}/cifar-10-batches-py/mean_reshaped.npy 24 | std: 25 | _target_: torch.tensor 26 | data: 27 | _target_: numpy.load 28 | file: ${base_data_home}/cifar-10-batches-py/std_reshaped.npy 29 | classes: 10 30 | task: 1 31 | 32 | datasets: 33 | train: 34 | _target_: torchvision.datasets.CIFAR10 35 | root: ${base_data_home} 36 | train: True 37 | download: ${download} 38 | transform: ${transformation.train} 39 | 40 | val: 41 | _target_: torchvision.datasets.CIFAR10 42 | root: ${base_data_home} 43 | train: False 44 | download: ${download} 45 | transform: ${transformation.val} 46 | 47 | test: 48 | _target_: torchvision.datasets.CIFAR10 49 | root: ${base_data_home} 50 | train: False 51 | download: ${download} 52 | transform: ${transformation.val} 53 | 54 | datamodule: 55 | _target_: src.dataloader.DataModule 56 | batch_size: ${data.batch_size} 57 | num_workers: 8 58 | classes: ${data.classes} 59 | channels: ${data.channels} 60 | name: ${data.name} 61 | resolution: 62 | - ${data.height} 63 | - ${data.width} 64 | 65 | datamodule_eval: 66 | _target_: src.dataloader.DataModule 67 | batch_size: ${data.batch_size} 68 | num_workers: ${datamodule.num_workers} 69 | classes: ${data.classes} 70 | channels: ${data.channels} 71 | resolution: ${data.resolution} 72 | name: ${data.name} 73 | 74 | extradata: 75 | pcamodule: 76 | _target_ : numpy.load 77 | file: ${base_data_home}/cifar-10-batches-py/pc_matrix.npy 78 | 79 | eigenratiomodule: 80 | _target_: numpy.load 81 | file: ${base_data_home}/cifar-10-batches-py/eigenvalues_ratio.npy 82 | 83 | -------------------------------------------------------------------------------- /config/dataset/pathmnist.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | data: 4 | name: pathmnist 5 | resolution: 64 6 | channels: 3 7 | height: ${data.resolution} 8 | width: ${data.resolution} 9 | patch_size: 8 10 | batch_size: 512 11 | eval_fn: 12 | _target_: torch.nn.CrossEntropyLoss 13 | reduction: mean 14 | eval_logit_fn: 15 | _target_: torch.nn.Softmax 16 | dim: -1 17 | eval_type: multiclass 18 | str_patch_size: ${convert_str:${data.patch_size}} 19 | mean: 20 | _target_: torch.tensor 21 | data: 22 | _target_: numpy.load 23 | file: ${datasets.train.root}/pathmnist_mean_reshaped.npy 24 | std: 25 | _target_: torch.tensor 26 | data: 27 | _target_: numpy.load 28 | file: ${datasets.train.root}/pathmnist_std_reshaped.npy 29 | classes: 9 30 | task: 1 31 | 32 | datasets: 33 | train: 34 | _target_: medmnist.PathMNIST 35 | root: ${base_data_home}/medmnist 36 | split: train 37 | download: ${download} 38 | size: ${data.resolution} 39 | transform: ${transformation.train} 40 | 41 | val: 42 | _target_: medmnist.PathMNIST 43 | root: ${base_data_home}/medmnist 44 | split: val 45 | download: ${download} 46 | size: ${data.resolution} 47 | transform: ${transformation.train} 48 | 49 | test: 50 | _target_: medmnist.PathMNIST 51 | root: ${base_data_home}/medmnist 52 | split: val 53 | download: ${download} 54 | size: ${data.resolution} 55 | transform: ${transformation.train} 56 | 57 | datamodule: 58 | _target_: src.dataloader.DataModule 59 | batch_size: ${data.batch_size} 60 | num_workers: 8 61 | classes: ${data.classes} 62 | channels: ${data.channels} 63 | resolution: ${data.resolution} 64 | name: ${data.name} 65 | 66 | datamodule_eval: 67 | _target_: src.dataloader.DataModule 68 | batch_size: ${data.batch_size} 69 | num_workers: ${datamodule.num_workers} 70 | classes: ${data.classes} 71 | channels: ${data.channels} 72 | resolution: ${data.resolution} 73 | name: ${data.name} 74 | 75 | 76 | extradata: 77 | pcamodule: 78 | _target_ : numpy.load 79 | file: ${base_data_home}/medmnist/pathmnist_pc_matrix.npy 80 | 81 | eigenratiomodule: 82 | _target_: numpy.load 83 | file: ${base_data_home}/medmnist/pathmnist_eigenvalues_ratio.npy 84 | -------------------------------------------------------------------------------- /config/dataset/dermamnist.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | data: 4 | name: dermamnist 5 | resolution: 64 6 | channels: 3 7 | height: ${data.resolution} 8 | width: ${data.resolution} 9 | patch_size: 8 10 | batch_size: 512 11 | eval_fn: 12 | _target_: torch.nn.CrossEntropyLoss 13 | reduction: mean 14 | eval_logit_fn: 15 | _target_: torch.nn.Softmax 16 | dim: -1 17 | eval_type: multiclass 18 | str_patch_size: ${convert_str:${data.patch_size}} 19 | mean: 20 | _target_: torch.tensor 21 | data: 22 | _target_: numpy.load 23 | file: ${datasets.train.root}/dermamnist_mean_reshaped.npy 24 | std: 25 | _target_: torch.tensor 26 | data: 27 | _target_: numpy.load 28 | file: ${datasets.train.root}/dermamnist_std_reshaped.npy 29 | classes: 7 30 | task: 1 31 | 32 | datasets: 33 | train: 34 | _target_: medmnist.DermaMNIST 35 | root: ${base_data_home}/medmnist 36 | split: train 37 | download: ${download} 38 | size: ${data.resolution} 39 | transform: ${transformation.train} 40 | 41 | val: 42 | _target_: medmnist.DermaMNIST 43 | root: ${base_data_home}/medmnist 44 | split: val 45 | download: ${download} 46 | size: ${data.resolution} 47 | transform: ${transformation.train} 48 | 49 | test: 50 | _target_: medmnist.DermaMNIST 51 | root: ${base_data_home}/medmnist 52 | split: val 53 | download: ${download} 54 | size: ${data.resolution} 55 | transform: ${transformation.train} 56 | 57 | datamodule: 58 | _target_: src.dataloader.DataModule 59 | batch_size: ${data.batch_size} 60 | num_workers: 8 61 | classes: ${data.classes} 62 | channels: ${data.channels} 63 | resolution: ${data.resolution} 64 | name: ${data.name} 65 | 66 | datamodule_eval: 67 | _target_: src.dataloader.DataModule 68 | batch_size: ${data.batch_size} 69 | num_workers: ${datamodule.num_workers} 70 | classes: ${data.classes} 71 | channels: ${data.channels} 72 | resolution: ${data.resolution} 73 | name: ${data.name} 74 | 75 | extradata: 76 | pcamodule: 77 | _target_ : numpy.load 78 | file: ${base_data_home}/medmnist/dermamnist_pc_matrix.npy 79 | 80 | eigenratiomodule: 81 | _target_: numpy.load 82 | file: ${base_data_home}/medmnist/dermamnist_eigenvalues_ratio.npy 83 | -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # setup environment 4 | conda activate mae 5 | 6 | # ressource specs 7 | NUM_WORKERS=8 8 | TIME=4:00:00 9 | MEM_PER_CPU=2G 10 | MEM_PER_GPU=12G 11 | 12 | # user name 13 | USER_MACHINE=abizeul_euler 14 | 15 | # which experiment and masking ratio to run, see config/experiment 16 | EXPERIMENT=pcmae_cifar10_pc 17 | MASK=0.2 18 | EPOCH=800 19 | 20 | # for imagenet, ensure to request a GPU with 80GB of RAM 21 | if [ "$EXPERIMENT" == "pcmae_imagenet_pc" ]; then 22 | MEM_PER_GPU=80G 23 | fi 24 | 25 | # Linear probe 26 | RUN_TAG=""$EXPERIMENT"_mask_"$MASK"_eval_"$EPOCH"_lin" 27 | NAME="../$RUN_TAG" 28 | JOB="python main.py user=abizeul_euler experiment=$EXPERIMENT masking.pc_ratio=$MASK trainer=eval_lin checkpoint=pretrained checkpoint.epoch=$EPOCH run_tag=$RUN_TAG" 29 | sbatch -o "$NAME" -n 1 --cpus-per-task "$NUM_WORKERS" --mem-per-cpu="$MEM_PER_CPU" --time="$TIME" -p gpu --gpus=1 --gres=gpumem:"$MEM_PER_GPU" --wrap="nvidia-smi;$JOB" 30 | 31 | # MLP probe 32 | RUN_TAG=""$EXPERIMENT"_mask_"$MASK"_eval_"$EPOCH"_mlp" 33 | NAME="../$RUN_TAG" 34 | JOB="python main.py user=abizeul_euler experiment=$EXPERIMENT masking.pc_ratio=$MASK trainer=eval_mlp checkpoint=pretrained checkpoint.epoch=$EPOCH run_tag=$RUN_TAG" 35 | sbatch -o "$NAME" -n 1 --cpus-per-task "$NUM_WORKERS" --mem-per-cpu="$MEM_PER_CPU" --time="$TIME" -p gpu --gpus=1 --gres=gpumem:"$MEM_PER_GPU" --wrap="nvidia-smi;$JOB" 36 | 37 | # Fine-tuning 38 | RUN_TAG=""$EXPERIMENT"_mask_"$MASK"_eval_"$EPOCH"_fine" 39 | NAME="../$RUN_TAG" 40 | JOB="python main.py user=abizeul_euler experiment=$EXPERIMENT masking.pc_ratio=$MASK trainer=eval_fine checkpoint=pretrained checkpoint.epoch=$EPOCH run_tag=$RUN_TAG" 41 | sbatch -o "$NAME" -n 1 --cpus-per-task "$NUM_WORKERS" --mem-per-cpu="$MEM_PER_CPU" --time="$TIME" -p gpu --gpus=1 --gres=gpumem:"$MEM_PER_GPU" --wrap="nvidia-smi;$JOB" 42 | 43 | # k-Nearest Neighbors 44 | RUN_TAG=""$EXPERIMENT"_mask_"$MASK"_eval_"$EPOCH"_knn" 45 | NAME="../$RUN_TAG" 46 | JOB="python main.py user=abizeul_euler experiment=$EXPERIMENT masking.pc_ratio=$MASK trainer=eval_knn checkpoint=pretrained checkpoint.epoch=$EPOCH run_tag=$RUN_TAG" 47 | sbatch -o "$NAME" -n 1 --cpus-per-task "$NUM_WORKERS" --mem-per-cpu="$MEM_PER_CPU" --time="$TIME" -p gpu --gpus=1 --gres=gpumem:"$MEM_PER_GPU" --wrap="nvidia-smi;$JOB" 48 | 49 | 50 | -------------------------------------------------------------------------------- /config/dataset/bloodmnist.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | data: 4 | name: bloodmnist 5 | resolution: 64 6 | height: ${data.resolution} 7 | width: ${data.resolution} 8 | channels: 3 9 | batch_size: 512 10 | patch_size: 8 11 | eval_fn: 12 | _target_: torch.nn.CrossEntropyLoss 13 | reduction: mean 14 | eval_logit_fn: 15 | _target_: torch.nn.Softmax 16 | dim: -1 17 | eval_type: multiclass 18 | str_patch_size: ${convert_str:${data.patch_size}} 19 | mean: 20 | _target_: torch.tensor 21 | data: 22 | _target_: numpy.load 23 | file: ${datasets.train.root}/bloodmnist_mean_reshaped.npy 24 | std: 25 | _target_: torch.tensor 26 | data: 27 | _target_: numpy.load 28 | file: ${datasets.train.root}/bloodmnist_std_reshaped.npy 29 | classes: 8 30 | task: 1 31 | 32 | datasets: 33 | train: 34 | _target_: medmnist.BloodMNIST 35 | root: ${base_data_home}/medmnist 36 | split: train 37 | download: ${download} 38 | size: ${data.resolution} 39 | transform: ${transformation.train} 40 | 41 | val: 42 | _target_: medmnist.BloodMNIST 43 | root: ${base_data_home}/medmnist 44 | split: val 45 | download: ${download} 46 | size: ${data.resolution} 47 | transform: ${transformation.train} 48 | 49 | test: 50 | _target_: medmnist.BloodMNIST 51 | root: ${base_data_home}/medmnist 52 | split: val 53 | download: ${download} 54 | size: ${data.resolution} 55 | transform: ${transformation.train} 56 | 57 | datamodule: 58 | _target_: srd.dataloader.DataModule 59 | batch_size: ${data.batch_size} 60 | num_workers: 8 61 | classes: ${data.classes} 62 | channels: ${data.channels} 63 | resolution: ${data.resolution} 64 | name: ${data.name} 65 | 66 | datamodule_eval: 67 | _target_: src.dataloader.DataModule 68 | batch_size: ${data.batch_size} 69 | num_workers: ${datamodule.num_workers} 70 | classes: ${data.classes} 71 | channels: ${data.channels} 72 | resolution: ${data.resolution} 73 | name: ${data.name} 74 | 75 | extradata: 76 | pcamodule: 77 | _target_ : numpy.load 78 | file: ${base_data_home}/medmnist/bloodmnist_pc_matrix.npy 79 | 80 | eigenratiomodule: 81 | _target_: numpy.load 82 | file: ${base_data_home}/medmnist/bloodmnist_eigenvalues_ratio.npy 83 | 84 | -------------------------------------------------------------------------------- /src/plotting.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module Name: main.py 3 | Author: Alice Bizeul 4 | Ownership: ETH Zürich - ETH AI Center 5 | """ 6 | 7 | import plotly.graph_objs as go 8 | import plotly.io as pio 9 | from plotly.subplots import make_subplots 10 | import matplotlib.pyplot as plt 11 | import os 12 | 13 | def plot_loss(loss,name_loss,save_dir,name_file=""): 14 | # Create a subplot with 1 row and 2 columns 15 | os.makedirs(save_dir, exist_ok=True) 16 | fig = make_subplots(rows=1, cols=2, subplot_titles=('Linear Scale', 'Log Scale')) 17 | 18 | # First subplot (Linear scale) 19 | fig.add_trace( 20 | go.Scatter(y=loss, mode='lines', name='Linear Scale'), 21 | row=1, col=1 22 | ) 23 | 24 | # Second subplot (Log scale) 25 | fig.add_trace( 26 | go.Scatter(y=loss, mode='lines', name='Log Scale'), 27 | row=1, col=2 28 | ) 29 | 30 | # Set the y-axis of the second subplot to log scale 31 | fig.update_yaxes(type="log", row=1, col=2) 32 | 33 | # Set common y-axis label for both plots 34 | fig.update_yaxes(title_text=name_loss, row=1, col=1) 35 | 36 | # Update layout for better presentation 37 | fig.update_layout( 38 | height=500, # Adjust the figure height 39 | width=1000, # Adjust the figure width 40 | title_text="Loss Tracking", 41 | ) 42 | 43 | # Save the figure as a PNG file 44 | output_path = os.path.join(save_dir, f"loss{name_file}.png") 45 | fig.write_image(output_path, scale=2) 46 | 47 | def plot_performance(x,y,save_dir,name=""): 48 | os.makedirs(save_dir, exist_ok=True) 49 | # Create a Plotly figure 50 | fig = go.Figure() 51 | 52 | # Add a line plot 53 | fig.add_trace(go.Scatter( 54 | x=[a+1 for a in list(x)], # X-axis: class labels 55 | y=list(y), # Y-axis: accuracy values 56 | mode='lines+markers', 57 | name="Downstream prediction" 58 | )) 59 | 60 | # Set axis labels and title 61 | fig.update_layout( 62 | title='Model Performance', 63 | xaxis_title='Epochs', 64 | yaxis_title='Accuracy', 65 | font=dict(size=15) 66 | ) 67 | 68 | # Save the figure as a PNG file 69 | output_path = os.path.join(save_dir,f"performance_{name}.png") 70 | fig.write_image(output_path, scale=2) 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /config/user/abizeul_euler.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | base_logs_home: /cluster/scratch/abizeul/mae_alice 4 | base_data_home: /cluster/scratch/abizeul 5 | 6 | checkpoint_folders: 7 | vit-t: 8 | tinyimagenet: 9 | pixel: 10 | pixel: 11 | _75: 12 | _8: mae_tiny_pc_0.0_pixel_0.75/2024-08-27_19-14-36 13 | sampling: 14 | _None: 15 | _8: mae_sampling_tiny/2024-09-24_08-25-21 16 | pc: 17 | pc: 18 | _20: 19 | _8: pcmae_tiny_pc_pc_pc_0.2_lossA/2024-12-05_23-10-16 20 | sampling_pc: 21 | _None: 22 | _8: pcmae_tiny_pcsampling_lossA/2024-12-08_16-34-35 23 | cifar10: 24 | pixel: 25 | pixel: 26 | _75: 27 | _8: mae_cifar10_pc_pc_0.2/2025-02-10_10-57-45 28 | sampling: 29 | _None: 30 | _8: mae_sampling_cifar10/2024-09-25_19-25-28 31 | pc: 32 | pc: 33 | _20: 34 | _8: pcmae_cifar10_pc_pc_pc_0.2/2025-02-10_11-41-15 35 | sampling_pc: 36 | _None: 37 | _8: pcmae_cifar10_pcsampling_lossA/2024-12-08_10-35-07 38 | bloodmnist: 39 | pixel: 40 | pixel: 41 | _75: 42 | _8: mae_blood_pc_0.0_pixel_0.75/2024-08-27_19-14-36 43 | sampling: 44 | _None: 45 | _8: mae_sampling_blood/2024-09-23_19-22-33 46 | pc: 47 | pc: 48 | _20: 49 | _8: pcmae_blood_pc_pc_pc_0.2_lossA/2024-12-07_17-15-47 50 | sampling_pc: 51 | _None: 52 | _8: pcmae_blood_pcsampling_lossA/2024-12-08_16-34-35 53 | dermamnist: 54 | pixel: 55 | pixel: 56 | _75: 57 | _8: mae_derma_pc_0.0_pixel_0.75/2024-08-27_19-14-57 58 | sampling: 59 | _None: 60 | _8: mae_sampling_derma/2024-09-23_19-22-34 61 | pc: 62 | pc: 63 | _20: 64 | _8: pcmae_derma_pc_pc_pc_0.2_lossA/2024-12-06_04-00-36 65 | sampling_pc: 66 | _None: 67 | _8: pcmae_derma_pcsampling_lossA/2024-12-08_10-35-07 68 | pathmnist: 69 | pixel: 70 | pixel: 71 | _75: 72 | _8: mae_path_pc_0.0_pixel_0.75/2024-08-27_19-15-43 73 | sampling: 74 | _None: 75 | _8: mae_sampling_path/2024-09-24_08-25-21 76 | pc: 77 | pc: 78 | _20: 79 | _8: pcmae_path_pc_pc_pc_0.2_lossA/2024-12-06_06-42-49 80 | sampling_pc: 81 | _None: 82 | _8: pcmae_path_pcsampling_lossA/2024-12-08_16-34-35 83 | -------------------------------------------------------------------------------- /src/dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module Name: dataloader.py 3 | Author: Alice Bizeul 4 | Ownership: ETH Zürich - ETH AI Center 5 | """ 6 | 7 | # Standard library imports 8 | import os 9 | import random 10 | import time 11 | from typing import Optional 12 | 13 | # Third-party library imports 14 | import numpy as np 15 | import pytorch_lightning as pl 16 | import torch 17 | import torch.nn as nn 18 | from torch.utils.data import DataLoader, Dataset 19 | from torchvision import datasets, transforms 20 | from torchvision.transforms.functional import InterpolationMode 21 | import torchvision 22 | 23 | # Hydra imports 24 | from hydra.utils import instantiate 25 | 26 | USER_NAME = os.environ.get("USER") 27 | 28 | class PairedDataset(Dataset): 29 | def __init__(self, dataset, masking, extra_data): 30 | 31 | self.dataset = dataset 32 | self.masking = masking 33 | 34 | self.is_pc_mask = self.masking.type == "pc" 35 | 36 | if self.is_pc_mask: 37 | assert "eigenratiomodule" in extra_data and "pcamodule" in extra_data 38 | self.eigenvalues = np.array(extra_data['eigenratiomodule']) 39 | self.cum_eigenvalues = np.cumsum(self.eigenvalues) 40 | self.pc_mask = None 41 | self.find_threshold = lambda eigenvalues ,ratio: np.argmin(np.abs(np.cumsum(eigenvalues) - ratio)) 42 | self.get_pcs_index = np.arange 43 | else: 44 | self.pc_mask = 0 45 | 46 | def __len__(self): 47 | return len(self.dataset) 48 | 49 | def __getitem__(self, idx): 50 | 51 | # Load the images 52 | img, y = self.dataset[idx] 53 | pc_mask = self.pc_mask 54 | 55 | if self.masking.type == "pc": 56 | 57 | if self.masking.strategy == "sampling_pc": 58 | index = torch.randperm(self.eigenvalues.shape[0]).numpy() 59 | pc_ratio = random.uniform(0.1, 0.9) 60 | threshold = self.find_threshold(self.eigenvalues[index],pc_ratio) 61 | pc_mask = index[:threshold] 62 | 63 | elif self.masking.strategy == "pc": 64 | index = np.random.permutation(self.eigenvalues.shape[0]) 65 | threshold = self.find_threshold(self.eigenvalues[index],self.masking.pc_ratio) 66 | pc_mask = index[:threshold] 67 | 68 | elif self.masking.type == "pixel": 69 | if self.masking.strategy == "sampling": 70 | pc_mask = random.uniform(0.1, 0.9) 71 | 72 | return img, y, pc_mask 73 | 74 | class DataModule(pl.LightningDataModule): 75 | def __init__( 76 | self, 77 | data, 78 | masking, 79 | extra_data =None, 80 | batch_size: int = 512, 81 | num_workers: int = 8, 82 | classes: int =10, 83 | channels: int =3, 84 | resolution: int =32, 85 | name: str =None, 86 | ): 87 | super().__init__() 88 | self.batch_size = batch_size 89 | self.num_workers = num_workers 90 | self.num_classes = classes 91 | self.input_channels = channels 92 | self.image_size = resolution 93 | self.masking = masking 94 | self.extra_data = extra_data 95 | self.datasets = data 96 | self.name = name 97 | 98 | def setup(self, stage): 99 | self.train_dataset = PairedDataset( 100 | dataset=self.datasets["train"], 101 | masking=self.masking, 102 | extra_data=self.extra_data 103 | ) 104 | 105 | self.val_dataset = self.datasets["val"] 106 | self.num_val_samples = len(self.val_dataset) 107 | self.test_dataset = self.datasets["test"] 108 | 109 | def collate_fn(self,batch): 110 | """ 111 | Custom collate function to handle variable-sized pc_mask. 112 | Pads the pc_mask to the size of the largest pc_mask in the batch. 113 | """ 114 | 115 | imgs, labels, pc_masks = zip(*batch) 116 | 117 | imgs = torch.stack(imgs) 118 | labels = torch.tensor(labels) 119 | pc_masks = torch.tensor(pc_masks[0]) 120 | return imgs, labels, pc_masks 121 | 122 | def train_dataloader(self) -> DataLoader: 123 | training_loader = DataLoader( 124 | self.train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=False, num_workers=self.num_workers, collate_fn=self.collate_fn if (self.masking.type == "pc" and self.masking.strategy in ["sampling_pc","pc"]) else None 125 | ) 126 | return training_loader 127 | 128 | def val_dataloader(self): 129 | loader = DataLoader( 130 | self.val_dataset, batch_size=self.batch_size, shuffle=False, drop_last=False, num_workers=self.num_workers 131 | ) 132 | return loader 133 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module Name: main.py 3 | Author: Alice Bizeul 4 | Ownership: ETH Zürich - ETH AI Center 5 | """ 6 | 7 | import hydra 8 | from hydra.utils import instantiate 9 | from hydra.core.hydra_config import HydraConfig 10 | from omegaconf import DictConfig, OmegaConf 11 | 12 | import torch 13 | from torch import nn 14 | import torch.nn as nn 15 | import torchvision 16 | import torchvision.datasets 17 | from torchvision.datasets import CIFAR10 18 | import pytorch_lightning as pl 19 | import torch.nn.functional as F 20 | from pytorch_lightning.callbacks import ModelCheckpoint 21 | 22 | import os 23 | import logging 24 | import numpy as np 25 | import random 26 | import matplotlib.pyplot as plt 27 | import csv 28 | import medmnist 29 | import numpy 30 | import transformers 31 | from transformers import ViTMAEConfig 32 | 33 | import src.model 34 | from src.model.module import ViTMAE 35 | from src.model.module_lin import ViTMAE_lin 36 | from src.model.module_knn import ViTMAE_knn 37 | from src.model.vit_mae import ViTMAEForPreTraining 38 | from src.dataloader import DataModule 39 | from src.utils import ( 40 | print_config, 41 | setup_wandb, 42 | get_git_hash, 43 | load_checkpoints, 44 | Normalize 45 | ) 46 | 47 | # Configure logging 48 | log = logging.getLogger(__name__) 49 | git_hash = get_git_hash() 50 | def create_lambda_transform(mean, std): 51 | return torchvision.transforms.Lambda(lambda sample: (sample - mean) / std) 52 | OmegaConf.register_new_resolver('divide', lambda a, b: int(int(a)/b)) 53 | OmegaConf.register_new_resolver('multiply', lambda a, b: int(int(a)*b)) 54 | OmegaConf.register_new_resolver("compute_lr", lambda base_lr, batch_size: base_lr * (batch_size / 256)) 55 | OmegaConf.register_new_resolver("decimal_2_percent", lambda decimal: int(100*decimal) if decimal is not None else decimal) 56 | OmegaConf.register_new_resolver("convert_str", lambda number: "_"+str(number)) 57 | OmegaConf.register_new_resolver("substract_one", lambda number: number-1) 58 | OmegaConf.register_new_resolver('to_tuple', lambda a, b, c: (a,b,c)) 59 | OmegaConf.register_new_resolver('as_tuple', lambda *args: tuple(args)) 60 | 61 | # Main function 62 | @hydra.main(version_base="1.2", config_path="config", config_name="train_defaults.yaml") 63 | def main(config: DictConfig) -> None: 64 | 65 | # Setup 66 | print_config(config) 67 | pl.seed_everything(config.seed) 68 | hydra_core_config = HydraConfig.get() 69 | wandb_logger = setup_wandb( 70 | config, log, git_hash, {"job_id": hydra_core_config.job.name} 71 | ) 72 | 73 | # Creating data 74 | datamodule = instantiate( 75 | config.datamodule, 76 | data = config.datasets, 77 | masking = config.masking, 78 | extra_data = config.extradata, 79 | ) 80 | 81 | # Creating model 82 | vit_config = instantiate(config.module_config) 83 | vit = instantiate(config.module,vit_config) 84 | model_train = instantiate( 85 | config.pl_module, 86 | model=vit, 87 | datamodule = datamodule, 88 | save_dir=config.local_dir 89 | ) 90 | model_train = load_checkpoints(model_train, config.checkpoint_fn) 91 | 92 | # Model checkpointing 93 | checkpoint_callback = ModelCheckpoint( 94 | dirpath=config.checkpoint_dir, # Directory where to save the checkpoints 95 | filename='{epoch:02d}-{train_loss:.2f}', # Filename format 96 | save_top_k=-1, # Save all checkpoints 97 | save_weights_only=False, # Save the full model (True for weights only) 98 | every_n_epochs=100 # Save every epoch 99 | ) 100 | 101 | # Runing training (with eval on masked data to track behavior/convergence) 102 | trainer_configs = OmegaConf.to_container(config.trainer, resolve=True) 103 | trainer = pl.Trainer( 104 | **trainer_configs, 105 | logger=wandb_logger, 106 | enable_checkpointing = True, 107 | num_sanity_val_steps=0, 108 | callbacks=[checkpoint_callback], 109 | check_val_every_n_epoch=config.pl_module.eval_freq, 110 | ) 111 | print("------------------------- Start Training") 112 | trainer.fit(model_train, datamodule=datamodule) 113 | print("------------------------- End Training") 114 | 115 | 116 | # Final evaluation: original data, no pixel or pc masking, MAE eval protocol 117 | eval_configs = OmegaConf.to_container(config.evaluator, resolve=True) 118 | datamodule = instantiate( 119 | config.datamodule_eval, 120 | masking = {"type":"pixel","strategy":"pixel"}, 121 | data = config.datasets, 122 | ) 123 | 124 | del trainer, vit 125 | for i in range(config.data.task): 126 | model_eval = instantiate( 127 | config=config.pl_module_eval, 128 | model=model_train.model, 129 | datamodule=datamodule, 130 | save_dir=config.local_dir, 131 | task=i 132 | ) 133 | evaluator = pl.Trainer( 134 | **eval_configs, 135 | logger=wandb_logger, 136 | enable_checkpointing = False, 137 | num_sanity_val_steps=0, 138 | check_val_every_n_epoch=1 139 | ) 140 | print(f"------------------------- Start Evaluation: lin probe for task {i}") 141 | evaluator.fit(model_eval, datamodule=datamodule) 142 | print(f"------------------------- End Evaluation: lin probe for task {i}") 143 | 144 | 145 | if __name__ == "__main__": 146 | main() 147 | 148 | -------------------------------------------------------------------------------- /src/model/module_knn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module Name: module_knn.py 3 | Author: Alice Bizeul 4 | Ownership: ETH Zürich - ETH AI Center 5 | """ 6 | 7 | # Standard library imports 8 | import os 9 | import time 10 | from typing import Any, Dict, List, Optional 11 | import csv 12 | 13 | # Third-party library imports 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | import pytorch_lightning as pl 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | import torchmetrics 21 | from torch import Tensor 22 | from torch.nn.parameter import Parameter 23 | from torchvision.models import resnet18 24 | import wandb 25 | 26 | # Local imports 27 | from ..plotting import plot_loss, plot_performance 28 | from ..utils import save_attention_maps, save_attention_maps_batch, save_reconstructed_images 29 | 30 | # Lightning Module definition 31 | class ViTMAE_knn(pl.LightningModule): 32 | def __init__( 33 | self, 34 | model, 35 | k=[3,4,5,6,7,8,9,10,11,12,4,16,18,20], 36 | save_dir: str =None, 37 | evaluated_epoch: int =800, 38 | datamodule: Optional[pl.LightningDataModule] = None, 39 | task :int =0, 40 | ): 41 | super().__init__() 42 | self.model = model # Your base model (e.g., ResNet or any other embedding model) 43 | self.model.config.mask_ratio = 0.0 44 | self.model.vit.embeddings.config.mask_ratio=0.0 45 | self.task=task 46 | 47 | self.k = k # Number of nearest neighbors 48 | 49 | self.datamodule = datamodule 50 | self.num_classes = datamodule.num_classes # Number of classes for classification 51 | 52 | self.online_val_accuracy = torchmetrics.Accuracy( 53 | task="multiclass", num_classes=self.num_classes, top_k=1 54 | ) 55 | self.classifier = nn.Linear(model.config.hidden_size, self.num_classes) 56 | 57 | self.data_embeddings = [] 58 | self.data_labels = [] 59 | 60 | self.save_dir = save_dir 61 | self.performance = {} 62 | self.evaluated_epoch = evaluated_epoch 63 | 64 | def forward(self, x): 65 | return self.model(x) 66 | 67 | def shared_step(self, batch: Tensor, stage: str = "train", batch_idx: int =None): 68 | if stage == "train": 69 | img, y, _ = batch 70 | cls, _ = self.model(img,return_rep=True) 71 | self.data_embeddings.append(cls.detach()) 72 | self.data_labels.append(y) 73 | 74 | return None 75 | else: 76 | # Validation logic 77 | img, y = batch 78 | cls, _ = self.model(img,return_rep=True) 79 | distances = torch.cdist(cls.detach(), self.data_embeddings, p=2) # L2 distance 80 | accuracy_metric = getattr(self, f"online_val_accuracy") 81 | 82 | # Get the indices of the k nearest neighbors 83 | for k in self.k: 84 | _, indices = torch.topk(distances, k=k, dim=1, largest=False) 85 | pred_labels, _ = torch.mode(self.data_labels[indices].squeeze()) 86 | accuracy_metric(pred_labels.squeeze(), y.squeeze()) 87 | 88 | self.log( 89 | f"final_val_accuracy_{self.evaluated_epoch}_knn_k_{k}", 90 | accuracy_metric, 91 | prog_bar=True, 92 | sync_dist=True, 93 | on_epoch=True, 94 | on_step=False, 95 | ) 96 | 97 | if batch_idx == 0 and (k not in list(self.performance.keys())): 98 | self.performance[k]=[] 99 | self.performance[k].append(sum(pred_labels.squeeze()==y.squeeze()).item()) 100 | 101 | return None 102 | 103 | def training_step(self, batch, batch_idx): 104 | loss = self.shared_step(batch, stage="train", batch_idx=batch_idx) 105 | return loss 106 | 107 | def validation_step(self, batch, batch_idx): 108 | loss = self.shared_step(batch, stage="val", batch_idx=batch_idx) 109 | return loss 110 | 111 | def test_step(self, batch, batch_idx): 112 | loss = self.shared_step(batch, stage="test", batch_idx=batch_idx) 113 | return loss 114 | 115 | def on_validation_epoch_start(self,): 116 | self.data_embeddings = torch.cat(self.data_embeddings, dim=0) 117 | self.data_labels = torch.cat(self.data_labels, dim=0) 118 | return 119 | 120 | def on_validation_epoch_end(self): 121 | for k in self.k: 122 | self.performance[k] = sum(self.performance[k])/self.datamodule.num_val_samples 123 | if (self.current_epoch+1)%10 == 0: 124 | plot_performance(list(self.performance.keys()),list(self.performance.values()),self.save_dir,name=f"val_final_{self.evaluated_epoch}_knn") 125 | 126 | def on_fit_end(self): 127 | # Write to a CSV file 128 | with open(os.path.join(self.save_dir,f'performance_final_{self.evaluated_epoch}_knn_task_{self.task}.csv'), 'w', newline='') as csvfile: 129 | writer = csv.writer(csvfile) 130 | writer.writerow(['Eval Epoch', 'Test Accuracy']) 131 | for epoch in list(self.performance.keys()): 132 | # Assuming you have the accuracy for each epoch stored in a list 133 | writer.writerow([epoch, round(100*self.performance[epoch],2)]) 134 | return 135 | 136 | def configure_optimizers(self): 137 | optimizer = torch.optim.AdamW(self.classifier.parameters()) 138 | return [optimizer] 139 | 140 | 141 | 142 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Principal Masked Autoencoders 2 | 3 | Official PyTorch codebase for **P**rincipal **M**asked **A**uto-**E**ncoders (PMAE) presented in **From Pixels to Components: Eigenvector Masking for Visual Representation Learning** 4 | [\[arXiv\]](http://arxiv.org/abs/2502.06314). 5 | 6 | ## Method 7 | PMAE introduces an alternative approach to pixel masking for visual representation learning by masking principal components instead of pixel patches. This repository builds on top of the Masked Auto-Encoder (MAE, [\[arXiv\]](https://arxiv.org/pdf/2111.06377)) a prominent baseline for Masked Image Modelling (MIM) and replaces the masking of patches of pixels by the masking of principal components. 8 | 9 | ![pmae](https://github.com/alicebizeul/pmae/blob/main/assets/diagram-larger.png) 10 | 11 | ## Code Structure 12 | 13 | ``` 14 | . 15 | ├── assets # assets for the README file 16 | ├── configs # directory in which all experiment '.yaml' configs are stored 17 | ├── scripts # bash scripts to launch training and evaluation 18 | │ ├── train.sh # training script 19 | │ └── eval.sh # evaluation script 20 | ├── src # the package 21 | │ ├── plotting.py # plotting function to training tracking 22 | │ ├── utils.py # helper functions for init of models & opt/loading checkpoint 23 | │ ├── dataset # datasets, data loaders, ... 24 | │ └── model # models, training loops, ... 25 | ├── tools # scripts to compute PCA prior to training 26 | ├── main.py # entrypoint for launch PMAE pretraining locally on your machine 27 | └── requirements.txt # requirements file 28 | ``` 29 | 30 | **Config files:** 31 | Note that all experiment parameters are specified in config files (as opposed to command-line-arguments). See the [config/](config/) directory for example config files. 32 | 33 | 34 | ## Installation 35 | 36 | In your environment of choice, install the necessary requirements 37 | 38 | !pip install -r requirements.txt 39 | 40 | Alternatively, install individual packages as follows: 41 | 42 | !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 43 | !pip install pandas numpy pillow scikit-learn scikit-image plotly kaleido matplotlib submitit hydra-core pytorch-lightning imageio medmnist wandb transformers 44 | 45 | Create a config file that suits your machine: 46 | 47 | cd ./config/user 48 | cp abizeul_biomed.yaml myusername_mymachine.yaml 49 | 50 | Adjust the paths in ```myusername_mymachine.yaml``` to point to the directory you would like to use for storage of results and for fetching the data 51 | 52 | Make sure to either compute or download the necessary assets for the dataset you plan to use with PMAE. These include the mean and standard deviation for image normalization, as well as the eigenvalues and eigenvectors. For each dataset, these assets are available on Zenodo for which the link are listed below.ß 53 | 54 | - [CIFAR10](https://zenodo.org/records/14588944?token=eyJhbGciOiJIUzUxMiJ9.eyJpZCI6IjY3ODJmNjAwLWM5YTYtNDhiNy1iNDEzLThlYjJjN2RkMzYyMiIsImRhdGEiOnt9LCJyYW5kb20iOiIyODAyNGE5NzU1MGZlNWY2Zjc3NGExMzU1MGUxNTc0ZSJ9.gCq9v8x2srkjjlusAw3zlMFZu31I6dziOrroBiNbRHQsOs7PZadhbClREgeTMRcQZ4DXKxh1sMASIyHcC34k3Q) 55 | - [TinyImageNet](https://zenodo.org/records/14589101?token=eyJhbGciOiJIUzUxMiJ9.eyJpZCI6Ijk0YzI2NGZhLTZhYTYtNDRiMC04NjIzLTk1MjQxNDc5Njg1YyIsImRhdGEiOnt9LCJyYW5kb20iOiI0OGM1MTlmNTk3MDJiMjk3M2YyNzBjMzc2ZTkzYThhMyJ9.LAlnzb4HCHkhd_CAUTkz9LWptyrnsfDLTzHuFKCXjRAGK77YWXyA3L412aB5r5U77WcltxsetpUGEQCjebOuHg) 56 | - [ImageNet](https://zenodo.org/records/14589122?token=eyJhbGciOiJIUzUxMiJ9.eyJpZCI6IjYzYWE5MjEyLTQ3ZDMtNGFjOS1hYmIxLWQ3ZmJmNmEwNDBhNiIsImRhdGEiOnt9LCJyYW5kb20iOiI4YzQzZGU5N2NkYTkzNjdhZDAxYzQzYTFmMWNiZGFmYyJ9.KD1_j1A0ao9GS59rxDALo3Kvj9l5mhrLORf3cWkGUFJGO8ycs0e9STD0dqkAaweiSxYgqD3N8AFSOCbw12rD9Q) 57 | - [BloodMNIST](https://zenodo.org/records/14588621?token=eyJhbGciOiJIUzUxMiJ9.eyJpZCI6IjljMzM2YjE3LTg3MTQtNDA2MS1hYzU5LTZhMWY2Y2IwNmE1OSIsImRhdGEiOnt9LCJyYW5kb20iOiJjY2MyYjVhM2ZmMzkxNmIzMWMwNzFlZmE0YTIwNjJmZiJ9.K9eA_KqJFMA5zfHU_lRUbQ-143Jj1M7IjB8nLGY6WShbqKC-g4E7_W96z7YWzf0wB25A-N6Bh0g8nqxxaPTKGA) 58 | - [DermaMNIST](https://zenodo.org/records/14588800?token=eyJhbGciOiJIUzUxMiJ9.eyJpZCI6ImVkMDExNzU0LWJhODgtNDg0My1iODM0LWViMjg2ZDQ4NDk3MSIsImRhdGEiOnt9LCJyYW5kb20iOiIyZWIzMTY4NjYyNTA0MDRmNjkyNGI1NzI2ODliY2UzMiJ9.Dzkm-d0kba1FYwdW0h4oBav-qhGckbuirAF-Gre_JGJ6S0CTWDRESldO9AATRqwvCPNf7h3qa8i0KYnYZckCXw) 59 | - [PathMNIST](https://zenodo.org/records/14589091?token=eyJhbGciOiJIUzUxMiJ9.eyJpZCI6IjdkMzg2NzAxLWMwMGQtNDcxMi05ODRmLTBiNjk5ZTlmNTMyZCIsImRhdGEiOnt9LCJyYW5kb20iOiI2MjdhOGI0ZGI0MjcxM2Q2ZDFjYWYyNjBlNmMxYmM2NCJ9.yD3jRzhdy-vt0PIN-bNcZWSR5Uxz4jDOPvqNE4UeQfKwq3n11gp-YdyVFL-Rv_2eMNbYc3o2euM8iMfQxcNK6A) 60 | 61 | Once files are downloaded and stored on your local machine, make sure to specify their path in the dataset's [config](config/dataset/) (```data.mean.data.file```,```data.std.data.file```,```extradata.pcamodule.file``` and ```extradata.eigenratiomodule.file```). See imagenet's [config](config/dataset/imagenet.yaml) as an example. 62 | 63 | ## Launch Training 64 | To launch experiments, you can find training and evaluation scripts in ```scripts```. The following modifications should be made to the ```train.sh``` script to ensure a smooth training on your local machine: 65 | 66 | USER_MACHINE="myusername_mymachine" # the user which runs the experiment 67 | EXPERIMENT="pmae_tiny_pc" # the experiment to run, defines the model, dataset and masking type 68 | MASK=0.2 # the masking ratio to use, default: 0.2 69 | 70 | Please find the whole set of pre-defined experiment to chose from in [config/experiment](config/experiment). Note that ```train.sh``` does include a final evaluation of the representations using a linear probe. 71 | 72 | **Distributed Training:** For distributed training, please use the ```train_distributed.sh``` script instead and adjust the number of GPUs according to your own ressources. Note that our code uses Pytorch Lightning for distributed training. 73 | 74 | **Baselines:** To run the MAE baseline in place of PMAE, adjust ```EXPERIMENT``` to ```mae_tiny``` or any other experiment which starts by ```mae```. 75 | 76 | **Random Masking:** To run PMAE with randomized masking ratios as presented in the [\[arXiv\]](https://alicebizeul.github.io/assets/pdf/mae.pdf), adjust ```EXPERIMENT``` to ```pmae_tiny_pcsampling``` or any other experiment which contains ```pcsampling```. 77 | 78 | ## Launch Evaluation 79 | To evaluate a checkpoint, the evaluation script for linear probe, MLP probe, k-nearest neighbors, and fine-tuning approaches can be found in the ```scripts``` directory. The following modifications should be made to the ```eval.sh``` script to ensure a smooth evaluation on your local machine: 80 | 81 | USER_MACHINE="myusername_mymachine" # the user which runs the experiment 82 | EXPERIMENT="pmae_tiny_pc" # the experiment to run, defines the model, dataset and masking type 83 | EPOCH=800 # the epoch to be evaluated 84 | MASK=0.2 # the masking ratio to use, default: 0.2 85 | 86 | Additionally, ensure the path to the checkpoint you want to evaluate is correctly set in your [user configuration file](config/user/abizeul_euler.yaml). For reference, see config/user/abizeul_euler.yaml. The specified checkpoint (defined by its path and epoch) will then be evaluated. 87 | -------------------------------------------------------------------------------- /src/model/vit_mae.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module Name: vit_mae.py 3 | Author: Alice Bizeul 4 | Ownership: ETH Zürich - ETH AI Center 5 | """ 6 | 7 | import torch 8 | from transformers import ViTMAEConfig, ViTMAEPreTrainedModel, ViTMAEModel 9 | from transformers.models.vit_mae.modeling_vit_mae import ViTMAEDecoder, ViTMAEForPreTrainingOutput 10 | from typing import Optional, Set, Tuple, Union 11 | 12 | class ViTMAEForPreTraining(ViTMAEPreTrainedModel): 13 | def __init__(self, config): 14 | super().__init__(config) 15 | self.config = config 16 | 17 | self.vit = ViTMAEModel(config) # self.vit.embeddings.config 18 | self.decoder = ViTMAEDecoder(config, num_patches=self.vit.embeddings.num_patches) 19 | 20 | # Initialize weights and apply final processing 21 | self.post_init() 22 | 23 | def get_input_embeddings(self): 24 | return self.vit.embeddings.patch_embeddings 25 | 26 | def delete_decoder(self): 27 | self.decoder = None 28 | return 29 | 30 | def _prune_heads(self, heads_to_prune): 31 | """ 32 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 33 | class PreTrainedModel 34 | """ 35 | for layer, heads in heads_to_prune.items(): 36 | self.encoder.layer[layer].attention.prune_heads(heads) 37 | 38 | def patchify(self, pixel_values, interpolate_pos_encoding: bool = False): 39 | """ 40 | Args: 41 | pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 42 | Pixel values. 43 | interpolate_pos_encoding (`bool`, *optional*, default `False`): 44 | interpolation flag passed during the forward pass. 45 | 46 | Returns: 47 | `torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: 48 | Patchified pixel values. 49 | """ 50 | patch_size, num_channels = self.config.patch_size, self.config.num_channels 51 | # sanity checks 52 | if not interpolate_pos_encoding and ( 53 | pixel_values.shape[2] != pixel_values.shape[3] or pixel_values.shape[2] % patch_size != 0 54 | ): 55 | raise ValueError("Make sure the pixel values have a squared size that is divisible by the patch size") 56 | if pixel_values.shape[1] != num_channels: 57 | raise ValueError( 58 | "Make sure the number of channels of the pixel values is equal to the one set in the configuration" 59 | ) 60 | 61 | # patchify 62 | batch_size = pixel_values.shape[0] 63 | num_patches_h = pixel_values.shape[2] // patch_size 64 | num_patches_w = pixel_values.shape[3] // patch_size 65 | patchified_pixel_values = pixel_values.reshape( 66 | batch_size, num_channels, num_patches_h, patch_size, num_patches_w, patch_size 67 | ) 68 | patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values) 69 | patchified_pixel_values = patchified_pixel_values.reshape( 70 | batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels 71 | ) 72 | return patchified_pixel_values 73 | 74 | def unpatchify(self, patchified_pixel_values, original_image_size: Optional[Tuple[int, int]] = None): 75 | """ 76 | Args: 77 | patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: 78 | Patchified pixel values. 79 | original_image_size (`Tuple[int, int]`, *optional*): 80 | Original image size. 81 | 82 | Returns: 83 | `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`: 84 | Pixel values. 85 | """ 86 | patch_size, num_channels = self.config.patch_size, self.config.num_channels 87 | original_image_size = ( 88 | original_image_size 89 | if original_image_size is not None 90 | else (self.config.image_size, self.config.image_size) 91 | ) 92 | original_height, original_width = original_image_size 93 | num_patches_h = original_height // patch_size 94 | num_patches_w = original_width // patch_size 95 | # sanity check 96 | if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]: 97 | raise ValueError( 98 | f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}" 99 | ) 100 | 101 | # unpatchify 102 | batch_size = patchified_pixel_values.shape[0] 103 | patchified_pixel_values = patchified_pixel_values.reshape( 104 | batch_size, 105 | num_patches_h, 106 | num_patches_w, 107 | patch_size, 108 | patch_size, 109 | num_channels, 110 | ) 111 | patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values) 112 | pixel_values = patchified_pixel_values.reshape( 113 | batch_size, 114 | num_channels, 115 | num_patches_h * patch_size, 116 | num_patches_w * patch_size, 117 | ) 118 | return pixel_values 119 | 120 | def forward_loss(self, pixel_values, pred, mask, interpolate_pos_encoding: bool = False, patchify=True): 121 | """ 122 | Args: 123 | pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 124 | Pixel values. 125 | pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: 126 | Predicted pixel values. 127 | mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): 128 | Tensor indicating which patches are masked (1) and which are not (0). 129 | interpolate_pos_encoding (`bool`, *optional*, default `False`): 130 | interpolation flag passed during the forward pass. 131 | 132 | Returns: 133 | `torch.FloatTensor`: Pixel reconstruction loss. 134 | """ 135 | if patchify: 136 | target = self.patchify(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) 137 | else: 138 | target = pixel_values 139 | if self.config.norm_pix_loss: 140 | mean = target.mean(dim=-1, keepdim=True) 141 | var = target.var(dim=-1, keepdim=True) 142 | target = (target - mean) / (var + 1.0e-6) ** 0.5 143 | 144 | loss = (pred - target) ** 2 145 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 146 | 147 | if mask.sum() > 0: 148 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 149 | else: 150 | loss = loss.mean() 151 | 152 | return loss 153 | 154 | def forward( 155 | self, 156 | pixel_values: Optional[torch.FloatTensor] = None, 157 | noise: Optional[torch.FloatTensor] = None, 158 | head_mask: Optional[torch.FloatTensor] = None, 159 | output_attentions: Optional[bool] = None, 160 | output_hidden_states: Optional[bool] = None, 161 | return_rep: Optional[bool] = None, 162 | return_dict: Optional[bool] =None, 163 | interpolate_pos_encoding: bool = False, 164 | ) -> Union[Tuple, ViTMAEForPreTrainingOutput]: 165 | r""" 166 | Returns: 167 | 168 | Examples: 169 | 170 | ```python 171 | >>> from transformers import AutoImageProcessor, ViTMAEForPreTraining 172 | >>> from PIL import Image 173 | >>> import requests 174 | 175 | >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" 176 | >>> image = Image.open(requests.get(url, stream=True).raw) 177 | 178 | >>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base") 179 | >>> model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base") 180 | 181 | >>> inputs = image_processor(images=image, return_tensors="pt") 182 | >>> outputs = model(**inputs) 183 | >>> loss = outputs.loss 184 | >>> mask = outputs.mask 185 | >>> ids_restore = outputs.ids_restore 186 | ```""" 187 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 188 | 189 | outputs = self.vit( 190 | pixel_values, 191 | noise=noise, 192 | head_mask=head_mask, 193 | output_attentions=output_attentions, 194 | output_hidden_states=output_hidden_states, 195 | return_dict=return_dict, 196 | interpolate_pos_encoding=interpolate_pos_encoding, 197 | ) 198 | 199 | latent = outputs.last_hidden_state 200 | ids_restore = outputs.ids_restore 201 | mask = outputs.mask 202 | 203 | # Mask out encoder embeddings that correspond to padded tokens 204 | if head_mask is not None: 205 | head_mask_per_token = head_mask[0,:,0,0][...,None] 206 | latent = head_mask_per_token*latent 207 | 208 | #self.forward_loss(pixel_values, logits, mask, interpolate_pos_encoding=interpolate_pos_encoding) 209 | 210 | # if not return_dict: 211 | # output = (logits, mask, ids_restore) + outputs[2:] 212 | # return ((loss,) + output) if loss is not None else output 213 | 214 | if return_rep: 215 | return latent[:,0,:], outputs.attentions 216 | else: 217 | decoder_outputs = self.decoder(latent, ids_restore, interpolate_pos_encoding=interpolate_pos_encoding) 218 | logits = decoder_outputs.logits # shape (batch_size, num_patches, patch_size*patch_size*num_channels) 219 | 220 | return ViTMAEForPreTrainingOutput( 221 | loss=0, 222 | logits=logits, 223 | mask=mask, 224 | ids_restore=ids_restore, 225 | hidden_states=outputs.hidden_states, 226 | attentions=outputs.attentions, 227 | ), latent[:,0,:] -------------------------------------------------------------------------------- /src/model/module.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module Name: module.py 3 | Author: Alice Bizeul 4 | Ownership: ETH Zürich - ETH AI Center 5 | """ 6 | 7 | # Standard library imports 8 | import os 9 | import time 10 | from typing import Any, Dict, List, Optional 11 | 12 | # Third-party library imports 13 | import matplotlib.pyplot as plt 14 | import numpy as np 15 | import pytorch_lightning as pl 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | import torchmetrics 20 | from torch import Tensor 21 | from torch.nn.parameter import Parameter 22 | from torchvision.models import resnet18 23 | import wandb 24 | 25 | # Local imports 26 | from ..plotting import plot_loss, plot_performance 27 | from ..utils import save_attention_maps, save_attention_maps_batch, save_reconstructed_images 28 | 29 | class ViTMAE(pl.LightningModule): 30 | 31 | def __init__( 32 | self, 33 | model, 34 | learning_rate: float = 1e-3, 35 | base_learning_rate: float =1e-3, 36 | weight_decay: float = 0.05, 37 | betas: list =[0.9,0.95], 38 | optimizer_name: str = "adamw", 39 | warmup: int =10, 40 | datamodule: Optional[pl.LightningDataModule] = None, 41 | eval_freq: int =100, 42 | eval_type ="multiclass", 43 | eval_fn =nn.CrossEntropyLoss(), 44 | eval_logit_fn = nn.Softmax(), 45 | save_dir: str =None, 46 | ): 47 | super().__init__() 48 | self.learning_rate = learning_rate 49 | self.weight_decay = weight_decay 50 | self.betas = betas 51 | self.optimizer_name = optimizer_name 52 | self.datamodule = datamodule 53 | self.num_classes = datamodule.num_classes 54 | self.image_size = datamodule.image_size 55 | self.classifierlr = learning_rate 56 | self.warm_up = warmup 57 | self.eval_freq = eval_freq 58 | self.masking = datamodule.masking 59 | 60 | self.model = model 61 | 62 | if self.masking.type == "pc": 63 | self.register_buffer("masking_fn_",torch.Tensor(self.datamodule.extra_data.pcamodule.T)) 64 | elif self.masking.type == "random": 65 | self.register_buffer("masking_fn",nn.Linear()) 66 | 67 | self.classifier = nn.Linear(model.config.hidden_size, self.num_classes) 68 | self.online_classifier_loss = eval_fn 69 | self.online_logit_fn= eval_logit_fn 70 | self.online_train_accuracy = torchmetrics.Accuracy( 71 | task=eval_type, num_classes=self.num_classes, top_k=1 72 | ) 73 | self.online_val_accuracy = torchmetrics.Accuracy( 74 | task=eval_type, num_classes=self.num_classes, top_k=1 75 | ) 76 | self.save_dir = save_dir 77 | self.train_losses = [] 78 | self.avg_train_losses = [] 79 | self.online_losses = [] 80 | self.avg_online_losses = [] 81 | self.performance = {} 82 | 83 | def forward(self, x): 84 | return self.model(self.transformation(x)) 85 | 86 | def shared_step(self, batch: Tensor, stage: str = "train", batch_idx: int = None): 87 | if stage == "train": 88 | img, y, pc_mask = batch 89 | 90 | # mae training 91 | if self.masking.type == "pc": 92 | target = (img.reshape([img.shape[0],-1]) @ self.masking_fn_[:,pc_mask]) 93 | 94 | if self.masking.strategy in ["sampling_pc","pc"]: 95 | indexes = self.indexes.to(self.device) 96 | pc_mask_input = indexes[~torch.isin(indexes,pc_mask)] 97 | 98 | img = ((img.reshape([img.shape[0],-1]) @ self.masking_fn_[:,pc_mask_input])@ self.masking_fn_[:,pc_mask_input].T).reshape(img.shape) 99 | 100 | elif self.masking.type == "pixel": 101 | if self.masking.strategy == "sampling": 102 | self.model.config.mask_ratio = pc_mask 103 | self.model.vit.embeddings.config.mask_ratio=pc_mask 104 | target = img 105 | 106 | outputs, cls = self.model(img,return_rep=False) 107 | reconstruction = self.model.unpatchify(outputs.logits) 108 | mask = outputs.mask.unsqueeze(-1).repeat(1, 1, self.model.config.patch_size**2 *3) 109 | mask = self.model.unpatchify(mask) 110 | 111 | if self.masking.type == "pc": 112 | outputs.logits = reconstruction.reshape([img.shape[0],-1]) @ self.masking_fn_[:,pc_mask] 113 | outputs.mask = torch.zeros_like(mask.reshape([mask.shape[0],-1]),device=self.device) 114 | 115 | loss_mae = self.model.forward_loss(target,outputs.logits,outputs.mask,patchify=False if self.masking.type == "pc" else True) 116 | 117 | if (self.current_epoch+1)%self.eval_freq==0 and batch_idx==0: 118 | self.log( 119 | f"{stage}_mae_loss", 120 | loss_mae, 121 | prog_bar=True, 122 | sync_dist=False, 123 | on_step=True, 124 | on_epoch=False 125 | ) 126 | self.train_losses.append(loss_mae.item()) 127 | self.avg_train_losses.append(np.mean(self.train_losses)) 128 | plot_loss(self.avg_train_losses,name_loss="MSE",save_dir=self.save_dir,name_file="_train") 129 | plot_loss(self.avg_online_losses,name_loss="X-Ent",save_dir=self.save_dir,name_file="_train_online_cls") 130 | 131 | if ( 132 | self.model.config.mask_ratio is None 133 | or self.model.config.mask_ratio > 0 134 | ): 135 | save_reconstructed_images((-1*(mask[:10]-1))*img[:10],mask[:10]*img[:10], reconstruction[:10], self.current_epoch+1, self.save_dir,"train") 136 | else: 137 | save_reconstructed_images(img[:10], target[:10], reconstruction[:10], self.current_epoch+1, self.save_dir,"train") 138 | 139 | 140 | # online classifier 141 | logits_cls = self.classifier(cls.detach()) 142 | loss_ce = self.online_classifier_loss(logits_cls.squeeze(),y.squeeze()) 143 | 144 | if (self.current_epoch+1)%self.eval_freq==0 and batch_idx==0: 145 | self.log(f"{stage}_classifier_loss", loss_ce, sync_dist=False, on_step=True, on_epoch=False) 146 | 147 | accuracy_metric = getattr(self, f"online_{stage}_accuracy") 148 | accuracy_metric(self.online_logit_fn(logits_cls.squeeze()), y.squeeze()) 149 | self.log( 150 | f"online_{stage}_accuracy", 151 | accuracy_metric, 152 | prog_bar=False, 153 | sync_dist=True, 154 | ) 155 | del logits_cls 156 | 157 | self.online_losses.append(loss_ce.item()) 158 | self.avg_online_losses.append(np.mean(self.online_losses)) 159 | 160 | plot_loss(self.avg_online_losses,name_loss="X-Ent",save_dir=self.save_dir,name_file="_train_online_cls") 161 | 162 | return loss_mae + loss_ce 163 | 164 | else: 165 | img, y = batch 166 | cls, _ = self.model(img,return_rep=True) 167 | logits = self.classifier(cls.detach()) 168 | 169 | accuracy_metric = getattr(self, f"online_{stage}_accuracy") 170 | accuracy_metric(self.online_logit_fn(logits.squeeze()), y.squeeze()) 171 | self.log( 172 | f"online_{stage}_accuracy", 173 | accuracy_metric, 174 | prog_bar=True, 175 | sync_dist=True, 176 | on_epoch=True, 177 | on_step=False, 178 | ) 179 | 180 | if batch_idx == 0: 181 | if self.current_epoch+1 not in list(self.performance.keys()): 182 | self.performance[self.current_epoch+1]=[] 183 | 184 | if len(y.squeeze().shape) > 1: 185 | self.performance[self.current_epoch+1].append(sum(sum(1*((self.online_logit_fn(logits.squeeze())>0.5)==y.squeeze()))).item()) 186 | else: 187 | self.performance[self.current_epoch+1].append(sum(1*(torch.argmax(self.online_logit_fn(logits.squeeze()), dim=-1)==y.squeeze())).item()) 188 | 189 | return None 190 | 191 | def on_validation_epoch_end(self): 192 | self.performance[self.current_epoch+1] = sum(self.performance[self.current_epoch+1])/self.datamodule.num_val_samples 193 | plot_performance(list(self.performance.keys()),list(self.performance.values()),self.save_dir,name="val") 194 | 195 | def training_step(self, batch, batch_idx): 196 | loss = self.shared_step(batch, stage="train", batch_idx=batch_idx) 197 | return loss 198 | 199 | def validation_step(self, batch, batch_idx): 200 | loss = self.shared_step(batch, stage="val", batch_idx=batch_idx) 201 | return loss 202 | 203 | def test_step(self, batch, batch_idx): 204 | loss = self.shared_step(batch, stage="test", batch_idx=batch_idx) 205 | return loss 206 | 207 | def configure_optimizers(self): 208 | def warmup(current_step: int): 209 | return 1 / (10 ** (float(num_warmup_epochs - current_step))) 210 | 211 | if self.optimizer_name == "adamw": 212 | optimizer = torch.optim.AdamW( 213 | self.parameters(), 214 | lr=self.learning_rate, 215 | weight_decay=self.weight_decay, 216 | ) 217 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 218 | optimizer=optimizer, T_max=self.trainer.max_epochs, verbose=False 219 | ) 220 | elif self.optimizer_name == "adamw_warmup": 221 | num_warmup_epochs = self.warm_up 222 | optimizer = torch.optim.AdamW( 223 | self.parameters(), 224 | lr=self.learning_rate, 225 | weight_decay=self.weight_decay, 226 | betas=self.betas 227 | ) 228 | 229 | warmup_scheduler = torch.optim.lr_scheduler.LambdaLR( 230 | optimizer, lr_lambda=warmup 231 | ) 232 | 233 | train_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 234 | optimizer=optimizer, T_max=self.trainer.max_epochs, verbose=False 235 | ) 236 | 237 | lr_scheduler = torch.optim.lr_scheduler.SequentialLR( 238 | optimizer, [warmup_scheduler, train_scheduler], [num_warmup_epochs] 239 | ) 240 | 241 | else: 242 | raise ValueError(f"{self.optimizer_name} not supported") 243 | 244 | return [optimizer], [lr_scheduler] 245 | -------------------------------------------------------------------------------- /src/model/module_fine.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module Name: module_fine.py 3 | Author: Alice Bizeul 4 | Ownership: ETH Zürich - ETH AI Center 5 | """ 6 | 7 | # Standard library imports 8 | import os 9 | import time 10 | from typing import Any, Dict, List, Optional 11 | import csv 12 | 13 | # Third-party library imports 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | import pytorch_lightning as pl 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | import torchmetrics 21 | from torch import Tensor 22 | from torch.nn.parameter import Parameter 23 | from torchvision.models import resnet18 24 | import wandb 25 | 26 | # Local imports 27 | from ..plotting import plot_loss, plot_performance 28 | from ..utils import save_attention_maps, save_attention_maps_batch, save_reconstructed_images 29 | 30 | class ViTMAE_fine(pl.LightningModule): 31 | 32 | def __init__( 33 | self, 34 | model, 35 | learning_rate: float = 1e-3, 36 | base_learning_rate: float =1e-3, 37 | weight_decay: float = 0.05, 38 | betas: list =[0.9,0.95], 39 | optimizer_name: str = "adamw", 40 | warmup: int =10, 41 | datamodule: Optional[pl.LightningDataModule] = None, 42 | save_dir: str =None, 43 | evaluated_epoch: int =800, 44 | eval_type ="multiclass", 45 | eval_fn =nn.CrossEntropyLoss(), 46 | eval_logit_fn = nn.Softmax(), 47 | task :int =0, 48 | ): 49 | super().__init__() 50 | self.learning_rate = learning_rate 51 | self.weight_decay = weight_decay 52 | self.betas = betas 53 | self.optimizer_name = optimizer_name 54 | self.datamodule = datamodule 55 | self.num_classes = datamodule.num_classes 56 | self.image_size = datamodule.image_size 57 | self.warm_up = warmup 58 | self.evaluated_epoch = evaluated_epoch 59 | self.task = task 60 | 61 | model.delete_decoder() 62 | self.model = model 63 | self.model.config.mask_ratio = 0.0 64 | self.model.vit.embeddings.config.mask_ratio=0.0 65 | self.classifier = nn.Linear(model.config.hidden_size, self.num_classes) 66 | 67 | self.online_classifier_loss = eval_fn 68 | self.online_logit_fn= eval_logit_fn 69 | self.online_train_accuracy = torchmetrics.Accuracy( 70 | task=eval_type, num_classes=self.num_classes, top_k=1 71 | ) 72 | self.online_val_accuracy = torchmetrics.Accuracy( 73 | task=eval_type, num_classes=self.num_classes, top_k=1 74 | ) 75 | self.online_val_f1 = torchmetrics.F1Score( 76 | task=eval_type, num_classes=self.num_classes, top_k=1, average = "macro" 77 | ) 78 | 79 | self.save_dir = save_dir 80 | self.train_losses = [] 81 | self.avg_train_losses = [] 82 | self.performance = {} 83 | self.f1scores = {} 84 | 85 | def forward(self, x): 86 | return self.model(x) 87 | 88 | def shared_step(self, batch: Tensor, stage: str = "train", batch_idx: int =None): 89 | if stage == "train": 90 | img, y, _ = batch 91 | 92 | cls, _ = self.model(img,return_rep=True) 93 | logits = self.classifier(cls) 94 | 95 | loss_ce = self.online_classifier_loss(logits.squeeze(),y.squeeze()) 96 | self.log(f"final_{stage}_classifier_loss_{self.evaluated_epoch}_fine", loss_ce, sync_dist=True) 97 | 98 | accuracy_metric = getattr(self, f"online_{stage}_accuracy") 99 | accuracy_metric(self.online_logit_fn(logits.squeeze()), y.squeeze()) 100 | self.log( 101 | f"final_{stage}_accuracy_{self.evaluated_epoch}_fine", 102 | accuracy_metric, 103 | prog_bar=False, 104 | sync_dist=True, 105 | ) 106 | 107 | self.train_losses.append(loss_ce.item()) 108 | self.avg_train_losses.append(np.mean(self.train_losses)) 109 | 110 | if (self.current_epoch+1)%10==0 and batch_idx==0: 111 | plot_loss(self.avg_train_losses,name_loss="X-Entropy",save_dir=self.save_dir,name_file=f"_eval_train_{self.evaluated_epoch}_fine") 112 | return loss_ce 113 | 114 | else: 115 | img, y = batch 116 | 117 | cls, attentions = self.model(img,return_rep=True,output_attentions=True) 118 | logits = self.classifier(cls.detach()) 119 | 120 | accuracy_metric = getattr(self, f"online_{stage}_accuracy") 121 | accuracy_metric(self.online_logit_fn(logits.squeeze()), y.squeeze()) 122 | self.log( 123 | f"final_{stage}_accuracy_{self.evaluated_epoch}_fine", 124 | accuracy_metric, 125 | prog_bar=True, 126 | sync_dist=True, 127 | on_epoch=True, 128 | on_step=False, 129 | ) 130 | 131 | if batch_idx == 0 and (self.current_epoch+1) not in list(self.performance.keys()): 132 | self.performance[self.current_epoch+1]=[] 133 | if len(y.squeeze().shape) >1: 134 | self.f1scores[self.current_epoch+1]=[] 135 | 136 | if len(y.squeeze().shape) > 1: 137 | f1_metric = getattr(self, f"online_{stage}_f1") 138 | f1_score = f1_metric(self.online_logit_fn(logits.squeeze()), y.squeeze()) 139 | self.performance[self.current_epoch+1].append(sum(1*((self.online_logit_fn(logits.squeeze())>0.5)==y.squeeze())).detach().cpu().numpy()) 140 | self.f1scores[self.current_epoch+1].append(f1_score.detach().cpu().numpy()) 141 | 142 | else: 143 | self.performance[self.current_epoch+1].append(sum(1*(torch.argmax(logits.squeeze(), dim=-1)==y.squeeze())).item()) 144 | 145 | # check the attention we get at final 146 | if self.current_epoch==0 and batch_idx==0: 147 | attentions = attentions[-1].mean(1) 148 | att_map_cls = attentions[:,0,1:] 149 | att_map_spatial = torch.mean(attentions[:,1:,1:],dim=-1) 150 | att_map_cls = att_map_cls.reshape([img.shape[0],int(np.sqrt(att_map_cls.shape[-1])),int(np.sqrt(att_map_cls.shape[-1]))]) 151 | att_map_spatial = att_map_spatial.reshape([img.shape[0],int(np.sqrt(att_map_spatial.shape[-1])),int(np.sqrt(att_map_spatial.shape[-1]))]) 152 | save_attention_maps(img[:10],att_map_cls[:10].unsqueeze(1),att_map_spatial[:10].unsqueeze(1),self.current_epoch+1, self.save_dir,f"eval_{self.evaluated_epoch}_fine") 153 | save_attention_maps_batch(att_map_cls=att_map_cls,att_map_spatial=att_map_spatial,epoch=self.current_epoch+1, output_dir=self.save_dir,name=f"eval_{self.evaluated_epoch}_fine") 154 | 155 | return None 156 | 157 | def on_validation_epoch_end(self): 158 | if isinstance(self.performance[self.current_epoch+1][0],np.ndarray): 159 | self.performance[self.current_epoch+1] = np.array(self.performance[self.current_epoch+1]) 160 | self.f1scores[self.current_epoch+1] = np.array(self.f1scores[self.current_epoch+1]) 161 | 162 | self.performance[self.current_epoch+1] = sum(self.performance[self.current_epoch+1])/self.datamodule.num_val_samples 163 | 164 | if isinstance(self.performance[self.current_epoch+1],np.ndarray): 165 | self.performance[self.current_epoch+1] = np.mean(self.performance[self.current_epoch+1]) 166 | self.f1scores[self.current_epoch+1] = np.mean(self.f1scores[self.current_epoch+1]) 167 | 168 | if (self.current_epoch+1)%10 == 0 and not isinstance(self.performance[self.current_epoch+1],np.ndarray): 169 | plot_performance(list(self.performance.keys()),list(self.performance.values()),self.save_dir,name=f"val_final_{self.evaluated_epoch}_fine") 170 | 171 | def on_fit_end(self): 172 | # Write to a CSV file 173 | with open(os.path.join(self.save_dir,f'performance_final_{self.evaluated_epoch}_fine_task_{self.task}.csv'), 'w', newline='') as csvfile: 174 | writer = csv.writer(csvfile) 175 | writer.writerow(['Eval Epoch', 'Test Accuracy']) 176 | for epoch in list(self.performance.keys()): 177 | # Assuming you have the accuracy for each epoch stored in a list 178 | writer.writerow([epoch, round(100*self.performance[epoch],2)]) 179 | 180 | if len(self.f1scores.keys()) > 1 : 181 | with open(os.path.join(self.save_dir,f'f1scores_final_{self.evaluated_epoch}_fine_task_{self.task}.csv'), 'w', newline='') as csvfile: 182 | writer = csv.writer(csvfile) 183 | writer.writerow(['Eval Epoch', 'Test Accuracy']) 184 | for epoch in list(self.f1scores.keys()): 185 | # Assuming you have the accuracy for each epoch stored in a list 186 | writer.writerow([epoch, round(100*self.f1scores[epoch],2)]) 187 | return 188 | 189 | def training_step(self, batch, batch_idx): 190 | loss = self.shared_step(batch, stage="train", batch_idx=batch_idx) 191 | return loss 192 | 193 | def validation_step(self, batch, batch_idx): 194 | loss = self.shared_step(batch, stage="val", batch_idx=batch_idx) 195 | return loss 196 | 197 | def test_step(self, batch, batch_idx): 198 | loss = self.shared_step(batch, stage="test", batch_idx=batch_idx) 199 | return loss 200 | 201 | def configure_optimizers(self): 202 | def warmup(current_step: int): 203 | return 1 / (10 ** (float(num_warmup_epochs - current_step))) 204 | 205 | if self.optimizer_name == "adamw": 206 | optimizer = torch.optim.AdamW( 207 | self.parameters(), 208 | lr=self.learning_rate, 209 | weight_decay=self.weight_decay, 210 | ) 211 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 212 | optimizer=optimizer, T_max=self.trainer.max_epochs, verbose=False 213 | ) 214 | elif self.optimizer_name == "adamw_warmup": 215 | num_warmup_epochs = self.warm_up 216 | optimizer = torch.optim.AdamW( 217 | self.parameters(), 218 | lr=self.learning_rate, 219 | weight_decay=self.weight_decay, 220 | betas=self.betas 221 | ) 222 | warmup_scheduler = torch.optim.lr_scheduler.LambdaLR( 223 | optimizer, lr_lambda=warmup 224 | ) 225 | 226 | train_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 227 | optimizer=optimizer, T_max=self.trainer.max_epochs, verbose=False 228 | ) 229 | 230 | lr_scheduler = torch.optim.lr_scheduler.SequentialLR( 231 | optimizer, [warmup_scheduler, train_scheduler], [num_warmup_epochs] 232 | ) 233 | elif self.optimizer_name == "sgd": 234 | num_warmup_epochs = self.warm_up 235 | optimizer = torch.optim.SGD( 236 | self.parameters(), 237 | lr=self.learning_rate, 238 | weight_decay=self.weight_decay, 239 | momentum=self.betas 240 | ) 241 | 242 | warmup_scheduler = torch.optim.lr_scheduler.LambdaLR( 243 | optimizer, lr_lambda=warmup 244 | ) 245 | 246 | train_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 247 | optimizer=optimizer, T_max=self.trainer.max_epochs, verbose=False 248 | ) 249 | 250 | lr_scheduler = torch.optim.lr_scheduler.SequentialLR( 251 | optimizer, [warmup_scheduler, train_scheduler], [num_warmup_epochs] 252 | ) 253 | 254 | else: 255 | raise ValueError(f"{self.optimizer_name} not supported") 256 | 257 | return [optimizer], [lr_scheduler] 258 | 259 | -------------------------------------------------------------------------------- /src/model/module_lin.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module Name: module_lin.py 3 | Author: Alice Bizeul 4 | Ownership: ETH Zürich - ETH AI Center 5 | """ 6 | 7 | # Standard library imports 8 | import os 9 | import time 10 | from typing import Any, Dict, List, Optional 11 | import csv 12 | 13 | # Third-party library imports 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | import pytorch_lightning as pl 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | import torchmetrics 21 | from torch import Tensor 22 | from torch.nn.parameter import Parameter 23 | from torchvision.models import resnet18 24 | import wandb 25 | 26 | # Local imports 27 | from ..plotting import plot_loss, plot_performance 28 | from ..utils import save_attention_maps, save_attention_maps_batch, save_reconstructed_images 29 | 30 | class ViTMAE_lin(pl.LightningModule): 31 | 32 | def __init__( 33 | self, 34 | model, 35 | learning_rate: float = 1e-3, 36 | base_learning_rate: float =1e-3, 37 | weight_decay: float = 0.05, 38 | betas: list =[0.9,0.95], 39 | optimizer_name: str = "adamw", 40 | warmup: int =10, 41 | datamodule: Optional[pl.LightningDataModule] = None, 42 | save_dir: str =None, 43 | evaluated_epoch: int =800, 44 | eval_type ="multiclass", 45 | eval_fn =nn.CrossEntropyLoss(), 46 | eval_logit_fn = nn.Softmax(), 47 | task :int =0, 48 | ): 49 | super().__init__() 50 | self.learning_rate = learning_rate 51 | self.weight_decay = weight_decay 52 | self.betas = betas 53 | self.optimizer_name = optimizer_name 54 | self.datamodule = datamodule 55 | self.num_classes = datamodule.num_classes 56 | self.image_size = datamodule.image_size 57 | self.warm_up = warmup 58 | self.evaluated_epoch = evaluated_epoch 59 | self.task = task 60 | 61 | model.delete_decoder() 62 | self.model = model 63 | self.model.config.mask_ratio = 0.0 64 | self.model.vit.embeddings.config.mask_ratio=0.0 65 | self.classifier = nn.Linear(model.config.hidden_size, self.num_classes) 66 | 67 | self.online_classifier_loss = eval_fn 68 | self.online_logit_fn= eval_logit_fn 69 | self.online_train_accuracy = torchmetrics.Accuracy( 70 | task=eval_type, num_classes=self.num_classes, top_k=1 71 | ) 72 | self.online_val_accuracy = torchmetrics.Accuracy( 73 | task=eval_type, num_classes=self.num_classes, top_k=1 74 | ) 75 | self.online_val_f1 = torchmetrics.F1Score( 76 | task=eval_type, num_classes=self.num_classes, top_k=1, average = "macro" 77 | ) 78 | 79 | self.save_dir = save_dir 80 | self.train_losses = [] 81 | self.avg_train_losses = [] 82 | self.performance = {} 83 | self.f1scores = {} 84 | 85 | def forward(self, x): 86 | return self.model(x) 87 | 88 | def shared_step(self, batch: Tensor, stage: str = "train", batch_idx: int =None): 89 | if stage == "train": 90 | img, y, _ = batch 91 | if len(y.shape)>1: 92 | y = y[:,self.task] 93 | cls, _ = self.model(img,return_rep=True) 94 | logits = self.classifier(cls.detach()) 95 | 96 | loss_ce = self.online_classifier_loss(logits.squeeze(),y.squeeze()) 97 | self.log(f"final_{stage}_classifier_loss_{self.evaluated_epoch}_lin", loss_ce, sync_dist=True) 98 | 99 | accuracy_metric = getattr(self, f"online_{stage}_accuracy") 100 | accuracy_metric(self.online_logit_fn(logits.squeeze()), y.squeeze()) 101 | self.log( 102 | f"final_{stage}_accuracy_{self.evaluated_epoch}_lin", 103 | accuracy_metric, 104 | prog_bar=False, 105 | sync_dist=True, 106 | ) 107 | 108 | self.train_losses.append(loss_ce.item()) 109 | self.avg_train_losses.append(np.mean(self.train_losses)) 110 | 111 | if (self.current_epoch+1)%10==0 and batch_idx==0: 112 | plot_loss(self.avg_train_losses,name_loss="X-Entropy",save_dir=self.save_dir,name_file=f"_eval_train_{self.evaluated_epoch}_lin") 113 | return loss_ce 114 | 115 | else: 116 | img, y = batch 117 | 118 | if len(y.shape)>1: 119 | y = y[:,self.task] 120 | 121 | cls, attentions = self.model(img,return_rep=True,output_attentions=True) 122 | logits = self.classifier(cls.detach()) 123 | 124 | accuracy_metric = getattr(self, f"online_{stage}_accuracy") 125 | accuracy_metric(self.online_logit_fn(logits.squeeze()), y.squeeze()) 126 | self.log( 127 | f"final_{stage}_accuracy_{self.evaluated_epoch}_lin", 128 | accuracy_metric, 129 | prog_bar=True, 130 | sync_dist=True, 131 | on_epoch=True, 132 | on_step=False, 133 | ) 134 | 135 | if batch_idx == 0 and (self.current_epoch+1) not in list(self.performance.keys()): 136 | self.performance[self.current_epoch+1]=[] 137 | if len(y.squeeze().shape) >1: 138 | self.f1scores[self.current_epoch+1]=[] 139 | 140 | if len(y.squeeze().shape) > 1: 141 | f1_metric = getattr(self, f"online_{stage}_f1") 142 | f1_score = f1_metric(self.online_logit_fn(logits.squeeze()), y.squeeze()) 143 | self.performance[self.current_epoch+1].append(sum(1*((self.online_logit_fn(logits.squeeze())>0.5)==y.squeeze())).detach().cpu().numpy()) 144 | self.f1scores[self.current_epoch+1].append(f1_score.detach().cpu().numpy()) 145 | 146 | else: 147 | self.performance[self.current_epoch+1].append(sum(1*(torch.argmax(logits.squeeze(), dim=-1)==y.squeeze())).item()) 148 | 149 | # check the attention we get at final 150 | if self.current_epoch==0 and batch_idx==0: 151 | attentions = attentions[-1].mean(1) 152 | att_map_cls = attentions[:,0,1:] 153 | att_map_spatial = torch.mean(attentions[:,1:,1:],dim=-1) 154 | att_map_cls = att_map_cls.reshape([img.shape[0],int(np.sqrt(att_map_cls.shape[-1])),int(np.sqrt(att_map_cls.shape[-1]))]) 155 | att_map_spatial = att_map_spatial.reshape([img.shape[0],int(np.sqrt(att_map_spatial.shape[-1])),int(np.sqrt(att_map_spatial.shape[-1]))]) 156 | save_attention_maps(img[:10],att_map_cls[:10].unsqueeze(1),att_map_spatial[:10].unsqueeze(1),self.current_epoch+1, self.save_dir,f"eval_{self.evaluated_epoch}_lin") 157 | save_attention_maps_batch(att_map_cls=att_map_cls,att_map_spatial=att_map_spatial,epoch=self.current_epoch+1, output_dir=self.save_dir,name=f"eval_{self.evaluated_epoch}_lin") 158 | 159 | return None 160 | 161 | def on_validation_epoch_end(self): 162 | if isinstance(self.performance[self.current_epoch+1][0],np.ndarray): 163 | self.performance[self.current_epoch+1] = np.array(self.performance[self.current_epoch+1]) 164 | self.f1scores[self.current_epoch+1] = np.array(self.f1scores[self.current_epoch+1]) 165 | 166 | self.performance[self.current_epoch+1] = sum(self.performance[self.current_epoch+1])/self.datamodule.num_val_samples 167 | 168 | if isinstance(self.performance[self.current_epoch+1],np.ndarray): 169 | self.performance[self.current_epoch+1] = np.mean(self.performance[self.current_epoch+1]) 170 | self.f1scores[self.current_epoch+1] = np.mean(self.f1scores[self.current_epoch+1]) 171 | 172 | if (self.current_epoch+1)%10 == 0 and not isinstance(self.performance[self.current_epoch+1],np.ndarray): 173 | plot_performance(list(self.performance.keys()),list(self.performance.values()),self.save_dir,name=f"val_final_{self.evaluated_epoch}_lin") 174 | 175 | def on_fit_end(self): 176 | # Write to a CSV file 177 | with open(os.path.join(self.save_dir,f'performance_final_{self.evaluated_epoch}_lin_task_{self.task}.csv'), 'w', newline='') as csvfile: 178 | writer = csv.writer(csvfile) 179 | writer.writerow(['Eval Epoch', 'Test Accuracy']) 180 | for epoch in list(self.performance.keys()): 181 | # Assuming you have the accuracy for each epoch stored in a list 182 | writer.writerow([epoch, round(100*self.performance[epoch],2)]) 183 | 184 | if len(self.f1scores.keys()) > 1 : 185 | with open(os.path.join(self.save_dir,f'f1scores_final_{self.evaluated_epoch}_lin_task_{self.task}.csv'), 'w', newline='') as csvfile: 186 | writer = csv.writer(csvfile) 187 | writer.writerow(['Eval Epoch', 'Test Accuracy']) 188 | for epoch in list(self.f1scores.keys()): 189 | # Assuming you have the accuracy for each epoch stored in a list 190 | writer.writerow([epoch, round(100*self.f1scores[epoch],2)]) 191 | return 192 | 193 | def training_step(self, batch, batch_idx): 194 | loss = self.shared_step(batch, stage="train", batch_idx=batch_idx) 195 | return loss 196 | 197 | def validation_step(self, batch, batch_idx): 198 | loss = self.shared_step(batch, stage="val", batch_idx=batch_idx) 199 | return loss 200 | 201 | def test_step(self, batch, batch_idx): 202 | loss = self.shared_step(batch, stage="test", batch_idx=batch_idx) 203 | return loss 204 | 205 | def configure_optimizers(self): 206 | def warmup(current_step: int): 207 | return 1 / (10 ** (float(num_warmup_epochs - current_step))) 208 | 209 | if self.optimizer_name == "adamw": 210 | optimizer = torch.optim.AdamW( 211 | self.classifier.parameters(), 212 | lr=self.learning_rate, 213 | weight_decay=self.weight_decay, 214 | ) 215 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 216 | optimizer=optimizer, T_max=self.trainer.max_epochs, verbose=False 217 | ) 218 | elif self.optimizer_name == "adamw_warmup": 219 | num_warmup_epochs = self.warm_up 220 | optimizer = torch.optim.AdamW( 221 | self.classifier.parameters(), 222 | lr=self.learning_rate, 223 | weight_decay=self.weight_decay, 224 | betas=self.betas 225 | ) 226 | warmup_scheduler = torch.optim.lr_scheduler.LambdaLR( 227 | optimizer, lr_lambda=warmup 228 | ) 229 | 230 | train_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 231 | optimizer=optimizer, T_max=self.trainer.max_epochs, verbose=False 232 | ) 233 | 234 | lr_scheduler = torch.optim.lr_scheduler.SequentialLR( 235 | optimizer, [warmup_scheduler, train_scheduler], [num_warmup_epochs] 236 | ) 237 | elif self.optimizer_name == "sgd": 238 | num_warmup_epochs = self.warm_up 239 | optimizer = torch.optim.SGD( 240 | self.classifier.parameters(), 241 | lr=self.learning_rate, 242 | weight_decay=self.weight_decay, 243 | momentum=self.betas 244 | ) 245 | 246 | warmup_scheduler = torch.optim.lr_scheduler.LambdaLR( 247 | optimizer, lr_lambda=warmup 248 | ) 249 | 250 | train_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 251 | optimizer=optimizer, T_max=self.trainer.max_epochs, verbose=False 252 | ) 253 | 254 | lr_scheduler = torch.optim.lr_scheduler.SequentialLR( 255 | optimizer, [warmup_scheduler, train_scheduler], [num_warmup_epochs] 256 | ) 257 | 258 | else: 259 | raise ValueError(f"{self.optimizer_name} not supported") 260 | 261 | return [optimizer], [lr_scheduler] 262 | 263 | -------------------------------------------------------------------------------- /src/model/module_mlp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module Name: module_mlp.py 3 | Author: Alice Bizeul 4 | Ownership: ETH Zürich - ETH AI Center 5 | """ 6 | 7 | # Standard library imports 8 | import os 9 | import time 10 | from typing import Any, Dict, List, Optional 11 | import csv 12 | 13 | # Third-party library imports 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | import pytorch_lightning as pl 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | import torchmetrics 21 | from torch import Tensor 22 | from torch.nn.parameter import Parameter 23 | from torchvision.models import resnet18 24 | import wandb 25 | 26 | # Local imports 27 | from ..plotting import plot_loss, plot_performance 28 | from ..utils import save_attention_maps, save_attention_maps_batch, save_reconstructed_images 29 | 30 | class ViTMAE_mlp(pl.LightningModule): 31 | 32 | def __init__( 33 | self, 34 | model, 35 | learning_rate: float = 1e-3, 36 | base_learning_rate: float =1e-3, 37 | weight_decay: float = 0.05, 38 | betas: list =[0.9,0.95], 39 | optimizer_name: str = "adamw", 40 | warmup: int =10, 41 | datamodule: Optional[pl.LightningDataModule] = None, 42 | save_dir: str =None, 43 | evaluated_epoch: int =800, 44 | eval_type ="multiclass", 45 | eval_fn =nn.CrossEntropyLoss(), 46 | eval_logit_fn = nn.Softmax(), 47 | task :int =0, 48 | ): 49 | super().__init__() 50 | self.learning_rate = learning_rate 51 | self.weight_decay = weight_decay 52 | self.betas = betas 53 | self.optimizer_name = optimizer_name 54 | self.datamodule = datamodule 55 | self.num_classes = datamodule.num_classes 56 | self.image_size = datamodule.image_size 57 | self.warm_up = warmup 58 | self.evaluated_epoch = evaluated_epoch 59 | self.task = task 60 | 61 | model.delete_decoder() 62 | self.model = model 63 | self.model.config.mask_ratio = 0.0 64 | self.model.vit.embeddings.config.mask_ratio=0.0 65 | 66 | self.classifier = nn.Sequential( 67 | nn.Linear(model.config.hidden_size, model.config.hidden_size), 68 | nn.ReLU(), 69 | nn.Linear(model.config.hidden_size, model.config.hidden_size), 70 | nn.ReLU(), 71 | nn.Linear(model.config.hidden_size, self.num_classes) 72 | ) 73 | 74 | self.online_classifier_loss = eval_fn 75 | self.online_logit_fn= eval_logit_fn 76 | self.online_train_accuracy = torchmetrics.Accuracy( 77 | task=eval_type, num_classes=self.num_classes, top_k=1 78 | ) 79 | self.online_val_accuracy = torchmetrics.Accuracy( 80 | task=eval_type, num_classes=self.num_classes, top_k=1 81 | ) 82 | self.online_val_f1 = torchmetrics.F1Score( 83 | task=eval_type, num_classes=self.num_classes, top_k=1, average = "macro" 84 | ) 85 | 86 | self.save_dir = save_dir 87 | self.train_losses = [] 88 | self.avg_train_losses = [] 89 | self.performance = {} 90 | self.f1scores = {} 91 | 92 | def forward(self, x): 93 | return self.model(x) 94 | 95 | def shared_step(self, batch: Tensor, stage: str = "train", batch_idx: int =None): 96 | if stage == "train": 97 | img, y, _ = batch 98 | if len(y.shape)>1: 99 | y = y[:,self.task] 100 | cls, _ = self.model(img,return_rep=True) 101 | logits = self.classifier(cls.detach()) 102 | 103 | loss_ce = self.online_classifier_loss(logits.squeeze(),y.squeeze()) 104 | self.log(f"final_{stage}_classifier_loss_{self.evaluated_epoch}_mlp", loss_ce, sync_dist=True) 105 | 106 | accuracy_metric = getattr(self, f"online_{stage}_accuracy") 107 | accuracy_metric(self.online_logit_fn(logits.squeeze()), y.squeeze()) 108 | self.log( 109 | f"final_{stage}_accuracy_{self.evaluated_epoch}_mlp", 110 | accuracy_metric, 111 | prog_bar=False, 112 | sync_dist=True, 113 | ) 114 | 115 | self.train_losses.append(loss_ce.item()) 116 | self.avg_train_losses.append(np.mean(self.train_losses)) 117 | 118 | if (self.current_epoch+1)%10==0 and batch_idx==0: 119 | plot_loss(self.avg_train_losses,name_loss="X-Entropy",save_dir=self.save_dir,name_file=f"_eval_train_{self.evaluated_epoch}_mlp") 120 | return loss_ce 121 | 122 | else: 123 | img, y = batch 124 | 125 | if len(y.shape)>1: 126 | y = y[:,self.task] 127 | 128 | cls, attentions = self.model(img,return_rep=True,output_attentions=True) 129 | logits = self.classifier(cls.detach()) 130 | 131 | accuracy_metric = getattr(self, f"online_{stage}_accuracy") 132 | accuracy_metric(self.online_logit_fn(logits.squeeze()), y.squeeze()) 133 | self.log( 134 | f"final_{stage}_accuracy_{self.evaluated_epoch}_mlp", 135 | accuracy_metric, 136 | prog_bar=True, 137 | sync_dist=True, 138 | on_epoch=True, 139 | on_step=False, 140 | ) 141 | 142 | if batch_idx == 0 and (self.current_epoch+1) not in list(self.performance.keys()): 143 | self.performance[self.current_epoch+1]=[] 144 | if len(y.squeeze().shape) >1: 145 | self.f1scores[self.current_epoch+1]=[] 146 | 147 | if len(y.squeeze().shape) > 1: 148 | f1_metric = getattr(self, f"online_{stage}_f1") 149 | f1_score = f1_metric(self.online_logit_fn(logits.squeeze()), y.squeeze()) 150 | self.performance[self.current_epoch+1].append(sum(1*((self.online_logit_fn(logits.squeeze())>0.5)==y.squeeze())).detach().cpu().numpy()) 151 | self.f1scores[self.current_epoch+1].append(f1_score.detach().cpu().numpy()) 152 | 153 | else: 154 | self.performance[self.current_epoch+1].append(sum(1*(torch.argmax(logits.squeeze(), dim=-1)==y.squeeze())).item()) 155 | 156 | # check the attention we get at final 157 | if self.current_epoch==0 and batch_idx==0: 158 | attentions = attentions[-1].mean(1) 159 | att_map_cls = attentions[:,0,1:] 160 | att_map_spatial = torch.mean(attentions[:,1:,1:],dim=-1) 161 | att_map_cls = att_map_cls.reshape([img.shape[0],int(np.sqrt(att_map_cls.shape[-1])),int(np.sqrt(att_map_cls.shape[-1]))]) 162 | att_map_spatial = att_map_spatial.reshape([img.shape[0],int(np.sqrt(att_map_spatial.shape[-1])),int(np.sqrt(att_map_spatial.shape[-1]))]) 163 | save_attention_maps(img[:10],att_map_cls[:10].unsqueeze(1),att_map_spatial[:10].unsqueeze(1),self.current_epoch+1, self.save_dir,f"eval_{self.evaluated_epoch}_mlp") 164 | save_attention_maps_batch(att_map_cls=att_map_cls,att_map_spatial=att_map_spatial,epoch=self.current_epoch+1, output_dir=self.save_dir,name=f"eval_{self.evaluated_epoch}_mlp") 165 | 166 | return None 167 | 168 | def on_validation_epoch_end(self): 169 | if isinstance(self.performance[self.current_epoch+1][0],np.ndarray): 170 | self.performance[self.current_epoch+1] = np.array(self.performance[self.current_epoch+1]) 171 | self.f1scores[self.current_epoch+1] = np.array(self.f1scores[self.current_epoch+1]) 172 | 173 | self.performance[self.current_epoch+1] = sum(self.performance[self.current_epoch+1])/self.datamodule.num_val_samples 174 | 175 | if isinstance(self.performance[self.current_epoch+1],np.ndarray): 176 | self.performance[self.current_epoch+1] = np.mean(self.performance[self.current_epoch+1]) 177 | self.f1scores[self.current_epoch+1] = np.mean(self.f1scores[self.current_epoch+1]) 178 | 179 | if (self.current_epoch+1)%10 == 0 and not isinstance(self.performance[self.current_epoch+1],np.ndarray): 180 | plot_performance(list(self.performance.keys()),list(self.performance.values()),self.save_dir,name=f"val_final_{self.evaluated_epoch}_mlp") 181 | 182 | def on_fit_end(self): 183 | # Write to a CSV file 184 | with open(os.path.join(self.save_dir,f'performance_final_{self.evaluated_epoch}_mlp_task_{self.task}.csv'), 'w', newline='') as csvfile: 185 | writer = csv.writer(csvfile) 186 | writer.writerow(['Eval Epoch', 'Test Accuracy']) 187 | for epoch in list(self.performance.keys()): 188 | # Assuming you have the accuracy for each epoch stored in a list 189 | writer.writerow([epoch, round(100*self.performance[epoch],2)]) 190 | 191 | if len(self.f1scores.keys()) > 1 : 192 | with open(os.path.join(self.save_dir,f'f1scores_final_{self.evaluated_epoch}_mlp_task_{self.task}.csv'), 'w', newline='') as csvfile: 193 | writer = csv.writer(csvfile) 194 | writer.writerow(['Eval Epoch', 'Test Accuracy']) 195 | for epoch in list(self.f1scores.keys()): 196 | # Assuming you have the accuracy for each epoch stored in a list 197 | writer.writerow([epoch, round(100*self.f1scores[epoch],2)]) 198 | return 199 | 200 | def training_step(self, batch, batch_idx): 201 | loss = self.shared_step(batch, stage="train", batch_idx=batch_idx) 202 | return loss 203 | 204 | def validation_step(self, batch, batch_idx): 205 | loss = self.shared_step(batch, stage="val", batch_idx=batch_idx) 206 | return loss 207 | 208 | def test_step(self, batch, batch_idx): 209 | loss = self.shared_step(batch, stage="test", batch_idx=batch_idx) 210 | return loss 211 | 212 | def configure_optimizers(self): 213 | def warmup(current_step: int): 214 | return 1 / (10 ** (float(num_warmup_epochs - current_step))) 215 | 216 | if self.optimizer_name == "adamw": 217 | optimizer = torch.optim.AdamW( 218 | self.classifier.parameters(), 219 | lr=self.learning_rate, 220 | weight_decay=self.weight_decay, 221 | ) 222 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 223 | optimizer=optimizer, T_max=self.trainer.max_epochs, verbose=False 224 | ) 225 | elif self.optimizer_name == "adamw_warmup": 226 | num_warmup_epochs = self.warm_up 227 | optimizer = torch.optim.AdamW( 228 | self.classifier.parameters(), 229 | lr=self.learning_rate, 230 | weight_decay=self.weight_decay, 231 | betas=self.betas 232 | ) 233 | warmup_scheduler = torch.optim.lr_scheduler.LambdaLR( 234 | optimizer, lr_lambda=warmup 235 | ) 236 | 237 | train_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 238 | optimizer=optimizer, T_max=self.trainer.max_epochs, verbose=False 239 | ) 240 | 241 | lr_scheduler = torch.optim.lr_scheduler.SequentialLR( 242 | optimizer, [warmup_scheduler, train_scheduler], [num_warmup_epochs] 243 | ) 244 | elif self.optimizer_name == "sgd": 245 | num_warmup_epochs = self.warm_up 246 | optimizer = torch.optim.SGD( 247 | self.classifier.parameters(), 248 | lr=self.learning_rate, 249 | weight_decay=self.weight_decay, 250 | momentum=self.betas 251 | ) 252 | 253 | warmup_scheduler = torch.optim.lr_scheduler.LambdaLR( 254 | optimizer, lr_lambda=warmup 255 | ) 256 | 257 | train_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 258 | optimizer=optimizer, T_max=self.trainer.max_epochs, verbose=False 259 | ) 260 | 261 | lr_scheduler = torch.optim.lr_scheduler.SequentialLR( 262 | optimizer, [warmup_scheduler, train_scheduler], [num_warmup_epochs] 263 | ) 264 | 265 | else: 266 | raise ValueError(f"{self.optimizer_name} not supported") 267 | 268 | return [optimizer], [lr_scheduler] 269 | 270 | -------------------------------------------------------------------------------- /src/model/module_transfert.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module Name: module_transfert.py 3 | Author: Alice Bizeul 4 | Ownership: ETH Zürich - ETH AI Center 5 | """ 6 | 7 | # Standard library imports 8 | import os 9 | import time 10 | from typing import Any, Dict, List, Optional 11 | import csv 12 | 13 | # Third-party library imports 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | import pytorch_lightning as pl 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | import torchmetrics 21 | from torch import Tensor 22 | from torch.nn.parameter import Parameter 23 | from torchvision.models import resnet18 24 | import wandb 25 | 26 | # Local imports 27 | from ..plotting import plot_loss, plot_performance 28 | from ..utils import save_attention_maps, save_attention_maps_batch, save_reconstructed_images 29 | 30 | class ViTMAE_transfert(pl.LightningModule): 31 | 32 | def __init__( 33 | self, 34 | model, 35 | learning_rate: float = 1e-3, 36 | base_learning_rate: float =1e-3, 37 | weight_decay: float = 0.05, 38 | betas: list =[0.9,0.95], 39 | optimizer_name: str = "adamw", 40 | warmup: int =10, 41 | datamodule: Optional[pl.LightningDataModule] = None, 42 | save_dir: str =None, 43 | evaluated_epoch: int =800, 44 | eval_type ="multiclass", 45 | eval_fn =nn.CrossEntropyLoss(), 46 | eval_logit_fn = nn.Softmax(), 47 | task :int =0, 48 | ): 49 | super().__init__() 50 | self.learning_rate = learning_rate 51 | self.weight_decay = weight_decay 52 | self.betas = betas 53 | self.optimizer_name = optimizer_name 54 | self.datamodule = datamodule 55 | self.num_classes = datamodule.num_classes 56 | self.image_size = datamodule.image_size 57 | self.warm_up = warmup 58 | self.evaluated_epoch = evaluated_epoch 59 | self.task = task 60 | 61 | model.delete_decoder() 62 | self.model = model 63 | self.model.config.mask_ratio = 0.0 64 | self.model.vit.embeddings.config.mask_ratio=0.0 65 | self.classifier = nn.Linear(model.config.hidden_size, self.num_classes) 66 | 67 | self.online_classifier_loss = eval_fn 68 | self.online_logit_fn= eval_logit_fn 69 | self.online_train_accuracy = torchmetrics.Accuracy( 70 | task=eval_type, num_classes=self.num_classes, top_k=1 71 | ) 72 | self.online_val_accuracy = torchmetrics.Accuracy( 73 | task=eval_type, num_classes=self.num_classes, top_k=1 74 | ) 75 | self.online_val_f1 = torchmetrics.F1Score( 76 | task=eval_type, num_classes=self.num_classes, top_k=1, average = "macro" 77 | ) 78 | 79 | self.save_dir = save_dir 80 | self.train_losses = [] 81 | self.avg_train_losses = [] 82 | self.performance = {} 83 | self.f1scores = {} 84 | 85 | def forward(self, x): 86 | return self.model(x) 87 | 88 | def shared_step(self, batch: Tensor, stage: str = "train", batch_idx: int =None): 89 | if stage == "train": 90 | img, y, _ = batch 91 | if len(y.shape)>1: 92 | y = y[:,self.task] 93 | cls, _ = self.model(img,return_rep=True) 94 | logits = self.classifier(cls.detach()) 95 | 96 | loss_ce = self.online_classifier_loss(logits.squeeze(),y.squeeze()) 97 | self.log(f"final_{stage}_classifier_loss_{self.evaluated_epoch}_transfert_{self.datamodule.name}", loss_ce, sync_dist=True) 98 | 99 | accuracy_metric = getattr(self, f"online_{stage}_accuracy") 100 | accuracy_metric(self.online_logit_fn(logits.squeeze()), y.squeeze()) 101 | self.log( 102 | f"final_{stage}_accuracy_{self.evaluated_epoch}_transfert_{self.datamodule.name}", 103 | accuracy_metric, 104 | prog_bar=False, 105 | sync_dist=True, 106 | ) 107 | 108 | self.train_losses.append(loss_ce.item()) 109 | self.avg_train_losses.append(np.mean(self.train_losses)) 110 | 111 | if (self.current_epoch+1)%10==0 and batch_idx==0: 112 | plot_loss(self.avg_train_losses,name_loss="X-Entropy",save_dir=self.save_dir,name_file=f"_eval_train_{self.evaluated_epoch}_transfert_{self.datamodule.name}") 113 | return loss_ce 114 | 115 | else: 116 | img, y = batch 117 | 118 | if len(y.shape)>1: 119 | y = y[:,self.task] 120 | 121 | cls, attentions = self.model(img,return_rep=True,output_attentions=True) 122 | logits = self.classifier(cls.detach()) 123 | 124 | accuracy_metric = getattr(self, f"online_{stage}_accuracy") 125 | accuracy_metric(self.online_logit_fn(logits.squeeze()), y.squeeze()) 126 | self.log( 127 | f"final_{stage}_accuracy_{self.evaluated_epoch}_transfert_{self.datamodule.name}", 128 | accuracy_metric, 129 | prog_bar=True, 130 | sync_dist=True, 131 | on_epoch=True, 132 | on_step=False, 133 | ) 134 | 135 | if batch_idx == 0 and (self.current_epoch+1) not in list(self.performance.keys()): 136 | self.performance[self.current_epoch+1]=[] 137 | if len(y.squeeze().shape) >1: 138 | self.f1scores[self.current_epoch+1]=[] 139 | 140 | if len(y.squeeze().shape) > 1: 141 | f1_metric = getattr(self, f"online_{stage}_f1") 142 | f1_score = f1_metric(self.online_logit_fn(logits.squeeze()), y.squeeze()) 143 | self.performance[self.current_epoch+1].append(sum(1*((self.online_logit_fn(logits.squeeze())>0.5)==y.squeeze())).detach().cpu().numpy()) 144 | self.f1scores[self.current_epoch+1].append(f1_score.detach().cpu().numpy()) 145 | 146 | else: 147 | self.performance[self.current_epoch+1].append(sum(1*(torch.argmax(logits.squeeze(), dim=-1)==y.squeeze())).item()) 148 | 149 | # check the attention we get at final 150 | if self.current_epoch==0 and batch_idx==0: 151 | attentions = attentions[-1].mean(1) 152 | att_map_cls = attentions[:,0,1:] 153 | att_map_spatial = torch.mean(attentions[:,1:,1:],dim=-1) 154 | att_map_cls = att_map_cls.reshape([img.shape[0],int(np.sqrt(att_map_cls.shape[-1])),int(np.sqrt(att_map_cls.shape[-1]))]) 155 | att_map_spatial = att_map_spatial.reshape([img.shape[0],int(np.sqrt(att_map_spatial.shape[-1])),int(np.sqrt(att_map_spatial.shape[-1]))]) 156 | save_attention_maps(img[:10],att_map_cls[:10].unsqueeze(1),att_map_spatial[:10].unsqueeze(1),self.current_epoch+1, self.save_dir,f"eval_{self.evaluated_epoch}_transfert_{self.datamodule.name}") 157 | save_attention_maps_batch(att_map_cls=att_map_cls,att_map_spatial=att_map_spatial,epoch=self.current_epoch+1, output_dir=self.save_dir,name=f"eval_{self.evaluated_epoch}_transfert_{self.datamodule.name}") 158 | 159 | return None 160 | 161 | def on_validation_epoch_end(self): 162 | if isinstance(self.performance[self.current_epoch+1][0],np.ndarray): 163 | self.performance[self.current_epoch+1] = np.array(self.performance[self.current_epoch+1]) 164 | self.f1scores[self.current_epoch+1] = np.array(self.f1scores[self.current_epoch+1]) 165 | 166 | self.performance[self.current_epoch+1] = sum(self.performance[self.current_epoch+1])/self.datamodule.num_val_samples 167 | 168 | if isinstance(self.performance[self.current_epoch+1],np.ndarray): 169 | self.performance[self.current_epoch+1] = np.mean(self.performance[self.current_epoch+1]) 170 | self.f1scores[self.current_epoch+1] = np.mean(self.f1scores[self.current_epoch+1]) 171 | 172 | if (self.current_epoch+1)%10 == 0 and not isinstance(self.performance[self.current_epoch+1],np.ndarray): 173 | plot_performance(list(self.performance.keys()),list(self.performance.values()),self.save_dir,name=f"val_final_{self.evaluated_epoch}_transfert_{self.datamodule.name}") 174 | 175 | def on_fit_end(self): 176 | # Write to a CSV file 177 | with open(os.path.join(self.save_dir,f'performance_final_{self.evaluated_epoch}_transfert_task_{self.task}_{self.datamodule.name}.csv'), 'w', newline='') as csvfile: 178 | writer = csv.writer(csvfile) 179 | writer.writerow(['Eval Epoch', 'Test Accuracy']) 180 | for epoch in list(self.performance.keys()): 181 | # Assuming you have the accuracy for each epoch stored in a list 182 | writer.writerow([epoch, round(100*self.performance[epoch],2)]) 183 | 184 | if len(self.f1scores.keys()) > 1 : 185 | with open(os.path.join(self.save_dir,f'f1scores_final_{self.evaluated_epoch}_transfert_task_{self.task}_{self.datamodule.name}.csv'), 'w', newline='') as csvfile: 186 | writer = csv.writer(csvfile) 187 | writer.writerow(['Eval Epoch', 'Test Accuracy']) 188 | for epoch in list(self.f1scores.keys()): 189 | # Assuming you have the accuracy for each epoch stored in a list 190 | writer.writerow([epoch, round(100*self.f1scores[epoch],2)]) 191 | return 192 | 193 | def training_step(self, batch, batch_idx): 194 | loss = self.shared_step(batch, stage="train", batch_idx=batch_idx) 195 | return loss 196 | 197 | def validation_step(self, batch, batch_idx): 198 | loss = self.shared_step(batch, stage="val", batch_idx=batch_idx) 199 | return loss 200 | 201 | def test_step(self, batch, batch_idx): 202 | loss = self.shared_step(batch, stage="test", batch_idx=batch_idx) 203 | return loss 204 | 205 | def configure_optimizers(self): 206 | def warmup(current_step: int): 207 | return 1 / (10 ** (float(num_warmup_epochs - current_step))) 208 | 209 | if self.optimizer_name == "adamw": 210 | optimizer = torch.optim.AdamW( 211 | self.classifier.parameters(), 212 | lr=self.learning_rate, 213 | weight_decay=self.weight_decay, 214 | ) 215 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 216 | optimizer=optimizer, T_max=self.trainer.max_epochs, verbose=False 217 | ) 218 | elif self.optimizer_name == "adamw_warmup": 219 | num_warmup_epochs = self.warm_up 220 | optimizer = torch.optim.AdamW( 221 | self.classifier.parameters(), 222 | lr=self.learning_rate, 223 | weight_decay=self.weight_decay, 224 | betas=self.betas 225 | ) 226 | warmup_scheduler = torch.optim.lr_scheduler.LambdaLR( 227 | optimizer, lr_lambda=warmup 228 | ) 229 | 230 | train_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 231 | optimizer=optimizer, T_max=self.trainer.max_epochs, verbose=False 232 | ) 233 | 234 | lr_scheduler = torch.optim.lr_scheduler.SequentialLR( 235 | optimizer, [warmup_scheduler, train_scheduler], [num_warmup_epochs] 236 | ) 237 | elif self.optimizer_name == "sgd": 238 | num_warmup_epochs = self.warm_up 239 | optimizer = torch.optim.SGD( 240 | self.classifier.parameters(), 241 | lr=self.learning_rate, 242 | weight_decay=self.weight_decay, 243 | momentum=self.betas 244 | ) 245 | 246 | warmup_scheduler = torch.optim.lr_scheduler.LambdaLR( 247 | optimizer, lr_lambda=warmup 248 | ) 249 | 250 | train_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 251 | optimizer=optimizer, T_max=self.trainer.max_epochs, verbose=False 252 | ) 253 | 254 | lr_scheduler = torch.optim.lr_scheduler.SequentialLR( 255 | optimizer, [warmup_scheduler, train_scheduler], [num_warmup_epochs] 256 | ) 257 | 258 | else: 259 | raise ValueError(f"{self.optimizer_name} not supported") 260 | 261 | return [optimizer], [lr_scheduler] 262 | 263 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module Name: main.py 3 | Author: Alice Bizeul, Some portions of the code below were taken from prior codebases written by Mark Ibrahim and Randall Balestreiro 4 | Ownership: ETH Zürich - ETH AI Center 5 | """ 6 | # Standard library imports 7 | import logging 8 | import os 9 | from pathlib import Path 10 | from typing import Callable, Dict, Iterable, List, Optional, Tuple 11 | import yaml 12 | 13 | # Third-party library imports 14 | import git 15 | import matplotlib.pyplot as plt 16 | import numpy as np 17 | import pytorch_lightning as pl 18 | import submitit 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | import torch.optim.lr_scheduler as lr_scheduler 23 | from omegaconf import DictConfig, OmegaConf 24 | from PIL import Image 25 | from sklearn.decomposition import PCA 26 | from torch.optim.optimizer import Optimizer 27 | from torch.utils.data import DataLoader, Dataset 28 | import torchvision 29 | from torchvision import datasets, transforms 30 | from torchvision.datasets import ImageFolder 31 | from pytorch_lightning.loggers import WandbLogger 32 | from pytorch_lightning.utilities import rank_zero_only 33 | 34 | # Hydra imports 35 | from hydra.utils import instantiate 36 | 37 | def setup_wandb( 38 | config: DictConfig, 39 | log: logging.Logger, 40 | git_hash: str = "", 41 | extra_configs: dict = dict(), 42 | ) -> WandbLogger: 43 | 44 | log_job_info(log) 45 | config_dict = yaml.safe_load(OmegaConf.to_yaml(config, resolve=True)) 46 | job_logs_dir = os.getcwd() 47 | 48 | # increase timeout per wandb folks' suggestion 49 | os.environ["WANDB_INIT_TIMEOUT"] = "60" 50 | os.environ["WANDB_DIR"] = config.wandb_dir 51 | os.environ["WANDB_DATA_DIR"] = config.wandb_datadir 52 | os.environ["WANDB_CACHE_DIR"] = config.wandb_cachedir 53 | os.environ["WANDB_CONFIG_DIR"] = config.wandb_configdir 54 | 55 | config_dict["job_logs_dir"] = job_logs_dir 56 | config_dict["git_hash"] = git_hash 57 | 58 | name = ( 59 | config.wandb.tags 60 | + "_" 61 | + config.module._target_.split(".")[-1] 62 | + "_" 63 | + config.datamodule._target_.split(".")[-1] 64 | ) 65 | config_dict.update(extra_configs) 66 | 67 | try: 68 | wandb_logger = WandbLogger( 69 | name=name, 70 | config=config_dict, 71 | settings={"start_method": "fork"}, 72 | **config.wandb, 73 | ) 74 | except Exception as e: 75 | print(f"exception: {e}") 76 | print("starting wandb in offline mode. To sync logs run") 77 | print(f"wandb sync {job_logs_dir}") 78 | os.environ["WANDB_MODE"] = "offline" 79 | wandb_logger = WandbLogger( 80 | name=name, 81 | config=config_dict, 82 | settings={"start_method": "fork"}, 83 | **config.wandb, 84 | ) 85 | return wandb_logger 86 | 87 | def get_git_hash() -> Optional[str]: 88 | try: 89 | repo = git.Repo(search_parent_directories=True) 90 | sha = repo.head.object.hexsha 91 | return sha 92 | except: 93 | print("not able to find git hash") 94 | 95 | 96 | @rank_zero_only 97 | def print_config( 98 | config: DictConfig, 99 | resolve: bool = True, 100 | ) -> None: 101 | """Saves and prints content of DictConfig 102 | Args: 103 | config (DictConfig): Configuration composed by Hydra. 104 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 105 | """ 106 | run_configs = OmegaConf.to_yaml(config, resolve=resolve) 107 | with open("run_configs.yaml", "w") as f: 108 | OmegaConf.save(config=config, f=f) 109 | 110 | 111 | def log_job_info(log: logging.Logger): 112 | """Logs info about the job directory and SLURM job id""" 113 | job_logs_dir = os.getcwd() 114 | log.info(f"Logging to {job_logs_dir}") 115 | job_id = "local" 116 | 117 | try: 118 | job_env = submitit.JobEnvironment() 119 | job_id = job_env.job_id 120 | except RuntimeError: 121 | pass 122 | log.info(f"job id {job_id}") 123 | 124 | 125 | def find_existing_checkpoint(dirpath: str) -> Optional[str]: 126 | """Searches dirpath for an existing model checkpoint. 127 | If found, returns its path. 128 | """ 129 | ckpts = list(Path(dirpath).rglob("*.ckpt")) 130 | if ckpts: 131 | ckpt = str(ckpts[-1]) 132 | print(f"resuming from existing checkpoint: {ckpt}") 133 | return ckpt 134 | return None 135 | 136 | def load_checkpoints(model, config): 137 | if config.f is not None: 138 | print("------------------ Trying to load checkpoint from",config.f) 139 | try: 140 | model.load_state_dict(instantiate(config)["state_dict"],strict=False) 141 | attempt=1 142 | except: 143 | try: 144 | model.load_state_dict(instantiate(config)["model_state_dict"],strict=False) 145 | attempt = 2 146 | except: 147 | attempt=3 148 | print('------------------ Loaded checkpoint following attempt',attempt," - model is ",type(model)) 149 | return model 150 | 151 | # Define the function to save images and their reconstructions 152 | def save_reconstructed_images(input, target, reconstructed, epoch, output_dir, name): 153 | os.makedirs(output_dir, exist_ok=True) 154 | input_grid = torchvision.utils.make_grid(input[:8].cpu(), nrow=4, normalize=True) 155 | target_grid = torchvision.utils.make_grid(target[:8].cpu(), nrow=4, normalize=True) 156 | reconstructed_grid = torchvision.utils.make_grid(reconstructed[:8].cpu(), nrow=4, normalize=True) 157 | 158 | _, axes = plt.subplots(1, 3, figsize=(15, 5)) 159 | axes[0].imshow(input_grid.permute(1, 2, 0)) 160 | axes[0].set_title('Input Images') 161 | axes[0].axis('off') 162 | 163 | axes[1].imshow(target_grid.permute(1, 2, 0)) 164 | axes[1].set_title('Target Images') 165 | axes[1].axis('off') 166 | 167 | axes[2].imshow(reconstructed_grid.permute(1, 2, 0)) 168 | axes[2].set_title('Reconstructed Images') 169 | axes[2].axis('off') 170 | 171 | plt.savefig(os.path.join(output_dir, f'epoch_{epoch}_{name}.png')) 172 | plt.close() 173 | 174 | # Define the function to save images and their reconstructions 175 | def save_attention_maps(input, attention_cls, attention_spatial, epoch, output_dir, name): 176 | os.makedirs(output_dir, exist_ok=True) 177 | input_grid = torchvision.utils.make_grid(input[:8].cpu(), nrow=4, normalize=True) 178 | cls_grid = torchvision.utils.make_grid(attention_cls[:8].cpu(), nrow=4, normalize=True) 179 | spatial_grid = torchvision.utils.make_grid(attention_spatial[:8].cpu(), nrow=4, normalize=True) 180 | 181 | fig, axes = plt.subplots(1, 3, figsize=(15, 5)) 182 | axes[0].imshow(input_grid.permute(1, 2, 0)) 183 | axes[0].set_title('Input Images') 184 | axes[0].axis('off') 185 | 186 | im1= axes[1].imshow(cls_grid.permute(1, 2, 0),cmap='gray') 187 | axes[1].set_title('CLS Attention Maps') 188 | axes[1].axis('off') 189 | fig.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04) # Add color bar for CLS Attention Maps 190 | 191 | im2 = axes[2].imshow(spatial_grid.permute(1, 2, 0),cmap='gray') 192 | axes[2].set_title('Average Spatial Attention Maps') 193 | axes[2].axis('off') 194 | fig.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04) # Add color bar for CLS Attention Maps 195 | 196 | plt.savefig(os.path.join(output_dir, f'epoch_{epoch}_{name}_attention.png')) 197 | plt.close() 198 | 199 | def save_attention_maps_batch(att_map_cls, att_map_spatial, epoch, output_dir, name): 200 | # average over batch 201 | att_map_cls = torch.mean(att_map_cls.detach().cpu(),dim=0) 202 | att_map_spatial = torch.mean(att_map_spatial.detach().cpu(),dim=0) 203 | 204 | fig, axes = plt.subplots(1, 2, figsize=(15, 5)) 205 | im1 = axes[0].imshow(att_map_cls.unsqueeze(0).permute(1, 2, 0),cmap='viridis') 206 | axes[0].set_title('CLS Attention Maps') 207 | axes[0].axis('off') 208 | fig.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04) # Add color bar for CLS Attention Maps 209 | 210 | im2 = axes[1].imshow(att_map_spatial.unsqueeze(0).permute(1, 2, 0),cmap='viridis') 211 | axes[1].set_title('Average Spatial Attention Maps') 212 | axes[1].axis('off') 213 | fig.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04) # Add color bar for CLS Attention Maps 214 | 215 | plt.savefig(os.path.join(output_dir, f'epoch_{epoch}_{name}_attention_batchavg.png')) 216 | plt.close() 217 | 218 | class PCImageDataset(Dataset): 219 | def __init__(self, folder, pc_path, eigen_path, transform=None, ): 220 | """ 221 | Initialize the dataset with two root directories and an optional transform. 222 | 223 | :param root1: Root directory for the first dataset. 224 | :param root2: Root directory for the second dataset. 225 | :param transform: Transformations to apply to the images. 226 | """ 227 | self.dataset1 = ImageFolder(root=folder) 228 | try: 229 | self.pc_matrix = np.load(pc_path) 230 | self.eigenvalues = np.load(eigen_path) 231 | except: 232 | print(f"The path {pc_path} does not exist") 233 | self.transform = transform 234 | 235 | def __len__(self): 236 | return len(self.dataset1) 237 | 238 | def __getitem__(self, idx): 239 | 240 | # Load the images 241 | img1, _ = self.dataset1[idx] 242 | 243 | # Apply transformations if provided 244 | if self.transform: 245 | img1 = self.transform(img1) 246 | img2 = self.transform(img2) 247 | 248 | return img1, img2 249 | 250 | class Normalize(torch.nn.Module): 251 | """Normalize a tensor image with mean and standard deviation. 252 | This transform does not support PIL Image. 253 | Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n`` 254 | channels, this transform will normalize each channel of the input 255 | ``torch.*Tensor`` i.e., 256 | ``output[channel] = (input[channel] - mean[channel]) / std[channel]`` 257 | 258 | .. note:: 259 | This transform acts out of place, i.e., it does not mutate the input tensor. 260 | 261 | Args: 262 | mean (sequence): Sequence of means for each channel. 263 | std (sequence): Sequence of standard deviations for each channel. 264 | inplace(bool,optional): Bool to make this operation in-place. 265 | 266 | """ 267 | 268 | def __init__(self, mean, std): 269 | super().__init__() 270 | self.mean = mean 271 | self.std = std 272 | 273 | def forward(self, tensor): 274 | """ 275 | Args: 276 | tensor (Tensor): Tensor image to be normalized. 277 | 278 | Returns: 279 | Tensor: Normalized Tensor image. 280 | """ 281 | return (tensor - self.mean)/self.std 282 | 283 | 284 | def __repr__(self) -> str: 285 | return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})" 286 | 287 | 288 | def get_eigenvalues(data): 289 | pca = PCA() # You can adjust the number of components 290 | 291 | if len(data.shape)!=2: 292 | data = data.reshape(data.shape[0],*data.shape[1:]) 293 | pca.fit(data) 294 | 295 | return pca.explained_variance_ 296 | 297 | class LinearWarmupScheduler: 298 | def __init__(self, optimizer, warmup_epochs, total_epochs, target_lr): 299 | self.optimizer = optimizer 300 | self.warmup_epochs = warmup_epochs 301 | self.total_epochs = total_epochs 302 | self.target_lr = target_lr 303 | self.base_lr = 0.0 304 | self.annealing_scheduler = lr_scheduler.CosineAnnealingLR(optimizer, total_epochs - warmup_epochs, eta_min=0) 305 | 306 | def step(self, epoch): 307 | if epoch < self.warmup_epochs: 308 | lr = self.base_lr + (self.target_lr - self.base_lr) * (epoch / self.warmup_epochs) 309 | for param_group in self.optimizer.param_groups: 310 | param_group['lr'] = lr 311 | else: 312 | self.annealing_scheduler.step(epoch - self.warmup_epochs) 313 | 314 | class Lars(Optimizer): 315 | r"""Implements the LARS optimizer from `"Large batch training of convolutional networks" 316 | `_. 317 | Code taken from: https://github.com/NUS-HPC-AI-Lab/InfoBatch/blob/master/examples/lars.py 318 | Args: 319 | params (iterable): iterable of parameters to optimize or dicts defining 320 | parameter groups 321 | lr (float, optional): learning rate 322 | momentum (float, optional): momentum factor (default: 0) 323 | eeta (float, optional): LARS coefficient as used in the paper (default: 1e-3) 324 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 325 | """ 326 | 327 | def __init__( 328 | self, 329 | params: Iterable[torch.nn.Parameter], 330 | lr=1e-3, 331 | momentum=0, 332 | eeta=1e-3, 333 | weight_decay=0, 334 | epsilon=0.0 335 | ) -> None: 336 | if not isinstance(lr, float) or lr < 0.0: 337 | raise ValueError("Invalid learning rate: {}".format(lr)) 338 | if momentum < 0.0: 339 | raise ValueError("Invalid momentum value: {}".format(momentum)) 340 | if weight_decay < 0.0: 341 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 342 | if eeta <= 0: 343 | raise ValueError("Invalid eeta value: {}".format(eeta)) 344 | if epsilon < 0: 345 | raise ValueError("Invalid epsilon value: {}".format(epsilon)) 346 | defaults = dict(lr=lr, momentum=momentum, 347 | weight_decay=weight_decay, eeta=eeta, epsilon=epsilon, lars=True) 348 | 349 | super().__init__(params, defaults) 350 | 351 | def set_decay(self,weight_decay): 352 | for group in self.param_groups: 353 | group['weight_decay'] = weight_decay 354 | 355 | @torch.no_grad() 356 | def step(self, closure=None): 357 | """Performs a single optimization step. 358 | Arguments: 359 | closure (callable, optional): A closure that reevaluates the model 360 | and returns the loss. 361 | """ 362 | loss = None 363 | if closure is not None: 364 | with torch.enable_grad(): 365 | loss = closure() 366 | 367 | for group in self.param_groups: 368 | weight_decay = group['weight_decay'] 369 | momentum = group['momentum'] 370 | eeta = group['eeta'] 371 | lr = group['lr'] 372 | lars = group['lars'] 373 | eps = group['epsilon'] 374 | 375 | for p in group['params']: 376 | if p.grad is None: 377 | continue 378 | decayed_grad = p.grad 379 | scaled_lr = lr 380 | if lars: 381 | w_norm = torch.norm(p) 382 | g_norm = torch.norm(p.grad) 383 | trust_ratio = torch.where( 384 | w_norm > 0 and g_norm > 0, 385 | eeta * w_norm / (g_norm + weight_decay * w_norm + eps), 386 | torch.ones_like(w_norm) 387 | ) 388 | trust_ratio.clamp_(0.0, 50) 389 | scaled_lr *= trust_ratio.item() 390 | if weight_decay != 0: 391 | decayed_grad = decayed_grad.add(p, alpha=weight_decay) 392 | decayed_grad = torch.clamp(decayed_grad, -10.0, 10.0) 393 | 394 | if momentum != 0: 395 | param_state = self.state[p] 396 | if 'momentum_buffer' not in param_state: 397 | buf = param_state['momentum_buffer'] = torch.clone( 398 | decayed_grad).detach() 399 | else: 400 | buf = param_state['momentum_buffer'] 401 | buf.mul_(momentum).add_(decayed_grad) 402 | decayed_grad = buf 403 | 404 | p.add_(decayed_grad, alpha=-scaled_lr) 405 | 406 | return loss 407 | --------------------------------------------------------------------------------