├── LICENSE ├── README.md ├── assets ├── 01.png ├── 02.png ├── 03.jpeg ├── cifar_densenet121_train_acc.png ├── cifar_densenet121_train_loss.png ├── cifar_densenet121_val_acc.png ├── cifar_densenet121_val_loss.png ├── cifar_se_densenet121_full_in_loop_train_acc.png ├── cifar_se_densenet121_full_in_loop_train_loss.png ├── cifar_se_densenet121_full_in_loop_val_acc.png ├── cifar_se_densenet121_full_in_loop_val_loss.png ├── cifar_se_densenet121_full_train_acc.png ├── cifar_se_densenet121_full_train_loss.png ├── cifar_se_densenet121_full_val_acc.png ├── cifar_se_densenet121_full_val_loss.png ├── cifar_se_densenet121_w_block_train_acc.png ├── cifar_se_densenet121_w_block_train_loss.png ├── cifar_se_densenet121_w_block_val_acc.png ├── cifar_se_densenet121_w_block_val_loss.png ├── densenet121_train_acc.png ├── densenet121_train_loss.png ├── densenet121_val_acc.png ├── densenet121_val_loss.png ├── se_densenet121_train_acc.png ├── se_densenet121_train_loss.png ├── se_densenet121_val_acc.png └── se_densenet121_val_loss.png ├── baseline.py ├── cifar10.py ├── core ├── __init__.py ├── baseline.py ├── se_densenet.py ├── se_densenet_full.py ├── se_densenet_full_in_loop.py ├── se_densenet_w_block.py ├── se_efficient_densenet.py ├── se_module.py └── test_se_densenet.py ├── data └── cifar10 │ └── download_cifar10_at_here.md ├── state ├── state_baseline.txt ├── state_full.txt ├── state_full_in_loop.txt └── state_w_block.txt ├── utils.py ├── visual └── viz.py └── weights └── save_module_weights_at_here.md /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 AllenZhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | 3 | ![](assets/03.jpeg) 4 | This is a DensNet which contains a [SE](https://arxiv.org/abs/1709.01507) (Squeeze-and-Excitation Networks by Jie Hu, Li Shen and Gang Sun) module. 5 | Using densenet as backbone, I add senet module into densenet as pic shows below, but it's not the whole structure of se_densenet. 6 | 7 | ![](assets/02.png) 8 | 9 | Please click my **[blog](http://www.zhouyuangan.cn/2018/11/se_densenet-modify-densenet-with-champion-network-of-the-2017-classification-task-named-squeeze-and-excitation-network/)** if you want to know more edited se_densenet details. And Chinese version blog is [here](https://zhuanlan.zhihu.com/p/48499356) 10 | 11 | ## Table of contents 12 | 13 | - Experiment on Cifar dataset 14 | - Experiment on my Own datasets 15 | - How to train model 16 | - Conclusion 17 | - Todo 18 | 19 | Before we start, let's known how to test se_densenet first. 20 | 21 | ```bash 22 | cd core 23 | python3 se_densenet.py 24 | ``` 25 | 26 | And it will print the structure of se_densenet. 27 | 28 | Let's input an tensor which shape is (32, 3, 224, 224) into se_densenet 29 | 30 | ```bash 31 | cd core 32 | python3 test_se_densenet.py 33 | ``` 34 | 35 | It will print ``torch.size(32, 1000)`` 36 | 37 | ## Experiment on Cifar dataset 38 | 39 | ### core/baseline.py (baseline) 40 | 41 | - Train 42 | ![](assets/cifar_densenet121_train_acc.png) 43 | ![](assets/cifar_densenet121_train_loss.png) 44 | 45 | - val 46 | ![](assets/cifar_densenet121_val_acc.png) 47 | ![](assets/cifar_densenet121_val_loss.png) 48 | 49 | The best val acc is 0.9406 at epoch 98 50 | 51 | ### core/se_densenet_w_block.py 52 | 53 | In this part, I removed some selayers from densenet' ``transition`` layers, pls check [se_densenet_w_block.py](https://github.com/zhouyuangan/SE_DenseNet/blob/master/se_densenet_w_block.py) and you will find some commented code which point to selayers I have mentioned above. 54 | 55 | - train 56 | 57 | ![](assets/cifar_se_densenet121_w_block_train_acc.png) 58 | ![](assets/cifar_se_densenet121_w_block_train_loss.png) 59 | 60 | - val 61 | 62 | ![](assets/cifar_se_densenet121_w_block_val_acc.png) 63 | ![](assets/cifar_se_densenet121_w_block_val_loss.png) 64 | 65 | The best acc is 0.9381 at epoch 98. 66 | 67 | ### core/se_densenet_full.py 68 | 69 | Pls check [se_densenet_full.py](https://github.com/zhouyuangan/SE_DenseNet/blob/master/se_densenet_full.py) get more details, I add senet into both denseblock and transition, thanks for [@john1231983](https://github.com/John1231983)'s issue, I remove some redundant code in se_densenet_full.py, check this [issue](https://github.com/zhouyuangan/SE_DenseNet/issues/1) you will know what I say, here is train-val result on cifar-10: 70 | 71 | - train 72 | 73 | ![](assets/cifar_se_densenet121_full_train_acc.png) 74 | ![](assets/cifar_se_densenet121_full_train_loss.png) 75 | 76 | - val 77 | 78 | ![](assets/cifar_se_densenet121_full_val_acc.png) 79 | ![](assets/cifar_se_densenet121_full_val_loss.png) 80 | 81 | The best acc is 0.9407 at epoch 86. 82 | 83 | ### core/se_densenet_full_in_loop.py 84 | 85 | Pls check [se_densenet_full_in_loop.py](https://github.com/zhouyuangan/SE_DenseNet/blob/master/se_densenet_full_in_loop.py) get more details, and this [issue](https://github.com/zhouyuangan/SE_DenseNet/issues/1#issuecomment-438891133) illustrate what I have changed, here is train-val result on cifar-10: 86 | 87 | - train 88 | 89 | ![](assets/cifar_se_densenet121_full_in_loop_train_acc.png) 90 | ![](assets/cifar_se_densenet121_full_in_loop_train_loss.png) 91 | 92 | - val 93 | 94 | ![](assets/cifar_se_densenet121_full_in_loop_val_acc.png) 95 | ![](assets/cifar_se_densenet121_full_in_loop_val_loss.png) 96 | 97 | The best acc is 0.9434 at epoch 97. 98 | 99 | ### Result 100 | 101 | |network|best val acc|epoch| 102 | |--|--|--| 103 | |``densenet``|0.9406|98| 104 | |``se_densenet_w_block``|0.9381|98| 105 | |``se_densenet_full``|0.9407|**86**| 106 | |``se_densenet_full_in_loop``|**0.9434**|97| 107 | 108 | ## Experiment on my Own datasets 109 | 110 | 111 | ### core/baseline.py (baseline) 112 | 113 | - train 114 | ![](assets/densenet121_train_acc.png) 115 | ![](assets/densenet121_train_loss.png) 116 | 117 | - val 118 | ![](assets/densenet121_val_acc.png) 119 | ![](assets/densenet121_val_loss.png) 120 | 121 | The best acc is: 98.5417% 122 | 123 | ### core/se_densenet.py 124 | 125 | - train 126 | 127 | ![](assets/se_densenet121_train_acc.png) 128 | ![](assets/se_densenet121_train_loss.png) 129 | 130 | - val 131 | 132 | ![](assets/se_densenet121_val_acc.png) 133 | ![](assets/se_densenet121_val_loss.png) 134 | 135 | The best acc is: 98.6154% 136 | 137 | ### Result 138 | 139 | |network|best train acc|best val acc| 140 | |--|--|--| 141 | |``densenet``|0.966953|0.985417| 142 | |``se_densenet``|**0.967772**|**0.986154**| 143 | 144 | ``Se_densenet`` has got **0.0737%** higher accuracy than ``densenet``. I didn't train and test on public dataset like cifar and coco, because of low capacity of machine computation, you can train and test on cifar or coco dataset by yourself if you have the will. 145 | 146 | ## How to train model 147 | 148 | - Download dataset 149 | 150 | Cifar10 dataset is easy to access it's website and download it into `data/cifar10`, you can refer to pytorch official tutorials about how to train on cifar10 dataset. 151 | 152 | - Training 153 | 154 | There are some modules in `core` folder, before you start training, you need to edit code in `cifar10.py` file to change `from core.se_densenet_xxxx import se_densenet121`, 155 | 156 | Then, open your terminal, type: 157 | ```bash 158 | python cifar10.py 159 | ``` 160 | 161 | - Visualize training 162 | 163 | In your terminal, type: 164 | 165 | ```bash 166 | python visual/viz.py 167 | ``` 168 | Note: please change your state file path in /visual/viz.py. 169 | 170 | ### Conclusion 171 | 172 | - ``se_densenet_full_in_loop`` gets **the best accuracy** at epoch 97. 173 | - ``se_densenet_full`` performs well because of less epoch at 86,and it gets ``0.9407`` the second accuracy. 174 | - In the contrast, both ``densenet`` and ``se_densenet_w_block`` get their own the highest accuracy are ``98`` epoch. 175 | 176 | ## TODO 177 | 178 | I will release my training code on github as quickly as possible. 179 | 180 | - [x] Usage of my codes 181 | - [x] Test result on my own dataset 182 | - [x] Train and test on ``cifar-10`` dataset 183 | - [x] Release train and test code 184 | -------------------------------------------------------------------------------- /assets/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/01.png -------------------------------------------------------------------------------- /assets/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/02.png -------------------------------------------------------------------------------- /assets/03.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/03.jpeg -------------------------------------------------------------------------------- /assets/cifar_densenet121_train_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/cifar_densenet121_train_acc.png -------------------------------------------------------------------------------- /assets/cifar_densenet121_train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/cifar_densenet121_train_loss.png -------------------------------------------------------------------------------- /assets/cifar_densenet121_val_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/cifar_densenet121_val_acc.png -------------------------------------------------------------------------------- /assets/cifar_densenet121_val_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/cifar_densenet121_val_loss.png -------------------------------------------------------------------------------- /assets/cifar_se_densenet121_full_in_loop_train_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/cifar_se_densenet121_full_in_loop_train_acc.png -------------------------------------------------------------------------------- /assets/cifar_se_densenet121_full_in_loop_train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/cifar_se_densenet121_full_in_loop_train_loss.png -------------------------------------------------------------------------------- /assets/cifar_se_densenet121_full_in_loop_val_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/cifar_se_densenet121_full_in_loop_val_acc.png -------------------------------------------------------------------------------- /assets/cifar_se_densenet121_full_in_loop_val_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/cifar_se_densenet121_full_in_loop_val_loss.png -------------------------------------------------------------------------------- /assets/cifar_se_densenet121_full_train_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/cifar_se_densenet121_full_train_acc.png -------------------------------------------------------------------------------- /assets/cifar_se_densenet121_full_train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/cifar_se_densenet121_full_train_loss.png -------------------------------------------------------------------------------- /assets/cifar_se_densenet121_full_val_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/cifar_se_densenet121_full_val_acc.png -------------------------------------------------------------------------------- /assets/cifar_se_densenet121_full_val_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/cifar_se_densenet121_full_val_loss.png -------------------------------------------------------------------------------- /assets/cifar_se_densenet121_w_block_train_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/cifar_se_densenet121_w_block_train_acc.png -------------------------------------------------------------------------------- /assets/cifar_se_densenet121_w_block_train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/cifar_se_densenet121_w_block_train_loss.png -------------------------------------------------------------------------------- /assets/cifar_se_densenet121_w_block_val_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/cifar_se_densenet121_w_block_val_acc.png -------------------------------------------------------------------------------- /assets/cifar_se_densenet121_w_block_val_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/cifar_se_densenet121_w_block_val_loss.png -------------------------------------------------------------------------------- /assets/densenet121_train_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/densenet121_train_acc.png -------------------------------------------------------------------------------- /assets/densenet121_train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/densenet121_train_loss.png -------------------------------------------------------------------------------- /assets/densenet121_val_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/densenet121_val_acc.png -------------------------------------------------------------------------------- /assets/densenet121_val_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/densenet121_val_loss.png -------------------------------------------------------------------------------- /assets/se_densenet121_train_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/se_densenet121_train_acc.png -------------------------------------------------------------------------------- /assets/se_densenet121_train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/se_densenet121_train_loss.png -------------------------------------------------------------------------------- /assets/se_densenet121_val_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/se_densenet121_val_acc.png -------------------------------------------------------------------------------- /assets/se_densenet121_val_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/assets/se_densenet121_val_loss.png -------------------------------------------------------------------------------- /baseline.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from collections import OrderedDict 7 | 8 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 9 | 10 | 11 | model_urls = { 12 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 13 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 14 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 15 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 16 | } 17 | 18 | 19 | def densenet121(pretrained=False, **kwargs): 20 | r"""Densenet-121 model from 21 | `"Densely Connected Convolutional Networks" `_ 22 | 23 | Args: 24 | pretrained (bool): If True, returns a model pre-trained on ImageNet 25 | """ 26 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), 27 | **kwargs) 28 | if pretrained: 29 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 30 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 31 | # They are also in the checkpoints in model_urls. This pattern is used 32 | # to find such keys. 33 | pattern = re.compile( 34 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 35 | state_dict = model_zoo.load_url(model_urls['densenet121']) 36 | for key in list(state_dict.keys()): 37 | res = pattern.match(key) 38 | if res: 39 | new_key = res.group(1) + res.group(2) 40 | state_dict[new_key] = state_dict[key] 41 | del state_dict[key] 42 | model.load_state_dict(state_dict) 43 | return model 44 | 45 | 46 | def densenet169(pretrained=False, **kwargs): 47 | r"""Densenet-169 model from 48 | `"Densely Connected Convolutional Networks" `_ 49 | 50 | Args: 51 | pretrained (bool): If True, returns a model pre-trained on ImageNet 52 | """ 53 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 54 | **kwargs) 55 | if pretrained: 56 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 57 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 58 | # They are also in the checkpoints in model_urls. This pattern is used 59 | # to find such keys. 60 | pattern = re.compile( 61 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 62 | state_dict = model_zoo.load_url(model_urls['densenet169']) 63 | for key in list(state_dict.keys()): 64 | res = pattern.match(key) 65 | if res: 66 | new_key = res.group(1) + res.group(2) 67 | state_dict[new_key] = state_dict[key] 68 | del state_dict[key] 69 | model.load_state_dict(state_dict) 70 | return model 71 | 72 | 73 | def densenet201(pretrained=False, **kwargs): 74 | r"""Densenet-201 model from 75 | `"Densely Connected Convolutional Networks" `_ 76 | 77 | Args: 78 | pretrained (bool): If True, returns a model pre-trained on ImageNet 79 | """ 80 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 81 | **kwargs) 82 | if pretrained: 83 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 84 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 85 | # They are also in the checkpoints in model_urls. This pattern is used 86 | # to find such keys. 87 | pattern = re.compile( 88 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 89 | state_dict = model_zoo.load_url(model_urls['densenet201']) 90 | for key in list(state_dict.keys()): 91 | res = pattern.match(key) 92 | if res: 93 | new_key = res.group(1) + res.group(2) 94 | state_dict[new_key] = state_dict[key] 95 | del state_dict[key] 96 | model.load_state_dict(state_dict) 97 | return model 98 | 99 | 100 | def densenet161(pretrained=False, **kwargs): 101 | r"""Densenet-161 model from 102 | `"Densely Connected Convolutional Networks" `_ 103 | 104 | Args: 105 | pretrained (bool): If True, returns a model pre-trained on ImageNet 106 | """ 107 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 108 | **kwargs) 109 | if pretrained: 110 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 111 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 112 | # They are also in the checkpoints in model_urls. This pattern is used 113 | # to find such keys. 114 | pattern = re.compile( 115 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 116 | state_dict = model_zoo.load_url(model_urls['densenet161']) 117 | for key in list(state_dict.keys()): 118 | res = pattern.match(key) 119 | if res: 120 | new_key = res.group(1) + res.group(2) 121 | state_dict[new_key] = state_dict[key] 122 | del state_dict[key] 123 | model.load_state_dict(state_dict) 124 | return model 125 | 126 | 127 | class _DenseLayer(nn.Sequential): 128 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 129 | super(_DenseLayer, self).__init__() 130 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 131 | self.add_module('relu1', nn.ReLU(inplace=True)), 132 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 133 | growth_rate, kernel_size=1, stride=1, bias=False)), 134 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 135 | self.add_module('relu2', nn.ReLU(inplace=True)), 136 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 137 | kernel_size=3, stride=1, padding=1, bias=False)), 138 | self.drop_rate = drop_rate 139 | 140 | def forward(self, x): 141 | new_features = super(_DenseLayer, self).forward(x) 142 | if self.drop_rate > 0: 143 | new_features = F.dropout(new_features, 144 | p=self.drop_rate, 145 | training=self.training 146 | ) 147 | return torch.cat([x, new_features], 1) 148 | 149 | 150 | class _DenseBlock(nn.Sequential): 151 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 152 | super(_DenseBlock, self).__init__() 153 | for i in range(num_layers): 154 | layer = _DenseLayer( 155 | num_input_features + i * growth_rate, 156 | growth_rate, 157 | bn_size, 158 | drop_rate 159 | ) 160 | self.add_module('denselayer%d' % (i + 1), layer) 161 | 162 | 163 | class _Transition(nn.Sequential): 164 | def __init__(self, num_input_features, num_output_features): 165 | super(_Transition, self).__init__() 166 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 167 | self.add_module('relu', nn.ReLU(inplace=True)) 168 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 169 | kernel_size=1, stride=1, bias=False)) 170 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 171 | 172 | 173 | class DenseNet(nn.Module): 174 | r"""Densenet-BC model class, based on 175 | `"Densely Connected Convolutional Networks" `_ 176 | 177 | Args: 178 | growth_rate (int) - how many filters to add each layer (`k` in paper) 179 | block_config (list of 4 ints) - how many layers in each pooling block 180 | num_init_features (int) - the number of filters to learn in the first convolution layer 181 | bn_size (int) - multiplicative factor for number of bottle neck layers 182 | (i.e. bn_size * k features in the bottleneck layer) 183 | drop_rate (float) - dropout rate after each dense layer 184 | num_classes (int) - number of classification classes 185 | """ 186 | 187 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 188 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): 189 | 190 | super(DenseNet, self).__init__() 191 | 192 | # First convolution 193 | self.features = nn.Sequential(OrderedDict([ 194 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 195 | ('norm0', nn.BatchNorm2d(num_init_features)), 196 | ('relu0', nn.ReLU(inplace=True)), 197 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 198 | ])) 199 | 200 | # Each denseblock 201 | num_features = num_init_features 202 | for i, num_layers in enumerate(block_config): 203 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 204 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 205 | self.features.add_module('denseblock%d' % (i + 1), block) 206 | num_features = num_features + num_layers * growth_rate 207 | if i != len(block_config) - 1: 208 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) 209 | self.features.add_module('transition%d' % (i + 1), trans) 210 | num_features = num_features // 2 211 | 212 | # Final batch norm 213 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 214 | 215 | # Linear layer 216 | self.classifier = nn.Linear(num_features, num_classes) 217 | 218 | # Official init from torch repo. 219 | for m in self.modules(): 220 | if isinstance(m, nn.Conv2d): 221 | nn.init.kaiming_normal_(m.weight) 222 | elif isinstance(m, nn.BatchNorm2d): 223 | nn.init.constant_(m.weight, 1) 224 | nn.init.constant_(m.bias, 0) 225 | elif isinstance(m, nn.Linear): 226 | nn.init.constant_(m.bias, 0) 227 | 228 | def forward(self, x): 229 | features = self.features(x) 230 | out = F.relu(features, inplace=True) 231 | out = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1) 232 | out = self.classifier(out) 233 | return out 234 | -------------------------------------------------------------------------------- /cifar10.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch.nn.functional as F 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader 5 | from torchvision import datasets, transforms 6 | 7 | from core.se_densenet_full_in_loop import se_densenet121 8 | from core.baseline import densenet121 9 | from utils import Trainer 10 | 11 | 12 | def get_dataloader(batch_size, root="data/cifar10"): 13 | root = Path(root).expanduser() 14 | if not root.exists(): 15 | root.mkdir() 16 | root = str(root) 17 | 18 | to_normalized_tensor = [transforms.ToTensor(), 19 | transforms.ToPILImage(), 20 | transforms.Resize((224, 224)), 21 | transforms.ToTensor(), 22 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))] 23 | data_augmentation = [transforms.RandomHorizontalFlip(),] 24 | 25 | train_loader = DataLoader( 26 | datasets.CIFAR10(root, train=True, download=True, 27 | transform=transforms.Compose(data_augmentation + to_normalized_tensor)), 28 | batch_size=batch_size, shuffle=True) 29 | test_loader = DataLoader( 30 | datasets.CIFAR10(root, train=False, transform=transforms.Compose(to_normalized_tensor)), 31 | batch_size=batch_size, shuffle=True) 32 | return train_loader, test_loader 33 | 34 | 35 | def main(batch_size, baseline, reduction): 36 | train_loader, test_loader = get_dataloader(batch_size) 37 | 38 | if baseline: 39 | model = densenet121() 40 | else: 41 | model = se_densenet121(num_classes=10) 42 | 43 | optimizer = optim.SGD(params=model.parameters(), lr=1e-1, momentum=0.9, 44 | weight_decay=1e-4) 45 | scheduler = optim.lr_scheduler.StepLR(optimizer, 80, 0.1) 46 | trainer = Trainer(model, optimizer, F.cross_entropy, save_dir="weights") 47 | trainer.loop(100, train_loader, test_loader, scheduler) 48 | 49 | 50 | if __name__ == '__main__': 51 | import argparse 52 | 53 | p = argparse.ArgumentParser() 54 | p.add_argument("--batchsize", type=int, default=64) 55 | p.add_argument("--reduction", type=int, default=16) 56 | p.add_argument("--baseline", action="store_true") 57 | args = p.parse_args() 58 | main(args.batchsize, args.baseline, args.reduction) 59 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/core/__init__.py -------------------------------------------------------------------------------- /core/baseline.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from collections import OrderedDict 7 | 8 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 9 | 10 | 11 | model_urls = { 12 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 13 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 14 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 15 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 16 | } 17 | 18 | 19 | def densenet121(pretrained=False, **kwargs): 20 | r"""Densenet-121 model from 21 | `"Densely Connected Convolutional Networks" `_ 22 | 23 | Args: 24 | pretrained (bool): If True, returns a model pre-trained on ImageNet 25 | """ 26 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), 27 | **kwargs) 28 | if pretrained: 29 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 30 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 31 | # They are also in the checkpoints in model_urls. This pattern is used 32 | # to find such keys. 33 | pattern = re.compile( 34 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 35 | state_dict = model_zoo.load_url(model_urls['densenet121']) 36 | for key in list(state_dict.keys()): 37 | res = pattern.match(key) 38 | if res: 39 | new_key = res.group(1) + res.group(2) 40 | state_dict[new_key] = state_dict[key] 41 | del state_dict[key] 42 | model.load_state_dict(state_dict) 43 | return model 44 | 45 | 46 | def densenet169(pretrained=False, **kwargs): 47 | r"""Densenet-169 model from 48 | `"Densely Connected Convolutional Networks" `_ 49 | 50 | Args: 51 | pretrained (bool): If True, returns a model pre-trained on ImageNet 52 | """ 53 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 54 | **kwargs) 55 | if pretrained: 56 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 57 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 58 | # They are also in the checkpoints in model_urls. This pattern is used 59 | # to find such keys. 60 | pattern = re.compile( 61 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 62 | state_dict = model_zoo.load_url(model_urls['densenet169']) 63 | for key in list(state_dict.keys()): 64 | res = pattern.match(key) 65 | if res: 66 | new_key = res.group(1) + res.group(2) 67 | state_dict[new_key] = state_dict[key] 68 | del state_dict[key] 69 | model.load_state_dict(state_dict) 70 | return model 71 | 72 | 73 | def densenet201(pretrained=False, **kwargs): 74 | r"""Densenet-201 model from 75 | `"Densely Connected Convolutional Networks" `_ 76 | 77 | Args: 78 | pretrained (bool): If True, returns a model pre-trained on ImageNet 79 | """ 80 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 81 | **kwargs) 82 | if pretrained: 83 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 84 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 85 | # They are also in the checkpoints in model_urls. This pattern is used 86 | # to find such keys. 87 | pattern = re.compile( 88 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 89 | state_dict = model_zoo.load_url(model_urls['densenet201']) 90 | for key in list(state_dict.keys()): 91 | res = pattern.match(key) 92 | if res: 93 | new_key = res.group(1) + res.group(2) 94 | state_dict[new_key] = state_dict[key] 95 | del state_dict[key] 96 | model.load_state_dict(state_dict) 97 | return model 98 | 99 | 100 | def densenet161(pretrained=False, **kwargs): 101 | r"""Densenet-161 model from 102 | `"Densely Connected Convolutional Networks" `_ 103 | 104 | Args: 105 | pretrained (bool): If True, returns a model pre-trained on ImageNet 106 | """ 107 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 108 | **kwargs) 109 | if pretrained: 110 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 111 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 112 | # They are also in the checkpoints in model_urls. This pattern is used 113 | # to find such keys. 114 | pattern = re.compile( 115 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 116 | state_dict = model_zoo.load_url(model_urls['densenet161']) 117 | for key in list(state_dict.keys()): 118 | res = pattern.match(key) 119 | if res: 120 | new_key = res.group(1) + res.group(2) 121 | state_dict[new_key] = state_dict[key] 122 | del state_dict[key] 123 | model.load_state_dict(state_dict) 124 | return model 125 | 126 | 127 | class _DenseLayer(nn.Sequential): 128 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 129 | super(_DenseLayer, self).__init__() 130 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 131 | self.add_module('relu1', nn.ReLU(inplace=True)), 132 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 133 | growth_rate, kernel_size=1, stride=1, bias=False)), 134 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 135 | self.add_module('relu2', nn.ReLU(inplace=True)), 136 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 137 | kernel_size=3, stride=1, padding=1, bias=False)), 138 | self.drop_rate = drop_rate 139 | 140 | def forward(self, x): 141 | new_features = super(_DenseLayer, self).forward(x) 142 | if self.drop_rate > 0: 143 | new_features = F.dropout(new_features, 144 | p=self.drop_rate, 145 | training=self.training 146 | ) 147 | return torch.cat([x, new_features], 1) 148 | 149 | 150 | class _DenseBlock(nn.Sequential): 151 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 152 | super(_DenseBlock, self).__init__() 153 | for i in range(num_layers): 154 | layer = _DenseLayer( 155 | num_input_features + i * growth_rate, 156 | growth_rate, 157 | bn_size, 158 | drop_rate 159 | ) 160 | self.add_module('denselayer%d' % (i + 1), layer) 161 | 162 | 163 | class _Transition(nn.Sequential): 164 | def __init__(self, num_input_features, num_output_features): 165 | super(_Transition, self).__init__() 166 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 167 | self.add_module('relu', nn.ReLU(inplace=True)) 168 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 169 | kernel_size=1, stride=1, bias=False)) 170 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 171 | 172 | 173 | class DenseNet(nn.Module): 174 | r"""Densenet-BC model class, based on 175 | `"Densely Connected Convolutional Networks" `_ 176 | 177 | Args: 178 | growth_rate (int) - how many filters to add each layer (`k` in paper) 179 | block_config (list of 4 ints) - how many layers in each pooling block 180 | num_init_features (int) - the number of filters to learn in the first convolution layer 181 | bn_size (int) - multiplicative factor for number of bottle neck layers 182 | (i.e. bn_size * k features in the bottleneck layer) 183 | drop_rate (float) - dropout rate after each dense layer 184 | num_classes (int) - number of classification classes 185 | """ 186 | 187 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 188 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): 189 | 190 | super(DenseNet, self).__init__() 191 | 192 | # First convolution 193 | self.features = nn.Sequential(OrderedDict([ 194 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 195 | ('norm0', nn.BatchNorm2d(num_init_features)), 196 | ('relu0', nn.ReLU(inplace=True)), 197 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 198 | ])) 199 | 200 | # Each denseblock 201 | num_features = num_init_features 202 | for i, num_layers in enumerate(block_config): 203 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 204 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 205 | self.features.add_module('denseblock%d' % (i + 1), block) 206 | num_features = num_features + num_layers * growth_rate 207 | if i != len(block_config) - 1: 208 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) 209 | self.features.add_module('transition%d' % (i + 1), trans) 210 | num_features = num_features // 2 211 | 212 | # Final batch norm 213 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 214 | 215 | # Linear layer 216 | self.classifier = nn.Linear(num_features, num_classes) 217 | 218 | # Official init from torch repo. 219 | for m in self.modules(): 220 | if isinstance(m, nn.Conv2d): 221 | nn.init.kaiming_normal_(m.weight) 222 | elif isinstance(m, nn.BatchNorm2d): 223 | nn.init.constant_(m.weight, 1) 224 | nn.init.constant_(m.bias, 0) 225 | elif isinstance(m, nn.Linear): 226 | nn.init.constant_(m.bias, 0) 227 | 228 | def forward(self, x): 229 | features = self.features(x) 230 | out = F.relu(features, inplace=True) 231 | out = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1) 232 | out = self.classifier(out) 233 | return out 234 | -------------------------------------------------------------------------------- /core/se_densenet.py: -------------------------------------------------------------------------------- 1 | """SEDensenet 2 | """ 3 | import sys 4 | import re 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.utils.model_zoo as model_zoo 9 | from collections import OrderedDict 10 | from se_module import SELayer 11 | 12 | 13 | __all__ = ['SEDenseNet', 'se_densenet121', 'se_densenet169', 'se_densenet201', 'se_densenet161'] 14 | 15 | 16 | model_urls = { 17 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 18 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 19 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 20 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 21 | } 22 | 23 | 24 | def se_densenet121(pretrained=False, is_strict=False, **kwargs): 25 | r"""Densenet-121 model from 26 | `"Densely Connected Convolutional Networks" `_ 27 | 28 | Args: 29 | pretrained (bool): If True, returns a model pre-trained on ImageNet 30 | """ 31 | model = SEDenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), 32 | **kwargs) 33 | if pretrained: 34 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 35 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 36 | # They are also in the checkpoints in model_urls. This pattern is used 37 | # to find such keys. 38 | pattern = re.compile( 39 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 40 | state_dict = model_zoo.load_url(model_urls['densenet121']) 41 | for key in list(state_dict.keys()): 42 | res = pattern.match(key) 43 | if res: 44 | new_key = res.group(1) + res.group(2) 45 | state_dict[new_key] = state_dict[key] 46 | del state_dict[key] 47 | model.load_state_dict(state_dict, strict=is_strict) 48 | return model 49 | 50 | 51 | def se_densenet169(pretrained=False, **kwargs): 52 | r"""Densenet-169 model from 53 | `"Densely Connected Convolutional Networks" `_ 54 | 55 | Args: 56 | pretrained (bool): If True, returns a model pre-trained on ImageNet 57 | """ 58 | model = SEDenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 59 | **kwargs) 60 | if pretrained: 61 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 62 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 63 | # They are also in the checkpoints in model_urls. This pattern is used 64 | # to find such keys. 65 | pattern = re.compile( 66 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 67 | state_dict = model_zoo.load_url(model_urls['densenet169']) 68 | for key in list(state_dict.keys()): 69 | res = pattern.match(key) 70 | if res: 71 | new_key = res.group(1) + res.group(2) 72 | state_dict[new_key] = state_dict[key] 73 | del state_dict[key] 74 | model.load_state_dict(state_dict, strict=False) 75 | return model 76 | 77 | 78 | def se_densenet201(pretrained=False, **kwargs): 79 | r"""Densenet-201 model from 80 | `"Densely Connected Convolutional Networks" `_ 81 | 82 | Args: 83 | pretrained (bool): If True, returns a model pre-trained on ImageNet 84 | """ 85 | model = SEDenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 86 | **kwargs) 87 | if pretrained: 88 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 89 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 90 | # They are also in the checkpoints in model_urls. This pattern is used 91 | # to find such keys. 92 | pattern = re.compile( 93 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 94 | state_dict = model_zoo.load_url(model_urls['densenet201']) 95 | for key in list(state_dict.keys()): 96 | res = pattern.match(key) 97 | if res: 98 | new_key = res.group(1) + res.group(2) 99 | state_dict[new_key] = state_dict[key] 100 | del state_dict[key] 101 | model.load_state_dict(state_dict, strict=False) 102 | return model 103 | 104 | 105 | def se_densenet161(pretrained=False, **kwargs): 106 | r"""Densenet-161 model from 107 | `"Densely Connected Convolutional Networks" `_ 108 | 109 | Args: 110 | pretrained (bool): If True, returns a model pre-trained on ImageNet 111 | """ 112 | model = SEDenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 113 | **kwargs) 114 | if pretrained: 115 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 116 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 117 | # They are also in the checkpoints in model_urls. This pattern is used 118 | # to find such keys. 119 | pattern = re.compile( 120 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 121 | state_dict = model_zoo.load_url(model_urls['densenet161']) 122 | for key in list(state_dict.keys()): 123 | res = pattern.match(key) 124 | if res: 125 | new_key = res.group(1) + res.group(2) 126 | state_dict[new_key] = state_dict[key] 127 | del state_dict[key] 128 | model.load_state_dict(state_dict, strict=False) 129 | return model 130 | 131 | 132 | class _DenseLayer(nn.Sequential): 133 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 134 | super(_DenseLayer, self).__init__() 135 | # Add SELayer at here, like SE-PRE block in original paper illustrates 136 | self.add_module("selayer", SELayer(channel=num_input_features)), 137 | 138 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 139 | self.add_module('relu1', nn.ReLU(inplace=True)), 140 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 141 | growth_rate, kernel_size=1, stride=1, bias=False)), 142 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 143 | self.add_module('relu2', nn.ReLU(inplace=True)), 144 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 145 | kernel_size=3, stride=1, padding=1, bias=False)), 146 | self.drop_rate = drop_rate 147 | 148 | def forward(self, x): 149 | new_features = super(_DenseLayer, self).forward(x) 150 | if self.drop_rate > 0: 151 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 152 | return torch.cat([x, new_features], 1) 153 | 154 | 155 | class _DenseBlock(nn.Sequential): 156 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 157 | super(_DenseBlock, self).__init__() 158 | for i in range(num_layers): 159 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) 160 | self.add_module('denselayer%d' % (i + 1), layer) 161 | 162 | 163 | class _Transition(nn.Sequential): 164 | def __init__(self, num_input_features, num_output_features): 165 | super(_Transition, self).__init__() 166 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 167 | self.add_module('relu', nn.ReLU(inplace=True)) 168 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 169 | kernel_size=1, stride=1, bias=False)) 170 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 171 | 172 | 173 | class SEDenseNet(nn.Module): 174 | r"""Densenet-BC model class, based on 175 | `"Densely Connected Convolutional Networks" `_ 176 | 177 | Args: 178 | growth_rate (int) - how many filters to add each layer (`k` in paper) 179 | block_config (list of 4 ints) - how many layers in each pooling block 180 | num_init_features (int) - the number of filters to learn in the first convolution layer 181 | bn_size (int) - multiplicative factor for number of bottle neck layers 182 | (i.e. bn_size * k features in the bottleneck layer) 183 | drop_rate (float) - dropout rate after each dense layer 184 | num_classes (int) - number of classification classes 185 | """ 186 | 187 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 188 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): 189 | 190 | super(SEDenseNet, self).__init__() 191 | 192 | # First convolution 193 | self.features = nn.Sequential(OrderedDict([ 194 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 195 | ('norm0', nn.BatchNorm2d(num_init_features)), 196 | ('relu0', nn.ReLU(inplace=True)), 197 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 198 | ])) 199 | 200 | # Add SELayer at first convolution 201 | self.features.add_module("SELayer_0a", SELayer(channel=num_init_features)) 202 | 203 | # Each denseblock 204 | num_features = num_init_features 205 | for i, num_layers in enumerate(block_config): 206 | # Add a SELayer 207 | self.features.add_module("SELayer_%da" % (i + 1), SELayer(channel=num_features)) 208 | 209 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 210 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 211 | self.features.add_module('denseblock%d' % (i + 1), block) 212 | 213 | num_features = num_features + num_layers * growth_rate 214 | 215 | if i != len(block_config) - 1: 216 | # Add a SELayer behind each transition block 217 | self.features.add_module("SELayer_%db" % (i + 1), SELayer(channel=num_features)) 218 | 219 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) 220 | self.features.add_module('transition%d' % (i + 1), trans) 221 | num_features = num_features // 2 222 | 223 | # Final batch norm 224 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 225 | 226 | # Add SELayer 227 | self.features.add_module("SELayer_0b", SELayer(channel=num_features)) 228 | 229 | # Linear layer 230 | self.classifier = nn.Linear(num_features, num_classes) 231 | 232 | # Official init from torch repo. 233 | for m in self.modules(): 234 | if isinstance(m, nn.Conv2d): 235 | nn.init.kaiming_normal_(m.weight) 236 | elif isinstance(m, nn.BatchNorm2d): 237 | nn.init.constant_(m.weight, 1) 238 | nn.init.constant_(m.bias, 0) 239 | elif isinstance(m, nn.Linear): 240 | nn.init.constant_(m.bias, 0) 241 | 242 | def forward(self, x): 243 | features = self.features(x) 244 | out = F.relu(features, inplace=True) 245 | out = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1) 246 | out = self.classifier(out) 247 | return out 248 | 249 | if __name__ == "__main__": 250 | net = se_densenet121(pretrained=False) 251 | print(net) -------------------------------------------------------------------------------- /core/se_densenet_full.py: -------------------------------------------------------------------------------- 1 | """添加了senet模块,transition与block均有senet""" 2 | import sys 3 | sys.path.append("F:/car_classify_abnormal") 4 | 5 | import re 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.model_zoo as model_zoo 10 | from collections import OrderedDict 11 | from core.se_module import SELayer 12 | 13 | 14 | __all__ = ['SEDenseNet', 'se_densenet121', 'se_densenet169', 'se_densenet201', 'se_densenet161'] 15 | 16 | 17 | model_urls = { 18 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 19 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 20 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 21 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 22 | } 23 | 24 | 25 | def se_densenet121(pretrained=False, is_strict=False, **kwargs): 26 | r"""Densenet-121 model from 27 | `"Densely Connected Convolutional Networks" `_ 28 | 29 | Args: 30 | pretrained (bool): If True, returns a model pre-trained on ImageNet 31 | """ 32 | model = SEDenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), 33 | **kwargs) 34 | if pretrained: 35 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 36 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 37 | # They are also in the checkpoints in model_urls. This pattern is used 38 | # to find such keys. 39 | pattern = re.compile( 40 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 41 | state_dict = model_zoo.load_url(model_urls['densenet121']) 42 | for key in list(state_dict.keys()): 43 | res = pattern.match(key) 44 | if res: 45 | new_key = res.group(1) + res.group(2) 46 | state_dict[new_key] = state_dict[key] 47 | del state_dict[key] 48 | model.load_state_dict(state_dict, strict=is_strict) 49 | return model 50 | 51 | 52 | def se_densenet169(pretrained=False, **kwargs): 53 | r"""Densenet-169 model from 54 | `"Densely Connected Convolutional Networks" `_ 55 | 56 | Args: 57 | pretrained (bool): If True, returns a model pre-trained on ImageNet 58 | """ 59 | model = SEDenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 60 | **kwargs) 61 | if pretrained: 62 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 63 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 64 | # They are also in the checkpoints in model_urls. This pattern is used 65 | # to find such keys. 66 | pattern = re.compile( 67 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 68 | state_dict = model_zoo.load_url(model_urls['densenet169']) 69 | for key in list(state_dict.keys()): 70 | res = pattern.match(key) 71 | if res: 72 | new_key = res.group(1) + res.group(2) 73 | state_dict[new_key] = state_dict[key] 74 | del state_dict[key] 75 | model.load_state_dict(state_dict, strict=False) 76 | return model 77 | 78 | 79 | def se_densenet201(pretrained=False, **kwargs): 80 | r"""Densenet-201 model from 81 | `"Densely Connected Convolutional Networks" `_ 82 | 83 | Args: 84 | pretrained (bool): If True, returns a model pre-trained on ImageNet 85 | """ 86 | model = SEDenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 87 | **kwargs) 88 | if pretrained: 89 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 90 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 91 | # They are also in the checkpoints in model_urls. This pattern is used 92 | # to find such keys. 93 | pattern = re.compile( 94 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 95 | state_dict = model_zoo.load_url(model_urls['densenet201']) 96 | for key in list(state_dict.keys()): 97 | res = pattern.match(key) 98 | if res: 99 | new_key = res.group(1) + res.group(2) 100 | state_dict[new_key] = state_dict[key] 101 | del state_dict[key] 102 | model.load_state_dict(state_dict, strict=False) 103 | return model 104 | 105 | 106 | def se_densenet161(pretrained=False, **kwargs): 107 | r"""Densenet-161 model from 108 | `"Densely Connected Convolutional Networks" `_ 109 | 110 | Args: 111 | pretrained (bool): If True, returns a model pre-trained on ImageNet 112 | """ 113 | model = SEDenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 114 | **kwargs) 115 | if pretrained: 116 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 117 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 118 | # They are also in the checkpoints in model_urls. This pattern is used 119 | # to find such keys. 120 | pattern = re.compile( 121 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 122 | state_dict = model_zoo.load_url(model_urls['densenet161']) 123 | for key in list(state_dict.keys()): 124 | res = pattern.match(key) 125 | if res: 126 | new_key = res.group(1) + res.group(2) 127 | state_dict[new_key] = state_dict[key] 128 | del state_dict[key] 129 | model.load_state_dict(state_dict, strict=False) 130 | return model 131 | 132 | 133 | class _DenseLayer(nn.Sequential): 134 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 135 | super(_DenseLayer, self).__init__() 136 | # Add SELayer at here, like SE-PRE block in original paper illustrates 137 | self.add_module("selayer", SELayer(channel=num_input_features)), 138 | 139 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 140 | self.add_module('relu1', nn.ReLU(inplace=True)), 141 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 142 | growth_rate, kernel_size=1, stride=1, bias=False)), 143 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 144 | self.add_module('relu2', nn.ReLU(inplace=True)), 145 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 146 | kernel_size=3, stride=1, padding=1, bias=False)), 147 | self.drop_rate = drop_rate 148 | 149 | def forward(self, x): 150 | new_features = super(_DenseLayer, self).forward(x) 151 | if self.drop_rate > 0: 152 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 153 | return torch.cat([x, new_features], 1) 154 | 155 | 156 | class _DenseBlock(nn.Sequential): 157 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 158 | super(_DenseBlock, self).__init__() 159 | for i in range(num_layers): 160 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) 161 | self.add_module('denselayer%d' % (i + 1), layer) 162 | 163 | 164 | class _Transition(nn.Sequential): 165 | def __init__(self, num_input_features, num_output_features): 166 | super(_Transition, self).__init__() 167 | self.add_module("selayer", SELayer(channel=num_input_features)) 168 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 169 | self.add_module('relu', nn.ReLU(inplace=True)) 170 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 171 | kernel_size=1, stride=1, bias=False)) 172 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 173 | 174 | 175 | class SEDenseNet(nn.Module): 176 | r"""Densenet-BC model class, based on 177 | `"Densely Connected Convolutional Networks" `_ 178 | 179 | Args: 180 | growth_rate (int) - how many filters to add each layer (`k` in paper) 181 | block_config (list of 4 ints) - how many layers in each pooling block 182 | num_init_features (int) - the number of filters to learn in the first convolution layer 183 | bn_size (int) - multiplicative factor for number of bottle neck layers 184 | (i.e. bn_size * k features in the bottleneck layer) 185 | drop_rate (float) - dropout rate after each dense layer 186 | num_classes (int) - number of classification classes 187 | """ 188 | 189 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 190 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): 191 | 192 | super(SEDenseNet, self).__init__() 193 | 194 | # First convolution 195 | self.features = nn.Sequential(OrderedDict([ 196 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 197 | ('norm0', nn.BatchNorm2d(num_init_features)), 198 | ('relu0', nn.ReLU(inplace=True)), 199 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 200 | ])) 201 | 202 | # Add SELayer at first convolution 203 | # self.features.add_module("SELayer_0a", SELayer(channel=num_init_features)) 204 | 205 | # Each denseblock 206 | num_features = num_init_features 207 | for i, num_layers in enumerate(block_config): 208 | # Add a SELayer 209 | # self.features.add_module("SELayer_%da" % (i + 1), SELayer(channel=num_features)) 210 | 211 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 212 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 213 | self.features.add_module('denseblock%d' % (i + 1), block) 214 | 215 | num_features = num_features + num_layers * growth_rate 216 | 217 | if i != len(block_config) - 1: 218 | # Add a SELayer behind each transition block 219 | # self.features.add_module("SELayer_%db" % (i + 1), SELayer(channel=num_features)) 220 | 221 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) 222 | self.features.add_module('transition%d' % (i + 1), trans) 223 | num_features = num_features // 2 224 | 225 | # Final batch norm 226 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 227 | 228 | # Add SELayer 229 | # self.features.add_module("SELayer_0b", SELayer(channel=num_features)) 230 | 231 | # Linear layer 232 | self.classifier = nn.Linear(num_features, num_classes) 233 | 234 | # Official init from torch repo. 235 | for m in self.modules(): 236 | if isinstance(m, nn.Conv2d): 237 | nn.init.kaiming_normal_(m.weight) 238 | elif isinstance(m, nn.BatchNorm2d): 239 | nn.init.constant_(m.weight, 1) 240 | nn.init.constant_(m.bias, 0) 241 | elif isinstance(m, nn.Linear): 242 | nn.init.constant_(m.bias, 0) 243 | 244 | def forward(self, x): 245 | features = self.features(x) 246 | out = F.relu(features, inplace=True) 247 | out = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1) 248 | out = self.classifier(out) 249 | return out 250 | 251 | 252 | 253 | def test_se_densenet(pretrained=False): 254 | X = torch.Tensor(32, 3, 224, 224) 255 | 256 | if pretrained: 257 | model = se_densenet121(pretrained=pretrained) 258 | net_state_dict = {key: value for key, value in model_zoo.load_url("https://download.pytorch.org/models/densenet121-a639ec97.pth").items()} 259 | model.load_state_dict(net_state_dict, strict=False) 260 | 261 | else: 262 | model = se_densenet121(pretrained=pretrained) 263 | 264 | print(model) 265 | if torch.cuda.is_available(): 266 | X = X.cuda() 267 | model = model.cuda() 268 | model.eval() 269 | with torch.no_grad(): 270 | output = model(X) 271 | print(output.shape) 272 | 273 | 274 | if __name__ == "__main__": 275 | test_se_densenet() -------------------------------------------------------------------------------- /core/se_densenet_full_in_loop.py: -------------------------------------------------------------------------------- 1 | """添加了senet模块,loop方式添加senet,不在transition和denselayer中加入senet""" 2 | import sys 3 | sys.path.append("F:/car_classify_abnormal") 4 | 5 | import re 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.model_zoo as model_zoo 10 | from collections import OrderedDict 11 | from core.se_module import SELayer 12 | 13 | 14 | __all__ = ['SEDenseNet', 'se_densenet121', 'se_densenet169', 'se_densenet201', 'se_densenet161'] 15 | 16 | 17 | model_urls = { 18 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 19 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 20 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 21 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 22 | } 23 | 24 | 25 | def se_densenet121(pretrained=False, is_strict=False, **kwargs): 26 | r"""Densenet-121 model from 27 | `"Densely Connected Convolutional Networks" `_ 28 | 29 | Args: 30 | pretrained (bool): If True, returns a model pre-trained on ImageNet 31 | """ 32 | model = SEDenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), 33 | **kwargs) 34 | if pretrained: 35 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 36 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 37 | # They are also in the checkpoints in model_urls. This pattern is used 38 | # to find such keys. 39 | pattern = re.compile( 40 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 41 | state_dict = model_zoo.load_url(model_urls['densenet121']) 42 | for key in list(state_dict.keys()): 43 | res = pattern.match(key) 44 | if res: 45 | new_key = res.group(1) + res.group(2) 46 | state_dict[new_key] = state_dict[key] 47 | del state_dict[key] 48 | model.load_state_dict(state_dict, strict=is_strict) 49 | return model 50 | 51 | 52 | def se_densenet169(pretrained=False, **kwargs): 53 | r"""Densenet-169 model from 54 | `"Densely Connected Convolutional Networks" `_ 55 | 56 | Args: 57 | pretrained (bool): If True, returns a model pre-trained on ImageNet 58 | """ 59 | model = SEDenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 60 | **kwargs) 61 | if pretrained: 62 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 63 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 64 | # They are also in the checkpoints in model_urls. This pattern is used 65 | # to find such keys. 66 | pattern = re.compile( 67 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 68 | state_dict = model_zoo.load_url(model_urls['densenet169']) 69 | for key in list(state_dict.keys()): 70 | res = pattern.match(key) 71 | if res: 72 | new_key = res.group(1) + res.group(2) 73 | state_dict[new_key] = state_dict[key] 74 | del state_dict[key] 75 | model.load_state_dict(state_dict, strict=False) 76 | return model 77 | 78 | 79 | def se_densenet201(pretrained=False, **kwargs): 80 | r"""Densenet-201 model from 81 | `"Densely Connected Convolutional Networks" `_ 82 | 83 | Args: 84 | pretrained (bool): If True, returns a model pre-trained on ImageNet 85 | """ 86 | model = SEDenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 87 | **kwargs) 88 | if pretrained: 89 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 90 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 91 | # They are also in the checkpoints in model_urls. This pattern is used 92 | # to find such keys. 93 | pattern = re.compile( 94 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 95 | state_dict = model_zoo.load_url(model_urls['densenet201']) 96 | for key in list(state_dict.keys()): 97 | res = pattern.match(key) 98 | if res: 99 | new_key = res.group(1) + res.group(2) 100 | state_dict[new_key] = state_dict[key] 101 | del state_dict[key] 102 | model.load_state_dict(state_dict, strict=False) 103 | return model 104 | 105 | 106 | def se_densenet161(pretrained=False, **kwargs): 107 | r"""Densenet-161 model from 108 | `"Densely Connected Convolutional Networks" `_ 109 | 110 | Args: 111 | pretrained (bool): If True, returns a model pre-trained on ImageNet 112 | """ 113 | model = SEDenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 114 | **kwargs) 115 | if pretrained: 116 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 117 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 118 | # They are also in the checkpoints in model_urls. This pattern is used 119 | # to find such keys. 120 | pattern = re.compile( 121 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 122 | state_dict = model_zoo.load_url(model_urls['densenet161']) 123 | for key in list(state_dict.keys()): 124 | res = pattern.match(key) 125 | if res: 126 | new_key = res.group(1) + res.group(2) 127 | state_dict[new_key] = state_dict[key] 128 | del state_dict[key] 129 | model.load_state_dict(state_dict, strict=False) 130 | return model 131 | 132 | 133 | class _DenseLayer(nn.Sequential): 134 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 135 | super(_DenseLayer, self).__init__() 136 | # Add SELayer at here, like SE-PRE block in original paper illustrates 137 | # self.add_module("selayer", SELayer(channel=num_input_features)), 138 | 139 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 140 | self.add_module('relu1', nn.ReLU(inplace=True)), 141 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 142 | growth_rate, kernel_size=1, stride=1, bias=False)), 143 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 144 | self.add_module('relu2', nn.ReLU(inplace=True)), 145 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 146 | kernel_size=3, stride=1, padding=1, bias=False)), 147 | self.drop_rate = drop_rate 148 | 149 | def forward(self, x): 150 | new_features = super(_DenseLayer, self).forward(x) 151 | if self.drop_rate > 0: 152 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 153 | return torch.cat([x, new_features], 1) 154 | 155 | 156 | class _DenseBlock(nn.Sequential): 157 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 158 | super(_DenseBlock, self).__init__() 159 | for i in range(num_layers): 160 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) 161 | self.add_module('denselayer%d' % (i + 1), layer) 162 | 163 | 164 | class _Transition(nn.Sequential): 165 | def __init__(self, num_input_features, num_output_features): 166 | super(_Transition, self).__init__() 167 | # self.add_module("selayer", SELayer(channel=num_input_features)) 168 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 169 | self.add_module('relu', nn.ReLU(inplace=True)) 170 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 171 | kernel_size=1, stride=1, bias=False)) 172 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 173 | 174 | 175 | class SEDenseNet(nn.Module): 176 | r"""Densenet-BC model class, based on 177 | `"Densely Connected Convolutional Networks" `_ 178 | 179 | Args: 180 | growth_rate (int) - how many filters to add each layer (`k` in paper) 181 | block_config (list of 4 ints) - how many layers in each pooling block 182 | num_init_features (int) - the number of filters to learn in the first convolution layer 183 | bn_size (int) - multiplicative factor for number of bottle neck layers 184 | (i.e. bn_size * k features in the bottleneck layer) 185 | drop_rate (float) - dropout rate after each dense layer 186 | num_classes (int) - number of classification classes 187 | """ 188 | 189 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 190 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): 191 | 192 | super(SEDenseNet, self).__init__() 193 | 194 | # First convolution 195 | self.features = nn.Sequential(OrderedDict([ 196 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 197 | ('norm0', nn.BatchNorm2d(num_init_features)), 198 | ('relu0', nn.ReLU(inplace=True)), 199 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 200 | ])) 201 | 202 | # Add SELayer at first convolution 203 | # self.features.add_module("SELayer_0a", SELayer(channel=num_init_features)) 204 | 205 | # Each denseblock 206 | num_features = num_init_features 207 | for i, num_layers in enumerate(block_config): 208 | # Add a SELayer 209 | self.features.add_module("SELayer_%da" % (i + 1), SELayer(channel=num_features)) 210 | 211 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 212 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 213 | self.features.add_module('denseblock%d' % (i + 1), block) 214 | 215 | num_features = num_features + num_layers * growth_rate 216 | 217 | if i != len(block_config) - 1: 218 | # Add a SELayer behind each transition block 219 | self.features.add_module("SELayer_%db" % (i + 1), SELayer(channel=num_features)) 220 | 221 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) 222 | self.features.add_module('transition%d' % (i + 1), trans) 223 | num_features = num_features // 2 224 | 225 | # Final batch norm 226 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 227 | 228 | # Add SELayer 229 | # self.features.add_module("SELayer_0b", SELayer(channel=num_features)) 230 | 231 | # Linear layer 232 | self.classifier = nn.Linear(num_features, num_classes) 233 | 234 | # Official init from torch repo. 235 | for m in self.modules(): 236 | if isinstance(m, nn.Conv2d): 237 | nn.init.kaiming_normal_(m.weight) 238 | elif isinstance(m, nn.BatchNorm2d): 239 | nn.init.constant_(m.weight, 1) 240 | nn.init.constant_(m.bias, 0) 241 | elif isinstance(m, nn.Linear): 242 | nn.init.constant_(m.bias, 0) 243 | 244 | def forward(self, x): 245 | features = self.features(x) 246 | out = F.relu(features, inplace=True) 247 | out = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1) 248 | out = self.classifier(out) 249 | return out 250 | 251 | 252 | 253 | def test_se_densenet(pretrained=False): 254 | X = torch.Tensor(32, 3, 224, 224) 255 | 256 | if pretrained: 257 | model = se_densenet121(pretrained=pretrained) 258 | net_state_dict = {key: value for key, value in model_zoo.load_url("https://download.pytorch.org/models/densenet121-a639ec97.pth").items()} 259 | model.load_state_dict(net_state_dict, strict=False) 260 | 261 | else: 262 | model = se_densenet121(pretrained=pretrained) 263 | 264 | print(model) 265 | if torch.cuda.is_available(): 266 | X = X.cuda() 267 | model = model.cuda() 268 | model.eval() 269 | with torch.no_grad(): 270 | output = model(X) 271 | print(output.shape) 272 | 273 | 274 | if __name__ == "__main__": 275 | test_se_densenet() -------------------------------------------------------------------------------- /core/se_densenet_w_block.py: -------------------------------------------------------------------------------- 1 | """添加了senet模块,去除transition的senet模块,仅仅在denseblock中添加senet模块""" 2 | import sys 3 | sys.path.append("F:/car_classify_abnormal") 4 | 5 | import re 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.model_zoo as model_zoo 10 | from collections import OrderedDict 11 | from core.se_module import SELayer 12 | 13 | 14 | __all__ = ['SEDenseNet', 'se_densenet121', 'se_densenet169', 'se_densenet201', 'se_densenet161'] 15 | 16 | 17 | model_urls = { 18 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 19 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 20 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 21 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 22 | } 23 | 24 | 25 | def se_densenet121(pretrained=False, is_strict=False, **kwargs): 26 | r"""Densenet-121 model from 27 | `"Densely Connected Convolutional Networks" `_ 28 | 29 | Args: 30 | pretrained (bool): If True, returns a model pre-trained on ImageNet 31 | """ 32 | model = SEDenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), 33 | **kwargs) 34 | if pretrained: 35 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 36 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 37 | # They are also in the checkpoints in model_urls. This pattern is used 38 | # to find such keys. 39 | pattern = re.compile( 40 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 41 | state_dict = model_zoo.load_url(model_urls['densenet121']) 42 | for key in list(state_dict.keys()): 43 | res = pattern.match(key) 44 | if res: 45 | new_key = res.group(1) + res.group(2) 46 | state_dict[new_key] = state_dict[key] 47 | del state_dict[key] 48 | model.load_state_dict(state_dict, strict=is_strict) 49 | return model 50 | 51 | 52 | def se_densenet169(pretrained=False, **kwargs): 53 | r"""Densenet-169 model from 54 | `"Densely Connected Convolutional Networks" `_ 55 | 56 | Args: 57 | pretrained (bool): If True, returns a model pre-trained on ImageNet 58 | """ 59 | model = SEDenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 60 | **kwargs) 61 | if pretrained: 62 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 63 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 64 | # They are also in the checkpoints in model_urls. This pattern is used 65 | # to find such keys. 66 | pattern = re.compile( 67 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 68 | state_dict = model_zoo.load_url(model_urls['densenet169']) 69 | for key in list(state_dict.keys()): 70 | res = pattern.match(key) 71 | if res: 72 | new_key = res.group(1) + res.group(2) 73 | state_dict[new_key] = state_dict[key] 74 | del state_dict[key] 75 | model.load_state_dict(state_dict, strict=False) 76 | return model 77 | 78 | 79 | def se_densenet201(pretrained=False, **kwargs): 80 | r"""Densenet-201 model from 81 | `"Densely Connected Convolutional Networks" `_ 82 | 83 | Args: 84 | pretrained (bool): If True, returns a model pre-trained on ImageNet 85 | """ 86 | model = SEDenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 87 | **kwargs) 88 | if pretrained: 89 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 90 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 91 | # They are also in the checkpoints in model_urls. This pattern is used 92 | # to find such keys. 93 | pattern = re.compile( 94 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 95 | state_dict = model_zoo.load_url(model_urls['densenet201']) 96 | for key in list(state_dict.keys()): 97 | res = pattern.match(key) 98 | if res: 99 | new_key = res.group(1) + res.group(2) 100 | state_dict[new_key] = state_dict[key] 101 | del state_dict[key] 102 | model.load_state_dict(state_dict, strict=False) 103 | return model 104 | 105 | 106 | def se_densenet161(pretrained=False, **kwargs): 107 | r"""Densenet-161 model from 108 | `"Densely Connected Convolutional Networks" `_ 109 | 110 | Args: 111 | pretrained (bool): If True, returns a model pre-trained on ImageNet 112 | """ 113 | model = SEDenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 114 | **kwargs) 115 | if pretrained: 116 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 117 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 118 | # They are also in the checkpoints in model_urls. This pattern is used 119 | # to find such keys. 120 | pattern = re.compile( 121 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 122 | state_dict = model_zoo.load_url(model_urls['densenet161']) 123 | for key in list(state_dict.keys()): 124 | res = pattern.match(key) 125 | if res: 126 | new_key = res.group(1) + res.group(2) 127 | state_dict[new_key] = state_dict[key] 128 | del state_dict[key] 129 | model.load_state_dict(state_dict, strict=False) 130 | return model 131 | 132 | 133 | class _DenseLayer(nn.Sequential): 134 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 135 | super(_DenseLayer, self).__init__() 136 | # Add SELayer at here, like SE-PRE block in original paper illustrates 137 | self.add_module("selayer", SELayer(channel=num_input_features)), 138 | 139 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 140 | self.add_module('relu1', nn.ReLU(inplace=True)), 141 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 142 | growth_rate, kernel_size=1, stride=1, bias=False)), 143 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 144 | self.add_module('relu2', nn.ReLU(inplace=True)), 145 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 146 | kernel_size=3, stride=1, padding=1, bias=False)), 147 | self.drop_rate = drop_rate 148 | 149 | def forward(self, x): 150 | new_features = super(_DenseLayer, self).forward(x) 151 | if self.drop_rate > 0: 152 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 153 | return torch.cat([x, new_features], 1) 154 | 155 | 156 | class _DenseBlock(nn.Sequential): 157 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 158 | super(_DenseBlock, self).__init__() 159 | for i in range(num_layers): 160 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) 161 | self.add_module('denselayer%d' % (i + 1), layer) 162 | 163 | 164 | class _Transition(nn.Sequential): 165 | def __init__(self, num_input_features, num_output_features): 166 | super(_Transition, self).__init__() 167 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 168 | self.add_module('relu', nn.ReLU(inplace=True)) 169 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 170 | kernel_size=1, stride=1, bias=False)) 171 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 172 | 173 | 174 | class SEDenseNet(nn.Module): 175 | r"""Densenet-BC model class, based on 176 | `"Densely Connected Convolutional Networks" `_ 177 | 178 | Args: 179 | growth_rate (int) - how many filters to add each layer (`k` in paper) 180 | block_config (list of 4 ints) - how many layers in each pooling block 181 | num_init_features (int) - the number of filters to learn in the first convolution layer 182 | bn_size (int) - multiplicative factor for number of bottle neck layers 183 | (i.e. bn_size * k features in the bottleneck layer) 184 | drop_rate (float) - dropout rate after each dense layer 185 | num_classes (int) - number of classification classes 186 | """ 187 | 188 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 189 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): 190 | 191 | super(SEDenseNet, self).__init__() 192 | 193 | # First convolution 194 | self.features = nn.Sequential(OrderedDict([ 195 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 196 | ('norm0', nn.BatchNorm2d(num_init_features)), 197 | ('relu0', nn.ReLU(inplace=True)), 198 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 199 | ])) 200 | 201 | # Add SELayer at first convolution 202 | # self.features.add_module("SELayer_0a", SELayer(channel=num_init_features)) 203 | 204 | # Each denseblock 205 | num_features = num_init_features 206 | for i, num_layers in enumerate(block_config): 207 | # Add a SELayer 208 | # self.features.add_module("SELayer_%da" % (i + 1), SELayer(channel=num_features)) 209 | 210 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 211 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 212 | self.features.add_module('denseblock%d' % (i + 1), block) 213 | 214 | num_features = num_features + num_layers * growth_rate 215 | 216 | if i != len(block_config) - 1: 217 | # Add a SELayer behind each transition block 218 | # self.features.add_module("SELayer_%db" % (i + 1), SELayer(channel=num_features)) 219 | 220 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) 221 | self.features.add_module('transition%d' % (i + 1), trans) 222 | num_features = num_features // 2 223 | 224 | # Final batch norm 225 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 226 | 227 | # Add SELayer 228 | # self.features.add_module("SELayer_0b", SELayer(channel=num_features)) 229 | 230 | # Linear layer 231 | self.classifier = nn.Linear(num_features, num_classes) 232 | 233 | # Official init from torch repo. 234 | for m in self.modules(): 235 | if isinstance(m, nn.Conv2d): 236 | nn.init.kaiming_normal_(m.weight) 237 | elif isinstance(m, nn.BatchNorm2d): 238 | nn.init.constant_(m.weight, 1) 239 | nn.init.constant_(m.bias, 0) 240 | elif isinstance(m, nn.Linear): 241 | nn.init.constant_(m.bias, 0) 242 | 243 | def forward(self, x): 244 | features = self.features(x) 245 | out = F.relu(features, inplace=True) 246 | out = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1) 247 | out = self.classifier(out) 248 | return out 249 | 250 | 251 | def test_se_densenet(pretrained=False): 252 | X = torch.Tensor(32, 3, 224, 224) 253 | 254 | if pretrained: 255 | model = se_densenet121(pretrained=pretrained) 256 | net_state_dict = {key: value for key, value in model_zoo.load_url("https://download.pytorch.org/models/densenet121-a639ec97.pth").items()} 257 | model.load_state_dict(net_state_dict, strict=False) 258 | 259 | else: 260 | model = se_densenet121(pretrained=pretrained) 261 | 262 | print(model) 263 | if torch.cuda.is_available(): 264 | X = X.cuda() 265 | model = model.cuda() 266 | model.eval() 267 | with torch.no_grad(): 268 | output = model(X) 269 | print(output.shape) 270 | 271 | if __name__ == "__main__": 272 | test_se_densenet() -------------------------------------------------------------------------------- /core/se_efficient_densenet.py: -------------------------------------------------------------------------------- 1 | """This implementation is based on the DenseNet-BC implementation in torchvision 2 | 1. using pytorch.utils.checkpoint to reduce memory comsumption 3 | 2. add senet module 4 | """ 5 | import sys 6 | sys.path.append("F:/car_classify_abnormal") 7 | 8 | import math 9 | import re 10 | from collections import OrderedDict 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.utils.checkpoint as cp 16 | import torch.utils.model_zoo as model_zoo 17 | from core.se_module import SELayer 18 | 19 | 20 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 21 | 22 | 23 | model_urls = { 24 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 25 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 26 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 27 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 28 | } 29 | 30 | 31 | def densenet121(pretrained=False, is_strict=False, is_efficient=True, **kwargs): 32 | r"""Densenet-121 model from 33 | `"Densely Connected Convolutional Networks" `_ 34 | 35 | Args: 36 | pretrained (bool): If True, returns a model pre-trained on ImageNet 37 | """ 38 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), 39 | num_classes=4096, efficient=is_efficient, **kwargs) 40 | if pretrained: 41 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 42 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 43 | # They are also in the checkpoints in model_urls. This pattern is used 44 | # to find such keys. 45 | pattern = re.compile( 46 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 47 | state_dict = {key: value for key, value in model_zoo.load_url(model_urls['densenet121']).items() if "classifier" not in key} 48 | for key in list(state_dict.keys()): 49 | res = pattern.match(key) 50 | if res: 51 | new_key = res.group(1) + res.group(2) 52 | state_dict[new_key] = state_dict[key] 53 | del state_dict[key] 54 | model.load_state_dict(state_dict, strict=is_strict) 55 | return model 56 | 57 | 58 | def densenet169(pretrained=False, is_strict=False, is_efficient=True, **kwargs): 59 | r"""Densenet-169 model from 60 | `"Densely Connected Convolutional Networks" `_ 61 | 62 | Args: 63 | pretrained (bool): If True, returns a model pre-trained on ImageNet 64 | """ 65 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 66 | efficient=is_efficient, **kwargs) 67 | if pretrained: 68 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 69 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 70 | # They are also in the checkpoints in model_urls. This pattern is used 71 | # to find such keys. 72 | pattern = re.compile( 73 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 74 | state_dict = model_zoo.load_url(model_urls['densenet169']) 75 | for key in list(state_dict.keys()): 76 | res = pattern.match(key) 77 | if res: 78 | new_key = res.group(1) + res.group(2) 79 | state_dict[new_key] = state_dict[key] 80 | del state_dict[key] 81 | model.load_state_dict(state_dict, strict=is_strict) 82 | return model 83 | 84 | 85 | def densenet201(pretrained=False, is_strict=False, is_efficient=True, **kwargs): 86 | r"""Densenet-201 model from 87 | `"Densely Connected Convolutional Networks" `_ 88 | 89 | Args: 90 | pretrained (bool): If True, returns a model pre-trained on ImageNet 91 | """ 92 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 93 | efficient=is_efficient, **kwargs) 94 | if pretrained: 95 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 96 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 97 | # They are also in the checkpoints in model_urls. This pattern is used 98 | # to find such keys. 99 | pattern = re.compile( 100 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 101 | state_dict = model_zoo.load_url(model_urls['densenet201']) 102 | for key in list(state_dict.keys()): 103 | res = pattern.match(key) 104 | if res: 105 | new_key = res.group(1) + res.group(2) 106 | state_dict[new_key] = state_dict[key] 107 | del state_dict[key] 108 | model.load_state_dict(state_dict, strict=is_strict) 109 | return model 110 | 111 | 112 | def densenet161(pretrained=False, is_strict=False, is_efficient=True, **kwargs): 113 | r"""Densenet-161 model from 114 | `"Densely Connected Convolutional Networks" `_ 115 | 116 | Args: 117 | pretrained (bool): If True, returns a model pre-trained on ImageNet 118 | """ 119 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 120 | efficient=is_efficient, **kwargs) 121 | if pretrained: 122 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 123 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 124 | # They are also in the checkpoints in model_urls. This pattern is used 125 | # to find such keys. 126 | pattern = re.compile( 127 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 128 | state_dict = model_zoo.load_url(model_urls['densenet161']) 129 | for key in list(state_dict.keys()): 130 | res = pattern.match(key) 131 | if res: 132 | new_key = res.group(1) + res.group(2) 133 | state_dict[new_key] = state_dict[key] 134 | del state_dict[key] 135 | model.load_state_dict(state_dict, strict=is_strict) 136 | return model 137 | 138 | 139 | def _bn_function_factory(norm, relu, conv): 140 | def bn_function(*inputs): 141 | concated_features = torch.cat(inputs, 1) 142 | bottleneck_output = conv(relu(norm(concated_features))) 143 | return bottleneck_output 144 | return bn_function 145 | 146 | 147 | class _DenseLayer(nn.Module): 148 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, 149 | efficient=False): 150 | super(_DenseLayer, self).__init__() 151 | # Add SELayer at here, like SE-PRE block in original paper illustrates 152 | self.add_module("selayer", SELayer(channel=num_input_features)), 153 | 154 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 155 | self.add_module('relu1', nn.ReLU(inplace=True)), 156 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 157 | growth_rate, kernel_size=1, stride=1, bias=False)), 158 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 159 | self.add_module('relu2', nn.ReLU(inplace=True)), 160 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 161 | kernel_size=3, stride=1, padding=1, bias=False)), 162 | self.drop_rate = drop_rate 163 | self.efficient = efficient 164 | 165 | def forward(self, *prev_features): 166 | """原有的两次BN层需要消耗的两块显存空间, 167 | 通过使用checkpoint,实现了只开辟一块空间用来存储中间特征 168 | """ 169 | bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) 170 | # requires_grad is True means that model is in train status 171 | # checkpoint implement shared memory storage function 172 | 173 | if self.efficient and any(prev_feature.requires_grad for prev_feature in prev_features): 174 | bottleneck_output = cp.checkpoint(bn_function, *prev_features) 175 | else: 176 | bottleneck_output = bn_function(*prev_features) 177 | 178 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) 179 | 180 | if self.drop_rate > 0: 181 | new_features = F.dropout(new_features, 182 | p=self.drop_rate, 183 | training=self.training 184 | ) 185 | return new_features 186 | 187 | class _DenseBlock(nn.Module): 188 | def __init__(self, 189 | num_layers, 190 | num_input_features, 191 | bn_size, 192 | growth_rate, 193 | drop_rate, 194 | efficient=False): 195 | super(_DenseBlock, self).__init__() 196 | for i in range(num_layers): 197 | layer = _DenseLayer( 198 | num_input_features + i * growth_rate, 199 | growth_rate=growth_rate, 200 | bn_size=bn_size, 201 | drop_rate=drop_rate, 202 | efficient=efficient, 203 | ) 204 | 205 | self.add_module('denselayer%d' % (i + 1), layer) 206 | 207 | 208 | def forward(self, init_features): 209 | features = [init_features] 210 | for name, layer in self.named_children(): 211 | new_features = layer(*features) 212 | features.append(new_features) 213 | return torch.cat(features, 1) 214 | 215 | 216 | class _Transition(nn.Sequential): 217 | def __init__(self, num_input_features, num_output_features): 218 | super(_Transition, self).__init__() 219 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 220 | self.add_module('relu', nn.ReLU(inplace=True)) 221 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 222 | kernel_size=1, stride=1, bias=False)) 223 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 224 | 225 | 226 | class DenseNet(nn.Module): 227 | r"""Densenet-BC model class, based on 228 | `"Densely Connected Convolutional Networks" ` 229 | Args: 230 | growth_rate (int) - how many filters to add each layer (`k` in paper) 231 | block_config (list of 3 or 4 ints) - how many layers in each pooling block 232 | num_init_features (int) - the number of filters to learn in the first convolution layer 233 | bn_size (int) - multiplicative factor for number of bottle neck layers 234 | (i.e. bn_size * k features in the bottleneck layer) 235 | drop_rate (float) - dropout rate after each dense layer 236 | num_classes (int) - number of classification classes 237 | small_inputs (bool) - set to True if images are 32x32. Otherwise assumes images are larger. 238 | efficient (bool) - set to True to use checkpointing. Much more memory efficient, but slower. 239 | """ 240 | def __init__(self, growth_rate=12, block_config=(16, 16, 16), compression=0.5, 241 | num_init_features=24, bn_size=4, drop_rate=0, 242 | num_classes=4096, efficient=True): 243 | 244 | super(DenseNet, self).__init__() 245 | assert 0 < compression <= 1, 'compression of densenet should be between 0 and 1' 246 | 247 | # First convolution 248 | self.features = nn.Sequential(OrderedDict([ 249 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 250 | ('norm0', nn.BatchNorm2d(num_init_features)), 251 | ('relu0', nn.ReLU(inplace=True)), 252 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 253 | ])) 254 | 255 | # Each denseblock 256 | num_features = num_init_features 257 | 258 | for i, num_layers in enumerate(block_config): 259 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 260 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate, efficient=efficient, 261 | ) 262 | 263 | self.features.add_module('denseblock%d' % (i + 1), block) 264 | 265 | num_features = num_features + num_layers * growth_rate 266 | 267 | if i != len(block_config) - 1: 268 | # Add a SELayer behind each transition block 269 | self.features.add_module("SELayer_%da" % (i + 1), SELayer(channel=num_features)) 270 | 271 | trans = _Transition(num_input_features=num_features, 272 | num_output_features=int(num_features * compression)) 273 | self.features.add_module('transition%d' % (i + 1), trans) 274 | num_features = num_features // 2 275 | 276 | # Final batch norm 277 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 278 | 279 | # Linear layer 280 | self.classifier = nn.Linear(num_features, num_classes) 281 | 282 | # Official init from torch repo. 283 | for m in self.modules(): 284 | if isinstance(m, nn.Conv2d): 285 | nn.init.kaiming_normal_(m.weight) 286 | elif isinstance(m, nn.BatchNorm2d): 287 | nn.init.constant_(m.weight, 1) 288 | nn.init.constant_(m.bias, 0) 289 | elif isinstance(m, nn.Linear): 290 | nn.init.constant_(m.bias, 0) 291 | 292 | 293 | def forward(self, x): 294 | features = self.features(x) 295 | out = F.relu(features, inplace=True) 296 | out = F.avg_pool2d(out, kernel_size=7).view(features.size(0), -1) 297 | out = self.classifier(out) 298 | return out 299 | 300 | 301 | if __name__ == "__main__": 302 | # X = torch.Tensor(1, 3, 224, 224) 303 | X = torch.zeros([1, 3, 224, 224]) 304 | net = densenet121(pretrained=True, is_strict=False, is_efficient=True) 305 | print(net) 306 | 307 | if torch.cuda.is_available(): 308 | X = X.cuda() 309 | net = net.cuda() 310 | 311 | net.eval() 312 | with torch.no_grad(): 313 | output = net(X) 314 | print(output.shape) 315 | 316 | # print(net) -------------------------------------------------------------------------------- /core/se_module.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class SELayer(nn.Module): 5 | def __init__(self, channel, reduction=16): 6 | assert channel > reduction, "Make sure your input channel bigger than reduction which equals to {}".format(reduction) 7 | super(SELayer, self).__init__() 8 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 9 | self.fc = nn.Sequential( 10 | nn.Linear(channel, channel // reduction), 11 | nn.ReLU(inplace=True), 12 | nn.Linear(channel // reduction, channel), 13 | nn.Sigmoid() 14 | ) 15 | 16 | def forward(self, x): 17 | b, c, _, _ = x.size() 18 | y = self.avg_pool(x).view(b, c) 19 | y = self.fc(y).view(b, c, 1, 1) 20 | return x * y 21 | -------------------------------------------------------------------------------- /core/test_se_densenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.model_zoo as model_zoo 3 | from se_densenet import se_densenet121 4 | # from official repo import densenet 5 | from torchvision.models.densenet import densenet121 6 | 7 | 8 | def test_se_densenet(pretrained=False): 9 | X = torch.Tensor(32, 3, 224, 224) 10 | 11 | if pretrained: 12 | model = se_densenet121(pretrained=pretrained) 13 | net_state_dict = {key: value for key, value in model_zoo.load_url("https://download.pytorch.org/models/densenet121-a639ec97.pth").items()} 14 | model.load_state_dict(net_state_dict, strict=False) 15 | 16 | else: 17 | model = se_densenet121(pretrained=pretrained) 18 | 19 | # print(model) 20 | if torch.cuda.is_available(): 21 | X = X.cuda() 22 | model = model.cuda() 23 | model.eval() 24 | with torch.no_grad(): 25 | output = model(X) 26 | print(output.shape) 27 | 28 | 29 | def test_densenet(): 30 | """create example tensor data for densenet, and print output variable shape""" 31 | X = torch.Tensor(32, 3, 224, 224) 32 | 33 | model = densenet121(pretrained=False) 34 | 35 | if torch.cuda.is_available(): 36 | model = model.cuda() 37 | X = X.cuda() 38 | 39 | model.eval() 40 | with torch.no_grad(): 41 | output = model(X) 42 | print(output.shape) 43 | 44 | 45 | if __name__ == "__main__": 46 | test_se_densenet(pretrained=True) -------------------------------------------------------------------------------- /data/cifar10/download_cifar10_at_here.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/data/cifar10/download_cifar10_at_here.md -------------------------------------------------------------------------------- /state/state_baseline.txt: -------------------------------------------------------------------------------- 1 | {'epoch': 1, 'accuracy': 0.3869, 'loss': 1.7070355621902542, 'mode': 'train'} 2 | {'epoch': 1, 'accuracy': 0.5265, 'loss': 1.308528982529974, 'mode': 'test'} 3 | {'epoch': 2, 'accuracy': 0.62678, 'loss': 1.0495631145241913, 'mode': 'train'} 4 | {'epoch': 2, 'accuracy': 0.6939, 'loss': 0.9471340194629256, 'mode': 'test'} 5 | {'epoch': 3, 'accuracy': 0.74444, 'loss': 0.7294200721299253, 'mode': 'train'} 6 | {'epoch': 3, 'accuracy': 0.7594, 'loss': 0.7061205849905681, 'mode': 'test'} 7 | {'epoch': 4, 'accuracy': 0.8018, 'loss': 0.5735215999738643, 'mode': 'train'} 8 | {'epoch': 4, 'accuracy': 0.8157, 'loss': 0.5480034189998725, 'mode': 'test'} 9 | {'epoch': 5, 'accuracy': 0.82898, 'loss': 0.491147505238538, 'mode': 'train'} 10 | {'epoch': 5, 'accuracy': 0.7986, 'loss': 0.8200758986032692, 'mode': 'test'} 11 | {'epoch': 6, 'accuracy': 0.852, 'loss': 0.43055774563032606, 'mode': 'train'} 12 | {'epoch': 6, 'accuracy': 0.8365, 'loss': 0.4827959565979661, 'mode': 'test'} 13 | {'epoch': 7, 'accuracy': 0.86268, 'loss': 0.397368783860103, 'mode': 'train'} 14 | {'epoch': 7, 'accuracy': 0.8431, 'loss': 0.45875984610645626, 'mode': 'test'} 15 | {'epoch': 8, 'accuracy': 0.875, 'loss': 0.3581551666683549, 'mode': 'train'} 16 | {'epoch': 8, 'accuracy': 0.8443, 'loss': 0.45463992883065696, 'mode': 'test'} 17 | {'epoch': 9, 'accuracy': 0.88666, 'loss': 0.3282339582052992, 'mode': 'train'} 18 | {'epoch': 9, 'accuracy': 0.8353, 'loss': 0.4710612428036464, 'mode': 'test'} 19 | {'epoch': 10, 'accuracy': 0.89548, 'loss': 0.3009887498415186, 'mode': 'train'} 20 | {'epoch': 10, 'accuracy': 0.8462, 'loss': 0.45486110211557623, 'mode': 'test'} 21 | {'epoch': 11, 'accuracy': 0.9005, 'loss': 0.2867941619528226, 'mode': 'train'} 22 | {'epoch': 11, 'accuracy': 0.8544, 'loss': 0.4307156290597978, 'mode': 'test'} 23 | {'epoch': 12, 'accuracy': 0.9068, 'loss': 0.2694823994394154, 'mode': 'train'} 24 | {'epoch': 12, 'accuracy': 0.8596, 'loss': 0.42541023965474123, 'mode': 'test'} 25 | {'epoch': 13, 'accuracy': 0.90816, 'loss': 0.26439612920936734, 'mode': 'train'} 26 | {'epoch': 13, 'accuracy': 0.8725, 'loss': 0.38515629557667286, 'mode': 'test'} 27 | {'epoch': 14, 'accuracy': 0.914, 'loss': 0.24470223989480605, 'mode': 'train'} 28 | {'epoch': 14, 'accuracy': 0.878, 'loss': 0.3803155615830876, 'mode': 'test'} 29 | {'epoch': 15, 'accuracy': 0.91906, 'loss': 0.23101764301890915, 'mode': 'train'} 30 | {'epoch': 15, 'accuracy': 0.8721, 'loss': 0.3853813386077334, 'mode': 'test'} 31 | {'epoch': 16, 'accuracy': 0.91944, 'loss': 0.22683130851601394, 'mode': 'train'} 32 | {'epoch': 16, 'accuracy': 0.881, 'loss': 0.3632625983968661, 'mode': 'test'} 33 | {'epoch': 17, 'accuracy': 0.92534, 'loss': 0.213896721067941, 'mode': 'train'} 34 | {'epoch': 17, 'accuracy': 0.8763, 'loss': 0.37578690251347374, 'mode': 'test'} 35 | {'epoch': 18, 'accuracy': 0.92732, 'loss': 0.2081321217596075, 'mode': 'train'} 36 | {'epoch': 18, 'accuracy': 0.8715, 'loss': 0.39774996165637017, 'mode': 'test'} 37 | {'epoch': 19, 'accuracy': 0.9285, 'loss': 0.2045462084243362, 'mode': 'train'} 38 | {'epoch': 19, 'accuracy': 0.8415, 'loss': 0.506635734989385, 'mode': 'test'} 39 | {'epoch': 20, 'accuracy': 0.92904, 'loss': 0.20270288894738978, 'mode': 'train'} 40 | {'epoch': 20, 'accuracy': 0.8688, 'loss': 0.4094154776851083, 'mode': 'test'} 41 | {'epoch': 21, 'accuracy': 0.93174, 'loss': 0.1937037008192837, 'mode': 'train'} 42 | {'epoch': 21, 'accuracy': 0.8747, 'loss': 0.3853637102967613, 'mode': 'test'} 43 | {'epoch': 22, 'accuracy': 0.93486, 'loss': 0.18546549211758778, 'mode': 'train'} 44 | {'epoch': 22, 'accuracy': 0.8595, 'loss': 0.4607807563938154, 'mode': 'test'} 45 | {'epoch': 23, 'accuracy': 0.93434, 'loss': 0.18510274832015455, 'mode': 'train'} 46 | {'epoch': 23, 'accuracy': 0.8542, 'loss': 0.48144651493828794, 'mode': 'test'} 47 | {'epoch': 24, 'accuracy': 0.93742, 'loss': 0.17812334380262634, 'mode': 'train'} 48 | {'epoch': 24, 'accuracy': 0.8838, 'loss': 0.35136064516890575, 'mode': 'test'} 49 | {'epoch': 25, 'accuracy': 0.93506, 'loss': 0.18342369634781952, 'mode': 'train'} 50 | {'epoch': 25, 'accuracy': 0.8775, 'loss': 0.4020148741591509, 'mode': 'test'} 51 | {'epoch': 26, 'accuracy': 0.93732, 'loss': 0.17800434354854663, 'mode': 'train'} 52 | {'epoch': 26, 'accuracy': 0.8612, 'loss': 0.4518005688478991, 'mode': 'test'} 53 | {'epoch': 27, 'accuracy': 0.93952, 'loss': 0.17259756810105667, 'mode': 'train'} 54 | {'epoch': 27, 'accuracy': 0.8695, 'loss': 0.41985917461525857, 'mode': 'test'} 55 | {'epoch': 28, 'accuracy': 0.93868, 'loss': 0.17634227529854113, 'mode': 'train'} 56 | {'epoch': 28, 'accuracy': 0.8535, 'loss': 0.4761059612605223, 'mode': 'test'} 57 | {'epoch': 29, 'accuracy': 0.94114, 'loss': 0.17108567083811818, 'mode': 'train'} 58 | {'epoch': 29, 'accuracy': 0.8795, 'loss': 0.3904990235901181, 'mode': 'test'} 59 | {'epoch': 30, 'accuracy': 0.94062, 'loss': 0.16828305741100366, 'mode': 'train'} 60 | {'epoch': 30, 'accuracy': 0.8725, 'loss': 0.40022565158689105, 'mode': 'test'} 61 | {'epoch': 31, 'accuracy': 0.94102, 'loss': 0.16731096742212623, 'mode': 'train'} 62 | {'epoch': 31, 'accuracy': 0.8753, 'loss': 0.383897032897184, 'mode': 'test'} 63 | {'epoch': 32, 'accuracy': 0.9435, 'loss': 0.16252539707990876, 'mode': 'train'} 64 | {'epoch': 32, 'accuracy': 0.8538, 'loss': 0.4938757709067338, 'mode': 'test'} 65 | {'epoch': 33, 'accuracy': 0.9448, 'loss': 0.15843584180792875, 'mode': 'train'} 66 | {'epoch': 33, 'accuracy': 0.8815, 'loss': 0.37589270958475224, 'mode': 'test'} 67 | {'epoch': 34, 'accuracy': 0.94486, 'loss': 0.15767581540795814, 'mode': 'train'} 68 | {'epoch': 34, 'accuracy': 0.8826, 'loss': 0.36932552468245183, 'mode': 'test'} 69 | {'epoch': 35, 'accuracy': 0.94598, 'loss': 0.1535190101855858, 'mode': 'train'} 70 | {'epoch': 35, 'accuracy': 0.8406, 'loss': 0.5293987310805902, 'mode': 'test'} 71 | {'epoch': 36, 'accuracy': 0.94398, 'loss': 0.1585736681261789, 'mode': 'train'} 72 | {'epoch': 36, 'accuracy': 0.8798, 'loss': 0.3791128994932599, 'mode': 'test'} 73 | {'epoch': 37, 'accuracy': 0.94844, 'loss': 0.1478317224365823, 'mode': 'train'} 74 | {'epoch': 37, 'accuracy': 0.8702, 'loss': 0.40592986164958617, 'mode': 'test'} 75 | {'epoch': 38, 'accuracy': 0.94594, 'loss': 0.15362863195940954, 'mode': 'train'} 76 | {'epoch': 38, 'accuracy': 0.8841, 'loss': 0.3791521975568906, 'mode': 'test'} 77 | {'epoch': 39, 'accuracy': 0.94772, 'loss': 0.14822693958954725, 'mode': 'train'} 78 | {'epoch': 39, 'accuracy': 0.8843, 'loss': 0.3851375788640065, 'mode': 'test'} 79 | {'epoch': 40, 'accuracy': 0.94852, 'loss': 0.14638042388021782, 'mode': 'train'} 80 | {'epoch': 40, 'accuracy': 0.8829, 'loss': 0.37932144437625903, 'mode': 'test'} 81 | {'epoch': 41, 'accuracy': 0.94944, 'loss': 0.14358370161384276, 'mode': 'train'} 82 | {'epoch': 41, 'accuracy': 0.8748, 'loss': 0.43302345968735473, 'mode': 'test'} 83 | {'epoch': 42, 'accuracy': 0.95252, 'loss': 0.13859815092380987, 'mode': 'train'} 84 | {'epoch': 42, 'accuracy': 0.8733, 'loss': 0.42106063597521226, 'mode': 'test'} 85 | {'epoch': 43, 'accuracy': 0.95112, 'loss': 0.1419139150482462, 'mode': 'train'} 86 | {'epoch': 43, 'accuracy': 0.8819, 'loss': 0.36865578468438154, 'mode': 'test'} 87 | {'epoch': 44, 'accuracy': 0.9495, 'loss': 0.14441218667323014, 'mode': 'train'} 88 | {'epoch': 44, 'accuracy': 0.8876, 'loss': 0.37616785231289585, 'mode': 'test'} 89 | {'epoch': 45, 'accuracy': 0.95144, 'loss': 0.13832323992496243, 'mode': 'train'} 90 | {'epoch': 45, 'accuracy': 0.8829, 'loss': 0.3635770950917227, 'mode': 'test'} 91 | {'epoch': 46, 'accuracy': 0.9525, 'loss': 0.1328471898937318, 'mode': 'train'} 92 | {'epoch': 46, 'accuracy': 0.8806, 'loss': 0.4050914176330444, 'mode': 'test'} 93 | {'epoch': 47, 'accuracy': 0.95172, 'loss': 0.1384341535456193, 'mode': 'train'} 94 | {'epoch': 47, 'accuracy': 0.8759, 'loss': 0.4127537378459978, 'mode': 'test'} 95 | {'epoch': 48, 'accuracy': 0.9498, 'loss': 0.14337113628263984, 'mode': 'train'} 96 | {'epoch': 48, 'accuracy': 0.8742, 'loss': 0.41528847301082256, 'mode': 'test'} 97 | {'epoch': 49, 'accuracy': 0.9539, 'loss': 0.1324270457372337, 'mode': 'train'} 98 | {'epoch': 49, 'accuracy': 0.8856, 'loss': 0.3796541090983494, 'mode': 'test'} 99 | {'epoch': 50, 'accuracy': 0.9538, 'loss': 0.13250392774486805, 'mode': 'train'} 100 | {'epoch': 50, 'accuracy': 0.878, 'loss': 0.42052315692802905, 'mode': 'test'} 101 | {'epoch': 51, 'accuracy': 0.95392, 'loss': 0.1310321771732682, 'mode': 'train'} 102 | {'epoch': 51, 'accuracy': 0.8953, 'loss': 0.349161261111308, 'mode': 'test'} 103 | {'epoch': 52, 'accuracy': 0.95414, 'loss': 0.1319556686255481, 'mode': 'train'} 104 | {'epoch': 52, 'accuracy': 0.8909, 'loss': 0.3420301916872621, 'mode': 'test'} 105 | {'epoch': 53, 'accuracy': 0.95624, 'loss': 0.122365583048757, 'mode': 'train'} 106 | {'epoch': 53, 'accuracy': 0.876, 'loss': 0.41360473030122225, 'mode': 'test'} 107 | {'epoch': 54, 'accuracy': 0.95568, 'loss': 0.12968954627814183, 'mode': 'train'} 108 | {'epoch': 54, 'accuracy': 0.8801, 'loss': 0.4098439184343738, 'mode': 'test'} 109 | {'epoch': 55, 'accuracy': 0.95476, 'loss': 0.12796810306513404, 'mode': 'train'} 110 | {'epoch': 55, 'accuracy': 0.8839, 'loss': 0.3910343188578916, 'mode': 'test'} 111 | {'epoch': 56, 'accuracy': 0.95574, 'loss': 0.12758931261308668, 'mode': 'train'} 112 | {'epoch': 56, 'accuracy': 0.8859, 'loss': 0.360085581233547, 'mode': 'test'} 113 | {'epoch': 57, 'accuracy': 0.95634, 'loss': 0.12556623106303114, 'mode': 'train'} 114 | {'epoch': 57, 'accuracy': 0.8688, 'loss': 0.4509742689929949, 'mode': 'test'} 115 | {'epoch': 58, 'accuracy': 0.95746, 'loss': 0.1240024342275488, 'mode': 'train'} 116 | {'epoch': 58, 'accuracy': 0.8876, 'loss': 0.3945227263459734, 'mode': 'test'} 117 | {'epoch': 59, 'accuracy': 0.95782, 'loss': 0.12148631294555669, 'mode': 'train'} 118 | {'epoch': 59, 'accuracy': 0.8845, 'loss': 0.3934551621698269, 'mode': 'test'} 119 | {'epoch': 60, 'accuracy': 0.955, 'loss': 0.1275351512367311, 'mode': 'train'} 120 | {'epoch': 60, 'accuracy': 0.8877, 'loss': 0.3828521441122527, 'mode': 'test'} 121 | {'epoch': 61, 'accuracy': 0.95832, 'loss': 0.11920713231234296, 'mode': 'train'} 122 | {'epoch': 61, 'accuracy': 0.8653, 'loss': 0.4557086989568297, 'mode': 'test'} 123 | {'epoch': 62, 'accuracy': 0.95844, 'loss': 0.11982088970482505, 'mode': 'train'} 124 | {'epoch': 62, 'accuracy': 0.8802, 'loss': 0.40236303827185554, 'mode': 'test'} 125 | {'epoch': 63, 'accuracy': 0.95622, 'loss': 0.12434144690632833, 'mode': 'train'} 126 | {'epoch': 63, 'accuracy': 0.8803, 'loss': 0.3969290236568755, 'mode': 'test'} 127 | {'epoch': 64, 'accuracy': 0.95924, 'loss': 0.11752237154699659, 'mode': 'train'} 128 | {'epoch': 64, 'accuracy': 0.8861, 'loss': 0.3769667418139755, 'mode': 'test'} 129 | {'epoch': 65, 'accuracy': 0.95912, 'loss': 0.12002650031919976, 'mode': 'train'} 130 | {'epoch': 65, 'accuracy': 0.8885, 'loss': 0.3748677450760154, 'mode': 'test'} 131 | {'epoch': 66, 'accuracy': 0.9554, 'loss': 0.12826157748089412, 'mode': 'train'} 132 | {'epoch': 66, 'accuracy': 0.8809, 'loss': 0.40374053815367883, 'mode': 'test'} 133 | {'epoch': 67, 'accuracy': 0.95812, 'loss': 0.11856412766572773, 'mode': 'train'} 134 | {'epoch': 67, 'accuracy': 0.8634, 'loss': 0.4660199013105624, 'mode': 'test'} 135 | {'epoch': 68, 'accuracy': 0.9589, 'loss': 0.11746483851614818, 'mode': 'train'} 136 | {'epoch': 68, 'accuracy': 0.8889, 'loss': 0.3764256781833185, 'mode': 'test'} 137 | {'epoch': 69, 'accuracy': 0.95562, 'loss': 0.12445770192634106, 'mode': 'train'} 138 | {'epoch': 69, 'accuracy': 0.8808, 'loss': 0.40087049905281913, 'mode': 'test'} 139 | {'epoch': 70, 'accuracy': 0.96034, 'loss': 0.11386847697065008, 'mode': 'train'} 140 | {'epoch': 70, 'accuracy': 0.888, 'loss': 0.3884650320763801, 'mode': 'test'} 141 | {'epoch': 71, 'accuracy': 0.95674, 'loss': 0.12094113124948959, 'mode': 'train'} 142 | {'epoch': 71, 'accuracy': 0.8947, 'loss': 0.32958807383373295, 'mode': 'test'} 143 | {'epoch': 72, 'accuracy': 0.9601, 'loss': 0.11637869064727109, 'mode': 'train'} 144 | {'epoch': 72, 'accuracy': 0.8828, 'loss': 0.37514699549432, 'mode': 'test'} 145 | {'epoch': 73, 'accuracy': 0.95612, 'loss': 0.12389597728314912, 'mode': 'train'} 146 | {'epoch': 73, 'accuracy': 0.8842, 'loss': 0.38055567851491795, 'mode': 'test'} 147 | {'epoch': 74, 'accuracy': 0.96052, 'loss': 0.1108432029233411, 'mode': 'train'} 148 | {'epoch': 74, 'accuracy': 0.8869, 'loss': 0.38148661308987153, 'mode': 'test'} 149 | {'epoch': 75, 'accuracy': 0.95964, 'loss': 0.11506778915481801, 'mode': 'train'} 150 | {'epoch': 75, 'accuracy': 0.8781, 'loss': 0.42965845127773905, 'mode': 'test'} 151 | {'epoch': 76, 'accuracy': 0.95904, 'loss': 0.11685363895943404, 'mode': 'train'} 152 | {'epoch': 76, 'accuracy': 0.8884, 'loss': 0.37697183307568727, 'mode': 'test'} 153 | {'epoch': 77, 'accuracy': 0.95948, 'loss': 0.11330655618282527, 'mode': 'train'} 154 | {'epoch': 77, 'accuracy': 0.8869, 'loss': 0.386712168432345, 'mode': 'test'} 155 | {'epoch': 78, 'accuracy': 0.96078, 'loss': 0.11313145347606493, 'mode': 'train'} 156 | {'epoch': 78, 'accuracy': 0.8929, 'loss': 0.35334056976494527, 'mode': 'test'} 157 | {'epoch': 79, 'accuracy': 0.96066, 'loss': 0.11237457385072315, 'mode': 'train'} 158 | {'epoch': 79, 'accuracy': 0.8753, 'loss': 0.44808954047928956, 'mode': 'test'} 159 | {'epoch': 80, 'accuracy': 0.96088, 'loss': 0.11346649072702282, 'mode': 'train'} 160 | {'epoch': 80, 'accuracy': 0.8902, 'loss': 0.3727744262973973, 'mode': 'test'} 161 | {'epoch': 81, 'accuracy': 0.98942, 'loss': 0.03709737803129583, 'mode': 'train'} 162 | {'epoch': 81, 'accuracy': 0.9342, 'loss': 0.21177624550404814, 'mode': 'test'} 163 | {'epoch': 82, 'accuracy': 0.99708, 'loss': 0.015394963219266415, 'mode': 'train'} 164 | {'epoch': 82, 'accuracy': 0.9345, 'loss': 0.21087093750952163, 'mode': 'test'} 165 | {'epoch': 83, 'accuracy': 0.99888, 'loss': 0.00958086443526665, 'mode': 'train'} 166 | {'epoch': 83, 'accuracy': 0.9378, 'loss': 0.20984953167332207, 'mode': 'test'} 167 | {'epoch': 84, 'accuracy': 0.99922, 'loss': 0.006960038124295446, 'mode': 'train'} 168 | {'epoch': 84, 'accuracy': 0.9368, 'loss': 0.21019789092479998, 'mode': 'test'} 169 | {'epoch': 85, 'accuracy': 0.9994, 'loss': 0.0061714895965193275, 'mode': 'train'} 170 | {'epoch': 85, 'accuracy': 0.9386, 'loss': 0.2129666410434018, 'mode': 'test'} 171 | {'epoch': 86, 'accuracy': 0.9998, 'loss': 0.004586680408786317, 'mode': 'train'} 172 | {'epoch': 86, 'accuracy': 0.9386, 'loss': 0.21042226805428793, 'mode': 'test'} 173 | {'epoch': 87, 'accuracy': 0.99974, 'loss': 0.0040803727743875685, 'mode': 'train'} 174 | {'epoch': 87, 'accuracy': 0.939, 'loss': 0.21225800849260046, 'mode': 'test'} 175 | {'epoch': 88, 'accuracy': 0.99988, 'loss': 0.0034328265225186086, 'mode': 'train'} 176 | {'epoch': 88, 'accuracy': 0.9387, 'loss': 0.21086184485892587, 'mode': 'test'} 177 | {'epoch': 89, 'accuracy': 0.99986, 'loss': 0.003204720309170921, 'mode': 'train'} 178 | {'epoch': 89, 'accuracy': 0.9391, 'loss': 0.2128915631087722, 'mode': 'test'} 179 | {'epoch': 90, 'accuracy': 0.9999, 'loss': 0.002780051906700331, 'mode': 'train'} 180 | {'epoch': 90, 'accuracy': 0.9398, 'loss': 0.21008758232661862, 'mode': 'test'} 181 | {'epoch': 91, 'accuracy': 0.99994, 'loss': 0.002599785737979139, 'mode': 'train'} 182 | {'epoch': 91, 'accuracy': 0.9404, 'loss': 0.21136768532406752, 'mode': 'test'} 183 | {'epoch': 92, 'accuracy': 0.9999, 'loss': 0.002744697853732293, 'mode': 'train'} 184 | {'epoch': 92, 'accuracy': 0.939, 'loss': 0.21085699851725512, 'mode': 'test'} 185 | {'epoch': 93, 'accuracy': 0.99994, 'loss': 0.0022666903827196477, 'mode': 'train'} 186 | {'epoch': 93, 'accuracy': 0.9396, 'loss': 0.21333993876435955, 'mode': 'test'} 187 | {'epoch': 94, 'accuracy': 0.99998, 'loss': 0.0022167576777050884, 'mode': 'train'} 188 | {'epoch': 94, 'accuracy': 0.9378, 'loss': 0.21443662260937846, 'mode': 'test'} 189 | {'epoch': 95, 'accuracy': 0.99994, 'loss': 0.0021087861122072817, 'mode': 'train'} 190 | {'epoch': 95, 'accuracy': 0.9391, 'loss': 0.21031065883150526, 'mode': 'test'} 191 | {'epoch': 96, 'accuracy': 0.99996, 'loss': 0.002021871430947042, 'mode': 'train'} 192 | {'epoch': 96, 'accuracy': 0.9396, 'loss': 0.21044457134357689, 'mode': 'test'} 193 | {'epoch': 97, 'accuracy': 0.99998, 'loss': 0.0018738835402157005, 'mode': 'train'} 194 | {'epoch': 97, 'accuracy': 0.9402, 'loss': 0.21494550509437646, 'mode': 'test'} 195 | {'epoch': 98, 'accuracy': 0.99998, 'loss': 0.0017834534806668617, 'mode': 'train'} 196 | {'epoch': 98, 'accuracy': 0.9406, 'loss': 0.21404445256776863, 'mode': 'test'} 197 | {'epoch': 99, 'accuracy': 1.0, 'loss': 0.0017622469369407807, 'mode': 'train'} 198 | {'epoch': 99, 'accuracy': 0.9396, 'loss': 0.21090378450929737, 'mode': 'test'} 199 | {'epoch': 100, 'accuracy': 0.99996, 'loss': 0.001786161433247959, 'mode': 'train'} 200 | {'epoch': 100, 'accuracy': 0.9409, 'loss': 0.2122238306387975, 'mode': 'test'} -------------------------------------------------------------------------------- /state/state_full.txt: -------------------------------------------------------------------------------- 1 | {'loss': 1.701507381008714, 'mode': 'train', 'epoch': 1, 'accuracy': 0.38666} 2 | {'loss': 1.4928108932106359, 'mode': 'test', 'epoch': 1, 'accuracy': 0.4493} 3 | {'loss': 1.2021837340443948, 'mode': 'train', 'epoch': 2, 'accuracy': 0.56658} 4 | {'loss': 1.8669436015900538, 'mode': 'test', 'epoch': 2, 'accuracy': 0.5861} 5 | {'loss': 0.9127246160488917, 'mode': 'train', 'epoch': 3, 'accuracy': 0.67832} 6 | {'loss': 1.1657628627719392, 'mode': 'test', 'epoch': 3, 'accuracy': 0.6912} 7 | {'loss': 0.749766890655089, 'mode': 'train', 'epoch': 4, 'accuracy': 0.73588} 8 | {'loss': 1.2296965936566617, 'mode': 'test', 'epoch': 4, 'accuracy': 0.728} 9 | {'loss': 0.6556041519660167, 'mode': 'train', 'epoch': 5, 'accuracy': 0.77124} 10 | {'loss': 0.7869841974631998, 'mode': 'test', 'epoch': 5, 'accuracy': 0.7557} 11 | {'loss': 0.5484124159492796, 'mode': 'train', 'epoch': 6, 'accuracy': 0.8094} 12 | {'loss': 0.6138943060758005, 'mode': 'test', 'epoch': 6, 'accuracy': 0.7899} 13 | {'loss': 0.47876150147689234, 'mode': 'train', 'epoch': 7, 'accuracy': 0.83392} 14 | {'loss': 0.9107919263232285, 'mode': 'test', 'epoch': 7, 'accuracy': 0.7905} 15 | {'loss': 0.4301651023576026, 'mode': 'train', 'epoch': 8, 'accuracy': 0.85026} 16 | {'loss': 2.8089250785529987, 'mode': 'test', 'epoch': 8, 'accuracy': 0.8012} 17 | {'loss': 0.3876526968939531, 'mode': 'train', 'epoch': 9, 'accuracy': 0.86524} 18 | {'loss': 0.5354519632591562, 'mode': 'test', 'epoch': 9, 'accuracy': 0.8209} 19 | {'loss': 0.3615583087343843, 'mode': 'train', 'epoch': 10, 'accuracy': 0.87574} 20 | {'loss': 0.5041033884712086, 'mode': 'test', 'epoch': 10, 'accuracy': 0.8323} 21 | {'loss': 0.3297862098040183, 'mode': 'train', 'epoch': 11, 'accuracy': 0.88572} 22 | {'loss': 0.4581784237721923, 'mode': 'test', 'epoch': 11, 'accuracy': 0.8495} 23 | {'loss': 0.3019945525357034, 'mode': 'train', 'epoch': 12, 'accuracy': 0.89544} 24 | {'loss': 0.5183731607001301, 'mode': 'test', 'epoch': 12, 'accuracy': 0.8389} 25 | {'loss': 0.2813901776385961, 'mode': 'train', 'epoch': 13, 'accuracy': 0.90228} 26 | {'loss': 0.5616163620903235, 'mode': 'test', 'epoch': 13, 'accuracy': 0.8461} 27 | {'loss': 0.2645355407767895, 'mode': 'train', 'epoch': 14, 'accuracy': 0.90808} 28 | {'loss': 0.46804076831811553, 'mode': 'test', 'epoch': 14, 'accuracy': 0.8464} 29 | {'loss': 0.24726970742463739, 'mode': 'train', 'epoch': 15, 'accuracy': 0.91414} 30 | {'loss': 0.735022325899191, 'mode': 'test', 'epoch': 15, 'accuracy': 0.8547} 31 | {'loss': 0.23869266518679888, 'mode': 'train', 'epoch': 16, 'accuracy': 0.91658} 32 | {'loss': 0.5531352221206495, 'mode': 'test', 'epoch': 16, 'accuracy': 0.8508} 33 | {'loss': 0.22992790387490802, 'mode': 'train', 'epoch': 17, 'accuracy': 0.91916} 34 | {'loss': 0.5201042311586391, 'mode': 'test', 'epoch': 17, 'accuracy': 0.8495} 35 | {'loss': 0.21446329497677438, 'mode': 'train', 'epoch': 18, 'accuracy': 0.92414} 36 | {'loss': 0.40414618159745147, 'mode': 'test', 'epoch': 18, 'accuracy': 0.8674} 37 | {'loss': 0.20903654901020213, 'mode': 'train', 'epoch': 19, 'accuracy': 0.92692} 38 | {'loss': 0.4643758662567016, 'mode': 'test', 'epoch': 19, 'accuracy': 0.8503} 39 | {'loss': 0.1994660591940058, 'mode': 'train', 'epoch': 20, 'accuracy': 0.9299} 40 | {'loss': 0.5420943919070964, 'mode': 'test', 'epoch': 20, 'accuracy': 0.8309} 41 | {'loss': 0.19392704100960212, 'mode': 'train', 'epoch': 21, 'accuracy': 0.93222} 42 | {'loss': 0.49596195919498515, 'mode': 'test', 'epoch': 21, 'accuracy': 0.8513} 43 | {'loss': 0.18678548423778224, 'mode': 'train', 'epoch': 22, 'accuracy': 0.93374} 44 | {'loss': 0.38026612370636825, 'mode': 'test', 'epoch': 22, 'accuracy': 0.8746} 45 | {'loss': 0.18491771314626623, 'mode': 'train', 'epoch': 23, 'accuracy': 0.93596} 46 | {'loss': 0.4087956180428242, 'mode': 'test', 'epoch': 23, 'accuracy': 0.8715} 47 | {'loss': 0.18297010405784686, 'mode': 'train', 'epoch': 24, 'accuracy': 0.93568} 48 | {'loss': 0.44218498326031286, 'mode': 'test', 'epoch': 24, 'accuracy': 0.8542} 49 | {'loss': 0.17564884603709519, 'mode': 'train', 'epoch': 25, 'accuracy': 0.93752} 50 | {'loss': 0.39058889083232096, 'mode': 'test', 'epoch': 25, 'accuracy': 0.8745} 51 | {'loss': 0.1750161382691254, 'mode': 'train', 'epoch': 26, 'accuracy': 0.93874} 52 | {'loss': 0.47982504252035907, 'mode': 'test', 'epoch': 26, 'accuracy': 0.8522} 53 | {'loss': 0.17191265523910063, 'mode': 'train', 'epoch': 27, 'accuracy': 0.94018} 54 | {'loss': 0.3860075710115918, 'mode': 'test', 'epoch': 27, 'accuracy': 0.8783} 55 | {'loss': 0.16800044870952052, 'mode': 'train', 'epoch': 28, 'accuracy': 0.94046} 56 | {'loss': 0.428713690133611, 'mode': 'test', 'epoch': 28, 'accuracy': 0.8667} 57 | {'loss': 0.16577000001831285, 'mode': 'train', 'epoch': 29, 'accuracy': 0.94246} 58 | {'loss': 0.3971163987354109, 'mode': 'test', 'epoch': 29, 'accuracy': 0.8728} 59 | {'loss': 0.15954104398884583, 'mode': 'train', 'epoch': 30, 'accuracy': 0.94448} 60 | {'loss': 0.41596957414772867, 'mode': 'test', 'epoch': 30, 'accuracy': 0.8693} 61 | {'loss': 0.16619436153213082, 'mode': 'train', 'epoch': 31, 'accuracy': 0.9417} 62 | {'loss': 0.4290382457766564, 'mode': 'test', 'epoch': 31, 'accuracy': 0.869} 63 | {'loss': 0.16175390376001983, 'mode': 'train', 'epoch': 32, 'accuracy': 0.94258} 64 | {'loss': 0.3887259405414769, 'mode': 'test', 'epoch': 32, 'accuracy': 0.8778} 65 | {'loss': 0.15511017867728408, 'mode': 'train', 'epoch': 33, 'accuracy': 0.94542} 66 | {'loss': 0.4112637587793314, 'mode': 'test', 'epoch': 33, 'accuracy': 0.8739} 67 | {'loss': 0.1545597537637442, 'mode': 'train', 'epoch': 34, 'accuracy': 0.94536} 68 | {'loss': 0.3945926525600399, 'mode': 'test', 'epoch': 34, 'accuracy': 0.8751} 69 | {'loss': 0.14720909192662707, 'mode': 'train', 'epoch': 35, 'accuracy': 0.94744} 70 | {'loss': 0.3880315438672237, 'mode': 'test', 'epoch': 35, 'accuracy': 0.8845} 71 | {'loss': 0.14748527040547874, 'mode': 'train', 'epoch': 36, 'accuracy': 0.948} 72 | {'loss': 0.416666530072689, 'mode': 'test', 'epoch': 36, 'accuracy': 0.8701} 73 | {'loss': 0.1519192961351875, 'mode': 'train', 'epoch': 37, 'accuracy': 0.94632} 74 | {'loss': 0.4118852260386108, 'mode': 'test', 'epoch': 37, 'accuracy': 0.8726} 75 | {'loss': 0.1538788779731601, 'mode': 'train', 'epoch': 38, 'accuracy': 0.94612} 76 | {'loss': 0.4324377569233539, 'mode': 'test', 'epoch': 38, 'accuracy': 0.872} 77 | {'loss': 0.14536677703947343, 'mode': 'train', 'epoch': 39, 'accuracy': 0.94912} 78 | {'loss': 0.36516678504123795, 'mode': 'test', 'epoch': 39, 'accuracy': 0.8863} 79 | {'loss': 0.139217748275727, 'mode': 'train', 'epoch': 40, 'accuracy': 0.95086} 80 | {'loss': 0.4048831054739131, 'mode': 'test', 'epoch': 40, 'accuracy': 0.8789} 81 | {'loss': 0.14121819838709998, 'mode': 'train', 'epoch': 41, 'accuracy': 0.9512} 82 | {'loss': 0.38245561253872645, 'mode': 'test', 'epoch': 41, 'accuracy': 0.8806} 83 | {'loss': 0.13981572827300454, 'mode': 'train', 'epoch': 42, 'accuracy': 0.95032} 84 | {'loss': 0.3862506180622015, 'mode': 'test', 'epoch': 42, 'accuracy': 0.8819} 85 | {'loss': 0.14481345783738067, 'mode': 'train', 'epoch': 43, 'accuracy': 0.94988} 86 | {'loss': 0.35895476573307067, 'mode': 'test', 'epoch': 43, 'accuracy': 0.8881} 87 | {'loss': 0.13126857663550043, 'mode': 'train', 'epoch': 44, 'accuracy': 0.95318} 88 | {'loss': 0.39577574918794023, 'mode': 'test', 'epoch': 44, 'accuracy': 0.8781} 89 | {'loss': 0.13475442316640368, 'mode': 'train', 'epoch': 45, 'accuracy': 0.9526} 90 | {'loss': 0.4135137971988911, 'mode': 'test', 'epoch': 45, 'accuracy': 0.8754} 91 | {'loss': 0.13911602763778272, 'mode': 'train', 'epoch': 46, 'accuracy': 0.95068} 92 | {'loss': 0.3681163952751145, 'mode': 'test', 'epoch': 46, 'accuracy': 0.8884} 93 | {'loss': 0.13189648480995378, 'mode': 'train', 'epoch': 47, 'accuracy': 0.9543} 94 | {'loss': 0.38697044527644586, 'mode': 'test', 'epoch': 47, 'accuracy': 0.8784} 95 | {'loss': 0.13368999465937959, 'mode': 'train', 'epoch': 48, 'accuracy': 0.95376} 96 | {'loss': 0.40222299274555445, 'mode': 'test', 'epoch': 48, 'accuracy': 0.8782} 97 | {'loss': 0.13029073807589528, 'mode': 'train', 'epoch': 49, 'accuracy': 0.95462} 98 | {'loss': 0.38152846404511415, 'mode': 'test', 'epoch': 49, 'accuracy': 0.8864} 99 | {'loss': 0.1335244185631843, 'mode': 'train', 'epoch': 50, 'accuracy': 0.95284} 100 | {'loss': 0.36329996377039864, 'mode': 'test', 'epoch': 50, 'accuracy': 0.8874} 101 | {'loss': 0.13204549131033688, 'mode': 'train', 'epoch': 51, 'accuracy': 0.95398} 102 | {'loss': 0.3553180611532206, 'mode': 'test', 'epoch': 51, 'accuracy': 0.8862} 103 | {'loss': 0.12910209198379927, 'mode': 'train', 'epoch': 52, 'accuracy': 0.95478} 104 | {'loss': 0.35513117449108955, 'mode': 'test', 'epoch': 52, 'accuracy': 0.888} 105 | {'loss': 0.12485227229840624, 'mode': 'train', 'epoch': 53, 'accuracy': 0.95728} 106 | {'loss': 0.38907561878300007, 'mode': 'test', 'epoch': 53, 'accuracy': 0.8831} 107 | {'loss': 0.12308644332334673, 'mode': 'train', 'epoch': 54, 'accuracy': 0.95758} 108 | {'loss': 0.3873825145849756, 'mode': 'test', 'epoch': 54, 'accuracy': 0.8887} 109 | {'loss': 0.12965246650111636, 'mode': 'train', 'epoch': 55, 'accuracy': 0.95424} 110 | {'loss': 0.39025878678461556, 'mode': 'test', 'epoch': 55, 'accuracy': 0.8787} 111 | {'loss': 0.1281252900171843, 'mode': 'train', 'epoch': 56, 'accuracy': 0.95568} 112 | {'loss': 0.449953366474361, 'mode': 'test', 'epoch': 56, 'accuracy': 0.8618} 113 | {'loss': 0.119224290542133, 'mode': 'train', 'epoch': 57, 'accuracy': 0.95816} 114 | {'loss': 0.3588221183248388, 'mode': 'test', 'epoch': 57, 'accuracy': 0.8896} 115 | {'loss': 0.12261916789919376, 'mode': 'train', 'epoch': 58, 'accuracy': 0.95718} 116 | {'loss': 0.43766490868322433, 'mode': 'test', 'epoch': 58, 'accuracy': 0.8685} 117 | {'loss': 0.12576307859891053, 'mode': 'train', 'epoch': 59, 'accuracy': 0.95592} 118 | {'loss': 0.4373816805090873, 'mode': 'test', 'epoch': 59, 'accuracy': 0.8734} 119 | {'loss': 0.12282564514852543, 'mode': 'train', 'epoch': 60, 'accuracy': 0.95712} 120 | {'loss': 0.36598861587655035, 'mode': 'test', 'epoch': 60, 'accuracy': 0.8837} 121 | {'loss': 0.12214871014342141, 'mode': 'train', 'epoch': 61, 'accuracy': 0.95658} 122 | {'loss': 0.38112527976749805, 'mode': 'test', 'epoch': 61, 'accuracy': 0.8893} 123 | {'loss': 0.12402856823943019, 'mode': 'train', 'epoch': 62, 'accuracy': 0.9553} 124 | {'loss': 0.3221739788249041, 'mode': 'test', 'epoch': 62, 'accuracy': 0.8988} 125 | {'loss': 0.1088666003530897, 'mode': 'train', 'epoch': 63, 'accuracy': 0.9624} 126 | {'loss': 0.39043013680322913, 'mode': 'test', 'epoch': 63, 'accuracy': 0.8825} 127 | {'loss': 0.11411408948071336, 'mode': 'train', 'epoch': 64, 'accuracy': 0.96032} 128 | {'loss': 0.3833443418999386, 'mode': 'test', 'epoch': 64, 'accuracy': 0.8827} 129 | {'loss': 0.12377060531540907, 'mode': 'train', 'epoch': 65, 'accuracy': 0.95748} 130 | {'loss': 0.3827761898564685, 'mode': 'test', 'epoch': 65, 'accuracy': 0.8832} 131 | {'loss': 0.12111604612444514, 'mode': 'train', 'epoch': 66, 'accuracy': 0.95704} 132 | {'loss': 0.3680251757999892, 'mode': 'test', 'epoch': 66, 'accuracy': 0.8924} 133 | {'loss': 0.11580644736585717, 'mode': 'train', 'epoch': 67, 'accuracy': 0.95884} 134 | {'loss': 0.4106823524851706, 'mode': 'test', 'epoch': 67, 'accuracy': 0.8763} 135 | {'loss': 0.12060475547123907, 'mode': 'train', 'epoch': 68, 'accuracy': 0.9573} 136 | {'loss': 0.3889824640314292, 'mode': 'test', 'epoch': 68, 'accuracy': 0.8849} 137 | {'loss': 0.11598743845606248, 'mode': 'train', 'epoch': 69, 'accuracy': 0.9595} 138 | {'loss': 0.3181076217324112, 'mode': 'test', 'epoch': 69, 'accuracy': 0.9012} 139 | {'loss': 0.10793253100570058, 'mode': 'train', 'epoch': 70, 'accuracy': 0.96206} 140 | {'loss': 0.41573411169325464, 'mode': 'test', 'epoch': 70, 'accuracy': 0.8798} 141 | {'loss': 0.11520065711168075, 'mode': 'train', 'epoch': 71, 'accuracy': 0.9595} 142 | {'loss': 0.33289557785555035, 'mode': 'test', 'epoch': 71, 'accuracy': 0.8968} 143 | {'loss': 0.1211894040979694, 'mode': 'train', 'epoch': 72, 'accuracy': 0.95678} 144 | {'loss': 0.3671006180203645, 'mode': 'test', 'epoch': 72, 'accuracy': 0.8921} 145 | {'loss': 0.11148716030580466, 'mode': 'train', 'epoch': 73, 'accuracy': 0.96096} 146 | {'loss': 0.43286734544167826, 'mode': 'test', 'epoch': 73, 'accuracy': 0.8646} 147 | {'loss': 0.11270042534088695, 'mode': 'train', 'epoch': 74, 'accuracy': 0.96022} 148 | {'loss': 0.33534531964428094, 'mode': 'test', 'epoch': 74, 'accuracy': 0.8954} 149 | {'loss': 0.11693533238909588, 'mode': 'train', 'epoch': 75, 'accuracy': 0.95938} 150 | {'loss': 0.44399098121816205, 'mode': 'test', 'epoch': 75, 'accuracy': 0.871} 151 | {'loss': 0.11129785280036349, 'mode': 'train', 'epoch': 76, 'accuracy': 0.96162} 152 | {'loss': 0.42270676231688, 'mode': 'test', 'epoch': 76, 'accuracy': 0.8712} 153 | {'loss': 0.11404835176951887, 'mode': 'train', 'epoch': 77, 'accuracy': 0.95984} 154 | {'loss': 0.36246049195338215, 'mode': 'test', 'epoch': 77, 'accuracy': 0.8922} 155 | {'loss': 0.10977578357509943, 'mode': 'train', 'epoch': 78, 'accuracy': 0.96132} 156 | {'loss': 0.44367400376470234, 'mode': 'test', 'epoch': 78, 'accuracy': 0.8751} 157 | {'loss': 0.10897776357772392, 'mode': 'train', 'epoch': 79, 'accuracy': 0.96124} 158 | {'loss': 0.3395387756691616, 'mode': 'test', 'epoch': 79, 'accuracy': 0.8933} 159 | {'loss': 0.11842237337661518, 'mode': 'train', 'epoch': 80, 'accuracy': 0.95788} 160 | {'loss': 0.3565018450378614, 'mode': 'test', 'epoch': 80, 'accuracy': 0.8932} 161 | {'loss': 0.03500196959851949, 'mode': 'train', 'epoch': 81, 'accuracy': 0.99} 162 | {'loss': 0.22050533671477787, 'mode': 'test', 'epoch': 81, 'accuracy': 0.9331} 163 | {'loss': 0.01512880123141782, 'mode': 'train', 'epoch': 82, 'accuracy': 0.99714} 164 | {'loss': 0.21671485554450645, 'mode': 'test', 'epoch': 82, 'accuracy': 0.9369} 165 | {'loss': 0.009707257356447034, 'mode': 'train', 'epoch': 83, 'accuracy': 0.99886} 166 | {'loss': 0.2142113212519771, 'mode': 'test', 'epoch': 83, 'accuracy': 0.9368} 167 | {'loss': 0.00775727892623228, 'mode': 'train', 'epoch': 84, 'accuracy': 0.99934} 168 | {'loss': 0.21198879488429445, 'mode': 'test', 'epoch': 84, 'accuracy': 0.9374} 169 | {'loss': 0.006033864172408953, 'mode': 'train', 'epoch': 85, 'accuracy': 0.9996} 170 | {'loss': 0.21243946140359157, 'mode': 'test', 'epoch': 85, 'accuracy': 0.9395} 171 | {'loss': 0.0048667361240481545, 'mode': 'train', 'epoch': 86, 'accuracy': 0.99974} 172 | {'loss': 0.21474126726388934, 'mode': 'test', 'epoch': 86, 'accuracy': 0.9407} 173 | {'loss': 0.0044306598775222215, 'mode': 'train', 'epoch': 87, 'accuracy': 0.99984} 174 | {'loss': 0.2146223539569575, 'mode': 'test', 'epoch': 87, 'accuracy': 0.9379} 175 | {'loss': 0.003798292244753573, 'mode': 'train', 'epoch': 88, 'accuracy': 0.9999} 176 | {'loss': 0.21468912938218201, 'mode': 'test', 'epoch': 88, 'accuracy': 0.9393} 177 | {'loss': 0.003541848253068108, 'mode': 'train', 'epoch': 89, 'accuracy': 0.99982} 178 | {'loss': 0.2212018348086791, 'mode': 'test', 'epoch': 89, 'accuracy': 0.9378} 179 | {'loss': 0.003184019209212054, 'mode': 'train', 'epoch': 90, 'accuracy': 0.99992} 180 | {'loss': 0.2169848160141972, 'mode': 'test', 'epoch': 90, 'accuracy': 0.9393} 181 | {'loss': 0.003019570618334326, 'mode': 'train', 'epoch': 91, 'accuracy': 0.99992} 182 | {'loss': 0.2184833237414907, 'mode': 'test', 'epoch': 91, 'accuracy': 0.9392} 183 | {'loss': 0.0026865883937577167, 'mode': 'train', 'epoch': 92, 'accuracy': 0.99998} 184 | {'loss': 0.2182157895272704, 'mode': 'test', 'epoch': 92, 'accuracy': 0.9396} 185 | {'loss': 0.002603915751056599, 'mode': 'train', 'epoch': 93, 'accuracy': 0.9999} 186 | {'loss': 0.21871563728163193, 'mode': 'test', 'epoch': 93, 'accuracy': 0.9387} 187 | {'loss': 0.0023342992281517403, 'mode': 'train', 'epoch': 94, 'accuracy': 0.99996} 188 | {'loss': 0.22084234951502973, 'mode': 'test', 'epoch': 94, 'accuracy': 0.9388} 189 | {'loss': 0.002422085150961982, 'mode': 'train', 'epoch': 95, 'accuracy': 0.99994} 190 | {'loss': 0.21687235448295905, 'mode': 'test', 'epoch': 95, 'accuracy': 0.9401} 191 | {'loss': 0.002327153082851251, 'mode': 'train', 'epoch': 96, 'accuracy': 1.0} 192 | {'loss': 0.21919785276245168, 'mode': 'test', 'epoch': 96, 'accuracy': 0.9394} 193 | {'loss': 0.002161553096207207, 'mode': 'train', 'epoch': 97, 'accuracy': 0.99998} 194 | {'loss': 0.2176431942565046, 'mode': 'test', 'epoch': 97, 'accuracy': 0.9392} 195 | {'loss': 0.0022441323684609456, 'mode': 'train', 'epoch': 98, 'accuracy': 0.9999} 196 | {'loss': 0.2170503516533193, 'mode': 'test', 'epoch': 98, 'accuracy': 0.9387} 197 | {'loss': 0.0018972001131385772, 'mode': 'train', 'epoch': 99, 'accuracy': 1.0} 198 | {'loss': 0.22311048128992123, 'mode': 'test', 'epoch': 99, 'accuracy': 0.9382} 199 | {'loss': 0.0019207180184705175, 'mode': 'train', 'epoch': 100, 'accuracy': 0.99998} 200 | {'loss': 0.22380224558388342, 'mode': 'test', 'epoch': 100, 'accuracy': 0.9394} 201 | -------------------------------------------------------------------------------- /state/state_full_in_loop.txt: -------------------------------------------------------------------------------- 1 | {'loss': 1.7007273447025764, 'mode': 'train', 'accuracy': 0.38796, 'epoch': 1} 2 | {'loss': 1.5075985223624353, 'mode': 'test', 'accuracy': 0.4659, 'epoch': 1} 3 | {'loss': 1.1475214777547673, 'mode': 'train', 'accuracy': 0.5883, 'epoch': 2} 4 | {'loss': 1.0237990139396327, 'mode': 'test', 'accuracy': 0.6402, 'epoch': 2} 5 | {'loss': 0.8564700453025287, 'mode': 'train', 'accuracy': 0.69832, 'epoch': 3} 6 | {'loss': 0.82860863987048, 'mode': 'test', 'accuracy': 0.7186, 'epoch': 3} 7 | {'loss': 0.6687911291942568, 'mode': 'train', 'accuracy': 0.76562, 'epoch': 4} 8 | {'loss': 0.7498326920399998, 'mode': 'test', 'accuracy': 0.7614, 'epoch': 4} 9 | {'loss': 0.5520611177853613, 'mode': 'train', 'accuracy': 0.80786, 'epoch': 5} 10 | {'loss': 0.5979902190007982, 'mode': 'test', 'accuracy': 0.7973, 'epoch': 5} 11 | {'loss': 0.4681166113566254, 'mode': 'train', 'accuracy': 0.83846, 'epoch': 6} 12 | {'loss': 0.5526944858253382, 'mode': 'test', 'accuracy': 0.8125, 'epoch': 6} 13 | {'loss': 0.4155567596711773, 'mode': 'train', 'accuracy': 0.85718, 'epoch': 7} 14 | {'loss': 0.5236777084268586, 'mode': 'test', 'accuracy': 0.8191, 'epoch': 7} 15 | {'loss': 0.3602934756966506, 'mode': 'train', 'accuracy': 0.87564, 'epoch': 8} 16 | {'loss': 0.44742005598393214, 'mode': 'test', 'accuracy': 0.8488, 'epoch': 8} 17 | {'loss': 0.3306992431564251, 'mode': 'train', 'accuracy': 0.88532, 'epoch': 9} 18 | {'loss': 0.4363572088776119, 'mode': 'test', 'accuracy': 0.8472, 'epoch': 9} 19 | {'loss': 0.30211485800383375, 'mode': 'train', 'accuracy': 0.89438, 'epoch': 10} 20 | {'loss': 0.45789728470289026, 'mode': 'test', 'accuracy': 0.8473, 'epoch': 10} 21 | {'loss': 0.27929461834109026, 'mode': 'train', 'accuracy': 0.90224, 'epoch': 11} 22 | {'loss': 0.38672881597166603, 'mode': 'test', 'accuracy': 0.8666, 'epoch': 11} 23 | {'loss': 0.25351597916554003, 'mode': 'train', 'accuracy': 0.91114, 'epoch': 12} 24 | {'loss': 0.3946147208950322, 'mode': 'test', 'accuracy': 0.8684, 'epoch': 12} 25 | {'loss': 0.24124464427437772, 'mode': 'train', 'accuracy': 0.91608, 'epoch': 13} 26 | {'loss': 0.4179237357750063, 'mode': 'test', 'accuracy': 0.8608, 'epoch': 13} 27 | {'loss': 0.22418475723670592, 'mode': 'train', 'accuracy': 0.92224, 'epoch': 14} 28 | {'loss': 0.39182560311950687, 'mode': 'test', 'accuracy': 0.8765, 'epoch': 14} 29 | {'loss': 0.21766807441897432, 'mode': 'train', 'accuracy': 0.9249, 'epoch': 15} 30 | {'loss': 0.41445253163006657, 'mode': 'test', 'accuracy': 0.8635, 'epoch': 15} 31 | {'loss': 0.20860296659304958, 'mode': 'train', 'accuracy': 0.9266, 'epoch': 16} 32 | {'loss': 0.41642816374256364, 'mode': 'test', 'accuracy': 0.8678, 'epoch': 16} 33 | {'loss': 0.20074465007180592, 'mode': 'train', 'accuracy': 0.92832, 'epoch': 17} 34 | {'loss': 0.3488860415045623, 'mode': 'test', 'accuracy': 0.887, 'epoch': 17} 35 | {'loss': 0.1849651348007761, 'mode': 'train', 'accuracy': 0.93432, 'epoch': 18} 36 | {'loss': 0.39999383659499443, 'mode': 'test', 'accuracy': 0.873, 'epoch': 18} 37 | {'loss': 0.1858967669222439, 'mode': 'train', 'accuracy': 0.93536, 'epoch': 19} 38 | {'loss': 0.3752606924931715, 'mode': 'test', 'accuracy': 0.8773, 'epoch': 19} 39 | {'loss': 0.1832350518154291, 'mode': 'train', 'accuracy': 0.93566, 'epoch': 20} 40 | {'loss': 0.3926180781452521, 'mode': 'test', 'accuracy': 0.873, 'epoch': 20} 41 | {'loss': 0.17237658853954696, 'mode': 'train', 'accuracy': 0.93996, 'epoch': 21} 42 | {'loss': 0.47942085981748667, 'mode': 'test', 'accuracy': 0.8546, 'epoch': 21} 43 | {'loss': 0.17426487599092244, 'mode': 'train', 'accuracy': 0.9395, 'epoch': 22} 44 | {'loss': 0.3993801620260925, 'mode': 'test', 'accuracy': 0.877, 'epoch': 22} 45 | {'loss': 0.16371847742982698, 'mode': 'train', 'accuracy': 0.94266, 'epoch': 23} 46 | {'loss': 0.3626598390709064, 'mode': 'test', 'accuracy': 0.8862, 'epoch': 23} 47 | {'loss': 0.1659521320477473, 'mode': 'train', 'accuracy': 0.94188, 'epoch': 24} 48 | {'loss': 0.4062825002867706, 'mode': 'test', 'accuracy': 0.8707, 'epoch': 24} 49 | {'loss': 0.16165300064227178, 'mode': 'train', 'accuracy': 0.94446, 'epoch': 25} 50 | {'loss': 0.39855439552835614, 'mode': 'test', 'accuracy': 0.8724, 'epoch': 25} 51 | {'loss': 0.15794597837664276, 'mode': 'train', 'accuracy': 0.94418, 'epoch': 26} 52 | {'loss': 0.35741935708340555, 'mode': 'test', 'accuracy': 0.8866, 'epoch': 26} 53 | {'loss': 0.15474875281324313, 'mode': 'train', 'accuracy': 0.94584, 'epoch': 27} 54 | {'loss': 0.4310477250700543, 'mode': 'test', 'accuracy': 0.8699, 'epoch': 27} 55 | {'loss': 0.15673791573804516, 'mode': 'train', 'accuracy': 0.94538, 'epoch': 28} 56 | {'loss': 0.34404838875315746, 'mode': 'test', 'accuracy': 0.8886, 'epoch': 28} 57 | {'loss': 0.1548068871735916, 'mode': 'train', 'accuracy': 0.9451, 'epoch': 29} 58 | {'loss': 0.35036747314178274, 'mode': 'test', 'accuracy': 0.8898, 'epoch': 29} 59 | {'loss': 0.1451047400293677, 'mode': 'train', 'accuracy': 0.94894, 'epoch': 30} 60 | {'loss': 0.42883810434181974, 'mode': 'test', 'accuracy': 0.8721, 'epoch': 30} 61 | {'loss': 0.14787662089290218, 'mode': 'train', 'accuracy': 0.94772, 'epoch': 31} 62 | {'loss': 0.34186159240402236, 'mode': 'test', 'accuracy': 0.8905, 'epoch': 31} 63 | {'loss': 0.1453494139758826, 'mode': 'train', 'accuracy': 0.9495, 'epoch': 32} 64 | {'loss': 0.3857933079740802, 'mode': 'test', 'accuracy': 0.8768, 'epoch': 32} 65 | {'loss': 0.14699513734320713, 'mode': 'train', 'accuracy': 0.94886, 'epoch': 33} 66 | {'loss': 0.3685406640551652, 'mode': 'test', 'accuracy': 0.8781, 'epoch': 33} 67 | {'loss': 0.1429196486721181, 'mode': 'train', 'accuracy': 0.95058, 'epoch': 34} 68 | {'loss': 0.37952152051173965, 'mode': 'test', 'accuracy': 0.8847, 'epoch': 34} 69 | {'loss': 0.14595075096468169, 'mode': 'train', 'accuracy': 0.9484, 'epoch': 35} 70 | {'loss': 0.3677537344443568, 'mode': 'test', 'accuracy': 0.8877, 'epoch': 35} 71 | {'loss': 0.13826148995600873, 'mode': 'train', 'accuracy': 0.95162, 'epoch': 36} 72 | {'loss': 0.32227574677983667, 'mode': 'test', 'accuracy': 0.897, 'epoch': 36} 73 | {'loss': 0.1371391208751885, 'mode': 'train', 'accuracy': 0.95134, 'epoch': 37} 74 | {'loss': 0.36377236527052664, 'mode': 'test', 'accuracy': 0.8878, 'epoch': 37} 75 | {'loss': 0.1421136042565261, 'mode': 'train', 'accuracy': 0.95016, 'epoch': 38} 76 | {'loss': 0.40794309821857755, 'mode': 'test', 'accuracy': 0.8737, 'epoch': 38} 77 | {'loss': 0.1342421312771185, 'mode': 'train', 'accuracy': 0.95234, 'epoch': 39} 78 | {'loss': 0.35370542820851547, 'mode': 'test', 'accuracy': 0.8885, 'epoch': 39} 79 | {'loss': 0.12878580621022084, 'mode': 'train', 'accuracy': 0.95524, 'epoch': 40} 80 | {'loss': 0.3172357419778587, 'mode': 'test', 'accuracy': 0.8967, 'epoch': 40} 81 | {'loss': 0.12973436064389354, 'mode': 'train', 'accuracy': 0.95464, 'epoch': 41} 82 | {'loss': 0.35643544081290046, 'mode': 'test', 'accuracy': 0.885, 'epoch': 41} 83 | {'loss': 0.1272757337916918, 'mode': 'train', 'accuracy': 0.95566, 'epoch': 42} 84 | {'loss': 0.3792462154368685, 'mode': 'test', 'accuracy': 0.8845, 'epoch': 42} 85 | {'loss': 0.12080350213343537, 'mode': 'train', 'accuracy': 0.95796, 'epoch': 43} 86 | {'loss': 0.3655736175881829, 'mode': 'test', 'accuracy': 0.8838, 'epoch': 43} 87 | {'loss': 0.13208171827218415, 'mode': 'train', 'accuracy': 0.95406, 'epoch': 44} 88 | {'loss': 0.37989494726536377, 'mode': 'test', 'accuracy': 0.8862, 'epoch': 44} 89 | {'loss': 0.12700393828360945, 'mode': 'train', 'accuracy': 0.95548, 'epoch': 45} 90 | {'loss': 0.3858829955955978, 'mode': 'test', 'accuracy': 0.8787, 'epoch': 45} 91 | {'loss': 0.12482963224677625, 'mode': 'train', 'accuracy': 0.95594, 'epoch': 46} 92 | {'loss': 0.44605165663038876, 'mode': 'test', 'accuracy': 0.8673, 'epoch': 46} 93 | {'loss': 0.12752690182908277, 'mode': 'train', 'accuracy': 0.95468, 'epoch': 47} 94 | {'loss': 0.31893871188353584, 'mode': 'test', 'accuracy': 0.8983, 'epoch': 47} 95 | {'loss': 0.11698215041319131, 'mode': 'train', 'accuracy': 0.96002, 'epoch': 48} 96 | {'loss': 0.3625094685584876, 'mode': 'test', 'accuracy': 0.8917, 'epoch': 48} 97 | {'loss': 0.12403058114907006, 'mode': 'train', 'accuracy': 0.95662, 'epoch': 49} 98 | {'loss': 0.3296422512762864, 'mode': 'test', 'accuracy': 0.8969, 'epoch': 49} 99 | {'loss': 0.1272525352371089, 'mode': 'train', 'accuracy': 0.95534, 'epoch': 50} 100 | {'loss': 0.3386812239980241, 'mode': 'test', 'accuracy': 0.895, 'epoch': 50} 101 | {'loss': 0.11568249926409305, 'mode': 'train', 'accuracy': 0.95972, 'epoch': 51} 102 | {'loss': 0.3947831172804543, 'mode': 'test', 'accuracy': 0.8767, 'epoch': 51} 103 | {'loss': 0.12011221469477616, 'mode': 'train', 'accuracy': 0.95718, 'epoch': 52} 104 | {'loss': 0.3793963777601337, 'mode': 'test', 'accuracy': 0.887, 'epoch': 52} 105 | {'loss': 0.11552935730084742, 'mode': 'train', 'accuracy': 0.9598, 'epoch': 53} 106 | {'loss': 0.32679434561995174, 'mode': 'test', 'accuracy': 0.897, 'epoch': 53} 107 | {'loss': 0.1209672183189015, 'mode': 'train', 'accuracy': 0.9573, 'epoch': 54} 108 | {'loss': 0.3719159721568893, 'mode': 'test', 'accuracy': 0.8864, 'epoch': 54} 109 | {'loss': 0.11716402408755039, 'mode': 'train', 'accuracy': 0.95936, 'epoch': 55} 110 | {'loss': 0.36333953593946555, 'mode': 'test', 'accuracy': 0.8887, 'epoch': 55} 111 | {'loss': 0.11315495698996218, 'mode': 'train', 'accuracy': 0.96074, 'epoch': 56} 112 | {'loss': 0.4135039511379924, 'mode': 'test', 'accuracy': 0.8795, 'epoch': 56} 113 | {'loss': 0.12246145337076911, 'mode': 'train', 'accuracy': 0.95652, 'epoch': 57} 114 | {'loss': 0.36573378843771404, 'mode': 'test', 'accuracy': 0.895, 'epoch': 57} 115 | {'loss': 0.11873567865594557, 'mode': 'train', 'accuracy': 0.9589, 'epoch': 58} 116 | {'loss': 0.32789962482490376, 'mode': 'test', 'accuracy': 0.8985, 'epoch': 58} 117 | {'loss': 0.1111869095107707, 'mode': 'train', 'accuracy': 0.96088, 'epoch': 59} 118 | {'loss': 0.31956303617946663, 'mode': 'test', 'accuracy': 0.9019, 'epoch': 59} 119 | {'loss': 0.11284812037235184, 'mode': 'train', 'accuracy': 0.96058, 'epoch': 60} 120 | {'loss': 0.41686087353214346, 'mode': 'test', 'accuracy': 0.8788, 'epoch': 60} 121 | {'loss': 0.11122333228854889, 'mode': 'train', 'accuracy': 0.96026, 'epoch': 61} 122 | {'loss': 0.32971022278070444, 'mode': 'test', 'accuracy': 0.9001, 'epoch': 61} 123 | {'loss': 0.11963782632423317, 'mode': 'train', 'accuracy': 0.95828, 'epoch': 62} 124 | {'loss': 0.3454684309993581, 'mode': 'test', 'accuracy': 0.8994, 'epoch': 62} 125 | {'loss': 0.11059534256739537, 'mode': 'train', 'accuracy': 0.96144, 'epoch': 63} 126 | {'loss': 0.31646555961127504, 'mode': 'test', 'accuracy': 0.9014, 'epoch': 63} 127 | {'loss': 0.11062525651986947, 'mode': 'train', 'accuracy': 0.96198, 'epoch': 64} 128 | {'loss': 0.31434019868540924, 'mode': 'test', 'accuracy': 0.9023, 'epoch': 64} 129 | {'loss': 0.11140347934325633, 'mode': 'train', 'accuracy': 0.96036, 'epoch': 65} 130 | {'loss': 0.35508800032222365, 'mode': 'test', 'accuracy': 0.8972, 'epoch': 65} 131 | {'loss': 0.11169958122722486, 'mode': 'train', 'accuracy': 0.96086, 'epoch': 66} 132 | {'loss': 0.3512634305864763, 'mode': 'test', 'accuracy': 0.8917, 'epoch': 66} 133 | {'loss': 0.10244222205427597, 'mode': 'train', 'accuracy': 0.96394, 'epoch': 67} 134 | {'loss': 0.31726920263023156, 'mode': 'test', 'accuracy': 0.9007, 'epoch': 67} 135 | {'loss': 0.1143791249350591, 'mode': 'train', 'accuracy': 0.95992, 'epoch': 68} 136 | {'loss': 0.37649213256919456, 'mode': 'test', 'accuracy': 0.8858, 'epoch': 68} 137 | {'loss': 0.10819580082488638, 'mode': 'train', 'accuracy': 0.96228, 'epoch': 69} 138 | {'loss': 0.36308417421807143, 'mode': 'test', 'accuracy': 0.895, 'epoch': 69} 139 | {'loss': 0.107644793227353, 'mode': 'train', 'accuracy': 0.96276, 'epoch': 70} 140 | {'loss': 0.3798946409848085, 'mode': 'test', 'accuracy': 0.8897, 'epoch': 70} 141 | {'loss': 0.11175771160503786, 'mode': 'train', 'accuracy': 0.9597, 'epoch': 71} 142 | {'loss': 0.3433622278178194, 'mode': 'test', 'accuracy': 0.8938, 'epoch': 71} 143 | {'loss': 0.10451170038003152, 'mode': 'train', 'accuracy': 0.96332, 'epoch': 72} 144 | {'loss': 0.34372054244492484, 'mode': 'test', 'accuracy': 0.8974, 'epoch': 72} 145 | {'loss': 0.10656377096729519, 'mode': 'train', 'accuracy': 0.96342, 'epoch': 73} 146 | {'loss': 0.40727046791725086, 'mode': 'test', 'accuracy': 0.8828, 'epoch': 73} 147 | {'loss': 0.10515571675260958, 'mode': 'train', 'accuracy': 0.96414, 'epoch': 74} 148 | {'loss': 0.3458297490409226, 'mode': 'test', 'accuracy': 0.8928, 'epoch': 74} 149 | {'loss': 0.1191441372580006, 'mode': 'train', 'accuracy': 0.95764, 'epoch': 75} 150 | {'loss': 0.3132177760741512, 'mode': 'test', 'accuracy': 0.9047, 'epoch': 75} 151 | {'loss': 0.09833235890530717, 'mode': 'train', 'accuracy': 0.96574, 'epoch': 76} 152 | {'loss': 0.388969161090957, 'mode': 'test', 'accuracy': 0.8909, 'epoch': 76} 153 | {'loss': 0.110223984412487, 'mode': 'train', 'accuracy': 0.96114, 'epoch': 77} 154 | {'loss': 0.4030241401047463, 'mode': 'test', 'accuracy': 0.8849, 'epoch': 77} 155 | {'loss': 0.10254152673189432, 'mode': 'train', 'accuracy': 0.9643, 'epoch': 78} 156 | {'loss': 0.36017628640505905, 'mode': 'test', 'accuracy': 0.8902, 'epoch': 78} 157 | {'loss': 0.10229657004918452, 'mode': 'train', 'accuracy': 0.96506, 'epoch': 79} 158 | {'loss': 0.33533618876793575, 'mode': 'test', 'accuracy': 0.8989, 'epoch': 79} 159 | {'loss': 0.10514314044886229, 'mode': 'train', 'accuracy': 0.9631, 'epoch': 80} 160 | {'loss': 0.32156302969736655, 'mode': 'test', 'accuracy': 0.8998, 'epoch': 80} 161 | {'loss': 0.03282824413531726, 'mode': 'train', 'accuracy': 0.99072, 'epoch': 81} 162 | {'loss': 0.20515715065085965, 'mode': 'test', 'accuracy': 0.9339, 'epoch': 81} 163 | {'loss': 0.01350882985269474, 'mode': 'train', 'accuracy': 0.9978, 'epoch': 82} 164 | {'loss': 0.19697249559744914, 'mode': 'test', 'accuracy': 0.9372, 'epoch': 82} 165 | {'loss': 0.008524119563381696, 'mode': 'train', 'accuracy': 0.9991, 'epoch': 83} 166 | {'loss': 0.19931975408059782, 'mode': 'test', 'accuracy': 0.9392, 'epoch': 83} 167 | {'loss': 0.0066426828541719125, 'mode': 'train', 'accuracy': 0.99944, 'epoch': 84} 168 | {'loss': 0.2015725691939235, 'mode': 'test', 'accuracy': 0.9384, 'epoch': 84} 169 | {'loss': 0.005650836779066665, 'mode': 'train', 'accuracy': 0.9996, 'epoch': 85} 170 | {'loss': 0.20040988993303033, 'mode': 'test', 'accuracy': 0.9402, 'epoch': 85} 171 | {'loss': 0.0046106673910489785, 'mode': 'train', 'accuracy': 0.99968, 'epoch': 86} 172 | {'loss': 0.19892071377319898, 'mode': 'test', 'accuracy': 0.9407, 'epoch': 86} 173 | {'loss': 0.0038957311998090486, 'mode': 'train', 'accuracy': 0.99986, 'epoch': 87} 174 | {'loss': 0.19834516195070215, 'mode': 'test', 'accuracy': 0.9406, 'epoch': 87} 175 | {'loss': 0.0033843571896595673, 'mode': 'train', 'accuracy': 0.99982, 'epoch': 88} 176 | {'loss': 0.20265623044436146, 'mode': 'test', 'accuracy': 0.9402, 'epoch': 88} 177 | {'loss': 0.0029600172320290307, 'mode': 'train', 'accuracy': 0.99994, 'epoch': 89} 178 | {'loss': 0.20081229363182546, 'mode': 'test', 'accuracy': 0.9412, 'epoch': 89} 179 | {'loss': 0.0027051162727348647, 'mode': 'train', 'accuracy': 0.99998, 'epoch': 90} 180 | {'loss': 0.20215579528053107, 'mode': 'test', 'accuracy': 0.9418, 'epoch': 90} 181 | {'loss': 0.0024923448310331313, 'mode': 'train', 'accuracy': 1.0, 'epoch': 91} 182 | {'loss': 0.20225652226596882, 'mode': 'test', 'accuracy': 0.9407, 'epoch': 91} 183 | {'loss': 0.0024084903280753313, 'mode': 'train', 'accuracy': 0.99996, 'epoch': 92} 184 | {'loss': 0.20304609695152873, 'mode': 'test', 'accuracy': 0.9414, 'epoch': 92} 185 | {'loss': 0.0023134369831865724, 'mode': 'train', 'accuracy': 1.0, 'epoch': 93} 186 | {'loss': 0.20222375193124362, 'mode': 'test', 'accuracy': 0.9414, 'epoch': 93} 187 | {'loss': 0.0021321029042648864, 'mode': 'train', 'accuracy': 1.0, 'epoch': 94} 188 | {'loss': 0.20504200764617356, 'mode': 'test', 'accuracy': 0.9414, 'epoch': 94} 189 | {'loss': 0.0019849520224287064, 'mode': 'train', 'accuracy': 1.0, 'epoch': 95} 190 | {'loss': 0.19992726236866534, 'mode': 'test', 'accuracy': 0.9418, 'epoch': 95} 191 | {'loss': 0.0019801507330001766, 'mode': 'train', 'accuracy': 1.0, 'epoch': 96} 192 | {'loss': 0.1987796726452696, 'mode': 'test', 'accuracy': 0.9424, 'epoch': 96} 193 | {'loss': 0.0019010657068256227, 'mode': 'train', 'accuracy': 1.0, 'epoch': 97} 194 | {'loss': 0.20633137987772357, 'mode': 'test', 'accuracy': 0.9434, 'epoch': 97} 195 | {'loss': 0.0018241628242270718, 'mode': 'train', 'accuracy': 1.0, 'epoch': 98} 196 | {'loss': 0.20270946211401053, 'mode': 'test', 'accuracy': 0.9411, 'epoch': 98} 197 | {'loss': 0.0018041755603936025, 'mode': 'train', 'accuracy': 0.99998, 'epoch': 99} 198 | {'loss': 0.1996593784991724, 'mode': 'test', 'accuracy': 0.9426, 'epoch': 99} 199 | {'loss': 0.0016947369307965567, 'mode': 'train', 'accuracy': 1.0, 'epoch': 100} 200 | {'loss': 0.19792235255905777, 'mode': 'test', 'accuracy': 0.9429, 'epoch': 100} 201 | -------------------------------------------------------------------------------- /state/state_w_block.txt: -------------------------------------------------------------------------------- 1 | {'loss': 1.6878023418928965, 'epoch': 1, 'accuracy': 0.39366, 'mode': 'train'} 2 | {'loss': 1.4180713436406138, 'epoch': 1, 'accuracy': 0.4977, 'mode': 'test'} 3 | {'loss': 1.1537952408613763, 'epoch': 2, 'accuracy': 0.58558, 'mode': 'train'} 4 | {'loss': 1.041196934736459, 'epoch': 2, 'accuracy': 0.6441, 'mode': 'test'} 5 | {'loss': 0.8841170792079638, 'epoch': 3, 'accuracy': 0.68794, 'mode': 'train'} 6 | {'loss': 11.027340524895173, 'epoch': 3, 'accuracy': 0.6566, 'mode': 'test'} 7 | {'loss': 0.7457902899102475, 'epoch': 4, 'accuracy': 0.73874, 'mode': 'train'} 8 | {'loss': 0.7769547623054238, 'epoch': 4, 'accuracy': 0.7398, 'mode': 'test'} 9 | {'loss': 0.6367647716456356, 'epoch': 5, 'accuracy': 0.77632, 'mode': 'train'} 10 | {'loss': 0.9578903011835301, 'epoch': 5, 'accuracy': 0.7441, 'mode': 'test'} 11 | {'loss': 0.5637520943075197, 'epoch': 6, 'accuracy': 0.80326, 'mode': 'train'} 12 | {'loss': 0.8507674080171401, 'epoch': 6, 'accuracy': 0.7581, 'mode': 'test'} 13 | {'loss': 0.5041318567436374, 'epoch': 7, 'accuracy': 0.82476, 'mode': 'train'} 14 | {'loss': 0.732868454828384, 'epoch': 7, 'accuracy': 0.7968, 'mode': 'test'} 15 | {'loss': 0.44785956613471695, 'epoch': 8, 'accuracy': 0.84272, 'mode': 'train'} 16 | {'loss': 0.5479633343067895, 'epoch': 8, 'accuracy': 0.8169, 'mode': 'test'} 17 | {'loss': 0.40654291252574626, 'epoch': 9, 'accuracy': 0.85884, 'mode': 'train'} 18 | {'loss': 0.5620520483156685, 'epoch': 9, 'accuracy': 0.8221, 'mode': 'test'} 19 | {'loss': 0.3716294059286947, 'epoch': 10, 'accuracy': 0.8713, 'mode': 'train'} 20 | {'loss': 0.5570676779481257, 'epoch': 10, 'accuracy': 0.824, 'mode': 'test'} 21 | {'loss': 0.34448742169096025, 'epoch': 11, 'accuracy': 0.88036, 'mode': 'train'} 22 | {'loss': 0.8140175183107892, 'epoch': 11, 'accuracy': 0.7746, 'mode': 'test'} 23 | {'loss': 0.32498678763199307, 'epoch': 12, 'accuracy': 0.88762, 'mode': 'train'} 24 | {'loss': 0.5350769910083453, 'epoch': 12, 'accuracy': 0.8411, 'mode': 'test'} 25 | {'loss': 0.29884508675169175, 'epoch': 13, 'accuracy': 0.89606, 'mode': 'train'} 26 | {'loss': 1.0538866677481655, 'epoch': 13, 'accuracy': 0.8162, 'mode': 'test'} 27 | {'loss': 0.28882286142643593, 'epoch': 14, 'accuracy': 0.89938, 'mode': 'train'} 28 | {'loss': 0.6080745608563638, 'epoch': 14, 'accuracy': 0.8125, 'mode': 'test'} 29 | {'loss': 0.2737785367023608, 'epoch': 15, 'accuracy': 0.9028, 'mode': 'train'} 30 | {'loss': 0.43266354567685733, 'epoch': 15, 'accuracy': 0.8521, 'mode': 'test'} 31 | {'loss': 0.25536154588813087, 'epoch': 16, 'accuracy': 0.91162, 'mode': 'train'} 32 | {'loss': 0.457165577609068, 'epoch': 16, 'accuracy': 0.8507, 'mode': 'test'} 33 | {'loss': 0.2415769131249174, 'epoch': 17, 'accuracy': 0.91566, 'mode': 'train'} 34 | {'loss': 0.5411718308356157, 'epoch': 17, 'accuracy': 0.8316, 'mode': 'test'} 35 | {'loss': 0.2281929301979292, 'epoch': 18, 'accuracy': 0.91922, 'mode': 'train'} 36 | {'loss': 0.5226842206754502, 'epoch': 18, 'accuracy': 0.8287, 'mode': 'test'} 37 | {'loss': 0.22288499322369737, 'epoch': 19, 'accuracy': 0.92152, 'mode': 'train'} 38 | {'loss': 0.40844785578691284, 'epoch': 19, 'accuracy': 0.8667, 'mode': 'test'} 39 | {'loss': 0.20818384287073796, 'epoch': 20, 'accuracy': 0.92726, 'mode': 'train'} 40 | {'loss': 0.3998396033124561, 'epoch': 20, 'accuracy': 0.8685, 'mode': 'test'} 41 | {'loss': 0.2014645495359093, 'epoch': 21, 'accuracy': 0.92914, 'mode': 'train'} 42 | {'loss': 0.4011177501765784, 'epoch': 21, 'accuracy': 0.8709, 'mode': 'test'} 43 | {'loss': 0.19438265039659353, 'epoch': 22, 'accuracy': 0.9319, 'mode': 'train'} 44 | {'loss': 0.38241399127017156, 'epoch': 22, 'accuracy': 0.8727, 'mode': 'test'} 45 | {'loss': 0.1836819075943564, 'epoch': 23, 'accuracy': 0.93518, 'mode': 'train'} 46 | {'loss': 0.43507652307391953, 'epoch': 23, 'accuracy': 0.8666, 'mode': 'test'} 47 | {'loss': 0.18599094642454872, 'epoch': 24, 'accuracy': 0.93526, 'mode': 'train'} 48 | {'loss': 0.3897396937297408, 'epoch': 24, 'accuracy': 0.8733, 'mode': 'test'} 49 | {'loss': 0.18007229077046172, 'epoch': 25, 'accuracy': 0.93696, 'mode': 'train'} 50 | {'loss': 0.4833032493568529, 'epoch': 25, 'accuracy': 0.852, 'mode': 'test'} 51 | {'loss': 0.17575007102564158, 'epoch': 26, 'accuracy': 0.93852, 'mode': 'train'} 52 | {'loss': 0.43059357858387515, 'epoch': 26, 'accuracy': 0.8675, 'mode': 'test'} 53 | {'loss': 0.17085249830618537, 'epoch': 27, 'accuracy': 0.9396, 'mode': 'train'} 54 | {'loss': 0.4035677901783568, 'epoch': 27, 'accuracy': 0.8708, 'mode': 'test'} 55 | {'loss': 0.16510102132816448, 'epoch': 28, 'accuracy': 0.94172, 'mode': 'train'} 56 | {'loss': 0.3759245786602332, 'epoch': 28, 'accuracy': 0.8768, 'mode': 'test'} 57 | {'loss': 0.1576653265148937, 'epoch': 29, 'accuracy': 0.94468, 'mode': 'train'} 58 | {'loss': 0.38272928641100634, 'epoch': 29, 'accuracy': 0.8774, 'mode': 'test'} 59 | {'loss': 0.16132679469216524, 'epoch': 30, 'accuracy': 0.94344, 'mode': 'train'} 60 | {'loss': 0.38974595031920495, 'epoch': 30, 'accuracy': 0.8774, 'mode': 'test'} 61 | {'loss': 0.156163436910876, 'epoch': 31, 'accuracy': 0.94556, 'mode': 'train'} 62 | {'loss': 0.39439939057371426, 'epoch': 31, 'accuracy': 0.8767, 'mode': 'test'} 63 | {'loss': 0.15216027187836131, 'epoch': 32, 'accuracy': 0.94668, 'mode': 'train'} 64 | {'loss': 0.35250907453002456, 'epoch': 32, 'accuracy': 0.8852, 'mode': 'test'} 65 | {'loss': 0.15179247764484663, 'epoch': 33, 'accuracy': 0.94728, 'mode': 'train'} 66 | {'loss': 0.4489999884253094, 'epoch': 33, 'accuracy': 0.8617, 'mode': 'test'} 67 | {'loss': 0.15534332335528808, 'epoch': 34, 'accuracy': 0.94518, 'mode': 'train'} 68 | {'loss': 0.45389465384992067, 'epoch': 34, 'accuracy': 0.8641, 'mode': 'test'} 69 | {'loss': 0.14483047370105756, 'epoch': 35, 'accuracy': 0.94972, 'mode': 'train'} 70 | {'loss': 0.36713146741033365, 'epoch': 35, 'accuracy': 0.8839, 'mode': 'test'} 71 | {'loss': 0.1465595526539759, 'epoch': 36, 'accuracy': 0.94834, 'mode': 'train'} 72 | {'loss': 0.3612849863281675, 'epoch': 36, 'accuracy': 0.884, 'mode': 'test'} 73 | {'loss': 0.1488111440444845, 'epoch': 37, 'accuracy': 0.94648, 'mode': 'train'} 74 | {'loss': 0.3958049730700292, 'epoch': 37, 'accuracy': 0.8784, 'mode': 'test'} 75 | {'loss': 0.14151994888301545, 'epoch': 38, 'accuracy': 0.94962, 'mode': 'train'} 76 | {'loss': 0.3892430450031712, 'epoch': 38, 'accuracy': 0.8803, 'mode': 'test'} 77 | {'loss': 0.1447155995537408, 'epoch': 39, 'accuracy': 0.9489, 'mode': 'train'} 78 | {'loss': 0.3529704438558049, 'epoch': 39, 'accuracy': 0.889, 'mode': 'test'} 79 | {'loss': 0.14147455729258337, 'epoch': 40, 'accuracy': 0.9499, 'mode': 'train'} 80 | {'loss': 0.3683332359525049, 'epoch': 40, 'accuracy': 0.8868, 'mode': 'test'} 81 | {'loss': 0.13488899756346825, 'epoch': 41, 'accuracy': 0.953, 'mode': 'train'} 82 | {'loss': 0.41129933658299156, 'epoch': 41, 'accuracy': 0.879, 'mode': 'test'} 83 | {'loss': 0.13718361270320995, 'epoch': 42, 'accuracy': 0.95218, 'mode': 'train'} 84 | {'loss': 0.34767972189150026, 'epoch': 42, 'accuracy': 0.8918, 'mode': 'test'} 85 | {'loss': 0.14174377004546904, 'epoch': 43, 'accuracy': 0.95058, 'mode': 'train'} 86 | {'loss': 0.3832332724408739, 'epoch': 43, 'accuracy': 0.8779, 'mode': 'test'} 87 | {'loss': 0.13233132431726627, 'epoch': 44, 'accuracy': 0.95418, 'mode': 'train'} 88 | {'loss': 0.34930924212287195, 'epoch': 44, 'accuracy': 0.8901, 'mode': 'test'} 89 | {'loss': 0.13750159753786634, 'epoch': 45, 'accuracy': 0.9513, 'mode': 'train'} 90 | {'loss': 0.4319567603005726, 'epoch': 45, 'accuracy': 0.87, 'mode': 'test'} 91 | {'loss': 0.12734805751601458, 'epoch': 46, 'accuracy': 0.95502, 'mode': 'train'} 92 | {'loss': 0.3530826668260964, 'epoch': 46, 'accuracy': 0.8889, 'mode': 'test'} 93 | {'loss': 0.13097982949403392, 'epoch': 47, 'accuracy': 0.95434, 'mode': 'train'} 94 | {'loss': 0.457308367464193, 'epoch': 47, 'accuracy': 0.8651, 'mode': 'test'} 95 | {'loss': 0.12688765973996016, 'epoch': 48, 'accuracy': 0.95584, 'mode': 'train'} 96 | {'loss': 0.3597108416116922, 'epoch': 48, 'accuracy': 0.8855, 'mode': 'test'} 97 | {'loss': 0.1351787410343012, 'epoch': 49, 'accuracy': 0.95264, 'mode': 'train'} 98 | {'loss': 0.36300143006311103, 'epoch': 49, 'accuracy': 0.8867, 'mode': 'test'} 99 | {'loss': 0.1208075447260495, 'epoch': 50, 'accuracy': 0.95824, 'mode': 'train'} 100 | {'loss': 0.3568245654652833, 'epoch': 50, 'accuracy': 0.8897, 'mode': 'test'} 101 | {'loss': 0.13121635812665797, 'epoch': 51, 'accuracy': 0.95376, 'mode': 'train'} 102 | {'loss': 0.44519333484446183, 'epoch': 51, 'accuracy': 0.8689, 'mode': 'test'} 103 | {'loss': 0.1257994194345935, 'epoch': 52, 'accuracy': 0.95618, 'mode': 'train'} 104 | {'loss': 0.4441747354094389, 'epoch': 52, 'accuracy': 0.869, 'mode': 'test'} 105 | {'loss': 0.12502273627559224, 'epoch': 53, 'accuracy': 0.95664, 'mode': 'train'} 106 | {'loss': 0.391679892209685, 'epoch': 53, 'accuracy': 0.8811, 'mode': 'test'} 107 | {'loss': 0.1282738177748896, 'epoch': 54, 'accuracy': 0.95562, 'mode': 'train'} 108 | {'loss': 0.35270995689425494, 'epoch': 54, 'accuracy': 0.8886, 'mode': 'test'} 109 | {'loss': 0.12387080912185293, 'epoch': 55, 'accuracy': 0.95542, 'mode': 'train'} 110 | {'loss': 0.38550101766350925, 'epoch': 55, 'accuracy': 0.8869, 'mode': 'test'} 111 | {'loss': 0.117419524904331, 'epoch': 56, 'accuracy': 0.96008, 'mode': 'train'} 112 | {'loss': 0.34477526623352334, 'epoch': 56, 'accuracy': 0.894, 'mode': 'test'} 113 | {'loss': 0.11950218402173209, 'epoch': 57, 'accuracy': 0.95794, 'mode': 'train'} 114 | {'loss': 0.4892435942296012, 'epoch': 57, 'accuracy': 0.8623, 'mode': 'test'} 115 | {'loss': 0.11581262345890249, 'epoch': 58, 'accuracy': 0.95936, 'mode': 'train'} 116 | {'loss': 0.36563204480394446, 'epoch': 58, 'accuracy': 0.8846, 'mode': 'test'} 117 | {'loss': 0.11795573615376137, 'epoch': 59, 'accuracy': 0.95928, 'mode': 'train'} 118 | {'loss': 0.36611607129786433, 'epoch': 59, 'accuracy': 0.8889, 'mode': 'test'} 119 | {'loss': 0.11554457514982698, 'epoch': 60, 'accuracy': 0.95958, 'mode': 'train'} 120 | {'loss': 0.38289556429264626, 'epoch': 60, 'accuracy': 0.8855, 'mode': 'test'} 121 | {'loss': 0.1196208388956688, 'epoch': 61, 'accuracy': 0.95886, 'mode': 'train'} 122 | {'loss': 0.3806887177430141, 'epoch': 61, 'accuracy': 0.8889, 'mode': 'test'} 123 | {'loss': 0.1171161580040021, 'epoch': 62, 'accuracy': 0.95878, 'mode': 'train'} 124 | {'loss': 0.36984551612548777, 'epoch': 62, 'accuracy': 0.8837, 'mode': 'test'} 125 | {'loss': 0.11517268386395538, 'epoch': 63, 'accuracy': 0.96018, 'mode': 'train'} 126 | {'loss': 0.45217171491710995, 'epoch': 63, 'accuracy': 0.8704, 'mode': 'test'} 127 | {'loss': 0.11734662872388989, 'epoch': 64, 'accuracy': 0.95912, 'mode': 'train'} 128 | {'loss': 0.3764185043182342, 'epoch': 64, 'accuracy': 0.8885, 'mode': 'test'} 129 | {'loss': 0.11715841781147911, 'epoch': 65, 'accuracy': 0.9591, 'mode': 'train'} 130 | {'loss': 0.3703689343610385, 'epoch': 65, 'accuracy': 0.8878, 'mode': 'test'} 131 | {'loss': 0.11633476219556814, 'epoch': 66, 'accuracy': 0.95976, 'mode': 'train'} 132 | {'loss': 0.34874593438046775, 'epoch': 66, 'accuracy': 0.8905, 'mode': 'test'} 133 | {'loss': 0.11482773278661222, 'epoch': 67, 'accuracy': 0.95944, 'mode': 'train'} 134 | {'loss': 0.4132755638877296, 'epoch': 67, 'accuracy': 0.878, 'mode': 'test'} 135 | {'loss': 0.11674660698170088, 'epoch': 68, 'accuracy': 0.9592, 'mode': 'train'} 136 | {'loss': 0.3562072636025726, 'epoch': 68, 'accuracy': 0.8883, 'mode': 'test'} 137 | {'loss': 0.1181230912428073, 'epoch': 69, 'accuracy': 0.95822, 'mode': 'train'} 138 | {'loss': 0.31114418163990515, 'epoch': 69, 'accuracy': 0.9011, 'mode': 'test'} 139 | {'loss': 0.11711958902018604, 'epoch': 70, 'accuracy': 0.95842, 'mode': 'train'} 140 | {'loss': 0.4510415825209801, 'epoch': 70, 'accuracy': 0.8697, 'mode': 'test'} 141 | {'loss': 0.10480988568738293, 'epoch': 71, 'accuracy': 0.96416, 'mode': 'train'} 142 | {'loss': 0.3519335717057726, 'epoch': 71, 'accuracy': 0.8928, 'mode': 'test'} 143 | {'loss': 0.11741906517873638, 'epoch': 72, 'accuracy': 0.9587, 'mode': 'train'} 144 | {'loss': 0.3859244231490572, 'epoch': 72, 'accuracy': 0.8857, 'mode': 'test'} 145 | {'loss': 0.1126188402876373, 'epoch': 73, 'accuracy': 0.96104, 'mode': 'train'} 146 | {'loss': 0.37108120701874886, 'epoch': 73, 'accuracy': 0.8888, 'mode': 'test'} 147 | {'loss': 0.10980939855108333, 'epoch': 74, 'accuracy': 0.96062, 'mode': 'train'} 148 | {'loss': 0.33508176147747937, 'epoch': 74, 'accuracy': 0.8975, 'mode': 'test'} 149 | {'loss': 0.1066611345352419, 'epoch': 75, 'accuracy': 0.96298, 'mode': 'train'} 150 | {'loss': 0.33882362125026216, 'epoch': 75, 'accuracy': 0.8963, 'mode': 'test'} 151 | {'loss': 0.11316704521875089, 'epoch': 76, 'accuracy': 0.96036, 'mode': 'train'} 152 | {'loss': 0.39266691617904936, 'epoch': 76, 'accuracy': 0.8859, 'mode': 'test'} 153 | {'loss': 0.10993746015459988, 'epoch': 77, 'accuracy': 0.9619, 'mode': 'train'} 154 | {'loss': 0.3893682504915126, 'epoch': 77, 'accuracy': 0.8833, 'mode': 'test'} 155 | {'loss': 0.11443343077359894, 'epoch': 78, 'accuracy': 0.95926, 'mode': 'train'} 156 | {'loss': 0.3712494020249435, 'epoch': 78, 'accuracy': 0.8835, 'mode': 'test'} 157 | {'loss': 0.10540790580536999, 'epoch': 79, 'accuracy': 0.9639, 'mode': 'train'} 158 | {'loss': 0.40984857618618925, 'epoch': 79, 'accuracy': 0.8851, 'mode': 'test'} 159 | {'loss': 0.10419517668330913, 'epoch': 80, 'accuracy': 0.96382, 'mode': 'train'} 160 | {'loss': 0.3997493417589526, 'epoch': 80, 'accuracy': 0.8834, 'mode': 'test'} 161 | {'loss': 0.038183429642863934, 'epoch': 81, 'accuracy': 0.98824, 'mode': 'train'} 162 | {'loss': 0.22695207137875495, 'epoch': 81, 'accuracy': 0.929, 'mode': 'test'} 163 | {'loss': 0.014524630792534278, 'epoch': 82, 'accuracy': 0.99738, 'mode': 'train'} 164 | {'loss': 0.22414740820409382, 'epoch': 82, 'accuracy': 0.933, 'mode': 'test'} 165 | {'loss': 0.009587154244942121, 'epoch': 83, 'accuracy': 0.99886, 'mode': 'train'} 166 | {'loss': 0.22827114174320448, 'epoch': 83, 'accuracy': 0.9335, 'mode': 'test'} 167 | {'loss': 0.0067085563955480735, 'epoch': 84, 'accuracy': 0.9994, 'mode': 'train'} 168 | {'loss': 0.22448316895088583, 'epoch': 84, 'accuracy': 0.9354, 'mode': 'test'} 169 | {'loss': 0.005805204586719002, 'epoch': 85, 'accuracy': 0.99954, 'mode': 'train'} 170 | {'loss': 0.2266250274078861, 'epoch': 85, 'accuracy': 0.9348, 'mode': 'test'} 171 | {'loss': 0.004691207176431665, 'epoch': 86, 'accuracy': 0.99972, 'mode': 'train'} 172 | {'loss': 0.22724492684196504, 'epoch': 86, 'accuracy': 0.9362, 'mode': 'test'} 173 | {'loss': 0.004005829934650066, 'epoch': 87, 'accuracy': 0.99976, 'mode': 'train'} 174 | {'loss': 0.23144147746786947, 'epoch': 87, 'accuracy': 0.9352, 'mode': 'test'} 175 | {'loss': 0.0035704533948236724, 'epoch': 88, 'accuracy': 0.9998, 'mode': 'train'} 176 | {'loss': 0.23236922381125447, 'epoch': 88, 'accuracy': 0.9351, 'mode': 'test'} 177 | {'loss': 0.0031323154335436577, 'epoch': 89, 'accuracy': 0.99986, 'mode': 'train'} 178 | {'loss': 0.23092179458327355, 'epoch': 89, 'accuracy': 0.9362, 'mode': 'test'} 179 | {'loss': 0.002842128591235639, 'epoch': 90, 'accuracy': 0.99994, 'mode': 'train'} 180 | {'loss': 0.23246265698675134, 'epoch': 90, 'accuracy': 0.9373, 'mode': 'test'} 181 | {'loss': 0.0026545836340131014, 'epoch': 91, 'accuracy': 0.99996, 'mode': 'train'} 182 | {'loss': 0.23077612894640603, 'epoch': 91, 'accuracy': 0.9367, 'mode': 'test'} 183 | {'loss': 0.002389573927997324, 'epoch': 92, 'accuracy': 0.99998, 'mode': 'train'} 184 | {'loss': 0.23043886128409657, 'epoch': 92, 'accuracy': 0.9373, 'mode': 'test'} 185 | {'loss': 0.0022980364282493963, 'epoch': 93, 'accuracy': 1.0, 'mode': 'train'} 186 | {'loss': 0.23280814294792282, 'epoch': 93, 'accuracy': 0.9368, 'mode': 'test'} 187 | {'loss': 0.0021505535954176, 'epoch': 94, 'accuracy': 0.99994, 'mode': 'train'} 188 | {'loss': 0.23303929723467054, 'epoch': 94, 'accuracy': 0.9369, 'mode': 'test'} 189 | {'loss': 0.0020761114147389323, 'epoch': 95, 'accuracy': 0.99996, 'mode': 'train'} 190 | {'loss': 0.23053068881201894, 'epoch': 95, 'accuracy': 0.9373, 'mode': 'test'} 191 | {'loss': 0.0019166785509080221, 'epoch': 96, 'accuracy': 1.0, 'mode': 'train'} 192 | {'loss': 0.22991067382275676, 'epoch': 96, 'accuracy': 0.9375, 'mode': 'test'} 193 | {'loss': 0.0019486815861576355, 'epoch': 97, 'accuracy': 0.99994, 'mode': 'train'} 194 | {'loss': 0.23197510425642034, 'epoch': 97, 'accuracy': 0.9371, 'mode': 'test'} 195 | {'loss': 0.001756970415754089, 'epoch': 98, 'accuracy': 1.0, 'mode': 'train'} 196 | {'loss': 0.2303124351581191, 'epoch': 98, 'accuracy': 0.9381, 'mode': 'test'} 197 | {'loss': 0.0017600434896586198, 'epoch': 99, 'accuracy': 1.0, 'mode': 'train'} 198 | {'loss': 0.23283885224799455, 'epoch': 99, 'accuracy': 0.9363, 'mode': 'test'} 199 | {'loss': 0.0017902276066639215, 'epoch': 100, 'accuracy': 0.99998, 'mode': 'train'} 200 | {'loss': 0.23165485941490543, 'epoch': 100, 'accuracy': 0.9375, 'mode': 'test'} -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch 3 | 4 | from tqdm import tqdm 5 | 6 | 7 | class Trainer(object): 8 | cuda = torch.cuda.is_available() 9 | torch.backends.cudnn.benchmark = True 10 | 11 | def __init__(self, model, optimizer, loss_f, save_dir=None, save_freq=5): 12 | self.model = model 13 | if self.cuda: 14 | model.cuda() 15 | self.optimizer = optimizer 16 | self.loss_f = loss_f 17 | self.save_dir = save_dir 18 | self.save_freq = save_freq 19 | 20 | def _iteration(self, data_loader, is_train=True): 21 | loop_loss = [] 22 | accuracy = [] 23 | for data, target in tqdm(data_loader, ncols=80): 24 | if self.cuda: 25 | data, target = data.cuda(), target.cuda() 26 | output = self.model(data) 27 | loss = self.loss_f(output, target) 28 | loop_loss.append(loss.data.item() / len(data_loader)) 29 | accuracy.append((output.data.max(1)[1] == target.data).sum().item()) 30 | if is_train: 31 | self.optimizer.zero_grad() 32 | loss.backward() 33 | self.optimizer.step() 34 | print() 35 | mode = "train" if is_train else "test" 36 | print(">>>[{}] loss: {:.4f}/accuracy: {:.4f}".format(mode, sum(loop_loss), sum(accuracy) / len(data_loader.dataset) )) 37 | return mode, sum(loop_loss), sum(accuracy) / len(data_loader.dataset) 38 | 39 | def train(self, data_loader): 40 | self.model.train() 41 | with torch.enable_grad(): 42 | mode, loss, correct = self._iteration(data_loader) 43 | return mode, loss, correct 44 | 45 | def test(self, data_loader): 46 | self.model.eval() 47 | with torch.no_grad(): 48 | mode, loss, correct = self._iteration(data_loader, is_train=False) 49 | return mode, loss, correct 50 | 51 | def loop(self, epochs, train_data, test_data, scheduler=None): 52 | for ep in range(1, epochs + 1): 53 | if scheduler is not None: 54 | scheduler.step() 55 | print("epochs: {}".format(ep)) 56 | # save statistics into txt file 57 | self.save_statistic(*((ep,) + self.train(train_data))) 58 | self.save_statistic(*((ep,) + self.test(test_data))) 59 | if ep % self.save_freq: 60 | self.save(ep) 61 | 62 | def save(self, epoch, **kwargs): 63 | if self.save_dir is not None: 64 | model_out_path = Path(self.save_dir) 65 | state = {"epoch": epoch, "net_state_dict": self.model.state_dict()} 66 | if not model_out_path.exists(): 67 | model_out_path.mkdir() 68 | torch.save(state, model_out_path / "model_epoch_{}.ckpt".format(epoch)) 69 | 70 | def save_statistic(self, epoch, mode, loss, accuracy): 71 | with open("state.txt", "a", encoding="utf-8") as f: 72 | f.write(str({"epoch": epoch, "mode": mode, "loss": loss, "accuracy": accuracy})) 73 | f.write("\n") 74 | -------------------------------------------------------------------------------- /visual/viz.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | 3 | 4 | def read_txt(path): 5 | with open(path, "r", encoding="utf-8") as file: 6 | content = [eval(line.replace("\n", "")) for line in file.readlines()] 7 | return content 8 | 9 | def get_train_data(data): 10 | assert isinstance(data, list), "``data`` should be list type" 11 | epoch = [] 12 | acc = [] 13 | loss = [] 14 | for line in data: 15 | if line["mode"] == "train": 16 | epoch.append(line["epoch"]) 17 | acc.append(line["accuracy"]) 18 | loss.append(line["loss"]) 19 | return {"epoch": epoch, "accuracy": acc, "loss": loss} 20 | 21 | def get_val_data(data): 22 | assert isinstance(data, list), "``data`` should be ``list`` type" 23 | epoch = [] 24 | acc = [] 25 | loss = [] 26 | for line in data: 27 | if line["mode"] == "test": 28 | epoch.append(line["epoch"]) 29 | acc.append(line["accuracy"]) 30 | loss.append(line["loss"]) 31 | return {"epoch": epoch, "accuracy": acc, "loss": loss} 32 | 33 | def get_best_val_acc(data): 34 | assert isinstance(data, list), "``data`` must be ``list`` type" 35 | return max(data) 36 | 37 | def show(data): 38 | assert isinstance(data, dict), "``data`` should be ``dict`` type" 39 | _, axs = plt.subplots(1, 2,figsize=(16, 4), sharey=False) 40 | axs[0].plot(data["epoch"], data["accuracy"]) 41 | axs[0].set_title("accuracy") 42 | 43 | axs[1].plot(data["epoch"], data["loss"]) 44 | axs[1].set_title("loss") 45 | plt.show() 46 | 47 | if __name__ == "__main__": 48 | path = "F:/mixed-densenet/state/state_full_in_loop.txt" 49 | data = read_txt(path) 50 | train_data = get_train_data(data) 51 | print("best train accuracy:") 52 | print(get_best_val_acc(train_data["accuracy"])) 53 | val_data = get_val_data(data) 54 | print("best val accuracy:") 55 | print(get_best_val_acc(val_data["accuracy"])) 56 | show(train_data) 57 | show(val_data) 58 | -------------------------------------------------------------------------------- /weights/save_module_weights_at_here.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yy9568/SE_DenseNet/6a4d218cfb2fb7a6437339f9c1a4a0bb5ed4ce93/weights/save_module_weights_at_here.md --------------------------------------------------------------------------------