├── .gitignore ├── README.md ├── assets └── figure │ ├── 001.png │ ├── 002.png │ ├── 003.png │ ├── 004.png │ ├── 005.png │ ├── 006.png │ ├── 007.png │ ├── 008.png │ ├── 009.png │ ├── 010.png │ ├── 011.png │ ├── 2S1.png │ ├── confuser-rejection.png │ ├── cr-confusion-matrix.png │ ├── cr-roc.png │ ├── cr-training-plot.png │ ├── eoc-1-confusion-matrix.png │ ├── eoc-1-training-plot.png │ ├── eoc-2-cv-confusion-matrix.png │ ├── eoc-2-cv-training-plot.png │ ├── eoc-2-vv-confusion-matrix.png │ ├── eoc-2-vv-training-plot.png │ ├── soc-confusion-matrix.png │ └── soc-training-plot.png ├── dataset ├── confuser-rejection │ ├── .gitkeep │ └── raw │ │ └── rename.py ├── eoc-1-t72-132 │ ├── .gitkeep │ └── raw │ │ └── rename.py ├── eoc-1-t72-a64 │ ├── .gitkeep │ └── raw │ │ └── rename.py ├── eoc-2-cv │ ├── .gitkeep │ └── raw │ │ └── rename.py ├── eoc-2-vv │ ├── .gitkeep │ └── raw │ │ └── rename.py └── soc │ ├── .gitkeep │ └── raw │ └── rename.py ├── docker └── Dockerfile ├── experiments └── config │ ├── AConvNet-CR.json │ ├── AConvNet-EOC-1-T72-132.json │ ├── AConvNet-EOC-1-T72-A64.json │ ├── AConvNet-EOC-2-CV.json │ ├── AConvNet-EOC-2-VV.json │ └── AConvNet-SOC.json ├── notebook ├── experiments-CR.ipynb ├── experiments-EOC-1-T72-132.ipynb ├── experiments-EOC-1-T72-A64.ipynb ├── experiments-EOC-2-CV.ipynb ├── experiments-EOC-2-VV.ipynb ├── experiments-SOC.ipynb └── target-chip.ipynb ├── requirements.txt ├── run-docker.sh └── src ├── data ├── __init__.py ├── generate_dataset.py ├── loader.py ├── mstar.py └── preprocess.py ├── model ├── __init__.py ├── _base.py ├── _blocks.py └── network.py ├── train.py └── utils ├── __init__.py └── common.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/windows,linux,pycharm+all,python 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=windows,linux,pycharm+all,python 4 | 5 | ### Linux ### 6 | *~ 7 | 8 | # temporary files which can be created if a process still has a handle open of a deleted file 9 | .fuse_hidden* 10 | 11 | # KDE directory preferences 12 | .directory 13 | 14 | # Linux trash folder which might appear on any partition or disk 15 | .Trash-* 16 | 17 | # .nfs files are created when an open file is removed but is still being accessed 18 | .nfs* 19 | 20 | ### PyCharm+all ### 21 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 22 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 23 | 24 | # User-specific stuff 25 | .idea/**/workspace.xml 26 | .idea/**/tasks.xml 27 | .idea/**/usage.statistics.xml 28 | .idea/**/dictionaries 29 | .idea/**/shelf 30 | 31 | # Generated files 32 | .idea/**/contentModel.xml 33 | 34 | # Sensitive or high-churn files 35 | .idea/**/dataSources/ 36 | .idea/**/dataSources.ids 37 | .idea/**/dataSources.local.xml 38 | .idea/**/sqlDataSources.xml 39 | .idea/**/dynamic.xml 40 | .idea/**/uiDesigner.xml 41 | .idea/**/dbnavigator.xml 42 | 43 | # Gradle 44 | .idea/**/gradle.xml 45 | .idea/**/libraries 46 | 47 | # Gradle and Maven with auto-import 48 | # When using Gradle or Maven with auto-import, you should exclude module files, 49 | # since they will be recreated, and may cause churn. Uncomment if using 50 | # auto-import. 51 | # .idea/artifacts 52 | # .idea/compiler.xml 53 | # .idea/jarRepositories.xml 54 | # .idea/modules.xml 55 | # .idea/*.iml 56 | # .idea/modules 57 | # *.iml 58 | # *.ipr 59 | 60 | # CMake 61 | cmake-build-*/ 62 | 63 | # Mongo Explorer plugin 64 | .idea/**/mongoSettings.xml 65 | 66 | # File-based project format 67 | *.iws 68 | 69 | # IntelliJ 70 | out/ 71 | 72 | # mpeltonen/sbt-idea plugin 73 | .idea_modules/ 74 | 75 | # JIRA plugin 76 | atlassian-ide-plugin.xml 77 | 78 | # Cursive Clojure plugin 79 | .idea/replstate.xml 80 | 81 | # Crashlytics plugin (for Android Studio and IntelliJ) 82 | com_crashlytics_export_strings.xml 83 | crashlytics.properties 84 | crashlytics-build.properties 85 | fabric.properties 86 | 87 | # Editor-based Rest Client 88 | .idea/httpRequests 89 | 90 | # Android studio 3.1+ serialized cache file 91 | .idea/caches/build_file_checksums.ser 92 | 93 | ### PyCharm+all Patch ### 94 | # Ignores the whole .idea folder and all .iml files 95 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 96 | 97 | .idea/ 98 | 99 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 100 | 101 | *.iml 102 | modules.xml 103 | .idea/misc.xml 104 | *.ipr 105 | 106 | # Sonarlint plugin 107 | .idea/sonarlint 108 | 109 | ### Python ### 110 | # Byte-compiled / optimized / DLL files 111 | __pycache__/ 112 | *.py[cod] 113 | *$py.class 114 | 115 | # C extensions 116 | *.so 117 | 118 | # Distribution / packaging 119 | .Python 120 | build/ 121 | develop-eggs/ 122 | dist/ 123 | downloads/ 124 | eggs/ 125 | .eggs/ 126 | parts/ 127 | sdist/ 128 | var/ 129 | wheels/ 130 | pip-wheel-metadata/ 131 | share/python-wheels/ 132 | *.egg-info/ 133 | .installed.cfg 134 | *.egg 135 | MANIFEST 136 | 137 | # PyInstaller 138 | # Usually these files are written by a python script from a template 139 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 140 | *.manifest 141 | *.spec 142 | 143 | # Installer logs 144 | pip-log.txt 145 | pip-delete-this-directory.txt 146 | 147 | # Unit test / coverage reports 148 | htmlcov/ 149 | .tox/ 150 | .nox/ 151 | .coverage 152 | .coverage.* 153 | .cache 154 | nosetests.xml 155 | coverage.xml 156 | *.cover 157 | *.py,cover 158 | .hypothesis/ 159 | .pytest_cache/ 160 | pytestdebug.log 161 | 162 | # Translations 163 | *.mo 164 | *.pot 165 | 166 | # Django stuff: 167 | *.log 168 | local_settings.py 169 | db.sqlite3 170 | db.sqlite3-journal 171 | 172 | # Flask stuff: 173 | instance/ 174 | .webassets-cache 175 | 176 | # Scrapy stuff: 177 | .scrapy 178 | 179 | # Sphinx documentation 180 | docs/_build/ 181 | doc/_build/ 182 | 183 | # PyBuilder 184 | target/ 185 | 186 | # Jupyter Notebook 187 | .ipynb_checkpoints 188 | 189 | # IPython 190 | profile_default/ 191 | ipython_config.py 192 | 193 | # pyenv 194 | .python-version 195 | 196 | # pipenv 197 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 198 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 199 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 200 | # install all needed dependencies. 201 | #Pipfile.lock 202 | 203 | # poetry 204 | #poetry.lock 205 | 206 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 207 | __pypackages__/ 208 | 209 | # Celery stuff 210 | celerybeat-schedule 211 | celerybeat.pid 212 | 213 | # SageMath parsed files 214 | *.sage.py 215 | 216 | # Environments 217 | # .env 218 | .env/ 219 | .venv/ 220 | env/ 221 | venv/ 222 | ENV/ 223 | env.bak/ 224 | venv.bak/ 225 | pythonenv* 226 | 227 | # Spyder project settings 228 | .spyderproject 229 | .spyproject 230 | 231 | # Rope project settings 232 | .ropeproject 233 | 234 | # mkdocs documentation 235 | /site 236 | 237 | # mypy 238 | .mypy_cache/ 239 | .dmypy.json 240 | dmypy.json 241 | 242 | # Pyre type checker 243 | .pyre/ 244 | 245 | # pytype static type analyzer 246 | .pytype/ 247 | 248 | # operating system-related files 249 | *.DS_Store #file properties cache/storage on macOS 250 | Thumbs.db #thumbnail cache on Windows 251 | 252 | # profiling data 253 | .prof 254 | 255 | 256 | ### Windows ### 257 | # Windows thumbnail cache files 258 | Thumbs.db 259 | Thumbs.db:encryptable 260 | ehthumbs.db 261 | ehthumbs_vista.db 262 | 263 | # Dump file 264 | *.stackdump 265 | 266 | # Folder config file 267 | [Dd]esktop.ini 268 | 269 | # Recycle Bin used on file shares 270 | $RECYCLE.BIN/ 271 | 272 | # Windows Installer files 273 | *.cab 274 | *.msi 275 | *.msix 276 | *.msm 277 | *.msp 278 | 279 | # Windows shortcuts 280 | *.lnk 281 | 282 | # asdf 283 | dataset/soc/raw/train/ 284 | dataset/soc/raw/test/ 285 | dataset/soc/train/ 286 | dataset/soc/test/ 287 | 288 | dataset/eoc-1-*/raw/train/ 289 | dataset/eoc-1-*/raw/test/ 290 | dataset/eoc-1-*/train/ 291 | dataset/eoc-1-*/test/ 292 | 293 | dataset/eoc-2*/raw/train/ 294 | dataset/eoc-2*/raw/test/ 295 | dataset/eoc-2*/train/ 296 | dataset/eoc-2*/test/ 297 | 298 | dataset/confuser-rejection/raw/train/ 299 | dataset/confuser-rejection/raw/test/ 300 | dataset/confuser-rejection/train/ 301 | dataset/confuser-rejection/test/ 302 | 303 | experiments/history 304 | 305 | *.pth 306 | *.zip 307 | # End of https://www.toptal.com/developers/gitignore/api/windows,linux,pycharm+all,python 308 | 309 | 310 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AConvNet 2 | 3 | ### Target Classification Using the Deep Convolutional Networks for SAR Images 4 | 5 | This repository is reproduced-implementation of AConvNet which recognize target from MSTAR dataset. 6 | You can see the official implementation of the author at [MSTAR-AConvNet](https://github.com/fudanxu/MSTAR-AConvNet). 7 | 8 | ## Dataset 9 | 10 | ### MSTAR (Moving and Stationary Target Acquisition and Recognition) Database 11 | 12 | #### Format 13 | 14 | - Header 15 | - Type: ASCII 16 | - Including data shape(width, height), serial number, azimuth angle, etc. 17 | - Data 18 | - Type: Two-bytes 19 | - Shape: W x H x 2 20 | - Magnitude block 21 | - Phase Block 22 | 23 | Below figure is the example of magnitude block(Left) and phase block(Right) 24 | 25 | ![Example of data block: 2S1](./assets/figure/001.png) 26 | 27 | ## Model 28 | 29 | The proposed model only consists of **sparsely connected layers** without any fully connected layers. 30 | 31 | - It eases the over-fitting problem by reducing the number of free parameters(model capacity) 32 | 33 | | layer | Input | Conv 1 | Conv 2 | Conv 3 | Conv 4 | Conv 5 | 34 | | :---------: | ------ | :--------: | :--------: | :--------: | :----: | :-----: | 35 | | channels | 2 | 16 | 32 | 64 | 128 | 10 | 36 | | weight size | - | 5 x 5 | 5 x 5 | 6 x 6 | 5 x 5 | 3 x 3 | 37 | | pooling | - | 2 x 2 - s2 | 2 x 2 - s2 | 2 x 2 - s2 | - | - | 38 | | dropout | - | - | - | - | 0.5 | - | 39 | | activation | linear | ReLU | ReLU | ReLU | ReLU | Softmax | 40 | 41 | ## Training 42 | For training, this implementation fixes the random seed to `12321` for `reproducibility`. 43 | 44 | The experimental conditions are same as in the paper, except for `data augmentation`. 45 | You can see the details in `src/model/_base.py` and `experiments/config/AConvNet-SOC.json` 46 | 47 | ### Data Augmentation 48 | 49 | - The author uses random shifting to extract 88 x 88 patches from 128 x 128 SAR image chips. 50 | - The number of training images per one SAR image chip could be increased at maximum by (128 - 88 + 1) x (128 - 88 + 1) = 1681. 51 | 52 | - However, for SOC, this repository does not use random shifting tue to accuracy issue. 53 | - You can see the details in `src/data/generate_dataset.py` and `src/data/mstar.py` 54 | - The implementation details for data augmentation is as: 55 | - Crop the center of 94 x 94 size image on 100 x 100 SAR image chip (49 patches per image chip). 56 | - Extract 88 x 88 patches with stride 1 from 94 x 94 image with random cropping. 57 | 58 | 59 | ## Experiments 60 | 61 | You can download the MSTAR Dataset from [MSTAR Overview](https://www.sdms.afrl.af.mil/index.php?collection=mstar) 62 | - MSTAR Clutter - CD1 / CD2 63 | - MSTAR Target Chips (T72 BMP2 BTR70 SLICY) - CD1 64 | - MSTAR / IU Mixed Targets - CD1 / CD2 65 | - MSTAR / IU T-72 Variants - CD1 / CD2 66 | - MSTAR Predictlite Software - CD1 67 | 68 | You can download the experimental results of this repository from [experiments](https://github.com/jangsoopark/AConvNet-pytorch/releases/download/v2.2.0/experiments.zip) 69 | 70 | ### Standard Operating Condition (SOC) 71 | 72 | | | | Train | | Test | | 73 | | ------- | ---------- | ---------- | ---------- | ---------- | ---------- | 74 | | Class | Serial No. | Depression | No. Images | Depression | No. Images | 75 | | BMP-2 | 9563 | 17 | 233 | 15 | 196 | 76 | | BTR-70 | c71 | 17 | 233 | 15 | 196 | 77 | | T-72 | 132 | 17 | 232 | 15 | 196 | 78 | | BTR-60 | k10yt7532 | 17 | 256 | 15 | 195 | 79 | | 2S1 | b01 | 17 | 299 | 15 | 274 | 80 | | BRDM-2 | E-71 | 17 | 298 | 15 | 274 | 81 | | D7 | 92v13015 | 17 | 299 | 15 | 274 | 82 | | T-62 | A51 | 17 | 299 | 15 | 273 | 83 | | ZIL-131 | E12 | 17 | 299 | 15 | 274 | 84 | | ZSU-234 | d08 | 17 | 299 | 15 | 274 | 85 | 86 | ##### Training Set (Depression: 17$\degree$) 87 | 88 | ```shell 89 | MSTAR-PublicTargetChips-T72-BMP2-BTR70-SLICY/MSTAR_PUBLIC_TARGETS_CHIPS_T72_BMP2_BTR70_SLICY 90 | ├ TARGETS/TRAIN/17_DEG 91 | │ ├ BMP2/SN_9563/*.000 (233 images) 92 | │ ├ BTR70/SN_C71/*.004 (233 images) 93 | │ └ T72/SN_132/*.015 (232 images) 94 | └ ... 95 | 96 | MSTAR-PublicMixedTargets-CD2/MSTAR_PUBLIC_MIXED_TARGETS_CD2 97 | ├ 17_DEG 98 | │ ├ COL1/SCENE1/BTR_60/*.003 (256 images) 99 | │ └ COL2/SCENE1 100 | │ ├ 2S1/*.000 (299 images) 101 | │ ├ BRDM_2/*.001 (298 images) 102 | │ ├ D7/*.005 (299 images) 103 | │ ├ SLICY 104 | │ ├ T62/*.016 (299 images) 105 | │ ├ ZIL131/*.025 (299 images) 106 | │ └ ZSU_23_4/*.026 (299 images) 107 | └ ... 108 | 109 | ``` 110 | 111 | ##### Test Set (Depression: 15$\degree$) 112 | 113 | ```shell 114 | MSTAR-PublicTargetChips-T72-BMP2-BTR70-SLICY/MSTAR_PUBLIC_TARGETS_CHIPS_T72_BMP2_BTR70_SLICY 115 | ├ TARGETS/TEST/15_DEG 116 | │ ├ BMP2/SN_9563/*.000 (195 images) 117 | │ ├ BTR70/SN_C71/*.004 (196 images) 118 | │ └ T72/SN_132/*.015 (196 images) 119 | └ ... 120 | 121 | MSTAR-PublicMixedTargets-CD1/MSTAR_PUBLIC_MIXED_TARGETS_CD1 122 | ├ 15_DEG 123 | │ ├ COL1/SCENE1/BTR_60/*.003 (195 images) 124 | │ └ COL2/SCENE1 125 | │ ├ 2S1/*.000 (274 images) 126 | │ ├ BRDM_2/*.001 (274 images) 127 | │ ├ D7/*.005 (274 images) 128 | │ ├ SLICY 129 | │ ├ T62/*.016 (273 images) 130 | │ ├ ZIL131/*.025 (274 images) 131 | │ └ ZSU_23_4/*.026 (274 images) 132 | └ ... 133 | 134 | ``` 135 | ##### Quick Start Guide for Training 136 | 137 | - Dataset Preparation 138 | - Download the [dataset.zip](https://github.com/jangsoopark/AConvNet-pytorch/releases/download/v2.2.0/dataset.zip) 139 | - After extracting it, you can find `train` and `test` directories inside `raw` directory. 140 | - Place the two directories (`train` and `test`) to the `dataset/soc/raw`. 141 | ```shell 142 | $ cd src/data 143 | $ python3 generate_dataset.py --is_train=True --use_phase=True --chip_size=100 --patch_size=94 --use_phase=True --dataset=soc 144 | $ python3 generate_dataset.py --is_train=False --use_phase=True --chip_size=128 --patch_size=128 --use_phase=True --dataset=soc 145 | $ cd .. 146 | $ python3 train.py --config_name=config/AConvNet-SOC.json 147 | ``` 148 | 149 | ##### Results of SOC 150 | - Final Accuracy is **99.34%** at epoch 29 (The official accuracy is 99.13%) 151 | - You can see the details in `notebook/experiments-SOC.ipynb` 152 | 153 | - Visualization of training loss and test accuracy 154 | 155 | ![soc-training-plot](./assets/figure/soc-training-plot.png) 156 | 157 | - Confusion Matrix with best model at **epoch 28** 158 | 159 | ![soc-confusion-matrix](./assets/figure/soc-confusion-matrix.png) 160 | 161 | - Noise Simulation [1] 162 | - i.i.d samples from a uniform distribution 163 | - This simulation does not fix the random seed 164 | 165 | | Noise | 1% | 5% | 10% | 15%| 166 | | :---: | :---: | :---: | :---: | :---: | 167 | | AConvNet-PyTorch | 98.64 | 94.10 | 84.54 | 71.55 | 168 | | AConvNet-Official | 91.76 | 88.52 | 75.84 | 54.68 | 169 | 170 | 171 | ### Extended Operating Conditions (EOC) 172 | 173 | #### EOC-1 (Large depression angle change) 174 | 175 | | | | Train | | Test | | 176 | | ------- | ---------- | ---------- | ---------- | ---------- | ---------- | 177 | | Class | Serial No. | Depression | No. Images | Depression | No. Images | 178 | | T-72 | A64 | 17 | 299 | 30 | 196 | 179 | | 2S1 | b01 | 17 | 299 | 30 | 274 | 180 | | BRDM-2 | E-71 | 17 | 298 | 30 | 274 | 181 | | ZSU-234 | d08 | 17 | 299 | 30 | 274 | 182 | 183 | ##### Training Set (Depression: 17$\degree$) 184 | ```shell 185 | MSTAR-PublicT72Variants-CD2/MSTAR_PUBLIC_T_72_VARIANTS_CD2 186 | ├ 17_DEG/COL2/SCENE1 187 | │ └ A64/*.024 (299 images) 188 | └ ... 189 | 190 | MSTAR-PublicMixedTargets-CD2/MSTAR_PUBLIC_MIXED_TARGETS_CD2 191 | ├ 17_DEG 192 | │ └ COL2/SCENE1 193 | │ ├ 2S1/*.000 (299 images) 194 | │ ├ BRDM_2/*.001 (298 images) 195 | │ └ ZSU_23_4/*.026 (299 images) 196 | └ ... 197 | 198 | ``` 199 | 200 | ##### Test Set (Depression: 30$\degree$) 201 | 202 | ```shell 203 | MSTAR-PublicT72Variants-CD2/MSTAR_PUBLIC_T_72_VARIANTS_CD2 204 | ├ 30_DEG/COL2/SCENE1 205 | │ └ A64/*.024 (288 images) 206 | └ ... 207 | 208 | MSTAR-PublicMixedTargets-CD2/MSTAR_PUBLIC_MIXED_TARGETS_CD2 209 | ├ 30_DEG 210 | │ └ COL2/SCENE1 211 | │ ├ 2S1/*.000 (288 images) 212 | │ ├ BRDM_2/*.001 (287 images) 213 | │ ├ ZSU_23_4/*.026 (288 images) 214 | │ └ ... 215 | └ ... 216 | 217 | ``` 218 | 219 | ##### Quick Start Guide for Training 220 | 221 | - Dataset Preparation 222 | - Download the [dataset.zip](https://github.com/jangsoopark/AConvNet-pytorch/releases/download/v2.2.0/dataset.zip) 223 | - After extracting it, you can find `train` and `test` directories inside `raw` directory. 224 | - Place the two directories (`train` and `test`) to the `dataset/eoc-1-t72-a64/raw`. 225 | ```shell 226 | $ cd src/data 227 | $ python3 generate_dataset.py --is_train=True --use_phase=True --chip_size=100 --patch_size=94 --use_phase=True --dataset=eoc-1-t72-a64 228 | $ python3 generate_dataset.py --is_train=False --use_phase=True --chip_size=128 --patch_size=128 --use_phase=True --dataset=eoc-1-t72-a64 229 | $ cd .. 230 | $ python3 train.py --config_name=config/AConvNet-EOC-1-T72-A64.json 231 | ``` 232 | 233 | ##### Results of EOC-1 234 | - Final Accuracy is **91.49%** at epoch 17 (The official accuracy is 96.12%) 235 | - You can see the details in `notebook/experiments-EOC-1-T72-A64.ipynb` 236 | 237 | - Visualization of training loss and test accuracy 238 | 239 | ![eoc-1-training-plot](./assets/figure/eoc-1-training-plot.png) 240 | 241 | - Confusion Matrix with best model at **epoch 28** 242 | 243 | ![eoc-1-confusion-matrix](./assets/figure/eoc-1-confusion-matrix.png) 244 | 245 | 246 | #### EOC-2 (Target configuration variants) 247 | 248 | | | | Train | | Test | | 249 | | ------- | ---------- | ---------- | ---------- | ---------- | ---------- | 250 | | Class | Serial No. | Depression | No. Images | Depression | No. Images | 251 | | BMP-2 | 9563 | 17 | 233 | - | - | 252 | | BRDM-2 | E-71 | 17 | 298 | - | - | 253 | | BTR-70 | c71 | 17 | 233 | - | - | 254 | | T-72 | 132 | 17 | 232 | - | - | 255 | | T-72 | S7 | - | - | 15, 17 | 419 | 256 | | T-72 | A32 | - | - | 15, 17 | 572 | 257 | | T-72 | A62 | - | - | 15, 17 | 573 | 258 | | T-72 | A63 | - | - | 15, 17 | 573 | 259 | | T-72 | A64 | - | - | 15, 17 | 573 | 260 | 261 | ##### Training Set (Depression: 17$\degree$) 262 | ```shell 263 | # BMP2, BRDM2, BTR70, T72 are selected from SOC training data 264 | ``` 265 | 266 | ##### Test Set (Depression: 15$\degree$) 267 | 268 | ```shell 269 | MSTAR-PublicTargetChips-T72-BMP2-BTR70-SLICY/MSTAR_PUBLIC_TARGETS_CHIPS_T72_BMP2_BTR70_SLICY 270 | ├ TARGETS/TRAIN/17_DEG 271 | │ └ T72/SN_S7/*.017 (228 images) 272 | └ ... 273 | 274 | MSTAR-PublicT72Variants-CD2/MSTAR_PUBLIC_T_72_VARIANTS_CD2 275 | ├ 17_DEG/COL2/SCENE1 276 | │ ├ A32/*.017 (299 images) 277 | │ ├ A62/*.018 (299 images) 278 | │ ├ A63/*.019 (299 images) 279 | │ ├ A64/*.020 (299 images) 280 | ├ └ ... 281 | └ ... 282 | 283 | MSTAR-PublicTargetChips-T72-BMP2-BTR70-SLICY/MSTAR_PUBLIC_TARGETS_CHIPS_T72_BMP2_BTR70_SLICY 284 | ├ TARGETS/TEST/15_DEG 285 | │ └ T72/SN_S7/*.017 (191 images) 286 | └ ... 287 | 288 | MSTAR-PublicT72Variants-CD1/MSTAR_PUBLIC_T_72_VARIANTS_CD1 289 | ├ 15_DEG/COL2/SCENE1 290 | │ ├ A32/*.017 (274 images) 291 | │ ├ A62/*.018 (274 images) 292 | │ ├ A63/*.019 (274 images) 293 | │ ├ A64/*.020 (274 images) 294 | ├ └ ... 295 | └ ... 296 | 297 | ``` 298 | 299 | ##### Quick Start Guide for Training 300 | 301 | - Dataset Preparation 302 | - Download the [dataset.zip](https://github.com/jangsoopark/AConvNet-pytorch/releases/download/v2.2.0/dataset.zip) 303 | - After extracting it, you can find `train` and `test` directories inside `raw` directory. 304 | - Place the two directories (`train` and `test`) to the `dataset/eoc-2-cv/raw`. 305 | ```shell 306 | $ cd src/data 307 | $ python3 generate_dataset.py --is_train=True --use_phase=True --chip_size=100 --patch_size=94 --use_phase=True --dataset=eoc-2-cv 308 | $ python3 generate_dataset.py --is_train=False --use_phase=True --chip_size=128 --patch_size=128 --use_phase=True --dataset=eoc-2-cv 309 | $ cd .. 310 | $ python3 train.py --config_name=config/AConvNet-EOC-2-CV.json 311 | ``` 312 | 313 | ##### Results of EOC-2 Configuration Variants 314 | - Final Accuracy is **99.41%** at epoch 95 (The official accuracy is 98.93%) 315 | - You can see the details in `notebook/experiments-EOC-2-CV.ipynb` 316 | 317 | - Visualization of training loss and test accuracy 318 | 319 | ![eoc-2-cv-training-plot](./assets/figure/eoc-2-cv-training-plot.png) 320 | 321 | - Confusion Matrix with best model at **epoch 95** 322 | 323 | ![eoc-2-cv-confusion-matrix](./assets/figure/eoc-2-cv-confusion-matrix.png) 324 | 325 | 326 | #### EOC-2 (Target version variants) 327 | 328 | | | | Train | | Test | | 329 | | ------- | ---------- | ---------- | ---------- | ---------- | ---------- | 330 | | Class | Serial No. | Depression | No. Images | Depression | No. Images | 331 | | BMP-2 | 9563 | 17 | 233 | - | - | 332 | | BRDM-2 | E-71 | 17 | 298 | - | - | 333 | | BTR-70 | c71 | 17 | 233 | - | - | 334 | | T-72 | 132 | 17 | 232 | - | - | 335 | | BMP-2 | 9566 | - | - | 15, 17 | 428 | 336 | | BMP-2 | c21 | - | - | 15, 17 | 429 | 337 | | T-72 | 812 | - | - | 15, 17 | 426 | 338 | | T-72 | A04 | - | - | 15, 17 | 573 | 339 | | T-72 | A05 | - | - | 15, 17 | 573 | 340 | | T-72 | A07 | - | - | 15, 17 | 573 | 341 | | T-72 | A10 | - | - | 15, 17 | 567 | 342 | 343 | ##### Training Set (Depression: 17$\degree$) 344 | ```shell 345 | # BMP2, BRDM2, BTR70, T72 are selected from SOC training data 346 | ``` 347 | 348 | ##### Test Set (Depression: 15$\degree$) 349 | 350 | ```shell 351 | MSTAR-PublicTargetChips-T72-BMP2-BTR70-SLICY/MSTAR_PUBLIC_TARGETS_CHIPS_T72_BMP2_BTR70_SLICY 352 | ├ TARGETS/TRAIN/17_DEG 353 | │ ├ BMP2/SN_9566/*.001 (232 images) 354 | │ ├ BMP2/SN_C21/*.002 (233 images) 355 | │ ├ T72/SN_812/*.016 (231 images) 356 | │ └ ... 357 | └ ... 358 | 359 | MSTAR-PublicT72Variants-CD2/MSTAR_PUBLIC_T_72_VARIANTS_CD2 360 | ├ 17_DEG/COL2/SCENE1 361 | │ ├ A04/*.017 (299 images) 362 | │ ├ A05/*.018 (299 images) 363 | │ ├ A07/*.019 (299 images) 364 | │ ├ A10/*.020 (296 images) 365 | ├ └ ... 366 | └ ... 367 | 368 | MSTAR-PublicTargetChips-T72-BMP2-BTR70-SLICY/MSTAR_PUBLIC_TARGETS_CHIPS_T72_BMP2_BTR70_SLICY 369 | ├ TARGETS/TEST/15_DEG 370 | │ ├ BMP2/SN_9566/*.001 (196 images) 371 | │ ├ BMP2/SN_C21/*.002 (196 images) 372 | │ ├ T72/SN_812/*.0176 (195 images) 373 | │ └ ... 374 | └ ... 375 | 376 | MSTAR-PublicT72Variants-CD1/MSTAR_PUBLIC_T_72_VARIANTS_CD1 377 | ├ 15_DEG/COL2/SCENE1 378 | │ ├ A04/*.017 (274 images) 379 | │ ├ A05/*.018 (274 images) 380 | │ ├ A07/*.019 (274 images) 381 | │ ├ A10/*.020 (271 images) 382 | ├ └ ... 383 | └ ... 384 | 385 | ``` 386 | 387 | ##### Quick Start Guide for Training 388 | 389 | - Dataset Preparation 390 | - Download the [dataset.zip](https://github.com/jangsoopark/AConvNet-pytorch/releases/download/v2.2.0/dataset.zip) 391 | - After extracting it, you can find `train` and `test` directories inside `raw` directory. 392 | - Place the two directories (`train` and `test`) to the `dataset/eoc-2-vv/raw`. 393 | ```shell 394 | $ cd src/data 395 | $ python3 generate_dataset.py --is_train=True --use_phase=True --chip_size=100 --patch_size=94 --use_phase=True --dataset=eoc-2-vv 396 | $ python3 generate_dataset.py --is_train=False --use_phase=True --chip_size=128 --patch_size=128 --use_phase=True --dataset=eoc-2-vv 397 | $ cd .. 398 | $ python3 train.py --config_name=config/AConvNet-EOC-2-CV.json 399 | ``` 400 | 401 | ##### Results of EOC-2 Version Variants 402 | - Final Accuracy is **97.17%** at epoch 88 (The official accuracy is 98.60%) 403 | - You can see the details in `notebook/experiments-EOC-2-VV.ipynb` 404 | 405 | - Visualization of training loss and test accuracy 406 | 407 | ![eoc-2-vv-training-plot](./assets/figure/eoc-2-vv-training-plot.png) 408 | 409 | - Confusion Matrix with best model at **epoch 88** 410 | 411 | ![eoc-2-vv-confusion-matrix](./assets/figure/eoc-2-vv-confusion-matrix.png) 412 | 413 | 414 | 415 | ### Outlier Rejection 416 | 417 | | | | Train | | Test | | Remarks. | 418 | | ------- | ---------- | ---------- | ---------- | ---------- | ---------- | ---------- | 419 | | Class | Serial No. | Depression | No. Images | Depression | No. Images | Type. | 420 | | BMP-2 | 9563 | 17 | 233 | 15 | 196 | Known | 421 | | BTR-70 | c71 | 17 | 233 | 15 | 196 | Known | 422 | | T-72 | 132 | 17 | 232 | 15 | 196 | Known | 423 | | 2S1 | b01 | 17 | - | 15 | 274 | Confuser | 424 | | ZSU-234 | d08 | 17 | - | 15 | 274 | Confuser | 425 | 426 | ##### Training Set (Known targets in Depression: 17$\degree$) 427 | 428 | ```shell 429 | MSTAR-PublicTargetChips-T72-BMP2-BTR70-SLICY/MSTAR_PUBLIC_TARGETS_CHIPS_T72_BMP2_BTR70_SLICY 430 | ├ TARGETS/TRAIN/17_DEG # KNOWN 431 | │ ├ BMP2/SN_9563/*.000 (233 images) 432 | │ ├ BTR70/SN_C71/*.004 (233 images) 433 | │ └ T72/SN_132/*.015 (232 images) 434 | └ ... 435 | 436 | MSTAR-PublicMixedTargets-CD2/MSTAR_PUBLIC_MIXED_TARGETS_CD2 437 | ├ 17_DEG # Confuser 438 | │ └ COL2/SCENE1 439 | │ ├ 2S1/*.000 (299 images) 440 | │ └ ZIL131/*.025 (299 images) 441 | └ ... 442 | 443 | ``` 444 | 445 | 446 | ##### Test Set (Known targets and confuser targets in Depression: 15$\degree$) 447 | ```shell 448 | MSTAR-PublicTargetChips-T72-BMP2-BTR70-SLICY/MSTAR_PUBLIC_TARGETS_CHIPS_T72_BMP2_BTR70_SLICY 449 | ├ TARGETS/TEST/15_DEG # KNOWN 450 | │ ├ BMP2/SN_9563/*.000 (195 images) 451 | │ ├ BTR70/SN_C71/*.004 (196 images) 452 | │ └ T72/SN_132/*.015 (196 images) 453 | └ ... 454 | 455 | MSTAR-PublicMixedTargets-CD1/MSTAR_PUBLIC_MIXED_TARGETS_CD1 456 | ├ 15_DEG # Confuser 457 | │ └ COL2/SCENE1 458 | │ ├ 2S1/*.000 (274 images) 459 | │ └ ZIL131/*.025 (274 images) 460 | └ ... 461 | 462 | ``` 463 | 464 | ##### Quick Start Guide for Training 465 | 466 | - Dataset Preparation 467 | - Download the [dataset.zip](https://github.com/jangsoopark/AConvNet-pytorch/releases/download/v2.2.0/dataset.zip) 468 | - After extracting it, you can find `train` and `test` directories inside `raw` directory. 469 | - Place the two directories (`train` and `test`) to the `dataset/confuser-rejection/raw`. 470 | 471 | ```shell 472 | $ cd src/data 473 | $ python3 generate_dataset.py --is_train=True --use_phase=True --chip_size=100 --patch_size=94 --use_phase=True --dataset=confuser-rejection 474 | $ python3 generate_dataset.py --is_train=False --use_phase=True --chip_size=128 --patch_size=128 --use_phase=True --dataset=confuser-rejection 475 | $ cd .. 476 | $ 477 | $ # Remove Confuser objects when training to check validation accuracy with known targets 478 | $ cd ../dataset/confuser-rejection/test 479 | $ rm -rf 2S1 480 | $ rm -rf ZIL131 481 | $ cd - 482 | $ 483 | $ python3 train.py --config_name=config/AConvNet-CR.json 484 | $ 485 | $ # Restore the unknown targets 486 | $ cd ../dataset/confuser-rejection/ 487 | $ rm -rf test 488 | $ python3 generate_dataset.py --is_train=False --use_phase=True --chip_size=128 --patch_size=128 --use_phase=True --dataset=confuser-rejection 489 | ``` 490 | 491 | 492 | ##### Results of Outlier Rejection 493 | - Final Accuracy for known targets is **98.81%** at epoch 39 494 | - You can see the details in `notebook/experiments-CR.ipynb` 495 | 496 | - Visualization of training loss and test accuracy 497 | 498 | ![cr-training-plot](./assets/figure/cr-training-plot.png) 499 | 500 | - Confusion Matrix with best model at **epoch 88** 501 | 502 | ![cr-confusion-matrix](./assets/figure/cr-confusion-matrix.png) 503 | 504 | 505 | - Outlier Rejection: TODO: LOoking for more details.. 506 | ```python 507 | # Rules 508 | output_probability = model(image) 509 | is_confuser = output_probability < thresh 510 | 511 | if sum(is_confuser) == 3: 512 | target is confuser 513 | else: 514 | target is known 515 | 516 | ``` 517 | 518 | ![cr-roc](./assets/figure/cr-roc.png) 519 | 520 | 523 | ## Details about the specific environment of this repository 524 | 525 | | | | 526 | | :---------: | :------: | 527 | | OS | Ubuntu 20.04 LTS | 528 | | CPU | Intel i7-10700k | 529 | | GPU | RTX 2080Ti 11GB | 530 | | Memory | 32GB | 531 | | SSD | 500GB | 532 | | HDD | 2TB | 533 | 534 | ## Citation 535 | 536 | ```bibtex 537 | @ARTICLE{7460942, 538 | author={S. {Chen} and H. {Wang} and F. {Xu} and Y. {Jin}}, 539 | journal={IEEE Transactions on Geoscience and Remote Sensing}, 540 | title={Target Classification Using the Deep Convolutional Networks for SAR Images}, 541 | year={2016}, 542 | volume={54}, 543 | number={8}, 544 | pages={4806-4817}, 545 | doi={10.1109/TGRS.2016.2551720} 546 | } 547 | ``` 548 | 549 | ## References 550 | [1] G. Dong, N. Wang, and G. Kuang, 551 | "Sparse representation of monogenic signal: With application to target recognition in SAR images," 552 | *IEEE Signal Process. Lett.*, vol. 21, no. 8, pp. 952-956, Aug. 2014. 553 | 554 | 555 | --- 556 | 557 | ## TODO 558 | 559 | - [ ] Implementation 560 | - [ ] Data generation 561 | - [X] SOC 562 | - [X] EOC 563 | - [X] Outlier Rejection 564 | - [ ] End-to-End SAR-ATR 565 | - [ ] Data Loader 566 | - [X] SOC 567 | - [X] EOC 568 | - [X] Outlier Rejection 569 | - [ ] End-to-End SAR-ATR 570 | - [X] Model 571 | - [X] Network 572 | - [X] Training 573 | - [X] Early Stopping 574 | - [X] Hyper-parameter Optimization 575 | - [ ] Experiments 576 | - [X] Reproduce the SOC Results 577 | - [X] Reproduce the EOC Results 578 | - [X] Reproduce the outlier rejection 579 | - [ ] Reproduce the end-to-end SAR-ATR 580 | 581 | -------------------------------------------------------------------------------- /assets/figure/001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/001.png -------------------------------------------------------------------------------- /assets/figure/002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/002.png -------------------------------------------------------------------------------- /assets/figure/003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/003.png -------------------------------------------------------------------------------- /assets/figure/004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/004.png -------------------------------------------------------------------------------- /assets/figure/005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/005.png -------------------------------------------------------------------------------- /assets/figure/006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/006.png -------------------------------------------------------------------------------- /assets/figure/007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/007.png -------------------------------------------------------------------------------- /assets/figure/008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/008.png -------------------------------------------------------------------------------- /assets/figure/009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/009.png -------------------------------------------------------------------------------- /assets/figure/010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/010.png -------------------------------------------------------------------------------- /assets/figure/011.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/011.png -------------------------------------------------------------------------------- /assets/figure/2S1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/2S1.png -------------------------------------------------------------------------------- /assets/figure/confuser-rejection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/confuser-rejection.png -------------------------------------------------------------------------------- /assets/figure/cr-confusion-matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/cr-confusion-matrix.png -------------------------------------------------------------------------------- /assets/figure/cr-roc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/cr-roc.png -------------------------------------------------------------------------------- /assets/figure/cr-training-plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/cr-training-plot.png -------------------------------------------------------------------------------- /assets/figure/eoc-1-confusion-matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/eoc-1-confusion-matrix.png -------------------------------------------------------------------------------- /assets/figure/eoc-1-training-plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/eoc-1-training-plot.png -------------------------------------------------------------------------------- /assets/figure/eoc-2-cv-confusion-matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/eoc-2-cv-confusion-matrix.png -------------------------------------------------------------------------------- /assets/figure/eoc-2-cv-training-plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/eoc-2-cv-training-plot.png -------------------------------------------------------------------------------- /assets/figure/eoc-2-vv-confusion-matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/eoc-2-vv-confusion-matrix.png -------------------------------------------------------------------------------- /assets/figure/eoc-2-vv-training-plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/eoc-2-vv-training-plot.png -------------------------------------------------------------------------------- /assets/figure/soc-confusion-matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/soc-confusion-matrix.png -------------------------------------------------------------------------------- /assets/figure/soc-training-plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/assets/figure/soc-training-plot.png -------------------------------------------------------------------------------- /dataset/confuser-rejection/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/dataset/confuser-rejection/.gitkeep -------------------------------------------------------------------------------- /dataset/confuser-rejection/raw/rename.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | 5 | target_list = glob.glob('./*/*') 6 | 7 | for name in target_list: 8 | print(name, name.replace('_', '')) 9 | os.rename(name, name.replace('_', '')) 10 | -------------------------------------------------------------------------------- /dataset/eoc-1-t72-132/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/dataset/eoc-1-t72-132/.gitkeep -------------------------------------------------------------------------------- /dataset/eoc-1-t72-132/raw/rename.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | 5 | target_list = glob.glob('./*/*') 6 | 7 | for name in target_list: 8 | print(name, name.replace('_', '')) 9 | os.rename(name, name.replace('_', '')) 10 | -------------------------------------------------------------------------------- /dataset/eoc-1-t72-a64/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/dataset/eoc-1-t72-a64/.gitkeep -------------------------------------------------------------------------------- /dataset/eoc-1-t72-a64/raw/rename.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | 5 | target_list = glob.glob('./*/*') 6 | 7 | for name in target_list: 8 | print(name, name.replace('_', '')) 9 | os.rename(name, name.replace('_', '')) 10 | -------------------------------------------------------------------------------- /dataset/eoc-2-cv/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/dataset/eoc-2-cv/.gitkeep -------------------------------------------------------------------------------- /dataset/eoc-2-cv/raw/rename.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | 5 | target_list = glob.glob('./*/*') 6 | 7 | for name in target_list: 8 | print(name, name.replace('_', '')) 9 | os.rename(name, name.replace('_', '')) 10 | -------------------------------------------------------------------------------- /dataset/eoc-2-vv/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/dataset/eoc-2-vv/.gitkeep -------------------------------------------------------------------------------- /dataset/eoc-2-vv/raw/rename.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | 5 | target_list = glob.glob('./*/*') 6 | 7 | for name in target_list: 8 | print(name, name.replace('_', '')) 9 | os.rename(name, name.replace('_', '')) 10 | -------------------------------------------------------------------------------- /dataset/soc/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/dataset/soc/.gitkeep -------------------------------------------------------------------------------- /dataset/soc/raw/rename.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | 5 | target_list = glob.glob('./*/*') 6 | 7 | for name in target_list: 8 | print(name, name.replace('_', '')) 9 | os.rename(name, name.replace('_', '')) 10 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | 2 | # docker build . -t aconvnet-pytorch 3 | # Base container: docker pull pytorch/pytorch:1.9.0-cuda11.1-cudnn8-devel 4 | 5 | FROM pytorch/pytorch:1.9.0-cuda11.1-cudnn8-devel 6 | 7 | ARG DEBIAN_FRONTEND=noninteractive 8 | 9 | RUN apt update 10 | 11 | RUN pip install seaborn && \ 12 | pip install numpy && \ 13 | pip install scipy&& \ 14 | pip install tqdm && \ 15 | pip install jupyter && \ 16 | pip install matplotlib && \ 17 | pip install scikit-image && \ 18 | pip install scikit-learn && \ 19 | pip install opencv-python && \ 20 | pip install absl-py && \ 21 | pip install optuna 22 | 23 | 24 | RUN apt update && \ 25 | apt install -y wget vim emacs nano libgl1-mesa-glx 26 | 27 | 28 | RUN mkdir -p /workspace 29 | 30 | ARG work_dir=/workspace 31 | 32 | WORKDIR ${work_dir} 33 | -------------------------------------------------------------------------------- /experiments/config/AConvNet-CR.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "AConvNet-CR", 3 | "dataset": "confuser-rejection", 4 | "num_classes": 3, 5 | "channels": 2, 6 | "batch_size": 100, 7 | "epochs": 100, 8 | "momentum": 0.9, 9 | "lr": 1e-3, 10 | "lr_step": [50], 11 | "lr_decay": 0.1, 12 | "weight_decay": 4e-3, 13 | "dropout_rate": 0.5 14 | } 15 | -------------------------------------------------------------------------------- /experiments/config/AConvNet-EOC-1-T72-132.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "AConvNet-EOC-1-T72-132", 3 | "dataset": "eoc-1-t72-132", 4 | "num_classes": 4, 5 | "channels": 2, 6 | "batch_size": 100, 7 | "epochs": 100, 8 | "momentum": 0.9, 9 | "lr": 1e-3, 10 | "lr_step": [50], 11 | "lr_decay": 0.1, 12 | "weight_decay": 4e-3, 13 | "dropout_rate": 0.5 14 | } 15 | -------------------------------------------------------------------------------- /experiments/config/AConvNet-EOC-1-T72-A64.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "AConvNet-EOC-1-T72-A64", 3 | "dataset": "eoc-1-t72-a64", 4 | "num_classes": 4, 5 | "channels": 2, 6 | "batch_size": 100, 7 | "epochs": 100, 8 | "momentum": 0.9, 9 | "lr": 1e-3, 10 | "lr_step": [50], 11 | "lr_decay": 0.1, 12 | "weight_decay": 4e-3, 13 | "dropout_rate": 0.5 14 | } 15 | -------------------------------------------------------------------------------- /experiments/config/AConvNet-EOC-2-CV.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "AConvNet-EOC-2-CV", 3 | "dataset": "eoc-2-cv", 4 | "num_classes": 4, 5 | "channels": 2, 6 | "batch_size": 100, 7 | "epochs": 100, 8 | "momentum": 0.9, 9 | "lr": 1e-3, 10 | "lr_step": [5], 11 | "lr_decay": 0.1, 12 | "weight_decay": 4e-3, 13 | "dropout_rate": 0.5 14 | } 15 | -------------------------------------------------------------------------------- /experiments/config/AConvNet-EOC-2-VV.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "AConvNet-EOC-2-VV", 3 | "dataset": "eoc-2-vv", 4 | "num_classes": 4, 5 | "channels": 2, 6 | "batch_size": 100, 7 | "epochs": 100, 8 | "momentum": 0.9, 9 | "lr": 1e-3, 10 | "lr_step": [5], 11 | "lr_decay": 0.1, 12 | "weight_decay": 4e-3, 13 | "dropout_rate": 0.5 14 | } 15 | -------------------------------------------------------------------------------- /experiments/config/AConvNet-SOC.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "AConvNet-SOC", 3 | "dataset": "soc", 4 | "num_classes": 10, 5 | "channels": 2, 6 | "batch_size": 100, 7 | "epochs": 100, 8 | "momentum": 0.9, 9 | "lr": 1e-3, 10 | "lr_step": [50], 11 | "lr_decay": 0.1, 12 | "weight_decay": 4e-3, 13 | "dropout_rate": 0.5 14 | } 15 | -------------------------------------------------------------------------------- /notebook/experiments-EOC-1-T72-A64.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import matplotlib.pyplot as plt\n", 10 | "\n", 11 | "import numpy as np\n", 12 | "\n", 13 | "import json\n", 14 | "import glob\n", 15 | "import sys\n", 16 | "import os\n", 17 | "\n", 18 | "sys.path.append('../src')" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "### Visualization of training loss and test accuracy" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "with open('../experiments/history/history-AConvNet-EOC-1-T72-A64.json') as f:\n", 35 | " history = json.load(f)" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 3, 41 | "metadata": {}, 42 | "outputs": [ 43 | { 44 | "data": { 45 | "image/png": "\n", 46 | "text/plain": [ 47 | "
" 48 | ] 49 | }, 50 | "metadata": { 51 | "needs_background": "light" 52 | }, 53 | "output_type": "display_data" 54 | } 55 | ], 56 | "source": [ 57 | "training_loss = history['loss']\n", 58 | "test_accuracy = history['accuracy']\n", 59 | "\n", 60 | "epochs = np.arange(len(training_loss))\n", 61 | "\n", 62 | "fig, ax1 = plt.subplots()\n", 63 | "ax2 = ax1.twinx()\n", 64 | "\n", 65 | "plot1, = ax1.plot(epochs, training_loss, marker='.', c='blue', label='loss')\n", 66 | "plot2, = ax2.plot(epochs, test_accuracy, marker='.', c='red', label='accuracy')\n", 67 | "plt.legend([plot1, plot2], ['loss', 'accuracy'], loc='upper right')\n", 68 | "\n", 69 | "plt.grid()\n", 70 | "\n", 71 | "ax1.set_xlabel('Epoch')\n", 72 | "ax1.set_ylabel('loss', color='blue')\n", 73 | "ax2.set_ylabel('accuracy', color='red')\n", 74 | "plt.show()\n" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "### Early Stopping" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 4, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "from tqdm import tqdm\n", 91 | "import torchvision\n", 92 | "import torch\n", 93 | "\n", 94 | "from utils import common\n", 95 | "from data import preprocess\n", 96 | "from data import loader\n", 97 | "import model" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 5, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "def load_dataset(path, is_train, name, batch_size):\n", 107 | "\n", 108 | " _dataset = loader.Dataset(\n", 109 | " path, name=name, is_train=is_train,\n", 110 | " transform=torchvision.transforms.Compose([\n", 111 | " preprocess.CenterCrop(88), torchvision.transforms.ToTensor()\n", 112 | " ])\n", 113 | " )\n", 114 | " data_loader = torch.utils.data.DataLoader(\n", 115 | " _dataset, batch_size=batch_size, shuffle=is_train, num_workers=1\n", 116 | " )\n", 117 | " return data_loader\n", 118 | "\n", 119 | "\n", 120 | "def evaluate(_m, ds):\n", 121 | " \n", 122 | " num_data = 0\n", 123 | " corrects = 0\n", 124 | " \n", 125 | " _m.net.eval()\n", 126 | " _softmax = torch.nn.Softmax(dim=1)\n", 127 | " for i, data in enumerate(ds):\n", 128 | " images, labels, _ = data\n", 129 | "\n", 130 | " predictions = _m.inference(images)\n", 131 | " predictions = _softmax(predictions)\n", 132 | "\n", 133 | " _, predictions = torch.max(predictions.data, 1)\n", 134 | " labels = labels.type(torch.LongTensor)\n", 135 | " num_data += labels.size(0)\n", 136 | " corrects += (predictions == labels.to(m.device)).sum().item()\n", 137 | "\n", 138 | " accuracy = 100 * corrects / num_data\n", 139 | " return accuracy" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 6, 145 | "metadata": {}, 146 | "outputs": [ 147 | { 148 | "name": "stderr", 149 | "output_type": "stream", 150 | "text": [ 151 | "load test data set: 1151it [00:00, 2070.14it/s]\n", 152 | "d:\\ivs\\project\\004-research\\signal-processing\\image-processing\\remote-sensing\\aconvnet\\aconvnet-pytorch\\venv\\lib\\site-packages\\torch\\nn\\functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at ..\\c10/core/TensorImpl.h:1156.)\n", 153 | " return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n" 154 | ] 155 | }, 156 | { 157 | "name": "stdout", 158 | "output_type": "stream", 159 | "text": [ 160 | "Best accuracy at epoch=0 with 53.69%\n", 161 | "Best accuracy at epoch=1 with 70.20%\n", 162 | "Best accuracy at epoch=2 with 84.88%\n", 163 | "Best accuracy at epoch=3 with 89.31%\n", 164 | "Best accuracy at epoch=5 with 90.44%\n", 165 | "Best accuracy at epoch=15 with 90.96%\n", 166 | "Best accuracy at epoch=17 with 91.49%\n", 167 | "Final model is epoch=17 with accurayc=91.49%\n", 168 | "Path=D:\\ivs\\Project\\004-research\\signal-processing\\image-processing\\remote-sensing\\aconvnet\\AConvNet-pytorch\\experiments/model/AConvNet-EOC-1-T72-A64\\model-018.pth\n" 169 | ] 170 | } 171 | ], 172 | "source": [ 173 | "\n", 174 | "config = common.load_config(os.path.join(common.project_root, 'experiments/config/AConvNet-EOC-1-T72-A64.json'))\n", 175 | "model_name = config['model_name']\n", 176 | "test_set = load_dataset('dataset', False, 'eoc-1-t72-a64', 100)\n", 177 | "\n", 178 | "m = model.Model(\n", 179 | " classes=config['num_classes'], channels=config['channels'],\n", 180 | ")\n", 181 | "\n", 182 | "model_history = glob.glob(os.path.join(common.project_root, f'experiments/model/{model_name}/*.pth'))\n", 183 | "model_history = sorted(model_history, key=os.path.basename)\n", 184 | "\n", 185 | "best = {\n", 186 | " 'epoch': 0,\n", 187 | " 'accuracy': 0,\n", 188 | " 'path': ''\n", 189 | "}\n", 190 | "\n", 191 | "for i, model_path in enumerate(model_history):\n", 192 | " m.load(model_path)\n", 193 | " accuracy = evaluate(m, test_set)\n", 194 | " if accuracy > best['accuracy']:\n", 195 | " best['epoch'] = i\n", 196 | " best['accuracy'] = accuracy\n", 197 | " best['path'] = model_path\n", 198 | " print(f'Best accuracy at epoch={i} with {accuracy:.2f}%')\n", 199 | " \n", 200 | "best_epoch = best['epoch']\n", 201 | "best_accuracy = best['accuracy']\n", 202 | "best_path = best['path']\n", 203 | "\n", 204 | "print(f'Final model is epoch={best_epoch} with accurayc={best_accuracy:.2f}%')\n", 205 | "print(f'Path={best_path}')" 206 | ] 207 | }, 208 | { 209 | "cell_type": "markdown", 210 | "metadata": {}, 211 | "source": [ 212 | "### Confusion Matrix with Best Model" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 7, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "from sklearn import metrics\n", 222 | "from data import mstar\n", 223 | "\n", 224 | "def confusion_matrix(_m, ds):\n", 225 | " _pred = []\n", 226 | " _gt = []\n", 227 | " \n", 228 | " _m.net.eval()\n", 229 | " _softmax = torch.nn.Softmax(dim=1)\n", 230 | " for i, data in enumerate(ds):\n", 231 | " images, labels, _ = data\n", 232 | " \n", 233 | " predictions = _m.inference(images)\n", 234 | " predictions = _softmax(predictions)\n", 235 | "\n", 236 | " _, predictions = torch.max(predictions.data, 1)\n", 237 | " labels = labels.type(torch.LongTensor)\n", 238 | " \n", 239 | " _pred += predictions.cpu().tolist()\n", 240 | " _gt += labels.cpu().tolist()\n", 241 | " \n", 242 | " conf_mat = metrics.confusion_matrix(_gt, _pred)\n", 243 | " \n", 244 | " return conf_mat" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 8, 250 | "metadata": {}, 251 | "outputs": [ 252 | { 253 | "data": { 254 | "image/png": "\n", 255 | "text/plain": [ 256 | "
" 257 | ] 258 | }, 259 | "metadata": {}, 260 | "output_type": "display_data" 261 | } 262 | ], 263 | "source": [ 264 | "import matplotlib.pyplot as plt\n", 265 | "import seaborn as sns\n", 266 | "\n", 267 | "m.load(best_path)\n", 268 | "_conf_mat = confusion_matrix(m, test_set)\n", 269 | "\n", 270 | "sns.reset_defaults()\n", 271 | "ax = sns.heatmap(_conf_mat, annot=True, fmt='d', cbar=False)\n", 272 | "ax.set_yticklabels(mstar.target_name_eoc_1, rotation=0)\n", 273 | "ax.set_xticklabels(mstar.target_name_eoc_1, rotation=30)\n", 274 | "\n", 275 | "plt.xlabel('prediction', fontsize=12)\n", 276 | "plt.ylabel('label', fontsize=12)\n", 277 | "\n", 278 | "\n", 279 | "plt.show()" 280 | ] 281 | }, 282 | { 283 | "cell_type": "markdown", 284 | "metadata": {}, 285 | "source": [ 286 | "### Noise Simulation" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": 9, 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [ 295 | "from skimage import util\n", 296 | "\n", 297 | "\n", 298 | "def generate_noise(_images, amount):\n", 299 | " \n", 300 | " n, _, h, w = _images.shape\n", 301 | " \n", 302 | " noise = np.array([np.random.uniform(size=(1, h, w)) for _ in range(n)])\n", 303 | " portions = np.array([\n", 304 | " util.random_noise(np.zeros((1, h, w)), mode='s&p', amount=amount)\n", 305 | " for _ in range(n)\n", 306 | " ])\n", 307 | " noise = noise * portions\n", 308 | " \n", 309 | " return _images + noise.astype(np.float32)\n", 310 | "\n", 311 | "\n", 312 | "def noise_simulation(_m, ds, noise_ratio):\n", 313 | " \n", 314 | " num_data = 0\n", 315 | " corrects = 0\n", 316 | " \n", 317 | " _m.net.eval()\n", 318 | " _softmax = torch.nn.Softmax(dim=1)\n", 319 | " for i, data in enumerate(ds):\n", 320 | " images, labels, _ = data\n", 321 | " images = generate_noise(images, noise_ratio)\n", 322 | "\n", 323 | " predictions = _m.inference(images)\n", 324 | " predictions = _softmax(predictions)\n", 325 | "\n", 326 | " _, predictions = torch.max(predictions.data, 1)\n", 327 | " labels = labels.type(torch.LongTensor)\n", 328 | " num_data += labels.size(0)\n", 329 | " corrects += (predictions == labels.to(m.device)).sum().item()\n", 330 | "\n", 331 | " accuracy = 100 * corrects / num_data\n", 332 | " \n", 333 | " return accuracy" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 10, 339 | "metadata": {}, 340 | "outputs": [ 341 | { 342 | "name": "stdout", 343 | "output_type": "stream", 344 | "text": [ 345 | "ratio = 0.01, accuracy = 90.96\n", 346 | "ratio = 0.05, accuracy = 85.23\n", 347 | "ratio = 0.10, accuracy = 83.23\n", 348 | "ratio = 0.15, accuracy = 80.54\n" 349 | ] 350 | } 351 | ], 352 | "source": [ 353 | "noise_result = {}\n", 354 | "\n", 355 | "for ratio in [0.01, 0.05, 0.10, 0.15]:\n", 356 | " noise_result[ratio] = noise_simulation(m, test_set, ratio)\n", 357 | " print(f'ratio = {ratio:.2f}, accuracy = {noise_result[ratio]:.2f}')\n" 358 | ] 359 | } 360 | ], 361 | "metadata": { 362 | "kernelspec": { 363 | "display_name": "Python 3 (ipykernel)", 364 | "language": "python", 365 | "name": "python3" 366 | }, 367 | "language_info": { 368 | "codemirror_mode": { 369 | "name": "ipython", 370 | "version": 3 371 | }, 372 | "file_extension": ".py", 373 | "mimetype": "text/x-python", 374 | "name": "python", 375 | "nbconvert_exporter": "python", 376 | "pygments_lexer": "ipython3", 377 | "version": "3.7.9" 378 | } 379 | }, 380 | "nbformat": 4, 381 | "nbformat_minor": 4 382 | } 383 | -------------------------------------------------------------------------------- /notebook/experiments-EOC-2-CV.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import matplotlib.pyplot as plt\n", 10 | "\n", 11 | "import numpy as np\n", 12 | "\n", 13 | "import json\n", 14 | "import glob\n", 15 | "import sys\n", 16 | "import os\n", 17 | "\n", 18 | "sys.path.append('../src')" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "### Visualization of training loss and test accuracy" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "with open('../experiments/history/history-AConvNet-EOC-2-CV.json') as f:\n", 35 | " history = json.load(f)" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 3, 41 | "metadata": {}, 42 | "outputs": [ 43 | { 44 | "data": { 45 | "image/png": "\n", 46 | "text/plain": [ 47 | "
" 48 | ] 49 | }, 50 | "metadata": { 51 | "needs_background": "light" 52 | }, 53 | "output_type": "display_data" 54 | } 55 | ], 56 | "source": [ 57 | "training_loss = history['loss']\n", 58 | "test_accuracy = history['accuracy']\n", 59 | "\n", 60 | "epochs = np.arange(len(training_loss))\n", 61 | "\n", 62 | "fig, ax1 = plt.subplots()\n", 63 | "ax2 = ax1.twinx()\n", 64 | "\n", 65 | "plot1, = ax1.plot(epochs, training_loss, marker='.', c='blue', label='loss')\n", 66 | "plot2, = ax2.plot(epochs, test_accuracy, marker='.', c='red', label='accuracy')\n", 67 | "plt.legend([plot1, plot2], ['loss', 'accuracy'], loc='upper right')\n", 68 | "\n", 69 | "plt.grid()\n", 70 | "\n", 71 | "ax1.set_xlabel('Epoch')\n", 72 | "ax1.set_ylabel('loss', color='blue')\n", 73 | "ax2.set_ylabel('accuracy', color='red')\n", 74 | "plt.show()\n" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "### Early Stopping" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 4, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "from tqdm import tqdm\n", 91 | "import torchvision\n", 92 | "import torch\n", 93 | "\n", 94 | "from utils import common\n", 95 | "from data import preprocess\n", 96 | "from data import loader\n", 97 | "import model" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 5, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "def load_dataset(path, is_train, name, batch_size):\n", 107 | "\n", 108 | " _dataset = loader.Dataset(\n", 109 | " path, name=name, is_train=is_train,\n", 110 | " transform=torchvision.transforms.Compose([\n", 111 | " preprocess.CenterCrop(88), torchvision.transforms.ToTensor()\n", 112 | " ])\n", 113 | " )\n", 114 | " data_loader = torch.utils.data.DataLoader(\n", 115 | " _dataset, batch_size=batch_size, shuffle=is_train, num_workers=1\n", 116 | " )\n", 117 | " return data_loader\n", 118 | "\n", 119 | "\n", 120 | "def evaluate(_m, ds):\n", 121 | " \n", 122 | " num_data = 0\n", 123 | " corrects = 0\n", 124 | " \n", 125 | " _m.net.eval()\n", 126 | " _softmax = torch.nn.Softmax(dim=1)\n", 127 | " for i, data in enumerate(ds):\n", 128 | " images, labels, _ = data\n", 129 | "\n", 130 | " predictions = _m.inference(images)\n", 131 | " predictions = _softmax(predictions)\n", 132 | "\n", 133 | " _, predictions = torch.max(predictions.data, 1)\n", 134 | " labels = labels.type(torch.LongTensor)\n", 135 | " num_data += labels.size(0)\n", 136 | " corrects += (predictions == labels.to(m.device)).sum().item()\n", 137 | "\n", 138 | " accuracy = 100 * corrects / num_data\n", 139 | " return accuracy" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 6, 145 | "metadata": {}, 146 | "outputs": [ 147 | { 148 | "name": "stderr", 149 | "output_type": "stream", 150 | "text": [ 151 | "load test data set: 2710it [00:01, 2205.05it/s]\n", 152 | "d:\\ivs\\project\\004-research\\signal-processing\\image-processing\\remote-sensing\\aconvnet\\aconvnet-pytorch\\venv\\lib\\site-packages\\torch\\nn\\functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at ..\\c10/core/TensorImpl.h:1156.)\n", 153 | " return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n" 154 | ] 155 | }, 156 | { 157 | "name": "stdout", 158 | "output_type": "stream", 159 | "text": [ 160 | "Best accuracy at epoch=0 with 8.45%\n", 161 | "Best accuracy at epoch=1 with 61.73%\n", 162 | "Best accuracy at epoch=3 with 81.92%\n", 163 | "Best accuracy at epoch=4 with 85.90%\n", 164 | "Best accuracy at epoch=5 with 96.79%\n", 165 | "Best accuracy at epoch=9 with 97.82%\n", 166 | "Best accuracy at epoch=17 with 98.78%\n", 167 | "Best accuracy at epoch=23 with 99.15%\n", 168 | "Best accuracy at epoch=35 with 99.19%\n", 169 | "Best accuracy at epoch=42 with 99.30%\n", 170 | "Best accuracy at epoch=95 with 99.41%\n", 171 | "Final model is epoch=95 with accurayc=99.41%\n", 172 | "Path=D:\\ivs\\Project\\004-research\\signal-processing\\image-processing\\remote-sensing\\aconvnet\\AConvNet-pytorch\\experiments/model/AConvNet-EOC-2-CV\\model-096.pth\n" 173 | ] 174 | } 175 | ], 176 | "source": [ 177 | "\n", 178 | "config = common.load_config(os.path.join(common.project_root, 'experiments/config/AConvNet-EOC-2-CV.json'))\n", 179 | "model_name = config['model_name']\n", 180 | "test_set = load_dataset('dataset', False, 'eoc-2-cv', 100)\n", 181 | "\n", 182 | "m = model.Model(\n", 183 | " classes=config['num_classes'], channels=config['channels'],\n", 184 | ")\n", 185 | "\n", 186 | "model_history = glob.glob(os.path.join(common.project_root, f'experiments/model/{model_name}/*.pth'))\n", 187 | "model_history = sorted(model_history, key=os.path.basename)\n", 188 | "\n", 189 | "best = {\n", 190 | " 'epoch': 0,\n", 191 | " 'accuracy': 0,\n", 192 | " 'path': ''\n", 193 | "}\n", 194 | "\n", 195 | "for i, model_path in enumerate(model_history):\n", 196 | " m.load(model_path)\n", 197 | " accuracy = evaluate(m, test_set)\n", 198 | " if accuracy > best['accuracy']:\n", 199 | " best['epoch'] = i\n", 200 | " best['accuracy'] = accuracy\n", 201 | " best['path'] = model_path\n", 202 | " print(f'Best accuracy at epoch={i} with {accuracy:.2f}%')\n", 203 | " \n", 204 | "best_epoch = best['epoch']\n", 205 | "best_accuracy = best['accuracy']\n", 206 | "best_path = best['path']\n", 207 | "\n", 208 | "print(f'Final model is epoch={best_epoch} with accurayc={best_accuracy:.2f}%')\n", 209 | "print(f'Path={best_path}')" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "metadata": {}, 215 | "source": [ 216 | "### Confusion Matrix with Best Model" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 7, 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "from sklearn import metrics\n", 226 | "from data import mstar\n", 227 | "\n", 228 | "def confusion_matrix(_m, ds):\n", 229 | " conf_mat = {\n", 230 | " 'A32': np.zeros((1, 4), dtype=np.int32),\n", 231 | " 'A62': np.zeros((1, 4), dtype=np.int32),\n", 232 | " 'A63': np.zeros((1, 4), dtype=np.int32),\n", 233 | " 'A64': np.zeros((1, 4), dtype=np.int32),\n", 234 | " 's7': np.zeros((1, 4), dtype=np.int32),\n", 235 | " }\n", 236 | " _pred = []\n", 237 | " _gt = []\n", 238 | " \n", 239 | " _m.net.eval()\n", 240 | " _softmax = torch.nn.Softmax(dim=1)\n", 241 | " for i, data in enumerate(ds):\n", 242 | " images, labels, serial_numbers = data\n", 243 | " \n", 244 | " predictions = _m.inference(images)\n", 245 | " predictions = _softmax(predictions)\n", 246 | "\n", 247 | " _, predictions = torch.max(predictions.data, 1)\n", 248 | " \n", 249 | "# _pred += predictions.cpu().tolist()\n", 250 | " for s, c in zip(serial_numbers, predictions):\n", 251 | " conf_mat[s][0, c] += 1\n", 252 | " \n", 253 | " conf_mat = np.r_[\n", 254 | " conf_mat['A32'], \n", 255 | " conf_mat['A62'], \n", 256 | " conf_mat['A63'], \n", 257 | " conf_mat['A64'], \n", 258 | " conf_mat['s7']\n", 259 | " ]\n", 260 | " return conf_mat" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 8, 266 | "metadata": {}, 267 | "outputs": [ 268 | { 269 | "data": { 270 | "image/png": "\n", 271 | "text/plain": [ 272 | "
" 273 | ] 274 | }, 275 | "metadata": {}, 276 | "output_type": "display_data" 277 | } 278 | ], 279 | "source": [ 280 | "import matplotlib.pyplot as plt\n", 281 | "import seaborn as sns\n", 282 | "\n", 283 | "m.load(best_path)\n", 284 | "_conf_mat = confusion_matrix(m, test_set)\n", 285 | "\n", 286 | "sns.reset_defaults()\n", 287 | "ax = sns.heatmap(_conf_mat, annot=True, fmt='d', cbar=False)\n", 288 | "ax.set_yticklabels(mstar.target_name_eoc_2_cv, rotation=0)\n", 289 | "ax.set_xticklabels(mstar.target_name_eoc_2, rotation=30)\n", 290 | "\n", 291 | "plt.xlabel('prediction', fontsize=12)\n", 292 | "plt.ylabel('label', fontsize=12)\n", 293 | "\n", 294 | "\n", 295 | "plt.show()" 296 | ] 297 | }, 298 | { 299 | "cell_type": "markdown", 300 | "metadata": {}, 301 | "source": [ 302 | "### Noise Simulation" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 9, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "from skimage import util\n", 312 | "\n", 313 | "\n", 314 | "def generate_noise(_images, amount):\n", 315 | " \n", 316 | " n, _, h, w = _images.shape\n", 317 | " \n", 318 | " noise = np.array([np.random.uniform(size=(1, h, w)) for _ in range(n)])\n", 319 | " portions = np.array([\n", 320 | " util.random_noise(np.zeros((1, h, w)), mode='s&p', amount=amount)\n", 321 | " for _ in range(n)\n", 322 | " ])\n", 323 | " noise = noise * portions\n", 324 | " \n", 325 | " return _images + noise.astype(np.float32)\n", 326 | "\n", 327 | "\n", 328 | "def noise_simulation(_m, ds, noise_ratio):\n", 329 | " \n", 330 | " num_data = 0\n", 331 | " corrects = 0\n", 332 | " \n", 333 | " _m.net.eval()\n", 334 | " _softmax = torch.nn.Softmax(dim=1)\n", 335 | " for i, data in enumerate(ds):\n", 336 | " images, labels, _ = data\n", 337 | " images = generate_noise(images, noise_ratio)\n", 338 | "\n", 339 | " predictions = _m.inference(images)\n", 340 | " predictions = _softmax(predictions)\n", 341 | "\n", 342 | " _, predictions = torch.max(predictions.data, 1)\n", 343 | " labels = labels.type(torch.LongTensor)\n", 344 | " num_data += labels.size(0)\n", 345 | " corrects += (predictions == labels.to(m.device)).sum().item()\n", 346 | "\n", 347 | " accuracy = 100 * corrects / num_data\n", 348 | " \n", 349 | " return accuracy" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": 10, 355 | "metadata": {}, 356 | "outputs": [ 357 | { 358 | "name": "stdout", 359 | "output_type": "stream", 360 | "text": [ 361 | "ratio = 0.01, accuracy = 99.15\n", 362 | "ratio = 0.05, accuracy = 93.10\n", 363 | "ratio = 0.10, accuracy = 61.85\n", 364 | "ratio = 0.15, accuracy = 26.94\n" 365 | ] 366 | } 367 | ], 368 | "source": [ 369 | "noise_result = {}\n", 370 | "\n", 371 | "for ratio in [0.01, 0.05, 0.10, 0.15]:\n", 372 | " noise_result[ratio] = noise_simulation(m, test_set, ratio)\n", 373 | " print(f'ratio = {ratio:.2f}, accuracy = {noise_result[ratio]:.2f}')\n" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": null, 379 | "metadata": {}, 380 | "outputs": [], 381 | "source": [] 382 | } 383 | ], 384 | "metadata": { 385 | "kernelspec": { 386 | "display_name": "Python 3 (ipykernel)", 387 | "language": "python", 388 | "name": "python3" 389 | }, 390 | "language_info": { 391 | "codemirror_mode": { 392 | "name": "ipython", 393 | "version": 3 394 | }, 395 | "file_extension": ".py", 396 | "mimetype": "text/x-python", 397 | "name": "python", 398 | "nbconvert_exporter": "python", 399 | "pygments_lexer": "ipython3", 400 | "version": "3.7.9" 401 | } 402 | }, 403 | "nbformat": 4, 404 | "nbformat_minor": 4 405 | } 406 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-image==0.18.2 2 | numpy==1.22.0 3 | absl-py 4 | torch==1.9.0+cu111 5 | tqdm==4.61.2 6 | torchvision==0.10.0+cu111 7 | matplotlib 8 | scikit-learn 9 | seaborn 10 | Pillow -------------------------------------------------------------------------------- /run-docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | WORKSPACE= 4 | 5 | docker run --gpus all --rm -it -p 8888:8888 --mount type=bind,src=${WORKSPACE},dst=/workspace aconvnet-pytorch /bin/bash 6 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/src/data/__init__.py -------------------------------------------------------------------------------- /src/data/generate_dataset.py: -------------------------------------------------------------------------------- 1 | from absl import logging 2 | from absl import flags 3 | from absl import app 4 | 5 | from multiprocessing import Pool 6 | from PIL import Image 7 | import numpy as np 8 | 9 | import json 10 | import glob 11 | import os 12 | 13 | import mstar 14 | 15 | flags.DEFINE_string('image_root', default='dataset', help='') 16 | flags.DEFINE_string('dataset', default='soc', help='') 17 | flags.DEFINE_boolean('is_train', default=False, help='') 18 | flags.DEFINE_integer('chip_size', default=100, help='') 19 | flags.DEFINE_integer('patch_size', default=94, help='') 20 | flags.DEFINE_boolean('use_phase', default=True, help='') 21 | 22 | FLAGS = flags.FLAGS 23 | 24 | project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 25 | 26 | 27 | def data_scaling(chip): 28 | r = chip.max() - chip.min() 29 | return (chip - chip.min()) / r 30 | 31 | 32 | def log_scale(chip): 33 | return np.log10(np.abs(chip) + 1) 34 | 35 | 36 | def generate(src_path, dst_path, is_train, chip_size, patch_size, use_phase, dataset): 37 | if not os.path.exists(src_path): 38 | return 39 | if not os.path.exists(dst_path): 40 | os.makedirs(dst_path, exist_ok=True) 41 | print(f'Target Name: {os.path.basename(dst_path)}') 42 | 43 | _mstar = mstar.MSTAR( 44 | name=dataset, is_train=is_train, chip_size=chip_size, patch_size=patch_size, use_phase=use_phase, stride=1 45 | ) 46 | 47 | image_list = glob.glob(os.path.join(src_path, '*')) 48 | 49 | for path in image_list: 50 | label, _images = _mstar.read(path) 51 | for i, _image in enumerate(_images): 52 | name = os.path.splitext(os.path.basename(path))[0] 53 | with open(os.path.join(dst_path, f'{name}-{i}.json'), mode='w', encoding='utf-8') as f: 54 | json.dump(label, f, ensure_ascii=False, indent=2) 55 | 56 | # _image = log_scale(_image) 57 | np.save(os.path.join(dst_path, f'{name}-{i}.npy'), _image) 58 | # Image.fromarray(data_scaling(_image)).convert('L').save(os.path.join(dst_path, f'{name}-{i}.bmp')) 59 | 60 | 61 | def main(_): 62 | dataset_root = os.path.join(project_root, FLAGS.image_root, FLAGS.dataset) 63 | raw_root = os.path.join(dataset_root, 'raw') 64 | 65 | mode = 'train' if FLAGS.is_train else 'test' 66 | 67 | output_root = os.path.join(dataset_root, mode) 68 | if not os.path.exists(output_root): 69 | os.makedirs(output_root, exist_ok=True) 70 | 71 | arguments = [ 72 | ( 73 | os.path.join(raw_root, mode, target), 74 | os.path.join(output_root, target), 75 | FLAGS.is_train, FLAGS.chip_size, FLAGS.patch_size, FLAGS.use_phase, FLAGS.dataset 76 | ) for target in mstar.target_name[FLAGS.dataset] 77 | ] 78 | 79 | with Pool(10) as p: 80 | p.starmap(generate, arguments) 81 | 82 | 83 | if __name__ == '__main__': 84 | app.run(main) 85 | -------------------------------------------------------------------------------- /src/data/loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from skimage import io 4 | import torch 5 | import tqdm 6 | 7 | import json 8 | import glob 9 | import os 10 | 11 | # import utils.common as common 12 | project_root = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) 13 | 14 | 15 | class Dataset(torch.utils.data.Dataset): 16 | 17 | def __init__(self, path, name='soc', is_train=False, transform=None): 18 | self.is_train = is_train 19 | self.name = name 20 | 21 | self.images = [] 22 | self.labels = [] 23 | self.serial_number = [] 24 | 25 | self.transform = transform 26 | self._load_data(path) 27 | 28 | def __len__(self): 29 | return len(self.labels) 30 | 31 | def __getitem__(self, idx): 32 | if torch.is_tensor(idx): 33 | idx = idx.tolist() 34 | 35 | _image = self.images[idx] 36 | _label = self.labels[idx] 37 | _serial_number = self.serial_number[idx] 38 | 39 | if self.transform: 40 | _image = self.transform(_image) 41 | 42 | return _image, _label, _serial_number 43 | 44 | def _load_data(self, path): 45 | mode = 'train' if self.is_train else 'test' 46 | 47 | image_list = glob.glob(os.path.join(project_root, path, f'{self.name}/{mode}/*/*.npy')) 48 | label_list = glob.glob(os.path.join(project_root, path, f'{self.name}/{mode}/*/*.json')) 49 | image_list = sorted(image_list, key=os.path.basename) 50 | label_list = sorted(label_list, key=os.path.basename) 51 | 52 | for image_path, label_path in tqdm.tqdm(zip(image_list, label_list), desc=f'load {mode} data set'): 53 | self.images.append(np.load(image_path)) 54 | 55 | with open(label_path, mode='r', encoding='utf-8') as f: 56 | _label = json.load(f) 57 | 58 | self.labels.append(_label['class_id']) 59 | self.serial_number.append(_label['serial_number']) 60 | -------------------------------------------------------------------------------- /src/data/mstar.py: -------------------------------------------------------------------------------- 1 | from skimage.util import shape 2 | 3 | import numpy as np 4 | import tqdm 5 | 6 | import glob 7 | import os 8 | 9 | target_name_soc = ('2S1', 'BMP2', 'BRDM2', 'BTR60', 'BTR70', 'D7', 'T62', 'T72', 'ZIL131', 'ZSU234') 10 | target_name_eoc_1 = ('2S1', 'BRDM2', 'T72', 'ZSU234') 11 | 12 | target_name_eoc_2 = ('BMP2', 'BRDM2', 'BTR70', 'T72') 13 | target_name_eoc_2_cv = ('T72-A32', 'T72-A62', 'T72-A63', 'T72-A64', 'T72-S7') 14 | target_name_eoc_2_vv = ('BMP2-9566', 'BMP2-C21', 'T72-812', 'T72-A04', 'T72-A05', 'T72-A07', 'T72-A10') 15 | 16 | target_name_confuser_rejection = ('BMP2', 'BTR70', 'T72', '2S1', 'ZIL131') 17 | 18 | target_name = { 19 | 'soc': target_name_soc, 20 | 'eoc-1': target_name_eoc_1, 21 | 'eoc-1-t72-132': target_name_eoc_1, 22 | 'eoc-1-t72-a64': target_name_eoc_1, 23 | 'eoc-2-cv': target_name_eoc_2 + target_name_eoc_2_cv, 24 | 'eoc-2-vv': target_name_eoc_2 + target_name_eoc_2_vv, 25 | 'confuser-rejection': target_name_confuser_rejection 26 | } 27 | 28 | serial_number = { 29 | 'b01': 0, 30 | 31 | '9563': 1, 32 | '9566': 1, 33 | 'c21': 1, 34 | 35 | 'E-71': 2, 36 | 'k10yt7532': 3, 37 | 'c71': 4, 38 | '92v13015': 5, 39 | 'A51': 6, 40 | 41 | '132': 7, 42 | '812': 7, 43 | 's7': 7, 44 | 'A04': 7, 45 | 'A05': 7, 46 | 'A07': 7, 47 | 'A10': 7, 48 | 'A32': 7, 49 | 'A62': 7, 50 | 'A63': 7, 51 | 'A64': 7, 52 | 53 | 'E12': 8, 54 | 'd08': 9 55 | } 56 | 57 | 58 | class MSTAR(object): 59 | 60 | def __init__(self, name='soc', is_train=False, use_phase=False, chip_size=94, patch_size=88, stride=40): 61 | self.name = name 62 | self.is_train = is_train 63 | self.use_phase = use_phase 64 | self.chip_size = chip_size 65 | self.patch_size = patch_size 66 | self.stride = stride 67 | 68 | def read(self, path): 69 | f = open(path, 'rb') 70 | _header = self._parse_header(f) 71 | _data = np.fromfile(f, dtype='>f4') 72 | f.close() 73 | 74 | h = eval(_header['NumberOfRows']) 75 | w = eval(_header['NumberOfColumns']) 76 | 77 | _data = _data.reshape(-1, h, w) 78 | _data = _data.transpose(1, 2, 0) 79 | _data = _data.astype(np.float32) 80 | if not self.use_phase: 81 | _data = np.expand_dims(_data[:, :, 0], axis=2) 82 | 83 | # _data = self._normalize(_data) 84 | _data = self._center_crop(_data) 85 | 86 | if self.is_train: 87 | _data = self._data_augmentation(_data, patch_size=self.patch_size, stride=self.stride) 88 | else: 89 | _data = [self._center_crop(_data, size=self.patch_size)] 90 | 91 | meta_label = self._extract_meta_label(_header) 92 | return meta_label, _data 93 | 94 | @staticmethod 95 | def _parse_header(file): 96 | header = {} 97 | for line in file: 98 | line = line.decode('utf-8') 99 | line = line.strip() 100 | 101 | if not line: 102 | continue 103 | 104 | if 'PhoenixHeaderVer' in line: 105 | continue 106 | 107 | if 'EndofPhoenixHeader' in line: 108 | break 109 | 110 | key, value = line.split('=') 111 | header[key.strip()] = value.strip() 112 | 113 | return header 114 | 115 | @staticmethod 116 | def _center_crop(data, size=128): 117 | h, w, _ = data.shape 118 | 119 | y = (h - size) // 2 120 | x = (w - size) // 2 121 | 122 | return data[y: y + size, x: x + size] 123 | 124 | def _data_augmentation(self, data, patch_size=88, stride=40): 125 | # patch extraction 126 | _data = MSTAR._center_crop(data, size=self.chip_size) 127 | _, _, channels = _data.shape 128 | patches = shape.view_as_windows(_data, window_shape=(patch_size, patch_size, channels), step=stride) 129 | patches = patches.reshape(-1, patch_size, patch_size, channels) 130 | return patches 131 | 132 | def _extract_meta_label(self, header): 133 | 134 | target_type = header['TargetType'] 135 | sn = header['TargetSerNum'] 136 | 137 | class_id = serial_number[sn] 138 | if not self.name == 'soc': 139 | class_id = target_name[self.name].index(target_name_soc[class_id]) 140 | 141 | azimuth_angle = MSTAR._get_azimuth_angle(header['TargetAz']) 142 | 143 | return { 144 | 'class_id': class_id, 145 | 'target_type': target_type, 146 | 'serial_number': sn, 147 | 'azimuth_angle': azimuth_angle 148 | } 149 | 150 | @staticmethod 151 | def _get_azimuth_angle(angle): 152 | azimuth_angle = eval(angle) 153 | if azimuth_angle > 180: 154 | azimuth_angle -= 180 155 | return int(azimuth_angle) 156 | 157 | @staticmethod 158 | def _normalize(x): 159 | d = (x - x.min()) / (x.max() - x.min()) 160 | return d.astype(np.float32) 161 | -------------------------------------------------------------------------------- /src/data/preprocess.py: -------------------------------------------------------------------------------- 1 | from skimage import transform 2 | import numpy as np 3 | 4 | 5 | class ToTensor(object): 6 | 7 | def __init__(self): 8 | pass 9 | 10 | def __call__(self, sample): 11 | _input = sample 12 | 13 | if len(_input.shape) < 3: 14 | _input = np.expand_dims(_input, axis=2) 15 | 16 | _input = _input.transpose((2, 0, 1)) 17 | 18 | return _input 19 | 20 | 21 | class RandomCrop(object): 22 | 23 | def __init__(self, size): 24 | if isinstance(size, int): 25 | self.size = (size, size) 26 | else: 27 | assert len(size) == 2 28 | self.size = size 29 | 30 | def __call__(self, sample): 31 | _input = sample 32 | 33 | if len(_input.shape) < 3: 34 | _input = np.expand_dims(_input, axis=2) 35 | 36 | h, w, _ = _input.shape 37 | oh, ow = self.size 38 | 39 | dh = h - oh 40 | dw = w - ow 41 | y = np.random.randint(0, dh) if dh > 0 else 0 42 | x = np.random.randint(0, dw) if dw > 0 else 0 43 | oh = oh if dh > 0 else h 44 | ow = ow if dw > 0 else w 45 | 46 | return _input[y: y + oh, x: x + ow, :] 47 | 48 | 49 | class CenterCrop(object): 50 | 51 | def __init__(self, size): 52 | if isinstance(size, int): 53 | self.size = (size, size) 54 | else: 55 | assert len(size) == 2 56 | self.size = size 57 | 58 | def __call__(self, sample): 59 | _input = sample 60 | 61 | if len(_input.shape) < 3: 62 | _input = np.expand_dims(_input, axis=2) 63 | 64 | h, w, _ = _input.shape 65 | oh, ow = self.size 66 | y = (h - oh) // 2 67 | x = (w - ow) // 2 68 | 69 | return _input[y: y + oh, x: x + ow, :] 70 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | from ._base import * 2 | -------------------------------------------------------------------------------- /src/model/_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import model.network 4 | 5 | 6 | class Model(object): 7 | def __init__(self, **params): 8 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 9 | self.net = model.network.Network( 10 | classes=params.get('classes', 10), 11 | channels=params.get('channels', 1), 12 | dropout_rate=params.get('dropout_rate', 0.5) 13 | ) 14 | self.net.to(self.device) 15 | 16 | self.lr = params.get('lr', 1e-3) 17 | self.lr_step = params.get('lr_step', [50]) 18 | self.lr_decay = params.get('lr_decay', 0.1) 19 | 20 | self.lr_scheduler = None 21 | 22 | self.momentum = params.get('momentum', 0.9) 23 | self.weight_decay = params.get('weight_decay', 4e-3) 24 | 25 | self.criterion = torch.nn.CrossEntropyLoss() 26 | self.optimizer = torch.optim.SGD( 27 | self.net.parameters(), 28 | lr=self.lr, 29 | momentum=self.momentum, 30 | weight_decay=self.weight_decay 31 | ) 32 | 33 | if self.lr_decay: 34 | self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( 35 | optimizer=self.optimizer, 36 | milestones=self.lr_step, 37 | gamma=self.lr_decay 38 | ) 39 | 40 | def optimize(self, x, y): 41 | p = self.net(x.to(self.device)) 42 | loss = self.criterion(p, y.to(self.device)) 43 | 44 | self.optimizer.zero_grad() 45 | loss.backward() 46 | self.optimizer.step() 47 | 48 | return loss.item() 49 | 50 | @torch.no_grad() 51 | def inference(self, x): 52 | return self.net(x.to(self.device)) 53 | 54 | def save(self, path): 55 | torch.save(self.net.state_dict(), path) 56 | 57 | def load(self, path): 58 | self.net.load_state_dict(torch.load(path)) 59 | self.net.eval() 60 | -------------------------------------------------------------------------------- /src/model/_blocks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | import collections 4 | 5 | _activations = { 6 | 'relu': nn.ReLU, 7 | 'relu6': nn.ReLU6, 8 | 'leaky_relu': nn.LeakyReLU 9 | } 10 | 11 | 12 | class BaseBlock(nn.Module): 13 | 14 | def __init__(self): 15 | super(BaseBlock, self).__init__() 16 | self._layer: nn.Sequential 17 | 18 | def forward(self, x): 19 | return self._layer(x) 20 | 21 | 22 | class DenseBlock(BaseBlock): 23 | 24 | def __init__(self, shape, **params): 25 | super(DenseBlock, self).__init__() 26 | in_dims, out_dims = shape 27 | _seq = collections.OrderedDict([ 28 | ('dense', nn.Linear(in_dims, out_dims)), 29 | ]) 30 | _act_name = params.get('activation') 31 | if _act_name: 32 | _seq.update({_act_name: _activations[_act_name](inplace=True)}) 33 | 34 | self._layer = nn.Sequential(_seq) 35 | 36 | w_init = params.get('w_init', None) 37 | idx = list(dict(self._layer.named_children()).keys()).index('dense') 38 | if w_init: 39 | w_init(self._layer[idx].weight) 40 | b_init = params.get('b_init', None) 41 | if b_init: 42 | b_init(self._layer[idx].bias) 43 | 44 | 45 | class Conv2DBlock(BaseBlock): 46 | 47 | def __init__(self, shape, stride, padding='same', **params): 48 | super(Conv2DBlock, self).__init__() 49 | 50 | h, w, in_channels, out_channels = shape 51 | _seq = collections.OrderedDict([ 52 | ('conv', nn.Conv2d(in_channels, out_channels, kernel_size=(h, w), stride=stride, padding=padding)) 53 | ]) 54 | 55 | _bn = params.get('batch_norm') 56 | if _bn: 57 | _seq.update({'bn': nn.BatchNorm2d(out_channels)}) 58 | 59 | _act_name = params.get('activation') 60 | if _act_name: 61 | _seq.update({_act_name: _activations[_act_name](inplace=True)}) 62 | 63 | _max_pool = params.get('max_pool') 64 | if _max_pool: 65 | _kernel_size = params.get('max_pool_size', 2) 66 | _stride = params.get('max_pool_stride', _kernel_size) 67 | _seq.update({'max_pool': nn.MaxPool2d(kernel_size=_kernel_size, stride=_stride)}) 68 | 69 | self._layer = nn.Sequential(_seq) 70 | 71 | w_init = params.get('w_init', None) 72 | idx = list(dict(self._layer.named_children()).keys()).index('conv') 73 | if w_init: 74 | w_init(self._layer[idx].weight) 75 | b_init = params.get('b_init', None) 76 | if b_init: 77 | b_init(self._layer[idx].bias) 78 | -------------------------------------------------------------------------------- /src/model/network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | from . import _blocks 5 | 6 | 7 | class Network(nn.Module): 8 | 9 | def __init__(self, **params): 10 | super(Network, self).__init__() 11 | self.dropout_rate = params.get('dropout_rate', 0.5) 12 | self.classes = params.get('classes', 10) 13 | self.channels = params.get('channels', 1) 14 | 15 | _w_init = params.get('w_init', lambda x: nn.init.kaiming_normal_(x, nonlinearity='relu')) 16 | _b_init = params.get('b_init', lambda x: nn.init.constant_(x, 0.1)) 17 | 18 | self._layer = nn.Sequential( 19 | _blocks.Conv2DBlock( 20 | shape=[5, 5, self.channels, 16], stride=1, padding='valid', activation='relu', max_pool=True, 21 | w_init=_w_init, b_init=_b_init 22 | ), 23 | _blocks.Conv2DBlock( 24 | shape=[5, 5, 16, 32], stride=1, padding='valid', activation='relu', max_pool=True, 25 | w_init=_w_init, b_init=_b_init 26 | ), 27 | _blocks.Conv2DBlock( 28 | shape=[6, 6, 32, 64], stride=1, padding='valid', activation='relu', max_pool=True, 29 | w_init=_w_init, b_init=_b_init 30 | ), 31 | _blocks.Conv2DBlock( 32 | shape=[5, 5, 64, 128], stride=1, padding='valid', activation='relu', 33 | w_init=_w_init, b_init=_b_init 34 | ), 35 | nn.Dropout(p=self.dropout_rate), 36 | _blocks.Conv2DBlock( 37 | shape=[3, 3, 128, self.classes], stride=1, padding='valid', 38 | w_init=_w_init, b_init=nn.init.zeros_ 39 | ), 40 | nn.Flatten() 41 | ) 42 | 43 | def forward(self, x): 44 | return self._layer(x) 45 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | from absl import logging 2 | from absl import flags 3 | from absl import app 4 | 5 | from tqdm import tqdm 6 | 7 | from torch.utils import tensorboard 8 | 9 | import torchvision 10 | import torch 11 | 12 | import numpy as np 13 | 14 | import json 15 | import os 16 | 17 | from data import preprocess 18 | from data import loader 19 | from utils import common 20 | import model 21 | 22 | flags.DEFINE_string('experiments_path', os.path.join(common.project_root, 'experiments'), help='') 23 | flags.DEFINE_string('config_name', 'config/AConvNet-SOC.json', help='') 24 | FLAGS = flags.FLAGS 25 | 26 | 27 | common.set_random_seed(12321) 28 | 29 | 30 | def load_dataset(path, is_train, name, batch_size): 31 | transform = [preprocess.CenterCrop(88), torchvision.transforms.ToTensor()] 32 | if is_train: 33 | transform = [preprocess.RandomCrop(88), torchvision.transforms.ToTensor()] 34 | _dataset = loader.Dataset( 35 | path, name=name, is_train=is_train, 36 | transform=torchvision.transforms.Compose(transform) 37 | ) 38 | data_loader = torch.utils.data.DataLoader( 39 | _dataset, batch_size=batch_size, shuffle=is_train, num_workers=1 40 | ) 41 | return data_loader 42 | 43 | 44 | @torch.no_grad() 45 | def validation(m, ds): 46 | num_data = 0 47 | corrects = 0 48 | 49 | # Test loop 50 | m.net.eval() 51 | _softmax = torch.nn.Softmax(dim=1) 52 | for i, data in enumerate(tqdm(ds)): 53 | images, labels, _ = data 54 | 55 | predictions = m.inference(images) 56 | predictions = _softmax(predictions) 57 | 58 | _, predictions = torch.max(predictions.data, 1) 59 | labels = labels.type(torch.LongTensor) 60 | num_data += labels.size(0) 61 | corrects += (predictions == labels.to(m.device)).sum().item() 62 | 63 | accuracy = 100 * corrects / num_data 64 | return accuracy 65 | 66 | 67 | def run(epochs, dataset, classes, channels, batch_size, 68 | lr, lr_step, lr_decay, weight_decay, dropout_rate, 69 | model_name, experiments_path=None): 70 | train_set = load_dataset('dataset', True, dataset, batch_size) 71 | valid_set = load_dataset('dataset', False, dataset, batch_size) 72 | 73 | m = model.Model( 74 | classes=classes, dropout_rate=dropout_rate, channels=channels, 75 | lr=lr, lr_step=lr_step, lr_decay=lr_decay, 76 | weight_decay=weight_decay 77 | ) 78 | 79 | model_path = os.path.join(experiments_path, f'model/{model_name}') 80 | if not os.path.exists(model_path): 81 | os.makedirs(model_path, exist_ok=True) 82 | 83 | history_path = os.path.join(experiments_path, 'history') 84 | if not os.path.exists(history_path): 85 | os.makedirs(history_path, exist_ok=True) 86 | 87 | history = { 88 | 'loss': [], 89 | 'accuracy': [] 90 | } 91 | 92 | for epoch in range(epochs): 93 | _loss = [] 94 | 95 | m.net.train() 96 | for i, data in enumerate(tqdm(train_set)): 97 | images, labels, _ = data 98 | _loss.append(m.optimize(images, labels)) 99 | 100 | if m.lr_scheduler: 101 | lr = m.lr_scheduler.get_last_lr()[0] 102 | m.lr_scheduler.step() 103 | 104 | accuracy = validation(m, valid_set) 105 | 106 | logging.info( 107 | f'Epoch: {epoch + 1:03d}/{epochs:03d} | loss={np.mean(_loss):.4f} | lr={lr} | accuracy={accuracy:.2f}' 108 | ) 109 | 110 | history['loss'].append(np.mean(_loss)) 111 | history['accuracy'].append(accuracy) 112 | 113 | if experiments_path: 114 | m.save(os.path.join(model_path, f'model-{epoch + 1:03d}.pth')) 115 | 116 | with open(os.path.join(history_path, f'history-{model_name}.json'), mode='w', encoding='utf-8') as f: 117 | json.dump(history, f, ensure_ascii=True, indent=2) 118 | 119 | 120 | def main(_): 121 | logging.info('Start') 122 | experiments_path = FLAGS.experiments_path 123 | config_name = FLAGS.config_name 124 | 125 | config = common.load_config(os.path.join(experiments_path, config_name)) 126 | 127 | dataset = config['dataset'] 128 | classes = config['num_classes'] 129 | channels = config['channels'] 130 | epochs = config['epochs'] 131 | batch_size = config['batch_size'] 132 | 133 | lr = config['lr'] 134 | lr_step = config['lr_step'] 135 | lr_decay = config['lr_decay'] 136 | 137 | weight_decay = config['weight_decay'] 138 | dropout_rate = config['dropout_rate'] 139 | 140 | model_name = config['model_name'] 141 | 142 | run(epochs, dataset, classes, channels, batch_size, 143 | lr, lr_step, lr_decay, weight_decay, dropout_rate, 144 | model_name, experiments_path) 145 | 146 | logging.info('Finish') 147 | 148 | 149 | if __name__ == '__main__': 150 | app.run(main) 151 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jangsoopark/AConvNet-pytorch/c60740d40407f68c9df71a8c0b871601d5ba849d/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import random 5 | import json 6 | import os 7 | 8 | project_root = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) 9 | 10 | 11 | def set_random_seed(random_seed): 12 | torch.manual_seed(random_seed) 13 | torch.cuda.manual_seed(random_seed) 14 | torch.cuda.manual_seed_all(random_seed) # if use multi-GPU 15 | 16 | torch.backends.cudnn.deterministic = True 17 | torch.backends.cudnn.benchmark = False 18 | 19 | np.random.seed(random_seed) 20 | random.seed(random_seed) 21 | 22 | 23 | def load_config(path): 24 | with open(path, mode='r', encoding='utf-8') as f: 25 | return json.load(f) 26 | --------------------------------------------------------------------------------