├── SFM-Finetune ├── modules │ ├── modeling │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── aspp.cpython-36.pyc │ │ │ ├── aspp.cpython-37.pyc │ │ │ ├── decoder.cpython-36.pyc │ │ │ ├── decoder.cpython-37.pyc │ │ │ ├── deeplab.cpython-36.pyc │ │ │ ├── deeplab.cpython-37.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── Unet_models.cpython-36.pyc │ │ │ └── Unet_models.cpython-37.pyc │ │ ├── backbone │ │ │ ├── __pycache__ │ │ │ │ ├── drn.cpython-36.pyc │ │ │ │ ├── drn.cpython-37.pyc │ │ │ │ ├── __init__.cpython-36.pyc │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ ├── resnet.cpython-36.pyc │ │ │ │ ├── resnet.cpython-37.pyc │ │ │ │ ├── xception.cpython-36.pyc │ │ │ │ ├── xception.cpython-37.pyc │ │ │ │ ├── mobilenet.cpython-36.pyc │ │ │ │ └── mobilenet.cpython-37.pyc │ │ │ ├── __init__.py │ │ │ ├── mobilenet.py │ │ │ ├── resnet.py │ │ │ └── xception.py │ │ ├── sync_batchnorm │ │ │ ├── __pycache__ │ │ │ │ ├── comm.cpython-36.pyc │ │ │ │ ├── comm.cpython-37.pyc │ │ │ │ ├── __init__.cpython-36.pyc │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ ├── batchnorm.cpython-36.pyc │ │ │ │ ├── batchnorm.cpython-37.pyc │ │ │ │ ├── replicate.cpython-36.pyc │ │ │ │ └── replicate.cpython-37.pyc │ │ │ ├── __init__.py │ │ │ ├── replicate.py │ │ │ └── comm.py │ │ ├── decoder.py │ │ ├── aspp.py │ │ └── deeplab.py │ └── __pycache__ │ │ ├── comm.cpython-36.pyc │ │ ├── comm.cpython-37.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── batchnorm.cpython-36.pyc │ │ ├── batchnorm.cpython-37.pyc │ │ ├── replicate.cpython-36.pyc │ │ └── replicate.cpython-37.pyc ├── util │ ├── __pycache__ │ │ ├── lars.cpython-36.pyc │ │ ├── misc.cpython-36.pyc │ │ ├── misc.cpython-37.pyc │ │ ├── msssim.cpython-36.pyc │ │ ├── msssim.cpython-37.pyc │ │ ├── tools.cpython-36.pyc │ │ ├── tools.cpython-37.pyc │ │ ├── datasets.cpython-36.pyc │ │ ├── datasets.cpython-37.pyc │ │ ├── lr_decay.cpython-36.pyc │ │ ├── lr_decay.cpython-37.pyc │ │ ├── lr_sched.cpython-36.pyc │ │ ├── lr_sched.cpython-37.pyc │ │ ├── metrics.cpython-36.pyc │ │ ├── pos_embed.cpython-36.pyc │ │ └── pos_embed.cpython-37.pyc │ ├── lr_sched.py │ ├── crop.py │ ├── lars.py │ ├── lr_decay.py │ ├── metrics.py │ ├── tools.py │ ├── pos_embed.py │ ├── msssim.py │ ├── pos_embedtest.py │ ├── datasets.py │ └── misc.py ├── finetune-Facies.sh ├── Application │ ├── finetune-Denoise.sh │ ├── finetune-Reflect.sh │ ├── finetune-Salt.sh │ ├── finetune-Interpolation.sh │ └── README.md ├── README.md ├── submitit_finetune.py ├── models_Regression.py ├── engine_finetune.py ├── models_mae.py └── models_Segmentation.py ├── requirements.txt ├── assert ├── Network.png └── SeismicPretrainedModel.png ├── SFM-Pretrain ├── slurmjob.sh ├── submit-train.sh ├── train.sh ├── README.md ├── engine_pretrain.py ├── submitit_pretrain.py ├── main_pretrain.py └── models_mae.py ├── Data ├── README-Facies.md ├── README-Denoise.md ├── README-Geobody.md ├── README-Inversion.md ├── README-Interpolation.md └── README-Pretrain.md └── README.md /SFM-Finetune/modules/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm==0.3.2 2 | scikit-learn 3 | matplotlib 4 | tensorboard 5 | -------------------------------------------------------------------------------- /assert/Network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/assert/Network.png -------------------------------------------------------------------------------- /assert/SeismicPretrainedModel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/assert/SeismicPretrainedModel.png -------------------------------------------------------------------------------- /SFM-Finetune/util/__pycache__/lars.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/util/__pycache__/lars.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/util/__pycache__/misc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/util/__pycache__/misc.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/util/__pycache__/misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/util/__pycache__/misc.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/util/__pycache__/msssim.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/util/__pycache__/msssim.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/util/__pycache__/msssim.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/util/__pycache__/msssim.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/util/__pycache__/tools.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/util/__pycache__/tools.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/util/__pycache__/tools.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/util/__pycache__/tools.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/__pycache__/comm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/__pycache__/comm.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/__pycache__/comm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/__pycache__/comm.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/util/__pycache__/datasets.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/util/__pycache__/datasets.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/util/__pycache__/datasets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/util/__pycache__/datasets.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/util/__pycache__/lr_decay.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/util/__pycache__/lr_decay.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/util/__pycache__/lr_decay.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/util/__pycache__/lr_decay.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/util/__pycache__/lr_sched.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/util/__pycache__/lr_sched.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/util/__pycache__/lr_sched.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/util/__pycache__/lr_sched.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/util/__pycache__/metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/util/__pycache__/metrics.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/util/__pycache__/pos_embed.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/util/__pycache__/pos_embed.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/util/__pycache__/pos_embed.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/util/__pycache__/pos_embed.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/__pycache__/batchnorm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/__pycache__/batchnorm.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/__pycache__/batchnorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/__pycache__/batchnorm.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/__pycache__/replicate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/__pycache__/replicate.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/__pycache__/replicate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/__pycache__/replicate.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/__pycache__/aspp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/__pycache__/aspp.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/__pycache__/aspp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/__pycache__/aspp.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/__pycache__/decoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/__pycache__/decoder.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/__pycache__/decoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/__pycache__/decoder.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/__pycache__/deeplab.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/__pycache__/deeplab.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/__pycache__/deeplab.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/__pycache__/deeplab.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/__pycache__/Unet_models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/__pycache__/Unet_models.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/__pycache__/Unet_models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/__pycache__/Unet_models.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/backbone/__pycache__/drn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/backbone/__pycache__/drn.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/backbone/__pycache__/drn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/backbone/__pycache__/drn.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/backbone/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/backbone/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/backbone/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/backbone/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/backbone/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/backbone/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/backbone/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/backbone/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/backbone/__pycache__/xception.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/backbone/__pycache__/xception.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/backbone/__pycache__/xception.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/backbone/__pycache__/xception.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/backbone/__pycache__/mobilenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/backbone/__pycache__/mobilenet.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/backbone/__pycache__/mobilenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/backbone/__pycache__/mobilenet.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/sync_batchnorm/__pycache__/comm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/sync_batchnorm/__pycache__/comm.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/sync_batchnorm/__pycache__/comm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/sync_batchnorm/__pycache__/comm.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/sync_batchnorm/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/sync_batchnorm/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/sync_batchnorm/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/sync_batchnorm/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/sync_batchnorm/__pycache__/replicate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/sync_batchnorm/__pycache__/replicate.cpython-36.pyc -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/sync_batchnorm/__pycache__/replicate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenghanlin/SeismicFoundationModel/HEAD/SFM-Finetune/modules/modeling/sync_batchnorm/__pycache__/replicate.cpython-37.pyc -------------------------------------------------------------------------------- /SFM-Pretrain/slurmjob.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -J SFM 3 | #SBATCH -p GPU-8A100 4 | #SBATCH --cpus-per-task=15 5 | #SBATCH --gres=gpu:4 6 | #SBATCH -N 1 7 | #SBATCH -t 30:00 8 | #SBATCH --qos=gpu_8a100 9 | ./train.sh 10 | -------------------------------------------------------------------------------- /SFM-Pretrain/submit-train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python submitit_pretrain.py \ 3 | --job_dir './output_base_gpu4/' \ 4 | --batch_size 580\ 5 | --accum_iter 4 \ 6 | --model mae_vit_base_patch16D4d256 \ 7 | --mask_ratio 0.75 \ 8 | --epochs 1600 \ 9 | --warmup_epochs 40 \ 10 | --blr 1.5e-4 --weight_decay 0.05 \ 11 | --data_path '../Data/Pretrain/mae_data_more/' 12 | 13 | -------------------------------------------------------------------------------- /SFM-Pretrain/train.sh: -------------------------------------------------------------------------------- 1 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=4 main_pretrain.py \ 2 | --batch_size 580\ 3 | --accum_iter 4 \ 4 | --model mae_vit_base_patch16D4d256 \ 5 | --mask_ratio 0.75 \ 6 | --epochs 1600 \ 7 | --warmup_epochs 40 \ 8 | --blr 1.5e-4 --weight_decay 0.05 \ 9 | --data_path '../Data/Pretrain/mae_data_more/' 10 | --output_dir './output_model/' \ 11 | --log_dir './output_model/' \ -------------------------------------------------------------------------------- /Data/README-Facies.md: -------------------------------------------------------------------------------- 1 | # 🌟 Seismic Facies 2 | 3 | Data link [DatFile] 4 | 5 | The folder '''Facies''' contains 117 768*768 seismic data ('''seismic''') and the corresponding labels ('''label'''). The first 100 data are used as training set and the last 17 are used as validation dataset. 6 | 7 | All dat files are float32 binary files. 8 | 9 |
10 |
11 | # License 12 | This project is released under the [MIT license](LICENSE). 13 | 14 | -------------------------------------------------------------------------------- /Data/README-Denoise.md: -------------------------------------------------------------------------------- 1 | # 🌟 Seismic Denoise 2 | 3 | Data link [DatFile] 4 | 5 | The folder '''Denoise'' contains 2000 224*224 seismic data ('''seismic''') and the corresponding labels ('''label'''). Besides, the folder '''Denoise''' contains 4000 224*224 validation seismic data ('''field'''). 6 | 7 | All dat files are float32 binary files. 8 | 9 |
10 |
11 | # License 12 | This project is released under the [MIT license](LICENSE). 13 | 14 | -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /Data/README-Geobody.md: -------------------------------------------------------------------------------- 1 | # 🌟 Seismic geobody (Salt) 2 | 3 | Data link [DatFile] 4 | 5 | The folder '''Geobody''' contains 4000 224*224 seismic data ('''seismic''') and the corresponding labels ('''label'''). The first 3500 data are used as training set and the last 500 are used as validation set. 6 | 7 | All dat files are float32 binary files. 8 | 9 |
10 |
11 | # License 12 | This project is released under the [MIT license](LICENSE). 13 | 14 | -------------------------------------------------------------------------------- /Data/README-Inversion.md: -------------------------------------------------------------------------------- 1 | # 🌟 Seismic Inversion 2 | 3 | Data link [DatFile] 4 | 5 | The folder '''Inversion''' contains 2200 224*224 training seismic data ('''seismic''') and the corresponding labels ('''label'''). Besides, the folder '''Inversion''' contains 5000 224*224 validation seismic data ('''SEAMseismic''') and the corresponding labels ('''SEAMreflect'''). 6 | 7 | All dat files are float32 binary files. 8 | 9 |
10 |
11 | # License 12 | This project is released under the [MIT license](LICENSE). 13 | 14 | -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from modules.modeling.backbone import resnet, xception, drn, mobilenet 2 | 3 | def build_backbone(backbone, output_stride, BatchNorm): 4 | if backbone == 'resnet': 5 | return resnet.ResNet101(output_stride, BatchNorm) 6 | elif backbone == 'xception': 7 | return xception.AlignedXception(output_stride, BatchNorm) 8 | elif backbone == 'drn': 9 | return drn.drn_d_54(BatchNorm) 10 | elif backbone == 'mobilenet': 11 | return mobilenet.MobileNetV2(output_stride, BatchNorm) 12 | else: 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /SFM-Finetune/finetune-Facies.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES='5' OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=1 main_finetune.py \ 2 | --data_path '../../Data/Facies/' \ 3 | --accum_iter 2 \ 4 | --batch_size 1 \ 5 | --model vit_base_patch16 \ 6 | --finetune '../SFM-Pretrain/output_dir/Base-512.pth'\ 7 | --output_dir './Application/Facies/modelbase_512/' \ 8 | --log_dir './Application/Facies/modelbase_512/' \ 9 | --epochs 100 \ 10 | --warmup_epochs 10 \ 11 | --blr 1.5e-3 --weight_decay 0.05 \ 12 | --layer_decay 0.05 --drop_path 0.1 --reprob 0.25 \ 13 | --dist_eval 14 | 15 | -------------------------------------------------------------------------------- /Data/README-Interpolation.md: -------------------------------------------------------------------------------- 1 | # 🌟 Seismic Interpolation 2 | 3 | Data link [DatFile] 4 | 5 | The folder '''Interpolation''' contains 8000 224*224 seismic data. The first 3500 data are used as training set and the last 500 are used as validation set. The 6000 data starting with a number in the data is the training dataset and the 2000 data with the letter U plus a number is the validation dataset. 6 | 7 | All dat files are float32 binary files. 8 | 9 |
10 |
11 | # License 12 | This project is released under the [MIT license](LICENSE). 13 | 14 | -------------------------------------------------------------------------------- /SFM-Finetune/Application/finetune-Denoise.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES='5' OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=1 main_finetune.py \ 2 | --data_path '../Data/Denoise/' \ 3 | --task 'Denoise' \ 4 | --accum_iter 1 \ 5 | --batch_size 60 \ 6 | --input_size 224 \ 7 | --model vit_base_patch16 \ 8 | --output_dir './finetune_result/Denoise/net_Base_scratch/' \ 9 | --log_dir './finetune_result/Denoise/net_Base_scratch/' \ 10 | --epochs 100 \ 11 | --warmup_epochs 10 \ 12 | --lr 1.5e-3 --weight_decay 0.05 \ 13 | --layer_decay 0.05 --drop_path 0.1 --reprob 0.25 \ 14 | --dist_eval \ 15 | 16 | -------------------------------------------------------------------------------- /SFM-Finetune/Application/finetune-Reflect.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES='1' OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=1 main_finetune.py \ 2 | --data_path '../Data/Inversion/' \ 3 | --task 'Reflection' \ 4 | --accum_iter 1 \ 5 | --batch_size 60 \ 6 | --input_size 224 \ 7 | --model vit_base_patch16 \ 8 | --output_dir './finetune_result/Reflection/net_Base_scratch/' \ 9 | --log_dir './finetune_result/Reflection/net_Base_scratch/' \ 10 | --epochs 100 \ 11 | --warmup_epochs 10 \ 12 | --lr 1.5e-4 --weight_decay 0.05 \ 13 | --layer_decay 0.05 --drop_path 0.1 --reprob 0.25 \ 14 | --dist_eval \ 15 | 16 | -------------------------------------------------------------------------------- /SFM-Finetune/Application/finetune-Salt.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES='2' OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=1 main_finetune.py \ 2 | --data_path ../Data/Geobody' \ 3 | --task 'Salt' \ 4 | --accum_iter 1 \ 5 | --batch_size 64 \ 6 | --input_size 224 \ 7 | --model vit_large_patch16 \ 8 | --finetune './output_dir_more/Large-1600.pth' \ 9 | --output_dir './finetune_result/Salt/sert_Transformer_test/' \ 10 | --log_dir './finetune_result/Salt/sert_Transformer_test/' \ 11 | --epochs 100 \ 12 | --warmup_epochs 10 \ 13 | --blr 1.5e-3 --weight_decay 0.05 \ 14 | --layer_decay 0.75 --drop_path 0.1 --reprob 0.25 \ 15 | --dist_eval 16 | 17 | -------------------------------------------------------------------------------- /SFM-Finetune/Application/finetune-Interpolation.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES='0' OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=1 main_finetune.py \ 2 | --data_path '../Data/Interpolation/' \ 3 | --task 'Interpolation' \ 4 | --accum_iter 1 \ 5 | --batch_size 50 \ 6 | --input_size 224 \ 7 | --model vit_base_patch16 \ 8 | --resume './finetune_result/Interpolation/net_Transformer_16_800epoch/checkpoint-99.pth' \ 9 | --output_dir './finetune_result/Interpolation/net_Transformer_16_800epoch/' \ 10 | --log_dir './finetune_result/Interpolation/net_Transformer_16_800epoch/' \ 11 | --epochs 800 \ 12 | --warmup_epochs 30 \ 13 | --blr 1.5e-4 --weight_decay 0.75 \ 14 | --layer_decay 0.5 --drop_path 0.1 --reprob 0.25 \ 15 | --dist_eval \ 16 | 17 | -------------------------------------------------------------------------------- /Data/README-Pretrain.md: -------------------------------------------------------------------------------- 1 | # 🌟 Seismic pretrain 2 | 3 | Data link [DatFile] 4 | 5 | The folder '''mae_data_more'' contains 2286422 224*224 seismic data. Limited by the size of the uploaded file, we split the zip file into eight sub-files. 6 | When decompressing on '''windows''', you only need to decompress the mae_data_moreb.zip to parse the other volumes together. 7 | 8 | When decompressing on '''Linux''', you need to use the following command to synthesize a whole file before decompressing it. 9 | ‘’‘ 10 | zip -s 0 mae_data_more.zip --out pretrain.zip 11 | unzip pretrain.zip 12 | ’‘’ 13 | 14 | All dat files are float32 binary files. 15 | 16 |
17 |
18 | # License 19 | This project is released under the [MIT license](LICENSE). 20 | 21 | -------------------------------------------------------------------------------- /SFM-Finetune/util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /SFM-Finetune/util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | 14 | 15 | class RandomResizedCrop(transforms.RandomResizedCrop): 16 | """ 17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 18 | This may lead to results different with torchvision's version. 19 | Following BYOL's TF code: 20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 21 | """ 22 | @staticmethod 23 | def get_params(img, scale, ratio): 24 | width, height = F._get_image_size(img) 25 | area = height * width 26 | 27 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 28 | log_ratio = torch.log(torch.tensor(ratio)) 29 | aspect_ratio = torch.exp( 30 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 31 | ).item() 32 | 33 | w = int(round(math.sqrt(target_area * aspect_ratio))) 34 | h = int(round(math.sqrt(target_area / aspect_ratio))) 35 | 36 | w = min(w, width) 37 | h = min(h, height) 38 | 39 | i = torch.randint(0, height - h + 1, size=(1,)).item() 40 | j = torch.randint(0, width - w + 1, size=(1,)).item() 41 | 42 | return i, j, h, w -------------------------------------------------------------------------------- /SFM-Finetune/util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /SFM-Pretrain/README.md: -------------------------------------------------------------------------------- 1 | ## Seismic Foundation Model - Pre-train 2 | 3 |

4 | 5 |

6 | 7 | This is a PyTorch/GPU implementation of the paper [Seismic Foundation Model](https://arxiv.org/abs/2309.02791): 8 | ``` 9 | @article{sheng2023seismic, 10 | title={Seismic Foundation Model (SFM): a new generation deep learning model in geophysics}, 11 | author={Sheng, Hanlin and Wu, Xinming and Si, Xu and Li, Jintao and Zhang, Sibio and Duan, Xudong}, 12 | journal={arXiv preprint arXiv:2309.02791}, 13 | year={2023} 14 | } 15 | ``` 16 | * This repo is a modification on the [MAE](https://github.com/facebookresearch/mae). Installation and preparation follow that repo. 17 | 18 | * This repo is based on [`timm==0.3.2`](https://github.com/rwightman/pytorch-image-models), for which a [fix](https://github.com/rwightman/pytorch-image-models/issues/420#issuecomment-776459842) is needed to work with PyTorch 1.8.1+. 19 | 20 | ## Pre-training MAE 21 | 22 | To pre-train SFM-Base/Large with **multi-node distributed training**, run the ```./submit-train.sh``` : 23 | ``` 24 | python submitit_pretrain.py \ 25 | --job_dir ${JOB_DIR} \ 26 | --batch_size 580\ 27 | --accum_iter 4 \ 28 | --model mae_vit_base_patch16D4d256 \ 29 | --mask_ratio 0.75 \ 30 | --epochs 1600 \ 31 | --warmup_epochs 40 \ 32 | --blr 1.5e-4 --weight_decay 0.05 \ 33 | --data_path ${DATA_DIR} 34 | ``` 35 | - Here the effective batch size is 580 (`batch_size` per gpu) * 4 (`nodes`) * 4 (gpus per node) = 9280. If memory or # gpus is limited, use `--accum_iter` to maintain the effective batch size, which is `batch_size` (per gpu) * `nodes` * 4 (gpus per node) * `accum_iter`. 36 | - `blr` is the base learning rate. The actual `lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `lr` = `blr` * effective batch size / 256. 37 | 38 | To pre-train SFM-Base/Large with **single node**, run the ```./train.sh``` 39 | ### License 40 | 41 | This project is under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for details. 42 | -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/decoder.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # References: 3 | # Deeplab v3+: https://github.com/jfzhang95/pytorch-deeplab-xception 4 | # -------------------------------------------------------- 5 | 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from modules.modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 11 | 12 | class Decoder(nn.Module): 13 | def __init__(self, num_classes, backbone, BatchNorm): 14 | super(Decoder, self).__init__() 15 | if backbone == 'resnet' or backbone == 'drn': 16 | low_level_inplanes = 256 17 | elif backbone == 'xception': 18 | low_level_inplanes = 128 19 | elif backbone == 'mobilenet': 20 | low_level_inplanes = 24 21 | else: 22 | raise NotImplementedError 23 | 24 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 25 | self.bn1 = BatchNorm(48) 26 | self.relu = nn.ReLU() 27 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 28 | BatchNorm(256), 29 | nn.ReLU(), 30 | nn.Dropout(0.5), 31 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 32 | BatchNorm(256), 33 | nn.ReLU(), 34 | nn.Dropout(0.1), 35 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1)) 36 | self._init_weight() 37 | 38 | 39 | def forward(self, x, low_level_feat): 40 | low_level_feat = self.conv1(low_level_feat) 41 | low_level_feat = self.bn1(low_level_feat) 42 | low_level_feat = self.relu(low_level_feat) 43 | 44 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) 45 | x = torch.cat((x, low_level_feat), dim=1) 46 | x = self.last_conv(x) 47 | 48 | return x 49 | 50 | def _init_weight(self): 51 | for m in self.modules(): 52 | if isinstance(m, nn.Conv2d): 53 | torch.nn.init.kaiming_normal_(m.weight) 54 | elif isinstance(m, SynchronizedBatchNorm2d): 55 | m.weight.data.fill_(1) 56 | m.bias.data.zero_() 57 | elif isinstance(m, nn.BatchNorm2d): 58 | m.weight.data.fill_(1) 59 | m.bias.data.zero_() 60 | 61 | def build_decoder(num_classes, backbone, BatchNorm): 62 | return Decoder(num_classes, backbone, BatchNorm) -------------------------------------------------------------------------------- /SFM-Finetune/util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers -------------------------------------------------------------------------------- /SFM-Finetune/README.md: -------------------------------------------------------------------------------- 1 | ## Seismic Foundation Model - Fine-tune 2 | 3 |

4 | 5 |

6 | 7 | This is a PyTorch/GPU implementation of the paper [Seismic Foundation Model](https://arxiv.org/abs/2309.02791): 8 | ``` 9 | @article{sheng2023seismic, 10 | title={Seismic Foundation Model (SFM): a new generation deep learning model in geophysics}, 11 | author={Sheng, Hanlin and Wu, Xinming and Si, Xu and Li, Jintao and Zhang, Sibio and Duan, Xudong}, 12 | journal={arXiv preprint arXiv:2309.02791}, 13 | year={2023} 14 | } 15 | ``` 16 | * This repo is a modification on the [MAE](https://github.com/facebookresearch/mae). Installation and preparation follow that repo. 17 | 18 | * This repo is based on [`timm==0.3.2`](https://github.com/rwightman/pytorch-image-models), for which a [fix](https://github.com/rwightman/pytorch-image-models/issues/420#issuecomment-776459842) is needed to work with PyTorch 1.8.1+. 19 | 20 | 21 | ## Fine-tune 22 | 23 | To fine-tune the Downsteam Task**, run the ```./finetune-SEAM.sh``` : 24 | ``` 25 | CUDA_VISIBLE_DEVICES='5' OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=1 main_finetune.py \ 26 | --data_path ${DATA_DIR} \ 27 | --accum_iter 2 \ 28 | --batch_size 1 \ 29 | --model vit_base_patch16 \ 30 | --finetune './output_dir_more/Base-512.pth'\ 31 | --output_dir './finetune_result/SEAM/modelbase_512/' \ 32 | --log_dir './finetune_result/SEAM/modelbase_512/' \ 33 | --epochs 100 \ 34 | --warmup_epochs 10 \ 35 | --blr 1.5e-3 --weight_decay 0.05 \ 36 | --layer_decay 0.05 --drop_path 0.1 --reprob 0.25 \ 37 | --dist_eval 38 | ``` 39 | - Here the effective batch size is 64 (`batch_size` per gpu) * 8 (`nodes`) * 8 (gpus per node) = 4096. If memory or # gpus is limited, use `--accum_iter` to maintain the effective batch size, which is `batch_size` (per gpu) * `nodes` * 8 (gpus per node) * `accum_iter`. 40 | - `blr` is the base learning rate. The actual `lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `lr` = `blr` * effective batch size / 256. 41 | -`finetune` can be used to loaded pre-trained model. 42 | - `modelComparsion` can be placed in a bash file to train ['Unet'](https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets)/['Deeplab'](https://github.com/jfzhang95/pytorch-deeplab-xception) models with the same parameters. 43 | - `forzen` can be used to freeze the loaded pre-trained model parameters. 44 | 45 | ## Visualization 46 | You can use `Application\finetune_results_visualization.ipynb` to show the results 47 | 48 | ## Fine-tune On your own task 49 | 50 | 1. You can mimic the file `util/datasets.py` by first writing a Dataloader 51 | 2. Add your Dataloader and corresponding Task to the file `main_finetune.py`. 52 | 3. Add your task to the discriminant statement of the specified model. 53 | 4. You can mimic the file `Application/*.sh` to set hyperparameters for experimentation. 54 | 5. Start Training! 55 | 56 | 57 | 58 | ### License 59 | 60 | This project is under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for details. 61 | -------------------------------------------------------------------------------- /SFM-Finetune/Application/README.md: -------------------------------------------------------------------------------- 1 | ## Seismic Foundation Model - Fine-tune 2 | 3 |

4 | 5 |

6 | 7 | This is a PyTorch/GPU implementation of the paper [Seismic Foundation Model](https://arxiv.org/abs/2309.02791): 8 | ``` 9 | @misc{sheng2023seismic, 10 | title={Seismic Foundation Model (SFM): a new generation deep learning model in geophysics}, 11 | author={Hanlin Sheng and Xinming Wu and Xu Si and Jintao Li and Sibio Zhang and Xudong Duan}, 12 | year={2023}, 13 | eprint={2309.02791}, 14 | archivePrefix={arXiv}, 15 | primaryClass={physics.geo-ph} 16 | } 17 | ``` 18 | * This repo is a modification on the [MAE](https://github.com/facebookresearch/mae). Installation and preparation follow that repo. 19 | 20 | * This repo is based on [`timm==0.3.2`](https://github.com/rwightman/pytorch-image-models), for which a [fix](https://github.com/rwightman/pytorch-image-models/issues/420#issuecomment-776459842) is needed to work with PyTorch 1.8.1+. 21 | 22 | 23 | ## Fine-tune 24 | 25 | To fine-tune the Downsteam Task**, run the ```./finetune-SEAM.sh``` : 26 | ``` 27 | CUDA_VISIBLE_DEVICES='5' OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=1 main_finetune.py \ 28 | --data_path ${DATA_DIR} \ 29 | --accum_iter 2 \ 30 | --batch_size 1 \ 31 | --model vit_base_patch16 \ 32 | --finetune './output_dir_more/Base-512.pth'\ 33 | --output_dir './finetune_result/SEAM/modelbase_512/' \ 34 | --log_dir './finetune_result/SEAM/modelbase_512/' \ 35 | --epochs 100 \ 36 | --warmup_epochs 10 \ 37 | --blr 1.5e-3 --weight_decay 0.05 \ 38 | --layer_decay 0.05 --drop_path 0.1 --reprob 0.25 \ 39 | --dist_eval 40 | ``` 41 | - Here the effective batch size is 64 (`batch_size` per gpu) * 8 (`nodes`) * 8 (gpus per node) = 4096. If memory or # gpus is limited, use `--accum_iter` to maintain the effective batch size, which is `batch_size` (per gpu) * `nodes` * 8 (gpus per node) * `accum_iter`. 42 | - `blr` is the base learning rate. The actual `lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `lr` = `blr` * effective batch size / 256. 43 | -`finetune` can be used to loaded pre-trained model. 44 | - `modelComparsion` can be placed in a bash file to train ['Unet'](https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets)/['Deeplab'](https://github.com/jfzhang95/pytorch-deeplab-xception) models with the same parameters. 45 | - `forzen` can be used to freeze the loaded pre-trained model parameters. 46 | 47 | ## Visualization 48 | You can use `Application\finetune_results_visualization.ipynb` to show the results 49 | 50 | ## Fine-tune On your own task 51 | 52 | 1. You can mimic the file `util/datasets.py` by first writing a Dataloader 53 | 2. Add your Dataloader and corresponding Task to the file `main_finetune.py`. 54 | 3. Add your task to the discriminant statement of the specified model. 55 | 4. You can mimic the file `Application\*.sh` to set hyperparameters for experimentation. 56 | 5. Start Training! 57 | 58 | 59 | 60 | ### License 61 | 62 | This project is under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for details. 63 | -------------------------------------------------------------------------------- /SFM-Pretrain/engine_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # MAE: https://github.com/facebookresearch/mae 11 | # -------------------------------------------------------- 12 | import math 13 | import sys 14 | from typing import Iterable 15 | 16 | import torch 17 | import util.misc as misc 18 | import util.lr_sched as lr_sched 19 | 20 | 21 | def train_one_epoch(model: torch.nn.Module, 22 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 23 | device: torch.device, epoch: int, loss_scaler, 24 | log_writer=None, 25 | args=None): 26 | model.train(True) 27 | metric_logger = misc.MetricLogger(delimiter=" ") 28 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 29 | header = 'Epoch: [{}]'.format(epoch) 30 | print_freq = 50 31 | 32 | accum_iter = args.accum_iter 33 | 34 | optimizer.zero_grad() 35 | 36 | if log_writer is not None: 37 | print('log_dir: {}'.format(log_writer.log_dir)) 38 | 39 | for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 40 | 41 | # we use a per iteration (instead of per epoch) lr scheduler 42 | if data_iter_step % accum_iter == 0: 43 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 44 | 45 | samples = samples.to(device, non_blocking=True) 46 | 47 | with torch.cuda.amp.autocast(enabled=False): 48 | loss, _, _ = model(samples, mask_ratio=args.mask_ratio) 49 | 50 | loss_value = loss.item() 51 | 52 | if not math.isfinite(loss_value): 53 | print("Loss is {}, stopping training".format(loss_value)) 54 | sys.exit(1) 55 | 56 | loss /= accum_iter 57 | loss_scaler(loss, optimizer, parameters=model.parameters(), 58 | update_grad=(data_iter_step + 1) % accum_iter == 0) 59 | if (data_iter_step + 1) % accum_iter == 0: 60 | optimizer.zero_grad() 61 | 62 | torch.cuda.synchronize() 63 | 64 | metric_logger.update(loss=loss_value) 65 | 66 | lr = optimizer.param_groups[0]["lr"] 67 | metric_logger.update(lr=lr) 68 | 69 | loss_value_reduce = misc.all_reduce_mean(loss_value) 70 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 71 | """ We use epoch_1000x as the x-axis in tensorboard. 72 | This calibrates different curves when batch size changes. 73 | """ 74 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 75 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) 76 | log_writer.add_scalar('lr', lr, epoch_1000x) 77 | 78 | 79 | # gather the stats from all processes 80 | metric_logger.synchronize_between_processes() 81 | print("Averaged stats:", metric_logger) 82 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # References: 3 | # Deeplab v3+: https://github.com/jfzhang95/pytorch-deeplab-xception 4 | # -------------------------------------------------------- 5 | 6 | # This file is part of Synchronized-BatchNorm-PyTorch. 7 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 8 | # Distributed under MIT License. 9 | 10 | import functools 11 | 12 | from torch.nn.parallel.data_parallel import DataParallel 13 | 14 | __all__ = [ 15 | 'CallbackContext', 16 | 'execute_replication_callbacks', 17 | 'DataParallelWithCallback', 18 | 'patch_replication_callback' 19 | ] 20 | 21 | 22 | class CallbackContext(object): 23 | pass 24 | 25 | 26 | def execute_replication_callbacks(modules): 27 | """ 28 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 29 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 30 | Note that, as all modules are isomorphism, we assign each sub-module with a context 31 | (shared among multiple copies of this module on different devices). 32 | Through this context, different copies can share some information. 33 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 34 | of any slave copies. 35 | """ 36 | master_copy = modules[0] 37 | nr_modules = len(list(master_copy.modules())) 38 | ctxs = [CallbackContext() for _ in range(nr_modules)] 39 | 40 | for i, module in enumerate(modules): 41 | for j, m in enumerate(module.modules()): 42 | if hasattr(m, '__data_parallel_replicate__'): 43 | m.__data_parallel_replicate__(ctxs[j], i) 44 | 45 | 46 | class DataParallelWithCallback(DataParallel): 47 | """ 48 | Data Parallel with a replication callback. 49 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 50 | original `replicate` function. 51 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 52 | Examples: 53 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 54 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 55 | # sync_bn.__data_parallel_replicate__ will be invoked. 56 | """ 57 | 58 | def replicate(self, module, device_ids): 59 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 60 | execute_replication_callbacks(modules) 61 | return modules 62 | 63 | 64 | def patch_replication_callback(data_parallel): 65 | """ 66 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 67 | Useful when you have customized `DataParallel` implementation. 68 | Examples: 69 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 70 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 71 | > patch_replication_callback(sync_bn) 72 | # this is equivalent to 73 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 74 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 75 | """ 76 | 77 | assert isinstance(data_parallel, DataParallel) 78 | 79 | old_replicate = data_parallel.replicate 80 | 81 | @functools.wraps(old_replicate) 82 | def new_replicate(module, device_ids): 83 | modules = old_replicate(module, device_ids) 84 | execute_replication_callbacks(modules) 85 | return modules 86 | 87 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /SFM-Finetune/util/metrics.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | refer to https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/utils/metrics.py 4 | """ 5 | import numpy as np 6 | 7 | __all__ = ['SegmentationMetric'] 8 | 9 | """ 10 | confusionMetric # 注意:此处横着代表预测值,竖着代表真实值,与之前介绍的相反 11 | P\L P N 12 | P TP FP 13 | N FN TN 14 | """ 15 | 16 | 17 | class SegmentationMetric(object): 18 | def __init__(self, numClass): 19 | self.numClass = numClass 20 | self.confusionMatrix = np.zeros((self.numClass,) * 2) # 混淆矩阵(空) 21 | 22 | def pixelAccuracy(self): 23 | # return all class overall pixel accuracy 正确的像素占总像素的比例 24 | # PA = acc = (TP + TN) / (TP + TN + FP + TN) 25 | acc = np.diag(self.confusionMatrix).sum() / self.confusionMatrix.sum() 26 | return acc 27 | 28 | def classPixelAccuracy(self): 29 | # return each category pixel accuracy(A more accurate way to call it precision) 30 | # acc = (TP) / TP + FP 31 | classAcc = np.diag(self.confusionMatrix) / self.confusionMatrix.sum(axis=1) 32 | return classAcc # 返回的是一个列表值,如:[0.90, 0.80, 0.96],表示类别1 2 3各类别的预测准确率 33 | 34 | def meanPixelAccuracy(self): 35 | """ 36 | Mean Pixel Accuracy(MPA,均像素精度):是PA的一种简单提升,计算每个类内被正确分类像素数的比例,之后求所有类的平均。 37 | :return: 38 | """ 39 | classAcc = self.classPixelAccuracy() 40 | meanAcc = np.nanmean(classAcc) # np.nanmean 求平均值,nan表示遇到Nan类型,其值取为0 41 | return meanAcc # 返回单个值,如:np.nanmean([0.90, 0.80, 0.96, nan, nan]) = (0.90 + 0.80 + 0.96) / 3 = 0.89 42 | 43 | def IntersectionOverUnion(self): 44 | # Intersection = TP Union = TP + FP + FN 45 | # IoU = TP / (TP + FP + FN) 46 | intersection = np.diag(self.confusionMatrix) # 取对角元素的值,返回列表 47 | union = np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) - np.diag( 48 | self.confusionMatrix) # axis = 1表示混淆矩阵行的值,返回列表; axis = 0表示取混淆矩阵列的值,返回列表 49 | IoU = intersection / union # 返回列表,其值为各个类别的IoU 50 | return IoU 51 | 52 | def meanIntersectionOverUnion(self): 53 | mIoU = np.nanmean(self.IntersectionOverUnion()) # 求各类别IoU的平均 54 | return mIoU 55 | 56 | def genConfusionMatrix(self, imgPredict, imgLabel): # 57 | """ 58 | 同FCN中score.py的fast_hist()函数,计算混淆矩阵 59 | :param imgPredict: 60 | :param imgLabel: 61 | :return: 混淆矩阵 62 | """ 63 | # remove classes from unlabeled pixels in gt image and predict 64 | mask = (imgLabel >= 0) & (imgLabel < self.numClass) 65 | label = self.numClass * imgLabel[mask] + imgPredict[mask] 66 | count = np.bincount(label, minlength=self.numClass ** 2) 67 | confusionMatrix = count.reshape(self.numClass, self.numClass) 68 | # print(confusionMatrix) 69 | return confusionMatrix 70 | 71 | def Frequency_Weighted_Intersection_over_Union(self): 72 | """ 73 | FWIoU,频权交并比:为MIoU的一种提升,这种方法根据每个类出现的频率为其设置权重。 74 | FWIOU = [(TP+FN)/(TP+FP+TN+FN)] *[TP / (TP + FP + FN)] 75 | """ 76 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) 77 | iu = np.diag(self.confusion_matrix) / ( 78 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 79 | np.diag(self.confusion_matrix)) 80 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 81 | return FWIoU 82 | 83 | def addBatch(self, imgPredict, imgLabel): 84 | assert imgPredict.shape == imgLabel.shape 85 | self.confusionMatrix += self.genConfusionMatrix(imgPredict, imgLabel) # 得到混淆矩阵 86 | return self.confusionMatrix 87 | 88 | def reset(self): 89 | self.confusionMatrix = np.zeros((self.numClass, self.numClass)) 90 | 91 | -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/aspp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from modules.modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class _ASPPModule(nn.Module): 8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): 9 | super(_ASPPModule, self).__init__() 10 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 11 | stride=1, padding=padding, dilation=dilation, bias=False) 12 | self.bn = BatchNorm(planes) 13 | self.relu = nn.ReLU() 14 | 15 | self._init_weight() 16 | 17 | def forward(self, x): 18 | x = self.atrous_conv(x) 19 | x = self.bn(x) 20 | 21 | return self.relu(x) 22 | 23 | def _init_weight(self): 24 | for m in self.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | torch.nn.init.kaiming_normal_(m.weight) 27 | elif isinstance(m, SynchronizedBatchNorm2d): 28 | m.weight.data.fill_(1) 29 | m.bias.data.zero_() 30 | elif isinstance(m, nn.BatchNorm2d): 31 | m.weight.data.fill_(1) 32 | m.bias.data.zero_() 33 | 34 | class ASPP(nn.Module): 35 | def __init__(self, backbone, output_stride, BatchNorm): 36 | super(ASPP, self).__init__() 37 | if backbone == 'drn': 38 | inplanes = 512 39 | elif backbone == 'mobilenet': 40 | inplanes = 320 41 | else: 42 | inplanes = 2048 43 | if output_stride == 16: 44 | dilations = [1, 6, 12, 18] 45 | elif output_stride == 8: 46 | dilations = [1, 12, 24, 36] 47 | else: 48 | raise NotImplementedError 49 | 50 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm) 51 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm) 52 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm) 53 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm) 54 | 55 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 56 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 57 | BatchNorm(256), 58 | nn.ReLU()) 59 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 60 | self.bn1 = BatchNorm(256) 61 | self.relu = nn.ReLU() 62 | self.dropout = nn.Dropout(0.5) 63 | self._init_weight() 64 | 65 | def forward(self, x): 66 | x1 = self.aspp1(x) 67 | x2 = self.aspp2(x) 68 | x3 = self.aspp3(x) 69 | x4 = self.aspp4(x) 70 | x5 = self.global_avg_pool(x) 71 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 72 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 73 | 74 | x = self.conv1(x) 75 | x = self.bn1(x) 76 | x = self.relu(x) 77 | 78 | return self.dropout(x) 79 | 80 | def _init_weight(self): 81 | for m in self.modules(): 82 | if isinstance(m, nn.Conv2d): 83 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 84 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 85 | torch.nn.init.kaiming_normal_(m.weight) 86 | elif isinstance(m, SynchronizedBatchNorm2d): 87 | m.weight.data.fill_(1) 88 | m.bias.data.zero_() 89 | elif isinstance(m, nn.BatchNorm2d): 90 | m.weight.data.fill_(1) 91 | m.bias.data.zero_() 92 | 93 | 94 | def build_aspp(backbone, output_stride, BatchNorm): 95 | return ASPP(backbone, output_stride, BatchNorm) -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/deeplab.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # References: 3 | # Deeplab v3+: https://github.com/jfzhang95/pytorch-deeplab-xception 4 | # -------------------------------------------------------- 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from modules.modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 10 | from modules.modeling.aspp import build_aspp 11 | from modules.modeling.decoder import build_decoder 12 | from modules.modeling.backbone import build_backbone 13 | 14 | class DeepLab(nn.Module): 15 | def __init__(self, backbone='resnet', output_stride=16, num_classes=21, 16 | sync_bn=True, freeze_bn=False,Mask=False): 17 | super(DeepLab, self).__init__() 18 | if backbone == 'drn': 19 | output_stride = 8 20 | 21 | if sync_bn == True: 22 | BatchNorm = SynchronizedBatchNorm2d 23 | else: 24 | BatchNorm = nn.BatchNorm2d 25 | if Mask == True: 26 | self.mask = Mask 27 | else: 28 | self.mask = None 29 | 30 | self.backbone = build_backbone(backbone, output_stride, BatchNorm) 31 | self.aspp = build_aspp(backbone, output_stride, BatchNorm) 32 | self.decoder = build_decoder(num_classes, backbone, BatchNorm) 33 | 34 | self.freeze_bn = freeze_bn 35 | 36 | def generate_mask(self,input_tensor, ratio): 37 | mask = torch.zeros_like(input_tensor) 38 | mask[:, :, :, torch.randperm(mask.size(3))[:int(mask.size(3) * ratio)]] = 1 39 | return mask 40 | 41 | def forward(self, input): 42 | if self.mask: 43 | mask = self.generate_mask(input,0.5) 44 | input = input*mask 45 | x, low_level_feat = self.backbone(input) 46 | x = self.aspp(x) 47 | x = self.decoder(x, low_level_feat) 48 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) 49 | 50 | return x 51 | 52 | def freeze_bn(self): 53 | for m in self.modules(): 54 | if isinstance(m, SynchronizedBatchNorm2d): 55 | m.eval() 56 | elif isinstance(m, nn.BatchNorm2d): 57 | m.eval() 58 | 59 | def get_1x_lr_params(self): 60 | modules = [self.backbone] 61 | for i in range(len(modules)): 62 | for m in modules[i].named_modules(): 63 | if self.freeze_bn: 64 | if isinstance(m[1], nn.Conv2d): 65 | for p in m[1].parameters(): 66 | if p.requires_grad: 67 | yield p 68 | else: 69 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 70 | or isinstance(m[1], nn.BatchNorm2d): 71 | for p in m[1].parameters(): 72 | if p.requires_grad: 73 | yield p 74 | 75 | def get_10x_lr_params(self): 76 | modules = [self.aspp, self.decoder] 77 | for i in range(len(modules)): 78 | for m in modules[i].named_modules(): 79 | if self.freeze_bn: 80 | if isinstance(m[1], nn.Conv2d): 81 | for p in m[1].parameters(): 82 | if p.requires_grad: 83 | yield p 84 | else: 85 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 86 | or isinstance(m[1], nn.BatchNorm2d): 87 | for p in m[1].parameters(): 88 | if p.requires_grad: 89 | yield p 90 | 91 | if __name__ == "__main__": 92 | model = DeepLab(backbone='mobilenet', output_stride=16) 93 | model.eval() 94 | input = torch.rand(1, 3, 513, 513) 95 | output = model(input) 96 | print(output.size()) 97 | 98 | 99 | -------------------------------------------------------------------------------- /SFM-Finetune/util/tools.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Jintao Li 3 | Date: 2022-05-30 16:42:14 4 | LastEditors: Jintao Li 5 | LastEditTime: 2022-07-11 23:05:53 6 | 2022 by CIG. 7 | ''' 8 | 9 | import os, shutil 10 | import yaml, argparse 11 | from sklearn.metrics import confusion_matrix 12 | import numpy as np 13 | import torch 14 | 15 | 16 | def accuracy(output, target): 17 | ''' 18 | output: [N, num_classes, ...], torch.float 19 | target: [N, ...], torch.int 20 | ''' 21 | output = output.argmax(dim=1).flatten().detach().cpu().numpy() 22 | target = target.flatten().detach().cpu().numpy() 23 | return _pxiel_acc(output, target), _miou(output, target) 24 | 25 | 26 | def _pxiel_acc(output, target): 27 | r""" 28 | 计算像素准确率 (Pixel Accuracy, PA) 29 | $$ PA = \frac{\sum_{i=0}^k p_{ii}} 30 | {\sum_{i=0}^k \sum_{j=0}^k p_{ij}} $$ and 31 | $n_class = k+1$ 32 | Parameters: 33 | ----------- 34 | shape: [N, ], (use flatten() function) 35 | return: 36 | ---------- 37 | - PA 38 | """ 39 | assert output.shape == target.shape, "shapes must be same" 40 | cm = confusion_matrix(target, output) 41 | return np.diag(cm).sum() / cm.sum() 42 | 43 | 44 | def _miou(output, target): 45 | r""" 46 | 计算均值交并比 MIoU (Mean Intersection over Union) 47 | $$ MIoU = \frac{1}{k+1} \sum_{i=0}^k \frac{p_{ii}} 48 | {\sum_{j=0}^k p_{ij} + \sum_{j=0}^k p_{ji} - p_{ii}} $$ 49 | Parameters: 50 | output, target: [N, ] 51 | return: 52 | MIoU 53 | """ 54 | assert output.shape == target.shape, "shapes must be same" 55 | cm = confusion_matrix(target, output) 56 | intersection = np.diag(cm) 57 | union = np.sum(cm, 1) + np.sum(cm, 0) - np.diag(cm) 58 | iou = intersection / union 59 | miou = np.nanmean(iou) 60 | 61 | return miou 62 | 63 | 64 | def yaml_config_hook(config_file: str) -> argparse.Namespace: 65 | """ 66 | 加载yaml文件里面的参数配置, 并生成argparse形式的参数集合 67 | """ 68 | with open(config_file) as f: 69 | cfg = yaml.safe_load(f) 70 | for d in cfg.get("defaults", []): 71 | config_dir, cf = d.popitem() 72 | cf = os.path.join(os.path.dirname(config_file), config_dir, 73 | cf + ".yaml") 74 | with open(cf) as f: 75 | l = yaml.safe_load(f) 76 | cfg.update(l) 77 | 78 | if "defaults" in cfg.keys(): 79 | del cfg["defaults"] 80 | 81 | parser = argparse.ArgumentParser() 82 | for k, v in cfg.items(): 83 | parser.add_argument(f"--{k}", default=v, type=type(v)) 84 | args = parser.parse_args() 85 | 86 | return args 87 | 88 | 89 | def backup_code(work_dir, back_dir, exceptions=[], include=[]): 90 | r""" 91 | 备份本次运行的代码到指定目录下, 并排除某些文件和目录 92 | 93 | Args: 94 | work_dir: 工作目录, i.e. 需要备份的代码 95 | back_dir: 目标目录.备份代码放置的目录 96 | exception (list): 被排除的目录和以指定后缀结尾的文件, 默认的有 97 | ["__pycache__", ".pyc", ".dat", "backup", ".vscode"] 98 | include (list): 某些必须被备份的文件,该文件可能在exception里面 99 | """ 100 | _exp = [ 101 | "*__pycache__*", "*.pyc", "*.dat", "backup", ".vscode", "*.log", 102 | "*log*" 103 | ] 104 | exceptions = exceptions + _exp 105 | 106 | # if not os.path.exists(back_dir): 107 | os.makedirs(back_dir, exist_ok=True) 108 | 109 | shutil.copytree(work_dir, 110 | back_dir + 'code/', 111 | ignore=shutil.ignore_patterns(*exceptions), 112 | dirs_exist_ok=True) 113 | 114 | for f in include: 115 | shutil.copyfile(os.path.join(work_dir, f), 116 | os.path.join(back_dir + 'code', f)) 117 | 118 | 119 | def list_files(path, full=False): 120 | r""" 121 | 递归列出目录下所有的文件,包括子目录下的文件 122 | """ 123 | out = [] 124 | for f in os.listdir(path): 125 | fname = os.path.join(path, f) 126 | if os.path.isdir(fname): 127 | fname = list_files(fname) 128 | out += [os.path.join(f, i) for i in fname] 129 | else: 130 | out.append(f) 131 | if full: 132 | out = [os.path.join(path, i) for i in out] 133 | return out 134 | 135 | 136 | if __name__ == "__main__": 137 | output = torch.randn(4, 2, 6, 6) 138 | target = torch.randn(4, 2, 6, 6) 139 | # output = output.cuda() 140 | # target = target.cuda() 141 | target = target.argmax(1) 142 | 143 | accuracy(output, target) -------------------------------------------------------------------------------- /SFM-Pretrain/submitit_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # A script to run multinode training with submitit. 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import os 12 | import uuid 13 | from pathlib import Path 14 | 15 | import main_pretrain as trainer 16 | import submitit 17 | 18 | 19 | def parse_args(): 20 | trainer_parser = trainer.get_args_parser() 21 | parser = argparse.ArgumentParser("Submitit for MAE pretrain", parents=[trainer_parser]) 22 | parser.add_argument("--ngpus", default=4, type=int, help="Number of gpus to request on each node") 23 | parser.add_argument("--nodes", default=1, type=int, help="Number of nodes to request") 24 | parser.add_argument("--timeout", default=4500, type=int, help="Duration of the job") 25 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 26 | parser.add_argument("--partition", default="GPU-8A100", type=str, help="Partition where to submit") 27 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 28 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 29 | return parser.parse_args() 30 | 31 | 32 | def get_shared_folder() -> Path: 33 | user = os.getenv("USER") 34 | p = Path(f"./output_base_gpu4/") 35 | p.mkdir(exist_ok=True) 36 | return p 37 | 38 | def get_init_file(): 39 | # Init file must not exist, but it's parent dir must exist. 40 | os.makedirs(str(get_shared_folder()), exist_ok=True) 41 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 42 | if init_file.exists(): 43 | os.remove(str(init_file)) 44 | return init_file 45 | 46 | 47 | class Trainer(object): 48 | def __init__(self, args): 49 | self.args = args 50 | 51 | def __call__(self): 52 | import main_pretrain as trainer 53 | 54 | self._setup_gpu_args() 55 | trainer.main(self.args) 56 | 57 | def checkpoint(self): 58 | import os 59 | import submitit 60 | 61 | self.args.dist_url = get_init_file().as_uri() 62 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 63 | if os.path.exists(checkpoint_file): 64 | self.args.resume = checkpoint_file 65 | print("Requeuing ", self.args) 66 | empty_trainer = type(self)(self.args) 67 | return submitit.helpers.DelayedSubmission(empty_trainer) 68 | 69 | def _setup_gpu_args(self): 70 | import submitit 71 | from pathlib import Path 72 | 73 | job_env = submitit.JobEnvironment() 74 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 75 | self.args.log_dir = self.args.output_dir 76 | self.args.gpu = job_env.local_rank 77 | self.args.rank = job_env.global_rank 78 | self.args.world_size = job_env.num_tasks 79 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 80 | 81 | 82 | def main(): 83 | args = parse_args() 84 | if args.job_dir == "": 85 | args.job_dir = get_shared_folder() / "%j" 86 | 87 | # Note that the folder will depend on the job_id, to easily track experiments 88 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 89 | 90 | num_gpus_per_node = args.ngpus 91 | nodes = args.nodes 92 | timeout_min = args.timeout 93 | 94 | partition = args.partition 95 | kwargs = {} 96 | #if args.use_volta32: 97 | # kwargs['slurm_constraint'] = 'volta32gb' 98 | #if args.comment: 99 | # kwargs['slurm_comment'] = args.comment 100 | 101 | executor.update_parameters( 102 | mem_gb=80 * num_gpus_per_node, 103 | gpus_per_node=num_gpus_per_node, 104 | tasks_per_node=num_gpus_per_node, # one task per GPU 105 | cpus_per_task=10, 106 | nodes=nodes, 107 | timeout_min=timeout_min, # max is 60 * 72 108 | slurm_qos='gpu_8a100', 109 | # Below are cluster dependent parameters 110 | slurm_partition=partition, 111 | slurm_signal_delay_s=120, 112 | **kwargs 113 | ) 114 | 115 | executor.update_parameters(name="mae") 116 | 117 | args.dist_url = get_init_file().as_uri() 118 | args.output_dir = args.job_dir 119 | 120 | trainer = Trainer(args) 121 | job = executor.submit(trainer) 122 | 123 | # print("Submitted job_id:", job.job_id) 124 | print(job.job_id) 125 | 126 | 127 | if __name__ == "__main__": 128 | main() 129 | -------------------------------------------------------------------------------- /SFM-Finetune/submitit_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # A script to run multinode training with submitit. 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import os 12 | import uuid 13 | from pathlib import Path 14 | 15 | import main_finetune as classification 16 | import submitit 17 | 18 | 19 | def parse_args(): 20 | classification_parser = classification.get_args_parser() 21 | parser = argparse.ArgumentParser("Submitit for MAE finetune", parents=[classification_parser]) 22 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 23 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 24 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") 25 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 26 | 27 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 28 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 29 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 30 | return parser.parse_args() 31 | 32 | 33 | def get_shared_folder() -> Path: 34 | user = os.getenv("USER") 35 | if Path("/checkpoint/").is_dir(): 36 | p = Path(f"/checkpoint/{user}/experiments") 37 | p.mkdir(exist_ok=True) 38 | return p 39 | raise RuntimeError("No shared folder available") 40 | 41 | 42 | def get_init_file(): 43 | # Init file must not exist, but it's parent dir must exist. 44 | os.makedirs(str(get_shared_folder()), exist_ok=True) 45 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 46 | if init_file.exists(): 47 | os.remove(str(init_file)) 48 | return init_file 49 | 50 | 51 | class Trainer(object): 52 | def __init__(self, args): 53 | self.args = args 54 | 55 | def __call__(self): 56 | import main_finetune as classification 57 | 58 | self._setup_gpu_args() 59 | classification.main(self.args) 60 | 61 | def checkpoint(self): 62 | import os 63 | import submitit 64 | 65 | self.args.dist_url = get_init_file().as_uri() 66 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 67 | if os.path.exists(checkpoint_file): 68 | self.args.resume = checkpoint_file 69 | print("Requeuing ", self.args) 70 | empty_trainer = type(self)(self.args) 71 | return submitit.helpers.DelayedSubmission(empty_trainer) 72 | 73 | def _setup_gpu_args(self): 74 | import submitit 75 | from pathlib import Path 76 | 77 | job_env = submitit.JobEnvironment() 78 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 79 | self.args.log_dir = self.args.output_dir 80 | self.args.gpu = job_env.local_rank 81 | self.args.rank = job_env.global_rank 82 | self.args.world_size = job_env.num_tasks 83 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 84 | 85 | 86 | def main(): 87 | args = parse_args() 88 | if args.job_dir == "": 89 | args.job_dir = get_shared_folder() / "%j" 90 | 91 | # Note that the folder will depend on the job_id, to easily track experiments 92 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 93 | 94 | num_gpus_per_node = args.ngpus 95 | nodes = args.nodes 96 | timeout_min = args.timeout 97 | 98 | partition = args.partition 99 | kwargs = {} 100 | if args.use_volta32: 101 | kwargs['slurm_constraint'] = 'volta32gb' 102 | if args.comment: 103 | kwargs['slurm_comment'] = args.comment 104 | 105 | executor.update_parameters( 106 | mem_gb=40 * num_gpus_per_node, 107 | gpus_per_node=num_gpus_per_node, 108 | tasks_per_node=num_gpus_per_node, # one task per GPU 109 | cpus_per_task=10, 110 | nodes=nodes, 111 | timeout_min=timeout_min, 112 | # Below are cluster dependent parameters 113 | slurm_partition=partition, 114 | slurm_signal_delay_s=120, 115 | **kwargs 116 | ) 117 | 118 | executor.update_parameters(name="mae") 119 | 120 | args.dist_url = get_init_file().as_uri() 121 | args.output_dir = args.job_dir 122 | 123 | trainer = Trainer(args) 124 | job = executor.submit(trainer) 125 | 126 | # print("Submitted job_id:", job.job_id) 127 | print(job.job_id) 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /SFM-Finetune/util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model,newsize1=None,newsize2=None): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | if newsize1 == None: 88 | newsize1,newsize2 = new_size,new_size 89 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, newsize1, newsize2)) 90 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 91 | # only the position tokens are interpolated 92 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 93 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 94 | pos_tokens = torch.nn.functional.interpolate( 95 | pos_tokens, size=(newsize1, newsize2), mode='bicubic', align_corners=False) 96 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 97 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 98 | checkpoint_model['pos_embed'] = new_pos_embed 99 | # elif orig_size > new_size: 100 | # print("Position generate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 101 | # pos_tokens = get_2d_sincos_pos_embed(embedding_size, new_size, cls_token=True) 102 | # pos_tokens = torch.from_numpy(pos_tokens).float().unsqueeze(0) 103 | # checkpoint_model['pos_embed'] = pos_tokens 104 | 105 | -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # References: 3 | # Deeplab v3+: https://github.com/jfzhang95/pytorch-deeplab-xception 4 | # -------------------------------------------------------- 5 | 6 | # This file is part of Synchronized-BatchNorm-PyTorch. 7 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 8 | # Distributed under MIT License. 9 | 10 | import queue 11 | import collections 12 | import threading 13 | 14 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 15 | 16 | 17 | class FutureResult(object): 18 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 19 | 20 | def __init__(self): 21 | self._result = None 22 | self._lock = threading.Lock() 23 | self._cond = threading.Condition(self._lock) 24 | 25 | def put(self, result): 26 | with self._lock: 27 | assert self._result is None, 'Previous result has\'t been fetched.' 28 | self._result = result 29 | self._cond.notify() 30 | 31 | def get(self): 32 | with self._lock: 33 | if self._result is None: 34 | self._cond.wait() 35 | 36 | res = self._result 37 | self._result = None 38 | return res 39 | 40 | 41 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 42 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 43 | 44 | 45 | class SlavePipe(_SlavePipeBase): 46 | """Pipe for master-slave communication.""" 47 | 48 | def run_slave(self, msg): 49 | self.queue.put((self.identifier, msg)) 50 | ret = self.result.get() 51 | self.queue.put(True) 52 | return ret 53 | 54 | 55 | class SyncMaster(object): 56 | """An abstract `SyncMaster` object. 57 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 58 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 59 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 60 | and passed to a registered callback. 61 | - After receiving the messages, the master device should gather the information and determine to message passed 62 | back to each slave devices. 63 | """ 64 | 65 | def __init__(self, master_callback): 66 | """ 67 | Args: 68 | master_callback: a callback to be invoked after having collected messages from slave devices. 69 | """ 70 | self._master_callback = master_callback 71 | self._queue = queue.Queue() 72 | self._registry = collections.OrderedDict() 73 | self._activated = False 74 | 75 | def __getstate__(self): 76 | return {'master_callback': self._master_callback} 77 | 78 | def __setstate__(self, state): 79 | self.__init__(state['master_callback']) 80 | 81 | def register_slave(self, identifier): 82 | """ 83 | Register an slave device. 84 | Args: 85 | identifier: an identifier, usually is the device id. 86 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 87 | """ 88 | if self._activated: 89 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 90 | self._activated = False 91 | self._registry.clear() 92 | future = FutureResult() 93 | self._registry[identifier] = _MasterRegistry(future) 94 | return SlavePipe(identifier, self._queue, future) 95 | 96 | def run_master(self, master_msg): 97 | """ 98 | Main entry for the master device in each forward pass. 99 | The messages were first collected from each devices (including the master device), and then 100 | an callback will be invoked to compute the message to be sent back to each devices 101 | (including the master device). 102 | Args: 103 | master_msg: the message that the master want to send to itself. This will be placed as the first 104 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 105 | Returns: the message to be sent back to the master device. 106 | """ 107 | self._activated = True 108 | 109 | intermediates = [(0, master_msg)] 110 | for i in range(self.nr_slaves): 111 | intermediates.append(self._queue.get()) 112 | 113 | results = self._master_callback(intermediates) 114 | assert results[0][0] == 0, 'The first result should belongs to the master.' 115 | 116 | for i, res in results: 117 | if i == 0: 118 | continue 119 | self._registry[i].result.put(res) 120 | 121 | for i in range(self.nr_slaves): 122 | assert self._queue.get() is True 123 | 124 | return results[0][1] 125 | 126 | @property 127 | def nr_slaves(self): 128 | return len(self._registry) 129 | -------------------------------------------------------------------------------- /SFM-Finetune/util/msssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from math import exp 4 | 5 | 6 | def gaussian(window_size, sigma): 7 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 8 | return gauss/gauss.sum() 9 | 10 | 11 | def create_window(window_size, channel=1): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 15 | return window 16 | 17 | 18 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 19 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). 20 | if val_range is None: 21 | if torch.max(img1) > 128: 22 | max_val = 255 23 | else: 24 | max_val = 1 25 | 26 | if torch.min(img1) < -0.5: 27 | min_val = -1 28 | else: 29 | min_val = 0 30 | L = max_val - min_val 31 | else: 32 | L = val_range 33 | 34 | padd = 0 35 | (_, channel, height, width) = img1.size() 36 | if window is None: 37 | real_size = min(window_size, height, width) 38 | window = create_window(real_size, channel=channel).to(img1.device) 39 | 40 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel) 41 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel) 42 | 43 | mu1_sq = mu1.pow(2) 44 | mu2_sq = mu2.pow(2) 45 | mu1_mu2 = mu1 * mu2 46 | 47 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 48 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 49 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 50 | 51 | C1 = (0.01 * L) ** 2 52 | C2 = (0.03 * L) ** 2 53 | 54 | v1 = 2.0 * sigma12 + C2 55 | v2 = sigma1_sq + sigma2_sq + C2 56 | cs = torch.mean(v1 / v2) # contrast sensitivity 57 | 58 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 59 | 60 | if size_average: 61 | ret = ssim_map.mean() 62 | else: 63 | ret = ssim_map.mean(1).mean(1).mean(1) 64 | 65 | if full: 66 | return ret, cs 67 | return ret 68 | 69 | 70 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=True): 71 | device = img1.device 72 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) 73 | levels = weights.size()[0] 74 | mssim = [] 75 | mcs = [] 76 | for _ in range(levels): 77 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) 78 | mssim.append(sim) 79 | mcs.append(cs) 80 | 81 | img1 = F.avg_pool2d(img1, (2, 2)) 82 | img2 = F.avg_pool2d(img2, (2, 2)) 83 | 84 | mssim = torch.stack(mssim) 85 | mcs = torch.stack(mcs) 86 | 87 | # Normalize (to avoid NaNs during training unstable models, not compliant with original definition) 88 | if normalize: 89 | mssim = (mssim + 1) / 2 90 | mcs = (mcs + 1) / 2 91 | 92 | pow1 = mcs ** weights 93 | pow2 = mssim ** weights 94 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ 95 | output = torch.prod(pow1[:-1] * pow2[-1]) 96 | return output 97 | 98 | 99 | # Classes to re-use window 100 | class SSIM(torch.nn.Module): 101 | def __init__(self, window_size=11, size_average=True, val_range=None): 102 | super(SSIM, self).__init__() 103 | self.window_size = window_size 104 | self.size_average = size_average 105 | self.val_range = val_range 106 | 107 | # Assume 1 channel for SSIM 108 | self.channel = 1 109 | self.window = create_window(window_size) 110 | 111 | def forward(self, img1, img2): 112 | (_, channel, _, _) = img1.size() 113 | 114 | if channel == self.channel and self.window.dtype == img1.dtype: 115 | window = self.window 116 | else: 117 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) 118 | self.window = window 119 | self.channel = channel 120 | 121 | return 1 - ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) 122 | 123 | class MSSSIM(torch.nn.Module): 124 | def __init__(self, window_size=11, size_average=True, channel=1): 125 | super(MSSSIM, self).__init__() 126 | self.window_size = window_size 127 | self.size_average = size_average 128 | self.channel = channel 129 | 130 | def forward(self, img1, img2): 131 | # TODO: store window between calls if possible 132 | return 1 - msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) 133 | 134 | class PSNR(torch.nn.Module): 135 | def __init__(self): 136 | super(PSNR, self).__init__() 137 | 138 | def torchPSNR(self,tar_img, prd_img): 139 | imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1) 140 | rmse = (imdff**2).mean().sqrt() 141 | ps = 20*torch.log10(1/rmse) 142 | return ps 143 | 144 | def forward(self, img1, img2): 145 | # TODO: store window between calls if possible 146 | return self.torchPSNR(img1, img2) 147 | -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/backbone/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import math 5 | from modules.modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | import torch.utils.model_zoo as model_zoo 7 | 8 | def conv_bn(inp, oup, stride, BatchNorm): 9 | return nn.Sequential( 10 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 11 | BatchNorm(oup), 12 | nn.ReLU6(inplace=True) 13 | ) 14 | 15 | 16 | def fixed_padding(inputs, kernel_size, dilation): 17 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 18 | pad_total = kernel_size_effective - 1 19 | pad_beg = pad_total // 2 20 | pad_end = pad_total - pad_beg 21 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 22 | return padded_inputs 23 | 24 | 25 | class InvertedResidual(nn.Module): 26 | def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm): 27 | super(InvertedResidual, self).__init__() 28 | self.stride = stride 29 | assert stride in [1, 2] 30 | 31 | hidden_dim = round(inp * expand_ratio) 32 | self.use_res_connect = self.stride == 1 and inp == oup 33 | self.kernel_size = 3 34 | self.dilation = dilation 35 | 36 | if expand_ratio == 1: 37 | self.conv = nn.Sequential( 38 | # dw 39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 40 | BatchNorm(hidden_dim), 41 | nn.ReLU6(inplace=True), 42 | # pw-linear 43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False), 44 | BatchNorm(oup), 45 | ) 46 | else: 47 | self.conv = nn.Sequential( 48 | # pw 49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False), 50 | BatchNorm(hidden_dim), 51 | nn.ReLU6(inplace=True), 52 | # dw 53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 54 | BatchNorm(hidden_dim), 55 | nn.ReLU6(inplace=True), 56 | # pw-linear 57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False), 58 | BatchNorm(oup), 59 | ) 60 | 61 | def forward(self, x): 62 | x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation) 63 | if self.use_res_connect: 64 | x = x + self.conv(x_pad) 65 | else: 66 | x = self.conv(x_pad) 67 | return x 68 | 69 | 70 | class MobileNetV2(nn.Module): 71 | def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True): 72 | super(MobileNetV2, self).__init__() 73 | block = InvertedResidual 74 | input_channel = 32 75 | current_stride = 1 76 | rate = 1 77 | interverted_residual_setting = [ 78 | # t, c, n, s 79 | [1, 16, 1, 1], 80 | [6, 24, 2, 2], 81 | [6, 32, 3, 2], 82 | [6, 64, 4, 2], 83 | [6, 96, 3, 1], 84 | [6, 160, 3, 2], 85 | [6, 320, 1, 1], 86 | ] 87 | 88 | # building first layer 89 | input_channel = int(input_channel * width_mult) 90 | self.features = [conv_bn(3, input_channel, 2, BatchNorm)] 91 | current_stride *= 2 92 | # building inverted residual blocks 93 | for t, c, n, s in interverted_residual_setting: 94 | if current_stride == output_stride: 95 | stride = 1 96 | dilation = rate 97 | rate *= s 98 | else: 99 | stride = s 100 | dilation = 1 101 | current_stride *= s 102 | output_channel = int(c * width_mult) 103 | for i in range(n): 104 | if i == 0: 105 | self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm)) 106 | else: 107 | self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm)) 108 | input_channel = output_channel 109 | self.features = nn.Sequential(*self.features) 110 | self._initialize_weights() 111 | 112 | if pretrained: 113 | self._load_pretrained_model() 114 | 115 | self.low_level_features = self.features[0:4] 116 | self.high_level_features = self.features[4:] 117 | 118 | def forward(self, x): 119 | low_level_feat = self.low_level_features(x) 120 | x = self.high_level_features(low_level_feat) 121 | return x, low_level_feat 122 | 123 | def _load_pretrained_model(self): 124 | pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth') 125 | model_dict = {} 126 | state_dict = self.state_dict() 127 | for k, v in pretrain_dict.items(): 128 | if k in state_dict: 129 | model_dict[k] = v 130 | state_dict.update(model_dict) 131 | self.load_state_dict(state_dict) 132 | 133 | def _initialize_weights(self): 134 | for m in self.modules(): 135 | if isinstance(m, nn.Conv2d): 136 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 137 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 138 | torch.nn.init.kaiming_normal_(m.weight) 139 | elif isinstance(m, SynchronizedBatchNorm2d): 140 | m.weight.data.fill_(1) 141 | m.bias.data.zero_() 142 | elif isinstance(m, nn.BatchNorm2d): 143 | m.weight.data.fill_(1) 144 | m.bias.data.zero_() 145 | 146 | if __name__ == "__main__": 147 | input = torch.rand(1, 3, 512, 512) 148 | model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d) 149 | output, low_level_feat = model(input) 150 | print(output.size()) 151 | print(low_level_feat.size()) 152 | -------------------------------------------------------------------------------- /SFM-Finetune/util/pos_embedtest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model,newsize1=None,newsize2=None): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | if newsize1 == None: 88 | newsize1,newsize2 = new_size,new_size 89 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, newsize1, newsize2)) 90 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 91 | # only the position tokens are interpolated 92 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 93 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 94 | pos_tokens = torch.nn.functional.interpolate( 95 | pos_tokens, size=(newsize1, newsize2), mode='bicubic', align_corners=False) 96 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 97 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 98 | checkpoint_model['pos_embed'] = new_pos_embed 99 | 100 | def interpolate_dec_embed(model, checkpoint_model): 101 | if 'decoder_pos_embed' in checkpoint_model: 102 | pos_embed_checkpoint = checkpoint_model['decoder_pos_embed'] 103 | embedding_size = pos_embed_checkpoint.shape[-1] 104 | num_patches = model.decoder_pos_embed.num_patches 105 | num_extra_tokens = model.decoder_pos_embed.shape[-2] - num_patches 106 | # height (== width) for the checkpoint position embedding 107 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 108 | # height (== width) for the new position embedding 109 | new_size = int(num_patches ** 0.5) 110 | # class_token and dist_token are kept unchanged 111 | if orig_size != new_size: 112 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 113 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 114 | # only the position tokens are interpolated 115 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 116 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 117 | pos_tokens = torch.nn.functional.interpolate( 118 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 119 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 120 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 121 | checkpoint_model['decoder_pos_embed'] = new_pos_embed 122 | # elif orig_size > new_size: 123 | # print("Position generate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 124 | # pos_tokens = get_2d_sincos_pos_embed(embedding_size, new_size, cls_token=True) 125 | # pos_tokens = torch.from_numpy(pos_tokens).float().unsqueeze(0) 126 | # checkpoint_model['pos_embed'] = pos_tokens 127 | 128 | -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | from modules.modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | 6 | class Bottleneck(nn.Module): 7 | expansion = 4 8 | 9 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 10 | super(Bottleneck, self).__init__() 11 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 12 | self.bn1 = BatchNorm(planes) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 14 | dilation=dilation, padding=dilation, bias=False) 15 | self.bn2 = BatchNorm(planes) 16 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 17 | self.bn3 = BatchNorm(planes * 4) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.downsample = downsample 20 | self.stride = stride 21 | self.dilation = dilation 22 | 23 | def forward(self, x): 24 | residual = x 25 | 26 | out = self.conv1(x) 27 | out = self.bn1(out) 28 | out = self.relu(out) 29 | 30 | out = self.conv2(out) 31 | out = self.bn2(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv3(out) 35 | out = self.bn3(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | class ResNet(nn.Module): 46 | 47 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=False): 48 | self.inplanes = 64 49 | super(ResNet, self).__init__() 50 | blocks = [1, 2, 4] 51 | if output_stride == 16: 52 | strides = [1, 2, 2, 1] 53 | dilations = [1, 1, 1, 2] 54 | elif output_stride == 8: 55 | strides = [1, 2, 1, 1] 56 | dilations = [1, 1, 2, 4] 57 | else: 58 | raise NotImplementedError 59 | 60 | # Modules 61 | self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = BatchNorm(64) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 66 | 67 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm) 68 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm) 69 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm) 70 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 71 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 72 | self._init_weight() 73 | 74 | # if pretrained: 75 | # self._load_pretrained_model() 76 | 77 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 78 | downsample = None 79 | if stride != 1 or self.inplanes != planes * block.expansion: 80 | downsample = nn.Sequential( 81 | nn.Conv2d(self.inplanes, planes * block.expansion, 82 | kernel_size=1, stride=stride, bias=False), 83 | BatchNorm(planes * block.expansion), 84 | ) 85 | 86 | layers = [] 87 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 88 | self.inplanes = planes * block.expansion 89 | for i in range(1, blocks): 90 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 91 | 92 | return nn.Sequential(*layers) 93 | 94 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 95 | downsample = None 96 | if stride != 1 or self.inplanes != planes * block.expansion: 97 | downsample = nn.Sequential( 98 | nn.Conv2d(self.inplanes, planes * block.expansion, 99 | kernel_size=1, stride=stride, bias=False), 100 | BatchNorm(planes * block.expansion), 101 | ) 102 | 103 | layers = [] 104 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, 105 | downsample=downsample, BatchNorm=BatchNorm)) 106 | self.inplanes = planes * block.expansion 107 | for i in range(1, len(blocks)): 108 | layers.append(block(self.inplanes, planes, stride=1, 109 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm)) 110 | 111 | return nn.Sequential(*layers) 112 | 113 | def forward(self, input): 114 | x = self.conv1(input) 115 | x = self.bn1(x) 116 | x = self.relu(x) 117 | x = self.maxpool(x) 118 | 119 | x = self.layer1(x) 120 | low_level_feat = x 121 | x = self.layer2(x) 122 | x = self.layer3(x) 123 | x = self.layer4(x) 124 | return x, low_level_feat 125 | 126 | def _init_weight(self): 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 130 | m.weight.data.normal_(0, math.sqrt(2. / n)) 131 | elif isinstance(m, SynchronizedBatchNorm2d): 132 | m.weight.data.fill_(1) 133 | m.bias.data.zero_() 134 | elif isinstance(m, nn.BatchNorm2d): 135 | m.weight.data.fill_(1) 136 | m.bias.data.zero_() 137 | 138 | def _load_pretrained_model(self): 139 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth') 140 | model_dict = {} 141 | state_dict = self.state_dict() 142 | for k, v in pretrain_dict.items(): 143 | if k in state_dict: 144 | model_dict[k] = v 145 | state_dict.update(model_dict) 146 | self.load_state_dict(state_dict) 147 | 148 | def ResNet101(output_stride, BatchNorm, pretrained=False): 149 | """Constructs a ResNet-101 model. 150 | Args: 151 | pretrained (bool): If True, returns a model pre-trained on ImageNet 152 | """ 153 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained) 154 | return model 155 | 156 | if __name__ == "__main__": 157 | import torch 158 | model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=False, output_stride=8) 159 | input = torch.rand(1, 3, 512, 512) 160 | output, low_level_feat = model(input) 161 | print(output.size()) 162 | print(low_level_feat.size()) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | 6 |
7 |
8 | Hanlin Sheng 9 | 1  10 | Xinming Wu1,†,‡  11 | Xu Si1  12 | Jintao Li1  13 |
14 | Sibo Zhang 2  16 | Xudong Duan 2  18 |
19 |
20 | 21 |
22 | 1 23 | University of Science and Technology of China  24 | 2 25 | Huawei  26 |
27 | 28 | Corresponding Author  29 | Project Lead  30 |
31 | 32 | ----------------- 33 | 34 | # 🌟 Seismic Foundation Model (SFM) 35 | 36 | As shown in this workflow figure, we test the Seismic Foundation Model's performance in segmentation tasks and regression tasks, specifically in classification (i.e. seismic facies), segmentaion (i.e. seismic geobody), signal processing (i.e. denoising), inversion (i.e. reflectivity estimation), and interpolation. 37 | 38 | This is a PyTorch/GPU implementation of the paper [Seismic Foundation Model](https://arxiv.org/abs/2309.02791): 39 | ``` 40 | @article{sheng2023seismic, 41 | title={Seismic Foundation Model (SFM): a new generation deep learning model in geophysics}, 42 | author={Sheng, Hanlin and Wu, Xinming and Si, Xu and Li, Jintao and Zhang, Sibio and Duan, Xudong}, 43 | journal={arXiv preprint arXiv:2309.02791}, 44 | year={2023} 45 | } 46 | ``` 47 | 48 | ## 🌟 News 49 | * **2024.11.12:** The article has been accepted by the Geophysics journal and is awaiting publication. 50 | * **2023.9.7:** Paper is released at arxiv, and code will be gradually released. ⌛⌛⌛ 51 | * **2023.8.7:** Github Repository Initialization (copy from Meta-Transformer). 52 | 53 | ## 👉 Pre-train & Fine-tune Code 54 | 55 | * The pre-training instruction is in [PRETRAIN.md](SFM-Pretrain/README.md). 56 | 57 | * The Fine-tuning instruction is in [FINETUNE.md](SFM-Finetune/README.md). 58 | 59 | 60 | ## :rocket: Model Zoo & Data Release 61 | 62 | 63 | Open-source Pretrained Models 64 |
65 |
66 | 67 | 68 | | Model | Pretraining Size | Download | 69 | |---|:---:|:---:| 70 | | SFM-Base | 224 × 224 | ckpt ckpt-Baidu Netdisk | 71 | | SFM-Base-512 | 512 × 512 | ckpt ckpt-Baidu Netdisk | 72 | | SFM-Large | 224 × 224 | ckpt ckpt-Baidu Netdisk | 73 | | SFM-Large-512 | 512 × 512 | ckpt ckpt-Baidu Netdisk | 74 | 75 | 76 | Open-source Training & DownStream Fine-tune Task Data 77 |
78 |
79 | 80 | | Task | Size | Download | 81 | |:------------------:|:--------------------------:|:----------:| 82 | | PreTrain | 224 × 224 | [DatFile] | 83 | | Seismic Facies Classification | 768 × 768 | [DatFile DatFile-Baidu Netdisk] | 84 | | Seismic GeoBody Identification | 224 × 224 | [DatFile DatFile-Baidu Netdisk] | 85 | | Inversion (Reflectivity Estimation) | 224 × 224 | [DatFile DatFile-Baidu Netdisk] | 86 | | Signal Processing (Denoise) | 224 × 224 | [DatFile DatFile-Baidu Netdisk] | 87 | | Interpolation | 224 × 224 | [DatFile DatFile-Baidu Netdisk] | 88 | 89 | # :neckbeard: Quick Guide 90 | 91 | ## Installation 92 | 93 | To prepare the environment, please follow the following instructions. 94 | ```shell 95 | # create virtual environment 96 | conda create -n SFM python=3.9.12 97 | conda activate SFM 98 | 99 | # install pytorch 100 | pip3 install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html 101 | 102 | # install other requirements 103 | pip install -r requirements.txt 104 | 105 | # if you want to visualize the results as shown in SFM-Finetune/Application/visualization.ipynb 106 | pip install jupyter notebook 107 | python -m ipykernel install --user --name=SFM --display-name="Python (SFM)" 108 | ``` 109 | ## Download Dataset & Model 110 | 111 | Place the downloaded dataset and model in the corresponding folder. 112 | 113 | 114 | - If you want to obtain a foundation model pre-trained from scratch, Download the ```Pretrain data``` zip file in ```Data``` folder. 115 | ```shell 116 | # First execute merge 117 | zip -s 0 mae_data_more.zip --out pretrain.zip 118 | # Unzip the merged compressed file 119 | unzip pretrain.zip 120 | ``` 121 | 122 | - If you want to use our pre-trained model directly, Download ```Pre-trained model``` and place it in folder ```SFM-Pretrain/output_dir``` 123 | ```shell 124 | cd SFM-Pretrain 125 | mkdir output_dir 126 | cd output_dir 127 | ``` 128 | 129 | - If you want to apply the model to downstream tasks, Download the DownStream Task data zip file in ```Data``` folder. 130 | ```shell 131 | cd Data 132 | unzip *.zip 133 | ``` 134 | ## Facies Example 135 | 136 | 1. Download the DownStream Facies Task model [facies.pth](https://rec.ustc.edu.cn/share/2c102b40-057f-11ef-9b0d-cd9b2fe068c4) and place it in folder ```SFM-Finetune/Application/Facies/SFM-Finetune/``` 137 | 138 | 2. Download the DownStream [Facies Data](https://rec.ustc.edu.cn/share/d6cd54a0-e839-11ee-982a-9748e54ad7a4) and place it in folder Data/ then ```unzip *.zip``` 139 | 140 | 3. run the following code: 141 | 142 | ```shell 143 | cd SFM-Finetune/Application 144 | #Use jupyter notebbok to open visualization.ipynb 145 | jupyter notebook 146 | ``` 147 | ## Star History 148 | 149 | [![Star History Chart](https://api.star-history.com/svg?repos=shenghanlin/SeismicFoundationModel&type=date&legend=top-left)](https://www.star-history.com/#shenghanlin/SeismicFoundationModel&type=date&legend=top-left) 150 | 151 |
152 |
153 | # License 154 | This project is released under the [MIT license](LICENSE). 155 | 156 | -------------------------------------------------------------------------------- /SFM-Finetune/models_Regression.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | 13 | from functools import partial 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import timm.models.vision_transformer 19 | import numpy as np 20 | from util.msssim import MSSSIM 21 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 22 | """ Vision Transformer with support for global average pooling 23 | """ 24 | def __init__(self, global_pool=False,Interpolation=False, **kwargs): 25 | super(VisionTransformer, self).__init__(**kwargs) 26 | 27 | self.global_pool = global_pool 28 | self.interpolation = Interpolation 29 | self.decoder = DecoderCup(in_channels=[self.embed_dim,256,128,64]) 30 | 31 | self.segmentation_head = SegmentationHead( 32 | in_channels=64, 33 | out_channels=self.num_classes, 34 | kernel_size=1 35 | ) 36 | if self.global_pool: 37 | norm_layer = kwargs['norm_layer'] 38 | embed_dim = kwargs['embed_dim'] 39 | self.fc_norm = norm_layer(embed_dim) 40 | del self.norm # remove the original norm 41 | 42 | def generate_mask(self,input_tensor, ratio): 43 | mask = torch.zeros_like(input_tensor) 44 | indices = torch.randperm(mask.size(3)//16)[:int(mask.size(3)//16 * ratio)] 45 | sorted_indices = torch.sort(indices)[0] # 对索引进行排序 46 | for i in range(0, len(sorted_indices)): 47 | mask[:, :, :, sorted_indices[i]*16:(sorted_indices[i]+1)*16] = 1 48 | return mask 49 | 50 | def forward_features(self, x): 51 | B,C,H,W = x.shape 52 | 53 | if self.interpolation: 54 | mask = self.generate_mask(x,0.75) 55 | x = x*mask 56 | img = x 57 | else: 58 | img = x 59 | x = self.patch_embed(x) 60 | _H,_W = H //self.patch_embed.patch_size[0],W //self.patch_embed.patch_size[0] 61 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 62 | x = torch.cat((cls_tokens, x), dim=1) 63 | x = x + self.pos_embed 64 | x = self.pos_drop(x) 65 | for blk in self.blocks: 66 | x = blk(x) 67 | x = self.norm(x) 68 | 69 | x = self.decoder(x[:, 1:, :],img) 70 | x = self.segmentation_head(x) 71 | if self.interpolation: 72 | return x,mask 73 | return x 74 | 75 | def forward_Interpolationloss(self, imgs, pred, mask): 76 | loss1f = torch.nn.MSELoss() 77 | loss1 = loss1f(imgs, pred*(1-mask)+imgs*mask) 78 | loss2f = MSSSIM() 79 | loss2 = loss2f(imgs, pred*(1-mask)+imgs*mask) 80 | a = 0.1 81 | loss = (1-a)*loss1+a*loss2 82 | return loss 83 | 84 | def forward(self, x): 85 | if self.interpolation: 86 | pred,mask = self.forward_features(x) 87 | loss = self.forward_Interpolationloss(x, pred, mask) 88 | return loss, pred, mask 89 | x = self.forward_features(x) 90 | 91 | return x 92 | 93 | class Conv2dReLU(nn.Sequential): 94 | def __init__( 95 | self, 96 | in_channels, 97 | out_channels, 98 | kernel_size, 99 | padding=0, 100 | stride=1, 101 | use_batchnorm=True, 102 | ): 103 | conv = nn.Conv2d( 104 | in_channels, 105 | out_channels, 106 | kernel_size, 107 | stride=stride, 108 | padding=padding, 109 | bias=not (use_batchnorm), 110 | ) 111 | relu = nn.ReLU(inplace=True) 112 | 113 | bn = nn.BatchNorm2d(out_channels) 114 | 115 | super(Conv2dReLU, self).__init__(conv, bn, relu) 116 | 117 | 118 | class DecoderBlock(nn.Module): 119 | def __init__( 120 | self, 121 | in_channels, 122 | out_channels, 123 | skip_channels=0, 124 | use_batchnorm=True, 125 | ): 126 | super().__init__() 127 | self.conv1 = Conv2dReLU( 128 | in_channels + skip_channels, 129 | out_channels, 130 | kernel_size=3, 131 | padding=1, 132 | use_batchnorm=use_batchnorm, 133 | ) 134 | self.conv2 = Conv2dReLU( 135 | out_channels, 136 | out_channels, 137 | kernel_size=3, 138 | padding=1, 139 | use_batchnorm=use_batchnorm, 140 | ) 141 | self.up = nn.UpsamplingBilinear2d(scale_factor=2) 142 | 143 | def forward(self, x, skip=None): 144 | x = self.up(x) 145 | if skip is not None: 146 | x = torch.cat([x, skip], dim=1) 147 | x = self.conv1(x) 148 | x = self.conv2(x) 149 | return x 150 | 151 | 152 | class SegmentationHead(nn.Sequential): 153 | 154 | def __init__(self, in_channels, out_channels, kernel_size=1, upsampling=1): 155 | conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=0) 156 | upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() 157 | super().__init__(conv2d, upsampling) 158 | 159 | 160 | class DecoderCup(nn.Module): 161 | def __init__(self,in_channels=[1024,256,128,64]): 162 | super().__init__() 163 | head_channels = 512 164 | self.conv_more = Conv2dReLU( 165 | 1, 166 | 32, 167 | kernel_size=3, 168 | padding=1, 169 | use_batchnorm=True, 170 | ) 171 | skip_channels=[0,0,0,32] 172 | out_channels=[256,128,64,64] 173 | blocks = [ 174 | DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) 175 | ] 176 | self.blocks = nn.ModuleList(blocks) 177 | 178 | def forward(self, hidden_states, img, features=None): 179 | B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) 180 | h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) 181 | x = hidden_states.permute(0, 2, 1) 182 | x = x.contiguous().view(B, hidden, h, w) 183 | skip_channels=[None,None,None,self.conv_more(img)] 184 | for i, decoder_block in enumerate(self.blocks): 185 | x = decoder_block(x, skip=skip_channels[i]) 186 | return x 187 | 188 | def forward_loss(imgs, pred): 189 | """ 190 | imgs: [N, 3, H, W] 191 | pred: [N, L, p*p*3] 192 | mask: [N, L], 0 is keep, 1 is remove, 193 | """ 194 | loss1f = torch.nn.MSELoss() 195 | loss1 = loss1f(imgs, pred) 196 | loss2f = MSSSIM() 197 | loss2 = loss2f(imgs, pred) 198 | a = 0.5 199 | loss = (1-a)*loss1+a*loss2 200 | return loss 201 | 202 | 203 | def mae_vit_small_patch16(**kwargs): 204 | model = VisionTransformer( 205 | patch_size=16, embed_dim=768, depth=6, num_heads=12, mlp_ratio=4, qkv_bias=True, 206 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 207 | return model 208 | 209 | def vit_base_patch16(**kwargs): 210 | model = VisionTransformer( 211 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 212 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 213 | return model 214 | 215 | 216 | def vit_large_patch16(**kwargs): 217 | model = VisionTransformer( 218 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 219 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 220 | return model 221 | 222 | 223 | def vit_huge_patch14(**kwargs): 224 | model = VisionTransformer( 225 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 226 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 227 | return model 228 | 229 | 230 | 231 | -------------------------------------------------------------------------------- /SFM-Finetune/engine_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import math 13 | import sys 14 | from typing import Iterable, Optional 15 | 16 | import torch 17 | 18 | from timm.data import Mixup 19 | from timm.utils import accuracy 20 | 21 | import util.misc as misc 22 | import util.lr_sched as lr_sched 23 | import util.tools as tools 24 | from models_Regression import forward_loss 25 | from util.msssim import MSSSIM 26 | from util.msssim import PSNR 27 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 28 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 29 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 30 | mixup_fn: Optional[Mixup] = None, log_writer=None,task=None, 31 | args=None): 32 | model.train(True) 33 | metric_logger = misc.MetricLogger(delimiter=" ") 34 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 35 | header = 'Epoch: [{}]'.format(epoch) 36 | print_freq = 20 37 | 38 | accum_iter = args.accum_iter 39 | 40 | optimizer.zero_grad() 41 | 42 | if log_writer is not None: 43 | print('log_dir: {}'.format(log_writer.log_dir)) 44 | 45 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 46 | 47 | # we use a per iteration (instead of per epoch) lr scheduler 48 | if data_iter_step % accum_iter == 0: 49 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 50 | 51 | samples = samples.to(device, non_blocking=True) 52 | targets = targets.to(device, non_blocking=True) 53 | 54 | if mixup_fn is not None: 55 | samples, targets = mixup_fn(samples, targets) 56 | 57 | with torch.cuda.amp.autocast(): 58 | if task=='Interpolation': 59 | loss, _, _ = model(samples) 60 | # loss_value = loss.item() 61 | else: 62 | outputs = model(samples) 63 | loss = criterion(outputs, targets) 64 | 65 | loss_value = loss.item() 66 | 67 | if not math.isfinite(loss_value): 68 | print("Loss is {}, stopping training".format(loss_value)) 69 | sys.exit(1) 70 | 71 | loss /= accum_iter 72 | loss_scaler(loss, optimizer, clip_grad=max_norm, 73 | parameters=model.parameters(), create_graph=False, 74 | update_grad=(data_iter_step + 1) % accum_iter == 0) 75 | if (data_iter_step + 1) % accum_iter == 0: 76 | optimizer.zero_grad() 77 | 78 | torch.cuda.synchronize() 79 | 80 | metric_logger.update(loss=loss_value) 81 | min_lr = 10. 82 | max_lr = 0. 83 | for group in optimizer.param_groups: 84 | min_lr = min(min_lr, group["lr"]) 85 | max_lr = max(max_lr, group["lr"]) 86 | 87 | metric_logger.update(lr=max_lr) 88 | 89 | loss_value_reduce = misc.all_reduce_mean(loss_value) 90 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 91 | """ We use epoch_1000x as the x-axis in tensorboard. 92 | This calibrates different curves when batch size changes. 93 | """ 94 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 95 | log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) 96 | log_writer.add_scalar('lr', max_lr, epoch_1000x) 97 | 98 | # gather the stats from all processes 99 | metric_logger.synchronize_between_processes() 100 | print("Averaged stats:", metric_logger) 101 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 102 | 103 | 104 | @torch.no_grad() 105 | def evaluate(data_loader, model, device): 106 | criterion = torch.nn.CrossEntropyLoss() 107 | 108 | metric_logger = misc.MetricLogger(delimiter=" ") 109 | header = 'Test:' 110 | 111 | # switch to evaluation mode 112 | model.eval() 113 | 114 | for batch in metric_logger.log_every(data_loader, 10, header): 115 | images = batch[0] 116 | target = batch[-1] 117 | images = images.to(device, non_blocking=True) 118 | target = target.to(device, non_blocking=True) 119 | 120 | # compute output 121 | with torch.cuda.amp.autocast(): 122 | output = model(images) 123 | loss = criterion(output, target) 124 | 125 | acc, miou = tools.accuracy(output, target) 126 | 127 | batch_size = images.shape[0] 128 | metric_logger.update(loss=loss.item()) 129 | metric_logger.meters['acc'].update(acc.item(), n=batch_size) 130 | metric_logger.meters['miou'].update(miou.item(), n=batch_size) 131 | # gather the stats from all processes 132 | metric_logger.synchronize_between_processes() 133 | print('* Acc@1 {top1.global_avg:.3f} MIOU {miou.global_avg:.3f} loss {losses.global_avg:.3f}' 134 | .format(top1=metric_logger.acc, miou=metric_logger.miou, losses=metric_logger.loss)) 135 | 136 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 137 | 138 | 139 | @torch.no_grad() 140 | def evaluateRegressionold(data_loader, model, device): 141 | # MSEcriterion = torch.nn.MSE() 142 | criterion = forward_loss 143 | 144 | metric_logger = misc.MetricLogger(delimiter=" ") 145 | header = 'Test:' 146 | 147 | # switch to evaluation mode 148 | model.eval() 149 | 150 | for batch in metric_logger.log_every(data_loader, 10, header): 151 | images = batch[0] 152 | target = batch[-1] 153 | images = images.to(device, non_blocking=True) 154 | target = target.to(device, non_blocking=True) 155 | 156 | # compute output 157 | with torch.cuda.amp.autocast(): 158 | output = model(images) 159 | loss = criterion(output, target) 160 | 161 | # acc, miou = tools.accuracy(output, target) 162 | 163 | batch_size = images.shape[0] 164 | metric_logger.update(loss=loss.item()) 165 | # metric_logger.meters['acc'].update(acc.item(), n=batch_size) 166 | # metric_logger.meters['miou'].update(miou.item(), n=batch_size) 167 | # gather the stats from all processes 168 | metric_logger.synchronize_between_processes() 169 | # print('* loss {losses.global_avg:.3f}'.format(losses=metric_logger.loss)) 170 | 171 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 172 | 173 | 174 | @torch.no_grad() 175 | def evaluateRegression(data_loader, model, device,task=''): 176 | MSEcriterion = torch.nn.MSELoss() 177 | MSSSIMcriterion = MSSSIM() 178 | PSNRcriterion = PSNR() 179 | # criterion = forward_loss 180 | 181 | metric_logger = misc.MetricLogger(delimiter=" ") 182 | header = 'Test:' 183 | 184 | # switch to evaluation mode 185 | model.eval() 186 | 187 | for batch in metric_logger.log_every(data_loader, 10, header): 188 | images = batch[0] 189 | target = batch[-1] 190 | images = images.to(device, non_blocking=True) 191 | target = target.to(device, non_blocking=True) 192 | 193 | # compute output 194 | with torch.cuda.amp.autocast(): 195 | # output = model(images) 196 | if task=='Denoise': 197 | output = model(images) 198 | output = images - output 199 | elif task == 'Interpolation': 200 | loss, output, mask = model(images) 201 | output = output*(1-mask)+images*mask 202 | else: 203 | output = model(images) 204 | mseloss = MSEcriterion(output, target) 205 | msssimloss = MSSSIMcriterion(output, target) 206 | psnrloss = PSNRcriterion(output, target) 207 | 208 | # acc, miou = tools.accuracy(output, target) 209 | 210 | batch_size = images.shape[0] 211 | metric_logger.update(loss=mseloss.item()) 212 | metric_logger.meters['mse'].update(mseloss.item(), n=batch_size) 213 | metric_logger.meters['msssim'].update(msssimloss.item(), n=batch_size) 214 | metric_logger.meters['psnr'].update(psnrloss.item(), n=batch_size) 215 | # gather the stats from all processes 216 | metric_logger.synchronize_between_processes() 217 | # print('* loss {losses.global_avg:.3f}'.format(losses=metric_logger.loss)) 218 | print('* MSE {mse.global_avg:.3f} MSSSIM {msssim.global_avg:.3f} PSNR {psnr.global_avg:.3f}' 219 | .format(mse=metric_logger.mse, msssim=metric_logger.msssim, psnr=metric_logger.psnr)) 220 | 221 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 222 | -------------------------------------------------------------------------------- /SFM-Pretrain/main_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # MAE: https://github.com/facebookresearch/mae 11 | # -------------------------------------------------------- 12 | import argparse 13 | import datetime 14 | import json 15 | import numpy as np 16 | import os 17 | import time 18 | from pathlib import Path 19 | import os 20 | os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3' 21 | import torch 22 | import torch.backends.cudnn as cudnn 23 | from torch.utils.tensorboard import SummaryWriter 24 | import torchvision.transforms as transforms 25 | import torchvision.datasets as datasets 26 | 27 | import timm 28 | from util.datasets import SeismicSet 29 | assert timm.__version__ == "0.3.2" # version check 30 | import timm.optim.optim_factory as optim_factory 31 | 32 | import util.misc as misc 33 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 34 | 35 | import models_mae 36 | 37 | from engine_pretrain import train_one_epoch 38 | 39 | 40 | def get_args_parser(): 41 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 42 | parser.add_argument('--batch_size', default=64, type=int, 43 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 44 | parser.add_argument('--epochs', default=400, type=int) 45 | parser.add_argument('--accum_iter', default=1, type=int, 46 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 47 | 48 | # Model parameters 49 | parser.add_argument('--model', default='mae_vit_large_patch16', type=str, metavar='MODEL', 50 | help='Name of model to train') 51 | 52 | parser.add_argument('--input_size', default=224, type=int, 53 | help='images input size') 54 | 55 | parser.add_argument('--mask_ratio', default=0.75, type=float, 56 | help='Masking ratio (percentage of removed patches).') 57 | 58 | parser.add_argument('--finetune', default='', 59 | help='finetune from checkpoint') 60 | 61 | parser.add_argument('--norm_pix_loss', action='store_true', 62 | help='Use (per-patch) normalized pixels as targets for computing loss') 63 | parser.set_defaults(norm_pix_loss=False) 64 | 65 | # Optimizer parameters 66 | parser.add_argument('--weight_decay', type=float, default=0.05, 67 | help='weight decay (default: 0.05)') 68 | 69 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 70 | help='learning rate (absolute lr)') 71 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 72 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 73 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 74 | help='lower lr bound for cyclic schedulers that hit 0') 75 | 76 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', 77 | help='epochs to warmup LR') 78 | 79 | # Dataset parameters 80 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 81 | help='dataset path') 82 | 83 | parser.add_argument('--output_dir', default='./output_dir/', 84 | help='path where to save, empty for no saving') 85 | parser.add_argument('--log_dir', default='./output_dir', 86 | help='path where to tensorboard log') 87 | parser.add_argument('--device', default='cuda', 88 | help='device to use for training / testing') 89 | parser.add_argument('--seed', default=0, type=int) 90 | parser.add_argument('--resume', default='', 91 | help='resume from checkpoint') 92 | 93 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 94 | help='start epoch') 95 | parser.add_argument('--num_workers', default=8, type=int) 96 | parser.add_argument('--pin_mem', action='store_true', 97 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 98 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 99 | parser.set_defaults(pin_mem=True) 100 | 101 | # distributed training parameters 102 | parser.add_argument('--world_size', default=1, type=int, 103 | help='number of distributed processes') 104 | parser.add_argument('--local_rank', default=-1, type=int) 105 | parser.add_argument('--dist_on_itp', action='store_true') 106 | parser.add_argument('--dist_url', default='env://', 107 | help='url used to set up distributed training') 108 | 109 | return parser 110 | 111 | 112 | def main(args): 113 | misc.init_distributed_mode(args) 114 | 115 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 116 | print("{}".format(args).replace(', ', ',\n')) 117 | 118 | device = torch.device(args.device) 119 | 120 | # fix the seed for reproducibility 121 | seed = args.seed + misc.get_rank() 122 | torch.manual_seed(seed) 123 | np.random.seed(seed) 124 | 125 | cudnn.benchmark = True 126 | dataset_train = SeismicSet(args.data_path, args.input_size) 127 | 128 | 129 | if True: # args.distributed: 130 | num_tasks = misc.get_world_size() 131 | global_rank = misc.get_rank() 132 | sampler_train = torch.utils.data.DistributedSampler( 133 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 134 | ) 135 | print("Sampler_train = %s" % str(sampler_train)) 136 | else: 137 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 138 | 139 | if global_rank == 0 and args.log_dir is not None: 140 | os.makedirs(args.log_dir, exist_ok=True) 141 | log_writer = SummaryWriter(log_dir=args.log_dir) 142 | else: 143 | log_writer = None 144 | 145 | data_loader_train = torch.utils.data.DataLoader( 146 | dataset_train, sampler=sampler_train, 147 | batch_size=args.batch_size, 148 | num_workers=args.num_workers, 149 | pin_memory=args.pin_mem, 150 | drop_last=True, 151 | ) 152 | 153 | # define the model 154 | model = models_mae.__dict__[args.model](norm_pix_loss=args.norm_pix_loss,in_chans=1) 155 | model.to(device) 156 | 157 | model_without_ddp = model 158 | print("Model = %s" % str(model_without_ddp)) 159 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 160 | if args.lr is None: # only base_lr is specified 161 | args.lr = args.blr * eff_batch_size / 256 162 | 163 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 164 | print("actual lr: %.2e" % args.lr) 165 | print("accumulate grad iterations: %d" % args.accum_iter) 166 | print("effective batch size: %d" % eff_batch_size) 167 | 168 | if args.distributed: 169 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 170 | model_without_ddp = model.module 171 | 172 | # following timm: set wd as 0 for bias and norm layers 173 | param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay) 174 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 175 | print(optimizer) 176 | loss_scaler = NativeScaler() 177 | 178 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 179 | 180 | print(f"Start training for {args.epochs} epochs") 181 | start_time = time.time() 182 | for epoch in range(args.start_epoch, args.epochs): 183 | if args.distributed: 184 | data_loader_train.sampler.set_epoch(epoch) 185 | train_stats = train_one_epoch( 186 | model, data_loader_train, 187 | optimizer, device, epoch, loss_scaler, 188 | log_writer=log_writer, 189 | args=args 190 | ) 191 | if args.output_dir and (epoch % 2 == 0 or epoch + 1 == args.epochs): 192 | misc.save_model( 193 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 194 | loss_scaler=loss_scaler, epoch=epoch) 195 | 196 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 197 | 'epoch': epoch,} 198 | 199 | if args.output_dir and misc.is_main_process(): 200 | if log_writer is not None: 201 | log_writer.flush() 202 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 203 | f.write(json.dumps(log_stats) + "\n") 204 | 205 | total_time = time.time() - start_time 206 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 207 | print('Training time {}'.format(total_time_str)) 208 | 209 | 210 | if __name__ == '__main__': 211 | args = get_args_parser() 212 | args = args.parse_args() 213 | if args.output_dir: 214 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 215 | main(args) 216 | -------------------------------------------------------------------------------- /SFM-Finetune/util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | 14 | import os, random, glob 15 | import numpy as np 16 | import torch 17 | import torch.utils.data as data 18 | import torchvision.transforms as transforms 19 | 20 | random.seed(42) 21 | 22 | from torchvision import datasets, transforms 23 | 24 | from timm.data import create_transform 25 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 26 | 27 | 28 | def build_dataset(is_train, args): 29 | transform = build_transform(is_train, args) 30 | 31 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 32 | dataset = datasets.ImageFolder(root, transform=transform) 33 | 34 | print(dataset) 35 | 36 | return dataset 37 | 38 | 39 | def build_transform(is_train, args): 40 | mean = IMAGENET_DEFAULT_MEAN 41 | std = IMAGENET_DEFAULT_STD 42 | # train transform 43 | if is_train: 44 | # this should always dispatch to transforms_imagenet_train 45 | transform = create_transform( 46 | input_size=args.input_size, 47 | is_training=True, 48 | color_jitter=args.color_jitter, 49 | auto_augment=args.aa, 50 | interpolation='bicubic', 51 | re_prob=args.reprob, 52 | re_mode=args.remode, 53 | re_count=args.recount, 54 | mean=mean, 55 | std=std, 56 | ) 57 | return transform 58 | 59 | # eval transform 60 | t = [] 61 | if args.input_size <= 224: 62 | crop_pct = 224 / 256 63 | else: 64 | crop_pct = 1.0 65 | size = int(args.input_size / crop_pct) 66 | t.append( 67 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 68 | ) 69 | t.append(transforms.CenterCrop(args.input_size)) 70 | 71 | t.append(transforms.ToTensor()) 72 | t.append(transforms.Normalize(mean, std)) 73 | return transforms.Compose(t) 74 | 75 | 76 | ## pretrain 77 | class SeismicSet(data.Dataset): 78 | 79 | def __init__(self, path, input_size) -> None: 80 | super().__init__() 81 | # self.file_list = os.listdir(path) 82 | # self.file_list = [os.path.join(path, f) for f in self.file_list] 83 | self.get_file_list(path) 84 | self.input_size = input_size 85 | print(len(self.file_list)) 86 | 87 | def __len__(self) -> int: 88 | return len(self.file_list) 89 | # return 100000 90 | 91 | def __getitem__(self, index): 92 | d = np.fromfile(self.file_list[index], dtype=np.float32) 93 | d = d.reshape(1, self.input_size, self.input_size) 94 | d = (d - d.mean()) / (d.std()+1e-6) 95 | 96 | # return to_transforms(d, self.input_size) 97 | return d,torch.tensor([1]) 98 | 99 | def get_file_list(self, path): 100 | dirs = [os.path.join(path, f) for f in os.listdir(path)] 101 | self.file_list = dirs 102 | 103 | # for ds in dirs: 104 | # if os.path.isdir(ds): 105 | # self.file_list += [os.path.join(ds, f) for f in os.listdir(ds)] 106 | 107 | return random.shuffle(self.file_list) 108 | 109 | 110 | def to_transforms(d, input_size): 111 | t = transforms.Compose([ 112 | transforms.RandomResizedCrop(input_size, 113 | scale=(0.2, 1.0), 114 | interpolation=3), # 3 is bicubic 115 | transforms.RandomHorizontalFlip(), 116 | transforms.ToTensor() 117 | ]) 118 | 119 | return t(d) 120 | 121 | 122 | 123 | 124 | 125 | 126 | ### fintune 127 | class FacesSet(data.Dataset): 128 | # folder/train/data/**.dat, folder/train/label/**.dat 129 | # folder/test/data/**.dat, folder/test/label/**.dat 130 | def __init__(self, 131 | folder, 132 | shape=[768, 768], 133 | is_train=True) -> None: 134 | super().__init__() 135 | self.shape = shape 136 | 137 | # self.data_list = sorted(glob.glob(folder + 'seismic/*.dat')) 138 | self.data_list = [folder +'seismic/'+ str(f)+'.dat' for f in range(117)] 139 | 140 | n = len(self.data_list) 141 | if is_train: 142 | self.data_list = self.data_list[:100] 143 | elif not is_train: 144 | self.data_list = self.data_list[100:] 145 | self.label_list = [ 146 | f.replace('/seismic/', '/label/') for f in self.data_list 147 | ] 148 | 149 | def __getitem__(self, index): 150 | d = np.fromfile(self.data_list[index], np.float32) 151 | d = d.reshape([1] + self.shape) 152 | l = np.fromfile(self.label_list[index], np.float32).reshape(self.shape)-1 153 | l = l.astype(int) 154 | return torch.tensor(d), torch.tensor(l) 155 | 156 | 157 | def __len__(self): 158 | return len(self.data_list) 159 | 160 | 161 | 162 | class SaltSet(data.Dataset): 163 | 164 | def __init__(self, 165 | folder, 166 | shape=[224, 224], 167 | is_train=True) -> None: 168 | super().__init__() 169 | self.shape = shape 170 | self.data_list = [folder +'seismic/'+ str(f)+'.dat' for f in range(4000)] 171 | n = len(self.data_list) 172 | if is_train: 173 | self.data_list = self.data_list[:3500] 174 | elif not is_train: 175 | self.data_list = self.data_list[3500:] 176 | self.label_list = [ 177 | f.replace('/seismic/', '/label/') for f in self.data_list 178 | ] 179 | 180 | def __getitem__(self, index): 181 | d = np.fromfile(self.data_list[index], np.float32) 182 | d = d.reshape([1] + self.shape) 183 | l = np.fromfile(self.label_list[index], np.float32).reshape(self.shape) 184 | l = l.astype(int) 185 | return torch.tensor(d), torch.tensor(l) 186 | def __len__(self): 187 | return len(self.data_list) 188 | 189 | 190 | class InterpolationSet(data.Dataset): 191 | # folder/train/data/**.dat, folder/train/label/**.dat 192 | # folder/test/data/**.dat, folder/test/label/**.dat 193 | def __init__(self, 194 | folder, 195 | shape=[224, 224], 196 | is_train=True) -> None: 197 | super().__init__() 198 | self.shape = shape 199 | self.data_list = [folder + str(f)+'.dat' for f in range(6000)] 200 | n = len(self.data_list) 201 | if is_train: 202 | self.data_list = self.data_list 203 | elif not is_train: 204 | self.data_list = [folder+'U'+ + str(f)+'.dat' for f in range(2000,4000)] 205 | self.label_list = self.data_list 206 | 207 | def __getitem__(self, index): 208 | d = np.fromfile(self.data_list[index], np.float32) 209 | d = d.reshape([1] + self.shape) 210 | return torch.tensor(d), torch.tensor(d) 211 | 212 | 213 | def __len__(self): 214 | return len(self.data_list) 215 | # return 10000 216 | 217 | 218 | 219 | class DenoiseSet(data.Dataset): 220 | def __init__(self, 221 | folder, 222 | shape=[224, 224], 223 | is_train=True) -> None: 224 | super().__init__() 225 | self.shape = shape 226 | self.data_list = [folder+'seismic/'+ str(f)+'.dat' for f in range(2000)] 227 | n = len(self.data_list) 228 | if is_train: 229 | self.data_list = self.data_list 230 | self.label_list = [f.replace('/seismic/', '/label/') for f in self.data_list] 231 | elif not is_train: 232 | self.data_list = [folder+'field/'+ str(f)+'.dat' for f in range(4000)] 233 | self.label_list = self.data_list 234 | 235 | def __getitem__(self, index): 236 | d = np.fromfile(self.data_list[index], np.float32) 237 | d = d.reshape([1] + self.shape) 238 | # d = (d - d.mean())/d.std() 239 | l = np.fromfile(self.label_list[index], np.float32) 240 | l = l.reshape([1] + self.shape) 241 | # l = (l - d.mean())/l.std() 242 | return torch.tensor(d), torch.tensor(l) 243 | 244 | 245 | def __len__(self): 246 | return len(self.data_list) 247 | 248 | 249 | class ReflectSet(data.Dataset): 250 | # folder/train/data/**.dat, folder/train/label/**.dat 251 | # folder/test/data/**.dat, folder/test/label/**.dat 252 | def __init__(self, 253 | folder, 254 | shape=[224, 224], 255 | is_train=True) -> None: 256 | super().__init__() 257 | self.shape = shape 258 | self.data_list = [folder+'seismic/'+ str(f)+'.dat' for f in range(2200)] 259 | 260 | 261 | 262 | n = len(self.data_list) 263 | if is_train: 264 | self.data_list = self.data_list 265 | self.label_list = [ 266 | f.replace('/seismic/', '/label/') for f in self.data_list 267 | ] 268 | elif not is_train: 269 | self.data_list = [folder+'SEAMseismic/'+ str(f)+'.dat' for f in range(4000)] 270 | self.label_list = [ 271 | f.replace('/SEAMseismic/', '/SEAMreflect/') for f in self.data_list 272 | ] 273 | 274 | def __getitem__(self, index): 275 | d = np.fromfile(self.data_list[index], np.float32) 276 | d = d- d.mean() 277 | d = d/(d.std()+1e-6) 278 | d = d.reshape([1] + self.shape) 279 | l = np.fromfile(self.label_list[index], np.float32) 280 | l = l-l.mean() 281 | l = l/(l.std()+1e-6) 282 | l = l.reshape([1] + self.shape) 283 | return torch.tensor(d), torch.tensor(l) 284 | 285 | 286 | def __len__(self): 287 | return len(self.data_list) 288 | -------------------------------------------------------------------------------- /SFM-Finetune/models_mae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from timm.models.vision_transformer import PatchEmbed, Block 18 | 19 | from util.pos_embed import get_2d_sincos_pos_embed 20 | 21 | 22 | class MaskedAutoencoderViT(nn.Module): 23 | """ Masked Autoencoder with VisionTransformer backbone 24 | """ 25 | def __init__(self, img_size=224, patch_size=16, in_chans=3, 26 | embed_dim=1024, depth=24, num_heads=16, 27 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 28 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): 29 | super().__init__() 30 | 31 | # -------------------------------------------------------------------------- 32 | # MAE encoder specifics 33 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) 34 | num_patches = self.patch_embed.num_patches 35 | 36 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 37 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 38 | self.in_chans = in_chans 39 | self.blocks = nn.ModuleList([ 40 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 41 | for i in range(depth)]) 42 | self.norm = norm_layer(embed_dim) 43 | # -------------------------------------------------------------------------- 44 | 45 | # -------------------------------------------------------------------------- 46 | # MAE decoder specifics 47 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 48 | 49 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 50 | 51 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding 52 | 53 | self.decoder_blocks = nn.ModuleList([ 54 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 55 | for i in range(decoder_depth)]) 56 | 57 | self.decoder_norm = norm_layer(decoder_embed_dim) 58 | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch 59 | # -------------------------------------------------------------------------- 60 | 61 | self.norm_pix_loss = norm_pix_loss 62 | 63 | self.initialize_weights() 64 | 65 | def initialize_weights(self): 66 | # initialization 67 | # initialize (and freeze) pos_embed by sin-cos embedding 68 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 69 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 70 | 71 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 72 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 73 | 74 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 75 | w = self.patch_embed.proj.weight.data 76 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 77 | 78 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 79 | torch.nn.init.normal_(self.cls_token, std=.02) 80 | torch.nn.init.normal_(self.mask_token, std=.02) 81 | 82 | # initialize nn.Linear and nn.LayerNorm 83 | self.apply(self._init_weights) 84 | 85 | def _init_weights(self, m): 86 | if isinstance(m, nn.Linear): 87 | # we use xavier_uniform following official JAX ViT: 88 | torch.nn.init.xavier_uniform_(m.weight) 89 | if isinstance(m, nn.Linear) and m.bias is not None: 90 | nn.init.constant_(m.bias, 0) 91 | elif isinstance(m, nn.LayerNorm): 92 | nn.init.constant_(m.bias, 0) 93 | nn.init.constant_(m.weight, 1.0) 94 | 95 | def patchify(self, imgs): 96 | """ 97 | imgs: (N, 3, H, W) 98 | x: (N, L, patch_size**2 *3) 99 | """ 100 | p = self.patch_embed.patch_size[0] 101 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 102 | 103 | h = w = imgs.shape[2] // p 104 | x = imgs.reshape(shape=(imgs.shape[0], self.in_chans, h, p, w, p)) 105 | x = torch.einsum('nchpwq->nhwpqc', x) 106 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * self.in_chans)) 107 | return x 108 | 109 | def unpatchify(self, x): 110 | """ 111 | x: (N, L, patch_size**2 *3) 112 | imgs: (N, 3, H, W) 113 | """ 114 | p = self.patch_embed.patch_size[0] 115 | h = w = int(x.shape[1]**.5) 116 | assert h * w == x.shape[1] 117 | 118 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 119 | x = torch.einsum('nhwpqc->nchpwq', x) 120 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 121 | return imgs 122 | 123 | def random_masking(self, x, mask_ratio): 124 | """ 125 | Perform per-sample random masking by per-sample shuffling. 126 | Per-sample shuffling is done by argsort random noise. 127 | x: [N, L, D], sequence 128 | """ 129 | N, L, D = x.shape # batch, length, dim 130 | len_keep = int(L * (1 - mask_ratio)) 131 | 132 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 133 | 134 | # sort noise for each sample 135 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 136 | ids_restore = torch.argsort(ids_shuffle, dim=1) 137 | 138 | # keep the first subset 139 | ids_keep = ids_shuffle[:, :len_keep] 140 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 141 | 142 | # generate the binary mask: 0 is keep, 1 is remove 143 | mask = torch.ones([N, L], device=x.device) 144 | mask[:, :len_keep] = 0 145 | # unshuffle to get the binary mask 146 | mask = torch.gather(mask, dim=1, index=ids_restore) 147 | 148 | return x_masked, mask, ids_restore 149 | 150 | def forward_encoder(self, x, mask_ratio): 151 | # embed patches 152 | x = self.patch_embed(x) 153 | 154 | # add pos embed w/o cls token 155 | x = x + self.pos_embed[:, 1:, :] 156 | 157 | # masking: length -> length * mask_ratio 158 | x, mask, ids_restore = self.random_masking(x, mask_ratio) 159 | 160 | # append cls token 161 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 162 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 163 | x = torch.cat((cls_tokens, x), dim=1) 164 | 165 | # apply Transformer blocks 166 | for blk in self.blocks: 167 | x = blk(x) 168 | x = self.norm(x) 169 | 170 | return x, mask, ids_restore 171 | 172 | def forward_decoder(self, x, ids_restore): 173 | # embed tokens 174 | x = self.decoder_embed(x) 175 | 176 | # append mask tokens to sequence 177 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) 178 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token 179 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 180 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token 181 | 182 | # add pos embed 183 | x = x + self.decoder_pos_embed 184 | 185 | # apply Transformer blocks 186 | for blk in self.decoder_blocks: 187 | x = blk(x) 188 | x = self.decoder_norm(x) 189 | 190 | # predictor projection 191 | x = self.decoder_pred(x) 192 | 193 | # remove cls token 194 | x = x[:, 1:, :] 195 | 196 | return x 197 | 198 | def forward_loss(self, imgs, pred, mask): 199 | """ 200 | imgs: [N, 3, H, W] 201 | pred: [N, L, p*p*3] 202 | mask: [N, L], 0 is keep, 1 is remove, 203 | """ 204 | target = self.patchify(imgs) 205 | if self.norm_pix_loss: 206 | mean = target.mean(dim=-1, keepdim=True) 207 | var = target.var(dim=-1, keepdim=True) 208 | target = (target - mean) / (var + 1.e-6)**.5 209 | 210 | loss = (pred - target) ** 2 211 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 212 | 213 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 214 | return loss 215 | 216 | def forward(self, imgs, mask_ratio=0.75): 217 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 218 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] 219 | loss = self.forward_loss(imgs, pred, mask) 220 | return loss, pred, mask 221 | 222 | def mae_vit_small_patch16_dec512d8b(**kwargs): 223 | model = MaskedAutoencoderViT( 224 | patch_size=16, embed_dim=768, depth=6, num_heads=12, 225 | decoder_embed_dim=512, decoder_depth=4, decoder_num_heads=16, 226 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 227 | return model 228 | 229 | def mae_vit_base_patch16_dec512d8b(**kwargs): 230 | model = MaskedAutoencoderViT( 231 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 232 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 233 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 234 | return model 235 | 236 | 237 | def mae_vit_large_patch16_dec512d8b(**kwargs): 238 | model = MaskedAutoencoderViT( 239 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 240 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 241 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 242 | return model 243 | 244 | 245 | def mae_vit_huge_patch14_dec512d8b(**kwargs): 246 | model = MaskedAutoencoderViT( 247 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 248 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 249 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 250 | return model 251 | 252 | def mae_vit_large_patch16_dec256d4b(**kwargs): 253 | model = MaskedAutoencoderViT( 254 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 255 | decoder_embed_dim=256, decoder_depth=4, decoder_num_heads=16, 256 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 257 | return model 258 | # set recommended archs 259 | mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b # decoder: 512 dim, 8 blocks 260 | mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks 261 | mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks 262 | mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks 263 | mae_vit_large_patch16D4d256 = mae_vit_large_patch16_dec256d4b # decoder: 512 dim, 8 blocks 264 | if __name__ == '__main__': 265 | model = mae_vit_large_patch16() 266 | inputs = torch.randn(2, 1, 224,224) 267 | outputs = model(inputs) 268 | print(outputs.shape) 269 | print('Done') -------------------------------------------------------------------------------- /SFM-Pretrain/models_mae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from timm.models.vision_transformer import PatchEmbed, Block 18 | 19 | from util.pos_embed import get_2d_sincos_pos_embed 20 | 21 | 22 | class MaskedAutoencoderViT(nn.Module): 23 | """ Masked Autoencoder with VisionTransformer backbone 24 | """ 25 | def __init__(self, img_size=224, patch_size=16, in_chans=1, 26 | embed_dim=1024, depth=24, num_heads=16, 27 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 28 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): 29 | super().__init__() 30 | 31 | # -------------------------------------------------------------------------- 32 | # MAE encoder specifics 33 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) 34 | num_patches = self.patch_embed.num_patches 35 | 36 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 37 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 38 | self.in_chans = in_chans 39 | self.blocks = nn.ModuleList([ 40 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 41 | for i in range(depth)]) 42 | self.norm = norm_layer(embed_dim) 43 | # -------------------------------------------------------------------------- 44 | 45 | # -------------------------------------------------------------------------- 46 | # MAE decoder specifics 47 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 48 | 49 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 50 | 51 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding 52 | 53 | self.decoder_blocks = nn.ModuleList([ 54 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 55 | for i in range(decoder_depth)]) 56 | 57 | self.decoder_norm = norm_layer(decoder_embed_dim) 58 | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch 59 | # -------------------------------------------------------------------------- 60 | 61 | self.norm_pix_loss = norm_pix_loss 62 | 63 | self.initialize_weights() 64 | 65 | def initialize_weights(self): 66 | # initialization 67 | # initialize (and freeze) pos_embed by sin-cos embedding 68 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 69 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 70 | 71 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 72 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 73 | 74 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 75 | w = self.patch_embed.proj.weight.data 76 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 77 | 78 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 79 | torch.nn.init.normal_(self.cls_token, std=.02) 80 | torch.nn.init.normal_(self.mask_token, std=.02) 81 | 82 | # initialize nn.Linear and nn.LayerNorm 83 | self.apply(self._init_weights) 84 | 85 | def _init_weights(self, m): 86 | if isinstance(m, nn.Linear): 87 | # we use xavier_uniform following official JAX ViT: 88 | torch.nn.init.xavier_uniform_(m.weight) 89 | if isinstance(m, nn.Linear) and m.bias is not None: 90 | nn.init.constant_(m.bias, 0) 91 | elif isinstance(m, nn.LayerNorm): 92 | nn.init.constant_(m.bias, 0) 93 | nn.init.constant_(m.weight, 1.0) 94 | 95 | def patchify(self, imgs): 96 | """ 97 | imgs: (N, 3, H, W) 98 | x: (N, L, patch_size**2 *3) 99 | """ 100 | p = self.patch_embed.patch_size[0] 101 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 102 | 103 | h = w = imgs.shape[2] // p 104 | x = imgs.reshape(shape=(imgs.shape[0], self.in_chans, h, p, w, p)) 105 | x = torch.einsum('nchpwq->nhwpqc', x) 106 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * self.in_chans)) 107 | return x 108 | 109 | def unpatchify(self, x): 110 | """ 111 | x: (N, L, patch_size**2 *3) 112 | imgs: (N, 3, H, W) 113 | """ 114 | p = self.patch_embed.patch_size[0] 115 | h = w = int(x.shape[1]**.5) 116 | assert h * w == x.shape[1] 117 | 118 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 119 | x = torch.einsum('nhwpqc->nchpwq', x) 120 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 121 | return imgs 122 | 123 | def random_masking(self, x, mask_ratio): 124 | """ 125 | Perform per-sample random masking by per-sample shuffling. 126 | Per-sample shuffling is done by argsort random noise. 127 | x: [N, L, D], sequence 128 | """ 129 | N, L, D = x.shape # batch, length, dim 130 | len_keep = int(L * (1 - mask_ratio)) 131 | 132 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 133 | 134 | # sort noise for each sample 135 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 136 | ids_restore = torch.argsort(ids_shuffle, dim=1) 137 | 138 | # keep the first subset 139 | ids_keep = ids_shuffle[:, :len_keep] 140 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 141 | 142 | # generate the binary mask: 0 is keep, 1 is remove 143 | mask = torch.ones([N, L], device=x.device) 144 | mask[:, :len_keep] = 0 145 | # unshuffle to get the binary mask 146 | mask = torch.gather(mask, dim=1, index=ids_restore) 147 | 148 | return x_masked, mask, ids_restore 149 | 150 | def forward_encoder(self, x, mask_ratio): 151 | # embed patches 152 | x = self.patch_embed(x) 153 | 154 | # add pos embed w/o cls token 155 | x = x + self.pos_embed[:, 1:, :] 156 | 157 | # masking: length -> length * mask_ratio 158 | x, mask, ids_restore = self.random_masking(x, mask_ratio) 159 | 160 | # append cls token 161 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 162 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 163 | x = torch.cat((cls_tokens, x), dim=1) 164 | 165 | # apply Transformer blocks 166 | for blk in self.blocks: 167 | x = blk(x) 168 | x = self.norm(x) 169 | 170 | return x, mask, ids_restore 171 | 172 | def forward_decoder(self, x, ids_restore): 173 | # embed tokens 174 | x = self.decoder_embed(x) 175 | 176 | # append mask tokens to sequence 177 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) 178 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token 179 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 180 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token 181 | 182 | # add pos embed 183 | x = x + self.decoder_pos_embed 184 | 185 | # apply Transformer blocks 186 | for blk in self.decoder_blocks: 187 | x = blk(x) 188 | x = self.decoder_norm(x) 189 | 190 | # predictor projection 191 | x = self.decoder_pred(x) 192 | 193 | # remove cls token 194 | x = x[:, 1:, :] 195 | 196 | return x 197 | 198 | def forward_loss(self, imgs, pred, mask): 199 | """ 200 | imgs: [N, 3, H, W] 201 | pred: [N, L, p*p*3] 202 | mask: [N, L], 0 is keep, 1 is remove, 203 | """ 204 | target = self.patchify(imgs) 205 | if self.norm_pix_loss: 206 | mean = target.mean(dim=-1, keepdim=True) 207 | var = target.var(dim=-1, keepdim=True) 208 | target = (target - mean) / (var + 1.e-6)**.5 209 | 210 | loss = (pred - target) ** 2 211 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 212 | 213 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 214 | return loss 215 | 216 | def forward(self, imgs, mask_ratio=0.75): 217 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 218 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] 219 | loss = self.forward_loss(imgs, pred, mask) 220 | return loss, pred, mask 221 | 222 | def mae_vit_small_patch16_dec512d8b(**kwargs): 223 | model = MaskedAutoencoderViT( 224 | patch_size=16, embed_dim=768, depth=6, num_heads=12, 225 | decoder_embed_dim=512, decoder_depth=4, decoder_num_heads=16, 226 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 227 | return model 228 | 229 | def mae_vit_base_patch16_dec512d8b(**kwargs): 230 | model = MaskedAutoencoderViT( 231 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 232 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 233 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 234 | return model 235 | 236 | def mae_vit_base_patch16_dec256d4b(**kwargs): 237 | model = MaskedAutoencoderViT( 238 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 239 | decoder_embed_dim=256, decoder_depth=4, decoder_num_heads=16, 240 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 241 | return model 242 | 243 | def mae_vit_large_patch16_dec256d4b(**kwargs): 244 | model = MaskedAutoencoderViT( 245 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 246 | decoder_embed_dim=256, decoder_depth=4, decoder_num_heads=16, 247 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 248 | return model 249 | 250 | def mae_vit_large_patch16_dec512d8b(**kwargs): 251 | model = MaskedAutoencoderViT( 252 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 253 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 254 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 255 | return model 256 | 257 | 258 | def mae_vit_huge_patch14_dec512d8b(**kwargs): 259 | model = MaskedAutoencoderViT( 260 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 261 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 262 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 263 | return model 264 | 265 | 266 | # set recommended archs 267 | mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b # decoder: 512 dim, 8 blocks 268 | mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks 269 | mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks 270 | mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks 271 | mae_vit_large_patch16D4d256 = mae_vit_large_patch16_dec256d4b # decoder: 512 dim, 8 blocks 272 | mae_vit_base_patch16D4d256 = mae_vit_base_patch16_dec256d4b -------------------------------------------------------------------------------- /SFM-Finetune/modules/modeling/backbone/xception.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from modules.modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 7 | 8 | def fixed_padding(inputs, kernel_size, dilation): 9 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 10 | pad_total = kernel_size_effective - 1 11 | pad_beg = pad_total // 2 12 | pad_end = pad_total - pad_beg 13 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 14 | return padded_inputs 15 | 16 | 17 | class SeparableConv2d(nn.Module): 18 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, BatchNorm=None): 19 | super(SeparableConv2d, self).__init__() 20 | 21 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation, 22 | groups=inplanes, bias=bias) 23 | self.bn = BatchNorm(inplanes) 24 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 25 | 26 | def forward(self, x): 27 | x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0]) 28 | x = self.conv1(x) 29 | x = self.bn(x) 30 | x = self.pointwise(x) 31 | return x 32 | 33 | 34 | class Block(nn.Module): 35 | def __init__(self, inplanes, planes, reps, stride=1, dilation=1, BatchNorm=None, 36 | start_with_relu=True, grow_first=True, is_last=False): 37 | super(Block, self).__init__() 38 | 39 | if planes != inplanes or stride != 1: 40 | self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False) 41 | self.skipbn = BatchNorm(planes) 42 | else: 43 | self.skip = None 44 | 45 | self.relu = nn.ReLU(inplace=True) 46 | rep = [] 47 | 48 | filters = inplanes 49 | if grow_first: 50 | rep.append(self.relu) 51 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 52 | rep.append(BatchNorm(planes)) 53 | filters = planes 54 | 55 | for i in range(reps - 1): 56 | rep.append(self.relu) 57 | rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, BatchNorm=BatchNorm)) 58 | rep.append(BatchNorm(filters)) 59 | 60 | if not grow_first: 61 | rep.append(self.relu) 62 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 63 | rep.append(BatchNorm(planes)) 64 | 65 | if stride != 1: 66 | rep.append(self.relu) 67 | rep.append(SeparableConv2d(planes, planes, 3, 2, BatchNorm=BatchNorm)) 68 | rep.append(BatchNorm(planes)) 69 | 70 | if stride == 1 and is_last: 71 | rep.append(self.relu) 72 | rep.append(SeparableConv2d(planes, planes, 3, 1, BatchNorm=BatchNorm)) 73 | rep.append(BatchNorm(planes)) 74 | 75 | if not start_with_relu: 76 | rep = rep[1:] 77 | 78 | self.rep = nn.Sequential(*rep) 79 | 80 | def forward(self, inp): 81 | x = self.rep(inp) 82 | 83 | if self.skip is not None: 84 | skip = self.skip(inp) 85 | skip = self.skipbn(skip) 86 | else: 87 | skip = inp 88 | 89 | x = x + skip 90 | 91 | return x 92 | 93 | 94 | class AlignedXception(nn.Module): 95 | """ 96 | Modified Alighed Xception 97 | """ 98 | def __init__(self, output_stride, BatchNorm, 99 | pretrained=True): 100 | super(AlignedXception, self).__init__() 101 | 102 | if output_stride == 16: 103 | entry_block3_stride = 2 104 | middle_block_dilation = 1 105 | exit_block_dilations = (1, 2) 106 | elif output_stride == 8: 107 | entry_block3_stride = 1 108 | middle_block_dilation = 2 109 | exit_block_dilations = (2, 4) 110 | else: 111 | raise NotImplementedError 112 | 113 | 114 | # Entry flow 115 | self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False) 116 | self.bn1 = BatchNorm(32) 117 | self.relu = nn.ReLU(inplace=True) 118 | 119 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False) 120 | self.bn2 = BatchNorm(64) 121 | 122 | self.block1 = Block(64, 128, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False) 123 | self.block2 = Block(128, 256, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False, 124 | grow_first=True) 125 | self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, BatchNorm=BatchNorm, 126 | start_with_relu=True, grow_first=True, is_last=True) 127 | 128 | # Middle flow 129 | self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 130 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 131 | self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 132 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 133 | self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 134 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 135 | self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 136 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 137 | self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 138 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 139 | self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 140 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 141 | self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 142 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 143 | self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 144 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 145 | self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 146 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 147 | self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 148 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 149 | self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 150 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 151 | self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 152 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 153 | self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 154 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 155 | self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 156 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 157 | self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 158 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 159 | self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 160 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 161 | 162 | # Exit flow 163 | self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0], 164 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=False, is_last=True) 165 | 166 | self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 167 | self.bn3 = BatchNorm(1536) 168 | 169 | self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 170 | self.bn4 = BatchNorm(1536) 171 | 172 | self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 173 | self.bn5 = BatchNorm(2048) 174 | 175 | # Init weights 176 | self._init_weight() 177 | 178 | # Load pretrained model 179 | if pretrained: 180 | self._load_pretrained_model() 181 | 182 | def forward(self, x): 183 | # Entry flow 184 | x = self.conv1(x) 185 | x = self.bn1(x) 186 | x = self.relu(x) 187 | 188 | x = self.conv2(x) 189 | x = self.bn2(x) 190 | x = self.relu(x) 191 | 192 | x = self.block1(x) 193 | # add relu here 194 | x = self.relu(x) 195 | low_level_feat = x 196 | x = self.block2(x) 197 | x = self.block3(x) 198 | 199 | # Middle flow 200 | x = self.block4(x) 201 | x = self.block5(x) 202 | x = self.block6(x) 203 | x = self.block7(x) 204 | x = self.block8(x) 205 | x = self.block9(x) 206 | x = self.block10(x) 207 | x = self.block11(x) 208 | x = self.block12(x) 209 | x = self.block13(x) 210 | x = self.block14(x) 211 | x = self.block15(x) 212 | x = self.block16(x) 213 | x = self.block17(x) 214 | x = self.block18(x) 215 | x = self.block19(x) 216 | 217 | # Exit flow 218 | x = self.block20(x) 219 | x = self.relu(x) 220 | x = self.conv3(x) 221 | x = self.bn3(x) 222 | x = self.relu(x) 223 | 224 | x = self.conv4(x) 225 | x = self.bn4(x) 226 | x = self.relu(x) 227 | 228 | x = self.conv5(x) 229 | x = self.bn5(x) 230 | x = self.relu(x) 231 | 232 | return x, low_level_feat 233 | 234 | def _init_weight(self): 235 | for m in self.modules(): 236 | if isinstance(m, nn.Conv2d): 237 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 238 | m.weight.data.normal_(0, math.sqrt(2. / n)) 239 | elif isinstance(m, SynchronizedBatchNorm2d): 240 | m.weight.data.fill_(1) 241 | m.bias.data.zero_() 242 | elif isinstance(m, nn.BatchNorm2d): 243 | m.weight.data.fill_(1) 244 | m.bias.data.zero_() 245 | 246 | 247 | def _load_pretrained_model(self): 248 | pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth') 249 | model_dict = {} 250 | state_dict = self.state_dict() 251 | 252 | for k, v in pretrain_dict.items(): 253 | if k in state_dict: 254 | if 'pointwise' in k: 255 | v = v.unsqueeze(-1).unsqueeze(-1) 256 | if k.startswith('block11'): 257 | model_dict[k] = v 258 | model_dict[k.replace('block11', 'block12')] = v 259 | model_dict[k.replace('block11', 'block13')] = v 260 | model_dict[k.replace('block11', 'block14')] = v 261 | model_dict[k.replace('block11', 'block15')] = v 262 | model_dict[k.replace('block11', 'block16')] = v 263 | model_dict[k.replace('block11', 'block17')] = v 264 | model_dict[k.replace('block11', 'block18')] = v 265 | model_dict[k.replace('block11', 'block19')] = v 266 | elif k.startswith('block12'): 267 | model_dict[k.replace('block12', 'block20')] = v 268 | elif k.startswith('bn3'): 269 | model_dict[k] = v 270 | model_dict[k.replace('bn3', 'bn4')] = v 271 | elif k.startswith('conv4'): 272 | model_dict[k.replace('conv4', 'conv5')] = v 273 | elif k.startswith('bn4'): 274 | model_dict[k.replace('bn4', 'bn5')] = v 275 | else: 276 | model_dict[k] = v 277 | state_dict.update(model_dict) 278 | self.load_state_dict(state_dict) 279 | 280 | 281 | 282 | if __name__ == "__main__": 283 | import torch 284 | model = AlignedXception(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16) 285 | input = torch.rand(1, 3, 512, 512) 286 | output, low_level_feat = model(input) 287 | print(output.size()) 288 | print(low_level_feat.size()) 289 | -------------------------------------------------------------------------------- /SFM-Finetune/util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch._six import inf 22 | 23 | 24 | class SmoothedValue(object): 25 | """Track a series of values and provide access to smoothed values over a 26 | window or the global series average. 27 | """ 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | def update(self, value, n=1): 38 | self.deque.append(value) 39 | self.count += n 40 | self.total += value * n 41 | 42 | def synchronize_between_processes(self): 43 | """ 44 | Warning: does not synchronize the deque! 45 | """ 46 | if not is_dist_avail_and_initialized(): 47 | return 48 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 49 | dist.barrier() 50 | dist.all_reduce(t) 51 | t = t.tolist() 52 | self.count = int(t[0]) 53 | self.total = t[1] 54 | 55 | @property 56 | def median(self): 57 | d = torch.tensor(list(self.deque)) 58 | return d.median().item() 59 | 60 | @property 61 | def avg(self): 62 | d = torch.tensor(list(self.deque), dtype=torch.float32) 63 | return d.mean().item() 64 | 65 | @property 66 | def global_avg(self): 67 | return self.total / self.count 68 | 69 | @property 70 | def max(self): 71 | return max(self.deque) 72 | 73 | @property 74 | def value(self): 75 | return self.deque[-1] 76 | 77 | def __str__(self): 78 | return self.fmt.format( 79 | median=self.median, 80 | avg=self.avg, 81 | global_avg=self.global_avg, 82 | max=self.max, 83 | value=self.value) 84 | 85 | 86 | class MetricLogger(object): 87 | def __init__(self, delimiter="\t"): 88 | self.meters = defaultdict(SmoothedValue) 89 | self.delimiter = delimiter 90 | 91 | def update(self, **kwargs): 92 | for k, v in kwargs.items(): 93 | if v is None: 94 | continue 95 | if isinstance(v, torch.Tensor): 96 | v = v.item() 97 | assert isinstance(v, (float, int)) 98 | self.meters[k].update(v) 99 | 100 | def __getattr__(self, attr): 101 | if attr in self.meters: 102 | return self.meters[attr] 103 | if attr in self.__dict__: 104 | return self.__dict__[attr] 105 | raise AttributeError("'{}' object has no attribute '{}'".format( 106 | type(self).__name__, attr)) 107 | 108 | def __str__(self): 109 | loss_str = [] 110 | for name, meter in self.meters.items(): 111 | loss_str.append( 112 | "{}: {}".format(name, str(meter)) 113 | ) 114 | return self.delimiter.join(loss_str) 115 | 116 | def synchronize_between_processes(self): 117 | for meter in self.meters.values(): 118 | meter.synchronize_between_processes() 119 | 120 | def add_meter(self, name, meter): 121 | self.meters[name] = meter 122 | 123 | def log_every(self, iterable, print_freq, header=None): 124 | i = 0 125 | if not header: 126 | header = '' 127 | start_time = time.time() 128 | end = time.time() 129 | iter_time = SmoothedValue(fmt='{avg:.4f}') 130 | data_time = SmoothedValue(fmt='{avg:.4f}') 131 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 132 | log_msg = [ 133 | header, 134 | '[{0' + space_fmt + '}/{1}]', 135 | 'eta: {eta}', 136 | '{meters}', 137 | 'time: {time}', 138 | 'data: {data}' 139 | ] 140 | if torch.cuda.is_available(): 141 | log_msg.append('max mem: {memory:.0f}') 142 | log_msg = self.delimiter.join(log_msg) 143 | MB = 1024.0 * 1024.0 144 | for obj in iterable: 145 | data_time.update(time.time() - end) 146 | yield obj 147 | iter_time.update(time.time() - end) 148 | if i % print_freq == 0 or i == len(iterable) - 1: 149 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 150 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 151 | if torch.cuda.is_available(): 152 | print(log_msg.format( 153 | i, len(iterable), eta=eta_string, 154 | meters=str(self), 155 | time=str(iter_time), data=str(data_time), 156 | memory=torch.cuda.max_memory_allocated() / MB)) 157 | else: 158 | print(log_msg.format( 159 | i, len(iterable), eta=eta_string, 160 | meters=str(self), 161 | time=str(iter_time), data=str(data_time))) 162 | i += 1 163 | end = time.time() 164 | total_time = time.time() - start_time 165 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 166 | print('{} Total time: {} ({:.4f} s / it)'.format( 167 | header, total_time_str, total_time / len(iterable))) 168 | 169 | 170 | def setup_for_distributed(is_master): 171 | """ 172 | This function disables printing when not in master process 173 | """ 174 | builtin_print = builtins.print 175 | 176 | def print(*args, **kwargs): 177 | force = kwargs.pop('force', False) 178 | force = force or (get_world_size() > 8) 179 | if is_master or force: 180 | now = datetime.datetime.now().time() 181 | builtin_print('[{}] '.format(now), end='') # print with time stamp 182 | builtin_print(*args, **kwargs) 183 | 184 | builtins.print = print 185 | 186 | 187 | def is_dist_avail_and_initialized(): 188 | if not dist.is_available(): 189 | return False 190 | if not dist.is_initialized(): 191 | return False 192 | return True 193 | 194 | 195 | def get_world_size(): 196 | if not is_dist_avail_and_initialized(): 197 | return 1 198 | return dist.get_world_size() 199 | 200 | 201 | def get_rank(): 202 | if not is_dist_avail_and_initialized(): 203 | return 0 204 | return dist.get_rank() 205 | 206 | 207 | def is_main_process(): 208 | return get_rank() == 0 209 | 210 | 211 | def save_on_master(*args, **kwargs): 212 | if is_main_process(): 213 | torch.save(*args, **kwargs) 214 | 215 | 216 | def init_distributed_mode(args): 217 | if args.dist_on_itp: 218 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 219 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 220 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 221 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 222 | os.environ['LOCAL_RANK'] = str(args.gpu) 223 | os.environ['RANK'] = str(args.rank) 224 | os.environ['WORLD_SIZE'] = str(args.world_size) 225 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 226 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 227 | args.rank = int(os.environ["RANK"]) 228 | args.world_size = int(os.environ['WORLD_SIZE']) 229 | args.gpu = int(os.environ['LOCAL_RANK']) 230 | elif 'SLURM_PROCID' in os.environ: 231 | args.rank = int(os.environ['SLURM_PROCID']) 232 | args.gpu = args.rank % torch.cuda.device_count() 233 | else: 234 | print('Not using distributed mode') 235 | setup_for_distributed(is_master=True) # hack 236 | args.distributed = False 237 | return 238 | 239 | args.distributed = True 240 | 241 | torch.cuda.set_device(args.gpu) 242 | args.dist_backend = 'nccl' 243 | print('| distributed init (rank {}): {}, gpu {}'.format( 244 | args.rank, args.dist_url, args.gpu), flush=True) 245 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 246 | world_size=args.world_size, rank=args.rank) 247 | torch.distributed.barrier() 248 | setup_for_distributed(args.rank == 0) 249 | 250 | 251 | class NativeScalerWithGradNormCount: 252 | state_dict_key = "amp_scaler" 253 | 254 | def __init__(self): 255 | self._scaler = torch.cuda.amp.GradScaler() 256 | 257 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 258 | self._scaler.scale(loss).backward(create_graph=create_graph) 259 | if update_grad: 260 | if clip_grad is not None: 261 | assert parameters is not None 262 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 263 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 264 | else: 265 | self._scaler.unscale_(optimizer) 266 | norm = get_grad_norm_(parameters) 267 | self._scaler.step(optimizer) 268 | self._scaler.update() 269 | else: 270 | norm = None 271 | return norm 272 | 273 | def state_dict(self): 274 | return self._scaler.state_dict() 275 | 276 | def load_state_dict(self, state_dict): 277 | self._scaler.load_state_dict(state_dict) 278 | 279 | 280 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 281 | if isinstance(parameters, torch.Tensor): 282 | parameters = [parameters] 283 | parameters = [p for p in parameters if p.grad is not None] 284 | norm_type = float(norm_type) 285 | if len(parameters) == 0: 286 | return torch.tensor(0.) 287 | device = parameters[0].grad.device 288 | if norm_type == inf: 289 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 290 | else: 291 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 292 | return total_norm 293 | 294 | 295 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 296 | output_dir = Path(args.output_dir) 297 | epoch_name = str(epoch) 298 | if loss_scaler is not None: 299 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 300 | for checkpoint_path in checkpoint_paths: 301 | to_save = { 302 | 'model': model_without_ddp.state_dict(), 303 | 'optimizer': optimizer.state_dict(), 304 | 'epoch': epoch, 305 | 'scaler': loss_scaler.state_dict(), 306 | 'args': args, 307 | } 308 | 309 | save_on_master(to_save, checkpoint_path) 310 | else: 311 | client_state = {'epoch': epoch} 312 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 313 | 314 | 315 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 316 | if args.resume: 317 | if args.resume.startswith('https'): 318 | checkpoint = torch.hub.load_state_dict_from_url( 319 | args.resume, map_location='cpu', check_hash=True) 320 | else: 321 | checkpoint = torch.load(args.resume, map_location='cpu') 322 | model_without_ddp.load_state_dict(checkpoint['model']) 323 | print("Resume checkpoint %s" % args.resume) 324 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 325 | optimizer.load_state_dict(checkpoint['optimizer']) 326 | args.start_epoch = checkpoint['epoch'] + 1 327 | if 'scaler' in checkpoint: 328 | loss_scaler.load_state_dict(checkpoint['scaler']) 329 | print("With optim & sched!") 330 | 331 | 332 | def all_reduce_mean(x): 333 | world_size = get_world_size() 334 | if world_size > 1: 335 | x_reduce = torch.tensor(x).cuda() 336 | dist.all_reduce(x_reduce) 337 | x_reduce /= world_size 338 | return x_reduce.item() 339 | else: 340 | return x -------------------------------------------------------------------------------- /SFM-Finetune/models_Segmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import timm.models.vision_transformer 18 | import numpy as np 19 | 20 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 21 | """ Vision Transformer with support for global average pooling 22 | """ 23 | def __init__(self, global_pool=False, **kwargs): 24 | super(VisionTransformer, self).__init__(**kwargs) 25 | 26 | self.global_pool = global_pool 27 | self.decoder = VIT_MLAHead(mla_channels=self.embed_dim,num_classes=self.num_classes) 28 | 29 | self.segmentation_head = SegmentationHead( 30 | in_channels=16, 31 | out_channels=self.num_classes, 32 | kernel_size=3, 33 | ) 34 | if self.global_pool: 35 | norm_layer = kwargs['norm_layer'] 36 | embed_dim = kwargs['embed_dim'] 37 | self.fc_norm = norm_layer(embed_dim) 38 | del self.norm # remove the original norm 39 | 40 | def forward_features(self, x): 41 | B,C,H,W = x.shape 42 | x = self.patch_embed(x) 43 | _H,_W = H //self.patch_embed.patch_size[0],W //self.patch_embed.patch_size[0] 44 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 45 | x = torch.cat((cls_tokens, x), dim=1) 46 | x = x + self.pos_embed 47 | x = self.pos_drop(x) 48 | featureskip = [] 49 | featureskipnum = 1 50 | for blk in self.blocks: 51 | x = blk(x) 52 | if featureskipnum%(len(self.blocks)//4)==0: 53 | featureskip.append(x[:,1:,:]) 54 | # print(featureskipnum) 55 | featureskipnum += 1 56 | 57 | x = self.decoder(featureskip[0],featureskip[1],featureskip[2],featureskip[3],h=_H,w=_W) 58 | return x 59 | 60 | def forward(self, x): 61 | x = self.forward_features(x) 62 | return x 63 | 64 | class Conv2dReLU(nn.Sequential): 65 | def __init__( 66 | self, 67 | in_channels, 68 | out_channels, 69 | kernel_size, 70 | padding=0, 71 | stride=1, 72 | use_batchnorm=True, 73 | ): 74 | conv = nn.Conv2d( 75 | in_channels, 76 | out_channels, 77 | kernel_size, 78 | stride=stride, 79 | padding=padding, 80 | bias=not (use_batchnorm), 81 | ) 82 | relu = nn.ReLU(inplace=True) 83 | 84 | bn = nn.BatchNorm2d(out_channels) 85 | 86 | super(Conv2dReLU, self).__init__(conv, bn, relu) 87 | 88 | 89 | class DecoderBlock(nn.Module): 90 | def __init__( 91 | self, 92 | in_channels, 93 | out_channels, 94 | skip_channels=0, 95 | use_batchnorm=True, 96 | ): 97 | super().__init__() 98 | self.conv1 = Conv2dReLU( 99 | in_channels + skip_channels, 100 | out_channels, 101 | kernel_size=3, 102 | padding=1, 103 | use_batchnorm=use_batchnorm, 104 | ) 105 | self.conv2 = Conv2dReLU( 106 | out_channels, 107 | out_channels, 108 | kernel_size=3, 109 | padding=1, 110 | use_batchnorm=use_batchnorm, 111 | ) 112 | self.up = nn.UpsamplingBilinear2d(scale_factor=2) 113 | 114 | def forward(self, x, skip=None): 115 | # print(x.shape,skip.shape) 116 | if skip is not None: 117 | x = torch.cat([x, skip], dim=1) 118 | x = self.up(x) 119 | x = self.conv1(x) 120 | x = self.conv2(x) 121 | return x 122 | 123 | 124 | class SegmentationHead(nn.Sequential): 125 | 126 | def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): 127 | conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) 128 | upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() 129 | super().__init__(conv2d, upsampling) 130 | 131 | 132 | class DecoderCup(nn.Module): 133 | def __init__(self): 134 | super().__init__() 135 | # self.config = config 136 | head_channels = 512 137 | self.conv_more = Conv2dReLU( 138 | 1024, 139 | head_channels, 140 | kernel_size=3, 141 | padding=1, 142 | use_batchnorm=True, 143 | ) 144 | 145 | decoder_channels = (256,128,64,16) 146 | 147 | 148 | in_channels = [head_channels] + list(decoder_channels[:-1]) 149 | out_channels = decoder_channels 150 | 151 | # if self.config.n_skip != 0: 152 | # skip_channels = self.config.skip_channels 153 | # for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip 154 | # skip_channels[3-i]=0 155 | # else: 156 | # skip_channels=[0,0,0,0] 157 | skip_channels=[512,256,128,64] 158 | self.conv_feature1 = Conv2dReLU(1024,skip_channels[0],kernel_size=3,padding=1,use_batchnorm=True) 159 | self.conv_feature2 = Conv2dReLU(1024,skip_channels[1],kernel_size=3,padding=1,use_batchnorm=True) 160 | self.up2 = nn.UpsamplingBilinear2d(scale_factor=2) 161 | self.conv_feature3 = Conv2dReLU(1024,skip_channels[2],kernel_size=3,padding=1,use_batchnorm=True) 162 | self.up3 = nn.UpsamplingBilinear2d(scale_factor=4) 163 | self.conv_feature4 = Conv2dReLU(1024,skip_channels[3],kernel_size=3,padding=1,use_batchnorm=True) 164 | self.up4 = nn.UpsamplingBilinear2d(scale_factor=8) 165 | 166 | # skip_channels=[128,64,32,8] 167 | blocks = [ 168 | DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) 169 | ] 170 | self.blocks = nn.ModuleList(blocks) 171 | 172 | def TransShape(self,x,head_channels = 512,up=0): 173 | B, n_patch, hidden = x.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) 174 | 175 | h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) 176 | x = x.permute(0, 2, 1) 177 | x = x.contiguous().view(B, hidden, h, w) 178 | if up==0: 179 | x = self.conv_feature1(x) 180 | elif up==1: 181 | x = self.conv_feature2(x) 182 | x = self.up2(x) 183 | elif up==2: 184 | x = self.conv_feature3(x) 185 | x = self.up3(x) 186 | elif up==3: 187 | x = self.conv_feature4(x) 188 | x = self.up4(x) 189 | return x 190 | 191 | def forward(self, hidden_states, features=None): 192 | B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) 193 | h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) 194 | x = hidden_states.permute(0, 2, 1) 195 | x = x.contiguous().view(B, hidden, h, w) 196 | x = self.conv_more(x) 197 | skip_channels=[512,256,128,64] 198 | for i, decoder_block in enumerate(self.blocks): 199 | if features is not None: 200 | skip = self.TransShape(features[i],head_channels=skip_channels[i],up=i) 201 | else: 202 | skip = None 203 | x = decoder_block(x, skip=skip) 204 | return x 205 | 206 | 207 | class MLAHead(nn.Module): 208 | def __init__(self, mla_channels=256, mlahead_channels=128, norm_cfg=None): 209 | super(MLAHead, self).__init__() 210 | self.head2 = nn.Sequential(nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False), 211 | nn.BatchNorm2d(mlahead_channels), nn.ReLU(), 212 | nn.Conv2d( 213 | mlahead_channels, mlahead_channels, 3, padding=1, bias=False), 214 | nn.BatchNorm2d(mlahead_channels), nn.ReLU()) 215 | self.head3 = nn.Sequential(nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False), 216 | nn.BatchNorm2d(mlahead_channels), nn.ReLU(), 217 | nn.Conv2d( 218 | mlahead_channels, mlahead_channels, 3, padding=1, bias=False), 219 | nn.BatchNorm2d(mlahead_channels), nn.ReLU()) 220 | self.head4 = nn.Sequential(nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False), 221 | nn.BatchNorm2d(mlahead_channels), nn.ReLU(), 222 | nn.Conv2d( 223 | mlahead_channels, mlahead_channels, 3, padding=1, bias=False), 224 | nn.BatchNorm2d(mlahead_channels), nn.ReLU()) 225 | self.head5 = nn.Sequential(nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False), 226 | nn.BatchNorm2d(mlahead_channels), nn.ReLU(), 227 | nn.Conv2d( 228 | mlahead_channels, mlahead_channels, 3, padding=1, bias=False), 229 | nn.BatchNorm2d(mlahead_channels), nn.ReLU()) 230 | 231 | def forward(self, mla_p2, mla_p3, mla_p4, mla_p5): 232 | head2 = F.interpolate(self.head2( 233 | mla_p2), (4*mla_p2.shape[-2],4*mla_p2.shape[-1]), mode='bilinear', align_corners=True) 234 | head3 = F.interpolate(self.head3( 235 | mla_p3), (4*mla_p3.shape[-2],4*mla_p3.shape[-1]), mode='bilinear', align_corners=True) 236 | head4 = F.interpolate(self.head4( 237 | mla_p4), (4*mla_p4.shape[-2],4*mla_p4.shape[-1]), mode='bilinear', align_corners=True) 238 | head5 = F.interpolate(self.head5( 239 | mla_p5), (4*mla_p5.shape[-2],4*mla_p5.shape[-1]), mode='bilinear', align_corners=True) 240 | return torch.cat([head2, head3, head4, head5], dim=1) 241 | 242 | 243 | class VIT_MLAHead(nn.Module): 244 | """ Vision Transformer with support for patch or hybrid CNN input stage 245 | """ 246 | 247 | def __init__(self, img_size=768, mla_channels=256, mlahead_channels=128,num_classes=6, 248 | norm_layer=nn.BatchNorm2d, norm_cfg=None, **kwargs): 249 | super(VIT_MLAHead, self).__init__(**kwargs) 250 | self.img_size = img_size 251 | self.norm_cfg = norm_cfg 252 | self.mla_channels = mla_channels 253 | self.BatchNorm = norm_layer 254 | self.mlahead_channels = mlahead_channels 255 | self.num_classes = num_classes 256 | self.mlahead = MLAHead(mla_channels=self.mla_channels, 257 | mlahead_channels=self.mlahead_channels, norm_cfg=self.norm_cfg) 258 | self.cls = nn.Conv2d(4 * self.mlahead_channels, 259 | self.num_classes, 3, padding=1) 260 | 261 | def forward(self, x1,x2,x3,x4,h=14,w=14): 262 | B,n_patch,hidden = x1.size() 263 | if h==w: 264 | h,w = int(np.sqrt(n_patch)),int(np.sqrt(n_patch)) 265 | x1 = x1.permute(0,2,1) 266 | x1 = x1.contiguous().view(B,hidden,h,w) 267 | x2 = x2.permute(0,2,1) 268 | x2 = x2.contiguous().view(B,hidden,h,w) 269 | x3 = x3.permute(0,2,1) 270 | x3 = x3.contiguous().view(B,hidden,h,w) 271 | x4 = x4.permute(0,2,1) 272 | x4 = x4.contiguous().view(B,hidden,h,w) 273 | x = self.mlahead(x1,x2,x3,x4) 274 | x = self.cls(x) 275 | x = F.interpolate(x, size=(h*16,w*16), mode='bilinear', 276 | align_corners=True) 277 | return x 278 | 279 | 280 | def mae_vit_small_patch16(**kwargs): 281 | model = VisionTransformer( 282 | patch_size=16, embed_dim=768, depth=6, num_heads=12, mlp_ratio=4, qkv_bias=True, 283 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 284 | return model 285 | 286 | def vit_base_patch16(**kwargs): 287 | model = VisionTransformer( 288 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 289 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 290 | return model 291 | 292 | 293 | def vit_large_patch16(**kwargs): 294 | model = VisionTransformer( 295 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 296 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 297 | return model 298 | 299 | 300 | def vit_huge_patch14(**kwargs): 301 | model = VisionTransformer( 302 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 303 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 304 | return model 305 | 306 | 307 | --------------------------------------------------------------------------------