├── .gitignore ├── README.md ├── environment.yml ├── eval_utils.py ├── linear_eval_files ├── train_ssl_0.01_filelist.txt ├── train_ssl_0.1_filelist.txt └── val_ssl_filelist.txt ├── main_lincls.py ├── main_moco_files_dataset_strong_aug.py ├── metadata_files ├── im100_metadata.txt ├── imagenet100_classes.txt └── imagenet_metadata.txt ├── moco ├── __init__.py ├── builder.py ├── loader.py └── optimizer.py ├── patch_search_iterative_search.py ├── patch_search_poison_classifier.py ├── run_pipeline.sh ├── tools.py └── vits.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | output/ 132 | runs/ 133 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PatchSearch 2 | Code for the CVPR '23 paper, "Defending Against Patch-based Data Poisoning Backdoor Attacks on Self-Supervised Learning" 3 | 4 | ## Installation 5 | 6 | 1. Clone the repository with `git clone https://github.com/UCDvision/PatchSearch` 7 | 2. Navigate into the clone repo with `cd PatchSearch` 8 | 3. Install the required packages with `conda env create -f environment.yml` 9 | 4. Activate the environment with `conda activate patch_search` 10 | 11 | Note that we use a custom version of [pytorch-grad-cam](https://github.com/UCDvision/pytorch-grad-cam) 12 | since the original does not return the model output which is required by our code. You can install it with `pip install git+https://github.com/UCDvision/pytorch-grad-cam` 13 | 14 | ## Data 15 | 16 | Use the code from the [SS-Backdoor](https://github.com/UMBCvision/SSL-Backdoor#poison-generation) repository to generate clean and poisoned data. 17 | 18 | ## Pre-training 19 | 20 | Now, you're ready to pre-train your models on poisoned data. 21 | First, set a few variables pointing to training data files, training images, and experiment output directory. 22 | Then, run the `main_moco_files_dataset_strong_aug.py` file as shown in the below command.: 23 | 24 | ``` 25 | OUTPUT_DIR='== ROOT DIRECTORY TO SAVE EXPERIMENT RESULTS GOES HERE ==' 26 | CODE_DIR='== ROOT DIRECTORY THAT CONTAINS TRAINING DATA FILE LISTS GOES HERE ==' 27 | EXPERIMENT_ID='HTBA_trigger_10_targeted_n02106550' 28 | EXP_DIR=$OUTPUT_DIR/$EXPERIMENT_ID/moco 29 | RATE='0.50' 30 | SEED=4789 31 | 32 | ### STEP 1.1: pretrain the model 33 | python main_moco_files_dataset_strong_aug.py \ 34 | --seed $SEED \ 35 | -a vit_base --epochs 200 -b 1024 \ 36 | --stop-grad-conv1 --moco-m-cos \ 37 | --multiprocessing-distributed --world-size 1 --rank 0 \ 38 | --dist-url "tcp://localhost:$(( $RANDOM % 50 + 10000 ))" \ 39 | --save_folder $EXP_DIR \ 40 | $CODE_DIR/poison-generation/data/$EXPERIMENT_ID/train/loc_random_loc*_rate_${RATE}_targeted_True_*.txt 41 | 42 | ``` 43 | 44 | Adjust the batch size and learning rate based on your available GPU. 45 | For training with `i-cutmix`, simply add the command line parameters `--icutmix --alpha 1.0` to the above pre-training command. 46 | 47 | ## Linear Evaluation 48 | 49 | The pre-trained checkpoints are evaluated by running the below command. Note that it uses some of the variables set above. 50 | 51 | ``` 52 | ### STEP 1.2: train a linear layer for evaluating the pretrained model 53 | CUDA_VISIBLE_DEVICES=0 python main_lincls.py \ 54 | --seed $SEED \ 55 | -a vit_base --lr 0.1 \ 56 | --pretrained $EXP_DIR/checkpoint_0199.pth.tar \ 57 | --train_file linear_eval_files/train_ssl_0.01_filelist.txt \ 58 | --val_file linear_eval_files/val_ssl_filelist.txt \ 59 | --save_folder $EVAL_DIR 60 | 61 | ### STEP 1.3: evaluate the trained linear layer with clean and poisoned val data 62 | CUDA_VISIBLE_DEVICES=0 python main_lincls.py \ 63 | --seed $SEED \ 64 | -a vit_base --lr 0.1 \ 65 | --conf_matrix \ 66 | --resume $EVAL_DIR/checkpoint.pth.tar \ 67 | --train_file linear_eval_files/train_ssl_0.01_filelist.txt \ 68 | --val_file linear_eval_files/val_ssl_filelist.txt \ 69 | --val_poisoned_file $CODE_DIR/$EXPERIMENT_ID/val_poisoned/loc_random_*.txt \ 70 | --eval_id exp_${SEED} 71 | ``` 72 | 73 | ## Running Iterative Search in PatchSearch 74 | 75 | Now, we can run the iterative search part of PartSearch to find highly poisonous samples. 76 | Use the below command with some of the variables set during pre-trainig. 77 | The poisons will be stored in the directory `all_top_poison_patches` located inside the experiment output directory of below command. 78 | 79 | ``` 80 | EVAL_DIR=$EXP_DIR/linear 81 | 82 | ### STEP 2: run iterative search 83 | for i in {1..2} 84 | do 85 | ### STEP 2.1: calculate and cache the features if not done so and exit else run defense 86 | ### STEP 2.2: run defense if previous step cached the features else just run defense one more time 87 | ### need to break this into two steps since the combining them slows the defense 88 | python patch_search_iterative_search.py \ 89 | --arch moco_vit_base \ 90 | --weights $EXP_DIR/checkpoint_0199.pth.tar \ 91 | --linear_weights $EVAL_DIR/checkpoint.pth.tar \ 92 | --train_file $CODE_DIR/$EXPERIMENT_ID/train/loc_random_loc*_rate_${RATE}_targeted_True_*.txt \ 93 | --val_file linear_eval_files/val_ssl_filelist.txt \ 94 | --prune_clusters \ 95 | --use_cached_feats \ 96 | --use_cached_poison_scores 97 | done 98 | ``` 99 | 100 | ## Running Poison Classifier in PatchSearch 101 | 102 | Once we have the list of highly poisonous patches, we can build a classifier to detect them with the below command. 103 | The filtered training data is listed in the `filtered.txt`. 104 | 105 | ``` 106 | DEFENSE_DIR=$EXP_DIR/patch_search_iterative_search_test_images_size_1000_window_w_60_repeat_patch_1_prune_clusters_True_num_clusters_1000_per_iteration_samples_2_remove_0x25 107 | 108 | ### STEP 3: run poison classifier 109 | CUDA_VISIBLE_DEVICES=0 python patch_search_poison_classifier.py \ 110 | --print_freq 20 \ 111 | --model_count 5 \ 112 | --batch_size 32 \ 113 | --eval_freq 20 \ 114 | --max_iterations 2000 \ 115 | --workers 8 \ 116 | --seed ${SEED} \ 117 | --train_file $CODE_DIR/$EXPERIMENT_ID/train/loc_random_loc*_rate_${RATE}_targeted_True_*.txt \ 118 | --poison_dir $DEFENSE_DIR/all_top_poison_patches \ 119 | --poison_scores $DEFENSE_DIR/poison-scores.npy \ 120 | --eval_data "seed_${SEED}" \ 121 | --topk_poisons 20 122 | ``` 123 | 124 | ## Post-Defense Pre-training and Evaluation 125 | 126 | Finally, run the below command to run pre-training on cleaned up data and evaluate the resulting model. 127 | 128 | ``` 129 | FILTERED_DIR=$DEFENSE_DIR/patch_search_poison_classifier_topk_20_ensemble_5_max_iterations_2000_seed_4789 130 | EXP_DIR=$FILTERED_DIR/moco 131 | EVAL_DIR=$EXP_DIR/linear 132 | 133 | ### STEP 4.1: pretrain the model on training set filtered with PatchSearch 134 | python main_moco_files_dataset_strong_aug.py \ 135 | --seed $SEED \ 136 | -a vit_base --epochs 200 -b 1024 \ 137 | --icutmix --alpha 1.0 \ 138 | --stop-grad-conv1 --moco-m-cos \ 139 | --multiprocessing-distributed --world-size 1 --rank 0 \ 140 | --dist-url "tcp://localhost:$(( $RANDOM % 50 + 10000 ))" \ 141 | --save_folder $EXP_DIR \ 142 | $FILTERED_DIR/filtered.txt 143 | 144 | ### STEP 4.2: train a linear layer for evaluating the pretrained model 145 | CUDA_VISIBLE_DEVICES=0 python main_lincls.py \ 146 | --seed $seed \ 147 | -a vit_base --lr 0.1 \ 148 | --pretrained $EXP_DIR/checkpoint_0199.pth.tar \ 149 | --train_file linear_eval_files/train_ssl_0.01_filelist.txt \ 150 | --val_file linear_eval_files/val_ssl_filelist.txt \ 151 | --save_folder $EVAL_DIR 152 | 153 | ### STEP 4.3: evaluate the trained linear layer with clean and poisoned val data 154 | CUDA_VISIBLE_DEVICES=0 python main_lincls.py \ 155 | --seed $seed \ 156 | -a vit_base --lr 0.1 \ 157 | --conf_matrix \ 158 | --resume $EVAL_DIR/checkpoint.pth.tar \ 159 | --train_file linear_eval_files/train_ssl_0.01_filelist.txt \ 160 | --val_file linear_eval_files/val_ssl_filelist.txt \ 161 | --val_poisoned_file $CODE_DIR/$EXPERIMENT_ID/val_poisoned/loc_random_*.txt \ 162 | --eval_id exp_${seed} 163 | ``` 164 | 165 | ## Citation 166 | 167 | Tejankar, Ajinkya, et al. "Defending Against Patch-based Backdoor Attacks on Self-Supervised Learning." _Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition._ 2023. 168 | 169 | ``` 170 | @inproceedings{tejankar2023defending, 171 | title={Defending Against Patch-based Backdoor Attacks on Self-Supervised Learning}, 172 | author={Tejankar, Ajinkya and Sanjabi, Maziar and Wang, Qifan and Wang, Sinong and Firooz, Hamed and Pirsiavash, Hamed and Tan, Liang}, 173 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 174 | pages={12239--12249}, 175 | year={2023} 176 | } 177 | ``` 178 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: patch_search 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=2_kmp_llvm 9 | - absl-py=0.15.0=pyhd3eb1b0_0 10 | - aiohttp=3.6.3=py37h7b6447c_0 11 | - argon2-cffi=20.1.0=py37h27cfd23_1 12 | - async-timeout=3.0.1=py37h06a4308_0 13 | - attrs=21.4.0=pyhd3eb1b0_0 14 | - backcall=0.2.0=pyhd3eb1b0_0 15 | - beautifulsoup4=4.11.1=py37h06a4308_0 16 | - blas=1.0=mkl 17 | - bleach=4.1.0=pyhd3eb1b0_0 18 | - blinker=1.4=py37h06a4308_0 19 | - bottleneck=1.3.4=py37hce1f21e_0 20 | - brotli=1.0.9=he6710b0_2 21 | - brotlipy=0.7.0=py37h27cfd23_1003 22 | - bzip2=1.0.8=h7f98852_4 23 | - c-ares=1.18.1=h7f8727e_0 24 | - ca-certificates=2022.12.7=ha878542_0 25 | - cachetools=4.2.2=pyhd3eb1b0_0 26 | - certifi=2022.12.7=pyhd8ed1ab_0 27 | - cffi=1.15.0=py37hd667e15_1 28 | - chardet=3.0.4=py37h06a4308_1003 29 | - click=8.0.4=py37h06a4308_0 30 | - cryptography=3.4.8=py37hd23ed53_0 31 | - cudatoolkit=11.1.1=h6406543_10 32 | - cycler=0.11.0=pyhd3eb1b0_0 33 | - cython=0.29.28=py37hd23a5d3_2 34 | - dataclasses=0.8=pyh6d0b6a4_7 35 | - dbus=1.13.18=hb2f20db_0 36 | - debugpy=1.5.1=py37h295c915_0 37 | - decorator=5.1.1=pyhd3eb1b0_0 38 | - defusedxml=0.7.1=pyhd3eb1b0_0 39 | - entrypoints=0.4=py37h06a4308_0 40 | - expat=2.4.4=h295c915_0 41 | - faiss=1.7.2=py37cuda111h8d611f9_0_cuda 42 | - faiss-gpu=1.7.2=h788eb59_3 43 | - ffmpeg=4.3=hf484d3e_0 44 | - fontconfig=2.13.1=h6c09931_0 45 | - fonttools=4.25.0=pyhd3eb1b0_0 46 | - freetype=2.10.4=h0708190_1 47 | - gensim=4.1.2=py37h295c915_0 48 | - giflib=5.2.1=h7b6447c_0 49 | - glib=2.69.1=h4ff587b_1 50 | - gmp=6.2.1=h58526e2_0 51 | - gnutls=3.6.13=h85f3911_1 52 | - google-auth=2.6.0=pyhd3eb1b0_0 53 | - google-auth-oauthlib=0.4.1=py_2 54 | - grpcio=1.42.0=py37hce63b2e_0 55 | - gst-plugins-base=1.14.0=h8213a91_2 56 | - gstreamer=1.14.0=h28cd5cc_2 57 | - icu=58.2=he6710b0_3 58 | - idna=3.3=pyhd3eb1b0_0 59 | - importlib-metadata=4.11.3=py37h06a4308_0 60 | - importlib_metadata=4.11.3=hd3eb1b0_0 61 | - importlib_resources=5.2.0=pyhd3eb1b0_1 62 | - ipykernel=6.9.1=py37h06a4308_0 63 | - ipython=7.31.1=py37h06a4308_0 64 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 65 | - ipywidgets=7.6.5=pyhd3eb1b0_1 66 | - jedi=0.18.1=py37h06a4308_1 67 | - jinja2=3.0.3=pyhd3eb1b0_0 68 | - joblib=1.1.0=pyhd3eb1b0_0 69 | - jpeg=9b=h024ee3a_2 70 | - jsonschema=4.4.0=py37h06a4308_0 71 | - jupyter=1.0.0=py37_7 72 | - jupyter_client=7.2.2=py37h06a4308_0 73 | - jupyter_console=6.4.3=pyhd3eb1b0_0 74 | - jupyter_core=4.9.2=py37h06a4308_0 75 | - jupyterlab_pygments=0.1.2=py_0 76 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 77 | - kiwisolver=1.3.2=py37h295c915_0 78 | - lame=3.100=h7f98852_1001 79 | - lcms2=2.12=h3be6417_0 80 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 81 | - libblas=3.9.0=12_linux64_mkl 82 | - libfaiss=1.7.2=cuda111h7721031_0_cuda 83 | - libfaiss-avx2=1.7.2=cuda111h1234567_0_cuda 84 | - libffi=3.3=he6710b0_2 85 | - libgcc-ng=11.2.0=h1d223b6_16 86 | - libgfortran-ng=7.5.0=ha8ba4b0_17 87 | - libgfortran4=7.5.0=ha8ba4b0_17 88 | - libiconv=1.16=h516909a_0 89 | - liblapack=3.9.0=12_linux64_mkl 90 | - libnsl=2.0.0=h7f98852_0 91 | - libpng=1.6.37=h21135ba_2 92 | - libprotobuf=3.19.1=h4ff587b_0 93 | - libsodium=1.0.18=h7b6447c_0 94 | - libstdcxx-ng=11.2.0=he4da1e4_16 95 | - libtiff=4.2.0=h85742a9_0 96 | - libuuid=1.0.3=h7f8727e_2 97 | - libuv=1.43.0=h7f98852_0 98 | - libwebp=1.2.0=h89dd481_0 99 | - libwebp-base=1.2.0=h27cfd23_0 100 | - libxcb=1.14=h7b6447c_0 101 | - libxml2=2.9.12=h03d6c58_0 102 | - libzlib=1.2.11=h166bdaf_1014 103 | - llvm-openmp=14.0.3=he0ac6c6_0 104 | - lz4-c=1.9.3=h295c915_1 105 | - markdown=3.3.4=py37h06a4308_0 106 | - markupsafe=2.0.1=py37h27cfd23_0 107 | - matplotlib=3.4.3=py37h06a4308_0 108 | - matplotlib-base=3.4.3=py37hbbc1b5f_0 109 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2 110 | - mistune=0.8.4=py37h14c3975_1001 111 | - mkl=2021.4.0=h8d4b97c_729 112 | - mkl-service=2.4.0=py37h402132d_0 113 | - mkl_fft=1.3.1=py37h3e078e5_1 114 | - mkl_random=1.2.2=py37h219a48f_0 115 | - multidict=4.7.6=py37h7b6447c_1 116 | - munkres=1.1.4=py_0 117 | - nbclient=0.5.13=py37h06a4308_0 118 | - nbconvert=6.4.4=py37h06a4308_0 119 | - nbformat=5.3.0=py37h06a4308_0 120 | - ncurses=6.3=h27087fc_1 121 | - nest-asyncio=1.5.5=py37h06a4308_0 122 | - nettle=3.6=he412f7d_0 123 | - ninja=1.10.2=h4bd325d_1 124 | - notebook=6.4.8=py37h06a4308_0 125 | - numexpr=2.8.1=py37h6abb31d_0 126 | - numpy=1.21.5=py37he7a7128_1 127 | - numpy-base=1.21.5=py37hf524024_1 128 | - oauthlib=3.2.0=pyhd3eb1b0_0 129 | - olefile=0.46=pyh9f0ad1d_1 130 | - openh264=2.1.1=h780b84a_0 131 | - openssl=1.1.1s=h7f8727e_0 132 | - packaging=21.3=pyhd3eb1b0_0 133 | - pandas=1.3.4=py37h8c16a72_0 134 | - pandocfilters=1.5.0=pyhd3eb1b0_0 135 | - parso=0.8.3=pyhd3eb1b0_0 136 | - pcre=8.45=h295c915_0 137 | - pexpect=4.8.0=pyhd3eb1b0_3 138 | - pickleshare=0.7.5=pyhd3eb1b0_1003 139 | - pillow=9.0.1=py37h22f2fdc_0 140 | - pip=22.0.4=pyhd8ed1ab_0 141 | - prometheus_client=0.13.1=pyhd3eb1b0_0 142 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0 143 | - prompt_toolkit=3.0.20=hd3eb1b0_0 144 | - protobuf=3.19.1=py37h295c915_0 145 | - ptyprocess=0.7.0=pyhd3eb1b0_2 146 | - pyasn1=0.4.8=pyhd3eb1b0_0 147 | - pyasn1-modules=0.2.8=py_0 148 | - pycparser=2.21=pyhd3eb1b0_0 149 | - pygments=2.11.2=pyhd3eb1b0_0 150 | - pyjwt=2.1.0=py37h06a4308_0 151 | - pyopenssl=21.0.0=pyhd3eb1b0_1 152 | - pyparsing=3.0.4=pyhd3eb1b0_0 153 | - pyqt=5.9.2=py37h05f1152_2 154 | - pyrsistent=0.18.0=py37heee7806_0 155 | - pysocks=1.7.1=py37_1 156 | - python=3.7.13=h12debd9_0 157 | - python-dateutil=2.8.2=pyhd3eb1b0_0 158 | - python-fastjsonschema=2.15.1=pyhd3eb1b0_0 159 | - python_abi=3.7=2_cp37m 160 | - pytorch=1.8.0=py3.7_cuda11.1_cudnn8.0.5_0 161 | - pytz=2021.3=pyhd3eb1b0_0 162 | - pyzmq=22.3.0=py37h295c915_2 163 | - qt=5.9.7=h5867ecd_1 164 | - qtconsole=5.3.0=pyhd3eb1b0_0 165 | - qtpy=2.0.1=pyhd3eb1b0_0 166 | - readline=8.1=h46c0cb4_0 167 | - requests=2.27.1=pyhd3eb1b0_0 168 | - requests-oauthlib=1.3.0=py_0 169 | - rsa=4.7.2=pyhd3eb1b0_1 170 | - scikit-learn=1.0.2=py37h51133e4_1 171 | - scipy=1.7.3=py37hc147768_0 172 | - send2trash=1.8.0=pyhd3eb1b0_1 173 | - setuptools=62.1.0=py37h89c1867_0 174 | - sip=4.19.8=py37hf484d3e_0 175 | - six=1.16.0=pyh6c4a22f_0 176 | - smart_open=6.0.0=pyhd8ed1ab_0 177 | - soupsieve=2.3.1=pyhd3eb1b0_0 178 | - sqlite=3.38.3=h4ff8645_0 179 | - tbb=2021.5.0=h924138e_1 180 | - tensorboard=2.6.0=py_1 181 | - tensorboard-data-server=0.6.0=py37hca6d32c_0 182 | - tensorboard-plugin-wit=1.6.0=py_0 183 | - terminado=0.13.1=py37h06a4308_0 184 | - testpath=0.5.0=pyhd3eb1b0_0 185 | - threadpoolctl=2.2.0=pyh0d69192_0 186 | - tk=8.6.12=h27826a3_0 187 | - torchaudio=0.8.0=py37 188 | - torchvision=0.9.0=py37_cu111 189 | - tornado=6.1=py37h27cfd23_0 190 | - tqdm=4.64.0=py37h06a4308_0 191 | - traitlets=5.1.1=pyhd3eb1b0_0 192 | - typing_extensions=4.2.0=pyha770c72_1 193 | - urllib3=1.26.9=py37h06a4308_0 194 | - wcwidth=0.2.5=pyhd3eb1b0_0 195 | - webencodings=0.5.1=py37_1 196 | - werkzeug=2.0.3=pyhd3eb1b0_0 197 | - wheel=0.37.1=pyhd8ed1ab_0 198 | - widgetsnbextension=3.5.2=py37h06a4308_0 199 | - xz=5.2.5=h516909a_1 200 | - yarl=1.6.3=py37h27cfd23_0 201 | - zeromq=4.3.4=h2531618_0 202 | - zipp=3.7.0=pyhd3eb1b0_0 203 | - zlib=1.2.11=h166bdaf_1014 204 | - zstd=1.4.9=haebb681_0 205 | - pip: 206 | - alembic==1.7.7 207 | - appdirs==1.4.4 208 | - astroid==2.11.4 209 | - autopage==0.5.0 210 | - blessings==1.7 211 | - charset-normalizer==2.0.12 212 | - cliff==3.10.1 213 | - cmaes==0.8.2 214 | - cmd2==2.4.1 215 | - colorlog==6.6.0 216 | - dill==0.3.4 217 | - filelock==3.9.0 218 | - future==0.18.2 219 | - glances==3.2.5 220 | - gpustat==0.6.0 221 | - grad-cam==1.4.6 222 | - greenlet==1.1.2 223 | - higher==0.2.1 224 | - huggingface-hub==0.13.1 225 | - imageio==2.19.1 226 | - isort==5.10.1 227 | - lazy-object-proxy==1.7.1 228 | - mako==1.2.0 229 | - mccabe==0.7.0 230 | - networkx==2.6.3 231 | - nvidia-ml-py3==7.352.0 232 | - opencv-python==4.5.5.64 233 | - optuna==2.10.0 234 | - pbr==5.9.0 235 | - platformdirs==2.5.2 236 | - prettytable==3.3.0 237 | - psutil==5.9.0 238 | - pyee==8.2.2 239 | - pylint==2.13.8 240 | - pyperclip==1.8.2 241 | - pyppeteer==1.0.2 242 | - pywavelets==1.3.0 243 | - pyyaml==6.0 244 | - scikit-image==0.19.2 245 | - sqlalchemy==1.4.36 246 | - stevedore==3.5.0 247 | - tifffile==2021.11.2 248 | - timm==0.4.9 249 | - tomli==2.0.1 250 | - ttach==0.0.3 251 | - typed-ast==1.5.3 252 | - websockets==10.3 253 | - wrapt==1.14.1 254 | prefix: /home/ajinkya/anaconda3/envs/pytorch1.8 255 | -------------------------------------------------------------------------------- /eval_utils.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | 3 | import logging 4 | import os 5 | 6 | import torch 7 | from torch import nn 8 | from torchvision import models 9 | 10 | def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False): 11 | logger = logging.getLogger() 12 | if debug: 13 | level = logging.DEBUG 14 | else: 15 | level = logging.INFO 16 | logger.setLevel(level) 17 | if saving: 18 | info_file_handler = logging.FileHandler(logpath, mode="a") 19 | info_file_handler.setLevel(level) 20 | logger.addHandler(info_file_handler) 21 | if displaying: 22 | console_handler = logging.StreamHandler() 23 | console_handler.setLevel(level) 24 | logger.addHandler(console_handler) 25 | logger.info(filepath) 26 | with open(filepath, "r") as f: 27 | logger.info(f.read()) 28 | 29 | for f in package_files: 30 | logger.info(f) 31 | with open(f, "r") as package_f: 32 | logger.info(package_f.read()) 33 | 34 | return logger 35 | 36 | class AverageMeter(object): 37 | """Computes and stores the average and current value""" 38 | def __init__(self, name, fmt=':f'): 39 | self.name = name 40 | self.fmt = fmt 41 | self.reset() 42 | 43 | def reset(self): 44 | self.val = 0 45 | self.avg = 0 46 | self.sum = 0 47 | self.count = 0 48 | 49 | def update(self, val, n=1): 50 | self.val = val 51 | self.sum += val * n 52 | self.count += n 53 | self.avg = self.sum / self.count 54 | 55 | def __str__(self): 56 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 57 | return fmtstr.format(**self.__dict__) 58 | 59 | 60 | class ProgressMeter(object): 61 | def __init__(self, num_batches, meters, prefix=""): 62 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 63 | self.meters = meters 64 | self.prefix = prefix 65 | 66 | def display(self, batch): 67 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 68 | entries += [str(meter) for meter in self.meters] 69 | return '\t'.join(entries) 70 | 71 | def _get_batch_fmtstr(self, num_batches): 72 | num_digits = len(str(num_batches // 1)) 73 | fmt = '{:' + str(num_digits) + 'd}' 74 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 75 | 76 | def accuracy(output, target, topk=(1,)): 77 | """Computes the accuracy over the k top predictions for the specified values of k""" 78 | with torch.no_grad(): 79 | maxk = max(topk) 80 | batch_size = target.size(0) 81 | 82 | _, pred = output.topk(maxk, 1, True, True) 83 | pred = pred.t() 84 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 85 | # pdb.set_trace() 86 | res = [] 87 | for k in topk: 88 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 89 | res.append(correct_k.mul_(100.0 / batch_size)) 90 | return res 91 | 92 | arch_to_key = { 93 | 'alexnet': 'alexnet', 94 | 'alexnet_moco': 'alexnet', 95 | 'resnet18': 'resnet18', 96 | 'resnet50': 'resnet50', 97 | 'rotnet_r50': 'resnet50', 98 | 'rotnet_r18': 'resnet18', 99 | 'moco_resnet18': 'resnet18', 100 | 'resnet_moco': 'resnet50', 101 | } 102 | 103 | model_names = list(arch_to_key.keys()) 104 | 105 | def save_checkpoint(state, is_best, save_dir): 106 | ckpt_path = os.path.join(save_dir, 'checkpoint.pth.tar') 107 | torch.save(state, ckpt_path) 108 | if is_best: 109 | best_ckpt_path = os.path.join(save_dir, 'model_best.pth.tar') 110 | shutil.copyfile(ckpt_path, best_ckpt_path) 111 | 112 | 113 | def makedirs(dirname): 114 | if not os.path.exists(dirname): 115 | os.makedirs(dirname) 116 | -------------------------------------------------------------------------------- /main_lincls.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import builtins 3 | import math 4 | import os 5 | import random 6 | import shutil 7 | import time 8 | import warnings 9 | import re 10 | from pathlib import Path 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.parallel 15 | import torch.backends.cudnn as cudnn 16 | import torch.distributed as dist 17 | import torch.optim 18 | import torch.multiprocessing as mp 19 | import torch.utils.data 20 | import torch.utils.data.distributed 21 | import torchvision.transforms as transforms 22 | import torchvision.datasets as datasets 23 | import torchvision.models as torchvision_models 24 | from torch.utils import data 25 | 26 | import vits 27 | from PIL import Image 28 | import numpy as np 29 | 30 | from eval_utils import get_logger 31 | 32 | torchvision_model_names = sorted(name for name in torchvision_models.__dict__ 33 | if name.islower() and not name.startswith("__") 34 | and callable(torchvision_models.__dict__[name])) 35 | 36 | model_names = ['vit_small', 'vit_base', 'vit_conv_small', 'vit_conv_base'] + torchvision_model_names 37 | 38 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 39 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', 40 | choices=model_names, 41 | help='model architecture: ' + 42 | ' | '.join(model_names) + 43 | ' (default: resnet50)') 44 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 45 | help='number of data loading workers (default: 32)') 46 | parser.add_argument('--epochs', default=30, type=int, metavar='N', 47 | help='number of total epochs to run') 48 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 49 | help='manual epoch number (useful on restarts)') 50 | parser.add_argument('-b', '--batch-size', default=256, type=int, 51 | metavar='N', 52 | help='mini-batch size (default: 1024), this is the total ' 53 | 'batch size of all GPUs on all nodes when ' 54 | 'using Data Parallel or Distributed Data Parallel') 55 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 56 | metavar='LR', help='initial (base) learning rate', dest='lr') 57 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 58 | help='momentum') 59 | parser.add_argument('--wd', '--weight-decay', default=0., type=float, 60 | metavar='W', help='weight decay (default: 0.)', 61 | dest='weight_decay') 62 | parser.add_argument('-p', '--print-freq', default=10, type=int, 63 | metavar='N', help='print frequency (default: 10)') 64 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 65 | help='path to latest checkpoint (default: none)') 66 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 67 | help='evaluate model on validation set') 68 | parser.add_argument('--world-size', default=-1, type=int, 69 | help='number of nodes for distributed training') 70 | parser.add_argument('--rank', default=-1, type=int, 71 | help='node rank for distributed training') 72 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 73 | help='url used to set up distributed training') 74 | parser.add_argument('--dist-backend', default='nccl', type=str, 75 | help='distributed backend') 76 | parser.add_argument('--seed', default=None, type=int, 77 | help='seed for initializing training. ') 78 | parser.add_argument('--gpu', default=None, type=int, 79 | help='GPU id to use.') 80 | parser.add_argument('--multiprocessing-distributed', action='store_true', 81 | help='Use multi-processing distributed training to launch ' 82 | 'N processes per node, which has N GPUs. This is the ' 83 | 'fastest way to use PyTorch for either single node or ' 84 | 'multi node data parallel training') 85 | 86 | # additional configs: 87 | parser.add_argument('--pretrained', default='', type=str, 88 | help='path to moco pretrained checkpoint') 89 | 90 | parser.add_argument('--conf_matrix', action='store_true', 91 | help='create confusion matrix') 92 | parser.add_argument('--train_file', type=str, required=False, 93 | help='file containing training image paths') 94 | parser.add_argument('--val_file', type=str, required=True, 95 | help='file containing training image paths') 96 | parser.add_argument('--val_poisoned_file', type=str, required=False, 97 | help='file containing training image paths') 98 | parser.add_argument('--eval_id', default='1', type=str) 99 | parser.add_argument('--nb_classes', default=100, type=int, 100 | help='number of the classification types') 101 | parser.add_argument('--save_folder', default='./save_folder', 102 | help='path where to save, empty for no saving') 103 | parser.add_argument('--debug', action='store_true', 104 | help='debug mode') 105 | 106 | best_acc1 = 0 107 | 108 | 109 | class FileListDataset(data.Dataset): 110 | def __init__(self, path_to_txt_file, transform): 111 | # self.data_root = data_root 112 | with open(path_to_txt_file, 'r') as f: 113 | self.file_list = f.readlines() 114 | self.file_list = [row.rstrip() for row in self.file_list] 115 | 116 | self.transform = transform 117 | 118 | 119 | def __getitem__(self, idx): 120 | image_path = self.file_list[idx].split()[0] 121 | img = Image.open(image_path).convert('RGB') 122 | target = int(self.file_list[idx].split()[1]) 123 | 124 | if self.transform is not None: 125 | images = self.transform(img) 126 | 127 | return image_path, images, target, idx 128 | 129 | def __len__(self): 130 | return len(self.file_list) 131 | 132 | 133 | def main(): 134 | args = parser.parse_args() 135 | 136 | if args.resume: 137 | args.save_folder = '/'.join(args.resume.split('/')[:-1] + [f'eval_{args.eval_id}']) 138 | if args.val_poisoned_file: 139 | args.target_wnid = re.search(r'HTBA_trigger_\d+_targeted_(n\d+)', args.val_poisoned_file).groups()[0] 140 | if args.save_folder: 141 | Path(args.save_folder).mkdir(parents=True, exist_ok=True) 142 | 143 | if args.seed is not None: 144 | random.seed(args.seed) 145 | np.random.seed(args.seed) 146 | torch.manual_seed(args.seed) 147 | cudnn.deterministic = True 148 | warnings.warn('You have chosen to seed training. ' 149 | 'This will turn on the CUDNN deterministic setting, ' 150 | 'which can slow down your training considerably! ' 151 | 'You may see unexpected behavior when restarting ' 152 | 'from checkpoints.') 153 | 154 | if args.gpu is not None: 155 | warnings.warn('You have chosen a specific GPU. This will completely ' 156 | 'disable data parallelism.') 157 | 158 | if args.dist_url == "env://" and args.world_size == -1: 159 | args.world_size = int(os.environ["WORLD_SIZE"]) 160 | 161 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 162 | 163 | ngpus_per_node = torch.cuda.device_count() 164 | if args.multiprocessing_distributed: 165 | # Since we have ngpus_per_node processes per node, the total world_size 166 | # needs to be adjusted accordingly 167 | args.world_size = ngpus_per_node * args.world_size 168 | # Use torch.multiprocessing.spawn to launch distributed processes: the 169 | # main_worker process function 170 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 171 | else: 172 | # Simply call main_worker function 173 | main_worker(args.gpu, ngpus_per_node, args) 174 | 175 | 176 | def main_worker(gpu, ngpus_per_node, args): 177 | global best_acc1 178 | args.gpu = gpu 179 | 180 | # suppress printing if not master 181 | if args.multiprocessing_distributed and args.gpu != 0: 182 | def print_pass(*args): 183 | pass 184 | builtins.print = print_pass 185 | else: 186 | if not args.debug: 187 | os.environ['PYTHONBREAKPOINT'] = '0' 188 | logger = get_logger( 189 | logpath=os.path.join(args.save_folder, 'logs'), 190 | filepath=os.path.abspath(__file__) 191 | ) 192 | def print_pass(*arg): 193 | logger.info(*arg) 194 | builtins.print = print_pass 195 | 196 | print('==> training parameters <==') 197 | for arg in vars(args): 198 | print(f'==> {arg}: {getattr(args, arg)}') 199 | print('===========================') 200 | 201 | if args.gpu is not None: 202 | print("Use GPU: {} for training".format(args.gpu)) 203 | 204 | if args.distributed: 205 | if args.dist_url == "env://" and args.rank == -1: 206 | args.rank = int(os.environ["RANK"]) 207 | if args.multiprocessing_distributed: 208 | # For multiprocessing distributed training, rank needs to be the 209 | # global rank among all the processes 210 | args.rank = args.rank * ngpus_per_node + gpu 211 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 212 | world_size=args.world_size, rank=args.rank) 213 | torch.distributed.barrier() 214 | # create model 215 | print("=> creating model '{}'".format(args.arch)) 216 | if args.arch.startswith('vit'): 217 | model = vits.__dict__[args.arch](num_classes=args.nb_classes) 218 | linear_keyword = 'head' 219 | else: 220 | model = torchvision_models.__dict__[args.arch](num_classes=args.nb_classes) 221 | linear_keyword = 'fc' 222 | 223 | breakpoint() 224 | # freeze all layers but the last fc 225 | for name, param in model.named_parameters(): 226 | if name not in ['%s.weight' % linear_keyword, '%s.bias' % linear_keyword]: 227 | param.requires_grad = False 228 | # init the fc layer 229 | getattr(model, linear_keyword).weight.data.normal_(mean=0.0, std=0.01) 230 | getattr(model, linear_keyword).bias.data.zero_() 231 | 232 | # load from pre-trained, before DistributedDataParallel constructor 233 | if args.pretrained: 234 | if os.path.isfile(args.pretrained): 235 | print("=> loading checkpoint '{}'".format(args.pretrained)) 236 | checkpoint = torch.load(args.pretrained, map_location="cpu") 237 | 238 | # rename moco pre-trained keys 239 | state_dict = checkpoint['state_dict'] 240 | state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} 241 | for k in list(state_dict.keys()): 242 | # retain only base_encoder up to before the embedding layer 243 | if k.startswith('base_encoder') and not k.startswith('base_encoder.%s' % linear_keyword): 244 | # remove prefix 245 | state_dict[k[len("base_encoder."):]] = state_dict[k] 246 | elif k.startswith('encoder_q') and not k.startswith('encoder_q.%s' % linear_keyword): 247 | # remove prefix 248 | state_dict[k[len("encoder_q."):]] = state_dict[k] 249 | # delete renamed or unused k 250 | del state_dict[k] 251 | 252 | args.start_epoch = 0 253 | msg = model.load_state_dict(state_dict, strict=False) 254 | assert set(msg.missing_keys) == {"%s.weight" % linear_keyword, "%s.bias" % linear_keyword} 255 | 256 | print("=> loaded pre-trained model '{}'".format(args.pretrained)) 257 | else: 258 | print("=> no checkpoint found at '{}'".format(args.pretrained)) 259 | 260 | # infer learning rate before changing batch size 261 | init_lr = args.lr * args.batch_size / 256 262 | 263 | if not torch.cuda.is_available(): 264 | print('using CPU, this will be slow') 265 | elif args.distributed: 266 | # For multiprocessing distributed, DistributedDataParallel constructor 267 | # should always set the single device scope, otherwise, 268 | # DistributedDataParallel will use all available devices. 269 | if args.gpu is not None: 270 | torch.cuda.set_device(args.gpu) 271 | model.cuda(args.gpu) 272 | # When using a single GPU per process and per 273 | # DistributedDataParallel, we need to divide the batch size 274 | # ourselves based on the total number of GPUs we have 275 | args.batch_size = int(args.batch_size / args.world_size) 276 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 277 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 278 | else: 279 | model.cuda() 280 | # DistributedDataParallel will divide and allocate batch_size to all 281 | # available GPUs if device_ids are not set 282 | model = torch.nn.parallel.DistributedDataParallel(model) 283 | elif args.gpu is not None: 284 | torch.cuda.set_device(args.gpu) 285 | model = model.cuda(args.gpu) 286 | else: 287 | # DataParallel will divide and allocate batch_size to all available GPUs 288 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 289 | model.features = torch.nn.DataParallel(model.features) 290 | model.cuda() 291 | else: 292 | model = torch.nn.DataParallel(model).cuda() 293 | 294 | # define loss function (criterion) and optimizer 295 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 296 | 297 | # optimize only the linear classifier 298 | parameters = list(filter(lambda p: p.requires_grad, model.parameters())) 299 | assert len(parameters) == 2 # weight, bias 300 | 301 | optimizer = torch.optim.SGD(parameters, init_lr, 302 | momentum=args.momentum, 303 | weight_decay=args.weight_decay) 304 | 305 | # optionally resume from a checkpoint 306 | if args.resume: 307 | if os.path.isfile(args.resume): 308 | print("=> loading checkpoint '{}'".format(args.resume)) 309 | if args.gpu is None: 310 | checkpoint = torch.load(args.resume) 311 | else: 312 | # Map model to be loaded to specified single gpu. 313 | loc = 'cuda:{}'.format(args.gpu) 314 | checkpoint = torch.load(args.resume, map_location=loc) 315 | args.start_epoch = checkpoint['epoch'] 316 | best_acc1 = checkpoint['best_acc1'] 317 | if args.gpu is not None: 318 | # best_acc1 may be from a checkpoint from a different GPU 319 | best_acc1 = best_acc1.to(args.gpu) 320 | model.load_state_dict(checkpoint['state_dict']) 321 | optimizer.load_state_dict(checkpoint['optimizer']) 322 | print("=> loaded checkpoint '{}' (epoch {})" 323 | .format(args.resume, checkpoint['epoch'])) 324 | else: 325 | print("=> no checkpoint found at '{}'".format(args.resume)) 326 | 327 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 328 | std=[0.229, 0.224, 0.225]) 329 | 330 | train_dataset = FileListDataset(args.train_file, transforms.Compose([ 331 | transforms.RandomResizedCrop(224), 332 | transforms.RandomHorizontalFlip(), 333 | transforms.ToTensor(), 334 | normalize, 335 | ])) 336 | 337 | val_dataset = FileListDataset(args.val_file, transforms.Compose([ 338 | transforms.Resize(256), 339 | transforms.CenterCrop(224), 340 | transforms.ToTensor(), 341 | normalize, 342 | ])) 343 | 344 | if args.conf_matrix: 345 | val_poisoned_dataset = FileListDataset(args.val_poisoned_file, transforms.Compose([ 346 | transforms.ToTensor(), 347 | normalize 348 | ])) 349 | 350 | val_poisoned_loader = torch.utils.data.DataLoader( 351 | val_poisoned_dataset, 352 | batch_size=256, shuffle=False, 353 | num_workers=args.workers, pin_memory=True) 354 | 355 | if args.distributed: 356 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 357 | else: 358 | train_sampler = None 359 | 360 | train_loader = torch.utils.data.DataLoader( 361 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 362 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 363 | 364 | val_loader = torch.utils.data.DataLoader( 365 | val_dataset, 366 | batch_size=256, shuffle=False, 367 | num_workers=args.workers, pin_memory=True) 368 | 369 | if args.evaluate: 370 | validate(val_loader, model, criterion, args) 371 | return 372 | 373 | if args.conf_matrix: 374 | # load imagenet metadata 375 | with open("metadata_files/imagenet_metadata.txt","r") as f: 376 | data = [l.strip() for l in f.readlines()] 377 | imagenet_metadata_dict = {} 378 | for line in data: 379 | wnid, classname = line.split('\t')[0], line.split('\t')[1] 380 | imagenet_metadata_dict[wnid] = classname 381 | 382 | with open(f'metadata_files/imagenet{args.nb_classes}_classes.txt', 'r') as f: 383 | class_dir_list = [l.strip() for l in f.readlines()] 384 | class_dir_list = sorted(class_dir_list) 385 | 386 | clean_acc1, conf_matrix_clean = validate(val_loader, model, criterion, args) 387 | poisoned_acc1, conf_matrix_poisoned = validate(val_poisoned_loader, model, criterion, args) 388 | 389 | np.save("{}/conf_matrix_clean.npy".format(args.save_folder), conf_matrix_clean) 390 | np.save("{}/conf_matrix_poisoned.npy".format(args.save_folder), conf_matrix_poisoned) 391 | 392 | with open("{}/conf_matrix.csv".format(args.save_folder), "w") as f: 393 | f.write("Model {},,Clean val,,,,Pois. val,,\n".format(os.path.join(os.path.dirname(args.resume).split("/")[-3], 394 | os.path.dirname(args.resume).split("/")[-2], 395 | os.path.dirname(args.resume).split("/")[-1], 396 | os.path.basename(args.resume)).replace(",",";"))) 397 | f.write("Data {},,acc1,,,,acc1,,\n".format(args.val_poisoned_file)) 398 | f.write(",,{:.2f},,,,{:.2f},,\n".format(clean_acc1, poisoned_acc1)) 399 | f.write("class name,class id,TP,FP,,TP,FP\n") 400 | 401 | clean_val_info = {'MAX_FP':0} 402 | poisoned_val_info = {'MAX_FP':0} 403 | for target in range(args.nb_classes): 404 | f.write("{},{},{},{},,".format(imagenet_metadata_dict[class_dir_list[target]].replace(",",";"), target, conf_matrix_clean[target][target], conf_matrix_clean[:, target].sum() - conf_matrix_clean[target][target])) 405 | f.write("{},{}\n".format(conf_matrix_poisoned[target][target], conf_matrix_poisoned[:, target].sum() - conf_matrix_poisoned[target][target])) 406 | 407 | if (conf_matrix_clean[:, target].sum() - conf_matrix_clean[target][target]) > clean_val_info['MAX_FP']: 408 | clean_val_info['MAX_FP'] = (conf_matrix_clean[:, target].sum() - conf_matrix_clean[target][target]) 409 | if (conf_matrix_poisoned[:, target].sum() - conf_matrix_poisoned[target][target]) > poisoned_val_info['MAX_FP']: 410 | poisoned_val_info['MAX_FP'] = (conf_matrix_poisoned[:, target].sum() - conf_matrix_poisoned[target][target]) 411 | # print results for target class 412 | if args.target_wnid == class_dir_list[target]: 413 | clean_val_info['WNID'] = class_dir_list[target] 414 | clean_val_info['CLASSNAME'] = imagenet_metadata_dict[class_dir_list[target]] 415 | clean_val_info['TP'] = conf_matrix_clean[target][target] 416 | clean_val_info['FP'] = conf_matrix_clean[:, target].sum() - conf_matrix_clean[target][target] 417 | 418 | poisoned_val_info['WNID'] = class_dir_list[target] 419 | poisoned_val_info['CLASSNAME'] = imagenet_metadata_dict[class_dir_list[target]] 420 | poisoned_val_info['TP'] = conf_matrix_poisoned[target][target] 421 | poisoned_val_info['FP'] = conf_matrix_poisoned[:, target].sum() - conf_matrix_poisoned[target][target] 422 | 423 | clean_val_info['NFP'] = clean_val_info['FP']/clean_val_info['MAX_FP'] 424 | poisoned_val_info['NFP'] = poisoned_val_info['FP']/poisoned_val_info['MAX_FP'] 425 | 426 | parse_csv("{}/conf_matrix.csv".format(args.save_folder)) 427 | 428 | print("\n\n") 429 | print("Clean val: {} {}".format(clean_val_info['WNID'], clean_val_info['CLASSNAME'])) 430 | print("{:.2f} {:d} {:d} {:.2f}".format(clean_acc1, int(clean_val_info['TP']), int(clean_val_info['FP']), clean_val_info['NFP'])) 431 | print("Poisoned val: {} {}".format(poisoned_val_info['WNID'], poisoned_val_info['CLASSNAME'])) 432 | print("{:.2f} {:d} {:d} {:.2f}".format(poisoned_acc1, int(poisoned_val_info['TP']), int(poisoned_val_info['FP']), poisoned_val_info['NFP'])) 433 | print("\n\n") 434 | return 435 | 436 | 437 | for epoch in range(args.start_epoch, args.epochs): 438 | if args.distributed: 439 | train_sampler.set_epoch(epoch) 440 | adjust_learning_rate(optimizer, init_lr, epoch, args) 441 | 442 | # train for one epoch 443 | train(train_loader, model, criterion, optimizer, epoch, args) 444 | 445 | # evaluate on validation set 446 | acc1, _ = validate(val_loader, model, criterion, args) 447 | 448 | # remember best acc@1 and save checkpoint 449 | is_best = acc1 > best_acc1 450 | best_acc1 = max(acc1, best_acc1) 451 | 452 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 453 | and args.rank == 0): # only the first GPU saves checkpoint 454 | ckpt_path = os.path.join(args.save_folder, 'checkpoint.pth.tar') 455 | save_checkpoint({ 456 | 'epoch': epoch + 1, 457 | 'arch': args.arch, 458 | 'state_dict': model.state_dict(), 459 | 'best_acc1': best_acc1, 460 | 'optimizer' : optimizer.state_dict(), 461 | }, is_best, ckpt_path) 462 | if epoch == args.start_epoch: 463 | sanity_check(model.state_dict(), args.pretrained, linear_keyword) 464 | 465 | 466 | def parse_csv(file_path): 467 | with open(file_path, 'r') as f: 468 | lines = f.readlines() 469 | lines = [line.strip().split(',') for line in lines] 470 | for line in lines: 471 | print(f'{line[0][:50]:50s} {line[1][:10]:10s} {line[2]:8s} {line[3]:8s} {line[5]:8s} {line[6]:8s}') 472 | 473 | 474 | def train(train_loader, model, criterion, optimizer, epoch, args): 475 | batch_time = AverageMeter('Time', ':6.3f') 476 | data_time = AverageMeter('Data', ':6.3f') 477 | losses = AverageMeter('Loss', ':.4e') 478 | top1 = AverageMeter('Acc@1', ':6.2f') 479 | top5 = AverageMeter('Acc@5', ':6.2f') 480 | progress = ProgressMeter( 481 | len(train_loader), 482 | [batch_time, data_time, losses, top1, top5], 483 | prefix="Epoch: [{}]".format(epoch)) 484 | 485 | """ 486 | Switch to eval mode: 487 | Under the protocol of linear classification on frozen features/models, 488 | it is not legitimate to change any part of the pre-trained model. 489 | BatchNorm in train mode may revise running mean/std (even if it receives 490 | no gradient), which are part of the model parameters too. 491 | """ 492 | model.eval() 493 | 494 | end = time.time() 495 | for i, (_, images, target, _) in enumerate(train_loader): 496 | # measure data loading time 497 | data_time.update(time.time() - end) 498 | 499 | if args.gpu is not None: 500 | images = images.cuda(args.gpu, non_blocking=True) 501 | if torch.cuda.is_available(): 502 | target = target.cuda(args.gpu, non_blocking=True) 503 | 504 | # compute output 505 | output = model(images) 506 | loss = criterion(output, target) 507 | 508 | # measure accuracy and record loss 509 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 510 | losses.update(loss.item(), images.size(0)) 511 | top1.update(acc1[0], images.size(0)) 512 | top5.update(acc5[0], images.size(0)) 513 | 514 | # compute gradient and do SGD step 515 | optimizer.zero_grad() 516 | loss.backward() 517 | optimizer.step() 518 | 519 | # measure elapsed time 520 | batch_time.update(time.time() - end) 521 | end = time.time() 522 | 523 | if i % args.print_freq == 0: 524 | progress.display(i) 525 | 526 | 527 | def validate(val_loader, model, criterion, args): 528 | batch_time = AverageMeter('Time', ':6.3f') 529 | losses = AverageMeter('Loss', ':.4e') 530 | top1 = AverageMeter('Acc@1', ':6.2f') 531 | top5 = AverageMeter('Acc@5', ':6.2f') 532 | progress = ProgressMeter( 533 | len(val_loader), 534 | [batch_time, losses, top1, top5], 535 | prefix='Test: ') 536 | 537 | # switch to evaluate mode 538 | model.eval() 539 | 540 | conf_matrix = np.zeros((args.nb_classes, args.nb_classes)) 541 | 542 | with torch.no_grad(): 543 | end = time.time() 544 | for i, (_, images, target, _) in enumerate(val_loader): 545 | if args.gpu is not None: 546 | images = images.cuda(args.gpu, non_blocking=True) 547 | if torch.cuda.is_available(): 548 | target = target.cuda(args.gpu, non_blocking=True) 549 | 550 | # compute output 551 | output = model(images) 552 | loss = criterion(output, target) 553 | 554 | # measure accuracy and record loss 555 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 556 | losses.update(loss.item(), images.size(0)) 557 | top1.update(acc1[0], images.size(0)) 558 | top5.update(acc5[0], images.size(0)) 559 | 560 | # update confusion matrix 561 | _, pred = output.topk(1, 1, True, True) 562 | for y, yp in zip(target.cpu().numpy(), pred.cpu().numpy()): 563 | conf_matrix[int(y), int(yp)] += 1 564 | 565 | # measure elapsed time 566 | batch_time.update(time.time() - end) 567 | end = time.time() 568 | 569 | if i % args.print_freq == 0: 570 | progress.display(i) 571 | 572 | # TODO: this should also be done with the ProgressMeter 573 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 574 | .format(top1=top1, top5=top5)) 575 | 576 | return top1.avg, conf_matrix 577 | 578 | 579 | def save_checkpoint(state, is_best, ckpt_path): 580 | while True: 581 | try: 582 | print(f'=======> attempt model saving at: {ckpt_path}') 583 | torch.save(state, ckpt_path) 584 | break 585 | except: 586 | print(f'=======> model saving failed at: {ckpt_path}') 587 | print(f'=======> delete the old file') 588 | os.remove(ckpt_path) 589 | print(f'=======> sleep for 2 seconds') 590 | time.sleep(2) 591 | print(f'=======> retry') 592 | 593 | if is_best: 594 | to_ckpt_path = ckpt_path.replace('checkpoint.pth.tar', 'model_best.pth.tar') 595 | while True: 596 | try: 597 | print(f'=======> copy best model: {to_ckpt_path}') 598 | shutil.copyfile(ckpt_path, to_ckpt_path) 599 | break 600 | except: 601 | print(f'=======> model copying failed at: {to_ckpt_path}') 602 | print(f'=======> delete the old file') 603 | os.remove(to_ckpt_path) 604 | print(f'=======> sleep for 2 seconds') 605 | time.sleep(2) 606 | print(f'=======> retry') 607 | 608 | 609 | def sanity_check(state_dict, pretrained_weights, linear_keyword): 610 | """ 611 | Linear classifier should not change any weights other than the linear layer. 612 | This sanity check asserts nothing wrong happens (e.g., BN stats updated). 613 | """ 614 | print("=> loading '{}' for sanity check".format(pretrained_weights)) 615 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 616 | sd_pre = checkpoint['state_dict'] 617 | sd_pre = {k.replace('module.', ''): v for k, v in sd_pre.items()} 618 | sd_pre = {k: v for k, v in sd_pre.items() if 'encoder_q' in k or 'base_encoder'} 619 | sd_pre = {k.replace('encoder_q.', ''): v for k, v in sd_pre.items()} 620 | sd_pre = {k.replace('base_encoder.', ''): v for k, v in sd_pre.items()} 621 | 622 | for k in list(state_dict.keys()): 623 | # only ignore linear layer 624 | if '%s.weight' % linear_keyword in k or '%s.bias' % linear_keyword in k: 625 | continue 626 | 627 | # name in pretrained model 628 | k_pre = k[len('module.'):] if k.startswith('module.') else k 629 | 630 | assert ((state_dict[k].cpu() == sd_pre[k_pre]).all()), \ 631 | '{} is changed in linear classifier training.'.format(k) 632 | 633 | print("=> sanity check passed.") 634 | 635 | 636 | class AverageMeter(object): 637 | """Computes and stores the average and current value""" 638 | def __init__(self, name, fmt=':f'): 639 | self.name = name 640 | self.fmt = fmt 641 | self.reset() 642 | 643 | def reset(self): 644 | self.val = 0 645 | self.avg = 0 646 | self.sum = 0 647 | self.count = 0 648 | 649 | def update(self, val, n=1): 650 | self.val = val 651 | self.sum += val * n 652 | self.count += n 653 | self.avg = self.sum / self.count 654 | 655 | def __str__(self): 656 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 657 | return fmtstr.format(**self.__dict__) 658 | 659 | 660 | class ProgressMeter(object): 661 | def __init__(self, num_batches, meters, prefix=""): 662 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 663 | self.meters = meters 664 | self.prefix = prefix 665 | 666 | def display(self, batch): 667 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 668 | entries += [str(meter) for meter in self.meters] 669 | print('\t'.join(entries)) 670 | 671 | def _get_batch_fmtstr(self, num_batches): 672 | num_digits = len(str(num_batches // 1)) 673 | fmt = '{:' + str(num_digits) + 'd}' 674 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 675 | 676 | 677 | def adjust_learning_rate(optimizer, init_lr, epoch, args): 678 | """Decay the learning rate based on schedule""" 679 | cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 680 | for param_group in optimizer.param_groups: 681 | param_group['lr'] = cur_lr 682 | 683 | 684 | def accuracy(output, target, topk=(1,)): 685 | """Computes the accuracy over the k top predictions for the specified values of k""" 686 | with torch.no_grad(): 687 | maxk = max(topk) 688 | batch_size = target.size(0) 689 | 690 | _, pred = output.topk(maxk, 1, True, True) 691 | pred = pred.t() 692 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 693 | 694 | res = [] 695 | for k in topk: 696 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 697 | res.append(correct_k.mul_(100.0 / batch_size)) 698 | return res 699 | 700 | 701 | if __name__ == '__main__': 702 | main() 703 | -------------------------------------------------------------------------------- /main_moco_files_dataset_strong_aug.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import builtins 3 | import math 4 | import os 5 | import random 6 | import shutil 7 | import time 8 | import warnings 9 | from functools import partial 10 | import logging 11 | from pathlib import Path 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.parallel 16 | import torch.backends.cudnn as cudnn 17 | import torch.distributed as dist 18 | import torch.optim 19 | import torch.multiprocessing as mp 20 | import torch.utils.data 21 | import torch.utils.data.distributed 22 | import torchvision.transforms as transforms 23 | import torchvision.datasets as datasets 24 | import torchvision.models as torchvision_models 25 | from torch.utils.tensorboard import SummaryWriter 26 | 27 | from PIL import Image 28 | import numpy as np 29 | from timm.data.random_erasing import RandomErasing 30 | 31 | import moco.builder 32 | import moco.loader 33 | import moco.optimizer 34 | 35 | import vits 36 | from tools import get_logger 37 | 38 | 39 | torchvision_model_names = sorted(name for name in torchvision_models.__dict__ 40 | if name.islower() and not name.startswith("__") 41 | and callable(torchvision_models.__dict__[name])) 42 | 43 | model_names = ['vit_small', 'vit_base', 'vit_conv_small', 'vit_conv_base'] + torchvision_model_names 44 | 45 | parser = argparse.ArgumentParser(description='MoCo ImageNet Pre-Training') 46 | parser.add_argument('data', type=str, 47 | help='file containing dataset paths') 48 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', 49 | choices=model_names, 50 | help='model architecture: ' + 51 | ' | '.join(model_names) + 52 | ' (default: resnet50)') 53 | parser.add_argument('-j', '--workers', default=80, type=int, metavar='N', 54 | help='number of data loading workers (default: 32)') 55 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 56 | help='number of total epochs to run') 57 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 58 | help='manual epoch number (useful on restarts)') 59 | parser.add_argument('-b', '--batch-size', default=1024, type=int, 60 | metavar='N', 61 | help='mini-batch size (default: 4096), this is the total ' 62 | 'batch size of all GPUs on all nodes when ' 63 | 'using Data Parallel or Distributed Data Parallel') 64 | parser.add_argument('--lr', '--learning-rate', default=1.5e-4, type=float, 65 | metavar='LR', help='initial (base) learning rate', dest='lr') 66 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 67 | help='momentum') 68 | parser.add_argument('--wd', '--weight-decay', default=.1, type=float, 69 | metavar='W', help='weight decay (default: 1e-6)', 70 | dest='weight_decay') 71 | parser.add_argument('-p', '--print-freq', default=10, type=int, 72 | metavar='N', help='print frequency (default: 10)') 73 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 74 | help='path to latest checkpoint (default: none)') 75 | parser.add_argument('--world-size', default=-1, type=int, 76 | help='number of nodes for distributed training') 77 | parser.add_argument('--rank', default=-1, type=int, 78 | help='node rank for distributed training') 79 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 80 | help='url used to set up distributed training') 81 | parser.add_argument('--dist-backend', default='nccl', type=str, 82 | help='distributed backend') 83 | parser.add_argument('--seed', default=None, type=int, 84 | help='seed for initializing training. ') 85 | parser.add_argument('--gpu', default=None, type=int, 86 | help='GPU id to use.') 87 | parser.add_argument('--multiprocessing-distributed', action='store_true', 88 | help='Use multi-processing distributed training to launch ' 89 | 'N processes per node, which has N GPUs. This is the ' 90 | 'fastest way to use PyTorch for either single node or ' 91 | 'multi node data parallel training') 92 | parser.add_argument('--icutmix', action='store_true', 93 | help='apply input cut mix augmentation') 94 | parser.add_argument('--alpha', default=1.0, type=float, 95 | help='parameter alpha used in beta distribution') 96 | 97 | # moco specific configs: 98 | parser.add_argument('--moco-dim', default=256, type=int, 99 | help='feature dimension (default: 256)') 100 | parser.add_argument('--moco-mlp-dim', default=4096, type=int, 101 | help='hidden dimension in MLPs (default: 4096)') 102 | parser.add_argument('--moco-m', default=0.99, type=float, 103 | help='moco momentum of updating momentum encoder (default: 0.99)') 104 | parser.add_argument('--moco-m-cos', action='store_true', 105 | help='gradually increase moco momentum to 1 with a ' 106 | 'half-cycle cosine schedule') 107 | parser.add_argument('--moco-t', default=0.2, type=float, 108 | help='softmax temperature (default: 1.0)') 109 | 110 | # vit specific configs: 111 | parser.add_argument('--stop-grad-conv1', action='store_true', 112 | help='stop-grad after first conv, or patch embedding') 113 | 114 | # other upgrades 115 | parser.add_argument('--optimizer', default='adamw', type=str, 116 | choices=['lars', 'adamw'], 117 | help='optimizer used (default: lars)') 118 | parser.add_argument('--warmup-epochs', default=40, type=int, metavar='N', 119 | help='number of warmup epochs') 120 | parser.add_argument('--crop-min', default=0.08, type=float, 121 | help='minimum scale for random cropping (default: 0.08)') 122 | 123 | parser.add_argument('--save_folder', default='./output/debug', 124 | help='path where to save, empty for no saving') 125 | parser.add_argument('--save-freq', default=20, type=int, 126 | help='checkpoint save frequency') 127 | parser.add_argument('--debug', action='store_true', 128 | help='debug mode') 129 | 130 | 131 | class FileListDataset(torch.utils.data.Dataset): 132 | def __init__(self, path_to_txt_file, transform): 133 | # self.data_root = data_root 134 | with open(path_to_txt_file, 'r') as f: 135 | self.file_list = f.readlines() 136 | self.file_list = [row.rstrip() for row in self.file_list] 137 | 138 | self.transform = transform 139 | 140 | 141 | def __getitem__(self, idx): 142 | image_path = self.file_list[idx].split()[0] 143 | img = Image.open(image_path).convert('RGB') 144 | target = int(self.file_list[idx].split()[1]) 145 | 146 | if self.transform is not None: 147 | images = self.transform(img) 148 | 149 | return images, target 150 | 151 | def __len__(self): 152 | return len(self.file_list) 153 | 154 | 155 | def main(): 156 | args = parser.parse_args() 157 | 158 | Path(args.save_folder).mkdir(parents=True, exist_ok=True) 159 | 160 | if args.seed is not None: 161 | random.seed(args.seed) 162 | np.random.seed(args.seed) 163 | torch.manual_seed(args.seed) 164 | cudnn.deterministic = True 165 | warnings.warn('You have chosen to seed training. ' 166 | 'This will turn on the CUDNN deterministic setting, ' 167 | 'which can slow down your training considerably! ' 168 | 'You may see unexpected behavior when restarting ' 169 | 'from checkpoints.') 170 | 171 | if args.gpu is not None: 172 | warnings.warn('You have chosen a specific GPU. This will completely ' 173 | 'disable data parallelism.') 174 | 175 | if args.dist_url == "env://" and args.world_size == -1: 176 | args.world_size = int(os.environ["WORLD_SIZE"]) 177 | 178 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 179 | 180 | ngpus_per_node = torch.cuda.device_count() 181 | if args.multiprocessing_distributed: 182 | # Since we have ngpus_per_node processes per node, the total world_size 183 | # needs to be adjusted accordingly 184 | args.world_size = ngpus_per_node * args.world_size 185 | # Use torch.multiprocessing.spawn to launch distributed processes: the 186 | # main_worker process function 187 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 188 | else: 189 | # Simply call main_worker function 190 | main_worker(args.gpu, ngpus_per_node, args) 191 | 192 | 193 | def main_worker(gpu, ngpus_per_node, args): 194 | args.gpu = gpu 195 | 196 | # suppress printing if not first GPU on each node 197 | if args.multiprocessing_distributed and (args.gpu != 0 or args.rank != 0): 198 | def print_pass(*args): 199 | pass 200 | builtins.print = print_pass 201 | else: 202 | if not args.debug: 203 | os.environ['PYTHONBREAKPOINT'] = '0' 204 | logger = get_logger( 205 | logpath=os.path.join(args.save_folder, 'logs'), 206 | filepath=os.path.abspath(__file__) 207 | ) 208 | def print_pass(*arg): 209 | logger.info(*arg) 210 | builtins.print = print_pass 211 | 212 | print('==> training parameters <==') 213 | for arg in vars(args): 214 | print(f'==> {arg}: {getattr(args, arg)}') 215 | print('===========================') 216 | 217 | if args.gpu is not None: 218 | print("Use GPU: {} for training".format(args.gpu)) 219 | 220 | if args.distributed: 221 | if args.dist_url == "env://" and args.rank == -1: 222 | args.rank = int(os.environ["RANK"]) 223 | if args.multiprocessing_distributed: 224 | # For multiprocessing distributed training, rank needs to be the 225 | # global rank among all the processes 226 | args.rank = args.rank * ngpus_per_node + gpu 227 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 228 | world_size=args.world_size, rank=args.rank) 229 | torch.distributed.barrier() 230 | # create model 231 | print("=> creating model '{}'".format(args.arch)) 232 | if args.arch.startswith('vit'): 233 | model = moco.builder.MoCo_ViT( 234 | partial(vits.__dict__[args.arch], stop_grad_conv1=args.stop_grad_conv1), 235 | args.moco_dim, args.moco_mlp_dim, args.moco_t, 236 | icutmix=args.icutmix, alpha=args.alpha) 237 | else: 238 | model = moco.builder.MoCo_ResNet( 239 | partial(torchvision_models.__dict__[args.arch], zero_init_residual=True), 240 | args.moco_dim, args.moco_mlp_dim, args.moco_t, 241 | icutmix=args.icutmix, alpha=args.alpha) 242 | 243 | # infer learning rate before changing batch size 244 | args.lr = args.lr * args.batch_size / 256 245 | 246 | if not torch.cuda.is_available(): 247 | print('using CPU, this will be slow') 248 | elif args.distributed: 249 | # apply SyncBN 250 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 251 | # For multiprocessing distributed, DistributedDataParallel constructor 252 | # should always set the single device scope, otherwise, 253 | # DistributedDataParallel will use all available devices. 254 | if args.gpu is not None: 255 | torch.cuda.set_device(args.gpu) 256 | model.cuda(args.gpu) 257 | # When using a single GPU per process and per 258 | # DistributedDataParallel, we need to divide the batch size 259 | # ourselves based on the total number of GPUs we have 260 | args.batch_size = int(args.batch_size / args.world_size) 261 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 262 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 263 | else: 264 | model.cuda() 265 | # DistributedDataParallel will divide and allocate batch_size to all 266 | # available GPUs if device_ids are not set 267 | model = torch.nn.parallel.DistributedDataParallel(model) 268 | elif args.gpu is not None: 269 | torch.cuda.set_device(args.gpu) 270 | model = model.cuda(args.gpu) 271 | # comment out the following line for debugging 272 | # raise NotImplementedError("Only DistributedDataParallel is supported.") 273 | else: 274 | # AllGather/rank implementation in this code only supports DistributedDataParallel. 275 | raise NotImplementedError("Only DistributedDataParallel is supported.") 276 | print(model) # print model after SyncBatchNorm 277 | 278 | if args.optimizer == 'lars': 279 | optimizer = moco.optimizer.LARS(model.parameters(), args.lr, 280 | weight_decay=args.weight_decay, 281 | momentum=args.momentum) 282 | elif args.optimizer == 'adamw': 283 | optimizer = torch.optim.AdamW(model.parameters(), args.lr, 284 | weight_decay=args.weight_decay) 285 | 286 | scaler = torch.cuda.amp.GradScaler() 287 | summary_writer = SummaryWriter() if args.rank == 0 else None 288 | 289 | # optionally resume from a checkpoint 290 | if args.resume: 291 | if os.path.isfile(args.resume): 292 | print("=> loading checkpoint '{}'".format(args.resume)) 293 | if args.gpu is None: 294 | checkpoint = torch.load(args.resume) 295 | else: 296 | # Map model to be loaded to specified single gpu. 297 | loc = 'cuda:{}'.format(args.gpu) 298 | checkpoint = torch.load(args.resume, map_location=loc) 299 | args.start_epoch = checkpoint['epoch'] 300 | model.load_state_dict(checkpoint['state_dict']) 301 | optimizer.load_state_dict(checkpoint['optimizer']) 302 | scaler.load_state_dict(checkpoint['scaler']) 303 | print("=> loaded checkpoint '{}' (epoch {})" 304 | .format(args.resume, checkpoint['epoch'])) 305 | else: 306 | print("=> no checkpoint found at '{}'".format(args.resume)) 307 | 308 | # Data loading code 309 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 310 | std=[0.229, 0.224, 0.225]) 311 | 312 | # follow BYOL's augmentation recipe: https://arxiv.org/abs/2006.07733 313 | augmentation1 = [ 314 | transforms.RandomResizedCrop(224, scale=(args.crop_min, 1.)), 315 | transforms.RandomApply([ 316 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened 317 | ], p=0.8), 318 | transforms.RandomGrayscale(p=0.2), 319 | transforms.RandomApply([moco.loader.GaussianBlur([.1, 2.])], p=1.0), 320 | transforms.RandomHorizontalFlip(), 321 | transforms.ToTensor(), 322 | normalize 323 | ] 324 | 325 | augmentation2 = [ 326 | transforms.RandomResizedCrop(224, scale=(args.crop_min, 1.)), 327 | transforms.RandomApply([ 328 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened 329 | ], p=0.8), 330 | transforms.RandomGrayscale(p=0.2), 331 | transforms.RandomApply([moco.loader.GaussianBlur([.1, 2.])], p=0.1), 332 | transforms.RandomApply([moco.loader.Solarize()], p=0.2), 333 | transforms.RandomHorizontalFlip(), 334 | transforms.ToTensor(), 335 | normalize 336 | ] 337 | 338 | train_dataset = FileListDataset( 339 | args.data, 340 | moco.loader.TwoCropsTransform(transforms.Compose(augmentation1), 341 | transforms.Compose(augmentation2))) 342 | 343 | if args.distributed: 344 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 345 | else: 346 | train_sampler = None 347 | 348 | train_loader = torch.utils.data.DataLoader( 349 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 350 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) 351 | 352 | for epoch in range(args.start_epoch, args.epochs): 353 | if args.distributed: 354 | train_sampler.set_epoch(epoch) 355 | 356 | # train for one epoch 357 | train(train_loader, model, optimizer, scaler, summary_writer, epoch, args) 358 | 359 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 360 | and args.rank == 0): # only the first GPU saves checkpoint 361 | filename = os.path.join(args.save_folder, 'checkpoint_%04d.pth.tar' % epoch) 362 | if epoch % args.save_freq == 0 or (epoch + 1) == args.epochs: 363 | save_checkpoint({ 364 | 'epoch': epoch + 1, 365 | 'arch': args.arch, 366 | 'state_dict': model.state_dict(), 367 | 'optimizer' : optimizer.state_dict(), 368 | 'scaler': scaler.state_dict(), 369 | }, is_best=False, filename=filename) 370 | 371 | if args.rank == 0: 372 | summary_writer.close() 373 | 374 | def train(train_loader, model, optimizer, scaler, summary_writer, epoch, args): 375 | batch_time = AverageMeter('Time', ':6.3f') 376 | data_time = AverageMeter('Data', ':6.3f') 377 | learning_rates = AverageMeter('LR', ':.4e') 378 | losses = AverageMeter('Loss', ':.4e') 379 | progress = ProgressMeter( 380 | len(train_loader), 381 | [batch_time, data_time, learning_rates, losses], 382 | prefix="Epoch: [{}]".format(epoch)) 383 | 384 | # switch to train mode 385 | model.train() 386 | 387 | end = time.time() 388 | iters_per_epoch = len(train_loader) 389 | moco_m = args.moco_m 390 | for i, (images, _) in enumerate(train_loader): 391 | # measure data loading time 392 | data_time.update(time.time() - end) 393 | 394 | # adjust learning rate and momentum coefficient per iteration 395 | lr = adjust_learning_rate(optimizer, epoch + i / iters_per_epoch, args) 396 | learning_rates.update(lr) 397 | if args.moco_m_cos: 398 | moco_m = adjust_moco_momentum(epoch + i / iters_per_epoch, args) 399 | 400 | if args.gpu is not None: 401 | images[0] = images[0].cuda(args.gpu, non_blocking=True) 402 | images[1] = images[1].cuda(args.gpu, non_blocking=True) 403 | 404 | # compute output 405 | with torch.cuda.amp.autocast(True): 406 | loss = model(images[0], images[1], moco_m) 407 | 408 | losses.update(loss.item(), images[0].size(0)) 409 | if args.rank == 0: 410 | summary_writer.add_scalar("loss", loss.item(), epoch * iters_per_epoch + i) 411 | 412 | # compute gradient and do SGD step 413 | optimizer.zero_grad() 414 | scaler.scale(loss).backward() 415 | scaler.step(optimizer) 416 | scaler.update() 417 | 418 | # measure elapsed time 419 | batch_time.update(time.time() - end) 420 | end = time.time() 421 | 422 | if i % args.print_freq == 0: 423 | progress.display(i) 424 | 425 | 426 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 427 | while True: 428 | try: 429 | print(f'=======> attempt model saving at: {filename}') 430 | torch.save(state, filename) 431 | break 432 | except: 433 | print(f'=======> model saving failed at: {filename}') 434 | print(f'=======> delete the old file') 435 | os.remove(filename) 436 | print(f'=======> sleep for 2 seconds') 437 | time.sleep(2) 438 | print(f'=======> retry') 439 | 440 | 441 | class AverageMeter(object): 442 | """Computes and stores the average and current value""" 443 | def __init__(self, name, fmt=':f'): 444 | self.name = name 445 | self.fmt = fmt 446 | self.reset() 447 | 448 | def reset(self): 449 | self.val = 0 450 | self.avg = 0 451 | self.sum = 0 452 | self.count = 0 453 | 454 | def update(self, val, n=1): 455 | self.val = val 456 | self.sum += val * n 457 | self.count += n 458 | self.avg = self.sum / self.count 459 | 460 | def __str__(self): 461 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 462 | return fmtstr.format(**self.__dict__) 463 | 464 | 465 | class ProgressMeter(object): 466 | def __init__(self, num_batches, meters, prefix=""): 467 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 468 | self.meters = meters 469 | self.prefix = prefix 470 | 471 | def display(self, batch): 472 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 473 | entries += [str(meter) for meter in self.meters] 474 | print('\t'.join(entries)) 475 | 476 | def _get_batch_fmtstr(self, num_batches): 477 | num_digits = len(str(num_batches // 1)) 478 | fmt = '{:' + str(num_digits) + 'd}' 479 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 480 | 481 | 482 | def adjust_learning_rate(optimizer, epoch, args): 483 | """Decays the learning rate with half-cycle cosine after warmup""" 484 | if epoch < args.warmup_epochs: 485 | lr = args.lr * epoch / args.warmup_epochs 486 | else: 487 | lr = args.lr * 0.5 * (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 488 | for param_group in optimizer.param_groups: 489 | param_group['lr'] = lr 490 | return lr 491 | 492 | 493 | def adjust_moco_momentum(epoch, args): 494 | """Adjust moco momentum based on current epoch""" 495 | m = 1. - 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) * (1. - args.moco_m) 496 | return m 497 | 498 | 499 | if __name__ == '__main__': 500 | main() 501 | 502 | -------------------------------------------------------------------------------- /metadata_files/im100_metadata.txt: -------------------------------------------------------------------------------- 1 | n01558993 0 robin, American robin, Turdus migratorius 2 | n01692333 1 Gila monster, Heloderma suspectum 3 | n01729322 2 hognose snake, puff adder, sand viper 4 | n01735189 3 garter snake, grass snake 5 | n01749939 4 green mamba 6 | n01773797 5 garden spider, Aranea diademata 7 | n01820546 6 lorikeet 8 | n01855672 7 goose 9 | n01978455 8 rock crab, Cancer irroratus 10 | n01980166 9 fiddler crab 11 | n01983481 10 American lobster, Northern lobster, Maine lobster, Homarus americanus 12 | n02009229 11 little blue heron, Egretta caerulea 13 | n02018207 12 American coot, marsh hen, mud hen, water hen, Fulica americana 14 | n02085620 13 Chihuahua 15 | n02086240 14 Shih-Tzu 16 | n02086910 15 papillon 17 | n02087046 16 toy terrier 18 | n02089867 17 Walker hound, Walker foxhound 19 | n02089973 18 English foxhound 20 | n02090622 19 borzoi, Russian wolfhound 21 | n02091831 20 Saluki, gazelle hound 22 | n02093428 21 American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier 23 | n02099849 22 Chesapeake Bay retriever 24 | n02100583 23 vizsla, Hungarian pointer 25 | n02104029 24 kuvasz 26 | n02105505 25 komondor 27 | n02106550 26 Rottweiler 28 | n02107142 27 Doberman, Doberman pinscher 29 | n02108089 28 boxer 30 | n02109047 29 Great Dane 31 | n02113799 30 standard poodle 32 | n02113978 31 Mexican hairless 33 | n02114855 32 coyote, prairie wolf, brush wolf, Canis latrans 34 | n02116738 33 African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus 35 | n02119022 34 red fox, Vulpes vulpes 36 | n02123045 35 tabby, tabby cat 37 | n02138441 36 meerkat, mierkat 38 | n02172182 37 dung beetle 39 | n02231487 38 walking stick, walkingstick, stick insect 40 | n02259212 39 leafhopper 41 | n02326432 40 hare 42 | n02396427 41 wild boar, boar, Sus scrofa 43 | n02483362 42 gibbon, Hylobates lar 44 | n02488291 43 langur 45 | n02701002 44 ambulance 46 | n02788148 45 bannister, banister, balustrade, balusters, handrail 47 | n02804414 46 bassinet 48 | n02859443 47 boathouse 49 | n02869837 48 bonnet, poke bonnet 50 | n02877765 49 bottlecap 51 | n02974003 50 car wheel 52 | n03017168 51 chime, bell, gong 53 | n03032252 52 cinema, movie theater, movie theatre, movie house, picture palace 54 | n03062245 53 cocktail shaker 55 | n03085013 54 computer keyboard, keypad 56 | n03259280 55 Dutch oven 57 | n03379051 56 football helmet 58 | n03424325 57 gasmask, respirator, gas helmet 59 | n03492542 58 hard disc, hard disk, fixed disk 60 | n03494278 59 harmonica, mouth organ, harp, mouth harp 61 | n03530642 60 honeycomb 62 | n03584829 61 iron, smoothing iron 63 | n03594734 62 jean, blue jean, denim 64 | n03637318 63 lampshade, lamp shade 65 | n03642806 64 laptop, laptop computer 66 | n03764736 65 milk can 67 | n03775546 66 mixing bowl 68 | n03777754 67 modem 69 | n03785016 68 moped 70 | n03787032 69 mortarboard 71 | n03794056 70 mousetrap 72 | n03837869 71 obelisk 73 | n03891251 72 park bench 74 | n03903868 73 pedestal, plinth, footstall 75 | n03930630 74 pickup, pickup truck 76 | n03947888 75 pirate, pirate ship 77 | n04026417 76 purse 78 | n04067472 77 reel 79 | n04099969 78 rocking chair, rocker 80 | n04111531 79 rotisserie 81 | n04127249 80 safety pin 82 | n04136333 81 sarong 83 | n04229816 82 ski mask 84 | n04238763 83 slide rule, slipstick 85 | n04336792 84 stretcher 86 | n04418357 85 theater curtain, theatre curtain 87 | n04429376 86 throne 88 | n04435653 87 tile roof 89 | n04485082 88 tripod 90 | n04493381 89 tub, vat 91 | n04517823 90 vacuum, vacuum cleaner 92 | n04589890 91 window screen 93 | n04592741 92 wing 94 | n07714571 93 head cabbage 95 | n07715103 94 cauliflower 96 | n07753275 95 pineapple, ananas 97 | n07831146 96 carbonara 98 | n07836838 97 chocolate sauce, chocolate syrup 99 | n13037406 98 gyromitra 100 | n13040303 99 stinkhorn, carrion fungus 101 | -------------------------------------------------------------------------------- /metadata_files/imagenet100_classes.txt: -------------------------------------------------------------------------------- 1 | n02869837 2 | n01749939 3 | n02488291 4 | n02107142 5 | n13037406 6 | n02091831 7 | n04517823 8 | n04589890 9 | n03062245 10 | n01773797 11 | n01735189 12 | n07831146 13 | n07753275 14 | n03085013 15 | n04485082 16 | n02105505 17 | n01983481 18 | n02788148 19 | n03530642 20 | n04435653 21 | n02086910 22 | n02859443 23 | n13040303 24 | n03594734 25 | n02085620 26 | n02099849 27 | n01558993 28 | n04493381 29 | n02109047 30 | n04111531 31 | n02877765 32 | n04429376 33 | n02009229 34 | n01978455 35 | n02106550 36 | n01820546 37 | n01692333 38 | n07714571 39 | n02974003 40 | n02114855 41 | n03785016 42 | n03764736 43 | n03775546 44 | n02087046 45 | n07836838 46 | n04099969 47 | n04592741 48 | n03891251 49 | n02701002 50 | n03379051 51 | n02259212 52 | n07715103 53 | n03947888 54 | n04026417 55 | n02326432 56 | n03637318 57 | n01980166 58 | n02113799 59 | n02086240 60 | n03903868 61 | n02483362 62 | n04127249 63 | n02089973 64 | n03017168 65 | n02093428 66 | n02804414 67 | n02396427 68 | n04418357 69 | n02172182 70 | n01729322 71 | n02113978 72 | n03787032 73 | n02089867 74 | n02119022 75 | n03777754 76 | n04238763 77 | n02231487 78 | n03032252 79 | n02138441 80 | n02104029 81 | n03837869 82 | n03494278 83 | n04136333 84 | n03794056 85 | n03492542 86 | n02018207 87 | n04067472 88 | n03930630 89 | n03584829 90 | n02123045 91 | n04229816 92 | n02100583 93 | n03642806 94 | n04336792 95 | n03259280 96 | n02116738 97 | n02108089 98 | n03424325 99 | n01855672 100 | n02090622 101 | -------------------------------------------------------------------------------- /moco/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCDvision/PatchSearch/39f298225dde9e868f4257a5a4c11d6f805bfa1a/moco/__init__.py -------------------------------------------------------------------------------- /moco/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def rand_bbox(size, lam): 6 | W, H = size 7 | cut_rat = (1. - lam).sqrt() 8 | cut_w = (W * cut_rat).to(torch.long) 9 | cut_h = (H * cut_rat).to(torch.long) 10 | 11 | cx = torch.zeros_like(cut_w, dtype=cut_w.dtype).random_(0, W) 12 | cy = torch.zeros_like(cut_h, dtype=cut_h.dtype).random_(0, H) 13 | 14 | bbx1 = (cx - cut_w // 2).clamp(0, W) 15 | bby1 = (cy - cut_h // 2).clamp(0, H) 16 | bbx2 = (cx + cut_w // 2).clamp(0, W) 17 | bby2 = (cy + cut_h // 2).clamp(0, H) 18 | 19 | new_lam = 1. - (bbx2 - bbx1).to(lam.dtype) * (bby2 - bby1).to(lam.dtype) / (W * H) 20 | 21 | return (bbx1, bby1, bbx2, bby2), new_lam 22 | 23 | 24 | def cutmix(x, alpha): 25 | if not isinstance(alpha, (list, tuple)): 26 | alpha = [alpha, alpha] 27 | 28 | # create beta distribution and sample lambda from it 29 | beta = torch.distributions.beta.Beta(*alpha) 30 | lam = beta.sample().to(device=x.device) 31 | lam = torch.max(lam, 1. - lam) 32 | 33 | # sample a random bounding box 34 | (bbx1, bby1, bbx2, bby2), lam = rand_bbox(x.shape[-2:], lam) 35 | 36 | # sample random indices for cutmix 37 | randind = torch.randperm(x.shape[0], device=x.device) 38 | output = x.clone() 39 | output[..., bbx1:bbx2, bby1:bby2] = output[randind][..., bbx1:bbx2, bby1:bby2] 40 | 41 | return output, randind, lam 42 | 43 | 44 | class MoCo(nn.Module): 45 | """ 46 | Build a MoCo model with a base encoder, a momentum encoder, and two MLPs 47 | https://arxiv.org/abs/1911.05722 48 | """ 49 | def __init__(self, base_encoder, dim=256, mlp_dim=4096, T=1.0, icutmix=False, alpha=1.0): 50 | """ 51 | dim: feature dimension (default: 256) 52 | mlp_dim: hidden dimension in MLPs (default: 4096) 53 | T: softmax temperature (default: 1.0) 54 | """ 55 | super(MoCo, self).__init__() 56 | 57 | self.T = T 58 | self.icutmix = icutmix 59 | self.alpha = alpha 60 | 61 | # build encoders 62 | self.base_encoder = base_encoder(num_classes=mlp_dim) 63 | self.momentum_encoder = base_encoder(num_classes=mlp_dim) 64 | 65 | self._build_projector_and_predictor_mlps(dim, mlp_dim) 66 | 67 | for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()): 68 | param_m.data.copy_(param_b.data) # initialize 69 | param_m.requires_grad = False # not update by gradient 70 | 71 | def _build_mlp(self, num_layers, input_dim, mlp_dim, output_dim, last_bn=True): 72 | mlp = [] 73 | for l in range(num_layers): 74 | dim1 = input_dim if l == 0 else mlp_dim 75 | dim2 = output_dim if l == num_layers - 1 else mlp_dim 76 | 77 | mlp.append(nn.Linear(dim1, dim2, bias=False)) 78 | 79 | if l < num_layers - 1: 80 | mlp.append(nn.BatchNorm1d(dim2)) 81 | mlp.append(nn.ReLU(inplace=True)) 82 | elif last_bn: 83 | # follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157 84 | # for simplicity, we further removed gamma in BN 85 | mlp.append(nn.BatchNorm1d(dim2, affine=False)) 86 | 87 | return nn.Sequential(*mlp) 88 | 89 | def _build_projector_and_predictor_mlps(self, dim, mlp_dim): 90 | pass 91 | 92 | @torch.no_grad() 93 | def _update_momentum_encoder(self, m): 94 | """Momentum update of the momentum encoder""" 95 | for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()): 96 | param_m.data = param_m.data * m + param_b.data * (1. - m) 97 | 98 | def contrastive_loss(self, q, k): 99 | # normalize 100 | q = nn.functional.normalize(q, dim=1) 101 | k = nn.functional.normalize(k, dim=1) 102 | # gather all targets 103 | k = concat_all_gather(k) 104 | # Einstein sum is more intuitive 105 | logits = torch.einsum('nc,mc->nm', [q, k]) / self.T 106 | N = logits.shape[0] # batch size per GPU 107 | labels = (torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()).cuda() 108 | return nn.CrossEntropyLoss()(logits, labels) * (2 * self.T) 109 | 110 | def contrastive_icutmix_loss(self, q, labels_aux, lam, k): 111 | # normalize 112 | q = nn.functional.normalize(q, dim=1) 113 | k = nn.functional.normalize(k, dim=1) 114 | # gather all targets 115 | k = concat_all_gather(k) 116 | # Einstein sum is more intuitive 117 | logits = torch.einsum('nc,mc->nm', [q, k]) / self.T 118 | N = logits.shape[0] # batch size per GPU 119 | offset = N * torch.distributed.get_rank() 120 | labels = (torch.arange(N, dtype=torch.long) + offset).cuda() 121 | labels_aux = (labels_aux + offset).cuda() 122 | loss = lam * nn.CrossEntropyLoss()(logits, labels) * (2 * self.T) 123 | loss = loss + (1. - lam) * nn.CrossEntropyLoss()(logits, labels_aux) * (2 * self.T) 124 | return loss 125 | 126 | def forward(self, x1, x2, m): 127 | """ 128 | Input: 129 | x1: first views of images 130 | x2: second views of images 131 | m: moco momentum 132 | Output: 133 | loss 134 | """ 135 | 136 | # compute features 137 | if self.icutmix: 138 | x1_mixed, y1_aux, lam1 = cutmix(x1, self.alpha) 139 | x2_mixed, y2_aux, lam2 = cutmix(x2, self.alpha) 140 | q1_mixed = self.predictor(self.base_encoder(x1_mixed)) 141 | q2_mixed = self.predictor(self.base_encoder(x2_mixed)) 142 | else: 143 | q1 = self.predictor(self.base_encoder(x1)) 144 | q2 = self.predictor(self.base_encoder(x2)) 145 | 146 | with torch.no_grad(): # no gradient 147 | self._update_momentum_encoder(m) # update the momentum encoder 148 | 149 | # compute momentum features as targets 150 | k1 = self.momentum_encoder(x1) 151 | k2 = self.momentum_encoder(x2) 152 | 153 | if self.icutmix: 154 | l1 = self.contrastive_icutmix_loss(q1_mixed, y1_aux, lam1, k2) 155 | l2 = self.contrastive_icutmix_loss(q2_mixed, y2_aux, lam2, k1) 156 | return l1 + l2 157 | else: 158 | l1 = self.contrastive_loss(q1, k2) 159 | l2 = self.contrastive_loss(q2, k1) 160 | return l1 + l2 161 | 162 | 163 | class MoCo_ResNet(MoCo): 164 | def _build_projector_and_predictor_mlps(self, dim, mlp_dim): 165 | hidden_dim = self.base_encoder.fc.weight.shape[1] 166 | del self.base_encoder.fc, self.momentum_encoder.fc # remove original fc layer 167 | 168 | # projectors 169 | self.base_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim) 170 | self.momentum_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim) 171 | 172 | # predictor 173 | self.predictor = self._build_mlp(2, dim, mlp_dim, dim, False) 174 | 175 | 176 | class MoCo_ViT(MoCo): 177 | def _build_projector_and_predictor_mlps(self, dim, mlp_dim): 178 | hidden_dim = self.base_encoder.head.weight.shape[1] 179 | del self.base_encoder.head, self.momentum_encoder.head # remove original fc layer 180 | 181 | # projectors 182 | self.base_encoder.head = self._build_mlp(3, hidden_dim, mlp_dim, dim) 183 | self.momentum_encoder.head = self._build_mlp(3, hidden_dim, mlp_dim, dim) 184 | 185 | # predictor 186 | self.predictor = self._build_mlp(2, dim, mlp_dim, dim) 187 | 188 | 189 | # utils 190 | @torch.no_grad() 191 | def concat_all_gather(tensor): 192 | """ 193 | Performs all_gather operation on the provided tensors. 194 | *** Warning ***: torch.distributed.all_gather has no gradient. 195 | """ 196 | tensors_gather = [torch.ones_like(tensor) 197 | for _ in range(torch.distributed.get_world_size())] 198 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 199 | 200 | output = torch.cat(tensors_gather, dim=0) 201 | return output 202 | -------------------------------------------------------------------------------- /moco/loader.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageFilter, ImageOps 2 | import math 3 | import random 4 | import torchvision.transforms.functional as tf 5 | 6 | 7 | class TwoCropsTransform: 8 | """Take two random crops of one image""" 9 | 10 | def __init__(self, base_transform1, base_transform2): 11 | self.base_transform1 = base_transform1 12 | self.base_transform2 = base_transform2 13 | 14 | def __call__(self, x): 15 | im1 = self.base_transform1(x) 16 | im2 = self.base_transform2(x) 17 | return [im1, im2] 18 | 19 | 20 | class GaussianBlur(object): 21 | """Gaussian blur augmentation from SimCLR: https://arxiv.org/abs/2002.05709""" 22 | 23 | def __init__(self, sigma=[.1, 2.]): 24 | self.sigma = sigma 25 | 26 | def __call__(self, x): 27 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 28 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 29 | return x 30 | 31 | 32 | class Solarize(object): 33 | """Solarize augmentation from BYOL: https://arxiv.org/abs/2006.07733""" 34 | 35 | def __call__(self, x): 36 | return ImageOps.solarize(x) 37 | -------------------------------------------------------------------------------- /moco/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LARS(torch.optim.Optimizer): 5 | """ 6 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 7 | """ 8 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 9 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 10 | super().__init__(params, defaults) 11 | 12 | @torch.no_grad() 13 | def step(self): 14 | for g in self.param_groups: 15 | for p in g['params']: 16 | dp = p.grad 17 | 18 | if dp is None: 19 | continue 20 | 21 | if p.ndim > 1: # if not normalization gamma/beta or bias 22 | dp = dp.add(p, alpha=g['weight_decay']) 23 | param_norm = torch.norm(p) 24 | update_norm = torch.norm(dp) 25 | one = torch.ones_like(param_norm) 26 | q = torch.where(param_norm > 0., 27 | torch.where(update_norm > 0, 28 | (g['trust_coefficient'] * param_norm / update_norm), one), 29 | one) 30 | dp = dp.mul(q) 31 | 32 | param_state = self.state[p] 33 | if 'mu' not in param_state: 34 | param_state['mu'] = torch.zeros_like(p) 35 | mu = param_state['mu'] 36 | mu.mul_(g['momentum']).add_(dp) 37 | p.add_(mu, alpha=-g['lr']) 38 | -------------------------------------------------------------------------------- /patch_search_iterative_search.py: -------------------------------------------------------------------------------- 1 | import re 2 | import argparse 3 | import os 4 | import copy 5 | from collections import Counter 6 | import time 7 | import shutil 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | import torch.nn as nn 12 | from torch.utils.data import DataLoader, Dataset 13 | 14 | import torchvision.transforms as transforms 15 | import torchvision.datasets as datasets 16 | import torchvision.models as models 17 | from torchvision.utils import save_image 18 | 19 | from tqdm import tqdm 20 | import matplotlib.pyplot as plt 21 | from PIL import Image 22 | import numpy as np 23 | import faiss 24 | from sklearn.metrics import precision_recall_curve, pairwise_distances 25 | 26 | from pytorch_grad_cam import GradCAM 27 | from pytorch_grad_cam.utils.image import show_cam_on_image 28 | 29 | from eval_utils import get_logger 30 | import vits 31 | 32 | 33 | parser = argparse.ArgumentParser(description='Grad-CAM SSL Defense') 34 | parser.add_argument('-a', '--arch', default='resnet18', 35 | help='model architecture (default: resnet18)') 36 | parser.add_argument('-j', '--workers', default=48, type=int, 37 | help='number of data loading workers (default: 32)') 38 | parser.add_argument('-b', '--batch-size', default=128, type=int, 39 | help='mini-batch size (default: 128)') 40 | parser.add_argument('-p', '--print-freq', default=50, type=int, 41 | help='print frequency (default: 50)') 42 | parser.add_argument('--weights', type=str, required=True, 43 | help='pre-trained model weights') 44 | parser.add_argument('--train_file', type=str, required=False, 45 | help='file containing training image paths') 46 | parser.add_argument('--val_file', type=str, required=False, 47 | help='file containing eval image paths') 48 | parser.add_argument('--num_clusters', default=1000, type=int, 49 | help='number of clusters') 50 | parser.add_argument('--test_images_size', default=1000, type=int, 51 | help='number random test images to sample for evalluating a patch candidate') 52 | parser.add_argument('--window_w', default=60, type=int, 53 | help='size of the patch candidate to extract from the grad cam heatmap') 54 | parser.add_argument('--repeat_patch', default=1, type=int, 55 | help='number of max firing patches to extract from a candidate image') 56 | parser.add_argument('--samples_per_iteration', default=2, type=int, 57 | help='number of samples to randomly sample from each cluster during an iteration') 58 | parser.add_argument('--remove_per_iteration', default=.25, type=float, 59 | help='fraction of clusters to prune during each iteration') 60 | parser.add_argument('--use_cached_feats', action='store_true', 61 | help='use cached features or not') 62 | parser.add_argument('--use_cached_poison_scores', action='store_true', 63 | help='use cached poison scores or not') 64 | parser.add_argument('--prune_clusters', action='store_true', 65 | help='prune clusters during filtering') 66 | parser.add_argument('--cached_poison_scores', type=str, 67 | help='file path of the ') 68 | 69 | 70 | class FileListDataset(Dataset): 71 | def __init__(self, path_to_txt_file, transform): 72 | with open(path_to_txt_file, 'r') as f: 73 | lines = f.readlines() 74 | samples = [line.strip().split() for line in lines] 75 | samples = [(pth, int(target)) for pth, target in samples] 76 | self.samples = samples 77 | self.transform = transform 78 | self.classes = list(sorted(set(y for _, y in self.samples))) 79 | 80 | 81 | def __getitem__(self, idx): 82 | image_path, target = self.samples[idx] 83 | img = Image.open(image_path).convert('RGB') 84 | 85 | if self.transform is not None: 86 | image = self.transform(img) 87 | 88 | is_poisoned = 'HTBA_trigger' in image_path 89 | 90 | return image, target, is_poisoned, idx 91 | 92 | def __len__(self): 93 | return len(self.samples) 94 | 95 | 96 | def denormalize(x): 97 | if x.shape[0] == 3: 98 | x = x.permute((1, 2, 0)) 99 | mean = torch.tensor([0.485, 0.456, 0.406], device=x.device) 100 | std = torch.tensor([0.229, 0.224, 0.225], device=x.device) 101 | x = ((x * std) + mean) 102 | return x 103 | 104 | 105 | def run_gradcam(arch, model, inp, targets=None): 106 | if 'vit' in arch: 107 | return run_vit_gradcam(model, [model.blocks[-1].norm1], inp, targets) 108 | else: 109 | return run_cnn_gradcam(model, [model.layer4], inp, targets) 110 | 111 | 112 | def run_cnn_gradcam(model, target_layers, inp, targets=None): 113 | with GradCAM(model=model, target_layers=target_layers, use_cuda=True) as cam: 114 | cam.batch_size = 32 115 | grayscale_cam, out = cam(input_tensor=inp, targets=targets) 116 | return grayscale_cam, out 117 | 118 | 119 | def reshape_transform(tensor, height=14, width=14): 120 | result = tensor[:, 1: , :].reshape(tensor.size(0), 121 | height, width, tensor.size(2)) 122 | 123 | # Bring the channels to the first dimension, 124 | # like in CNNs. 125 | result = result.transpose(2, 3).transpose(1, 2) 126 | return result 127 | 128 | 129 | def run_vit_gradcam(model, target_layers, inp, targets=None): 130 | with GradCAM(model=model, target_layers=target_layers, 131 | reshape_transform=reshape_transform, use_cuda=True) as cam: 132 | cam.batch_size = 32 133 | grayscale_cam, out = cam(input_tensor=inp, targets=targets) 134 | return grayscale_cam, out 135 | 136 | 137 | def get_feats(model, loader): 138 | model = nn.DataParallel(model).cuda() 139 | model.eval() 140 | feats, labels, indices, is_poisoned = [], [], [], [] 141 | for images, targets, is_p, inds in tqdm(loader): 142 | with torch.no_grad(): 143 | feats.append(model(images.cuda()).cpu()) 144 | labels.append(targets) 145 | indices.append(inds) 146 | is_poisoned.append(is_p) 147 | feats = torch.cat(feats) 148 | labels = torch.cat(labels) 149 | indices = torch.cat(indices) 150 | is_poisoned = torch.cat(is_poisoned) 151 | feats /= feats.norm(2, dim=-1, keepdim=True) 152 | return feats, labels, is_poisoned, indices 153 | 154 | 155 | def faiss_kmeans(train_feats, nmb_clusters): 156 | train_feats = train_feats.numpy() 157 | 158 | d = train_feats.shape[-1] 159 | 160 | clus = faiss.Clustering(d, nmb_clusters) 161 | clus.niter = 20 162 | clus.max_points_per_centroid = 10000000 163 | 164 | index = faiss.IndexFlatL2(d) 165 | co = faiss.GpuMultipleClonerOptions() 166 | co.useFloat16 = True 167 | co.shard = True 168 | index = faiss.index_cpu_to_all_gpus(index, co) 169 | 170 | # perform the training 171 | clus.train(train_feats, index) 172 | train_d, train_a = index.search(train_feats, 1) 173 | 174 | return train_d, train_a, index, clus.centroids 175 | 176 | 177 | class KMeansLinear(nn.Module): 178 | def __init__(self, train_a, train_val_feats, num_clusters): 179 | super().__init__() 180 | clusters = [] 181 | for i in range(num_clusters): 182 | cluster = train_val_feats[train_a == i].mean(dim=0) 183 | clusters.append(cluster) 184 | self.classifier = nn.Parameter(torch.stack(clusters)) 185 | 186 | def forward(self, x): 187 | c = self.classifier 188 | c = c / c.norm(2, dim=1, keepdim=True) 189 | x = x / x.norm(2, dim=1, keepdim=True) 190 | return x @ c.T 191 | 192 | 193 | def paste_patch(inputs, patch): 194 | B = inputs.shape[0] 195 | inp_w = inputs.shape[-1] 196 | window_w = patch.shape[-1] 197 | ij = torch.randint(low=0, high=(inp_w - window_w), size=(B, 2)) 198 | i, j = ij[:, 0], ij[:, 1] 199 | 200 | # create row and column indices for each position in the window 201 | s = torch.arange(window_w, device=inputs.device) 202 | ri = i.view(B, 1).repeat(1, window_w) 203 | rj = j.view(B, 1).repeat(1, window_w) 204 | sri, srj = ri + s, rj + s 205 | 206 | # repeat starting row index in columns and vice versa 207 | xi = sri.view(B, window_w, 1).repeat(1, 1, window_w) 208 | xj = srj.view(B, 1, window_w).repeat(1, window_w, 1) 209 | 210 | # these are 2d indices so convert them into 1d indices 211 | inds = xi * inp_w + xj 212 | 213 | # repeat the indices across color channels 214 | inds = inds.unsqueeze(1).repeat((1, 3, 1, 1)).view(B, 3, -1) 215 | 216 | # convert patch 2d->1d and repeat across the batch dimension 217 | patch = patch.reshape(3, -1).unsqueeze(0).repeat(B, 1, 1) 218 | 219 | # convert image 2d->1d, scatter patch, convert image 1d->2d 220 | inputs = inputs.reshape(B, 3, -1) 221 | inputs.scatter_(dim=2, index=inds, src=patch) 222 | inputs = inputs.reshape(B, 3, inp_w, inp_w) 223 | return inputs 224 | 225 | 226 | def block_max_window(cam_images, inputs, window_w=30): 227 | B, _, inp_w = cam_images.shape 228 | grayscale_cam = torch.from_numpy(cam_images) 229 | inputs = inputs.clone() 230 | sum_conv = torch.ones((1, 1, window_w, window_w)) 231 | 232 | # calculate sums in each window 233 | sums_cam = F.conv2d(grayscale_cam.unsqueeze(1), sum_conv) 234 | 235 | # flatten the sums and take argmax 236 | flat_sums_cam = sums_cam.view(B, -1) 237 | ij = flat_sums_cam.argmax(dim=-1) 238 | 239 | # separate out the row and column indices 240 | # this gives us the location of top left window corner 241 | sums_cam_w = sums_cam.shape[-1] 242 | i, j = ij // sums_cam_w, ij % sums_cam_w 243 | 244 | # create row and column indices for each position in the window 245 | s = torch.arange(window_w, device=inputs.device) 246 | ri = i.view(B, 1).repeat(1, window_w) 247 | rj = j.view(B, 1).repeat(1, window_w) 248 | sri, srj = ri + s, rj + s 249 | 250 | # repeat starting row index in columns and vice versa 251 | xi = sri.view(B, window_w, 1).repeat(1, 1, window_w) 252 | xj = srj.view(B, 1, window_w).repeat(1, window_w, 1) 253 | 254 | # these are 2d indices so convert them into 1d indices 255 | inds = xi * inp_w + xj 256 | 257 | # repeat the indices across color channels 258 | inds = inds.unsqueeze(1).repeat((1, 3, 1, 1)).view(B, 3, -1) 259 | 260 | # convert image 2d->1d, set window locations to 0, convert image 1d->2d 261 | inputs = inputs.reshape(B, 3, -1) 262 | inputs.scatter_(dim=2, index=inds, value=0) 263 | inputs = inputs.reshape(B, 3, inp_w, inp_w) 264 | return inputs 265 | 266 | 267 | def extract_max_window(cam_images, inputs, window_w=30): 268 | B, _, inp_w = cam_images.shape 269 | grayscale_cam = torch.from_numpy(cam_images) 270 | inputs = inputs.clone() 271 | sum_conv = torch.ones((1, 1, window_w, window_w)) 272 | 273 | # calculate sums in each window 274 | sums_cam = F.conv2d(grayscale_cam.unsqueeze(1), sum_conv) 275 | 276 | # flatten the sums and take argmax 277 | flat_sums_cam = sums_cam.view(B, -1) 278 | ij = flat_sums_cam.argmax(dim=-1) 279 | 280 | # separate out the row and column indices 281 | # this gives us the location of top left window corner 282 | sums_cam_w = sums_cam.shape[-1] 283 | i, j = ij // sums_cam_w, ij % sums_cam_w 284 | 285 | # create row and column indices for each position in the window 286 | s = torch.arange(window_w, device=inputs.device) 287 | ri = i.view(B, 1).repeat(1, window_w) 288 | rj = j.view(B, 1).repeat(1, window_w) 289 | sri, srj = ri + s, rj + s 290 | 291 | # repeat starting row index in columns and vice versa 292 | xi = sri.view(B, window_w, 1).repeat(1, 1, window_w) 293 | xj = srj.view(B, 1, window_w).repeat(1, window_w, 1) 294 | 295 | # these are 2d indices so convert them into 1d indices 296 | inds = xi * inp_w + xj 297 | 298 | # repeat the indices across color channels 299 | inds = inds.unsqueeze(1).repeat((1, 3, 1, 1)).view(B, 3, -1) 300 | 301 | # convert image 2d->1d 302 | inputs = inputs.reshape(B, 3, -1) 303 | 304 | # gather the windows and reshape 1d->2d 305 | windows = torch.gather(inputs, dim=2, index=inds) 306 | windows = windows.reshape(B, 3, window_w, window_w) 307 | 308 | return windows 309 | 310 | 311 | def get_candidate_patches(model, loader, args): 312 | candidate_patches = [] 313 | for inp, _, _, _ in tqdm(loader): 314 | windows = [] 315 | for _ in range(args.repeat_patch): 316 | cam_images, _ = run_gradcam(args.arch, model, inp) 317 | windows.append(extract_max_window(cam_images, inp, args.window_w)) 318 | block_max_window(cam_images, inp, int(args.window_w * .5)) 319 | windows = torch.stack(windows) 320 | windows = torch.einsum('kb...->bk...', windows) 321 | candidate_patches.append(windows.detach().cpu()) 322 | candidate_patches = torch.cat(candidate_patches) 323 | return candidate_patches 324 | 325 | 326 | def get_model(arch, wts_path): 327 | if 'moco_vit' in arch: 328 | model = vits.__dict__[arch.replace('moco_', '')]() 329 | model.head = nn.Identity() 330 | sd = torch.load(wts_path)['state_dict'] 331 | sd = {k.replace('module.', ''): v for k, v in sd.items()} 332 | sd = {k: v for k, v in sd.items() if 'base_encoder' in k} 333 | sd = {k: v for k, v in sd.items() if 'head' not in k} 334 | sd = {k.replace('base_encoder.', ''): v for k, v in sd.items()} 335 | model.load_state_dict(sd, strict=True) 336 | elif 'moco' in arch: 337 | model = models.__dict__[arch.replace('moco_', '')]() 338 | model.fc = nn.Sequential() 339 | sd = torch.load(wts_path)['state_dict'] 340 | sd = {k.replace('module.', ''): v for k, v in sd.items()} 341 | sd = {k: v for k, v in sd.items() if 'encoder_q' in k or 'base_encoder' in k} 342 | sd = {k: v for k, v in sd.items() if 'fc' not in k} 343 | sd = {k.replace('encoder_q.', ''): v for k, v in sd.items()} 344 | sd = {k.replace('base_encoder.', ''): v for k, v in sd.items()} 345 | model.load_state_dict(sd, strict=True) 346 | elif 'byol' in arch: 347 | model = models.__dict__[arch.replace('byol_', '')]() 348 | model.fc = nn.Sequential() 349 | sd = torch.load(wts_path) 350 | sd = {k.replace('module.', ''): v for k, v in sd.items()} 351 | sd = {k: v for k, v in sd.items() if 'model_t' not in k} 352 | sd = {k: v for k, v in sd.items() if 'head' not in k} 353 | sd = {k: v for k, v in sd.items() if 'pred' not in k} 354 | sd = {k.replace('model.', ''): v for k, v in sd.items()} 355 | model.load_state_dict(sd, strict=True) 356 | elif 'resnet' in arch: 357 | model = models.__dict__[arch]() 358 | model.fc = nn.Sequential() 359 | load_weights(model, wts_path) 360 | else: 361 | raise ValueError('arch not found: ' + arch) 362 | 363 | model = model.eval() 364 | 365 | return model 366 | 367 | def get_test_images(train_val_dataset, cluster_wise_i, args): 368 | test_images_i = [] 369 | k = args.test_images_size // len(cluster_wise_i) 370 | if k > 0: 371 | for inds in cluster_wise_i: 372 | test_images_i.extend(inds[:k]) 373 | else: 374 | for clust_i in np.random.permutation(len(cluster_wise_i))[:args.test_images_size]: 375 | test_images_i.append(cluster_wise_i[clust_i][0]) 376 | 377 | test_images_dataset = torch.utils.data.Subset( 378 | train_val_dataset, torch.tensor(test_images_i) 379 | ) 380 | test_images_loader = DataLoader( 381 | test_images_dataset, 382 | shuffle=False, batch_size=args.batch_size, 383 | num_workers=args.workers, pin_memory=True 384 | ) 385 | # logger.info('==> get test images') 386 | test_images = [] 387 | for inp, _, _, _ in tqdm(test_images_loader): 388 | test_images.append(inp) 389 | test_images = torch.cat(test_images) 390 | return test_images, test_images_i 391 | 392 | 393 | class Normalize(nn.Module): 394 | def forward(self, x): 395 | return x / x.norm(2, dim=1, keepdim=True) 396 | 397 | 398 | class FullBatchNorm(nn.Module): 399 | def __init__(self, var, mean): 400 | super(FullBatchNorm, self).__init__() 401 | self.register_buffer('inv_std', (1.0 / torch.sqrt(var + 1e-5))) 402 | self.register_buffer('mean', mean) 403 | 404 | def forward(self, x): 405 | return (x - self.mean) * self.inv_std 406 | 407 | 408 | def get_channels(arch): 409 | if 'resnet50' in arch: 410 | c = 2048 411 | elif 'resnet18' in arch: 412 | c = 512 413 | else: 414 | raise ValueError('arch not found: ' + arch) 415 | return c 416 | 417 | 418 | def main(): 419 | np.random.seed(10) 420 | torch.manual_seed(10) 421 | 422 | global logger 423 | 424 | args = parser.parse_args() 425 | args.save = os.path.dirname(args.weights) 426 | match = re.search(r'\d+', os.path.basename(args.weights)) 427 | ckpt = match.group(0) if match else 'final' 428 | dir_name = f'patch_search_iterative_search_test_images_size_{args.test_images_size}_window_w_{args.window_w}_repeat_patch_{args.repeat_patch}_prune_clusters_{args.prune_clusters}' 429 | dir_name = f'{dir_name}_num_clusters_{args.num_clusters}' 430 | if args.prune_clusters: 431 | dir_name = f'{dir_name}_per_iteration_samples_{args.samples_per_iteration}_remove_{args.remove_per_iteration}' 432 | dir_name = dir_name.replace('.', 'x') 433 | args.save = os.path.join(args.save, dir_name) 434 | os.makedirs(args.save, exist_ok=True) 435 | logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) 436 | if 'HTBA_trigger' in args.weights: 437 | args.trigger_id, args.experiment_id = re.search(r'HTBA_trigger_(\d+)_targeted_(n\d+)', args.weights).groups() 438 | else: 439 | # clean experiment 440 | args.trigger_id, args.experiment_id = 1234, 'n02106550' 441 | class_id_to_name = {} 442 | with open('metadata_files/im100_metadata.txt', 'r') as f: 443 | for line in f.readlines(): 444 | class_id = int(line.split()[1]) 445 | class_name = ' '.join(line.split()[2:]) 446 | class_id_to_name[class_id] = class_name 447 | if line.startswith(args.experiment_id): 448 | args.target_class_id = class_id 449 | args.target_class_name = class_name 450 | args.target_class_name += '__CLEAN' if 'clean' in args.weights and 'HTBA_trigger' not in args.weights else '' 451 | 452 | for arg in vars(args): 453 | logger.info(f'==> {arg}: {getattr(args, arg)}') 454 | 455 | backbone = get_model(args.arch, args.weights) 456 | 457 | val_transform = transforms.Compose([ 458 | transforms.Resize(256, interpolation=3), 459 | transforms.CenterCrop(224), 460 | transforms.ToTensor(), 461 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 462 | ]) 463 | 464 | train_val_dataset = FileListDataset(args.train_file, val_transform) 465 | 466 | train_val_loader = DataLoader( 467 | train_val_dataset, 468 | shuffle=False, batch_size=args.batch_size, 469 | num_workers=args.workers, pin_memory=True 470 | ) 471 | 472 | cache_file_path = os.path.join(args.save, 'cached_feats.pth') 473 | if os.path.exists(cache_file_path) and args.use_cached_feats: 474 | train_val_feats, train_val_labels, train_val_is_poisoned, train_val_inds = torch.load(cache_file_path) 475 | else: 476 | # step: get l2 normalized features and other information 477 | train_val_feats, train_val_labels, train_val_is_poisoned, train_val_inds = get_feats(backbone, train_val_loader) 478 | torch.save((train_val_feats, train_val_labels, train_val_is_poisoned, train_val_inds), cache_file_path) 479 | return 480 | 481 | num_clusters = args.num_clusters 482 | num_classes = len(train_val_dataset.classes) 483 | 484 | # step: cluster the features with k-means 485 | train_d, train_a, index, centroids = faiss_kmeans(train_val_feats, num_clusters) 486 | 487 | train_y = train_val_labels.numpy().reshape(-1, 1) 488 | train_i = train_val_inds.numpy().reshape(-1, 1) 489 | train_p = train_val_is_poisoned.numpy().reshape(-1, 1) 490 | 491 | model = copy.deepcopy(backbone) 492 | model.fc = KMeansLinear(train_a[:, 0], train_val_feats, num_clusters) 493 | model = model.cuda() 494 | 495 | # step: create per cluster queue ordered with distance to the centroid 496 | sorted_cluster_wise_i = [] 497 | random_cluster_wise_i = [] 498 | for cluster_id in range(num_clusters): 499 | cur_d = train_d[train_a == cluster_id] 500 | cur_i = train_i[train_a == cluster_id] 501 | sorted_cluster_wise_i.append(cur_i[np.argsort(cur_d)].tolist()) 502 | random_cluster_wise_i.append(cur_i[np.random.permutation(len(cur_i))].tolist()) 503 | 504 | # step: get test images by sampling closest samples to the centroid 505 | test_images, test_images_i = get_test_images(train_val_dataset, sorted_cluster_wise_i, args) 506 | test_images_a = train_a[test_images_i, 0] 507 | 508 | torch.cuda.empty_cache() 509 | 510 | # step: calculate pairwise distances between centroids 511 | c = model.fc.classifier.detach().cpu() 512 | c = (c / c.norm(2, dim=1, keepdim=True)).numpy() 513 | cluster_distances = pairwise_distances(c, c) 514 | 515 | # backbone used for calculating features of poisoned images 516 | # and model is used for calculating the Grad-CAM heatmap 517 | backbone = nn.DataParallel(backbone).cuda() 518 | backbone = backbone.eval() 519 | 520 | # step: initialize some variables for use in the poison detection loop 521 | poison_scores = np.zeros(len(train_val_dataset)) 522 | candidate_clusters = list(range(num_clusters)) 523 | cur_iter = 0 524 | 525 | # step: only use the cache if poison scores are already saved 526 | poison_scores_file = os.path.join(args.save, 'poison-scores.npy') 527 | use_cached_poison_scores = args.use_cached_poison_scores and os.path.exists(poison_scores_file) 528 | processed_count = 0 529 | 530 | # step: if not using cached poisoned scores then run the the infinite loop 531 | while not use_cached_poison_scores: 532 | logger.info(f'==> current iteration {cur_iter}') 533 | 534 | # step: sample candidate images from each candidate cluster 535 | candidate_poison_i = [] 536 | for clust_id in candidate_clusters: 537 | clust_i = random_cluster_wise_i[clust_id] 538 | for _ in range(min(len(clust_i), args.samples_per_iteration)): 539 | candidate_poison_i.append(clust_i.pop(0)) 540 | 541 | # step: break if no candidate images found 542 | if not len(candidate_poison_i): 543 | break 544 | 545 | # step: create the data loader for candidate poison images 546 | candidate_poison_dataset = torch.utils.data.Subset( 547 | train_val_dataset, torch.tensor(candidate_poison_i) 548 | ) 549 | candidate_poison_loader = DataLoader( 550 | candidate_poison_dataset, 551 | shuffle=False, batch_size=args.batch_size, 552 | num_workers=args.workers, pin_memory=True 553 | ) 554 | processed_count += len(candidate_poison_dataset) 555 | 556 | # step: extract candidate patches from the data loader 557 | logger.info('==> extract patches') 558 | candidate_patches = get_candidate_patches(model, candidate_poison_loader, args) 559 | 560 | # step: calculate the poison score of each candidate patch 561 | logger.info('==> evaluate patches') 562 | for candidate_patch, patch_idx in tqdm(zip(candidate_patches, candidate_poison_i)): 563 | cur_scores = [] 564 | # step: there can be multpile patches sampled from a single image 565 | for cur_patch in candidate_patch: 566 | with torch.no_grad(): 567 | # step: paste candidate patch on the test images and extract features 568 | poisoned_test_images = paste_patch(test_images.clone(), cur_patch) 569 | feats_poisoned_test_images = backbone(poisoned_test_images.cuda()).cpu().numpy() 570 | # step: calculate flips and update the poison score 571 | _, poisoned_test_images_a = index.search(feats_poisoned_test_images, 1) 572 | new = np.count_nonzero(poisoned_test_images_a == train_a[patch_idx, 0]) 573 | orig = np.count_nonzero(test_images_a == train_a[patch_idx, 0]) 574 | cur_scores.append(new - orig) 575 | # step: take the max flips of all patches from an image 576 | assert poison_scores[patch_idx] == 0 577 | poison_scores[patch_idx] += max(cur_scores) 578 | 579 | # step: calculate the score for each candidate cluster 580 | logger.info(f'==> max poison score {poison_scores.argmax()} : {poison_scores.max()}') 581 | cluster_scores = [] 582 | for clust_id in candidate_clusters: 583 | cluster_scores.append((clust_id, poison_scores[train_a[:, 0] == clust_id].max())) 584 | cluster_scores = np.array(cluster_scores).astype(int) 585 | cluster_scores = cluster_scores[cluster_scores[:, 1].argsort()][::-1] 586 | 587 | # step: print a few top poisonous clusters 588 | for clust_rank, (clust_id, clust_score) in enumerate(cluster_scores.tolist()[:10]): 589 | logger.info(f'==> top poisoned clusters : rank {clust_rank:3d} cluster_id {clust_id:3d} score {clust_score}') 590 | 591 | logger.info(f'==> processed count : {processed_count:6d}/{len(train_val_dataset)} ({processed_count*100/len(train_val_dataset):.1f})') 592 | 593 | if args.prune_clusters: 594 | # step: remove a few least poisonous clusters 595 | rem = int(args.remove_per_iteration * len(candidate_clusters)) 596 | candidate_clusters = cluster_scores[:-rem, 0].tolist() 597 | 598 | cur_iter += 1 599 | 600 | # step: save the poison scores or load them from the cache 601 | if use_cached_poison_scores: 602 | poison_scores = np.load(poison_scores_file) 603 | else: 604 | np.save(poison_scores_file, poison_scores) 605 | 606 | ###################################################################################################### 607 | 608 | # step: get a few top poisonous images 609 | save_inds = poison_scores.argsort()[::-1][:100] 610 | 611 | inp, inp_titles = [], [] 612 | for i in save_inds: 613 | inp.append(train_val_dataset[i][0]) 614 | class_name = class_id_to_name[train_y[i, 0]] 615 | class_name = class_name if ',' not in class_name else class_name.split(',')[0] 616 | class_name = class_name.lower() 617 | inp_titles.append(class_name) 618 | inp = torch.stack(inp, dim=0) 619 | 620 | # step: save the top images and patches 621 | cam_images, out = run_gradcam(args.arch, model, inp) 622 | windows = extract_max_window(cam_images, inp, args.window_w) 623 | os.makedirs(os.path.join(args.save, 'all_top_poison_patches'), exist_ok=True) 624 | for i, win in enumerate(windows): 625 | win = denormalize(win) 626 | win = (win * 255).clamp(0, 255).numpy().astype(np.uint8) 627 | win = Image.fromarray(win) 628 | win.save(os.path.join(args.save, 'all_top_poison_patches', f'{i:05d}.png')) 629 | 630 | sorted_inds = poison_scores.argsort()[::-1] 631 | topks = [5, 10, 20, 50, 100, 500] 632 | accs = [train_p[sorted_inds[:k]].sum() * 100.0 / k for k in topks] 633 | logger.info('==> acc in top-k | ' + ' '.join(f'{k:7d}' for k in topks)) 634 | logger.info('==> acc in top-k | ' + ' '.join(f'{acc:7.1f}' for acc in accs)) 635 | 636 | 637 | if __name__ == '__main__': 638 | main() 639 | -------------------------------------------------------------------------------- /patch_search_poison_classifier.py: -------------------------------------------------------------------------------- 1 | import math 2 | import argparse 3 | import os 4 | import random 5 | import shutil 6 | import time 7 | import warnings 8 | import glob 9 | from functools import partial 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.parallel 14 | import torch.backends.cudnn as cudnn 15 | import torch.optim 16 | import torch.utils.data 17 | from torch.utils.data import DataLoader, Dataset 18 | import torchvision.transforms as transforms 19 | import torchvision.datasets as datasets 20 | import torchvision.models as models 21 | from torchvision.models.resnet import BasicBlock, ResNet 22 | import torch.nn.functional as F 23 | 24 | from eval_utils import AverageMeter, ProgressMeter, model_names, accuracy, get_logger, save_checkpoint 25 | from PIL import Image 26 | import numpy as np 27 | import matplotlib.pyplot as plt 28 | from sklearn.metrics import precision_recall_curve 29 | 30 | parser = argparse.ArgumentParser(description='Linear evaluation of contrastive model') 31 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 32 | help='number of data loading workers (default: 4)') 33 | parser.add_argument('-a', '--arch', default='resnet18', 34 | help='model architecture: ' + 35 | ' | '.join(model_names) + 36 | ' (default: resnet18)') 37 | parser.add_argument('--max_iterations', default=2500, type=int, metavar='N', 38 | help='maximum number of iterations to run') 39 | parser.add_argument('--start_iteration', default=0, type=int, metavar='N', 40 | help='manual iteration number (useful on restarts)') 41 | parser.add_argument('-b', '--batch_size', default=256, type=int, 42 | metavar='N', 43 | help='mini-batch size (default: 256), this is the total ' 44 | 'batch size of all GPUs on the current node when ' 45 | 'using Data Parallel or Distributed Data Parallel') 46 | parser.add_argument('--lr', '--learning_rate', default=0.01, type=float, 47 | metavar='LR', help='initial learning rate', dest='lr') 48 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 49 | help='momentum') 50 | parser.add_argument('--wd', '--weight_decay', default=1e-4, type=float, 51 | metavar='W', help='weight decay (default: 1e-4)', 52 | dest='weight_decay') 53 | parser.add_argument('-p', '--print_freq', default=10, type=int, 54 | metavar='N', help='print frequency (default: 10)') 55 | parser.add_argument('--eval_freq', default=50, type=int, 56 | help='eval frequency (default: 50)') 57 | parser.add_argument('--eval_repeat', default=1, type=int, 58 | help='number of times to repeat the evaluation') 59 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 60 | help='path to latest checkpoint (default: none)') 61 | parser.add_argument('--seed', default=0, type=int, 62 | help='seed for initializing training. ') 63 | parser.add_argument('--save', default='./output/', type=str, 64 | help='experiment output directory') 65 | parser.add_argument('--lr_schedule', type=str, default='1,2,3,4', 66 | help='lr drop schedule') 67 | parser.add_argument('--train_file', type=str, required=False, 68 | help='file containing training image paths') 69 | parser.add_argument('--val_file', type=str, 70 | help='file containing training image paths') 71 | parser.add_argument('--eval_data', type=str, default="", 72 | help='eval identifier') 73 | parser.add_argument('--poison_dir', type=str, default="", 74 | help='directory containing poisons') 75 | parser.add_argument('--poison_scores', type=str, default="", 76 | help='path to poison scores') 77 | parser.add_argument('--topk_poisons', type=int, default=10, 78 | help='how many top poisons to use for training the classifier') 79 | parser.add_argument('--top_p', type=float, default=0.10, 80 | help='bottom percentage of data to use for training') 81 | parser.add_argument('--model_count', type=int, default=3, 82 | help='how many models to use for ensembling') 83 | 84 | 85 | def denormalize(x): 86 | if x.shape[0] == 3: 87 | x = x.permute((1, 2, 0)) 88 | mean = torch.tensor([0.485, 0.456, 0.406], device=x.device) 89 | std = torch.tensor([0.229, 0.224, 0.225], device=x.device) 90 | x = ((x * std) + mean) 91 | return x 92 | 93 | 94 | class FileListDataset(Dataset): 95 | def __init__(self, path_to_txt_file, pre_transform, post_transform, poison_dir, topk_poisons, output_type='clean'): 96 | # self.data_root = data_root 97 | self.output_type = output_type 98 | 99 | self.poisons = [] 100 | for poison_file in sorted(glob.glob(f'{poison_dir}/*.png')): 101 | self.poisons.append(Image.open(poison_file).convert('RGB')) 102 | self.poisons = self.poisons[:topk_poisons] 103 | 104 | with open(path_to_txt_file, 'r') as f: 105 | self.file_list = f.readlines() 106 | self.file_list = [row.rstrip() for row in self.file_list] 107 | 108 | self.pre_transform = pre_transform 109 | self.post_transform = post_transform 110 | 111 | 112 | def paste_poison(self, img): 113 | margin = 10 114 | image_size = 224 115 | poison_size_low, poison_size_high = 20, 80 116 | poison = self.poisons[np.random.randint(low=0, high=len(self.poisons))] 117 | # poison = self.poisons[0] 118 | new_s = np.random.randint(low=poison_size_low, high=poison_size_high) 119 | poison = poison.resize((new_s, new_s)) 120 | loc_box = (margin, image_size - (new_s + margin)) 121 | loc_h, loc_w = np.random.randint(*loc_box), np.random.randint(*loc_box) 122 | img.paste(poison, (loc_h, loc_w)) 123 | return img 124 | 125 | 126 | def __getitem__(self, idx): 127 | image_path = self.file_list[idx].split()[0] 128 | is_poisoned = 'HTBA' in image_path 129 | img = Image.open(image_path).convert('RGB') 130 | is_poison = np.random.rand() > 0.5 131 | 132 | if self.output_type == 'clean' or (self.output_type == 'rand' and not is_poison): 133 | target = 0 134 | img = self.pre_transform(img) 135 | img = self.post_transform(img) 136 | elif self.output_type == 'poisoned' or (self.output_type == 'rand' and is_poison): 137 | target = 1 138 | img = self.pre_transform(img) 139 | img = self.paste_poison(img) 140 | img = self.post_transform(img) 141 | else: 142 | raise ValueError(f'unexpected output_type: {self.output_type}') 143 | 144 | return image_path, img, target, is_poisoned, idx 145 | 146 | def __len__(self): 147 | return len(self.file_list) 148 | 149 | 150 | 151 | class ValFileListDataset(Dataset): 152 | def __init__(self, path_to_txt_file, pos_inds, neg_inds, transform): 153 | with open(path_to_txt_file, 'r') as f: 154 | file_list = f.readlines() 155 | file_list = [row.strip().split() for row in file_list] 156 | 157 | pos_samples = [(file_list[i][0], 1) for i in pos_inds] 158 | neg_samples = [(file_list[i][0], 0) for i in neg_inds] 159 | self.samples = pos_samples + neg_samples 160 | self.transform = transform 161 | 162 | 163 | def __getitem__(self, idx): 164 | image_path, target = self.samples[idx] 165 | is_poisoned = 'HTBA' in image_path 166 | img = Image.open(image_path).convert('RGB') 167 | img = self.transform(img) 168 | 169 | return image_path, img, target, is_poisoned, idx 170 | 171 | def __len__(self): 172 | return len(self.samples) 173 | 174 | 175 | def denormalize(x): 176 | if x.shape[0] == 3: 177 | x = x.permute((1, 2, 0)) 178 | mean = torch.tensor([0.485, 0.456, 0.406], device=x.device) 179 | std = torch.tensor([0.229, 0.224, 0.225], device=x.device) 180 | x = ((x * std) + mean) 181 | return x 182 | 183 | 184 | def show_images_without_cam(inp, save, title): 185 | inp = inp[:40] 186 | fig, axes = plt.subplots(nrows=8, ncols=5, figsize=(10, 15)) 187 | for img_idx in range(inp.shape[0]): 188 | rgb_image = denormalize(inp[img_idx]).detach().cpu().numpy() 189 | axes[img_idx//5][img_idx%5].imshow(rgb_image) 190 | axes[img_idx//5][img_idx%5].set_xticks([]) 191 | axes[img_idx//5][img_idx%5].set_yticks([]) 192 | plt.tight_layout() 193 | fig.savefig(os.path.join(save, title.lower().replace(' ', '-') + '.png')) 194 | 195 | 196 | class Subset(Dataset): 197 | def __init__(self, dataset, indices): 198 | self.dataset = dataset 199 | self.indices = indices 200 | 201 | def __getitem__(self, idx): 202 | output = self.dataset[self.indices[idx]] 203 | output = (*output[:-1], idx) 204 | return output 205 | 206 | def __len__(self): 207 | return len(self.indices) 208 | 209 | 210 | def worker_init_fn(baseline_seed, it, worker_id): 211 | np.random.seed(baseline_seed + it + worker_id) 212 | 213 | 214 | def get_loaders(args): 215 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 216 | std=[0.229, 0.224, 0.225]) 217 | train_t1 = transforms.Compose([ 218 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 219 | ]) 220 | train_t2 = transforms.Compose([ 221 | transforms.RandomApply([ 222 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 223 | ], p=0.8), 224 | transforms.RandomGrayscale(p=0.2), 225 | transforms.RandomHorizontalFlip(), 226 | transforms.ToTensor(), 227 | normalize, 228 | ]) 229 | train_dataset = FileListDataset( 230 | args.train_file, 231 | pre_transform=train_t1, post_transform=train_t2, 232 | poison_dir=args.poison_dir, 233 | topk_poisons=args.topk_poisons, 234 | output_type='rand', 235 | ) 236 | 237 | inds = np.random.randint(low=0, high=len(train_dataset), size=40) 238 | train_dataset.output_type = 'clean' 239 | clean_images = torch.stack([train_dataset[i][1] for i in inds]) 240 | show_images_without_cam(clean_images, args.save, 'Train Clean Images') 241 | train_dataset.output_type = 'poisoned' 242 | poisoned_images = torch.stack([train_dataset[i][1] for i in inds]) 243 | show_images_without_cam(poisoned_images, args.save, 'Train Poisoned Images') 244 | train_dataset.output_type = 'rand' 245 | rand_images = torch.stack([train_dataset[i][1] for i in inds]) 246 | show_images_without_cam(rand_images, args.save, 'Train Rand Images') 247 | 248 | poison_scores = np.load(args.poison_scores) 249 | sorted_inds = (-poison_scores).argsort() 250 | pos_inds = sorted_inds[:args.topk_poisons] 251 | neg_inds = sorted_inds[-args.topk_poisons:] 252 | train_inds = sorted_inds[int(args.top_p*len(train_dataset)):-args.topk_poisons] 253 | logger.info(f'==> train dataset size {len(train_inds)/1000:.1f}k') 254 | 255 | # step: take subset of the train dataset 256 | train_dataset.output_type = 'rand' 257 | train_dataset = Subset(train_dataset, train_inds) 258 | train_loader = DataLoader( 259 | train_dataset, 260 | batch_size=args.batch_size, shuffle=True, 261 | num_workers=args.workers, pin_memory=True, 262 | worker_init_fn=partial(worker_init_fn, args.seed, 0) 263 | ) 264 | 265 | # step: the validataion dataset 266 | val_t1 = transforms.Compose([ 267 | transforms.Resize(256), 268 | transforms.CenterCrop(224), 269 | ]) 270 | val_t2 = transforms.Compose([ 271 | transforms.ToTensor(), 272 | normalize, 273 | ]) 274 | val_dataset = ValFileListDataset( 275 | args.train_file, 276 | pos_inds=pos_inds, 277 | neg_inds=neg_inds, 278 | transform=transforms.Compose([val_t1, val_t2]), 279 | ) 280 | val_images = torch.stack([val_dataset[i][1] for i in range(len(val_dataset))]) 281 | show_images_without_cam(val_images, args.save, 'Val Images') 282 | val_loader = DataLoader( 283 | val_dataset, 284 | batch_size=args.batch_size, shuffle=False, 285 | num_workers=args.workers, pin_memory=True, 286 | worker_init_fn=partial(worker_init_fn, args.seed, 0) 287 | ) 288 | 289 | # step: create the test dataset 290 | test_dataset = FileListDataset( 291 | args.train_file, 292 | pre_transform=val_t1, post_transform=val_t2, 293 | poison_dir=args.poison_dir, 294 | topk_poisons=args.topk_poisons, 295 | output_type='clean' 296 | ) 297 | inds = np.random.randint(low=0, high=len(test_dataset), size=40) 298 | test_images = torch.stack([test_dataset[i][1] for i in inds]) 299 | show_images_without_cam(test_images, args.save, 'Test Images') 300 | 301 | test_loader = DataLoader( 302 | test_dataset, 303 | batch_size=args.batch_size, shuffle=False, 304 | num_workers=args.workers, pin_memory=True, 305 | ) 306 | 307 | assert train_dataset.dataset.output_type == 'rand' 308 | assert test_dataset.output_type == 'clean' 309 | 310 | return train_loader, val_loader, test_loader 311 | 312 | 313 | class EnsembleNet(nn.Module): 314 | def __init__(self, models): 315 | super(EnsembleNet, self).__init__() 316 | self.models = nn.ModuleList(models) 317 | 318 | def forward(self, x): 319 | y = torch.stack([model(x) for model in self.models], dim=0) 320 | y = torch.einsum('kbd->bkd', y) 321 | y = y.mean(dim=1) 322 | return y 323 | 324 | 325 | def main(): 326 | global logger 327 | 328 | args = parser.parse_args() 329 | args.save = os.path.join(os.path.dirname(args.poison_dir), f'patch_search_poison_classifier_topk_{args.topk_poisons}_ensemble_{args.model_count}_max_iterations_{args.max_iterations}_{args.eval_data}') 330 | os.makedirs(args.save, exist_ok=True) 331 | logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) 332 | 333 | for arg in vars(args): 334 | logger.info(f'==> {arg}: {getattr(args, arg)}') 335 | 336 | if args.seed is not None: 337 | np.random.seed(args.seed) 338 | random.seed(args.seed) 339 | torch.manual_seed(args.seed) 340 | cudnn.deterministic = True 341 | warnings.warn('You have chosen to seed training. ' 342 | 'This will turn on the CUDNN deterministic setting, ' 343 | 'which can slow down your training considerably! ' 344 | 'You may see unexpected behavior when restarting ' 345 | 'from checkpoints.') 346 | 347 | train_loader, val_loader, test_loader = get_loaders(args) 348 | 349 | models = [] 350 | for model_i in range(args.model_count): 351 | logger.info('='*40 + f' model_i {model_i} ' + '='*40) 352 | train_loader.worker_init_fn = partial(worker_init_fn, args.seed, model_i) 353 | val_loader.worker_init_fn = partial(worker_init_fn, args.seed, model_i) 354 | 355 | model = ResNet(block=BasicBlock, layers=[1, 1, 1, 1]) 356 | model.fc = nn.Linear(512, 2) 357 | model = model.cuda() 358 | 359 | optimizer = torch.optim.SGD(model.parameters(), 360 | args.lr, 361 | momentum=args.momentum, 362 | weight_decay=args.weight_decay) 363 | 364 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.max_iterations) 365 | 366 | batch_time = AverageMeter('Time', ':6.3f') 367 | data_time = AverageMeter('Data', ':6.3f') 368 | lrs = AverageMeter('LR', ':6.3f') 369 | losses = AverageMeter('Loss', ':.4e') 370 | top1 = AverageMeter('Acc@1', ':6.2f') 371 | progress = ProgressMeter( 372 | args.max_iterations, 373 | [batch_time, data_time, lrs, losses, top1], 374 | prefix="Train: ") 375 | 376 | model.train() 377 | 378 | it = args.start_iteration 379 | val_metrics = [] 380 | while it < args.max_iterations: 381 | end = time.time() 382 | for _, images, target, is_poisoned, inds in train_loader: 383 | if it > args.max_iterations: 384 | break 385 | if it < 5: 386 | show_images_without_cam(images, args.save, f'train-images-iteration-{it:05d}') 387 | 388 | # measure data loading time 389 | data_time.update(time.time() - end) 390 | 391 | images = images.cuda(non_blocking=True) 392 | target = target.cuda(non_blocking=True) 393 | 394 | output = model(images) 395 | 396 | loss = F.cross_entropy(output, target) 397 | losses.update(loss.item(), images.size(0)) 398 | 399 | acc1 = accuracy(output, target, topk=(1,))[0] 400 | top1.update(acc1[0], images.size(0)) 401 | 402 | lrs.update(lr_scheduler.get_last_lr()[-1]) 403 | 404 | # compute gradient and do SGD step 405 | optimizer.zero_grad() 406 | loss.backward() 407 | optimizer.step() 408 | 409 | # measure elapsed time 410 | batch_time.update(time.time() - end) 411 | end = time.time() 412 | 413 | if it % args.print_freq == 0: 414 | logger.info(progress.display(it)) 415 | 416 | if it % args.eval_freq == 0: 417 | recall, precision, f1_beta = validate(val_loader, model, args) 418 | val_metrics.append((recall, precision, f1_beta)) 419 | vm = set([int(x[2]*100) for x in val_metrics[-10:]]) 420 | logger.info(vm) 421 | if len(vm) == 1 and len(val_metrics) > 10: 422 | it = args.max_iterations 423 | break 424 | model.train() 425 | 426 | # modify lr 427 | lr_scheduler.step() 428 | 429 | it += 1 430 | models.append(model) 431 | 432 | logger.info(f'==> run inference on test data') 433 | model = EnsembleNet(models) 434 | recall, precision, preds = test(test_loader, model, args) 435 | with open(os.path.join(args.save, 'filtered.txt'), 'w') as f: 436 | for line, is_poisoned in zip(test_loader.dataset.file_list, preds): 437 | if not is_poisoned: 438 | f.write(f'{line}\n') 439 | 440 | 441 | def validate(val_loader, model, args): 442 | batch_time = AverageMeter('Time', ':6.3f') 443 | data_time = AverageMeter('Data', ':6.3f') 444 | progress = ProgressMeter( 445 | len(val_loader), 446 | [batch_time, data_time], 447 | prefix='Evaluate') 448 | 449 | # switch to train mode 450 | model.eval() 451 | 452 | pred_is_poison = np.zeros(len(val_loader.dataset)) 453 | gt_is_poison = np.zeros(len(val_loader.dataset)) 454 | 455 | end = time.time() 456 | for i, (_, images, target, _, inds) in enumerate(val_loader): 457 | if i == 0: 458 | show_images_without_cam(images, args.save, f'eval-images-iteration-0') 459 | 460 | # measure data loading time 461 | data_time.update(time.time() - end) 462 | 463 | images = images.cuda(non_blocking=True) 464 | 465 | # compute output 466 | with torch.no_grad(): 467 | output = model(images) 468 | 469 | pred = output.argmax(dim=1).detach().cpu() 470 | pred_is_poison[inds.numpy()] = pred.numpy() 471 | gt_is_poison[inds.numpy()] = target.numpy().astype(int) 472 | 473 | # measure elapsed time 474 | batch_time.update(time.time() - end) 475 | end = time.time() 476 | 477 | if i % args.print_freq == 0: 478 | logger.info(progress.display(i)) 479 | 480 | recall = pred_is_poison[np.where(gt_is_poison)[0]].astype(float).mean() 481 | logger.info(f'==> poison recall : {recall*100:.1f}') 482 | 483 | precision = gt_is_poison[np.where(pred_is_poison)[0]].astype(float).mean() 484 | logger.info(f'==> poison precision : {precision*100:.1f}') 485 | 486 | beta = 1 487 | f1_beta = (1 + beta**2) * (precision * recall) / ((beta**2) * precision + recall) 488 | logger.info(f'==> poison F1_beta score (beta = {beta}) : {f1_beta*100:.1f}') 489 | 490 | if math.isnan(recall) or math.isnan(precision) or math.isnan(f1_beta): 491 | return 0., 0., 0. 492 | 493 | return recall, precision, f1_beta 494 | 495 | 496 | def test(test_loader, model, args): 497 | batch_time = AverageMeter('Time', ':6.3f') 498 | data_time = AverageMeter('Data', ':6.3f') 499 | progress = ProgressMeter( 500 | len(test_loader), 501 | [batch_time, data_time], 502 | prefix='Test') 503 | 504 | # switch to train mode 505 | model.eval() 506 | 507 | pred_is_poison = np.zeros(len(test_loader.dataset)) 508 | gt_is_poison = np.zeros(len(test_loader.dataset)) 509 | 510 | end = time.time() 511 | for i, (_, images, _, is_poisoned, inds) in enumerate(test_loader): 512 | # measure data loading time 513 | data_time.update(time.time() - end) 514 | 515 | images = images.cuda(non_blocking=True) 516 | 517 | # compute output 518 | with torch.no_grad(): 519 | output = model(images) 520 | 521 | pred = output.argmax(dim=1).detach().cpu() 522 | pred_is_poison[inds.numpy()] = pred.numpy() 523 | gt_is_poison[inds.numpy()] = is_poisoned.numpy().astype(int) 524 | 525 | # measure elapsed time 526 | batch_time.update(time.time() - end) 527 | end = time.time() 528 | 529 | if i % (len(test_loader) // 20) == 0: 530 | logger.info(progress.display(i)) 531 | 532 | logger.info(f'==> total poisons to remove : {np.count_nonzero(pred_is_poison)}') 533 | poison_recall = pred_is_poison[np.where(gt_is_poison)[0]].astype(float).mean() 534 | logger.info(f'==> poison recall : {poison_recall*100:.1f}') 535 | poison_precision = gt_is_poison[np.where(pred_is_poison)[0]].astype(float).mean() 536 | logger.info(f'==> poison precision : {poison_precision*100:.1f}') 537 | 538 | return poison_recall, poison_precision, pred_is_poison 539 | 540 | 541 | if __name__ == '__main__': 542 | main() 543 | -------------------------------------------------------------------------------- /run_pipeline.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | set -e 5 | 6 | OUTPUT_DIR='/home/ajinkya/work/PatchSearch/output' 7 | CODE_DIR='/home/ajinkya/work/ssl_backdoor_from_meta/code/SSL-Backdoor/poison-generation/data' 8 | EXPERIMENT_ID='HTBA_trigger_12_targeted_n02701002' 9 | EXP_DIR=$OUTPUT_DIR/$EXPERIMENT_ID/moco 10 | EVAL_DIR=$EXP_DIR/linear 11 | DEFENSE_DIR=$EXP_DIR/patch_search_iterative_search_test_images_size_1000_window_w_60_repeat_patch_1_prune_clusters_True_num_clusters_1000_per_iteration_samples_2_remove_0x25 12 | FILTERED_DIR=$DEFENSE_DIR/patch_search_poison_classifier_topk_20_ensemble_5_max_iterations_2000_seed_4789 13 | RATE='0.50' 14 | SEED=4789 15 | 16 | # ### STEP 1.1: pretrain the model 17 | # python main_moco_files_dataset_strong_aug.py \ 18 | # --seed $SEED \ 19 | # -a vit_base --epochs 200 -b 1024 \ 20 | # --stop-grad-conv1 --moco-m-cos \ 21 | # --multiprocessing-distributed --world-size 1 --rank 0 \ 22 | # --dist-url "tcp://localhost:$(( $RANDOM % 50 + 10000 ))" \ 23 | # --save_folder $EXP_DIR \ 24 | # $CODE_DIR/$EXPERIMENT_ID/train/loc_random_loc*_rate_${RATE}_targeted_True_*.txt 25 | 26 | ### STEP 1.2: train a linear layer for evaluating the pretrained model 27 | CUDA_VISIBLE_DEVICES=0 python main_lincls.py \ 28 | --seed $SEED \ 29 | -a vit_base --lr 0.1 \ 30 | --pretrained $EXP_DIR/checkpoint_0199.pth.tar \ 31 | --train_file linear_eval_files/train_ssl_0.01_filelist.txt \ 32 | --val_file linear_eval_files/val_ssl_filelist.txt \ 33 | --save_folder $EVAL_DIR 34 | 35 | ### STEP 1.3: evaluate the trained linear layer with clean and poisoned val data 36 | CUDA_VISIBLE_DEVICES=0 python main_lincls.py \ 37 | --seed $SEED \ 38 | -a vit_base --lr 0.1 \ 39 | --conf_matrix \ 40 | --resume $EVAL_DIR/checkpoint.pth.tar \ 41 | --train_file linear_eval_files/train_ssl_0.01_filelist.txt \ 42 | --val_file linear_eval_files/val_ssl_filelist.txt \ 43 | --val_poisoned_file $CODE_DIR/$EXPERIMENT_ID/val_poisoned/loc_random_*.txt \ 44 | --eval_id exp_${SEED} 45 | 46 | ### STEP 2: run iterative search 47 | for i in {1..2} 48 | do 49 | ### STEP 2.1: calculate and cache the features if not done so and exit else run defense 50 | ### STEP 2.2: run defense if previous step cached the features else just run defense one more time 51 | ### need to break this into two steps since the combining them slows the defense 52 | python patch_search_iterative_search.py \ 53 | --arch moco_vit_base \ 54 | --batch-size 64 \ 55 | --weights $EXP_DIR/checkpoint_0199.pth.tar \ 56 | --train_file $CODE_DIR/$EXPERIMENT_ID/train/loc_random_loc*_rate_${RATE}_targeted_True_*.txt \ 57 | --val_file linear_eval_files/val_ssl_filelist.txt \ 58 | --prune_clusters \ 59 | --use_cached_feats \ 60 | --use_cached_poison_scores 61 | done 62 | 63 | ### STEP 3: run poison classifier 64 | CUDA_VISIBLE_DEVICES=0 python patch_search_poison_classifier.py \ 65 | --print_freq 20 \ 66 | --model_count 5 \ 67 | --batch_size 32 \ 68 | --eval_freq 20 \ 69 | --max_iterations 2000 \ 70 | --workers 8 \ 71 | --seed ${SEED} \ 72 | --train_file $CODE_DIR/$EXPERIMENT_ID/train/loc_random_loc*_rate_${RATE}_targeted_True_*.txt \ 73 | --poison_dir $DEFENSE_DIR/all_top_poison_patches \ 74 | --poison_scores $DEFENSE_DIR/poison-scores.npy \ 75 | --eval_data "seed_${SEED}" \ 76 | --topk_poisons 20 77 | 78 | EXP_DIR=$FILTERED_DIR/moco 79 | EVAL_DIR=$EXP_DIR/linear 80 | 81 | ### STEP 4.1: pretrain the model on training set filtered with PatchSearch 82 | python main_moco_files_dataset_strong_aug.py \ 83 | --seed $SEED \ 84 | -a vit_base --epochs 200 -b 1024 \ 85 | --icutmix --alpha 1.0 \ 86 | --stop-grad-conv1 --moco-m-cos \ 87 | --multiprocessing-distributed --world-size 1 --rank 0 \ 88 | --dist-url "tcp://localhost:$(( $RANDOM % 50 + 10000 ))" \ 89 | --save_folder $EXP_DIR \ 90 | $FILTERED_DIR/filtered.txt 91 | 92 | ### STEP 4.2: train a linear layer for evaluating the pretrained model 93 | CUDA_VISIBLE_DEVICES=0 python main_lincls.py \ 94 | --seed $seed \ 95 | -a vit_base --lr 0.1 \ 96 | --pretrained $EXP_DIR/checkpoint_0199.pth.tar \ 97 | --train_file linear_eval_files/train_ssl_0.01_filelist.txt \ 98 | --val_file linear_eval_files/val_ssl_filelist.txt \ 99 | --save_folder $EVAL_DIR 100 | 101 | ### STEP 4.3: evaluate the trained linear layer with clean and poisoned val data 102 | CUDA_VISIBLE_DEVICES=0 python main_lincls.py \ 103 | --seed $seed \ 104 | -a vit_base --lr 0.1 \ 105 | --conf_matrix \ 106 | --resume $EVAL_DIR/checkpoint.pth.tar \ 107 | --train_file linear_eval_files/train_ssl_0.01_filelist.txt \ 108 | --val_file linear_eval_files/val_ssl_filelist.txt \ 109 | --val_poisoned_file $CODE_DIR/$EXPERIMENT_ID/val_poisoned/loc_random_*.txt \ 110 | --eval_id exp_${seed} 111 | 112 | -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | 3 | import logging 4 | import os 5 | 6 | import torch 7 | from torch import nn 8 | from torchvision import models 9 | 10 | 11 | def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False): 12 | logger = logging.getLogger() 13 | if debug: 14 | level = logging.DEBUG 15 | else: 16 | level = logging.INFO 17 | logger.setLevel(level) 18 | if saving: 19 | info_file_handler = logging.FileHandler(logpath, mode="a") 20 | info_file_handler.setLevel(level) 21 | logger.addHandler(info_file_handler) 22 | if displaying: 23 | console_handler = logging.StreamHandler() 24 | console_handler.setLevel(level) 25 | logger.addHandler(console_handler) 26 | logger.info(filepath) 27 | with open(filepath, "r") as f: 28 | logger.info(f.read()) 29 | 30 | for f in package_files: 31 | logger.info(f) 32 | with open(f, "r") as package_f: 33 | logger.info(package_f.read()) 34 | 35 | return logger 36 | 37 | 38 | def makedirs(dirname): 39 | if not os.path.exists(dirname): 40 | os.makedirs(dirname) 41 | 42 | 43 | def save_each_checkpoint(state, epoch, save_dir): 44 | ckpt_path = os.path.join(save_dir, 'ckpt_%d.pth.tar' % epoch) 45 | torch.save(state, ckpt_path) 46 | 47 | 48 | def save_checkpoint(state, is_best, save_dir): 49 | ckpt_path = os.path.join(save_dir, 'checkpoint.pth.tar') 50 | torch.save(state, ckpt_path) 51 | if is_best: 52 | best_ckpt_path = os.path.join(save_dir, 'model_best.pth.tar') 53 | shutil.copyfile(ckpt_path, best_ckpt_path) 54 | 55 | 56 | class AverageMeter(object): 57 | """Computes and stores the average and current value""" 58 | def __init__(self, name, fmt=':f'): 59 | self.name = name 60 | self.fmt = fmt 61 | self.reset() 62 | 63 | def reset(self): 64 | self.val = 0 65 | self.avg = 0 66 | self.sum = 0 67 | self.count = 0 68 | 69 | def update(self, val, n=1): 70 | self.val = val 71 | self.sum += val * n 72 | self.count += n 73 | self.avg = self.sum / self.count 74 | 75 | def __str__(self): 76 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 77 | return fmtstr.format(**self.__dict__) 78 | 79 | 80 | class ProgressMeter(object): 81 | def __init__(self, num_batches, meters, prefix=""): 82 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 83 | self.meters = meters 84 | self.prefix = prefix 85 | 86 | def display(self, batch): 87 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 88 | entries += [str(meter) for meter in self.meters] 89 | return '\t'.join(entries) 90 | 91 | def _get_batch_fmtstr(self, num_batches): 92 | num_digits = len(str(num_batches // 1)) 93 | fmt = '{:' + str(num_digits) + 'd}' 94 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 95 | 96 | import pdb 97 | 98 | def accuracy(output, target, topk=(1,)): 99 | """Computes the accuracy over the k top predictions for the specified values of k""" 100 | with torch.no_grad(): 101 | maxk = max(topk) 102 | batch_size = target.size(0) 103 | 104 | _, pred = output.topk(maxk, 1, True, True) 105 | pred = pred.t() 106 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 107 | # pdb.set_trace() 108 | res = [] 109 | for k in topk: 110 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 111 | res.append(correct_k.mul_(100.0 / batch_size)) 112 | return res 113 | 114 | 115 | arch_to_key = { 116 | 'alexnet': 'alexnet', 117 | 'alexnet_moco': 'alexnet', 118 | 'resnet18': 'resnet18', 119 | 'resnet50': 'resnet50', 120 | 'rotnet_r50': 'resnet50', 121 | 'rotnet_r18': 'resnet18', 122 | 'resnet18_moco': 'resnet18', 123 | 'resnet_moco': 'resnet50', 124 | } 125 | 126 | model_names = list(arch_to_key.keys()) 127 | 128 | 129 | def remove_dropout(model): 130 | classif = model.classifier.children() 131 | classif = [nn.Sequential() if isinstance(m, nn.Dropout) else m for m in classif] 132 | model.classifier = nn.Sequential(*classif) 133 | 134 | 135 | # 1. stores a list of models to ensemble 136 | # 2. forward through each model and save the output 137 | # 3. return mean of the outputs along the class dimension 138 | class EnsembleNet(nn.ModuleList): 139 | def forward(self, x): 140 | out = [m(x) for m in self] 141 | out = torch.stack(out, dim=-1) 142 | out = out.mean(dim=-1) 143 | return out 144 | -------------------------------------------------------------------------------- /vits.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from functools import partial, reduce 5 | from operator import mul 6 | 7 | from timm.models.vision_transformer import VisionTransformer, _cfg 8 | from timm.models.layers.helpers import to_2tuple 9 | from timm.models.layers import PatchEmbed 10 | 11 | __all__ = [ 12 | 'vit_small', 13 | 'vit_base', 14 | 'vit_conv_small', 15 | 'vit_conv_base', 16 | ] 17 | 18 | 19 | class VisionTransformerMoCo(VisionTransformer): 20 | def __init__(self, stop_grad_conv1=False, **kwargs): 21 | super().__init__(**kwargs) 22 | # Use fixed 2D sin-cos position embedding 23 | self.build_2d_sincos_position_embedding() 24 | 25 | # weight initialization 26 | for name, m in self.named_modules(): 27 | if isinstance(m, nn.Linear): 28 | if 'qkv' in name: 29 | # treat the weights of Q, K, V separately 30 | val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1])) 31 | nn.init.uniform_(m.weight, -val, val) 32 | else: 33 | nn.init.xavier_uniform_(m.weight) 34 | nn.init.zeros_(m.bias) 35 | nn.init.normal_(self.cls_token, std=1e-6) 36 | 37 | if isinstance(self.patch_embed, PatchEmbed): 38 | # xavier_uniform initialization 39 | val = math.sqrt(6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim)) 40 | nn.init.uniform_(self.patch_embed.proj.weight, -val, val) 41 | nn.init.zeros_(self.patch_embed.proj.bias) 42 | 43 | if stop_grad_conv1: 44 | self.patch_embed.proj.weight.requires_grad = False 45 | self.patch_embed.proj.bias.requires_grad = False 46 | 47 | def build_2d_sincos_position_embedding(self, temperature=10000.): 48 | h, w = self.patch_embed.grid_size 49 | grid_w = torch.arange(w, dtype=torch.float32) 50 | grid_h = torch.arange(h, dtype=torch.float32) 51 | grid_w, grid_h = torch.meshgrid(grid_w, grid_h) 52 | assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' 53 | pos_dim = self.embed_dim // 4 54 | omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim 55 | omega = 1. / (temperature**omega) 56 | out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) 57 | out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) 58 | pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] 59 | 60 | assert self.num_tokens == 1, 'Assuming one and only one token, [cls]' 61 | pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32) 62 | self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1)) 63 | self.pos_embed.requires_grad = False 64 | 65 | 66 | class ConvStem(nn.Module): 67 | """ 68 | ConvStem, from Early Convolutions Help Transformers See Better, Tete et al. https://arxiv.org/abs/2106.14881 69 | """ 70 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 71 | super().__init__() 72 | 73 | assert patch_size == 16, 'ConvStem only supports patch size of 16' 74 | assert embed_dim % 8 == 0, 'Embed dimension must be divisible by 8 for ConvStem' 75 | 76 | img_size = to_2tuple(img_size) 77 | patch_size = to_2tuple(patch_size) 78 | self.img_size = img_size 79 | self.patch_size = patch_size 80 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 81 | self.num_patches = self.grid_size[0] * self.grid_size[1] 82 | self.flatten = flatten 83 | 84 | # build stem, similar to the design in https://arxiv.org/abs/2106.14881 85 | stem = [] 86 | input_dim, output_dim = 3, embed_dim // 8 87 | for l in range(4): 88 | stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False)) 89 | stem.append(nn.BatchNorm2d(output_dim)) 90 | stem.append(nn.ReLU(inplace=True)) 91 | input_dim = output_dim 92 | output_dim *= 2 93 | stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1)) 94 | self.proj = nn.Sequential(*stem) 95 | 96 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 97 | 98 | def forward(self, x): 99 | B, C, H, W = x.shape 100 | assert H == self.img_size[0] and W == self.img_size[1], \ 101 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 102 | x = self.proj(x) 103 | if self.flatten: 104 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 105 | x = self.norm(x) 106 | return x 107 | 108 | 109 | def vit_small(**kwargs): 110 | model = VisionTransformerMoCo( 111 | patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 112 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 113 | model.default_cfg = _cfg() 114 | return model 115 | 116 | def vit_base(**kwargs): 117 | model = VisionTransformerMoCo( 118 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 119 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 120 | model.default_cfg = _cfg() 121 | return model 122 | 123 | def vit_conv_small(**kwargs): 124 | # minus one ViT block 125 | model = VisionTransformerMoCo( 126 | patch_size=16, embed_dim=384, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True, 127 | norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs) 128 | model.default_cfg = _cfg() 129 | return model 130 | 131 | def vit_conv_base(**kwargs): 132 | # minus one ViT block 133 | model = VisionTransformerMoCo( 134 | patch_size=16, embed_dim=768, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True, 135 | norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs) 136 | model.default_cfg = _cfg() 137 | return model 138 | --------------------------------------------------------------------------------