├── models
├── __init__.py
├── resnet_26.py
├── wide_resnet.py
├── experimental.py
├── yolo.py
└── common.py
├── utils
├── __init__.py
├── testing.py
├── rotation.py
├── training.py
├── google_utils.py
├── results_manager.py
├── autoanchor.py
├── loss.py
├── testing_yolov3.py
├── metrics.py
├── utils.py
├── torch_utils.py
├── data_loader.py
└── plots.py
├── .gitignore
├── methods
├── __init__.py
└── dua.py
├── .idea
├── .gitignore
├── vcs.xml
├── misc.xml
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
├── modules.xml
└── DUA.iml
├── globals.py
├── requirements.txt
├── readme
├── preparing_datasets.md
└── directory_scructures.md
├── README.md
├── config.py
├── init.py
└── main.py
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | results/
2 |
--------------------------------------------------------------------------------
/methods/__init__.py:
--------------------------------------------------------------------------------
1 | from .dua import dua
2 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/DUA.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/globals.py:
--------------------------------------------------------------------------------
1 |
2 | # Variables here are initialized with exemplary values and will be overwritten
3 | # based on commandline arguments
4 | SEVERTITIES = ['5']
5 |
6 | ROBUSTNESS_SEVERITIES = ['5', '4', '3', '2', '1']
7 |
8 | KITTI_SEVERITIES = {
9 | 'fog': ['fog_30', 'fog_40', 'fog_50'],
10 | 'rain': ['200mm', '100mm', '75mm'],
11 | 'snow': ['5', '5', '5']
12 | }
13 |
14 | TASKS = []
15 |
16 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | colorama==0.3.7
2 | coremltools==6.1
3 | matplotlib==2.1.1
4 | numpy==1.13.3
5 | onnx==1.12.0
6 | pandas==0.22.0
7 | Pillow==9.3.0
8 | protobuf==4.21.9
9 | pycocotools==2.0.6
10 | PyYAML==6.0
11 | requests==2.18.4
12 | scipy==0.19.1
13 | seaborn==0.11.2
14 | setuptools==39.0.1
15 | thop==0.1.1.post2209072238
16 | torch==1.10.2+cu113
17 | torchvision==0.11.3+cu113
18 | tqdm==4.64.0
19 | wandb==0.13.5
20 |
--------------------------------------------------------------------------------
/utils/testing.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from utils.rotation import *
4 |
5 |
6 | def test(dataloader, model, sslabel=None):
7 | device = next(model.parameters()).device
8 | criterion = nn.CrossEntropyLoss(reduction='none').to(device)
9 | model.eval()
10 | correct = []
11 | losses = []
12 | for batch_idx, (inputs, labels) in enumerate(dataloader):
13 | if sslabel is not None:
14 | inputs, labels = rotate_batch(inputs, sslabel)
15 | inputs, labels = inputs.to(device), labels.to(device)
16 | with torch.no_grad():
17 | outputs = model(inputs)
18 | loss = criterion(outputs, labels)
19 | losses.append(loss.cpu())
20 | _, predicted = outputs.max(1)
21 | correct.append(predicted.eq(labels).cpu())
22 | correct = torch.cat(correct).numpy()
23 | losses = torch.cat(losses).numpy()
24 | return 1- correct.mean(), correct, losses
25 |
--------------------------------------------------------------------------------
/readme/preparing_datasets.md:
--------------------------------------------------------------------------------
1 | # Preparing Datasets
2 | ## ImageNet and CIFAR datasets
3 | * Download the original train and test set for [ImageNet](https://image-net.org/download.php) & [ImageNet-C](https://zenodo.org/record/2235448#.Yn5OTrozZhE) datasets.
4 | * Download the original train and test set for [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) & [CIFAR-10C](https://zenodo.org/record/2535967#.Yn5QwbozZhE) datasets.
5 |
6 | ## KITTI dataset
7 | * Download Clear (Original) [KITTI dataset](http://www.cvlibs.net/datasets/kitti/).
8 | * Download [KITTI-Fog/Rain](https://team.inria.fr/rits/computer-vision/weather-augment/) datasets.
9 | * Super-impose snow on KITTI dataset through this [repository](https://github.com/hendrycks/robustness).
10 | * Generate labels YOLO can use (see [Dataset directory structures](#dataset-directory-structures) subsection).
11 |
12 | To generate labels YOLO can use from the original KITTI labels run
13 |
14 | `python main.py --kitti_to_yolo_labels /path/to/original/kitti`
15 |
16 | This is expecting the path to the original KITTI directory structure
17 | ```
18 | path_to_specify
19 | └── raw
20 | └── training
21 | ├── image_2
22 | └── label_2
23 | ```
24 | Which will create a `yolo_style_labels` directory in the `raw` directory, containing
25 | the KITTI labels in a format YOLO can use.
26 |
27 | Structure the choosen dataset(s) as described [here](directory_scructures.md).
--------------------------------------------------------------------------------
/utils/rotation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms.functional as TF
3 |
4 |
5 | def tensor_rot_90(x):
6 | x = TF.rotate(x, 90)
7 | return x
8 |
9 |
10 | def tensor_rot_180(x):
11 | x = TF.rotate(x, 180)
12 | return x
13 |
14 |
15 | def tensor_rot_270(x):
16 | x = TF.rotate(x, 270)
17 | return x
18 |
19 |
20 | def rotate_batch_with_labels(batch, labels):
21 | images = []
22 | for img, label in zip(batch, labels):
23 | if label == 1:
24 | img = tensor_rot_90(img)
25 | elif label == 2:
26 | img = tensor_rot_180(img)
27 | elif label == 3:
28 | img = tensor_rot_270(img)
29 | images.append(img.unsqueeze(0))
30 | return torch.cat(images)
31 |
32 |
33 | def rotate_batch(batch, label):
34 | if label == 'rand':
35 | labels = torch.randint(4, (len(batch),), dtype=torch.long)
36 | elif label == 'expand':
37 | labels = torch.cat([torch.zeros(len(batch), dtype=torch.long),
38 | torch.zeros(len(batch), dtype=torch.long) + 1,
39 | torch.zeros(len(batch), dtype=torch.long) + 2,
40 | torch.zeros(len(batch), dtype=torch.long) + 3])
41 | batch = batch.repeat((4, 1, 1, 1))
42 | else:
43 | assert isinstance(label, int)
44 | labels = torch.zeros((len(batch),), dtype=torch.long) + label
45 | return rotate_batch_with_labels(batch, labels), labels
--------------------------------------------------------------------------------
/readme/directory_scructures.md:
--------------------------------------------------------------------------------
1 | # Dataset Directory structures
2 |
3 | ## KITTI
4 | ```
5 | args.dataroot
6 | ├── fog
7 | | ├── fog_30
8 | | | ├── *png
9 | | |
10 | | ├── ... other severities
11 | |
12 | ├── initial
13 | | └── images
14 | | ├── *.png
15 | |
16 | ├── labels_caches [this is an initially empty directory]
17 | |
18 | ├── labels_yolo_format
19 | | ├── *.txt
20 | |
21 | ├── rain
22 | | ├── 200mm
23 | | | ├── *png
24 | | |
25 | | ├── ... other severities
26 | |
27 | ├── test.txt
28 | ├── train.txt
29 | └── val.txt
30 | ```
31 | The .txt files contain a list of image names defining the train/val/test splits.
32 |
33 |
34 |
35 | ## CIFAR-10-C
36 |
37 | ```
38 | args.dataroot
39 | ├── cifar-10-batches-py
40 | | ├── batches.meta
41 | | ├── data_batch_1
42 | | ├── ...
43 | |
44 | └── CIFAR-10-C
45 | ├── test
46 | | ├── brightness.npy
47 | | ├── contrast.npy
48 | | ├── ...
49 | |
50 | └── train
51 | ├── brightness.npy
52 | ├── contrast.npy
53 | ├── ...
54 |
55 | ```
56 |
57 |
58 | Tiny-Imagenet-200-C
59 | -
60 | ```
61 | args.dataroot
62 | ├── tiny-imagenet-200
63 | | ├── train
64 | | | ├── n01443537
65 | | | | └──images
66 | | | | ├── *.JPEG
67 | | | | ├── ...
68 | | | |
69 | | | ├── n01629819
70 | | | ├── ...
71 | | |
72 | | └── val
73 | | ├── n01443537
74 | | | └── images
75 | | | ├── *.JPEG
76 | | | ├── ...
77 | | |
78 | | ├── n01629819
79 | | ├── ...
80 | |
81 | └── tiny-imagenet-200-c
82 | ├── val
83 | | ├── brightness
84 | | | ├── 1
85 | | | | ├── n01443537
86 | | | | | ├── *.JPEG
87 | | | | | ├── ...
88 | | | | |
89 | | | | ├── n01629819
90 | | | | ├── ...
91 | | | |
92 | | | ├── 2
93 | | | ├── 3
94 | | | ├── 4
95 | | | └── 5
96 | | |
97 | | ├── contrast
98 | | ├── ...
99 | |
100 | └── train
101 | ├── ... same as tiny-imagenet-200-c/val
102 |
103 | ```
104 |
105 |
106 | Imagenet
107 | -
108 | ```
109 | args.dataroot
110 | ├── imagenet
111 | | ├── train
112 | | | ├── n01443537
113 | | | | ├── *.JPEG
114 | | | | ├── ...
115 | | | |
116 | | | ├── n01629819
117 | | | ├── ...
118 | | |
119 | | └── val
120 | | ├── n01443537
121 | | | ├── *.JPEG
122 | | | ├── ...
123 | | |
124 | | ├── n01629819
125 | | ├── ...
126 | |
127 | └── imagenet-c
128 | ├── val
129 | | ├── brightness
130 | | | ├── 1
131 | | | | ├── n01443537
132 | | | | | ├── *.JPEG
133 | | | | | ├── ...
134 | | | | |
135 | | | | ├── n01629819
136 | | | | ├── ...
137 | | | |
138 | | | ├── 2
139 | | | ├── 3
140 | | | ├── 4
141 | | | └── 5
142 | | |
143 | | ├── contrast
144 | | ├── ...
145 | |
146 | └── train
147 | ├── ... same as imagenet-c/val
148 |
149 | ```
--------------------------------------------------------------------------------
/models/resnet_26.py:
--------------------------------------------------------------------------------
1 | # Based on the ResNet implementation in torchvision
2 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
3 |
4 | import math
5 | import torch
6 | from torch import nn
7 | from torchvision.models.resnet import conv3x3
8 |
9 |
10 | class BasicBlock(nn.Module):
11 | def __init__(self, inplanes, planes, norm_layer, stride=1, downsample=None):
12 | super(BasicBlock, self).__init__()
13 | self.downsample = downsample
14 | self.stride = stride
15 |
16 | self.bn1 = norm_layer(inplanes)
17 | self.relu1 = nn.ReLU(inplace=True)
18 | self.conv1 = conv3x3(inplanes, planes, stride)
19 |
20 | self.bn2 = norm_layer(planes)
21 | self.relu2 = nn.ReLU(inplace=True)
22 | self.conv2 = conv3x3(planes, planes)
23 |
24 | def forward(self, x):
25 | residual = x
26 | residual = self.bn1(residual)
27 | residual = self.relu1(residual)
28 | residual = self.conv1(residual)
29 |
30 | residual = self.bn2(residual)
31 | residual = self.relu2(residual)
32 | residual = self.conv2(residual)
33 |
34 | if self.downsample is not None:
35 | x = self.downsample(x)
36 | return x + residual
37 |
38 |
39 | class Downsample(nn.Module):
40 | def __init__(self, nIn, nOut, stride):
41 | super(Downsample, self).__init__()
42 | self.avg = nn.AvgPool2d(stride)
43 | assert nOut % nIn == 0
44 | self.expand_ratio = nOut // nIn
45 |
46 | def forward(self, x):
47 | x = self.avg(x)
48 | return torch.cat([x] + [x.mul(0)] * (self.expand_ratio - 1), 1)
49 |
50 |
51 | class ResNetCifar(nn.Module):
52 | def __init__(self, depth, width=1, classes=10, channels=3, norm_layer=nn.BatchNorm2d):
53 | assert (depth - 2) % 6 == 0 # depth is 6N+2
54 | self.N = (depth - 2) // 6
55 | super(ResNetCifar, self).__init__()
56 | self.conv1 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
57 | self.inplanes = 16
58 | self.layer1 = self._make_layer(norm_layer, 16 * width)
59 | self.layer2 = self._make_layer(norm_layer, 32 * width, stride=2)
60 | self.layer3 = self._make_layer(norm_layer, 64 * width, stride=2)
61 | self.bn = norm_layer(64 * width)
62 | self.relu = nn.ReLU(inplace=True)
63 | self.avgpool = nn.AvgPool2d(8)
64 | self.fc = nn.Linear(64 * width, classes)
65 |
66 | # Initialization
67 | for m in self.modules():
68 | if isinstance(m, nn.Conv2d):
69 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
70 | m.weight.data.normal_(0, math.sqrt(2. / n))
71 |
72 | def _make_layer(self, norm_layer, planes, stride=1):
73 | downsample = None
74 | if stride != 1 or self.inplanes != planes:
75 | downsample = Downsample(self.inplanes, planes, stride)
76 | layers = [BasicBlock(self.inplanes, planes, norm_layer, stride, downsample)]
77 | self.inplanes = planes
78 | for i in range(self.N - 1):
79 | layers.append(BasicBlock(self.inplanes, planes, norm_layer))
80 | return nn.Sequential(*layers)
81 |
82 | def forward(self, x):
83 | x = self.conv1(x)
84 | x = self.layer1(x)
85 | x = self.layer2(x)
86 | x = self.layer3(x)
87 | x = self.bn(x)
88 | x = self.relu(x)
89 | x = self.avgpool(x)
90 | x = x.view(x.size(0), -1)
91 | x = self.fc(x)
92 | return x
93 |
--------------------------------------------------------------------------------
/models/wide_resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class BasicBlock(nn.Module):
8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
9 | super(BasicBlock, self).__init__()
10 | self.bn1 = nn.BatchNorm2d(in_planes)
11 | self.relu1 = nn.ReLU(inplace=True)
12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
13 | padding=1, bias=False)
14 | self.bn2 = nn.BatchNorm2d(out_planes)
15 | self.relu2 = nn.ReLU(inplace=True)
16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
17 | padding=1, bias=False)
18 | self.droprate = dropRate
19 | self.equalInOut = (in_planes == out_planes)
20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
21 | padding=0, bias=False) or None
22 |
23 | def forward(self, x):
24 | if not self.equalInOut:
25 | x = self.relu1(self.bn1(x))
26 | else:
27 | out = self.relu1(self.bn1(x))
28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
29 | if self.droprate > 0:
30 | out = F.dropout(out, p=self.droprate, training=self.training)
31 | out = self.conv2(out)
32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
33 |
34 |
35 | class NetworkBlock(nn.Module):
36 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
37 | super(NetworkBlock, self).__init__()
38 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
39 |
40 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
41 | layers = []
42 | for i in range(int(nb_layers)):
43 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
44 | return nn.Sequential(*layers)
45 |
46 | def forward(self, x):
47 | return self.layer(x)
48 |
49 |
50 | class WideResNet(nn.Module):
51 | """ Based on code from https://github.com/yaodongyu/TRADES """
52 | def __init__(self, depth=28, num_classes=10, widen_factor=10, sub_block1=False, dropRate=0.0, bias_last=True):
53 | super(WideResNet, self).__init__()
54 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
55 | assert ((depth - 4) % 6 == 0)
56 | n = (depth - 4) / 6
57 | block = BasicBlock
58 | # 1st conv before any network block
59 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
60 | padding=1, bias=False)
61 | # 1st block
62 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
63 | if sub_block1:
64 | # 1st sub-block
65 | self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
66 | # 2nd block
67 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
68 | # 3rd block
69 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
70 | # global average pooling and classifier
71 | self.bn1 = nn.BatchNorm2d(nChannels[3])
72 | self.relu = nn.ReLU(inplace=True)
73 | self.fc = nn.Linear(nChannels[3], num_classes, bias=bias_last)
74 | self.nChannels = nChannels[3]
75 |
76 | for m in self.modules():
77 | if isinstance(m, nn.Conv2d):
78 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
79 | m.weight.data.normal_(0, math.sqrt(2. / n))
80 | elif isinstance(m, nn.BatchNorm2d):
81 | m.weight.data.fill_(1)
82 | m.bias.data.zero_()
83 | elif isinstance(m, nn.Linear) and not m.bias is None:
84 | m.bias.data.zero_()
85 |
86 | def forward(self, x):
87 | out = self.conv1(x)
88 | out = self.block1(out)
89 | out = self.block2(out)
90 | out = self.block3(out)
91 | out = self.relu(self.bn1(out))
92 | out = F.avg_pool2d(out, 8)
93 | out = out.view(-1, self.nChannels)
94 | return self.fc(out)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DUA: Dynamic Unsupervised Adaptation (CVPR 2022)
2 |
3 | This is the official repository for our paper: [The Norm Must Go On: Dynamic Unsupervised Domain Adaptation by Normalization](https://openaccess.thecvf.com/content/CVPR2022/papers/Mirza_The_Norm_Must_Go_On_Dynamic_Unsupervised_Domain_Adaptation_by_CVPR_2022_paper.pdf)
4 |
5 | DUA is an extremely simple method which only adapts the (1st and 2nd order) statistics of the Batch Normalization layer
6 | in an online manner to adapt to the out-of-distribution test data at test-time. Adapting only the statistics for
7 | Unsupervised Domain Adaptation makes DUA extremely fast and computation efficient. Moreover,
8 | DUA requires less than 1% of data from the target domain and no back propagation to achieve
9 | competitive (and often state-of-the-art) results when compared to strong baselines.
10 |
11 | Short explanatory video about DUA is hosted [here](https://www.youtube.com/watch?v=fTe0Aqs-t7E).
12 |
13 | # Installation
14 |
15 | 1) `git clone` this repository.
16 | 2) `pip install -r requirements.txt` to install required packages
17 |
18 | # Running Experiments
19 |
20 | [comment]: <> (We recommend first setting up user specific paths in the `PATHS` dictionary in `config.py`,)
21 |
22 | [comment]: <> (by following the existing entry as an example and use `--usr` argument to set paths automatically.)
23 |
24 | [comment]: <> (However, all experiments can also be run through explicit command)
25 |
26 | [comment]: <> (line arguments. )
27 | Before starting with running the experiments, please prepare the datasets through the instructions listed
28 | [here](readme/preparing_datasets.md).
29 |
30 | We provide code for reproducing CIFAR-10C / ImageNet-C / KITTI results. These experiments
31 | can be run through the following example commands.
32 |
33 | ### CIFAR-10C (WRN-40-2)
34 | For running this experiment first download the [AugMix](https://arxiv.org/abs/1912.02781) pre-trained
35 | [WRN-40-2 Checkpoint](https://drive.google.com/file/d/1wy7gSRsUZzCzj8QhmTbcnwmES_2kkNph/view).
36 | ```
37 | python main.py --dataset cifar10 --model wrn --ckpt_path path/to/checkpoint.pt --dataroot root/path/for/cifar-10C
38 | ```
39 | #### WRN - Results Cifar10C (Level-5 Severity)
40 | | | data samples used| mean error | gauss_noise | shot_noise | impulse_noise | defocus_blur | glass_blur | motion_blur | zoom_blur | snow | frost | fog | brightness | contrast | elastic_trans | pixelate | jpeg |
41 | | ---------------------------------------------------------- | ---:|---: | ----------: | ---------: | ------------: | -----------: | ---------: | ----------: | --------: | ---: | ----: | ---: | ---------: | -------: | ------------: | -------: | ---: |
42 | | source |10000 |18.3|28.8| 22.9|26.2|9.5| 20.6|10.6|9.3|14.2|15.3|17.5|7.6|20.9|14.7|41.3|14.7|
43 | | tent |10000 |12.3|15.8|13.5|18.7|8.1|18.7|9.1|8.0|10.3|10.8|11.7|6.7|11.6|14.1|11.7|15.2|
44 | | dua |80|12.1|15.4|13.4|17.3|8.0|18.0|9.1|7.7|10.8|10.8|12.1|6.6|10.9|13.6|13.0|14.3|
45 |
46 | ### ImageNet-C (ResNet-18)
47 | ```
48 | python main.py --dataset imagenet --model res18 --dataroot root/path/for/imagenet-C
49 | ```
50 |
51 | ### KITTI (YOLOv3)
52 | ```
53 | python main.py --dataset kitti --dataroot root/path/for/kitti
54 | ```
55 | This will first train the network on the original KITTI dataset and then adapt separately to `Fog` and `Rain`.
56 | The current hyper-parameters are set to the default values used in the DUA paper, to experiment with other
57 | settings please refer to `main.py`.
58 |
59 | #### To cite us:
60 | ```bibtex
61 | @InProceedings{mirza2022dua,
62 | author = {Mirza, M. Jehanzeb and Micorek, Jakub and Possegger, Horst and Bischof, Horst},
63 | title = {The Norm Must Go On: Dynamic Unsupervised Domain Adaptation by Normalization},
64 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
65 | year = {2022}
66 | }
67 | ```
68 |
69 | Also read [DISC](https://openaccess.thecvf.com/content/CVPR2022W/V4AS/papers/Mirza_An_Efficient_Domain-Incremental_Learning_Approach_To_Drive_in_All_Weather_CVPRW_2022_paper.pdf), an extension of DUA - accepted at CVPR workshops.
70 | ```bibtex
71 | @InProceedings{mirza2022disc,
72 | author = {Mirza, M. Jehanzeb and Masana, Marc and Possegger, Horst and Bischof, Horst},
73 | title = {An Efficient Domain-Incremental Learning Approach To Drive in All Weather Conditions},
74 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
75 | year = {2022}
76 | }
77 | ```
78 |
--------------------------------------------------------------------------------
/utils/training.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from itertools import chain
3 | from os.path import join, dirname
4 | from typing import Iterable
5 | from warnings import warn
6 |
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torch import save
10 | from torch.optim.lr_scheduler import ReduceLROnPlateau
11 |
12 | from utils.data_loader import get_loader
13 | from utils.testing import test
14 | from utils.utils import make_dirs
15 |
16 | log = logging.getLogger('TRAINING')
17 |
18 | class ReduceLROnPlateauEarlyStop(ReduceLROnPlateau):
19 | """
20 | Extension of ReduceLROnPlateau to also implement early stopping.
21 | The argument max_unsuccessful_reductions defines how many lr reductions
22 | without improvement can be made before meeting the early stopping
23 | criteria, in which case the step() method returns False instead of True
24 | """
25 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
26 | threshold=1e-4, threshold_mode='rel', cooldown=0,
27 | min_lr=0, eps=1e-8, verbose=False,
28 | max_unsuccessful_reductions=3):
29 | super().__init__(optimizer, mode, factor, patience,
30 | threshold, threshold_mode, cooldown,
31 | min_lr, eps, verbose)
32 | self.consecutive_lr_reductions = 0
33 | self.max_unsuccessful_reductions = max_unsuccessful_reductions
34 |
35 | # slightly modified ReduceLROnPlateau step() method, to keep track of
36 | # lr decreases and return False on no improvement after
37 | # max_unsuccessful_reductions lr reductions
38 | def step(self, metrics, epoch=None):
39 | current = float(metrics)
40 | if epoch is None:
41 | epoch = self.last_epoch + 1
42 | else:
43 | warn(optim.lr_scheduler.EPOCH_DEPRECATION_WARNING)
44 | self.last_epoch = epoch
45 |
46 | if self.is_better(current, self.best):
47 | self.best = current
48 | self.num_bad_epochs = 0
49 | self.consecutive_lr_reductions = 0
50 | else:
51 | self.num_bad_epochs += 1
52 | if self.consecutive_lr_reductions >= self.max_unsuccessful_reductions:
53 | if self.verbose:
54 | log.info("Early stopping criteria reached!")
55 | return False
56 |
57 | if self.in_cooldown:
58 | self.cooldown_counter -= 1
59 | self.num_bad_epochs = 0
60 |
61 | if self.num_bad_epochs > self.patience:
62 | self._reduce_lr(epoch)
63 | self.consecutive_lr_reductions += 1
64 | self.cooldown_counter = self.cooldown
65 | self.num_bad_epochs = 0
66 |
67 | self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
68 | return True
69 |
70 |
71 | def get_heads_params(model):
72 | heads = model.get_heads()
73 | if isinstance(heads, Iterable):
74 | return chain.from_iterable([m.parameters() for m in heads])
75 | return heads.parameters()
76 |
77 |
78 | def train(model, args, result_path='checkpoints/ckpt.pt', lr=None,
79 | train_heads_only=False, joint=False):
80 | make_dirs(dirname(result_path))
81 | device = next(model.parameters()).device
82 | if not lr:
83 | lr = args.lr
84 |
85 | if train_heads_only:
86 | optimizer = optim.SGD(get_heads_params(model), lr=lr, momentum=0.9,
87 | weight_decay=5e-4)
88 | else:
89 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9,
90 | weight_decay=5e-4)
91 | criterion = nn.CrossEntropyLoss().to(device)
92 | n = args.max_unsuccessful_reductions
93 | scheduler = ReduceLROnPlateauEarlyStop(optimizer, factor=args.lr_factor,
94 | patience=args.patience,
95 | verbose=args.verbose,
96 | max_unsuccessful_reductions = n)
97 |
98 | train_loader = get_loader(args, split='train', joint=joint)
99 | valid_loader = get_loader(args, split='val', joint=joint)
100 |
101 | all_err_cls = []
102 | for epoch in range(1, args.epochs + 1):
103 | model.train()
104 |
105 | if train_heads_only: # freeze BN running estimates
106 | for m in model.modules():
107 | if isinstance(m, nn.modules.batchnorm._BatchNorm):
108 | m.eval()
109 |
110 | train_one_epoch(model, epoch, optimizer, train_loader, criterion, device)
111 | err_cls = test(valid_loader, model)[0]
112 | all_err_cls.append(err_cls)
113 | if err_cls <= min(all_err_cls):
114 | if train_heads_only:
115 | save(model.get_heads().state_dict(), result_path)
116 | else:
117 | save(model.state_dict(), result_path)
118 |
119 | log.info(('Epoch %d/%d:' % (epoch, args.epochs)).ljust(20) +
120 | '%.1f' % (err_cls * 100))
121 |
122 | if not scheduler.step(err_cls):
123 | log.info("Finished training")
124 | return
125 |
126 |
127 | def train_one_epoch(model, epoch, optimizer, train_loader, criterion, device):
128 | total_loss = 0
129 | for batch_idx, (images, labels) in enumerate(train_loader):
130 | optimizer.zero_grad()
131 | images, labels = images.to(device), labels.to(device)
132 | outputs = model(images)
133 | loss = criterion(outputs, labels)
134 | loss.backward()
135 | optimizer.step()
136 | total_loss += loss.item()
137 | log.info(f'Epoch {epoch} avg loss per batch: {total_loss / (batch_idx + 1):.4f}')
138 |
139 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | VALID_DATASETS = ['cifar10', 'imagenet', 'kitti', 'imagenet-mini']
2 |
3 | ROBUSTNESS_TASKS = [
4 | 'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur',
5 | 'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
6 | 'brightness', 'contrast', 'elastic_transform', 'pixelate',
7 | 'jpeg_compression'
8 | ]
9 |
10 | KITTI_TASKS = ['fog', 'rain']
11 |
12 | PATHS = {
13 | 'jm': {
14 | 'cifar10': {
15 | 'root': '/PATH/TO/cifar10/ROOT',
16 | 'ckpt': '/PATH/TO/cifar10/CHECKPOINT',
17 | },
18 | 'imagenet-mini': {
19 | 'root': '/PATH/TO/imagenet-mini/ROOT',
20 | 'ckpt': '/PATH/TO/imagenet-mini/CHECKPOINT',
21 | },
22 | 'imagenet': {
23 | 'root': '/PATH/TO/imagenet/ROOT',
24 | 'ckpt': '/PATH/TO/imagenet/CHECKPOINT',
25 | },
26 | 'kitti': {
27 | 'root': '/PATH/TO/kitti/ROOT',
28 | 'ckpt': '/PATH/TO/kitti-clear/CHECKPOINT',
29 | },
30 | },
31 | }
32 |
33 | LOGGER_CFG = {
34 | 'version': 1,
35 | 'formatters': {
36 | 'default': {
37 | 'format': '[%(name)s - %(levelname)s] %(message)s'
38 | },
39 | 'timestamped': {
40 | 'format': '%(asctime)s [%(name)s - %(levelname)s] %(message)s'
41 | },
42 | 'minimal': {
43 | 'format': '[%(name)s] %(message)s'
44 | }
45 | },
46 | 'filters': {
47 | 'name': {
48 | '()': 'config.ContextFilter'
49 | }
50 | },
51 | 'handlers': {
52 | 'console_handler': {
53 | 'level': 'DEBUG',
54 | 'class': 'logging.StreamHandler',
55 | 'formatter': 'minimal',
56 | 'stream': 'ext://sys.stdout',
57 | 'filters': ['name']
58 | },
59 | 'file_handler': {
60 | 'level': 'DEBUG',
61 | 'formatter': 'minimal',
62 | 'class': 'logging.FileHandler',
63 | 'filename': 'log.txt',
64 | 'mode': 'a',
65 | 'filters': ['name']
66 | },
67 | },
68 | 'loggers': {
69 | '': {
70 | 'handlers': ['console_handler', 'file_handler'],
71 | 'level': 'WARNING',
72 | 'propagate': False
73 | },
74 |
75 | 'MAIN': {
76 | 'handlers': ['console_handler', 'file_handler'],
77 | 'level': 'DEBUG',
78 | 'propagate': False
79 | },
80 | 'MAIN.DISC': {},
81 | 'MAIN.DUA': {},
82 | 'MAIN.DATA': {},
83 | 'MAIN.RESULTS': {},
84 |
85 | 'BASELINE': {
86 | 'handlers': ['console_handler', 'file_handler'],
87 | 'level': 'DEBUG',
88 | 'propagate': False
89 | },
90 | 'BASELINE.FREEZING': {},
91 | 'BASELINE.DISJOINT': {},
92 | 'BASELINE.JOINT_TRAINING': {},
93 | 'BASELINE.SOURCE_ONLY': {},
94 | 'BASELINE.FINE_TUNING': {},
95 |
96 | 'TRAINING': {
97 | 'handlers': ['console_handler', 'file_handler'],
98 | 'level': 'DEBUG',
99 | 'propagate': False
100 | },
101 |
102 | 'TESTING': {
103 | 'handlers': ['console_handler', 'file_handler'],
104 | 'level': 'DEBUG',
105 | 'propagate': False
106 | },
107 | 'TESTING.FILEONLY': {
108 | 'handlers': ['file_handler'],
109 | 'level': 'DEBUG',
110 | 'propagate': False
111 | }
112 | }
113 | }
114 |
115 | # Filtering logger tag prefixes
116 | class ContextFilter:
117 | def filter(self, record):
118 | split_name = record.name.split('.', 1)
119 | if split_name[0] == 'BASELINE' or split_name[0] == 'MAIN':
120 | if len(split_name) > 1:
121 | record.name = split_name[1]
122 | if split_name[0] == 'TESTING':
123 | if len(split_name) > 1:
124 | record.name = split_name[0]
125 | return True
126 |
127 |
128 | YOLO_HYP = {
129 | # !! lr0 will be overwritten by args.lr !!
130 | 'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3)
131 | 'lrf': 0.2, # final OneCycleLR learning rate (lr0 * lrf)
132 | 'momentum': 0.937, # SGD momentum/Adam beta1
133 | 'weight_decay': 0.0005, # optimizer weight decay 5e-4
134 | 'warmup_epochs': 3.0, # warmup epochs (fractions ok)
135 | 'warmup_momentum': 0.8, # warmup initial momentum
136 | 'warmup_bias_lr': 0.1, # warmup initial bias lr
137 | 'box': 0.05, # box loss gain
138 | 'cls': 0.5, # cls loss gain
139 | 'cls_pw': 1.0, # cls BCELoss positive_weight
140 | 'obj': 1.0, # obj loss gain (scale with pixels)
141 | 'obj_pw': 1.0, # obj BCELoss positive_weight
142 | 'iou_t': 0.20, # IoU training threshold
143 | 'anchor_t': 4.0, # anchor-multiple threshold
144 | 'fl_gamma': 0.0, # focal loss gamma (efficientDet default gamma=1.5)
145 | 'hsv_h': 0.015, # image HSV-Hue augmentation (fraction)
146 | 'hsv_s': 0.7, # image HSV-Saturation augmentation (fraction)
147 | 'hsv_v': 0.4, # image HSV-Value augmentation (fraction)
148 | 'degrees': 0.0, # image rotation (+/- deg)
149 | 'translate': 0.1, # image translation (+/- fraction)
150 | 'scale': 0.5, # image scale (+/- gain)
151 | 'shear': 0.0, # image shear (+/- deg)
152 | 'perspective': 0.0, # image perspective (+/- fraction), range 0-0.001
153 | 'flipud': 0.0, # image flip up-down (probability)
154 | 'fliplr': 0.5, # image flip left-right (probability)
155 | 'mosaic': 1.0, # image mosaic (probability)
156 | 'mixup': 0.0 # image mixup (probability)
157 | }
158 |
159 |
160 |
--------------------------------------------------------------------------------
/utils/google_utils.py:
--------------------------------------------------------------------------------
1 | # Google utils: https://cloud.google.com/storage/docs/reference/libraries
2 |
3 | import os
4 | import platform
5 | import subprocess
6 | import time
7 | from pathlib import Path
8 |
9 | import requests
10 | import torch
11 |
12 |
13 | def gsutil_getsize(url=''):
14 | # gs://bucket/file size https://cloud.google.com/storage/docs/gsutil/commands/du
15 | s = subprocess.check_output(f'gsutil du {url}', shell=True).decode('utf-8')
16 | return eval(s.split(' ')[0]) if len(s) else 0 # bytes
17 |
18 |
19 | def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
20 | # Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
21 | file = Path(file)
22 | try: # GitHub
23 | print(f'Downloading {url} to {file}...')
24 | torch.hub.download_url_to_file(url, str(file))
25 | assert file.exists() and file.stat().st_size > min_bytes # check
26 | except Exception as e: # GCP
27 | file.unlink(missing_ok=True) # remove partial downloads
28 | print(f'Download error: {e}\nRe-attempting {url2 or url} to {file}...')
29 | os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
30 | finally:
31 | if not file.exists() or file.stat().st_size < min_bytes: # check
32 | file.unlink(missing_ok=True) # remove partial downloads
33 | print(f'ERROR: Download failure: {error_msg or url}')
34 | print('')
35 |
36 |
37 | def attempt_download(file, repo='ultralytics/yolov3'):
38 | # Attempt file download if does not exist
39 | file = Path(str(file).strip().replace("'", ''))
40 | # print(file)
41 | if not file.exists():
42 | # URL specified
43 | name = file.name
44 | if str(file).startswith(('http:/', 'https:/')): # download
45 | url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
46 | safe_download(file=name, url=url, min_bytes=1E5)
47 | return name
48 |
49 | # GitHub assets
50 | file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
51 | try:
52 | response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api
53 | assets = [x['name'] for x in response['assets']] # release assets, i.e. ['yolov5s.pt', 'yolov5m.pt', ...]
54 | tag = response['tag_name'] # i.e. 'v1.0'
55 | except: # fallback plan
56 | assets = ['yolov3.pt', 'yolov3-spp.pt', 'yolov3-tiny.pt']
57 | try:
58 | tag = subprocess.check_output('git tag', shell=True, stderr=subprocess.STDOUT).decode().split()[-1]
59 | except:
60 | tag = 'v9.5.0' # current release
61 |
62 | if name in assets:
63 | safe_download(file,
64 | url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
65 | # url2=f'https://storage.googleapis.com/{repo}/ckpt/{name}', # backup url (optional)
66 | min_bytes=1E5,
67 | error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/')
68 |
69 | return str(file)
70 |
71 |
72 | def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'):
73 | # Downloads a file from Google Drive. from yolov3.utils.google_utils import *; gdrive_download()
74 | t = time.time()
75 | file = Path(file)
76 | cookie = Path('cookie') # gdrive cookie
77 | print(f'Downloading https://drive.google.com/uc?export=download&id={id} as {file}... ', end='')
78 | file.unlink(missing_ok=True) # remove existing file
79 | cookie.unlink(missing_ok=True) # remove existing cookie
80 |
81 | # Attempt file download
82 | out = "NUL" if platform.system() == "Windows" else "/dev/null"
83 | os.system(f'curl -c ./cookie -s -L "drive.google.com/uc?export=download&id={id}" > {out}')
84 | if os.path.exists('cookie'): # large file
85 | s = f'curl -Lb ./cookie "drive.google.com/uc?export=download&confirm={get_token()}&id={id}" -o {file}'
86 | else: # small file
87 | s = f'curl -s -L -o {file} "drive.google.com/uc?export=download&id={id}"'
88 | r = os.system(s) # execute, capture return
89 | cookie.unlink(missing_ok=True) # remove existing cookie
90 |
91 | # Error check
92 | if r != 0:
93 | file.unlink(missing_ok=True) # remove partial
94 | print('Download error ') # raise Exception('Download error')
95 | return r
96 |
97 | # Unzip if archive
98 | if file.suffix == '.zip':
99 | print('unzipping... ', end='')
100 | os.system(f'unzip -q {file}') # unzip
101 | file.unlink() # remove zip to free space
102 |
103 | print(f'Done ({time.time() - t:.1f}s)')
104 | return r
105 |
106 |
107 | def get_token(cookie="./cookie"):
108 | with open(cookie) as f:
109 | for line in f:
110 | if "download" in line:
111 | return line.split()[-1]
112 | return ""
113 |
114 | # def upload_blob(bucket_name, source_file_name, destination_blob_name):
115 | # # Uploads a file to a bucket
116 | # # https://cloud.google.com/storage/docs/uploading-objects#storage-upload-object-python
117 | #
118 | # storage_client = storage.Client()
119 | # bucket = storage_client.get_bucket(bucket_name)
120 | # blob = bucket.blob(destination_blob_name)
121 | #
122 | # blob.upload_from_filename(source_file_name)
123 | #
124 | # print('File {} uploaded to {}.'.format(
125 | # source_file_name,
126 | # destination_blob_name))
127 | #
128 | #
129 | # def download_blob(bucket_name, source_blob_name, destination_file_name):
130 | # # Uploads a blob from a bucket
131 | # storage_client = storage.Client()
132 | # bucket = storage_client.get_bucket(bucket_name)
133 | # blob = bucket.blob(source_blob_name)
134 | #
135 | # blob.download_to_filename(destination_file_name)
136 | #
137 | # print('Blob {} downloaded to {}.'.format(
138 | # source_blob_name,
139 | # destination_file_name))
--------------------------------------------------------------------------------
/models/experimental.py:
--------------------------------------------------------------------------------
1 | # YOLOv3 experimental modules
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 |
7 | from models.common import Conv, DWConv
8 | from utils.google_utils import attempt_download
9 |
10 |
11 | class CrossConv(nn.Module):
12 | # Cross Convolution Downsample
13 | def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
14 | # ch_in, ch_out, kernel, stride, groups, expansion, shortcut
15 | super(CrossConv, self).__init__()
16 | c_ = int(c2 * e) # hidden channels
17 | self.cv1 = Conv(c1, c_, (1, k), (1, s))
18 | self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
19 | self.add = shortcut and c1 == c2
20 |
21 | def forward(self, x):
22 | return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
23 |
24 |
25 | class Sum(nn.Module):
26 | # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
27 | def __init__(self, n, weight=False): # n: number of inputs
28 | super(Sum, self).__init__()
29 | self.weight = weight # apply weights boolean
30 | self.iter = range(n - 1) # iter object
31 | if weight:
32 | self.w = nn.Parameter(-torch.arange(1., n) / 2, requires_grad=True) # layer weights
33 |
34 | def forward(self, x):
35 | y = x[0] # no weight
36 | if self.weight:
37 | w = torch.sigmoid(self.w) * 2
38 | for i in self.iter:
39 | y = y + x[i + 1] * w[i]
40 | else:
41 | for i in self.iter:
42 | y = y + x[i + 1]
43 | return y
44 |
45 |
46 | class GhostConv(nn.Module):
47 | # Ghost Convolution https://github.com/huawei-noah/ghostnet
48 | def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
49 | super(GhostConv, self).__init__()
50 | c_ = c2 // 2 # hidden channels
51 | self.cv1 = Conv(c1, c_, k, s, None, g, act)
52 | self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
53 |
54 | def forward(self, x):
55 | y = self.cv1(x)
56 | return torch.cat([y, self.cv2(y)], 1)
57 |
58 |
59 | class GhostBottleneck(nn.Module):
60 | # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
61 | def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
62 | super(GhostBottleneck, self).__init__()
63 | c_ = c2 // 2
64 | self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
65 | DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
66 | GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
67 | self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
68 | Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
69 |
70 | def forward(self, x):
71 | return self.conv(x) + self.shortcut(x)
72 |
73 |
74 | class MixConv2d(nn.Module):
75 | # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595
76 | def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
77 | super(MixConv2d, self).__init__()
78 | groups = len(k)
79 | if equal_ch: # equal c_ per group
80 | i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices
81 | c_ = [(i == g).sum() for g in range(groups)] # intermediate channels
82 | else: # equal weight.numel() per group
83 | b = [c2] + [0] * groups
84 | a = np.eye(groups + 1, groups, k=-1)
85 | a -= np.roll(a, 1, axis=1)
86 | a *= np.array(k) ** 2
87 | a[0] = 1
88 | c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
89 |
90 | self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
91 | self.bn = nn.BatchNorm2d(c2)
92 | self.act = nn.LeakyReLU(0.1, inplace=True)
93 |
94 | def forward(self, x):
95 | return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
96 |
97 |
98 | class Ensemble(nn.ModuleList):
99 | # Ensemble of models
100 | def __init__(self):
101 | super(Ensemble, self).__init__()
102 |
103 | def forward(self, x, augment=False):
104 | y = []
105 | for module in self:
106 | y.append(module(x, augment)[0])
107 | # y = torch.stack(y).max(0)[0] # max ensemble
108 | # y = torch.stack(y).mean(0) # mean ensemble
109 | y = torch.cat(y, 1) # nms ensemble
110 | return y, None # inference, train output
111 |
112 |
113 | def attempt_load(weights, map_location=None, inplace=True):
114 | from models.yolo import Detect, Model
115 |
116 | # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
117 | model = Ensemble()
118 | for w in weights if isinstance(weights, list) else [weights]:
119 | # print(w)
120 | ckpt = torch.load(attempt_download(w), map_location=map_location) # load
121 | # print(ckpt.get('ema'))
122 | model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval()) # FP32 model
123 | # print('hereeeeeeeeeeeeeeeeeeeeee')
124 |
125 | # Compatibility updates
126 | for m in model.modules():
127 | if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
128 | m.inplace = inplace # pytorch 1.7.0 compatibility
129 | elif type(m) is Conv:
130 | m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
131 |
132 | if len(model) == 1:
133 | # print(f'Model created with {weights}\n')
134 | # print('herrrrrrrrrrrrrrrrrrrrrrrreeeeeeeeeeeeeee')
135 |
136 | return model[-1] # return model
137 | # print('herrrrrrrrrrrrrrrrrrrrrrrreeeeeeeeeeeeeee')
138 | else:
139 | print(f'Ensemble created with {weights}\n')
140 |
141 | for k in ['names']:
142 | setattr(model, k, getattr(model[-1], k))
143 | model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
144 | return model # return ensemble
145 |
146 | def attempt_load_(weights, map_location=None, inplace=True):
147 | from models.yolo import Detect, Model
148 |
149 | # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
150 | model = Ensemble()
151 |
152 | for w in weights if isinstance(weights, list) else [weights]:
153 | ckpt = torch.load(attempt_download(w), map_location=map_location) # load
154 | model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval()) # FP32 model
155 |
156 | return model
--------------------------------------------------------------------------------
/methods/dua.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | from utils.data_loader import *
3 | from utils.rotation import *
4 | from utils.testing import test
5 | from utils.testing_yolov3 import test as test_yolo
6 | from utils.torch_utils import select_device
7 | from utils.utils import make_dirs
8 | from utils.results_manager import ResultsManager
9 | from init import init_net
10 | log = logging.getLogger('MAIN.DUA')
11 |
12 |
13 | def dua(args, net, save_bn_stats=False, use_training_data=False, save_fname=None):
14 | results_mgr = ResultsManager()
15 | if args.model == 'yolov3':
16 | get_adaption_inputs = get_adaption_inputs_kitti
17 | metric = 'mAP@50'
18 | tr_transform_adapt = transforms.Compose([
19 | transforms.ToPILImage(),
20 | transforms.RandomCrop((224, 640)),
21 | transforms.RandomHorizontalFlip(),
22 | transforms.ToTensor(),
23 | ])
24 | else:
25 | get_adaption_inputs = get_adaption_inputs_default
26 | metric = 'Error'
27 | tr_transform_adapt = transforms.Compose([
28 | transforms.RandomCrop(32, padding=4),
29 | transforms.RandomHorizontalFlip(),
30 | transforms.ToTensor(),
31 | transforms.Normalize(*NORM)
32 | ])
33 | if not args.dataset == 'imagenet':
34 | ckpt = torch.load(args.ckpt_path)
35 | decay_factor = args.decay_factor
36 | min_momentum_constant = args.min_mom
37 | no_imp = 0
38 | no_imp_cnt = 0
39 | all_results = []
40 | device = select_device(args.device, batch_size=args.batch_size)
41 |
42 | for args.task in globals.TASKS:
43 | if not set_severity(args):
44 | continue
45 | mom_pre = 0.1
46 | results = []
47 | log.info(f'Task - {args.task} :::: Level - {args.severity}')
48 | if not args.dataset == 'imagenet':
49 | net.load_state_dict(ckpt)
50 | else:
51 | init_net(args)
52 |
53 | net.eval()
54 | if use_training_data:
55 | train_loader = get_loader(args, split='train')
56 | valid_loader = get_loader(args, split='val', pad=0.5, rect=True)
57 | else:
58 | # original DUA is run on test data only
59 | train_loader = valid_loader = get_loader(args, split='test', pad=0.5, rect=True)
60 |
61 | if args.model == 'yolov3':
62 | res = test_yolo(model=net, dataloader=valid_loader,
63 | iou_thres=args.iou_thres, conf_thres=args.conf_thres,
64 | augment=args.augment)[0] * 100
65 | else:
66 | res = test(valid_loader, net)[0] * 100
67 | log.info(f'{metric} Before Adaptation: {res:.1f}')
68 |
69 | for i in tqdm(range(1, args.num_samples + 1)):
70 | net.eval()
71 | image = train_loader.dataset.get_image_from_idx(i - 1)
72 | mom_new = (mom_pre * decay_factor)
73 | for m in net.modules():
74 | if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
75 | m.train()
76 | m.momentum = mom_new + min_momentum_constant
77 | mom_pre = mom_new
78 | inputs = get_adaption_inputs(image, tr_transform_adapt, device)
79 | _ = net(inputs)
80 | net.eval()
81 | if args.model == 'yolov3':
82 | res = test_yolo(model=net, dataloader=valid_loader,
83 | iou_thres=args.iou_thres, conf_thres=args.conf_thres,
84 | augment=args.augment)[0] * 100
85 | else:
86 | res = test(valid_loader, net)[0] * 100
87 | results.append(res)
88 | if result_improved(metric, res, results):
89 | save_bn_stats_in_model(net, args.task)
90 | no_imp = 0
91 | no_imp_cnt = 0
92 | else:
93 | no_imp += 1
94 | if no_imp >= 10:
95 | no_imp_cnt += no_imp
96 | no_imp = 0
97 | log.info(f'Iteration {i}/{args.num_samples}: No Improvement '
98 | f'for {no_imp_cnt} consecutive iterations')
99 |
100 | adaptation_result = max(results) if metric == 'mAP@50' else min(results)
101 |
102 | severity_str = '' if args.task == 'initial' else f'{args.severity}'
103 | results_mgr.add_result('DUA', f'{args.task} {severity_str}', adaptation_result, 'online')
104 |
105 | log.info(f'{metric} After Adaptation: {adaptation_result:.1f}')
106 | all_results.append(adaptation_result)
107 | log.info(f'Mean {metric} after Adaptation {(sum(all_results) / len(all_results)):.1f}')
108 |
109 | if save_bn_stats:
110 | save_bn_stats_to_file(net, args.dataset, args.model, save_fname)
111 |
112 |
113 | def result_improved(metric, current_result, all_results_for_current_task):
114 | """
115 | Check if the result has improved compared to all previous results.
116 | If metric is 'mAP@50' higher value means better, else
117 | lower value means better.
118 | """
119 | if metric == 'mAP@50':
120 | return current_result >= max(all_results_for_current_task)
121 | else:
122 | return current_result <= min(all_results_for_current_task)
123 |
124 |
125 | def get_adaption_inputs_default(img, tr_transform_adapt, device):
126 | inputs = [(tr_transform_adapt(img)) for _ in range(64)]
127 | inputs = torch.stack(inputs)
128 | inputs_ssh, _ = rotate_batch(inputs, 'rand')
129 | inputs_ssh = inputs_ssh.to(device, non_blocking=True)
130 | return inputs_ssh
131 |
132 |
133 | def get_adaption_inputs_kitti(img, tr_transform_adapt, device):
134 | img = img.squeeze(0)
135 | inputs = [(tr_transform_adapt(img)) for _ in range(64)]
136 | inputs = torch.stack(inputs)
137 | inputs_ssh, _ = rotate_batch(inputs, 'rand')
138 | inputs_ssh = inputs_ssh.to(device, non_blocking=True)
139 | inputs_ssh /= 255
140 | return inputs_ssh
141 |
142 |
143 | def save_bn_stats_in_model(net, task):
144 | """
145 | Saves the running estimates of all batch norm layers for a given
146 | task, in the net.bn_stats attribute.
147 | """
148 | state_dict = net.state_dict()
149 | net.bn_stats[task] = {}
150 | for layer_name, m in net.named_modules():
151 | if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
152 | net.bn_stats[task][layer_name] = {
153 | 'running_mean': state_dict[layer_name + '.running_mean'].detach().clone(),
154 | 'running_var': state_dict[layer_name + '.running_var'].detach().clone()
155 | }
156 |
157 |
158 | def save_bn_stats_to_file(net, dataset_str=None, model_str=None, file_name=None):
159 | """
160 | Saves net.bn_stats content to a file.
161 | """
162 | # ckpt_folder = 'checkpoints/' + dataset_str + '/' + model_str + '/'
163 | ckpt_folder = join('checkpoints', dataset_str, model_str)
164 | make_dirs(ckpt_folder)
165 | if not file_name:
166 | file_name = 'BN_stats.pt'
167 | torch.save(net.bn_stats, join(ckpt_folder, file_name))
--------------------------------------------------------------------------------
/utils/results_manager.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from os.path import exists
3 |
4 | import pandas as pd
5 |
6 |
7 | class ResultsManager():
8 | """
9 | Singleton class to manage results.
10 | """
11 | _instance = None
12 | log = logging.getLogger('MAIN.RESULTS')
13 | multi_run_res = {}
14 |
15 | def __new__(cls, _=None):
16 | if cls._instance is None:
17 | cls._instance = super(ResultsManager, cls).__new__(cls)
18 | return cls._instance
19 |
20 |
21 | def __init__(self, metric='mAP@50'):
22 | if hasattr(self, 'results'):
23 | return
24 | columns = ['method', 'task', 'value', 'scenario']
25 | self.results = pd.DataFrame(columns=columns)
26 | self.metric = metric
27 |
28 |
29 | def has_results(self):
30 | return not self.results.empty
31 |
32 |
33 | def save_to_file(self, file_name=None):
34 | if not file_name:
35 | path = 'results/raw_results_df.pkl'
36 | else:
37 | path = 'results/' + file_name
38 | self.results.to_pickle(path)
39 |
40 |
41 | def load_from_file(self, file_name=None):
42 | if not file_name:
43 | path = 'results/raw_results_df.pkl'
44 | else:
45 | path = 'results/' + file_name
46 | if not exists(path):
47 | raise Exception('Results file not found')
48 | self.results = pd.read_pickle(path)
49 |
50 |
51 | def add_result(self, method, task, value, scenario):
52 | entry = pd.DataFrame([{
53 | 'method' : method,
54 | 'task': task,
55 | 'value': value,
56 | 'scenario': scenario
57 | }])
58 | self.results = pd.concat([self.results, entry], ignore_index=True)
59 |
60 | if method not in self.multi_run_res:
61 | self.multi_run_res[method] = {}
62 | if scenario not in self.multi_run_res[method]:
63 | self.multi_run_res[method][scenario] = {}
64 | if task not in self.multi_run_res[method][scenario]:
65 | self.multi_run_res[method][scenario][task] = []
66 |
67 | self.multi_run_res[method][scenario][task].append(value)
68 |
69 |
70 | def print_multiple_runs_results(self):
71 | if not self.multi_run_res:
72 | return
73 |
74 | from statistics import mean, variance, stdev
75 |
76 | self.log.info('------------ Multi run results ------------')
77 | for method, v2 in self.multi_run_res.items():
78 | self.log.info(f'\nMethod: {method}')
79 | for scenario, v1 in v2.items():
80 | self.log.info(f'\t\tScenario: {scenario}')
81 | for task, v in v1.items():
82 | self.log.info(f'\t\tTask: {task}')#, v content: {v}')
83 | self.log.info(f'\t\tMEAN: {mean(v):.3f}, VAR: {variance(v):.3f}, STDEV {stdev(v):.3f}')
84 | self.log.info('-------------------------------------------')
85 |
86 |
87 | def reset_results(self):
88 | if hasattr(self, 'summary'):
89 | delattr(self, 'summary')
90 | columns = ['method', 'task', 'value', 'scenario']
91 | self.results = pd.DataFrame(columns=columns)
92 |
93 |
94 | def generate_summary(self):
95 | self.summary = {}
96 | tasks = self.results.task.unique()
97 | methods = self.results.method.unique()
98 | self.summary['online'] = pd.DataFrame(columns=tasks)
99 | self.summary['offline'] = pd.DataFrame(columns=tasks)
100 |
101 | for method in methods:
102 | for scenario in ['online', 'offline']:
103 | df = self.results[(self.results['method'] == method) &
104 | self.results['scenario'].isin([scenario, None])]
105 | if not len(df):
106 | continue
107 | self.summary[scenario].loc[method] = list(df['value'])
108 |
109 |
110 | def print_summary(self):
111 | if not hasattr(self, 'summary'):
112 | self.generate_summary()
113 | self.log.info('Results summary:')
114 | pd.set_option('display.max_columns', None)
115 | for scenario, scenario_summary in self.summary.items():
116 | self.log.info(scenario.upper(), ':')
117 | self.log.info(scenario_summary, '\n')
118 |
119 | def print_summary_latex(self, max_cols=8):
120 | self.log.info(f'\n{self.results}')
121 | import warnings
122 | from math import ceil
123 | warnings.simplefilter(action='ignore', category=FutureWarning)
124 |
125 | if not hasattr(self, 'summary'):
126 | self.generate_summary()
127 |
128 | res = ('-' * 30) + 'START LATEX' + ('-' * 30)
129 | for scenario in self.summary.keys():
130 | hdrs = self.summary[scenario].columns.values
131 | # short_hdrs = [x.split('_')[0] for x in hdrs]
132 | short_hdrs = [x for x in hdrs]
133 | length = len(hdrs)
134 | if max_cols == 0 or max_cols > length:
135 | max_cols = length
136 | start = 0
137 | end = min(max_cols, length)
138 | num_splits = ceil(length / max_cols)
139 | res += "\n\\begin{table}\n\\centering\n\\caption{" + scenario.capitalize() + "}\n"
140 | for x in range(num_splits):
141 | res += self.summary[scenario].to_latex(float_format="%.1f",
142 | columns=hdrs[start:end],
143 | header=short_hdrs[start:end])
144 | if x < num_splits - 1:
145 | res += "\\vspace{-.6mm}\\\\\n"
146 |
147 | start += max_cols
148 | if x == num_splits-2:
149 | end = length
150 | else:
151 | end += max_cols
152 |
153 | res += "\\end{table}\n"
154 |
155 | res += ('-' * 30) + 'END LATEX' + ('-' * 30)
156 | self.log.info(res)
157 |
158 |
159 | def plot_summary(self, file_name=None):
160 | import matplotlib.pyplot as plt
161 | import matplotlib.ticker as mticker
162 | import seaborn as sns
163 |
164 | sns.set_style("whitegrid")
165 | g = sns.FacetGrid(data=self.results, col='scenario', hue='method',
166 | legend_out=True, height=4, aspect= 1.33)
167 | g.map(sns.lineplot, 'task', 'value', marker='o')
168 | g.add_legend()
169 |
170 | for axes in g.axes.flat:
171 | ticks_loc = axes.get_xticks()
172 | axes.xaxis.set_major_locator(mticker.FixedLocator(ticks_loc))
173 | axes.set_xticklabels(axes.get_xticklabels(), rotation=90)
174 |
175 | # shorten x axis labels by cutting anything after an underscore
176 | # tasks_short = [x.get_text().split('_')[0] for x in axes.get_xticklabels()]
177 | # axes.set_xticklabels(tasks_short)
178 |
179 | axes.tick_params(labelleft=True)
180 | axes.set_xlabel('Task')
181 | axes.set_ylabel(self.metric)
182 |
183 | path = f'results/{file_name}' if file_name else 'results/plot_results.png'
184 | g.tight_layout()
185 | plt.savefig(path)
186 | # plt.show(block=True)
187 |
188 |
189 |
--------------------------------------------------------------------------------
/utils/autoanchor.py:
--------------------------------------------------------------------------------
1 | # Auto-anchor utils
2 |
3 | import numpy as np
4 | import torch
5 | import yaml
6 | from tqdm import tqdm
7 |
8 | from utils.general import colorstr
9 |
10 |
11 | def check_anchor_order(m):
12 | # Check anchor order against stride order for YOLOv3 Detect() module m, and correct if necessary
13 | a = m.anchor_grid.prod(-1).view(-1) # anchor area
14 | da = a[-1] - a[0] # delta a
15 | ds = m.stride[-1] - m.stride[0] # delta s
16 | if da.sign() != ds.sign(): # same order
17 | print('Reversing anchor order')
18 | m.anchors[:] = m.anchors.flip(0)
19 | m.anchor_grid[:] = m.anchor_grid.flip(0)
20 |
21 |
22 | def check_anchors(dataset, model, thr=4.0, imgsz=640):
23 | # Check anchor fit to data, recompute if necessary
24 | prefix = colorstr('autoanchor: ')
25 | print(f'\n{prefix}Analyzing anchors... ', end='')
26 | m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
27 | shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
28 | scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
29 | wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh
30 |
31 | def metric(k): # compute metric
32 | r = wh[:, None] / k[None]
33 | x = torch.min(r, 1. / r).min(2)[0] # ratio metric
34 | best = x.max(1)[0] # best_x
35 | aat = (x > 1. / thr).float().sum(1).mean() # anchors above threshold
36 | bpr = (best > 1. / thr).float().mean() # best possible recall
37 | return bpr, aat
38 |
39 | anchors = m.anchor_grid.clone().cpu().view(-1, 2) # current anchors
40 | bpr, aat = metric(anchors)
41 | print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='')
42 | if bpr < 0.98: # threshold to recompute
43 | print('. Attempting to improve anchors, please wait...')
44 | na = m.anchor_grid.numel() // 2 # number of anchors
45 | try:
46 | anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
47 | except Exception as e:
48 | print(f'{prefix}ERROR: {e}')
49 | new_bpr = metric(anchors)[0]
50 | if new_bpr > bpr: # replace anchors
51 | anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
52 | m.anchor_grid[:] = anchors.clone().view_as(m.anchor_grid) # for inference
53 | m.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
54 | check_anchor_order(m)
55 | print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.')
56 | else:
57 | print(f'{prefix}Original anchors better than new anchors. Proceeding with original anchors.')
58 | print('') # newline
59 |
60 |
61 | def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
62 | """ Creates kmeans-evolved anchors from training dataset
63 | Arguments:
64 | path: path to dataset *.yaml, or a loaded dataset
65 | n: number of anchors
66 | img_size: image size used for training
67 | thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
68 | gen: generations to evolve anchors using genetic algorithm
69 | verbose: print all results
70 | Return:
71 | k: kmeans evolved anchors
72 | Usage:
73 | from utils.autoanchor import *; _ = kmean_anchors()
74 | """
75 | from scipy.cluster.vq import kmeans
76 |
77 | thr = 1. / thr
78 | prefix = colorstr('autoanchor: ')
79 |
80 | def metric(k, wh): # compute metrics
81 | r = wh[:, None] / k[None]
82 | x = torch.min(r, 1. / r).min(2)[0] # ratio metric
83 | # x = wh_iou(wh, torch.tensor(k)) # iou metric
84 | return x, x.max(1)[0] # x, best_x
85 |
86 | def anchor_fitness(k): # mutation fitness
87 | _, best = metric(torch.tensor(k, dtype=torch.float32), wh)
88 | return (best * (best > thr).float()).mean() # fitness
89 |
90 | def print_results(k):
91 | k = k[np.argsort(k.prod(1))] # sort small to large
92 | x, best = metric(k, wh0)
93 | bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
94 | print(f'{prefix}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr')
95 | print(f'{prefix}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, '
96 | f'past_thr={x[x > thr].mean():.3f}-mean: ', end='')
97 | for i, x in enumerate(k):
98 | print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg
99 | return k
100 |
101 | if isinstance(path, str): # *.yaml file
102 | with open(path) as f:
103 | data_dict = yaml.safe_load(f) # model dict
104 | from utils.datasets import LoadImagesAndLabels
105 | dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
106 | else:
107 | dataset = path # dataset
108 |
109 | # Get label wh
110 | shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
111 | wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
112 |
113 | # Filter
114 | i = (wh0 < 3.0).any(1).sum()
115 | if i:
116 | print(f'{prefix}WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.')
117 | wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
118 | # wh = wh * (np.random.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
119 |
120 | # Kmeans calculation
121 | print(f'{prefix}Running kmeans for {n} anchors on {len(wh)} points...')
122 | s = wh.std(0) # sigmas for whitening
123 | k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
124 | assert len(k) == n, print(f'{prefix}ERROR: scipy.cluster.vq.kmeans requested {n} points but returned only {len(k)}')
125 | k *= s
126 | wh = torch.tensor(wh, dtype=torch.float32) # filtered
127 | wh0 = torch.tensor(wh0, dtype=torch.float32) # unfiltered
128 | k = print_results(k)
129 |
130 | # Plot
131 | # k, d = [None] * 20, [None] * 20
132 | # for i in tqdm(range(1, 21)):
133 | # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
134 | # fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True)
135 | # ax = ax.ravel()
136 | # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
137 | # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
138 | # ax[0].hist(wh[wh[:, 0]<100, 0],400)
139 | # ax[1].hist(wh[wh[:, 1]<100, 1],400)
140 | # fig.savefig('wh.png', dpi=200)
141 |
142 | # Evolve
143 | npr = np.random
144 | f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
145 | pbar = tqdm(range(gen), desc=f'{prefix}Evolving anchors with Genetic Algorithm:') # progress bar
146 | for _ in pbar:
147 | v = np.ones(sh)
148 | while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
149 | v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
150 | kg = (k.copy() * v).clip(min=2.0)
151 | fg = anchor_fitness(kg)
152 | if fg > f:
153 | f, k = fg, kg.copy()
154 | pbar.desc = f'{prefix}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
155 | if verbose:
156 | print_results(k)
157 |
158 | return print_results(k)
--------------------------------------------------------------------------------
/init.py:
--------------------------------------------------------------------------------
1 | import logging.config
2 | import os
3 | from os.path import exists, join, realpath, split
4 |
5 | import torchvision.models as tv_models
6 | from torch import nn, save
7 | from torchvision import __version__ as torchvision_version
8 |
9 | import globals
10 | import config
11 | from models.experimental import attempt_load
12 | from models.resnet_26 import ResNetCifar
13 | from models.wide_resnet import WideResNet
14 | from utils.data_loader import dataset_checks
15 | from utils.general import check_img_size, increment_path
16 | from utils.torch_utils import select_device
17 | from utils.training import train
18 | from utils.training_yolov3 import train as train_yolo
19 |
20 | log = logging.getLogger('MAIN')
21 |
22 |
23 | def set_paths(args):
24 | args.dataroot = config.PATHS[args.usr][args.dataset]['root']
25 | args.ckpt_path = config.PATHS[args.usr][args.dataset]['ckpt']
26 |
27 |
28 | def init_net(args):
29 | device = select_device(args.device, batch_size=args.batch_size)
30 |
31 | if args.group_norm == 0:
32 | norm_layer = nn.BatchNorm2d
33 | else:
34 | def gn_helper(planes):
35 | return nn.GroupNorm(args.group_norm, planes)
36 | norm_layer = gn_helper
37 |
38 | def get_heads_classification(self):
39 | # returns last layer
40 | for m in self.modules(): pass
41 | return m
42 |
43 | if args.model == 'wrn':
44 | net = WideResNet(widen_factor=2, depth=40, num_classes=10)
45 | WideResNet.get_heads = get_heads_classification
46 |
47 | elif args.model == 'res26':
48 | net = ResNetCifar(args.depth, args.width, channels=3, classes=10,
49 | norm_layer=norm_layer)
50 | ResNetCifar.get_heads = get_heads_classification
51 |
52 | elif args.model == 'res18':
53 | num_classes = 200 if args.dataset == 'tiny-imagenet' else 1000
54 | # if no checkpoint provided start from the pretrained one
55 | if not args.ckpt_path:
56 | if torchvision_version.startswith(('0.11', '0.12')):
57 | net = tv_models.resnet18(pretrained=True, norm_layer=norm_layer, num_classes=num_classes)
58 | else:
59 | net = tv_models.resnet18(weights='DEFAULT', norm_layer=norm_layer, num_classes=num_classes)
60 | else:
61 | net = tv_models.resnet18(norm_layer=norm_layer, num_classes=num_classes)
62 | tv_models.resnet.ResNet.get_heads = get_heads_classification
63 |
64 | elif args.model == 'yolov3':
65 | if hasattr(args, 'orig_ckpt_path'):
66 | args.ckpt_path = args.orig_ckpt_path
67 | if exists(args.ckpt_path):
68 | args.orig_ckpt_path = args.ckpt_path
69 | net = attempt_load(args.ckpt_path, map_location=device)
70 | args.gs = max(int(net.stride.max()), 32)
71 | args.img_size = [check_img_size(x, args.gs) for x in args.img_size]
72 | else:
73 | net = init_yolov3(args, device)
74 | args.gs = max(int(net.stride.max()), 32)
75 | args.img_size = [check_img_size(x, args.gs) for x in args.img_size]
76 | train_initial(args, net)
77 | save(net.state_dict(), 'yolo_kitti_state_dict_ckpt.pt')
78 | args.ckpt_path = join(split(realpath(__file__))[0], 'yolo_kitti_state_dict_ckpt.pt')
79 |
80 | else:
81 | raise Exception(f'Invalid model argument: {args.model}')
82 |
83 | net = net.to(device)
84 | setattr(net.__class__, 'bn_stats', {})
85 | return net
86 |
87 |
88 | def init_yolov3(args, device):
89 | import torch
90 |
91 | from models.yolo import Model
92 | from utils.google_utils import attempt_download
93 | from utils.torch_utils import intersect_dicts, torch_distributed_zero_first
94 |
95 | log.info('Loading yolov3.pt weights.')
96 | hyp = args.yolo_hyp()
97 | with torch_distributed_zero_first(args.global_rank):
98 | attempt_download('yolov3.pt') # download if not found locally
99 | ckpt = torch.load('yolov3.pt', map_location=device) # load checkpoint
100 | if hyp.get('anchors'):
101 | ckpt['model'].yaml['anchors'] = round(hyp['anchors']) # force autoanchor
102 | net = Model(args.cfg or ckpt['model'].yaml, ch=3, nc=args.nc).to(device) # create
103 | exclude = ['anchor'] if args.cfg or hyp.get('anchors') else [] # exclude keys
104 | state_dict = ckpt['model'].float().state_dict() # to FP32
105 | state_dict = intersect_dicts(state_dict, net.state_dict(), exclude=exclude) # intersect
106 | net.load_state_dict(state_dict, strict=False) # load
107 | net.to(device)
108 | return net
109 |
110 |
111 | def train_initial(args, net):
112 | args.epochs = 350
113 | log.info('Checkpoint trained on initial task not found - Starting training.')
114 | args.task = 'initial'
115 | save_dir_path = join('checkpoints', args.dataset, args.model, 'initial')
116 |
117 | if args.model == 'yolov3':
118 | device = select_device(args.device, batch_size=args.batch_size)
119 | args.save_dir = save_dir_path
120 | train_yolo(args.yolo_hyp(), args, device, model=net)
121 | args.ckpt_path = join(split(realpath(__file__))[0], save_dir_path, 'weights', 'best.pt')
122 | else:
123 | save_file_name = f'{args.dataset}_initial.pt'
124 | path = join(save_dir_path, save_file_name)
125 | train(net, args, path)
126 | args.ckpt_path = join(split(realpath(__file__))[0], path)
127 | log.info(f'Checkpoint trained on initial task saved at {args.ckpt_path}')
128 |
129 |
130 | def init_settings(args):
131 | args.methods = [x.lower() for x in args.methods]
132 | os.makedirs('results', exist_ok=True)
133 | if args.dataset == 'kitti':
134 | if not args.model:
135 | args.model = 'yolov3'
136 | if args.tasks:
137 | globals.TASKS = args.tasks
138 | else:
139 | globals.TASKS = config.KITTI_TASKS
140 | args.num_severities = max([len(args.fog_severities),
141 | len(args.rain_severities),
142 | len(args.snow_severities)])
143 | globals.KITTI_SEVERITIES['fog'] = args.fog_severities
144 | globals.KITTI_SEVERITIES['rain'] = args.rain_severities
145 | globals.KITTI_SEVERITIES['snow'] = args.snow_severities
146 |
147 | # set args.yolo_hyp to a function returning a copy of globals.YOLO_HYP
148 | # as some values get changed during training, which would lead to
149 | # false values if multiple training sessions are started within one
150 | # execution of the script
151 | def get_yolo_hyp():
152 | return config.YOLO_HYP.copy()
153 | config.YOLO_HYP['lr0'] = args.lr
154 | args.yolo_hyp = get_yolo_hyp
155 |
156 | # opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
157 | # opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
158 | args.world_size = 1
159 | args.global_rank = -1
160 |
161 | args.img_size.extend([args.img_size[-1]] * (2 - len(args.img_size))) # extend to 2 sizes (train, test)
162 | args.total_batch_size = args.batch_size
163 | args.nc = 8
164 | args.names = ['Car', 'Van', 'Truck', 'Pedestrian', 'Person_sitting',
165 | 'Cyclist', 'Tram', 'Misc']
166 | else:
167 | if args.tasks:
168 | globals.TASKS = args.tasks
169 | else:
170 | globals.TASKS = config.ROBUSTNESS_TASKS
171 | if args.dataset in ['imagenet', 'imagenet-mini']:
172 | from utils.datasets import ImgNet
173 | ImgNet.initial_dir = args.dataset
174 | args.num_severities = len(args.robustness_severities)
175 | args.severity = None
176 | config.ROBUSTNESS_SEVERITIES = args.robustness_severities
177 | if args.dataset == 'cifar10' and not args.model:
178 | args.model = 'wrn'
179 | elif args.dataset == 'cifar10' and args.model == 'res26':
180 | args.model = 'res26'
181 | elif args.dataset in ['imagenet', 'imagenet-mini'] and not args.model:
182 | args.model = 'res18'
183 |
184 |
185 | def initial_checks(net, args):
186 | log.info('Running initial checks.')
187 | dataset_checks(args)
188 | if not args.ckpt_path or not exists(args.ckpt_path):
189 | train_initial(args, net)
190 |
191 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
13 |
14 |
15 |
114 |
115 |
116 |
121 |
122 |
123 |
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | # Loss functions
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from utils.general import bbox_iou
7 | from utils.torch_utils import is_parallel
8 |
9 |
10 | def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
11 | # return positive, negative label smoothing BCE targets
12 | return 1.0 - 0.5 * eps, 0.5 * eps
13 |
14 |
15 | class BCEBlurWithLogitsLoss(nn.Module):
16 | # BCEwithLogitLoss() with reduced missing label effects.
17 | def __init__(self, alpha=0.05):
18 | super(BCEBlurWithLogitsLoss, self).__init__()
19 | self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none') # must be nn.BCEWithLogitsLoss()
20 | self.alpha = alpha
21 |
22 | def forward(self, pred, true):
23 | loss = self.loss_fcn(pred, true)
24 | pred = torch.sigmoid(pred) # prob from logits
25 | dx = pred - true # reduce only missing label effects
26 | # dx = (pred - true).abs() # reduce missing label and false label effects
27 | alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4))
28 | loss *= alpha_factor
29 | return loss.mean()
30 |
31 |
32 | class FocalLoss(nn.Module):
33 | # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
34 | def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
35 | super(FocalLoss, self).__init__()
36 | self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
37 | self.gamma = gamma
38 | self.alpha = alpha
39 | self.reduction = loss_fcn.reduction
40 | self.loss_fcn.reduction = 'none' # required to apply FL to each element
41 |
42 | def forward(self, pred, true):
43 | loss = self.loss_fcn(pred, true)
44 | # p_t = torch.exp(-loss)
45 | # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
46 |
47 | # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
48 | pred_prob = torch.sigmoid(pred) # prob from logits
49 | p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
50 | alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
51 | modulating_factor = (1.0 - p_t) ** self.gamma
52 | loss *= alpha_factor * modulating_factor
53 |
54 | if self.reduction == 'mean':
55 | return loss.mean()
56 | elif self.reduction == 'sum':
57 | return loss.sum()
58 | else: # 'none'
59 | return loss
60 |
61 |
62 | class QFocalLoss(nn.Module):
63 | # Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
64 | def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
65 | super(QFocalLoss, self).__init__()
66 | self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
67 | self.gamma = gamma
68 | self.alpha = alpha
69 | self.reduction = loss_fcn.reduction
70 | self.loss_fcn.reduction = 'none' # required to apply FL to each element
71 |
72 | def forward(self, pred, true):
73 | loss = self.loss_fcn(pred, true)
74 |
75 | pred_prob = torch.sigmoid(pred) # prob from logits
76 | alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
77 | modulating_factor = torch.abs(true - pred_prob) ** self.gamma
78 | loss *= alpha_factor * modulating_factor
79 |
80 | if self.reduction == 'mean':
81 | return loss.mean()
82 | elif self.reduction == 'sum':
83 | return loss.sum()
84 | else: # 'none'
85 | return loss
86 |
87 |
88 | def compute_loss(p, targets, model): # predictions, targets, model
89 | device = targets.device
90 | lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
91 | tcls, tbox, indices, anchors = build_targets(p, targets, model) # targets
92 | h = model.hyp # hyperparameters
93 |
94 | # Define criteria
95 | BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device)) # weight=model.class_weights)
96 | BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
97 |
98 | # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
99 | cp, cn = smooth_BCE(eps=0.0)
100 |
101 | # Focal loss
102 | g = h['fl_gamma'] # focal loss gamma
103 | if g > 0:
104 | BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
105 |
106 | # Losses
107 | balance = [4.0, 1.0, 0.4, 0.1] # P3-P6
108 | for i, pi in enumerate(p): # layer index, layer predictions
109 | b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
110 | tobj = torch.zeros_like(pi[..., 0], device=device) # target obj
111 |
112 | n = b.shape[0] # number of targets
113 | if n:
114 | ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
115 |
116 | # Regression
117 | pxy = ps[:, :2].sigmoid() * 2. - 0.5
118 | pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
119 | pbox = torch.cat((pxy, pwh), 1) # predicted box
120 | iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target)
121 | lbox += (1.0 - iou).mean() # iou loss
122 |
123 | # Objectness
124 | tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio
125 |
126 | # Classification
127 | if model.nc > 1: # cls loss (only if multiple classes)
128 | t = torch.full_like(ps[:, 5:], cn, device=device) # targets
129 | t[range(n), tcls[i]] = cp
130 | lcls += BCEcls(ps[:, 5:], t) # BCE
131 |
132 | # Append targets to text file
133 | # with open('targets.txt', 'a') as file:
134 | # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
135 |
136 | lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss
137 |
138 | lbox *= h['box']
139 | lobj *= h['obj']
140 | lcls *= h['cls']
141 | bs = tobj.shape[0] # batch size
142 |
143 | loss = lbox + lobj + lcls
144 | return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()
145 |
146 |
147 | def build_targets(p, targets, model):
148 | # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
149 | det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
150 | na, nt = det.na, targets.shape[0] # number of anchors, targets
151 | tcls, tbox, indices, anch = [], [], [], []
152 | gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
153 | ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
154 | targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
155 |
156 | g = 0.5 # bias
157 | off = torch.tensor([[0, 0],
158 | # [1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m
159 | # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
160 | ], device=targets.device).float() * g # offsets
161 |
162 | for i in range(det.nl):
163 | anchors = det.anchors[i]
164 | gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
165 |
166 | # Match targets to anchors
167 | t = targets * gain
168 | if nt:
169 | # Matches
170 | r = t[:, :, 4:6] / anchors[:, None] # wh ratio
171 | j = torch.max(r, 1. / r).max(2)[0] < model.hyp['anchor_t'] # compare
172 | # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
173 | t = t[j] # filter
174 |
175 | # Offsets
176 | gxy = t[:, 2:4] # grid xy
177 | gxi = gain[[2, 3]] - gxy # inverse
178 | j, k = ((gxy % 1. < g) & (gxy > 1.)).T
179 | l, m = ((gxi % 1. < g) & (gxi > 1.)).T
180 | j = torch.stack((torch.ones_like(j),))
181 | t = t.repeat((off.shape[0], 1, 1))[j]
182 | offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
183 | else:
184 | t = targets[0]
185 | offsets = 0
186 |
187 | # Define
188 | b, c = t[:, :2].long().T # image, class
189 | gxy = t[:, 2:4] # grid xy
190 | gwh = t[:, 4:6] # grid wh
191 | gij = (gxy - offsets).long()
192 | gi, gj = gij.T # grid xy indices
193 |
194 | # Append
195 | a = t[:, 6].long() # anchor indices
196 | indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1))) # image, anchor, grid indices
197 | tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
198 | anch.append(anchors[a]) # anchors
199 | tcls.append(c) # class
200 |
201 | return tcls, tbox, indices, anch
--------------------------------------------------------------------------------
/utils/testing_yolov3.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import logging
4 | import os
5 | from pathlib import Path
6 | from threading import Thread
7 |
8 | import numpy as np
9 | import torch
10 | import yaml
11 | from tqdm import tqdm
12 |
13 | from models.experimental import attempt_load
14 | from utils.general import (box_iou, check_dataset, check_file, check_img_size,
15 | check_requirements, colorstr, increment_path,
16 | non_max_suppression, scale_coords, set_logging,
17 | xywh2xyxy, xyxy2xywh)
18 | from utils.loss import compute_loss
19 | from utils.metrics import ConfusionMatrix, ap_per_class
20 | from utils.plots import output_to_target, plot_images, plot_study_txt
21 | from utils.torch_utils import select_device, time_synchronized
22 |
23 | logger = logging.getLogger('TESTING')
24 | logger_fileonly = logging.getLogger('TESTING.FILEONLY')
25 |
26 | @torch.no_grad()
27 | def test(batch_size=32,
28 | imgsz=1216,
29 | conf_thres=0.001,
30 | iou_thres=0.6, # for NMS
31 | single_cls=False,
32 | augment=False,
33 | verbose=False,
34 | model=None,
35 | dataloader=None,
36 | save_dir=Path(''), # for saving images
37 | save_txt=False, # for auto-labelling
38 | save_hybrid=False, # for hybrid auto-labelling
39 | plots=False,
40 | half_precision = False,
41 | nc=8,
42 | training=False,
43 | multi_label=False): # number of logged images
44 |
45 | assert model and dataloader, 'Model and Loader need to be passed to yolov3 test'
46 |
47 | device = next(model.parameters()).device # get model device
48 | iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95
49 | niou = iouv.numel()
50 | nc = 1 if single_cls else nc # number of classes
51 |
52 | # Multi-GPU disabled, incompatible with .half() https://github.com/ultralytics/yolov5/issues/99
53 | # if device.type != 'cpu' and torch.cuda.device_count() > 1:
54 | # model = nn.DataParallel(model)
55 |
56 | half = device.type != 'cpu' and half_precision # half precision only supported on CUDA
57 | if half:
58 | model.half()
59 | model.eval()
60 |
61 | seen = 0
62 | confusion_matrix = ConfusionMatrix(nc=nc)
63 | names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)}
64 | s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
65 | p, r, f1, mp, mr, map50, map, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0.
66 | loss = torch.zeros(3, device=device)
67 | jdict, stats, ap, ap_class, wandb_images = [], [], [], [], []
68 | for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
69 | img = img.to(device, non_blocking=True)
70 | img = img.half() if half else img.float() # uint8 to fp16/32
71 | img /= 255.0 # 0 - 255 to 0.0 - 1.0
72 | targets = targets.to(device)
73 | nb, _, height, width = img.shape # batch size, channels, height, width
74 |
75 | with torch.no_grad(): # TODO should be redundant because of decorator @torch.no_grad()
76 | # Run model
77 | t = time_synchronized()
78 | inf_out, train_out = model(img, augment=augment) # inference and training outputs
79 | t0 += time_synchronized() - t
80 |
81 | # Compute loss
82 | if training:
83 | loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # box, obj, cls
84 |
85 | # Run NMS
86 | targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
87 | lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
88 | t = time_synchronized()
89 | output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, labels=lb,
90 | multi_label=multi_label, agnostic=single_cls)
91 | t1 += time_synchronized() - t
92 |
93 |
94 | # Statistics per image
95 | for si, pred in enumerate(output):
96 | labels = targets[targets[:, 0] == si, 1:]
97 | nl = len(labels)
98 | tcls = labels[:, 0].tolist() if nl else [] # target class
99 | path = Path(paths[si])
100 | seen += 1
101 |
102 | if len(pred) == 0:
103 | if nl:
104 | stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
105 | continue
106 |
107 | # Predictions
108 | if single_cls:
109 | pred[:, 5] = 0
110 | predn = pred.clone()
111 | scale_coords(img[si].shape[1:], predn[:, :4], shapes[si][0], shapes[si][1]) # native-space pred
112 |
113 | # Assign all predictions as incorrect
114 | correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool, device=device)
115 | if nl:
116 | detected = [] # target indices
117 | tcls_tensor = labels[:, 0]
118 |
119 | # target boxes
120 | tbox = xywh2xyxy(labels[:, 1:5])
121 | scale_coords(img[si].shape[1:], tbox, shapes[si][0], shapes[si][1]) # native-space labels
122 | if plots:
123 | confusion_matrix.process_batch(pred, torch.cat((labels[:, 0:1], tbox), 1))
124 |
125 | # Per target class
126 | for cls in torch.unique(tcls_tensor):
127 | ti = (cls == tcls_tensor).nonzero(as_tuple=False).view(-1) # prediction indices
128 | pi = (cls == pred[:, 5]).nonzero(as_tuple=False).view(-1) # target indices
129 |
130 | # Search for detections
131 | if pi.shape[0]:
132 | # Prediction to target ious
133 | ious, i = box_iou(predn[pi, :4], tbox[ti]).max(1) # best ious, indices
134 |
135 | # Append detections
136 | detected_set = set()
137 | for j in (ious > iouv[0]).nonzero(as_tuple=False):
138 | d = ti[i[j]] # detected target
139 | if d.item() not in detected_set:
140 | detected_set.add(d.item())
141 | detected.append(d)
142 | correct[pi[j]] = ious[j] > iouv # iou_thres is 1xn
143 | if len(detected) == nl: # all targets already located in image
144 | break
145 |
146 | # Append statistics (correct, conf, pcls, tcls)
147 | stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))
148 |
149 | # Plot images
150 | if plots and batch_i < 3:
151 | f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels
152 | Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start()
153 | f = save_dir / f'test_batch{batch_i}_pred.jpg' # predictions
154 | Thread(target=plot_images, args=(img, output_to_target(output), paths, f, names), daemon=True).start()
155 |
156 | # Compute statistics
157 | stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
158 | if len(stats) and stats[0].any():
159 | p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
160 | ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
161 | mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
162 | nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
163 | else:
164 | nt = torch.zeros(1)
165 |
166 | # Print results
167 | pf = '%20s' + '%12.3g' * 6 # print format
168 | logger_fileonly.info(s)
169 | logger.info(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
170 |
171 | # Print results per class
172 | if (verbose or (nc <= 20 and not training)) and nc > 1 and len(stats):
173 | for i, c in enumerate(ap_class):
174 | logger.info(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
175 |
176 | # Print speeds
177 | t = tuple(x / seen * 1E3 for x in (t0, t1, t0 + t1)) + (batch_size,) # tuple
178 | if not training:
179 | logger.info('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per image at batch-size %g' % t)
180 |
181 | # Plots
182 | if plots:
183 | confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
184 |
185 | # Return results
186 | if not training:
187 | s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
188 | print(f"Results saved to {save_dir}{s}")
189 | model.float() # for training
190 | maps = np.zeros(nc) + map
191 | for i, c in enumerate(ap_class):
192 | maps[c] = ap[i]
193 | return map50, (mp, mr, map50, map, *(loss.cpu() / len(dataloader)).tolist()), maps, t, ap50
194 |
195 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | # Model validation metrics
2 |
3 | from pathlib import Path
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import torch
8 |
9 | from . import general
10 |
11 |
12 | def fitness(x):
13 | # Model fitness as a weighted combination of metrics
14 | w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
15 | return (x[:, :4] * w).sum(1)
16 |
17 |
18 | def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=()):
19 | """ Compute the average precision, given the recall and precision curves.
20 | Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
21 | # Arguments
22 | tp: True positives (nparray, nx1 or nx10).
23 | conf: Objectness value from 0-1 (nparray).
24 | pred_cls: Predicted object classes (nparray).
25 | target_cls: True object classes (nparray).
26 | plot: Plot precision-recall curve at mAP@0.5
27 | save_dir: Plot save directory
28 | # Returns
29 | The average precision as computed in py-faster-rcnn.
30 | """
31 |
32 | # Sort by objectness
33 | i = np.argsort(-conf)
34 | tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
35 |
36 | # Find unique classes
37 | unique_classes = np.unique(target_cls)
38 | nc = unique_classes.shape[0] # number of classes, number of detections
39 |
40 | # Create Precision-Recall curve and compute AP for each class
41 | px, py = np.linspace(0, 1, 1000), [] # for plotting
42 | ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
43 | for ci, c in enumerate(unique_classes):
44 | i = pred_cls == c
45 | n_l = (target_cls == c).sum() # number of labels
46 | n_p = i.sum() # number of predictions
47 |
48 | if n_p == 0 or n_l == 0:
49 | continue
50 | else:
51 | # Accumulate FPs and TPs
52 | fpc = (1 - tp[i]).cumsum(0)
53 | tpc = tp[i].cumsum(0)
54 |
55 | # Recall
56 | recall = tpc / (n_l + 1e-16) # recall curve
57 | r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
58 |
59 | # Precision
60 | precision = tpc / (tpc + fpc) # precision curve
61 | p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
62 |
63 | # AP from recall-precision curve
64 | for j in range(tp.shape[1]):
65 | ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
66 | if plot and j == 0:
67 | py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
68 |
69 | # Compute F1 (harmonic mean of precision and recall)
70 | f1 = 2 * p * r / (p + r + 1e-16)
71 | if plot:
72 | plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names)
73 | plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1')
74 | plot_mc_curve(px, p, Path(save_dir) / 'P_curve.png', names, ylabel='Precision')
75 | plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall')
76 |
77 | i = f1.mean(0).argmax() # max F1 index
78 | return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32')
79 |
80 |
81 | def compute_ap(recall, precision):
82 | """ Compute the average precision, given the recall and precision curves
83 | # Arguments
84 | recall: The recall curve (list)
85 | precision: The precision curve (list)
86 | # Returns
87 | Average precision, precision curve, recall curve
88 | """
89 |
90 | # Append sentinel values to beginning and end
91 | mrec = np.concatenate(([0.], recall, [recall[-1] + 0.01]))
92 | mpre = np.concatenate(([1.], precision, [0.]))
93 |
94 | # Compute the precision envelope
95 | mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
96 |
97 | # Integrate area under curve
98 | method = 'interp' # methods: 'continuous', 'interp'
99 | if method == 'interp':
100 | x = np.linspace(0, 1, 101) # 101-point interp (COCO)
101 | ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
102 | else: # 'continuous'
103 | i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
104 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
105 |
106 | return ap, mpre, mrec
107 |
108 |
109 | class ConfusionMatrix:
110 | # Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
111 | def __init__(self, nc, conf=0.25, iou_thres=0.45):
112 | self.matrix = np.zeros((nc + 1, nc + 1))
113 | self.nc = nc # number of classes
114 | self.conf = conf
115 | self.iou_thres = iou_thres
116 |
117 | def process_batch(self, detections, labels):
118 | """
119 | Return intersection-over-union (Jaccard index) of boxes.
120 | Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
121 | Arguments:
122 | detections (Array[N, 6]), x1, y1, x2, y2, conf, class
123 | labels (Array[M, 5]), class, x1, y1, x2, y2
124 | Returns:
125 | None, updates confusion matrix accordingly
126 | """
127 | detections = detections[detections[:, 4] > self.conf]
128 | gt_classes = labels[:, 0].int()
129 | detection_classes = detections[:, 5].int()
130 | iou = general.box_iou(labels[:, 1:], detections[:, :4])
131 |
132 | x = torch.where(iou > self.iou_thres)
133 | if x[0].shape[0]:
134 | matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().detach().numpy()
135 | if x[0].shape[0] > 1:
136 | matches = matches[matches[:, 2].argsort()[::-1]]
137 | matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
138 | matches = matches[matches[:, 2].argsort()[::-1]]
139 | matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
140 | else:
141 | matches = np.zeros((0, 3))
142 |
143 | n = matches.shape[0] > 0
144 | m0, m1, _ = matches.transpose().astype(np.int16)
145 | for i, gc in enumerate(gt_classes):
146 | j = m0 == i
147 | if n and sum(j) == 1:
148 | self.matrix[detection_classes[m1[j]], gc] += 1 # correct
149 | else:
150 | self.matrix[self.nc, gc] += 1 # background FP
151 |
152 | if n:
153 | for i, dc in enumerate(detection_classes):
154 | if not any(m1 == i):
155 | self.matrix[dc, self.nc] += 1 # background FN
156 |
157 | def matrix(self):
158 | return self.matrix
159 |
160 | def plot(self, save_dir='', names=()):
161 | try:
162 | import seaborn as sn
163 |
164 | array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize
165 | array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
166 |
167 | fig = plt.figure(figsize=(12, 9), tight_layout=True)
168 | sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size
169 | labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels
170 | sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
171 | xticklabels=names + ['background FP'] if labels else "auto",
172 | yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
173 | fig.axes[0].set_xlabel('True')
174 | fig.axes[0].set_ylabel('Predicted')
175 | fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
176 | except Exception as e:
177 | pass
178 |
179 | def print(self):
180 | for i in range(self.nc + 1):
181 | print(' '.join(map(str, self.matrix[i])))
182 |
183 |
184 | # Plots ----------------------------------------------------------------------------------------------------------------
185 |
186 | def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()):
187 | # Precision-recall curve
188 | fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
189 | py = np.stack(py, axis=1)
190 |
191 | if 0 < len(names) < 21: # display per-class legend if < 21 classes
192 | for i, y in enumerate(py.T):
193 | ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
194 | else:
195 | ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
196 |
197 | ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
198 | ax.set_xlabel('Recall')
199 | ax.set_ylabel('Precision')
200 | ax.set_xlim(0, 1)
201 | ax.set_ylim(0, 1)
202 | plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
203 | fig.savefig(Path(save_dir), dpi=250)
204 |
205 |
206 | def plot_mc_curve(px, py, save_dir='mc_curve.png', names=(), xlabel='Confidence', ylabel='Metric'):
207 | # Metric-confidence curve
208 | fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
209 |
210 | if 0 < len(names) < 21: # display per-class legend if < 21 classes
211 | for i, y in enumerate(py):
212 | ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
213 | else:
214 | ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)
215 |
216 | y = py.mean(0)
217 | ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
218 | ax.set_xlabel(xlabel)
219 | ax.set_ylabel(ylabel)
220 | ax.set_xlim(0, 1)
221 | ax.set_ylim(0, 1)
222 | plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
223 | fig.savefig(Path(save_dir), dpi=250)
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging.config
3 | import sys
4 | import time
5 | from argparse import Namespace
6 |
7 | import torch
8 | import torch.backends.cudnn as cudnn
9 | from datetime import datetime
10 | import methods
11 | import config
12 | from init import init_net, init_settings, initial_checks, set_paths
13 | from utils.results_manager import ResultsManager
14 | from utils.utils import timedelta_to_str, setup_log_folder
15 |
16 |
17 | def main(args):
18 | if args.kitti_to_yolo_labels:
19 | from utils.utils import kitti_labels_to_yolo
20 | kitti_labels_to_yolo(args.kitti_to_yolo_labels)
21 | exit()
22 |
23 | cudnn.benchmark = True
24 | start_time = datetime.now()
25 |
26 | log.info('------------------------------------ NEW RUN ------------------------------------')
27 | log.info(f'Running: {" ".join(sys.argv)}')
28 | log.info('Full args list:')
29 | for arg in vars(args):
30 | log.info(f'{arg}: {getattr(args, arg)}')
31 | log.info('---------------------------------------------------------------------------------')
32 |
33 | results = ResultsManager('mAP@50' if args.dataset == 'kitti' else 'Error')
34 |
35 | init_settings(args)
36 | if args.usr:
37 | set_paths(args)
38 |
39 | for run in range(args.num_runs):
40 | net = init_net(args)
41 | for args.severity_idx in range(args.num_severities):
42 | if 'dua' in args.methods:
43 | methods.dua(args, net)
44 |
45 | # log results
46 | if results.has_results():
47 | timestamp_str = time.strftime('%b-%d-%Y_%H%M', time.localtime())
48 | results.save_to_file(file_name=f'{timestamp_str}_raw_results.pkl')
49 | results.print_summary_latex()
50 |
51 | if args.num_runs > 1:
52 | results.reset_results()
53 | log.info(f'{">" * 50} FINISHED RUN #{run + 1} {"<" * 50}')
54 | runtime = datetime.now() - start_time
55 | log.info(f'Runtime so far: {timedelta_to_str(runtime)}')
56 | torch.cuda.empty_cache()
57 | del net
58 |
59 | if args.num_runs > 1:
60 | results.print_multiple_runs_results()
61 |
62 | runtime = datetime.now() - start_time
63 | log.info(f'Execution finished in {timedelta_to_str(runtime)}')
64 |
65 |
66 | # Log uncaught exceptions, that aren't keyboard interrupts
67 | def handle_exception(exception_type, value, traceback):
68 | if issubclass(exception_type, KeyboardInterrupt):
69 | sys.__excepthook__(exception_type, value, traceback)
70 | return
71 | log.exception('Exception occured:', exc_info=(exception_type, value, traceback))
72 |
73 |
74 | sys.excepthook = handle_exception
75 |
76 | if __name__ == '__main__':
77 | parser = argparse.ArgumentParser()
78 |
79 | parser.add_argument('--usr', default=None, type=str)
80 | parser.add_argument('--dataroot', default='path/to/dataroot')
81 | parser.add_argument('--ckpt_path', default='path/to/checkpoint.pt')
82 | parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'kitti', 'imagenet-mini', 'imagenet'])
83 | parser.add_argument('--model', default=None, type=str, choices=['wrn', 'res26', 'res18', 'yolov3'])
84 | parser.add_argument('--logfolder', default='logs', type=str)
85 |
86 | # General run settings
87 | parser.add_argument('--tasks', default=[], type=str, nargs='*',
88 | help='List of tasks to run (in given order), empty means defaults from config.py')
89 | parser.add_argument('--scenario', default=['online', 'offline'], type=str, nargs='*',
90 | help='Scenarios to run (online and/or offline)')
91 | parser.add_argument('--robustness_severities', default=['5'], type=str, nargs='*')
92 | parser.add_argument('--fog_severities', default=['fog_30'], type=str, nargs='*')
93 | parser.add_argument('--rain_severities', default=['200mm'], type=str, nargs='*')
94 | parser.add_argument('--snow_severities', default=['5'], type=str, nargs='*')
95 | parser.add_argument('--checkpoints_path', default='checkpoints', help='path where model checkpoints will be saved')
96 | parser.add_argument('--num_runs', default=1, type=int)
97 | parser.add_argument('--methods', default=['dua'], type=str, nargs='*',
98 | choices=['dua'],
99 | help='List of methods to run')
100 |
101 | # DUA/DISC adaption
102 | parser.add_argument('--num_samples', default=80, type=int)
103 | parser.add_argument('--decay_factor', default=0.94, type=float)
104 | parser.add_argument('--min_mom', default=0.005, type=float)
105 | parser.add_argument('--no_disc_adaption', action='store_true',
106 | help='skip DISC adaption phase (assumes existing BN running estimates checkpoint)')
107 |
108 | # Learning & Loading
109 | parser.add_argument('--lr', default=0.01, type=float, help='Learning rate for everything except')
110 | parser.add_argument('--initial_task_lr', default=0.01, type=float)
111 | parser.add_argument('--epochs', default=150, type=int)
112 | parser.add_argument('--batch_size', default=8, type=int)
113 | parser.add_argument('--workers', type=int, default=1, help='maximum number of dataloader workers')
114 | parser.add_argument('--yolo_lr_adjustment', type=str, default='thirds',
115 | choices=['thirds', 'linear_lr', 'cosine'],
116 | help='how yolov3 training reduces learning rate')
117 |
118 | # LR scheduler and early stopping
119 | # for yolov3 these setting only apply with yolo_lr_adjustment set to 'thirds',
120 | # in which case the reduction by a factor of 3 can also be changed by setting
121 | # lr_factor to a different value
122 | parser.add_argument('--patience', default=4, type=int)
123 | parser.add_argument('--lr_factor', default=1 / 3, type=float)
124 | parser.add_argument('--verbose', default=True, type=bool)
125 | parser.add_argument('--max_unsuccessful_reductions', default=3, type=int)
126 |
127 | # For creating a val/test set from train set for CIFAR/ImageNet
128 | parser.add_argument('--split_ratio', default=0.35, type=float)
129 | parser.add_argument('--split_seed', default=42, type=int)
130 |
131 | # ResNet
132 | parser.add_argument('--depth', default=26, type=int)
133 | parser.add_argument('--width', default=1, type=int)
134 | parser.add_argument('--group_norm', default=0, type=int)
135 | parser.add_argument('--rotation_type', default='rand')
136 |
137 | # yolov3
138 | parser.add_argument('--weights', type=str, default='yolov3.pt', help='initial weights path')
139 | parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
140 | parser.add_argument('--img_size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
141 | parser.add_argument('--rect', action='store_true', help='rectangular training')
142 | parser.add_argument('--device', default='1', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
143 | parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
144 | parser.add_argument('--start_disjoint_offline_from_initial', action='store_true',
145 | help='start offline disjoint training from checkpoint trained on initial task')
146 | parser.add_argument('--use_freezing_heads_ckpts', action='store_true',
147 | help='Use freezing baseline heads from a previous run. '
148 | 'Without this option previously saved heads are moved.')
149 | parser.add_argument('--conf_thres', type=float, default=0.001, help='object confidence threshold')
150 | parser.add_argument('--iou_thres', type=float, default=0.6, help='IOU threshold for NMS')
151 | parser.add_argument('--augment', default=False, action='store_true', help='augmented inference')
152 | # yolov3 untested
153 | parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
154 | parser.add_argument('--notest', action='store_true', help='only test final epoch')
155 | parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
156 | parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
157 | parser.add_argument('--cache_images', action='store_true', help='cache images for faster training')
158 | parser.add_argument('--image_weights', action='store_true', help='use weighted image selection for training')
159 | parser.add_argument('--multi_scale', action='store_true', help='vary img-size +/- 50%%')
160 | parser.add_argument('--single_cls', action='store_true', help='train multi-class data as single-class')
161 | parser.add_argument('--sync_bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
162 | parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
163 | parser.add_argument('--log_imgs', type=int, default=16, help='number of images for W&B logging, max 100')
164 | parser.add_argument('--log_artifacts', action='store_true', help='log artifacts, i.e. final trained model')
165 | parser.add_argument('--project', default='runs/train', help='save to project/name')
166 | parser.add_argument('--name', default='exp', help='save to project/name')
167 | parser.add_argument('--exist_ok', action='store_true', help='existing project/name ok, do not increment')
168 | parser.add_argument('--quad', action='store_true', help='quad dataloader')
169 |
170 | # other
171 | parser.add_argument('--kitti_to_yolo_labels', default=None, type=str,
172 | help='Generate YOLO style labels from KITTI labels, given original KITTI root dir')
173 |
174 | args: Namespace = parser.parse_args()
175 | setup_log_folder(args)
176 |
177 | config.LOGGER_CFG['handlers']['file_handler']['filename'] = args.logfile
178 |
179 | logging.config.dictConfig(config.LOGGER_CFG)
180 |
181 | log = logging.getLogger('MAIN')
182 |
183 | main(args)
184 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import join
3 | import torch
4 | from pathlib import Path
5 | import time
6 |
7 | # from colorama import Fore
8 |
9 |
10 | def get_grad(params):
11 | if isinstance(params, torch.Tensor):
12 | params = [params]
13 | params = list(filter(lambda p: p.grad is not None, params))
14 | grad = [p.grad.data.cpu().view(-1) for p in params]
15 | return torch.cat(grad)
16 |
17 |
18 | def write_to_txt(name, content):
19 | with open(name, 'w') as text_file:
20 | text_file.write(content)
21 |
22 |
23 | def make_dirs(path):
24 | os.makedirs(path, exist_ok=True)
25 |
26 |
27 | def print_args(opt):
28 | for arg in vars(opt):
29 | print('%s %s' % (arg, getattr(opt, arg)))
30 |
31 |
32 | def mean(ls):
33 | return sum(ls) / len(ls)
34 |
35 |
36 | def normalize(v):
37 | return (v - v.mean()) / v.std()
38 |
39 |
40 | def flat_grad(grad_tuple):
41 | return torch.cat([p.view(-1) for p in grad_tuple])
42 |
43 |
44 | def print_nparams(model):
45 | nparams = sum([param.nelement() for param in model.parameters()])
46 | print('number of parameters: %d' % (nparams))
47 |
48 |
49 | def plot_adaptation_err(all_err_cls, corr, args):
50 | import matplotlib.pyplot as plt
51 | plt.switch_backend('agg')
52 | fig, _ = plt.subplots()
53 |
54 | plt.plot(all_err_cls, color='r', label=corr)
55 | plt.xlabel('Number of Samples for Adaptation')
56 | plt.ylabel('Test Error (%)')
57 | plt.legend()
58 | plt.savefig(os.path.join(args.outf, corr), format="png")
59 | plt.close(fig)
60 |
61 |
62 | def eval_yolo_ckpts(net, args, scenario, baseline_str, ckpts=None):
63 | """
64 | Evaluate yolov3 chekpoints from previous runs.
65 |
66 | Example usage:
67 | args.severity_idx = 0
68 | ckpts = { ...ckpts to evaluate... }
69 | for bl in ['disjoint', 'fine_tuning', 'joint_training']:
70 | for scenario in ['online', 'offline']:
71 | eval_yolo_ckpts(net, args, scenario, bl, ckpts)
72 | """
73 | import logging
74 | from utils.results_manager import ResultsManager
75 | from os.path import join
76 | from utils.torch_utils import select_device
77 | from utils.data_loader import get_loader, set_severity
78 | from utils.testing_yolov3 import test as test_yolo
79 | from statistics import mean
80 | import globals
81 |
82 | log = logging.getLogger('MAIN')
83 |
84 | if not ckpts:
85 | # yolov3 training results directories, by default settings found at:
86 | # checkpoints/kitti/yolov3/ ...
87 | ckpts = {
88 | 'disjoint': {
89 | 'online': {
90 | 'fog': 'fog_fog_30_train_results',
91 | 'rain': 'rain_200mm_train_results',
92 | 'snow': 'snow_5_train_results'
93 | },
94 | 'offline': {
95 | 'fog': 'fog_fog_30_train_results',
96 | 'rain': 'rain_200mm_train_results',
97 | 'snow': 'snow_5_train_results'
98 | }
99 | },
100 | 'freezing': {
101 | 'online': {
102 | 'fog': 'fog_fog_30_train_results',
103 | 'rain': 'rain_200mm_train_results',
104 | 'snow': 'snow_5_train_results'
105 | },
106 | 'offline': {
107 | 'fog': 'fog_fog_30_train_results',
108 | 'rain': 'rain_200mm_train_results',
109 | 'snow': 'snow_5_train_results'
110 | }
111 | },
112 | 'fine_tuning': {
113 | 'online': {
114 | 'fog': 'fog_fog_30_train_results',
115 | 'rain': 'rain_200mm_train_results',
116 | 'snow': 'snow_5_train_results'
117 | },
118 | 'offline': {
119 | 'fog': 'fog_fog_30_train_results',
120 | 'rain': 'rain_200mm_train_results',
121 | 'snow': 'snow_5_train_results'
122 | }
123 | },
124 | 'joint_training': {
125 | 'online': {
126 | 'fog': 'fog_fog_30_train_results',
127 | 'rain': 'rain_200mm_train_results',
128 | 'snow': 'snow_5_train_results'
129 | },
130 | 'offline': {
131 | 'fog': 'fog_fog_30_train_results',
132 | 'rain': 'rain_200mm_train_results',
133 | 'snow': 'snow_5_train_results'
134 | }
135 | },
136 | }
137 |
138 | args.severity_idx = 0
139 | tasks = ['initial'] + globals.TASKS
140 | results = ResultsManager('mAP@50')
141 | device = select_device(args.device, batch_size=args.batch_size)
142 |
143 | log.info(f'::: Running ckpt evaluations for baseline {baseline_str} ({scenario}) :::')
144 | for idx, args.task in enumerate(tasks):
145 | ckpt_folder = join(args.checkpoints_path, args.dataset, args.model, baseline_str, scenario)
146 | if args.task == 'initial':
147 | continue
148 | current_results = []
149 | if not set_severity(args):
150 | continue
151 | severity_str = '' if args.task == 'initial' else f'Severity: {args.severity}'
152 | log.info(f'Start evaluation for Task-{idx} ({args.task}). {severity_str}')
153 |
154 | # load ckpt
155 | ckpt_folder = join(ckpt_folder, ckpts[baseline_str][scenario][args.task], 'weights')
156 | ckpt_path = join(ckpt_folder, 'best.pt')
157 | log.info(f'Loading: {ckpt_path}')
158 | ckpt = torch.load(ckpt_path, map_location=device) # load checkpoint
159 | state_dict = ckpt['model'].float().state_dict() # to FP32
160 | net.load_state_dict(state_dict) # load
161 |
162 | for i in range(0, idx + 1):
163 | args.task = tasks[i]
164 | if not set_severity(args):
165 | continue
166 |
167 | test_loader = get_loader(args, split='test', pad=0.5, rect=True)
168 | res = test_yolo(model=net, dataloader=test_loader,
169 | iou_thres=args.iou_thres, conf_thres=args.conf_thres,
170 | augment=args.augment)[0] * 100
171 |
172 | current_results.append(res)
173 | log.info(f'\tmAP@50 on Task-{i} ({tasks[i]}): {res:.1f}')
174 |
175 | if i == idx:
176 | mean_result = mean(current_results)
177 | log.info(f'\tMean mAP@50 over current task ({tasks[idx]}) '
178 | f'and previously seen tasks: {mean_result:.1f}')
179 | severity_str = '' if args.task == 'initial' else f'{args.severity}'
180 | results.add_result(baseline_str, f'{tasks[idx]} {severity_str}', mean_result, scenario)
181 |
182 |
183 | def timedelta_to_str(timedelta, explicit_days=False):
184 | s = ''
185 | if explicit_days:
186 | s = f'{timedelta.days} Days, {timedelta.seconds // 3600:02}:'
187 | else:
188 | total_hrs = timedelta.days * 24 + timedelta.seconds // 3600
189 | s = f'{str(total_hrs).zfill(2 if total_hrs < 100 else 3)}:'
190 | s += f'{(timedelta.seconds % 3600) // 60:02}:{timedelta.seconds % 60:02}'
191 | return s
192 |
193 |
194 | def setup_tiny_imagenet_val_dir(val_dir_path, val_num_imgs=10000, rm_initial=False):
195 | """
196 | Tiny ImageNet validation set comes with 10k images from all 200 classes
197 | placed in the same folder (images) and a val_annotations.txt pointing
198 | out which image belongs to which class.
199 | This method moves all of the images into an image folder inside a folder
200 | named after the class they belong to.
201 | """
202 | import glob
203 | from os.path import exists, join, split
204 | from shutil import copy, move
205 |
206 | from tqdm import tqdm
207 |
208 | val_dict = {}
209 | with open(f'{val_dir_path}/val_annotations.txt', 'r') as f:
210 | for line in f.readlines():
211 | split_line = line.split('\t')
212 | val_dict[split_line[0]] = split_line[1]
213 |
214 | paths = glob.iglob(join(val_dir_path, 'images', '*'))
215 | for path in tqdm(paths, total=val_num_imgs):
216 | file = split(path)[1]
217 | folder = val_dict[file]
218 | if not exists(val_dir_path + str(folder)):
219 | make_dirs(join(val_dir_path, str(folder), 'images'))
220 | # copy(path, join(val_dir_path, str(folder), 'images', str(file)))
221 | move(path, join(val_dir_path, str(folder), 'images', str(file)))
222 |
223 | if rm_initial:
224 | os.rmdir(join(val_dir_path, 'images'))
225 | os.remove(join(val_dir_path, 'val_annotations.txt'))
226 |
227 |
228 | def setup_log_folder(args):
229 | Path(args.logfolder).mkdir(exist_ok=True, parents=True)
230 | args.logfile = args.logfolder + f'/{time.strftime("%Y%m%d_%H%M%S")}.txt'
231 |
232 |
233 | def kitti_labels_to_yolo(dataroot):
234 | from cv2 import imread
235 |
236 | print('Converting KITTI labels to YOLO label format.')
237 |
238 | imgs_dir = join(dataroot, 'raw', 'training', 'image_2')
239 | labels_dir = join(dataroot, 'raw', 'training', 'label_2')
240 | save_at_dir = join(dataroot, 'raw', 'yolo_style_labels')
241 | make_dirs(save_at_dir)
242 |
243 | class_names = ['Car', 'Van', 'Truck', 'Pedestrian', 'Person_sitting',
244 | 'Cyclist', 'Tram', 'Misc']
245 | img_file_names = sorted(os.listdir(imgs_dir))
246 | label_file_names = sorted(os.listdir(labels_dir))
247 | label_dict = dict(zip(class_names, range(len(class_names))))
248 |
249 | for img_file_name, label_file_name in zip(img_file_names, label_file_names):
250 | img_path = join(imgs_dir, img_file_name)
251 | label_path = join(labels_dir, label_file_name)
252 |
253 | img = imread(img_path)
254 | img_height, img_width, = img.shape[:2]
255 |
256 | with open(label_path, 'r') as f:
257 | label_lines = f.readlines()
258 |
259 | yolo_label_file = open(join(save_at_dir, label_file_name), 'w')
260 |
261 | for line in label_lines:
262 | label_entry = line.split(' ')
263 | if len(label_entry) != 15:
264 | raise Exception(f'Faulty original label in: {label_file_name}')
265 |
266 | class_name = label_entry[0]
267 | if class_name == 'DontCare':
268 | continue
269 |
270 | x1 = float(label_entry[4]) # left
271 | y1 = float(label_entry[5]) # top
272 | x2 = float(label_entry[6]) # right
273 | y2 = float(label_entry[7]) # bottom
274 |
275 | bbox_center_x = (x1 + x2) / 2.0 / img_width
276 | bbox_center_y = (y1 + y2) / 2.0 / img_height
277 | bbox_width = float((x2 - x1) / img_width)
278 | bbox_height = float((y2 - y1) / img_height)
279 |
280 | yolo_label_line = f'{label_dict[class_name]} {bbox_center_x} ' \
281 | f'{bbox_center_y} {bbox_width} {bbox_height}\n'
282 | yolo_label_file.write(yolo_label_line)
283 | yolo_label_file.close()
284 |
--------------------------------------------------------------------------------
/utils/torch_utils.py:
--------------------------------------------------------------------------------
1 | # YOLOv3 PyTorch utils
2 |
3 | import datetime
4 | import logging
5 | import math
6 | import os
7 | import platform
8 | import subprocess
9 | import time
10 | from contextlib import contextmanager
11 | from copy import deepcopy
12 | from pathlib import Path
13 |
14 | import torch
15 | import torch.backends.cudnn as cudnn
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 | import torchvision
19 |
20 | try:
21 | import thop # for FLOPS computation
22 | except ImportError:
23 | thop = None
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | @contextmanager
28 | def torch_distributed_zero_first(local_rank: int):
29 | """
30 | Decorator to make all processes in distributed training wait for each local_master to do something.
31 | """
32 | if local_rank not in [-1, 0]:
33 | torch.distributed.barrier()
34 | yield
35 | if local_rank == 0:
36 | torch.distributed.barrier()
37 |
38 |
39 | def init_torch_seeds(seed=0):
40 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
41 | torch.manual_seed(seed)
42 | if seed == 0: # slower, more reproducible
43 | cudnn.benchmark, cudnn.deterministic = False, True
44 | else: # faster, less reproducible
45 | cudnn.benchmark, cudnn.deterministic = True, False
46 |
47 |
48 | def date_modified(path=__file__):
49 | # return human-readable file modification date, i.e. '2021-3-26'
50 | t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
51 | return f'{t.year}-{t.month}-{t.day}'
52 |
53 |
54 | def git_describe(path=Path(__file__).parent): # path must be a directory
55 | # return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
56 | s = f'git -C {path} describe --tags --long --always'
57 | try:
58 | return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
59 | except subprocess.CalledProcessError as e:
60 | return '' # not a git repository
61 |
62 |
63 | def select_device(device='', batch_size=None):
64 | # device = 'cpu' or '0' or '0,1,2,3'
65 | s = f'YOLOv3 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
66 | cpu = device.lower() == 'cpu'
67 | if cpu:
68 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
69 | elif device: # non-cpu device requested
70 | os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
71 | assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
72 |
73 | cuda = not cpu and torch.cuda.is_available()
74 | if cuda:
75 | devices = device.split(',') if device else range(torch.cuda.device_count()) # i.e. 0,1,6,7
76 | n = len(devices) # device count
77 | if n > 1 and batch_size: # check batch_size is divisible by device_count
78 | assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
79 | space = ' ' * len(s)
80 | for i, d in enumerate(devices):
81 | p = torch.cuda.get_device_properties(i)
82 | s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
83 | else:
84 | s += 'CPU\n'
85 |
86 | logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
87 | return torch.device('cuda:0' if cuda else 'cpu')
88 |
89 |
90 | def time_synchronized():
91 | # pytorch-accurate time
92 | if torch.cuda.is_available():
93 | torch.cuda.synchronize()
94 | return time.time()
95 |
96 |
97 | def profile(x, ops, n=100, device=None):
98 | # profile a pytorch module or list of modules. Example usage:
99 | # x = torch.randn(16, 3, 640, 640) # input
100 | # m1 = lambda x: x * torch.sigmoid(x)
101 | # m2 = nn.SiLU()
102 | # profile(x, [m1, m2], n=100) # profile speed over 100 iterations
103 |
104 | device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
105 | x = x.to(device)
106 | x.requires_grad = True
107 | print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
108 | print(f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
109 | for m in ops if isinstance(ops, list) else [ops]:
110 | m = m.to(device) if hasattr(m, 'to') else m # device
111 | m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type
112 | dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
113 | try:
114 | flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPS
115 | except:
116 | flops = 0
117 |
118 | for _ in range(n):
119 | t[0] = time_synchronized()
120 | y = m(x)
121 | t[1] = time_synchronized()
122 | try:
123 | _ = y.sum().backward()
124 | t[2] = time_synchronized()
125 | except: # no backward method
126 | t[2] = float('nan')
127 | dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
128 | dtb += (t[2] - t[1]) * 1000 / n # ms per op backward
129 |
130 | s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
131 | s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
132 | p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
133 | print(f'{p:12}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}')
134 |
135 |
136 | def is_parallel(model):
137 | # Returns True if model is of type DP or DDP
138 | return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
139 |
140 |
141 | def de_parallel(model):
142 | # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
143 | return model.module if is_parallel(model) else model
144 |
145 |
146 | def intersect_dicts(da, db, exclude=()):
147 | # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
148 | return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
149 |
150 |
151 | def initialize_weights(model):
152 | for m in model.modules():
153 | t = type(m)
154 | if t is nn.Conv2d:
155 | pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
156 | elif t is nn.BatchNorm2d:
157 | m.eps = 1e-3
158 | m.momentum = 0.03
159 | elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
160 | m.inplace = True
161 |
162 |
163 | def find_modules(model, mclass=nn.Conv2d):
164 | # Finds layer indices matching module class 'mclass'
165 | return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
166 |
167 |
168 | def sparsity(model):
169 | # Return global model sparsity
170 | a, b = 0., 0.
171 | for p in model.parameters():
172 | a += p.numel()
173 | b += (p == 0).sum()
174 | return b / a
175 |
176 |
177 | def prune(model, amount=0.3):
178 | # Prune model to requested global sparsity
179 | import torch.nn.utils.prune as prune
180 | print('Pruning model... ', end='')
181 | for name, m in model.named_modules():
182 | if isinstance(m, nn.Conv2d):
183 | prune.l1_unstructured(m, name='weight', amount=amount) # prune
184 | prune.remove(m, 'weight') # make permanent
185 | print(' %.3g global sparsity' % sparsity(model))
186 |
187 |
188 | def fuse_conv_and_bn(conv, bn):
189 | # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
190 | fusedconv = nn.Conv2d(conv.in_channels,
191 | conv.out_channels,
192 | kernel_size=conv.kernel_size,
193 | stride=conv.stride,
194 | padding=conv.padding,
195 | groups=conv.groups,
196 | bias=True).requires_grad_(False).to(conv.weight.device)
197 |
198 | # prepare filters
199 | w_conv = conv.weight.clone().view(conv.out_channels, -1)
200 | w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
201 | fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
202 |
203 | # prepare spatial bias
204 | b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
205 | b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
206 | fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
207 |
208 | return fusedconv
209 |
210 |
211 | def model_info(model, verbose=False, img_size=640):
212 | # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
213 | n_p = sum(x.numel() for x in model.parameters()) # number parameters
214 | n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
215 | if verbose:
216 | print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
217 | for i, (name, p) in enumerate(model.named_parameters()):
218 | name = name.replace('module_list.', '')
219 | print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
220 | (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
221 |
222 | try: # FLOPS
223 | from thop import profile
224 | stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
225 | img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
226 | flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPS
227 | img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
228 | fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPS
229 | except (ImportError, Exception):
230 | fs = ''
231 |
232 | logger.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
233 |
234 |
235 | def load_classifier(name='resnet101', n=2):
236 | # Loads a pretrained model reshaped to n-class output
237 | model = torchvision.models.__dict__[name](pretrained=True)
238 |
239 | # ResNet model properties
240 | # input_size = [3, 224, 224]
241 | # input_space = 'RGB'
242 | # input_range = [0, 1]
243 | # mean = [0.485, 0.456, 0.406]
244 | # std = [0.229, 0.224, 0.225]
245 |
246 | # Reshape output to n classes
247 | filters = model.fc.weight.shape[1]
248 | model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
249 | model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
250 | model.fc.out_features = n
251 | return model
252 |
253 |
254 | def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
255 | # scales img(bs,3,y,x) by ratio constrained to gs-multiple
256 | if ratio == 1.0:
257 | return img
258 | else:
259 | h, w = img.shape[2:]
260 | s = (int(h * ratio), int(w * ratio)) # new size
261 | img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
262 | if not same_shape: # pad/crop img
263 | h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
264 | return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
265 |
266 |
267 | def copy_attr(a, b, include=(), exclude=()):
268 | # Copy attributes from b to a, options to only include [...] and to exclude [...]
269 | for k, v in b.__dict__.items():
270 | if (len(include) and k not in include) or k.startswith('_') or k in exclude:
271 | continue
272 | else:
273 | setattr(a, k, v)
274 |
275 |
276 | class ModelEMA:
277 | """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
278 | Keep a moving average of everything in the model state_dict (parameters and buffers).
279 | This is intended to allow functionality like
280 | https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
281 | A smoothed version of the weights is necessary for some training schemes to perform well.
282 | This class is sensitive where it is initialized in the sequence of model init,
283 | GPU assignment and distributed training wrappers.
284 | """
285 |
286 | def __init__(self, model, decay=0.9999, updates=0):
287 | # Create EMA
288 | self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
289 | # if next(model.parameters()).device.type != 'cpu':
290 | # self.ema.half() # FP16 EMA
291 | self.updates = updates # number of EMA updates
292 | self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
293 | for p in self.ema.parameters():
294 | p.requires_grad_(False)
295 |
296 | def update(self, model):
297 | # Update EMA parameters
298 | with torch.no_grad():
299 | self.updates += 1
300 | d = self.decay(self.updates)
301 |
302 | msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
303 | for k, v in self.ema.state_dict().items():
304 | if v.dtype.is_floating_point:
305 | v *= d
306 | v += (1. - d) * msd[k].detach()
307 |
308 | def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
309 | # Update EMA attributes
310 | copy_attr(self.ema, model, include, exclude)
--------------------------------------------------------------------------------
/utils/data_loader.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import os
4 | from os.path import exists, join, normpath
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | import torch
9 | import torchvision.transforms as transforms
10 | from torch import manual_seed, randperm
11 | from torch.utils.data import ConcatDataset, DataLoader, Subset
12 | from torchvision.datasets import CIFAR10
13 |
14 | import globals
15 | import config
16 | from utils.datasets import CIFAR, ImgNet
17 | from utils.datasets import LoadImagesAndLabels as Kitti
18 | from utils.general import check_img_size, increment_path
19 | from utils.torch_utils import torch_distributed_zero_first
20 |
21 | log = logging.getLogger('MAIN.DATA')
22 |
23 | NORM = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
24 |
25 | te_transforms = transforms.Compose([transforms.ToTensor(),
26 | transforms.Normalize(*NORM)
27 | ])
28 |
29 | tr_transforms = transforms.Compose([transforms.RandomCrop(32, padding=4),
30 | transforms.RandomHorizontalFlip(),
31 | transforms.ToTensor(),
32 | transforms.Normalize(*NORM)
33 | ])
34 |
35 | NORM_IMGNET = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
36 |
37 | tr_transforms_imgnet = transforms.Compose([transforms.RandomResizedCrop(224),
38 | transforms.RandomHorizontalFlip(),
39 | transforms.ToTensor(),
40 | transforms.Normalize(*NORM_IMGNET)
41 | ])
42 |
43 | te_transforms_imgnet = transforms.Compose([transforms.Resize(256),
44 | transforms.CenterCrop(224),
45 | transforms.ToTensor(),
46 | transforms.Normalize(*NORM_IMGNET)
47 | ])
48 |
49 |
50 | def get_loader(args, split='train', joint=False, shuffle=True, pad=0.0, aug=False, rect=False):
51 | """
52 | Create the loader for the specified split (train/val/test) and
53 | current task (args.task).
54 | If joint=True the dataset will be created from the current task and
55 | all tasks which came before it (in globals.TASKS), combined.
56 | Parameters: padding (pad), augment (aug) and rectangular training (rect)
57 | only apply to the yolov3 model and are ignored for other models.
58 | YOLOv3 rectangular training (rect) is incompatible with dataloader
59 | shuffle (shuffle) and shuffle will be set to False silently if that
60 | combination of parameters is supplied to this function.
61 | """
62 | if args.model == 'yolov3' and rect and shuffle:
63 | shuffle = False
64 |
65 | if args.dataset == 'kitti' or not joint:
66 | # Create loader for joint or non-joint KITTI dataset, as well as
67 | # non-joint loaders for other datasets
68 | ds = get_dataset(args, split=split, pad=pad, aug=aug, rect=rect, joint=joint)
69 | collate_fn = Kitti.collate_fn if args.dataset == 'kitti' else None
70 | loader = DataLoader if args.model != 'yolov3' or args.image_weights else InfiniteDataLoader
71 | rank = args.global_rank if args.model == 'yolov3' and split == 'train' else -1
72 | return loader(ds, batch_size=args.batch_size, shuffle=shuffle,
73 | num_workers=args.workers, collate_fn=collate_fn,
74 | pin_memory=True)
75 | else:
76 | # Create joint loaders for datasets other than KITTI
77 | datasets = []
78 | current_task = args.task
79 | for args.task in ['initial'] + globals.TASKS:
80 | datasets.append(get_dataset(args, split=split))
81 | if current_task == args.task:
82 | break
83 | return DataLoader(ConcatDataset(datasets),
84 | batch_size=args.batch_size,
85 | shuffle=True,
86 | num_workers=args.workers)
87 |
88 |
89 | def get_dataset(args, split=None, pad=0.0, aug=False, rect=False, joint=False):
90 | """
91 | Create dataset based on args and split
92 | Parameters: padding (pad), augment (aug), rectangular training (rect) and
93 | joint training (joint) only apply to the yolov3 model and are ignored
94 | for other models.
95 | """
96 | if not hasattr(args, 'task'):
97 | args.task = 'initial'
98 | if args.task not in ['initial'] + globals.TASKS:
99 | raise Exception(f'Invalid task: {args.task}')
100 |
101 | if args.dataset == 'cifar10':
102 | transform = tr_transforms if split == 'train' else te_transforms
103 | ds = CIFAR(args.dataroot, args.task, split=split, transform=transform,
104 | severity=int(args.severity))
105 | if split != 'test':
106 | # train and val split are being created from the train set
107 | ds = get_split_subset(args, ds, split)
108 |
109 | elif args.dataset in ['imagenet', 'imagenet-mini']:
110 | transform = tr_transforms_imgnet if split == 'train' else te_transforms_imgnet
111 |
112 | ds = ImgNet(args.dataroot, split, args.task, args.severity, transform)
113 | if split != 'val':
114 | # train and test split are being created from the train set
115 | ds = get_split_subset(args, ds, split)
116 |
117 | elif args.dataset == 'kitti':
118 | path = join(args.dataroot, f'{split}.txt')
119 | img_size_idx = split != 'train'
120 | img_size = check_img_size(img_size=args.img_size[img_size_idx], s=args.gs)
121 | img_dirs_paths = []
122 | if joint:
123 | # put paths to all tasks image directories into img_dirs_paths
124 | for t in ['initial'] + globals.TASKS:
125 | if t != 'initial':
126 | if args.severity_idx < len(globals.KITTI_SEVERITIES[args.task]):
127 | args.severity = globals.KITTI_SEVERITIES[t][args.severity_idx]
128 | else:
129 | continue
130 | img_dir = 'images' if t == 'initial' else f'{args.severity}'
131 | img_dirs_paths.append(join(args.dataroot, f'{t}', img_dir))
132 | if t == args.task:
133 | break
134 | else:
135 | img_dir = 'images' if args.task == 'initial' else f'{args.severity}'
136 | img_dirs_paths.append(join(args.dataroot, f'{args.task}', img_dir))
137 |
138 | with torch_distributed_zero_first(-1):
139 | ds = Kitti(path, img_size, args.batch_size,
140 | augment=aug, hyp=args.yolo_hyp(), rect=rect,
141 | stride=int(args.gs), pad=pad, imgs_dir=img_dirs_paths)
142 | return ds
143 |
144 |
145 | def get_split_subset(args, ds, split):
146 | """
147 | Create a subset of given dataset (ds).
148 | Specifically defined for CIFAR10 and ImageNet, as they either do not
149 | have a labeled validation set or test set, therefore we create them
150 | here from the their train sets.
151 | args.split_seed is used to define a seed to be able to reproduce a split.
152 | args.split_ratio defines how much percent of the train set will be used
153 | as validation/test set (e.g. args.split_ratio = 0.3 for CIFAR10 means
154 | 30% of the train set will be used as validation set and the remaining
155 | 70% will be the train set).
156 | """
157 | manual_seed(args.split_seed)
158 | indices = randperm(len(ds))
159 | valid_size = round(len(ds) * args.split_ratio)
160 |
161 | if args.dataset == 'cifar10':
162 | if split == 'train':
163 | ds = Subset(ds, indices[:-valid_size])
164 | elif split == 'val':
165 | ds = Subset(ds, indices[-valid_size:])
166 |
167 | elif args.dataset in ['imagenet', 'imagenet-mini']:
168 | if split == 'train':
169 | ds = Subset(ds, indices[:-valid_size])
170 | elif split == 'test':
171 | ds = Subset(ds, indices[-valid_size:])
172 |
173 | return ds
174 |
175 |
176 | def get_image_from_idx(self, idx: int = 0):
177 | return self.dataset.get_image_from_idx(idx)
178 | Subset.get_image_from_idx = get_image_from_idx
179 |
180 |
181 | def set_yolo_save_dir(args, baseline, scenario):
182 | """
183 | Sets args.save_dir which is used in yolov3 training to save results
184 | """
185 | p = join(args.checkpoints_path, args.dataset, args.model, baseline,
186 | scenario, f'{args.task}_{args.severity}_train_results')
187 | args.save_dir = increment_path(Path(p), exist_ok=args.exist_ok)
188 |
189 |
190 | def set_severity(args):
191 | """
192 | Sets args.severity to the current severity and returns True on success.
193 | For the KITTI dataset this will get the appropriate severity for the
194 | current task. In case of different number of severities among tasks,
195 | False is returned if current args.severity_idx does not exist for the
196 | current task.
197 | """
198 | if args.dataset != 'kitti':
199 | args.severity = args.robustness_severities[args.severity_idx]
200 | return True
201 |
202 | if args.task == 'initial':
203 | args.severity = '' # TODO not tested thoroughly
204 | return True
205 |
206 | if args.severity_idx < len(globals.KITTI_SEVERITIES[args.task]):
207 | args.severity = globals.KITTI_SEVERITIES[args.task][args.severity_idx]
208 | return True
209 |
210 | return False
211 |
212 |
213 | def get_all_severities_str(args):
214 | all_severities_str = ''
215 | for task in globals.TASKS:
216 | if args.dataset != 'kitti':
217 | all_severities_str = f'{args.robustness_severities[args.severity_idx]}_'
218 | break
219 | elif args.severity_idx < len(globals.KITTI_SEVERITIES[task]):
220 | all_severities_str += f'{globals.KITTI_SEVERITIES[task][args.severity_idx]}_'
221 | return all_severities_str
222 |
223 |
224 | class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
225 | """ Dataloader that reuses workers
226 | Uses same syntax as vanilla DataLoader
227 | """
228 |
229 | def __init__(self, *args, **kwargs):
230 | super().__init__(*args, **kwargs)
231 | object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
232 | self.iterator = super().__iter__()
233 |
234 | def __len__(self):
235 | return len(self.batch_sampler.sampler)
236 |
237 | def __iter__(self):
238 | for i in range(len(self)):
239 | yield next(self.iterator)
240 |
241 |
242 | class _RepeatSampler(object):
243 | """ Sampler that repeats forever
244 | Args:
245 | sampler (Sampler)
246 | """
247 |
248 | def __init__(self, sampler):
249 | self.sampler = sampler
250 |
251 | def __iter__(self):
252 | while True:
253 | yield from iter(self.sampler)
254 |
255 |
256 | def dataset_checks(args):
257 | if not args.dataset in config.VALID_DATASETS:
258 | raise Exception(f'Invalid dataset argument: {args.dataset}')
259 |
260 | error = False
261 | if args.dataset == 'cifar10':
262 | error = check_cifar10_c(args)
263 | elif args.dataset in ['imagenet', 'imagenet-mini']:
264 | error = check_imgnet_c(args)
265 |
266 | if error:
267 | raise Exception('Dataset checks unsuccessful!')
268 | else:
269 | log.info('Dataset checks successful!')
270 |
271 |
272 | def check_cifar10_c(args):
273 | CIFAR10(root=args.dataroot, download=True)
274 | error = False
275 | test_set_path = join(args.dataroot, 'CIFAR-10-C', 'test')
276 | train_set_path = join(args.dataroot, 'CIFAR-10-C', 'train')
277 | if not exists(test_set_path):
278 | error = True
279 | log.error(f'CIFAR-10-C test set not found. Expected at {test_set_path}')
280 | if not exists(train_set_path):
281 | error = True
282 | log.error(f'CIFAR-10-C training set not found. Expected at {train_set_path}')
283 | missing_files = []
284 | for task in globals.TASKS:
285 | test_samples = join(test_set_path, task + '.npy')
286 | train_samples = join(train_set_path, task + '.npy')
287 | if not exists(test_samples):
288 | missing_files.append(test_samples)
289 | if not exists(train_samples):
290 | missing_files[:0] = [train_samples]
291 | if len(missing_files):
292 | error = True
293 | log.error('Missing the following CIFAR-10-C samples:')
294 | for f_path in missing_files:
295 | log.error(normpath(f_path))
296 | return error
297 |
298 |
299 | def check_imgnet_c(args):
300 | error = False
301 | val_set_path = join(args.dataroot, args.dataset + '-c', 'val')
302 | train_set_path = join(args.dataroot, args.dataset + '-c', 'train')
303 |
304 | if not exists(val_set_path):
305 | error = True
306 | log.error(f'{args.dataset.capitalize()} validation set not found. '
307 | f'Expected at {val_set_path}')
308 | if not exists(train_set_path):
309 | error = True
310 | log.error(f'{args.dataset.capitalize()} training set not found. '
311 | f'Expected at {train_set_path}')
312 | missing_dirs = []
313 | for task in globals.TASKS:
314 | for severity in globals.SEVERTITIES:
315 | val_samples_dir = join(val_set_path, task, str(severity))
316 | train_samples_dir = join(train_set_path, task, str(severity))
317 | if not exists(val_samples_dir):
318 | missing_dirs.append(val_samples_dir)
319 | if not exists(train_samples_dir):
320 | missing_dirs[:0] = [train_samples_dir]
321 | if len(missing_dirs):
322 | error = True
323 | log.error(f'Missing the following {args.dataset.capitalize()} directories:')
324 | for f_path in missing_dirs:
325 | log.error(normpath(f_path))
326 | return error
327 |
328 |
--------------------------------------------------------------------------------
/models/yolo.py:
--------------------------------------------------------------------------------
1 | """YOLOv3-specific modules
2 | Usage:
3 | $ python path/to/models/yolo.py --cfg yolov3.yaml
4 | """
5 |
6 | import argparse
7 | import logging
8 | import sys
9 | from copy import deepcopy
10 | from pathlib import Path
11 |
12 | sys.path.append(Path(__file__).parent.parent.absolute().__str__()) # to run '$ python *.py' files in subdirectories
13 | logger = logging.getLogger(__name__)
14 |
15 | from models.common import *
16 | from models.experimental import *
17 | from utils.autoanchor import check_anchor_order
18 | from utils.general import make_divisible, check_file, set_logging
19 | from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
20 | select_device, copy_attr
21 |
22 | import math
23 | from copy import copy
24 | from pathlib import Path
25 |
26 |
27 |
28 |
29 | try:
30 | import thop # for FLOPS computation
31 | except ImportError:
32 | thop = None
33 |
34 |
35 | class Detect(nn.Module):
36 | stride = None # strides computed during build
37 | onnx_dynamic = False # ONNX export parameter
38 |
39 | def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
40 | super(Detect, self).__init__()
41 | self.nc = nc # number of classes
42 | self.no = nc + 5 # number of outputs per anchor
43 | self.nl = len(anchors) # number of detection layers
44 | self.na = len(anchors[0]) // 2 # number of anchors
45 | self.grid = [torch.zeros(1)] * self.nl # init grid
46 | a = torch.tensor(anchors).float().view(self.nl, -1, 2)
47 | self.register_buffer('anchors', a) # shape(nl,na,2)
48 | self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
49 | #print(len(ch))
50 | self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
51 | self.inplace = inplace # use in-place ops (e.g. slice assignment)
52 |
53 | def forward(self, x):
54 | # x = x.copy() # for profiling
55 | z = [] # inference output
56 | for i in range(self.nl):
57 | x[i] = self.m[i](x[i]) # conv
58 | bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
59 | x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
60 |
61 | if not self.training: # inference
62 | if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic:
63 | self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
64 |
65 | y = x[i].sigmoid()
66 | if self.inplace:
67 | y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
68 | y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
69 | else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
70 | xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
71 | wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh
72 | y = torch.cat((xy, wh, y[..., 4:]), -1)
73 | z.append(y.view(bs, -1, self.no))
74 |
75 | return x if self.training else (torch.cat(z, 1), x)
76 |
77 | @staticmethod
78 | def _make_grid(nx=20, ny=20):
79 | yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
80 | return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
81 |
82 |
83 | class Model(nn.Module):
84 | def __init__(self, cfg='yolov3.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
85 | super(Model, self).__init__()
86 | if isinstance(cfg, dict):
87 | self.yaml = cfg # model dict
88 | else: # is *.yaml
89 | import yaml # for torch hub
90 | self.yaml_file = Path(cfg).name
91 | with open(cfg) as f:
92 | self.yaml = yaml.safe_load(f) # model dict
93 |
94 | # Define model
95 | ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
96 | if nc and nc != self.yaml['nc']:
97 | logger.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
98 | self.yaml['nc'] = nc # override yaml value
99 | if anchors:
100 | logger.info(f'Overriding model.yaml anchors with anchors={anchors}')
101 | self.yaml['anchors'] = round(anchors) # override yaml value
102 | self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
103 | self.names = [str(i) for i in range(self.yaml['nc'])] # default names
104 | self.inplace = self.yaml.get('inplace', True)
105 | # logger.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
106 |
107 | # Build strides, anchors
108 | m = self.model[-1] # Detect()
109 | if isinstance(m, Detect):
110 | s = 256 # 2x min stride
111 | m.inplace = self.inplace
112 | m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
113 | m.anchors /= m.stride.view(-1, 1, 1)
114 | check_anchor_order(m)
115 | self.stride = m.stride
116 | self._initialize_biases() # only run once
117 | # logger.info('Strides: %s' % m.stride.tolist())
118 |
119 | # Init weights, biases
120 | initialize_weights(self)
121 | self.info()
122 | logger.info('')
123 |
124 | def forward(self, x, augment=False, profile=False):
125 | if augment:
126 | return self.forward_augment(x) # augmented inference, None
127 | else:
128 | return self.forward_once(x, profile) # single-scale inference, train
129 |
130 | def forward_augment(self, x):
131 | img_size = x.shape[-2:] # height, width
132 | s = [1, 0.83, 0.67] # scales
133 | f = [None, 3, None] # flips (2-ud, 3-lr)
134 | y = [] # outputs
135 | for si, fi in zip(s, f):
136 | xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
137 | yi = self.forward_once(xi)[0] # forward
138 | # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
139 | yi = self._descale_pred(yi, fi, si, img_size)
140 | y.append(yi)
141 | return torch.cat(y, 1), None # augmented inference, train
142 |
143 | def forward_once(self, x, profile=False):
144 | y, dt = [], [] # outputs
145 | for m in self.model:
146 | if m.f != -1: # if not from previous layer
147 | x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
148 |
149 | if profile:
150 | o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPS
151 | t = time_synchronized()
152 | for _ in range(10):
153 | _ = m(x)
154 | dt.append((time_synchronized() - t) * 100)
155 | if m == self.model[0]:
156 | logger.info(f"{'time (ms)':>10s} {'GFLOPS':>10s} {'params':>10s} {'module'}")
157 | logger.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
158 |
159 | x = m(x) # run
160 | y.append(x if m.i in self.save else None) # save output
161 |
162 | if profile:
163 | logger.info('%.1fms total' % sum(dt))
164 | #print(x[0].shape)
165 | return x
166 |
167 | def _descale_pred(self, p, flips, scale, img_size):
168 | # de-scale predictions following augmented inference (inverse operation)
169 | if self.inplace:
170 | p[..., :4] /= scale # de-scale
171 | if flips == 2:
172 | p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
173 | elif flips == 3:
174 | p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
175 | else:
176 | x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
177 | if flips == 2:
178 | y = img_size[0] - y # de-flip ud
179 | elif flips == 3:
180 | x = img_size[1] - x # de-flip lr
181 | p = torch.cat((x, y, wh, p[..., 4:]), -1)
182 | return p
183 |
184 | def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
185 | # https://arxiv.org/abs/1708.02002 section 3.3
186 | # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
187 | m = self.model[-1] # Detect() module
188 | for mi, s in zip(m.m, m.stride): # from
189 | b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
190 | b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
191 | b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
192 | mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
193 |
194 | def _print_biases(self):
195 | m = self.model[-1] # Detect() module
196 | for mi in m.m: # from
197 | b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
198 | logger.info(
199 | ('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
200 |
201 | # def _print_weights(self):
202 | # for m in self.model.modules():
203 | # if type(m) is Bottleneck:
204 | # logger.info('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
205 |
206 | '''def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
207 | logger.info('Fusing layers... ')
208 | for m in self.model.modules():
209 | if type(m) is Conv and hasattr(m, 'bn'):
210 | m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
211 | delattr(m, 'bn') # remove batchnorm
212 | m.forward = m.fuseforward # update forward
213 | self.info()
214 | return self'''
215 |
216 | def nms(self, mode=True): # add or remove NMS module
217 | present = type(self.model[-1]) is NMS # last layer is NMS
218 | if mode and not present:
219 | logger.info('Adding NMS... ')
220 | m = NMS() # module
221 | m.f = -1 # from
222 | m.i = self.model[-1].i + 1 # index
223 | self.model.add_module(name='%s' % m.i, module=m) # add
224 | self.eval()
225 | elif not mode and present:
226 | logger.info('Removing NMS... ')
227 | self.model = self.model[:-1] # remove
228 | return self
229 |
230 | def autoshape(self): # add AutoShape module
231 | logger.info('Adding AutoShape... ')
232 | m = AutoShape(self) # wrap model
233 | copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
234 | return m
235 |
236 | def info(self, verbose=False, img_size=640): # print model information
237 | model_info(self, verbose, img_size)
238 |
239 |
240 | def parse_model(d, ch): # model_dict, input_channels(3)
241 | logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
242 | anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
243 | na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
244 | no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
245 |
246 | layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
247 | for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
248 | m = eval(m) if isinstance(m, str) else m # eval strings
249 | for j, a in enumerate(args):
250 | try:
251 | args[j] = eval(a) if isinstance(a, str) else a # eval strings
252 | except:
253 | pass
254 |
255 | n = max(round(n * gd), 1) if n > 1 else n # depth gain
256 | if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP,
257 | C3, C3TR]:
258 | c1, c2 = ch[f], args[0]
259 | if c2 != no: # if not output
260 | c2 = make_divisible(c2 * gw, 8)
261 |
262 | args = [c1, c2, *args[1:]]
263 | if m in [BottleneckCSP, C3, C3TR]:
264 | args.insert(2, n) # number of repeats
265 | n = 1
266 | elif m is nn.BatchNorm2d:
267 | args = [ch[f]]
268 | elif m is Concat:
269 | c2 = sum([ch[x] for x in f])
270 | elif m is Detect:
271 | args.append([ch[x] for x in f])
272 | if isinstance(args[1], int): # number of anchors
273 | args[1] = [list(range(args[1] * 2))] * len(f)
274 | elif m is Contract:
275 | c2 = ch[f] * args[0] ** 2
276 | elif m is Expand:
277 | c2 = ch[f] // args[0] ** 2
278 | else:
279 | c2 = ch[f]
280 |
281 | m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
282 | t = str(m)[8:-2].replace('__main__.', '') # module type
283 | np = sum([x.numel() for x in m_.parameters()]) # number params
284 | m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
285 | logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
286 | save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
287 | layers.append(m_)
288 | if i == 0:
289 | ch = []
290 | ch.append(c2)
291 | return nn.Sequential(*layers), sorted(save)
292 |
293 |
294 | if __name__ == '__main__':
295 | parser = argparse.ArgumentParser()
296 | parser.add_argument('--cfg', type=str, default='yolov3.yaml', help='model.yaml')
297 | parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
298 | opt = parser.parse_args()
299 | opt.cfg = check_file(opt.cfg) # check file
300 | set_logging()
301 | device = select_device(opt.device)
302 |
303 | # Create model
304 | model = Model(opt.cfg).to(device)
305 | model.train()
306 |
307 | # Profile
308 | # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 320, 320).to(device)
309 | # y = model(img, profile=True)
310 |
311 | # Tensorboard (not working https://github.com/ultralytics/yolov5/issues/2898)
312 | # from torch.utils.tensorboard import SummaryWriter
313 | # tb_writer = SummaryWriter('.')
314 | # logger.info("Run 'tensorboard --logdir=models' to view tensorboard at http://localhost:6006/")
315 | # tb_writer.add_graph(torch.jit.trace(model, img, strict=False), []) # add model graph
316 | # tb_writer.add_image('test', img[0], dataformats='CWH') # add model to tensorboard
--------------------------------------------------------------------------------
/models/common.py:
--------------------------------------------------------------------------------
1 | # YOLOv3 common modules
2 |
3 | import math
4 | from copy import copy
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | import pandas as pd
9 | import requests
10 | import torch
11 | import torch.nn as nn
12 | from PIL import Image
13 | from torch.cuda import amp
14 |
15 | from utils.datasets import letterbox
16 | from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh, save_one_box
17 | from utils.plots import colors, plot_one_box
18 | from utils.torch_utils import time_synchronized
19 |
20 |
21 | def autopad(k, p=None): # kernel, padding
22 | # Pad to 'same'
23 | if p is None:
24 | p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
25 | return p
26 |
27 |
28 | def DWConv(c1, c2, k=1, s=1, act=True):
29 | # Depthwise convolution
30 | return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
31 |
32 |
33 | class Conv(nn.Module):
34 | # Standard convolution
35 | def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
36 | super(Conv, self).__init__()
37 | self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
38 | self.bn = nn.BatchNorm2d(c2)
39 | self.act = nn.LeakyReLU(0.1) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
40 |
41 | def forward(self, x):
42 | return self.act(self.bn(self.conv(x)))
43 |
44 | def fuseforward(self, x):
45 | return self.act(self.conv(x))
46 |
47 |
48 | '''class TransformerLayer(nn.Module):
49 | # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
50 | def __init__(self, c, num_heads):
51 | super().__init__()
52 | self.q = nn.Linear(c, c, bias=False)
53 | self.k = nn.Linear(c, c, bias=False)
54 | self.v = nn.Linear(c, c, bias=False)
55 | self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
56 | self.fc1 = nn.Linear(c, c, bias=False)
57 | self.fc2 = nn.Linear(c, c, bias=False)
58 | def forward(self, x):
59 | x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
60 | x = self.fc2(self.fc1(x)) + x
61 | return x
62 | class TransformerBlock(nn.Module):
63 | # Vision Transformer https://arxiv.org/abs/2010.11929
64 | def __init__(self, c1, c2, num_heads, num_layers):
65 | super().__init__()
66 | self.conv = None
67 | if c1 != c2:
68 | self.conv = Conv(c1, c2)
69 | self.linear = nn.Linear(c2, c2) # learnable position embedding
70 | self.tr = nn.Sequential(*[TransformerLayer(c2, num_heads) for _ in range(num_layers)])
71 | self.c2 = c2
72 | def forward(self, x):
73 | if self.conv is not None:
74 | x = self.conv(x)
75 | b, _, w, h = x.shape
76 | p = x.flatten(2)
77 | p = p.unsqueeze(0)
78 | p = p.transpose(0, 3)
79 | p = p.squeeze(3)
80 | e = self.linear(p)
81 | x = p + e
82 | x = self.tr(x)
83 | x = x.unsqueeze(3)
84 | x = x.transpose(0, 3)
85 | x = x.reshape(b, self.c2, w, h)
86 | return x'''
87 |
88 |
89 | class Bottleneck(nn.Module):
90 | # Standard bottleneck
91 | def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
92 | super(Bottleneck, self).__init__()
93 | c_ = int(c2 * e) # hidden channels
94 | self.cv1 = Conv(c1, c_, 1, 1)
95 | self.cv2 = Conv(c_, c2, 3, 1, g=g)
96 | self.add = shortcut and c1 == c2
97 |
98 | def forward(self, x):
99 | return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
100 |
101 |
102 | class BottleneckCSP(nn.Module):
103 | # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
104 | def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
105 | super(BottleneckCSP, self).__init__()
106 | c_ = int(c2 * e) # hidden channels
107 | self.cv1 = Conv(c1, c_, 1, 1)
108 | self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
109 | self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
110 | self.cv4 = Conv(2 * c_, c2, 1, 1)
111 | self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
112 | self.act = nn.LeakyReLU(0.1, inplace=True)
113 | self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
114 |
115 | def forward(self, x):
116 | y1 = self.cv3(self.m(self.cv1(x)))
117 | y2 = self.cv2(x)
118 | return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
119 |
120 |
121 | class C3(nn.Module):
122 | # CSP Bottleneck with 3 convolutions
123 | def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
124 | super(C3, self).__init__()
125 | c_ = int(c2 * e) # hidden channels
126 | self.cv1 = Conv(c1, c_, 1, 1)
127 | self.cv2 = Conv(c1, c_, 1, 1)
128 | self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
129 | self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
130 | # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
131 |
132 | def forward(self, x):
133 | return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
134 |
135 |
136 | class C3TR(C3):
137 | # C3 module with TransformerBlock()
138 | def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
139 | super().__init__(c1, c2, n, shortcut, g, e)
140 | c_ = int(c2 * e)
141 | self.m = TransformerBlock(c_, c_, 4, n)
142 |
143 |
144 | class SPP(nn.Module):
145 | # Spatial pyramid pooling layer used in YOLOv3-SPP
146 | def __init__(self, c1, c2, k=(5, 9, 13)):
147 | super(SPP, self).__init__()
148 | c_ = c1 // 2 # hidden channels
149 | self.cv1 = Conv(c1, c_, 1, 1)
150 | self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
151 | self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
152 |
153 | def forward(self, x):
154 | x = self.cv1(x)
155 | return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
156 |
157 |
158 | class Focus(nn.Module):
159 | # Focus wh information into c-space
160 | def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
161 | super(Focus, self).__init__()
162 | self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
163 | # self.contract = Contract(gain=2)
164 |
165 | def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
166 | return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
167 | # return self.conv(self.contract(x))
168 |
169 |
170 | class Contract(nn.Module):
171 | # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
172 | def __init__(self, gain=2):
173 | super().__init__()
174 | self.gain = gain
175 |
176 | def forward(self, x):
177 | N, C, H, W = x.size() # assert (H / s == 0) and (W / s == 0), 'Indivisible gain'
178 | s = self.gain
179 | x = x.view(N, C, H // s, s, W // s, s) # x(1,64,40,2,40,2)
180 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
181 | return x.view(N, C * s * s, H // s, W // s) # x(1,256,40,40)
182 |
183 |
184 | class Expand(nn.Module):
185 | # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
186 | def __init__(self, gain=2):
187 | super().__init__()
188 | self.gain = gain
189 |
190 | def forward(self, x):
191 | N, C, H, W = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
192 | s = self.gain
193 | x = x.view(N, s, s, C // s ** 2, H, W) # x(1,2,2,16,80,80)
194 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
195 | return x.view(N, C // s ** 2, H * s, W * s) # x(1,16,160,160)
196 |
197 |
198 | class Concat(nn.Module):
199 | # Concatenate a list of tensors along dimension
200 | def __init__(self, dimension=1):
201 | super(Concat, self).__init__()
202 | self.d = dimension
203 |
204 | def forward(self, x):
205 | return torch.cat(x, self.d)
206 |
207 |
208 | class NMS(nn.Module):
209 | # Non-Maximum Suppression (NMS) module
210 | conf = 0.25 # confidence threshold
211 | iou = 0.45 # IoU threshold
212 | classes = None # (optional list) filter by class
213 | max_det = 1000 # maximum number of detections per image
214 |
215 | def __init__(self):
216 | super(NMS, self).__init__()
217 |
218 | def forward(self, x):
219 | return non_max_suppression(x[0], self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det)
220 |
221 |
222 | class AutoShape(nn.Module):
223 | # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
224 | conf = 0.25 # NMS confidence threshold
225 | iou = 0.45 # NMS IoU threshold
226 | classes = None # (optional list) filter by class
227 | max_det = 1000 # maximum number of detections per image
228 |
229 | def __init__(self, model):
230 | super(AutoShape, self).__init__()
231 | self.model = model.eval()
232 |
233 | def autoshape(self):
234 | print('AutoShape already enabled, skipping... ') # model already converted to model.autoshape()
235 | return self
236 |
237 | @torch.no_grad()
238 | def forward(self, imgs, size=640, augment=False, profile=False):
239 | # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
240 | # filename: imgs = 'data/images/zidane.jpg'
241 | # URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg'
242 | # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
243 | # PIL: = Image.open('image.jpg') # HWC x(640,1280,3)
244 | # numpy: = np.zeros((640,1280,3)) # HWC
245 | # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
246 | # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
247 |
248 | t = [time_synchronized()]
249 | p = next(self.model.parameters()) # for device and type
250 | if isinstance(imgs, torch.Tensor): # torch
251 | with amp.autocast(enabled=p.device.type != 'cpu'):
252 | return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
253 |
254 | # Pre-process
255 | n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
256 | shape0, shape1, files = [], [], [] # image and inference shapes, filenames
257 | for i, im in enumerate(imgs):
258 | f = f'image{i}' # filename
259 | if isinstance(im, str): # filename or uri
260 | im, f = np.asarray(Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im)), im
261 | elif isinstance(im, Image.Image): # PIL Image
262 | im, f = np.asarray(im), getattr(im, 'filename', f) or f
263 | files.append(Path(f).with_suffix('.jpg').name)
264 | if im.shape[0] < 5: # image in CHW
265 | im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
266 | im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input
267 | s = im.shape[:2] # HWC
268 | shape0.append(s) # image shape
269 | g = (size / max(s)) # gain
270 | shape1.append([y * g for y in s])
271 | imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
272 | shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
273 | x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
274 | x = np.stack(x, 0) if n > 1 else x[0][None] # stack
275 | x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
276 | x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
277 | t.append(time_synchronized())
278 |
279 | with amp.autocast(enabled=p.device.type != 'cpu'):
280 | # Inference
281 | y = self.model(x, augment, profile)[0] # forward
282 | t.append(time_synchronized())
283 |
284 | # Post-process
285 | y = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS
286 | for i in range(n):
287 | scale_coords(shape1, y[i][:, :4], shape0[i])
288 |
289 | t.append(time_synchronized())
290 | return Detections(imgs, y, files, t, self.names, x.shape)
291 |
292 |
293 | class Detections:
294 | # detections class for YOLOv3 inference results
295 | def __init__(self, imgs, pred, files, times=None, names=None, shape=None):
296 | super(Detections, self).__init__()
297 | d = pred[0].device # device
298 | gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations
299 | self.imgs = imgs # list of images as numpy arrays
300 | self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
301 | self.names = names # class names
302 | self.files = files # image filenames
303 | self.xyxy = pred # xyxy pixels
304 | self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
305 | self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
306 | self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
307 | self.n = len(self.pred) # number of images (batch size)
308 | self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
309 | self.s = shape # inference BCHW shape
310 |
311 | def display(self, pprint=False, show=False, save=False, crop=False, render=False, save_dir=Path('')):
312 | for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
313 | str = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} '
314 | if pred is not None:
315 | for c in pred[:, -1].unique():
316 | n = (pred[:, -1] == c).sum() # detections per class
317 | str += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
318 | if show or save or render or crop:
319 | for *box, conf, cls in pred: # xyxy, confidence, class
320 | label = f'{self.names[int(cls)]} {conf:.2f}'
321 | if crop:
322 | save_one_box(box, im, file=save_dir / 'crops' / self.names[int(cls)] / self.files[i])
323 | else: # all others
324 | plot_one_box(box, im, label=label, color=colors(cls))
325 |
326 | im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
327 | if pprint:
328 | print(str.rstrip(', '))
329 | if show:
330 | im.show(self.files[i]) # show
331 | if save:
332 | f = self.files[i]
333 | im.save(save_dir / f) # save
334 | print(f"{'Saved' * (i == 0)} {f}", end=',' if i < self.n - 1 else f' to {save_dir}\n')
335 | if render:
336 | self.imgs[i] = np.asarray(im)
337 |
338 | def print(self):
339 | self.display(pprint=True) # print results
340 | print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' % self.t)
341 |
342 | def show(self):
343 | self.display(show=True) # show results
344 |
345 | def save(self, save_dir='runs/hub/exp'):
346 | save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp', mkdir=True) # increment save_dir
347 | self.display(save=True, save_dir=save_dir) # save results
348 |
349 | def crop(self, save_dir='runs/hub/exp'):
350 | save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp', mkdir=True) # increment save_dir
351 | self.display(crop=True, save_dir=save_dir) # crop results
352 | print(f'Saved results to {save_dir}\n')
353 |
354 | def render(self):
355 | self.display(render=True) # render results
356 | return self.imgs
357 |
358 | def pandas(self):
359 | # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
360 | new = copy(self) # return copy
361 | ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
362 | cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
363 | for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
364 | a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
365 | setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
366 | return new
367 |
368 | def tolist(self):
369 | # return a list of Detections objects, i.e. 'for result in results.tolist():'
370 | x = [Detections([self.imgs[i]], [self.pred[i]], self.names, self.s) for i in range(self.n)]
371 | for d in x:
372 | for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
373 | setattr(d, k, getattr(d, k)[0]) # pop out of list
374 | return x
375 |
376 | def __len__(self):
377 | return self.n
378 |
379 |
380 | class Classify(nn.Module):
381 | # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
382 | def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
383 | super(Classify, self).__init__()
384 | self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
385 | self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
386 | self.flat = nn.Flatten()
387 |
388 | def forward(self, x):
389 | z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
390 | return self.flat(self.conv(z)) # flatten to x(b,c2)
--------------------------------------------------------------------------------
/utils/plots.py:
--------------------------------------------------------------------------------
1 | # Plotting utils
2 |
3 | import glob
4 | import math
5 | import os
6 | import random
7 | from copy import copy
8 | from pathlib import Path
9 |
10 | import cv2
11 | import matplotlib
12 | import matplotlib.pyplot as plt
13 | import numpy as np
14 | import pandas as pd
15 | import seaborn as sns
16 | import torch
17 | import yaml
18 | from PIL import Image, ImageDraw, ImageFont
19 |
20 | from utils.general import xywh2xyxy, xyxy2xywh
21 | from utils.metrics import fitness
22 |
23 | # Settings
24 | matplotlib.rc('font', **{'size': 11})
25 | matplotlib.use('Agg') # for writing to files only
26 |
27 |
28 | class Colors:
29 | # Ultralytics color palette https://ultralytics.com/
30 | def __init__(self):
31 | # hex = matplotlib.colors.TABLEAU_COLORS.values()
32 | hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
33 | '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
34 | self.palette = [self.hex2rgb('#' + c) for c in hex]
35 | self.n = len(self.palette)
36 |
37 | def __call__(self, i, bgr=False):
38 | c = self.palette[int(i) % self.n]
39 | return (c[2], c[1], c[0]) if bgr else c
40 |
41 | @staticmethod
42 | def hex2rgb(h): # rgb order (PIL)
43 | return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
44 |
45 |
46 | colors = Colors() # create instance for 'from utils.plots import colors'
47 |
48 |
49 | def hist2d(x, y, n=100):
50 | # 2d histogram used in labels.png and evolve.png
51 | xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
52 | hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
53 | xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
54 | yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
55 | return np.log(hist[xidx, yidx])
56 |
57 |
58 | def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
59 | from scipy.signal import butter, filtfilt
60 |
61 | # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
62 | def butter_lowpass(cutoff, fs, order):
63 | nyq = 0.5 * fs
64 | normal_cutoff = cutoff / nyq
65 | return butter(order, normal_cutoff, btype='low', analog=False)
66 |
67 | b, a = butter_lowpass(cutoff, fs, order=order)
68 | return filtfilt(b, a, data) # forward-backward filter
69 |
70 |
71 | def plot_one_box(x, im, color=(128, 128, 128), label=None, line_thickness=3):
72 | # Plots one bounding box on image 'im' using OpenCV
73 | assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.'
74 | tl = line_thickness or round(0.002 * (im.shape[0] + im.shape[1]) / 2) + 1 # line/font thickness
75 | c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
76 | cv2.rectangle(im, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
77 | if label:
78 | tf = max(tl - 1, 1) # font thickness
79 | t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
80 | c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
81 | cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA) # filled
82 | cv2.putText(im, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
83 |
84 |
85 | def plot_one_box_PIL(box, im, color=(128, 128, 128), label=None, line_thickness=None):
86 | # Plots one bounding box on image 'im' using PIL
87 | im = Image.fromarray(im)
88 | draw = ImageDraw.Draw(im)
89 | line_thickness = line_thickness or max(int(min(im.size) / 200), 2)
90 | draw.rectangle(box, width=line_thickness, outline=color) # plot
91 | if label:
92 | font = ImageFont.truetype("Arial.ttf", size=max(round(max(im.size) / 40), 12))
93 | txt_width, txt_height = font.getsize(label)
94 | draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=color)
95 | draw.text((box[0], box[1] - txt_height + 1), label, fill=(255, 255, 255), font=font)
96 | return np.asarray(im)
97 |
98 |
99 | def plot_wh_methods(): # from utils.plots import *; plot_wh_methods()
100 | # Compares the two methods for width-height anchor multiplication
101 | # https://github.com/ultralytics/yolov3/issues/168
102 | x = np.arange(-4.0, 4.0, .1)
103 | ya = np.exp(x)
104 | yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2
105 |
106 | fig = plt.figure(figsize=(6, 3), tight_layout=True)
107 | plt.plot(x, ya, '.-', label='YOLOv3')
108 | plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2')
109 | plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6')
110 | plt.xlim(left=-4, right=4)
111 | plt.ylim(bottom=0, top=6)
112 | plt.xlabel('input')
113 | plt.ylabel('output')
114 | plt.grid()
115 | plt.legend()
116 | fig.savefig('comparison.png', dpi=200)
117 |
118 |
119 | def output_to_target(output):
120 | # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
121 | targets = []
122 | for i, o in enumerate(output):
123 | for *box, conf, cls in o.cpu().numpy():
124 | targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
125 | return np.array(targets)
126 |
127 |
128 | def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
129 | # Plot image grid with labels
130 |
131 | if isinstance(images, torch.Tensor):
132 | images = images.cpu().float().numpy()
133 | if isinstance(targets, torch.Tensor):
134 | targets = targets.cpu().numpy()
135 |
136 | # un-normalise
137 | if np.max(images[0]) <= 1:
138 | images *= 255
139 |
140 | tl = 3 # line thickness
141 | tf = max(tl - 1, 1) # font thickness
142 | bs, _, h, w = images.shape # batch size, _, height, width
143 | bs = min(bs, max_subplots) # limit plot images
144 | ns = np.ceil(bs ** 0.5) # number of subplots (square)
145 |
146 | # Check if we should resize
147 | scale_factor = max_size / max(h, w)
148 | if scale_factor < 1:
149 | h = math.ceil(scale_factor * h)
150 | w = math.ceil(scale_factor * w)
151 |
152 | mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
153 | for i, img in enumerate(images):
154 | if i == max_subplots: # if last batch has fewer images than we expect
155 | break
156 |
157 | block_x = int(w * (i // ns))
158 | block_y = int(h * (i % ns))
159 |
160 | img = img.transpose(1, 2, 0)
161 | if scale_factor < 1:
162 | img = cv2.resize(img, (w, h))
163 |
164 | mosaic[block_y:block_y + h, block_x:block_x + w, :] = img
165 | if len(targets) > 0:
166 | image_targets = targets[targets[:, 0] == i]
167 | boxes = xywh2xyxy(image_targets[:, 2:6]).T
168 | classes = image_targets[:, 1].astype('int')
169 | labels = image_targets.shape[1] == 6 # labels if no conf column
170 | conf = None if labels else image_targets[:, 6] # check for confidence presence (label vs pred)
171 |
172 | if boxes.shape[1]:
173 | if boxes.max() <= 1.01: # if normalized with tolerance 0.01
174 | boxes[[0, 2]] *= w # scale to pixels
175 | boxes[[1, 3]] *= h
176 | elif scale_factor < 1: # absolute coords need scale if image scales
177 | boxes *= scale_factor
178 | boxes[[0, 2]] += block_x
179 | boxes[[1, 3]] += block_y
180 | for j, box in enumerate(boxes.T):
181 | cls = int(classes[j])
182 | color = colors(cls)
183 | cls = names[cls] if names else cls
184 | if labels or conf[j] > 0.25: # 0.25 conf thresh
185 | label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])
186 | plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)
187 |
188 | # Draw image filename labels
189 | if paths:
190 | label = Path(paths[i]).name[:40] # trim to 40 char
191 | t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
192 | cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf,
193 | lineType=cv2.LINE_AA)
194 |
195 | # Image border
196 | cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)
197 |
198 | if fname:
199 | r = min(1280. / max(h, w) / ns, 1.0) # ratio to limit image size
200 | mosaic = cv2.resize(mosaic, (int(ns * w * r), int(ns * h * r)), interpolation=cv2.INTER_AREA)
201 | # cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2 save
202 | Image.fromarray(mosaic).save(fname) # PIL save
203 | return mosaic
204 |
205 |
206 | def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
207 | # Plot LR simulating training for full epochs
208 | optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
209 | y = []
210 | for _ in range(epochs):
211 | scheduler.step()
212 | y.append(optimizer.param_groups[0]['lr'])
213 | plt.plot(y, '.-', label='LR')
214 | plt.xlabel('epoch')
215 | plt.ylabel('LR')
216 | plt.grid()
217 | plt.xlim(0, epochs)
218 | plt.ylim(0)
219 | plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
220 | plt.close()
221 |
222 |
223 | def plot_test_txt(): # from utils.plots import *; plot_test()
224 | # Plot test.txt histograms
225 | x = np.loadtxt('test.txt', dtype=np.float32)
226 | box = xyxy2xywh(x[:, :4])
227 | cx, cy = box[:, 0], box[:, 1]
228 |
229 | fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
230 | ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
231 | ax.set_aspect('equal')
232 | plt.savefig('hist2d.png', dpi=300)
233 |
234 | fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
235 | ax[0].hist(cx, bins=600)
236 | ax[1].hist(cy, bins=600)
237 | plt.savefig('hist1d.png', dpi=200)
238 |
239 |
240 | def plot_targets_txt(): # from utils.plots import *; plot_targets_txt()
241 | # Plot targets.txt histograms
242 | x = np.loadtxt('targets.txt', dtype=np.float32).T
243 | s = ['x targets', 'y targets', 'width targets', 'height targets']
244 | fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
245 | ax = ax.ravel()
246 | for i in range(4):
247 | ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))
248 | ax[i].legend()
249 | ax[i].set_title(s[i])
250 | plt.savefig('targets.jpg', dpi=200)
251 |
252 |
253 | def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_txt()
254 | # Plot study.txt generated by test.py
255 | fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)
256 | # ax = ax.ravel()
257 |
258 | fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
259 | # for f in [Path(path) / f'study_coco_{x}.txt' for x in ['yolov3-tiny', 'yolov3', 'yolov3-spp', 'yolov5l']]:
260 | for f in sorted(Path(path).glob('study*.txt')):
261 | y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
262 | x = np.arange(y.shape[1]) if x is None else np.array(x)
263 | s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)']
264 | # for i in range(7):
265 | # ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
266 | # ax[i].set_title(s[i])
267 |
268 | j = y[3].argmax() + 1
269 | ax2.plot(y[6, 1:j], y[3, 1:j] * 1E2, '.-', linewidth=2, markersize=8,
270 | label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
271 |
272 | ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
273 | 'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')
274 |
275 | ax2.grid(alpha=0.2)
276 | ax2.set_yticks(np.arange(20, 60, 5))
277 | ax2.set_xlim(0, 57)
278 | ax2.set_ylim(15, 55)
279 | ax2.set_xlabel('GPU Speed (ms/img)')
280 | ax2.set_ylabel('COCO AP val')
281 | ax2.legend(loc='lower right')
282 | plt.savefig(str(Path(path).name) + '.png', dpi=300)
283 |
284 |
285 | def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
286 | # plot dataset labels
287 | print('Plotting labels... ')
288 | c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
289 | nc = int(c.max() + 1) # number of classes
290 | x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
291 |
292 | # seaborn correlogram
293 | sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
294 | plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
295 | plt.close()
296 |
297 | # matplotlib labels
298 | matplotlib.use('svg') # faster
299 | ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
300 | y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
301 | # [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195
302 | ax[0].set_ylabel('instances')
303 | if 0 < len(names) < 30:
304 | ax[0].set_xticks(range(len(names)))
305 | ax[0].set_xticklabels(names, rotation=90, fontsize=10)
306 | else:
307 | ax[0].set_xlabel('classes')
308 | sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
309 | sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
310 |
311 | # rectangles
312 | labels[:, 1:3] = 0.5 # center
313 | labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
314 | img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
315 | for cls, *box in labels[:1000]:
316 | ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
317 | ax[1].imshow(img)
318 | ax[1].axis('off')
319 |
320 | for a in [0, 1, 2, 3]:
321 | for s in ['top', 'right', 'left', 'bottom']:
322 | ax[a].spines[s].set_visible(False)
323 |
324 | plt.savefig(save_dir / 'labels.jpg', dpi=200)
325 | matplotlib.use('Agg')
326 | plt.close()
327 |
328 | # loggers
329 | for k, v in loggers.items() or {}:
330 | if k == 'wandb' and v:
331 | v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.jpg')]}, commit=False)
332 |
333 |
334 | def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
335 | # Plot hyperparameter evolution results in evolve.txt
336 | with open(yaml_file) as f:
337 | hyp = yaml.safe_load(f)
338 | x = np.loadtxt('evolve.txt', ndmin=2)
339 | f = fitness(x)
340 | # weights = (f - f.min()) ** 2 # for weighted results
341 | plt.figure(figsize=(10, 12), tight_layout=True)
342 | matplotlib.rc('font', **{'size': 8})
343 | for i, (k, v) in enumerate(hyp.items()):
344 | y = x[:, i + 7]
345 | # mu = (y * weights).sum() / weights.sum() # best weighted result
346 | mu = y[f.argmax()] # best single result
347 | plt.subplot(6, 5, i + 1)
348 | plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
349 | plt.plot(mu, f.max(), 'k+', markersize=15)
350 | plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
351 | if i % 5 != 0:
352 | plt.yticks([])
353 | print('%15s: %.3g' % (k, mu))
354 | plt.savefig('evolve.png', dpi=200)
355 | print('\nPlot saved as evolve.png')
356 |
357 |
358 | def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
359 | # Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
360 | ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
361 | s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
362 | files = list(Path(save_dir).glob('frames*.txt'))
363 | for fi, f in enumerate(files):
364 | try:
365 | results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows
366 | n = results.shape[1] # number of rows
367 | x = np.arange(start, min(stop, n) if stop else n)
368 | results = results[:, x]
369 | t = (results[0] - results[0].min()) # set t0=0s
370 | results[0] = x
371 | for i, a in enumerate(ax):
372 | if i < len(results):
373 | label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
374 | a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
375 | a.set_title(s[i])
376 | a.set_xlabel('time (s)')
377 | # if fi == len(files) - 1:
378 | # a.set_ylim(bottom=0)
379 | for side in ['top', 'right']:
380 | a.spines[side].set_visible(False)
381 | else:
382 | a.remove()
383 | except Exception as e:
384 | print('Warning: Plotting error for %s; %s' % (f, e))
385 |
386 | ax[1].legend()
387 | plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
388 |
389 |
390 | def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_results_overlay()
391 | # Plot training 'results*.txt', overlaying train and val losses
392 | s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends
393 | t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
394 | for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
395 | results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
396 | n = results.shape[1] # number of rows
397 | x = range(start, min(stop, n) if stop else n)
398 | fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True)
399 | ax = ax.ravel()
400 | for i in range(5):
401 | for j in [i, i + 5]:
402 | y = results[j, x]
403 | ax[i].plot(x, y, marker='.', label=s[j])
404 | # y_smooth = butter_lowpass_filtfilt(y)
405 | # ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j])
406 |
407 | ax[i].set_title(t[i])
408 | ax[i].legend()
409 | ax[i].set_ylabel(f) if i == 0 else None # add filename
410 | fig.savefig(f.replace('.txt', '.png'), dpi=200)
411 |
412 |
413 | def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
414 | # Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp')
415 | fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
416 | ax = ax.ravel()
417 | s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall',
418 | 'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95']
419 | if bucket:
420 | # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
421 | files = ['results%g.txt' % x for x in id]
422 | c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id)
423 | os.system(c)
424 | else:
425 | files = list(Path(save_dir).glob('results*.txt'))
426 | assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir)
427 | for fi, f in enumerate(files):
428 | try:
429 | results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
430 | n = results.shape[1] # number of rows
431 | x = range(start, min(stop, n) if stop else n)
432 | for i in range(10):
433 | y = results[i, x]
434 | if i in [0, 1, 2, 5, 6, 7]:
435 | y[y == 0] = np.nan # don't show zero loss values
436 | # y /= y[0] # normalize
437 | label = labels[fi] if len(labels) else f.stem
438 | ax[i].plot(x, y, marker='.', label=label, linewidth=2, markersize=8)
439 | ax[i].set_title(s[i])
440 | # if i in [5, 6, 7]: # share train and val loss y axes
441 | # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
442 | except Exception as e:
443 | print('Warning: Plotting error for %s; %s' % (f, e))
444 |
445 | ax[1].legend()
446 | fig.savefig(Path(save_dir) / 'results.png', dpi=200)
--------------------------------------------------------------------------------