├── assets └── image-20191216151753209.png ├── crucible ├── data.py ├── __init__.py ├── io.py ├── init.py ├── modules.py ├── vision.py ├── metrics.py ├── utils.py ├── segtrainer.py └── trainer.py ├── README.md ├── focalloss.py ├── LICENSE ├── sampler.py ├── train.py └── Explaination.md /assets/image-20191216151753209.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashawkey/FocalLoss.pytorch/HEAD/assets/image-20191216151753209.png -------------------------------------------------------------------------------- /crucible/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | from torch.utils.data import Dataset, DataLoader, default_collate 5 | 6 | 7 | -------------------------------------------------------------------------------- /crucible/__init__.py: -------------------------------------------------------------------------------- 1 | # necessary in windows 2 | from . import metrics 3 | from . import vision 4 | from . import io 5 | from . import utils 6 | # pre import 7 | from .trainer import Trainer 8 | from .utils import EmailSender 9 | from .utils import fix_random_seed 10 | from .utils import backup 11 | 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Focal Loss 2 | 3 | [Paper](https://arxiv.org/abs/1708.02002) 4 | 5 | This is a focal loss implementation in pytorch. 6 | 7 | 8 | 9 | ### Simple Experiment 10 | 11 | Running results from the `train.py`. 12 | 13 | Also compared with [imbalanced-dataset-sampler](https://github.com/ufoym/imbalanced-dataset-sampler), and it seems that it is much better to use balanced sample method if your task can use it (eg. classification). 14 | 15 | ![image-20191216151753209](./assets/image-20191216151753209.png) 16 | -------------------------------------------------------------------------------- /focalloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import numpy as np 6 | 7 | class FocalLoss(nn.Module): 8 | '''Multi-class Focal loss implementation''' 9 | def __init__(self, gamma=2, weight=None): 10 | super(FocalLoss, self).__init__() 11 | self.gamma = gamma 12 | self.weight = weight 13 | 14 | def forward(self, input, target): 15 | """ 16 | input: [N, C] 17 | target: [N, ] 18 | """ 19 | logpt = F.log_softmax(input, dim=1) 20 | pt = torch.exp(logpt) 21 | logpt = (1-pt)**self.gamma * logpt 22 | loss = F.nll_loss(logpt, target, self.weight) 23 | return loss 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Hawkey 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import torchvision 4 | 5 | 6 | class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler): 7 | """Samples elements randomly from a given list of indices for imbalanced dataset 8 | Arguments: 9 | indices (list, optional): a list of indices 10 | num_samples (int, optional): number of samples to draw 11 | """ 12 | 13 | def __init__(self, dataset, indices=None, num_samples=None): 14 | 15 | # if indices is not provided, 16 | # all elements in the dataset will be considered 17 | self.indices = list(range(len(dataset))) \ 18 | if indices is None else indices 19 | 20 | # if num_samples is not provided, 21 | # draw `len(indices)` samples in each iteration 22 | self.num_samples = len(self.indices) \ 23 | if num_samples is None else num_samples 24 | 25 | # distribution of classes in the dataset 26 | label_to_count = {} 27 | for idx in self.indices: 28 | label = self._get_label(dataset, idx) 29 | if label in label_to_count: 30 | label_to_count[label] += 1 31 | else: 32 | label_to_count[label] = 1 33 | 34 | # weight for each sample 35 | weights = [1.0 / label_to_count[self._get_label(dataset, idx)] 36 | for idx in self.indices] 37 | self.weights = torch.DoubleTensor(weights) 38 | 39 | def _get_label(self, dataset, idx): 40 | return dataset[idx]['truth'].item() 41 | 42 | def __iter__(self): 43 | return (self.indices[i] for i in torch.multinomial( 44 | self.weights, self.num_samples, replacement=True)) 45 | 46 | def __len__(self): 47 | return self.num_samples 48 | -------------------------------------------------------------------------------- /crucible/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import csv 5 | from pprint import pprint 6 | 7 | """ 8 | method for read/load config files, and logging. 9 | """ 10 | 11 | def load_json(filename): 12 | with open(filename, "r") as f: 13 | res = json.load(f) 14 | return res 15 | 16 | def write_json(filename, data): 17 | with open(filename, "w") as f: 18 | json.dump(data, f) 19 | 20 | 21 | class logger: 22 | def __init__(self, workspace=None, flush=True, mute=False): 23 | """ 24 | logger class. 25 | param: 26 | workspace: path to save log file, if None, only print to stdout. 27 | flush: force flushing when printing. 28 | """ 29 | self.workspace = workspace 30 | self.flush = flush 31 | self.mute = mute 32 | if workspace is not None: 33 | os.makedirs(workspace, exist_ok=True) 34 | self.log_file = os.path.join(workspace, "log.txt") 35 | self.fp = open(self.log_file, "a+") 36 | else: 37 | self.fp = None 38 | 39 | def __del__(self): 40 | if self.fp: 41 | self.fp.close() 42 | 43 | def _print(self, text, use_pprint=False): 44 | if not self.mute: 45 | print(text) if not use_pprint else pprint(text) 46 | if self.fp: 47 | print(text, file=self.fp) 48 | if self.flush: 49 | sys.stdout.flush() 50 | 51 | 52 | def log(self, text, level=0): 53 | text = "\t"*level + text 54 | text.replace("\n", "\n"+"\t"*level) 55 | self._print(text) 56 | 57 | def log1(self, text): 58 | self.log(text, level=1) 59 | 60 | def info(self, text): 61 | text = "[INFO] " + text 62 | text.replace("\n", "\n"+"[INFO] ") 63 | self._print(text) 64 | 65 | def error(self, text): 66 | text = "[ERROR] " + text 67 | text.replace("\n", "\n"+"[ERROR] ") 68 | self._print(text) 69 | 70 | def logblock(self, text): 71 | self._print("#####################") 72 | self._print(text, use_pprint=True) 73 | self._print("#####################") 74 | -------------------------------------------------------------------------------- /crucible/init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | 5 | 6 | def init_weights(m, a=1e-2): 7 | if isinstance(m, nn.Conv1d): 8 | init.normal_(m.weight.data) 9 | if m.bias is not None: 10 | init.normal_(m.bias.data) 11 | elif isinstance(m, nn.Conv2d): 12 | init.kaiming_normal_(m.weight.data, a=a) 13 | if m.bias is not None: 14 | init.normal_(m.bias.data) 15 | elif isinstance(m, nn.Conv3d): 16 | init.kaiming_normal_(m.weight.data, a=a) 17 | if m.bias is not None: 18 | init.normal_(m.bias.data) 19 | elif isinstance(m, nn.ConvTranspose1d): 20 | init.normal_(m.weight.data) 21 | if m.bias is not None: 22 | init.normal_(m.bias.data) 23 | elif isinstance(m, nn.ConvTranspose2d): 24 | init.kaiming_normal_(m.weight.data, a=a) 25 | if m.bias is not None: 26 | init.normal_(m.bias.data) 27 | elif isinstance(m, nn.ConvTranspose3d): 28 | init.kaiming_normal_(m.weight.data, a=a) 29 | if m.bias is not None: 30 | init.normal_(m.bias.data) 31 | elif isinstance(m, nn.BatchNorm1d): 32 | init.normal_(m.weight.data, mean=1, std=0.02) 33 | init.constant_(m.bias.data, 0) 34 | elif isinstance(m, nn.BatchNorm2d): 35 | init.normal_(m.weight.data, mean=1, std=0.02) 36 | init.constant_(m.bias.data, 0) 37 | elif isinstance(m, nn.BatchNorm3d): 38 | init.normal_(m.weight.data, mean=1, std=0.02) 39 | init.constant_(m.bias.data, 0) 40 | elif isinstance(m, nn.Linear): 41 | init.kaiming_normal_(m.weight.data) 42 | init.normal_(m.bias.data) 43 | elif isinstance(m, nn.LSTM): 44 | for param in m.parameters(): 45 | if len(param.shape) >= 2: 46 | init.orthogonal_(param.data) 47 | else: 48 | init.normal_(param.data) 49 | elif isinstance(m, nn.LSTMCell): 50 | for param in m.parameters(): 51 | if len(param.shape) >= 2: 52 | init.orthogonal_(param.data) 53 | else: 54 | init.normal_(param.data) 55 | elif isinstance(m, nn.GRU): 56 | for param in m.parameters(): 57 | if len(param.shape) >= 2: 58 | init.orthogonal_(param.data) 59 | else: 60 | init.normal_(param.data) 61 | elif isinstance(m, nn.GRUCell): 62 | for param in m.parameters(): 63 | if len(param.shape) >= 2: 64 | init.orthogonal_(param.data) 65 | else: 66 | init.normal_(param.data) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import torch.utils.data as data 6 | from torch.optim import lr_scheduler 7 | import numpy as np 8 | from easydict import EasyDict 9 | 10 | from focalloss import FocalLoss 11 | from sampler import ImbalancedDatasetSampler 12 | 13 | from crucible.trainer import Trainer 14 | from crucible.metrics import * 15 | from crucible.io import logger 16 | from crucible.utils import fix_random_seed 17 | 18 | class fcdr(nn.Module): 19 | def __init__(self, Fin, Fout, dp=0.5): 20 | super(fcdr, self).__init__() 21 | self.fc = nn.Linear(Fin, Fout) 22 | self.dp = nn.Dropout(dp) 23 | self.ac = nn.ReLU(True) 24 | def forward(self, x): 25 | x = self.fc(x) 26 | x = self.dp(x) 27 | x = self.ac(x) 28 | return x 29 | 30 | class FCN(nn.Module): 31 | def __init__(self): 32 | super(FCN, self).__init__() 33 | self.fc0 = fcdr(10, 256) 34 | self.fc1 = fcdr(256, 512) 35 | self.fc2 = nn.Linear(512, 2) 36 | 37 | def forward(self, x): 38 | x = self.fc0(x) 39 | x = self.fc1(x) 40 | x = self.fc2(x) # [B, 2] 41 | return x 42 | 43 | class BiasedDataset(data.Dataset): 44 | def __init__(self, N, p=[0.9, 0.1]): 45 | super().__init__() 46 | self.N = N 47 | self.p = p 48 | Y = np.random.choice(2, size=(N, 1), p=p).astype(np.int64) 49 | X = (np.random.rand(N, 10) + Y * 0.2).astype(np.float32) # noised Y 50 | self.Y = torch.LongTensor(Y).squeeze() 51 | self.X = torch.FloatTensor(X) 52 | 53 | def __getitem__(self, index): 54 | data = { 55 | "input": self.X[index], 56 | "truth": self.Y[index], 57 | } 58 | 59 | return data 60 | 61 | def __len__(self): 62 | return self.N 63 | 64 | if __name__ == "__main__": 65 | conf = EasyDict() 66 | conf.workspace = 'workspace/imbalanced_sampler_cross_entropy' 67 | conf.device = 'cuda' 68 | conf.max_epochs = 100 69 | 70 | log = logger(conf.workspace) 71 | model = FCN() 72 | #loss_function = FocalLoss() 73 | loss_function = nn.CrossEntropyLoss() 74 | optimizer = optim.Adam(model.parameters()) 75 | scheduler = lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1) 76 | 77 | metrics = [ClassificationMeter(2),] 78 | train_dataset = BiasedDataset(1000, [0.9, 0.1]) 79 | test_dataset = BiasedDataset(1000, [0.5, 0.5]) 80 | loaders = { 81 | "train": data.DataLoader(train_dataset, sampler=ImbalancedDatasetSampler(train_dataset)), 82 | "test": data.DataLoader(test_dataset), 83 | } 84 | 85 | 86 | trainer = Trainer(conf, model, optimizer, scheduler, loss_function, loaders, log, metrics) 87 | trainer.train() 88 | trainer.evaluate() 89 | 90 | -------------------------------------------------------------------------------- /crucible/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .normalization import GroupNorm 4 | 5 | class fcdr(nn.Module): 6 | def __init__(self, in_features, out_features, p=0.5, activation=True): 7 | super(fcdr, self).__init__() 8 | self.seq = nn.Sequential( 9 | nn.Linear(in_features, out_features), 10 | nn.Dropout(p=p), 11 | ) 12 | if activation: 13 | self.seq.add_module("activatoin", nn.ReLU(inplace=True)) 14 | def forward(self, x): 15 | return self.seq(x) 16 | 17 | class fcbr(nn.Module): 18 | def __init__(self, in_features, out_features, activation=True): 19 | super(fcbr, self).__init__() 20 | self.seq = nn.Sequential( 21 | nn.Linear(in_features, out_features), 22 | nn.BatchNorm1d(out_features), 23 | ) 24 | if activation: 25 | self.seq.add_module("activatoin", nn.ReLU(inplace=True)) 26 | def forward(self, x): 27 | return self.seq(x) 28 | 29 | class fcgr(nn.Module): 30 | def __init__(self, in_features, out_features, num_groups=32, activation=True): 31 | super(fcgr, self).__init__() 32 | self.seq = nn.Sequential( 33 | nn.Linear(in_features, out_features), 34 | nn.GroupNorm(num_groups, out_features), 35 | ) 36 | if activation: 37 | self.seq.add_module("activatoin", nn.ReLU(inplace=True)) 38 | def forward(self, x): 39 | return self.seq(x) 40 | 41 | class conv1dbr(nn.Module): 42 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation=True): 43 | super(conv1dbr, self).__init__() 44 | self.seq = nn.Sequential( 45 | nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding=padding), 46 | nn.BatchNorm1d(out_channels), 47 | ) 48 | if activation: 49 | self.seq.add_module("activation", nn.ReLU(inplace=True)) 50 | 51 | def forward(self, x): 52 | return self.seq(x) 53 | 54 | 55 | class conv2dbr(nn.Module): 56 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation=True): 57 | super(conv2dbr, self).__init__() 58 | self.seq = nn.Sequential( 59 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding), 60 | nn.BatchNorm2d(out_channels) 61 | ) 62 | if activation: self.seq.add_module("activation", nn.ReLU(inplace=True)) 63 | 64 | def forward(self, x): 65 | return self.seq(x) 66 | 67 | 68 | class conv2dgr(nn.Module): 69 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation=True): 70 | super(conv2dgr, self).__init__() 71 | self.seq = nn.Sequential( 72 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding), 73 | GroupNorm(out_channels), 74 | ) 75 | if activation: self.seq.add_module("activation", nn.ReLU(inplace=True)) 76 | 77 | def forward(self, x): 78 | return self.seq(x) 79 | 80 | """ 81 | Experiment shows that BN after ReLU is better. 82 | ref: https://github.com/ducha-aiki/caffenet-benchmark/blob/master/batchnorm.md 83 | """ 84 | 85 | class conv1drb(nn.Module): 86 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation=True): 87 | super(conv1drb, self).__init__() 88 | self.seq = nn.Sequential() 89 | self.seq.add_module("conv", nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding=padding)) 90 | if activation: self.seq.add_module("activation", nn.ReLU(inplace=True)) 91 | self.seq.add_module("bn", nn.BatchNorm1d(out_channels)) 92 | 93 | def forward(self, x): 94 | return self.seq(x) 95 | 96 | class conv2drb(nn.Module): 97 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation=True): 98 | super(conv2drb, self).__init__() 99 | self.seq = nn.Sequential() 100 | self.seq.add_module("conv", nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding)) 101 | if activation: self.seq.add_module("activation", nn.ReLU(inplace=True)) 102 | self.seq.add_module("bn", nn.BatchNorm2d(out_channels)) 103 | 104 | def forward(self, x): 105 | return self.seq(x) 106 | 107 | class fcrb(nn.Module): 108 | def __init__(self, in_features, out_features, activation=True): 109 | super(fcrb, self).__init__() 110 | self.seq = nn.Sequential() 111 | self.seq.add_module("fc", nn.Linear(in_features, out_features)) 112 | if activation: self.seq.add_module("activatoin", nn.ReLU(inplace=True)) 113 | self.seq.add_module("bn", nn.BatchNorm1d(out_features)) 114 | 115 | def forward(self, x): 116 | return self.seq(x) 117 | -------------------------------------------------------------------------------- /crucible/vision.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import cv2 4 | import torch 5 | import numpy as np 6 | from PIL import Image 7 | import matplotlib.pyplot as plt 8 | import matplotlib.animation as animation 9 | from mpl_toolkits.mplot3d import Axes3D 10 | from sklearn.metrics import confusion_matrix 11 | 12 | def imread(filename, use_float=True): 13 | """ 14 | Read the image as it is. (Grayscale, RGB, RGBA) 15 | cv2.imread() needs to set FLAGS to negative. 16 | """ 17 | img = plt.imread(filename) 18 | if use_float and np.max(img) > 1: 19 | img = img.astype(np.float64)/255 20 | return img 21 | # return cv2.imread(filename, -1) 22 | 23 | def imsave(filename, img): 24 | """ 25 | Save the image as it is. 26 | plt.imsave() can't save grayscale images. 27 | """ 28 | cv2.imwrite(filename, img) 29 | 30 | def plot_images(img, img2=None): 31 | """ 32 | Plot at most 2 images. 33 | Support passing in ndarray or image path string. 34 | """ 35 | fig = plt.figure(figsize=(20,10)) 36 | if isinstance(img, str): img = imread(img) 37 | if isinstance(img2, str): img2 = imread(img2) 38 | if img2 is None: 39 | ax = fig.add_subplot(111) 40 | ax.imshow(img) 41 | else: 42 | height, width = img.shape[0], img.shape[1] 43 | if height < width: 44 | ax = fig.add_subplot(211) 45 | ax2 = fig.add_subplot(212) 46 | else: 47 | ax = fig.add_subplot(121) 48 | ax2 = fig.add_subplot(122) 49 | ax.imshow(img) 50 | ax2.imshow(img2) 51 | plt.show() 52 | 53 | 54 | def _plot_point_cloud(ax, pc, axes=[0,1,2], keep_ratio=1.0, pointsize=0.05, color='k'): 55 | N = pc.shape[0] 56 | selected = np.random.choice(N, int(N*keep_ratio)) 57 | if not isinstance(color, str): 58 | color = color[selected] 59 | ax.scatter(*(pc[selected[:,None], axes].T), s=pointsize, c=color, alpha=0.5) 60 | if len(axes)==3: 61 | ax.view_init(50, 135) 62 | 63 | def plot_point_cloud(pc, axes=[0,1,2], keep_ratio=1.0, pointsize=0.05): 64 | """ 65 | pc: [N, f] 66 | axes: [0,1,2] 67 | """ 68 | fig = plt.figure(figsize=(20,10)) 69 | if len(axes) == 3: 70 | ax = fig.add_subplot(111, projection='3d') 71 | elif len(axes) == 2: 72 | ax = fig.add_subplot(111) 73 | else: 74 | print("Axes should be either 2 or 3") 75 | exit(1) 76 | _plot_point_cloud(ax, pc, axes, keep_ratio=keep_ratio, pointsize=pointsize) 77 | plt.show() 78 | 79 | # plot a torch_geo data object (point in 3D space) 80 | def plot_graph(data): 81 | pos = data.pos.detach().cpu().numpy() 82 | ppos = pos.reshape(2, -1, 3) 83 | edge_index = data.edge_index.detach().cpu().numpy() 84 | 85 | fig = plt.figure() 86 | ax = fig.add_subplot(111, projection="3d") 87 | ax.scatter(*(ppos[0].T), s=1.0, c="r") 88 | ax.scatter(*(ppos[1].T), s=1.0, c="g") 89 | 90 | num_edges = edge_index.shape[1]//2 91 | for i in range(num_edges): 92 | c = 'k' if data.edge_attr[i] == 1 else 'b' 93 | A, B = pos[edge_index[0][i]], pos[edge_index[1][i]] 94 | ax.plot([A[0],B[0]], [A[1],B[1]], [A[2],B[2]], c=c, lw=1.0) 95 | 96 | plt.show() 97 | 98 | def plot_matrix(mat, path=None): 99 | if isinstance(mat, torch.Tensor): 100 | mat = mat.detach().cpu().numpy() 101 | 102 | fig = plt.figure() 103 | ax = fig.add_subplot(111) 104 | ax.matshow(mat) 105 | 106 | if path is None: 107 | plt.show() 108 | else: 109 | plt.savefig(path) 110 | 111 | plt.close(fig) 112 | 113 | 114 | def plot_confusion_matrix(y_true, y_pred, 115 | normalize=False, 116 | title=None, 117 | cmap=plt.cm.Blues): 118 | """ 119 | This function prints and plots the confusion matrix. 120 | Normalization can be applied by setting `normalize=True`. 121 | """ 122 | if not title: 123 | if normalize: 124 | title = 'Normalized confusion matrix' 125 | else: 126 | title = 'Confusion matrix, without normalization' 127 | 128 | # Compute confusion matrix 129 | cm = confusion_matrix(y_true, y_pred) 130 | # Only use the labels that appear in the data 131 | if normalize: 132 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 133 | print("Normalized confusion matrix") 134 | else: 135 | print('Confusion matrix, without normalization') 136 | 137 | print(cm) 138 | 139 | fig, ax = plt.subplots() 140 | im = ax.imshow(cm, interpolation='nearest', cmap=cmap) 141 | ax.figure.colorbar(im, ax=ax) 142 | 143 | # Loop over data dimensions and create text annotations. 144 | fmt = '.2f' if normalize else 'd' 145 | thresh = cm.max() / 2. 146 | for i in range(cm.shape[0]): 147 | for j in range(cm.shape[1]): 148 | ax.text(j, i, format(cm[i, j], fmt), 149 | ha="center", va="center", 150 | color="white" if cm[i, j] > thresh else "black") 151 | fig.tight_layout() 152 | plt.show() 153 | #return ax 154 | 155 | def view_batch(imgs, lbls, labels=['image', 'label'], stack=False): 156 | ''' 157 | imgs: [D, H, W, C], the depth or batch dimension should be the first. 158 | ''' 159 | fig = plt.figure() 160 | ax1 = fig.add_subplot(121) 161 | ax2 = fig.add_subplot(122) 162 | ax1.set_title(labels[0]) 163 | ax2.set_title(labels[1]) 164 | """ 165 | if init with zeros, the animation may not update? seems bug in animation. 166 | """ 167 | if stack: 168 | lbls = np.stack((lbls, imgs, imgs), -1) 169 | img1 = ax1.imshow(np.random.rand(*imgs.shape[1:])) 170 | img2 = ax2.imshow(np.random.rand(*lbls.shape[1:])) 171 | def update(i): 172 | plt.suptitle(str(i)) 173 | img1.set_data(imgs[i]) 174 | img2.set_data(lbls[i]) 175 | return img1, img2 176 | ani = animation.FuncAnimation(fig, update, frames=len(imgs), interval=10, blit=False, repeat_delay=0) 177 | plt.show() 178 | -------------------------------------------------------------------------------- /crucible/metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | """ 6 | How to add new metrics: 7 | * define a function that accepts numpy (pred, truth) and returns scalar/vector metric 8 | * wrap it with ScalarMeter/VectorMeter 9 | """ 10 | 11 | 12 | def IoU(preds, truths): 13 | """ 14 | preds: [B, H, W] 15 | truths: [B, H, W] 16 | """ 17 | batch_size = truths.shape[0] 18 | ious = [] 19 | 20 | for batch in range(batch_size): 21 | for part in range(self.nCls): 22 | I = np.sum(np.logical_and(preds[batch] == part, truths[batch] == part)) 23 | U = np.sum(np.logical_or(preds[batch] == part, truths[batch] == part)) 24 | if U == 0: 25 | continue 26 | else: 27 | ious.append(I/U) 28 | 29 | return np.mean(ious) 30 | 31 | class ScalarMeter: 32 | def __init__(self, name, core, larger=True, reduction="mean"): 33 | self.core = core 34 | self.data = [] 35 | self.larger = larger 36 | self.reduction = reduction 37 | self.name = name 38 | 39 | def clear(self): 40 | self.data = [] 41 | 42 | def prepare_inputs(self, outputs, truths): 43 | """ 44 | outputs and truths are pytorch tensors or numpy ndarrays. 45 | """ 46 | if torch.is_tensor(outputs): 47 | outputs = outputs.detach().cpu().numpy() 48 | if torch.is_tensor(truths): 49 | truths = truths.detach().cpu().numpy() 50 | 51 | return outputs, truths 52 | 53 | def update(self, outputs, truths): 54 | outputs, truths = self.prepare_inputs(outputs, truths) 55 | res = self.core(outputs, truths) 56 | self.data.append(res) 57 | 58 | def measure(self): 59 | if self.reduction == "mean": 60 | return np.mean(self.data) 61 | elif self.reduction == "sum": 62 | return np.sum(self.data) 63 | 64 | def better(self, A, B): 65 | if self.larger: 66 | return A > B 67 | else: 68 | return A < B 69 | 70 | def write(self, writer, global_step, prefix=""): 71 | writer.add_scalar(os.path.join(prefix, self.name), self.measure(), global_step) 72 | 73 | def report(self): 74 | text = f"{self.name} = {self.measure():.4f}\n" 75 | return text 76 | 77 | class VectorMeter: 78 | def __init__(self, name, core, larger=True, reduction="mean"): 79 | self.core = core 80 | """ 81 | core: lambda outputs, truths -> [vector] 82 | assume core function accepts torch.Tensor 83 | Classification needs three meters: accuracy, precision, recall 84 | 85 | larger: True 86 | the larger, the better 87 | """ 88 | self.data = [] 89 | self.larger = larger 90 | self.reduction = reduction 91 | self.name = name 92 | 93 | def clear(self): 94 | self.data = [] 95 | 96 | def prepare_inputs(self, outputs, truths): 97 | """ 98 | outputs and truths are pytorch tensors or numpy ndarrays. 99 | """ 100 | if torch.is_tensor(outputs): 101 | outputs = outputs.detach().cpu().numpy() 102 | if torch.is_tensor(truths): 103 | truths = truths.detach().cpu().numpy() 104 | 105 | return outputs, truths 106 | 107 | def update(self, outputs, truths): 108 | outputs, truths = self.prepare_inputs(outputs, truths) 109 | res = self.core(outputs, truths) 110 | self.data.append(res) 111 | 112 | def measure(self): 113 | if self.reduction == "mean": 114 | return np.mean(self.data) 115 | elif self.reduction == "sum": 116 | return np.sum(self.data) 117 | 118 | def better(self, A, B): 119 | if self.larger: 120 | return A > B 121 | else: 122 | return A < B 123 | 124 | def write(self, writer, global_step, prefix=""): 125 | writer.add_scalar(os.path.join(prefix, self.name), self.measure(), global_step) 126 | 127 | def report(self): 128 | res = np.mean(self.data, axis=0) 129 | text = f"{self.name}: mean = {np.mean(res):.4f}\n" 130 | for i in range(len(res)): 131 | text += f"\tClass {i} = {res[i]:.4f}\n" 132 | return text 133 | 134 | class ClassificationMeter: 135 | """ statistics for classification """ 136 | def __init__(self, nCls, eps=1e-5, names=None, keep_history=False): 137 | self.nCls = nCls 138 | self.names = names 139 | self.eps = eps 140 | self.N = 0 141 | self.table = np.zeros((self.nCls, 4), dtype=np.int32) 142 | self.keep_history = keep_history 143 | if keep_history: 144 | self.hist_preds = [] 145 | self.hist_truths = [] 146 | 147 | def clear(self): 148 | self.N = 0 149 | self.table = np.zeros((self.nCls, 4), dtype=np.int32) 150 | if self.keep_history: 151 | self.hist_preds = [] 152 | self.hist_truths = [] 153 | 154 | def prepare_inputs(self, outputs, truths): 155 | """ 156 | outputs and truths are pytorch tensors or numpy ndarrays. 157 | """ 158 | if torch.is_tensor(outputs): 159 | outputs = outputs.detach().cpu().numpy() 160 | if torch.is_tensor(truths): 161 | truths = truths.detach().cpu().numpy() 162 | 163 | return outputs, truths 164 | 165 | def update(self, preds, truths): 166 | 167 | preds, truths = self.prepare_inputs(preds, truths) 168 | 169 | if self.keep_history: 170 | self.hist_preds.extend(preds.tolist()) 171 | self.hist_truths.extend(truths.tolist()) 172 | 173 | self.N += np.prod(truths.shape) 174 | for Cls in range(self.nCls): 175 | true_positive = np.count_nonzero(np.bitwise_and(preds == Cls, truths == Cls)) 176 | true_negative = np.count_nonzero(np.bitwise_and(preds != Cls, truths != Cls)) 177 | false_positive = np.count_nonzero(np.bitwise_and(preds == Cls, truths != Cls)) 178 | false_negative = np.count_nonzero(np.bitwise_and(preds != Cls, truths == Cls)) 179 | self.table[Cls] += [true_positive, true_negative, false_positive, false_negative] 180 | 181 | def measure(self): 182 | """Overall Accuracy""" 183 | total_TP = np.sum(self.table[:, 0]) # all true positives 184 | accuracy = total_TP/self.N 185 | return accuracy 186 | 187 | def better(self, A, B): 188 | return A > B 189 | 190 | def write(self, writer, global_step, prefix=""): 191 | writer.add_scalar(os.path.join(prefix, "Accuracy"), self.measure(), global_step) 192 | 193 | def plot_conf_mat(self): 194 | if not self.keep_history: 195 | print("[ERROR]: classification meter not keeping history.") 196 | return 197 | #mat = confusion_matrix(self.hist_truths, self.hist_preds) 198 | from .vision import plot_confusion_matrix 199 | plot_confusion_matrix(self.hist_truths, self.hist_preds) 200 | 201 | def report(self, each_class=True, conf_mat=False): 202 | precisions = [] 203 | recalls = [] 204 | for Cls in range(self.nCls): 205 | precision = self.table[Cls,0] / (self.table[Cls,0] + self.table[Cls,3] + self.eps) # TP / (TP + FN) 206 | recall = self.table[Cls,0] / (self.table[Cls,0] + self.table[Cls,2] + self.eps) # TP / (TP + FP) 207 | precisions.append(precision) 208 | recalls.append(recall) 209 | total_TP = np.sum(self.table[:, 0]) # all true positives 210 | accuracy = total_TP/self.N 211 | accuracy_mean_class = np.mean(precisions) 212 | 213 | text = f"Overall Accuracy = {accuracy:.4f}({total_TP}/{self.N})\n" 214 | text += f"\tMean-class Accuracy = {accuracy_mean_class:.4f}\n" 215 | 216 | if each_class: 217 | for Cls in range(self.nCls): 218 | #if precisions[Cls] != 0 or recalls[Cls] != 0: 219 | text += f"\tClass {str(Cls)+'('+self.names[Cls]+')' if self.names is not None else Cls}: precision = {precisions[Cls]:.3f} recall = {recalls[Cls]:.3f}\n" 220 | if conf_mat: 221 | self.plot_conf_mat() 222 | 223 | return text 224 | -------------------------------------------------------------------------------- /crucible/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | import signal 5 | import smtplib 6 | import shutil 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from collections import OrderedDict 11 | 12 | from email.mime.multipart import MIMEMultipart 13 | from email.mime.base import MIMEBase 14 | from email.mime.text import MIMEText 15 | from email.utils import COMMASPACE, formatdate 16 | from email import encoders 17 | 18 | from .io import load_json 19 | 20 | class DelayedKeyboardInterrupt(object): 21 | """ Delayed SIGINT 22 | reference: https://stackoverflow.com/questions/842557/how-to-prevent-a-block-of-code-from-being-interrupted-by-keyboardinterrupt-in-py/21919644 23 | with statement: 24 | * call __init__, instantiate context manager class with parameters. 25 | * call __enter__, return of __enter__ is assigned to variable after `as`. 26 | * run `with` body. 27 | * call __exit__(exc_type, exc_value, exc_traceback). If return false, then raise error, else omit. 28 | * call __del__ 29 | signal.signal(sig, handler) 30 | set handler for sig, return the old handler(signal.SIG_DFL) 31 | """ 32 | def __enter__(self): 33 | self.signal_received = False 34 | self.old_handler = signal.signal(signal.SIGINT, self.handler) 35 | 36 | def handler(self, sig, frame): 37 | self.signal_received = (sig, frame) 38 | print('SIGINT received. Delaying KeyboardInterrupt.') 39 | 40 | def __exit__(self, type, value, traceback): 41 | signal.signal(signal.SIGINT, self.old_handler) 42 | if self.signal_received: 43 | self.old_handler(*self.signal_received) 44 | 45 | 46 | class EmailSender(object): 47 | def __init__(self, subject="model", config="/home/hawkey/my-python-modules/hawtorch/email.json"): 48 | args = load_json(config) 49 | self.subject = subject 50 | self.username = args["username"] 51 | self.password = args["password"] 52 | self.send_from = args["send_from"] 53 | self.send_to = args["send_to"] 54 | self.server = args["server"] 55 | self.port = args["port"] 56 | 57 | 58 | def send(self, files=[], subject=None, message="report", use_tls=True): 59 | msg = MIMEMultipart() 60 | msg['From'] = self.send_from 61 | msg['To'] = COMMASPACE.join(self.send_to) 62 | msg['Date'] = formatdate(localtime=True) 63 | subject = self.subject if subject is None else subject 64 | msg['Subject'] = subject 65 | 66 | message = f"Report from your model {subject}, attachments are: \n {files}" 67 | 68 | msg.attach(MIMEText(message)) 69 | 70 | for path in files: 71 | part = MIMEBase('application', "octet-stream") 72 | with open(path, 'rb') as file: 73 | part.set_payload(file.read()) 74 | encoders.encode_base64(part) 75 | part.add_header('Content-Disposition', 76 | 'attachment; filename="{}"'.format(os.path.basename(path))) 77 | msg.attach(part) 78 | 79 | smtp = smtplib.SMTP(self.server, self.port) 80 | if use_tls: 81 | smtp.starttls() 82 | smtp.login(self.username, self.password) 83 | smtp.sendmail(self.send_from, self.send_to, msg.as_string()) 84 | smtp.quit() 85 | 86 | def fix_random_seed(seed=42, cudnn=True): 87 | torch.manual_seed(seed) 88 | np.random.seed(seed) 89 | if cudnn: 90 | torch.backends.cudnn.deterministic = True 91 | torch.backends.cudnn.benchmark = False 92 | 93 | 94 | 95 | def summary(model, input_size, batch_size=-1, device="cuda", logger=None): 96 | # redirect to write in file 97 | if logger is not None: 98 | print = logger._print 99 | 100 | def register_hook(module): 101 | 102 | def hook(module, input, output): 103 | class_name = str(module.__class__).split(".")[-1].split("'")[0] 104 | module_idx = len(summary) 105 | 106 | m_key = "%s-%i" % (class_name, module_idx + 1) 107 | summary[m_key] = OrderedDict() 108 | summary[m_key]["input_shape"] = list(input[0].size()) 109 | summary[m_key]["input_shape"][0] = batch_size 110 | if isinstance(output, (list, tuple)): 111 | summary[m_key]["output_shape"] = [ 112 | [-1] + list(o.size())[1:] for o in output 113 | ] 114 | else: 115 | summary[m_key]["output_shape"] = list(output.size()) 116 | summary[m_key]["output_shape"][0] = batch_size 117 | 118 | params = 0 119 | if hasattr(module, "weight") and hasattr(module.weight, "size"): 120 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 121 | summary[m_key]["trainable"] = module.weight.requires_grad 122 | if hasattr(module, "bias") and hasattr(module.bias, "size"): 123 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 124 | summary[m_key]["nb_params"] = params 125 | 126 | if ( 127 | not isinstance(module, nn.Sequential) 128 | and not isinstance(module, nn.ModuleList) 129 | and not (module == model) 130 | ): 131 | hooks.append(module.register_forward_hook(hook)) 132 | 133 | device = device.lower() 134 | assert device in [ 135 | "cuda", 136 | "cpu", 137 | ], "Input device is not valid, please specify 'cuda' or 'cpu'" 138 | 139 | if device == "cuda" and torch.cuda.is_available(): 140 | dtype = torch.cuda.FloatTensor 141 | else: 142 | dtype = torch.FloatTensor 143 | 144 | # multiple inputs to the network 145 | if isinstance(input_size, tuple): 146 | input_size = [input_size] 147 | 148 | # batch_size of 2 for batchnorm 149 | x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size] 150 | # print(type(x[0])) 151 | 152 | # create properties 153 | summary = OrderedDict() 154 | hooks = [] 155 | 156 | # register hook 157 | model.apply(register_hook) 158 | 159 | # make a forward pass 160 | # print(x.shape) 161 | model(*x) 162 | 163 | # remove these hooks 164 | for h in hooks: 165 | h.remove() 166 | 167 | print("----------------------------------------------------------------") 168 | line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #") 169 | print(line_new) 170 | print("================================================================") 171 | total_params = 0 172 | total_output = 0 173 | trainable_params = 0 174 | for layer in summary: 175 | # input_shape, output_shape, trainable, nb_params 176 | line_new = "{:>20} {:>25} {:>15}".format( 177 | layer, 178 | str(summary[layer]["output_shape"]), 179 | "{0:,}".format(summary[layer]["nb_params"]), 180 | ) 181 | total_params += summary[layer]["nb_params"] 182 | total_output += np.prod(summary[layer]["output_shape"]) 183 | if "trainable" in summary[layer]: 184 | if summary[layer]["trainable"] == True: 185 | trainable_params += summary[layer]["nb_params"] 186 | print(line_new) 187 | 188 | # assume 4 bytes/number (float on cuda). 189 | total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.)) 190 | total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients 191 | total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.)) 192 | total_size = total_params_size + total_output_size + total_input_size 193 | 194 | print("================================================================") 195 | print("Total params: {0:,}".format(total_params)) 196 | print("Trainable params: {0:,}".format(trainable_params)) 197 | print("Non-trainable params: {0:,}".format(total_params - trainable_params)) 198 | print("----------------------------------------------------------------") 199 | print("Input size (MB): %0.2f" % total_input_size) 200 | print("Forward/backward pass size (MB): %0.2f" % total_output_size) 201 | print("Params size (MB): %0.2f" % total_params_size) 202 | print("Estimated Total Size (MB): %0.2f" % total_size) 203 | print("----------------------------------------------------------------") 204 | # return summary 205 | 206 | def backup(workspace, patterns=["*.py", "*.json", "*.sh"], recursive=True): 207 | save_path = os.path.join(workspace, "backup") 208 | os.makedirs(save_path, exist_ok=True) 209 | for patt in patterns: 210 | files = glob.glob(patt, recursive=True) 211 | for f in files: 212 | if os.path.exists(f): 213 | ff = os.path.join(save_path, f.split('/')[-1]) 214 | shutil.copy2(f, ff) 215 | print(f"[INFO] backed up files.") 216 | -------------------------------------------------------------------------------- /Explaination.md: -------------------------------------------------------------------------------- 1 | # Classification Losses & Focal Loss 2 | 3 | In PyTorch, All losses takes in Predictions (x, Input) and Ground Truth (y, target) , to calculate a list L: 4 | 5 | 6 | $$ 7 | l(x, y) = L = \{l_i\}_{i=0,1,..} \\ 8 | $$ 9 | And return L.sum() or L.mean() corresponding to the reduction parameter. 10 | 11 | ### NLLLoss 12 | 13 | Negative Log Likelihood Loss. 14 | 15 | Though it has Log in name, it doesn't calculate the logarithm. 16 | 17 | 18 | $$ 19 | l_i = -x_i[y_i] 20 | $$ 21 | 22 | ```python 23 | import torch 24 | import torch.nn as nn 25 | 26 | preds = torch.Tensor([[0.2, 0.3, 0.5]]) 27 | target = torch.Tensor([0]).to(torch.int64) 28 | 29 | loss = nn.NLLLoss() 30 | print(loss(preds, target)) 31 | # -0.2 32 | ``` 33 | 34 | 35 | 36 | ### CrossEntropyLoss 37 | 38 | Simply apply log and softmax, then pass to NLLLoss. 39 | 40 | `CrossEntropyLoss(x, y) = NLLLoss(LogSoftmax(x), y)` 41 | 42 | 43 | $$ 44 | l_i = -log(\frac {exp(x_i[y_i])} {\sum_j exp(x_i[j])}) 45 | $$ 46 | 47 | ```python 48 | import torch 49 | import torch.nn as nn 50 | import torch.nn.functional as F 51 | 52 | preds = torch.Tensor([[0.2, 0.3, 0.4]]) 53 | target = torch.Tensor([0]).to(torch.int64) 54 | 55 | loss = nn.NLLLoss() 56 | print(loss(F.log_softmax(preds), target)) 57 | # 1.2019 58 | 59 | loss = nn.CrossEntropyLoss() 60 | print(loss(preds, target)) 61 | # 1.2019 62 | ``` 63 | 64 | 65 | 66 | ### BCELoss 67 | 68 | However, this is different from CrossEntropyLoss when there are only two classes. 69 | 70 | No logarithm or softmax is performed. 71 | 72 | Input is `(N, ), float32`, and Target is `(N, ), float32` 73 | 74 | (Note for CE, Input is `(N, C), float32`, and Target is `(N, ) int64`) 75 | 76 | 77 | $$ 78 | l_i = -(y_ilogx_i + (1-y_i)log(1-x_i)) 79 | $$ 80 | 81 | ```python 82 | import torch 83 | import torch.nn as nn 84 | 85 | preds = torch.Tensor([0.2]) 86 | target = torch.Tensor([1]) 87 | 88 | loss = nn.BCELoss() 89 | print(loss(preds, target)) 90 | # -1.6094 91 | print(-np.log(0.2)) 92 | # -1.6094 93 | ``` 94 | 95 | 96 | 97 | ### FocalLoss 98 | 99 | \[[Paper](https://arxiv.org/abs/1708.02002)\] 100 | 101 | Focal Loss is invented first as an improvement of Binary Cross Entropy Loss to solve **the imbalanced classification problem:** 102 | 103 | 104 | $$ 105 | l_i = -(y_i(1-x_i)^{\gamma}logx_i + (1-y_i)x_i^{\gamma}log(1-x_i)) 106 | $$ 107 | 108 | 109 | Based on this, we can write the multi-class form as: 110 | 111 | 112 | $$ 113 | s_i = \frac {exp(x_i[y_i])} {\sum_j exp(x_i[j])}\\ 114 | l_i = -(1-s_i)^{\gamma}log(s_i) 115 | $$ 116 | 117 | 118 | Note that in the original paper, there is also an alpha parameter assigning different weight for each class. This is in fact implemented by PyTorch NLLLoss's weight parameter. 119 | 120 | ```python 121 | class FocalLoss(nn.Module): 122 | ''' 123 | Multi-class Focal Loss 124 | ''' 125 | def __init__(self, gamma=2, weight=None): 126 | super(FocalLoss, self).__init__() 127 | self.gamma = gamma 128 | self.weight = weight 129 | self.reduction = reduction 130 | 131 | def forward(self, input, target): 132 | """ 133 | input: [N, C], float32 134 | target: [N, ], int64 135 | """ 136 | logpt = F.log_softmax(input, dim=1) 137 | pt = torch.exp(logpt) 138 | logpt = (1-pt)**self.gamma * logpt 139 | loss = F.nll_loss(logpt, target, self.weight) 140 | return loss 141 | ``` 142 | 143 | 144 | 145 | Simple code for test: 146 | 147 | ```python 148 | import torch 149 | import torch.nn as nn 150 | import torch.nn.functional as F 151 | import torch.optim as optim 152 | import numpy as np 153 | 154 | Device = torch.device("cuda") 155 | Epoch = 32 156 | 157 | np.random.seed(42) 158 | 159 | class fcdr(nn.Module): 160 | def __init__(self, Fin, Fout, dp=0.5): 161 | super(fcdr, self).__init__() 162 | self.fc = nn.Linear(Fin, Fout) 163 | self.dp = nn.Dropout(dp) 164 | self.ac = nn.ReLU(True) 165 | def forward(self, x): 166 | x = self.fc(x) 167 | x = self.dp(x) 168 | x = self.ac(x) 169 | return x 170 | 171 | class FCN(nn.Module): 172 | def __init__(self): 173 | super(FCN, self).__init__() 174 | self.fc0 = fcdr(10, 256) 175 | self.fc1 = fcdr(256, 512) 176 | self.fc2 = nn.Linear(512, 2) 177 | 178 | def forward(self, x): 179 | x = self.fc0(x) 180 | x = self.fc1(x) 181 | x = self.fc2(x) # [B, 2] 182 | return x 183 | 184 | class FocalLoss(nn.Module): 185 | ''' 186 | Multi-class Focal loss implementation 187 | ''' 188 | def __init__(self, gamma=2, weight=None): 189 | super(FocalLoss, self).__init__() 190 | self.gamma = gamma 191 | self.weight = weight 192 | 193 | def forward(self, input, target): 194 | """ 195 | input: [N, C] 196 | target: [N, ] 197 | """ 198 | logpt = F.log_softmax(input, dim=1) 199 | pt = torch.exp(logpt) 200 | logpt = (1-pt)**self.gamma * logpt 201 | loss = F.nll_loss(logpt, target, self.weight) 202 | return loss 203 | 204 | class Averager: 205 | """ statistics for classification """ 206 | def __init__(self, nCls=2): 207 | self.nCls = nCls 208 | self.N = 0 209 | self.eps = 1e-15 210 | self.table = np.zeros((nCls, 4), dtype = np.int32) 211 | 212 | def update(self, logits, truth): 213 | self.N += 1 214 | preds = torch.argmax(logits, dim=1).detach().cpu().numpy() # [B, ] 215 | labels = truth.detach().cpu().numpy() # [B, ] 216 | for Cls in range(self.nCls): 217 | true_positive = np.count_nonzero(np.bitwise_and(preds == Cls, labels == Cls)) 218 | true_negative = np.count_nonzero(np.bitwise_and(preds != Cls, labels != Cls)) 219 | false_positive = np.count_nonzero(np.bitwise_and(preds == Cls, labels != Cls)) 220 | false_negative = np.count_nonzero(np.bitwise_and(preds != Cls, labels == Cls)) 221 | self.table[Cls] += [true_positive, true_negative, false_positive, false_negative] 222 | 223 | def measure(self): 224 | precisions = [] 225 | recalls = [] 226 | for Cls in range(self.nCls): 227 | precision = self.table[Cls,0] / (self.table[Cls,0] + self.table[Cls,3] + self.eps) # TP / (TP + FN) 228 | recall = self.table[Cls,0] / (self.table[Cls,0] + self.table[Cls,2] + self.eps) # TP / (TP + FP) 229 | precisions.append(precision) 230 | recalls.append(recall) 231 | total_TP = np.sum(self.table[:, 0]) # all true positives 232 | total = np.sum(self.table[0]) # total trials 233 | accuracy = total_TP/total 234 | return accuracy, precisions, recalls 235 | 236 | def report(self, intro, multiclass=True): 237 | accuracy, precisions, recalls = self.measure() 238 | text = "{}: Accuracy = {:.4f}\n".format(intro, accuracy) 239 | if multiclass: 240 | for Cls in range(self.nCls): 241 | text += "\tClass {}: precision = {:.3f} recall = {:.3f}\n".format(Cls, precisions[Cls], recalls[Cls]) 242 | print(text, end='') 243 | 244 | # non-batched 245 | def train(X, Y, model, criterion, optimizer, epoch): 246 | model.train() 247 | avg = Averager() 248 | for x, y in zip(X, Y): 249 | x = torch.from_numpy(x).to(Device) 250 | y = torch.from_numpy(y).to(Device) 251 | preds = model(x).unsqueeze(0) # [1, 2] for [N, C], due to non-batch 252 | loss = criterion(preds, y) # [1] for [N, ] 253 | optimizer.zero_grad() 254 | loss.backward() 255 | optimizer.step() 256 | avg.update(preds, y) 257 | avg.report("==> Epoch {}".format(epoch)) 258 | 259 | def validate(X, Y, model, epoch): 260 | model.eval() 261 | avg = Averager() 262 | for x, y in zip(X, Y): 263 | x = torch.from_numpy(x).to(Device) 264 | y = torch.from_numpy(y).to(Device) 265 | with torch.no_grad(): 266 | preds = model(x).unsqueeze(0) 267 | avg.update(preds, y) 268 | avg.report("++> Validate {}".format(epoch)) 269 | 270 | def gendata(N, p=[0.9, 0.1]): 271 | Y = np.random.choice(2, size=(N, 1), p=p).astype(np.int64) 272 | X = (np.random.rand(N, 10)+Y*0.2).astype(np.float32) # slightly related to Y 273 | #X = np.hstack((Y, X)).astype(np.float32) 274 | #Y = np.bitwise_xor(Y, np.ones_like(Y)) # reverse it 275 | return X, Y 276 | 277 | if __name__ == "__main__": 278 | train_data = gendata(1000) 279 | val_data = gendata(500) 280 | model = FCN().to(Device) 281 | criterion = FocalLoss() 282 | #criterion = nn.CrossEntropyLoss() 283 | optimizer = optim.Adam(model.parameters()) 284 | for epoch in range(Epoch): 285 | train(train_data[0], train_data[1], model, criterion, optimizer, epoch) 286 | validate(val_data[0], val_data[1], model, epoch) 287 | ``` 288 | 289 | -------------------------------------------------------------------------------- /crucible/segtrainer.py: -------------------------------------------------------------------------------- 1 | from . trainer import * 2 | from . vision import plot_images 3 | 4 | import skimage.transform as skt 5 | 6 | 7 | class SegTrainer3D(Trainer): 8 | 9 | def resize(self, image, new_shape, interpolation=1): 10 | B, C, H, W, D = image.shape 11 | image = image.reshape(-1, H, W, D).transpose(1,2,3,0) 12 | image = skt.resize(image, new_shape, order=interpolation, mode='constant', cval=0, clip=True, anti_aliasing=False) 13 | image = image.transpose(3,0,1,2).reshape(B, C, new_shape[0], new_shape[1], new_shape[2]) 14 | return image 15 | 16 | def pad(self, image, new_shape, border_mode="constant", value=0): 17 | ''' 18 | image: [B, C, H, W, D] 19 | new_shape: [H, W, D] 20 | ''' 21 | axes_not_pad = len(image.shape) - len(new_shape) 22 | 23 | old_shape = np.array(image.shape[-len(new_shape):]) 24 | new_shape = np.array([max(new_shape[i], old_shape[i]) for i in range(len(new_shape))]) 25 | 26 | difference = new_shape - old_shape 27 | pad_below = difference // 2 28 | pad_above = difference - pad_below 29 | 30 | pad_list = [[0, 0]] * axes_not_pad + [list(i) for i in zip(pad_below, pad_above)] 31 | 32 | if border_mode == 'reflect': 33 | res = np.pad(image, pad_list, border_mode) 34 | elif border_mode == 'constant': 35 | res = np.pad(image, pad_list, border_mode, constant_values=value) 36 | else: 37 | raise NotImplementedError 38 | pad_list = np.array(pad_list) 39 | pad_list[:, 1] = np.array(res.shape) - pad_list[:, 1] 40 | slicer = list(slice(*i) for i in pad_list) 41 | 42 | return res, slicer 43 | 44 | def mirroring_forward(self, x): 45 | y = self.model(x) 46 | for dims in [[0],[1],[2],[0,1],[1,2],[0,2],[0,1,2]]: 47 | y = y + torch.flip(self.model(torch.flip(x, dims)), dims) 48 | y = y / 8 49 | return y 50 | 51 | def eval_step(self, data): 52 | """ 53 | image: [B, C, H, W] 54 | """ 55 | image = data["image"].numpy() 56 | 57 | if self.conf.eval_type == 'whole': 58 | B, C, H, W, D = image.shape 59 | resized_image = self.resize(image, self.conf.patch_size, 3) 60 | 61 | resized_image = torch.from_numpy(resized_image).float().to(self.device) 62 | pred = self.model(resized_image) 63 | pred = pred.detach().cpu().numpy() 64 | 65 | pred = self.resize(pred, (H, W, D), 3) 66 | pred = pred.argmax(1) 67 | 68 | elif self.conf.eval_type == 'sliding_window': 69 | 70 | image, slicer = self.pad(image, self.conf.patch_size) 71 | 72 | B, C, H, W, D = image.shape 73 | # pad to at least patch_size 74 | pred_sum = np.zeros((B, self.conf.num_classes, H, W, D)) 75 | pred_cnt = np.zeros((1, 1, H, W, D)) 76 | # slide window 77 | ph, pw, pd = self.conf.patch_size 78 | sh, sw, sd = self.conf.patch_stride 79 | for h in range(0, H, sh): 80 | for w in range(0, W, sw): 81 | for d in range(0, D, sd): 82 | hh = min(h + ph, H) 83 | ww = min(w + pw, W) 84 | dd = min(d + pd, D) 85 | h = hh - ph 86 | w = ww - pw 87 | d = dd - pd 88 | 89 | patch = image[:, :, h:hh, w:ww, d:dd] # [B, C, ph, pw, pd] 90 | patch = torch.from_numpy(patch).float().to(self.device) 91 | 92 | # not enough 93 | pred = self.model(patch) 94 | #pred = self.mirroring_forward(patch) 95 | 96 | pred = pred.detach().cpu().numpy() 97 | 98 | pred_sum[:, :, h:hh, w:ww, d:dd] += pred 99 | pred_cnt[0, 0, h:hh, w:ww, d:dd] += 1 100 | 101 | pred = pred_sum / pred_cnt 102 | pred = pred[:, :, slicer[2], slicer[3], slicer[4]] 103 | pred = pred.argmax(1) # [B, H, W, D] 104 | 105 | else: 106 | raise NotImplementedError 107 | 108 | data["pred"] = pred 109 | 110 | return data 111 | 112 | 113 | def evaluate(self, 114 | eval_set=None, 115 | save_snap=True, 116 | save_image=False, 117 | save_image_folder=None, 118 | show_image=False, 119 | ): 120 | 121 | """ 122 | final evaluate at the best epoch. 123 | """ 124 | eval_set = self.eval_set if eval_set is None else eval_set 125 | self.log.info(f"Evaluate at the best epoch on {eval_set} set...") 126 | 127 | # load model 128 | model_name = type(self.model).__name__ 129 | ckpt_path = os.path.join(self.workspace_path, 'checkpoints') 130 | best_path = f"{ckpt_path}/{model_name}_best.pth.tar" 131 | if not os.path.exists(best_path): 132 | self.log.error(f"Best checkpoint not found at {best_path}, load by default.") 133 | self.load_checkpoint() 134 | else: 135 | self.load_checkpoint(best_path) 136 | 137 | # turn off logging to tensorboardX 138 | self.use_tensorboardX = False 139 | self.evaluate_one_epoch(eval_set, save_snap, save_image, save_image_folder, show_image) 140 | 141 | def evaluate_one_epoch(self, 142 | eval_set, 143 | save_snap=False, 144 | save_image=False, 145 | save_image_folder=None, 146 | show_image=False, 147 | ): 148 | self.log.log(f"++> Evaluate at epoch {self.epoch} ...") 149 | 150 | for metric in self.metrics: 151 | metric.clear() 152 | 153 | self.model.eval() 154 | 155 | pbar = self.dataloaders[eval_set] 156 | if self.use_tqdm: 157 | pbar = tqdm.tqdm(pbar) 158 | 159 | epoch_start_time = self.get_time() 160 | 161 | if save_image: 162 | if save_image_folder is None: 163 | save_image_folder = 'evaluation_' + self.time_stamp 164 | save_image_folder = os.path.join(self.workspace_path, save_image_folder) 165 | os.makedirs(save_image_folder, exist_ok=True) 166 | 167 | with torch.no_grad(): 168 | self.local_step = 0 169 | start_time = self.get_time() 170 | 171 | for data in pbar: 172 | self.local_step += 1 173 | 174 | if self.max_eval_step is not None and self.local_step > self.max_eval_step: 175 | break 176 | 177 | data = self.eval_step(data) 178 | pred, mask = data["pred"], data["mask"] 179 | 180 | for metric in self.metrics: 181 | metric.update(pred, mask) 182 | 183 | if show_image: 184 | batch_size = pred.shape[0] 185 | f = plt.figure() 186 | ax0 = f.add_subplot(121) 187 | ax1 = f.add_subplot(122) 188 | for batch in range(batch_size): 189 | ax0.imshow(pred[batch]) 190 | ax1.imshow(mask[batch]) 191 | plt.show() 192 | 193 | if save_image: 194 | batch_size = pred.shape[0] 195 | for batch in range(batch_size): 196 | if 'name' in data: 197 | name = data['name'][batch] + '.npy' 198 | else: 199 | name = str(self.local_step) + '_' + str(batch) + '.npy' 200 | 201 | np.save(os.path.join(save_image_folder, name), pred[batch]) 202 | self.log.info(f"Saved image {name} at {save_image_folder}.") 203 | 204 | 205 | total_time = self.get_time() - start_time 206 | self.log.log1(f"total_time={total_time:.2f}") 207 | 208 | self.stats["EvalResults"].append(self.metrics[0].measure()) 209 | 210 | if save_snap and self.use_tensorboardX: 211 | # only save first batch first layer 212 | self.writer.add_image("evaluate/image", data["image"][0, :, :, 0], self.epoch) 213 | self.writer.add_image("evaluate/pred", np.expand_dims(pred[0, :, :, 0], 0), self.epoch) 214 | self.writer.add_image("evaluate/mask", np.expand_dims(mask[0, :, :, 0], 0), self.epoch) 215 | 216 | for metric in self.metrics: 217 | self.log.log1(metric.report()) 218 | if self.use_tensorboardX: 219 | metric.write(self.writer, self.epoch, prefix="evaluate") 220 | metric.clear() 221 | 222 | epoch_end_time = self.get_time() 223 | self.log.log(f"++> Evaluate Finished. time={epoch_end_time-epoch_start_time:.4f}") 224 | 225 | def predict(self, eval_set='test', save_image=True, save_image_folder=None, show_image=False): 226 | self.log.log(f"++> Predict at epoch {self.epoch} ...") 227 | 228 | self.model.eval() 229 | 230 | pbar = self.dataloaders[eval_set] 231 | if self.use_tqdm: 232 | pbar = tqdm.tqdm(pbar) 233 | 234 | epoch_start_time = self.get_time() 235 | 236 | if save_image: 237 | if save_image_folder is None: 238 | save_image_folder = 'prediction_' + self.time_stamp 239 | save_image_folder = os.path.join(self.workspace_path, save_image_folder) 240 | os.makedirs(save_image_folder, exist_ok=True) 241 | 242 | with torch.no_grad(): 243 | self.local_step = 0 244 | start_time = self.get_time() 245 | 246 | for data in pbar: 247 | self.local_step += 1 248 | 249 | data = self.eval_step(data) 250 | pred = data["pred"] 251 | 252 | if save_image: 253 | batch_size = pred.shape[0] 254 | for batch in range(batch_size): 255 | if 'name' in data: 256 | name = data['name'][batch] + '.npy' 257 | else: 258 | name = str(self.local_step) + '_' + str(batch) + '.npy' 259 | 260 | np.save(os.path.join(save_image_folder, name), pred[batch]) 261 | self.log.info(f"Saved image {name} at {save_image_folder}.") 262 | 263 | 264 | total_time = self.get_time() - start_time 265 | self.log.log1(f"total_time={total_time:.2f}") 266 | 267 | epoch_end_time = self.get_time() 268 | self.log.log(f"++> Predict Finished. time={epoch_end_time-epoch_start_time:.4f}") -------------------------------------------------------------------------------- /crucible/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | import time 5 | import tqdm 6 | import tensorboardX 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import matplotlib.pyplot as plt 12 | 13 | from . io import logger 14 | from . utils import DelayedKeyboardInterrupt, summary 15 | from . init import init_weights 16 | 17 | 18 | 19 | class Trainer(object): 20 | """Base trainer class. 21 | """ 22 | 23 | def __init__(self, 24 | conf, 25 | model, 26 | optimizer, 27 | lr_scheduler, 28 | objective, 29 | dataloaders, 30 | logger, 31 | metrics=[], 32 | input_shape=None, 33 | use_checkpoint="latest", 34 | restart=False, 35 | max_keep_ckpt=1, 36 | eval_set="test", 37 | test_set="test", 38 | eval_interval=1, 39 | report_step_interval=200, 40 | max_eval_step=None, 41 | use_parallel=False, 42 | use_tqdm=True, 43 | use_tensorboardX=True, 44 | weight_init_function=init_weights, 45 | ): 46 | 47 | self.conf = conf 48 | self.device = conf.device 49 | self.workspace_path = conf.workspace 50 | 51 | self.model = model 52 | self.optimizer = optimizer 53 | self.lr_scheduler = lr_scheduler 54 | self.objective = objective 55 | self.dataloaders = dataloaders 56 | self.metrics = metrics 57 | self.log = logger 58 | self.use_checkpoint = use_checkpoint 59 | self.max_keep_ckpt = max_keep_ckpt 60 | self.restart = restart 61 | self.eval_set = eval_set 62 | self.test_set = test_set 63 | self.eval_interval = eval_interval 64 | self.report_step_interval = report_step_interval 65 | self.max_eval_step = max_eval_step 66 | 67 | self.use_parallel = use_parallel 68 | self.use_tqdm = use_tqdm 69 | self.use_tensorboardX = use_tensorboardX 70 | 71 | self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S") 72 | self.log.info(f'Time stamp is {self.time_stamp}') 73 | 74 | self.model.to(self.device) 75 | 76 | if input_shape is not None: 77 | summary(self.model, input_shape, logger=self.log) 78 | 79 | if self.use_parallel: 80 | self.model = nn.DataParallel(self.model) 81 | 82 | if weight_init_function is not None: 83 | self.model.apply(weight_init_function) 84 | 85 | self.log.info(f'Number of model parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}') 86 | 87 | self.epoch = 1 88 | self.global_step = 0 89 | self.local_step = 0 90 | self.stats = { 91 | "StepLoss": [], 92 | "EpochLoss": [], 93 | "EvalResults": [], 94 | "Checkpoints": [], 95 | "BestResult": None, 96 | } 97 | 98 | if self.workspace_path is not None: 99 | os.makedirs(self.workspace_path, exist_ok=True) 100 | if self.use_checkpoint == "latest": 101 | self.log.info("Loading latest checkpoint ...") 102 | self.load_checkpoint() 103 | elif self.use_checkpoint == "scratch": 104 | self.log.info("Train from scratch") 105 | elif self.use_checkpoint == "best": 106 | self.log.info("Loading best checkpoint ...") 107 | model_name = type(self.model).__name__ 108 | ckpt_path = os.path.join(self.workspace_path, 'checkpoints') 109 | best_path = f"{ckpt_path}/{model_name}_best.pth.tar" 110 | self.load_checkpoint(best_path) 111 | else: # path to ckpt 112 | self.log.info(f"Loading checkpoint {self.use_checkpoint} ...") 113 | self.load_checkpoint(self.use_checkpoint) 114 | 115 | 116 | ### --------------------------------------------------- 117 | ### example step function for input segmentation task 118 | ### --------------------------------------------------- 119 | 120 | 121 | def train_step(self, data): 122 | input, truth = data["input"], data["truth"] 123 | 124 | output = self.model(input) 125 | loss = self.objective(output, truth) 126 | pred = output.detach().cpu().numpy().argmax(axis=1) 127 | 128 | data["output"] = output 129 | data["pred"] = pred 130 | data["loss"] = loss 131 | 132 | return data 133 | 134 | def eval_step(self, data): 135 | input, truth = data["input"], data["truth"] 136 | 137 | output = self.model(input) 138 | pred = output.detach().cpu().numpy().argmax(axis=1) 139 | 140 | data["output"] = output 141 | data["pred"] = pred 142 | 143 | return data 144 | 145 | def predict_step(self, data): 146 | input = data["input"] 147 | 148 | output = self.model(input) 149 | pred = output.detach().cpu().numpy().argmax(axis=1) 150 | 151 | data["output"] = output 152 | data["pred"] = pred 153 | 154 | return data 155 | 156 | ### --------------------------------------------------- 157 | 158 | def train(self, max_epochs=None): 159 | """ 160 | do the training process for max_epochs. 161 | """ 162 | if max_epochs is None: 163 | max_epochs = self.conf.max_epochs 164 | 165 | if self.use_tensorboardX: 166 | logdir = os.path.join(self.workspace_path, "run", self.time_stamp) 167 | self.writer = tensorboardX.SummaryWriter(logdir) 168 | 169 | for epoch in range(self.epoch, max_epochs+1): 170 | self.epoch = epoch 171 | 172 | self.train_one_epoch() 173 | 174 | if self.epoch % self.eval_interval == 0: 175 | self.evaluate_one_epoch(self.eval_set) 176 | 177 | if self.workspace_path is not None: 178 | self.save_checkpoint() 179 | 180 | if self.use_tensorboardX: 181 | self.writer.close() 182 | 183 | self.log.info("Finished Training.") 184 | 185 | def evaluate(self, eval_set=None): 186 | """ 187 | final evaluate at the best epoch. 188 | """ 189 | eval_set = self.eval_set if eval_set is None else eval_set 190 | self.log.info(f"Evaluate at the best epoch on {eval_set} set...") 191 | 192 | # load model 193 | model_name = type(self.model).__name__ 194 | ckpt_path = os.path.join(self.workspace_path, 'checkpoints') 195 | best_path = f"{ckpt_path}/{model_name}_best.pth.tar" 196 | if not os.path.exists(best_path): 197 | self.log.error(f"Best checkpoint not found! {best_path}") 198 | raise FileNotFoundError 199 | self.load_checkpoint(best_path) 200 | 201 | # turn off logging to tensorboardX 202 | self.use_tensorboardX = False 203 | self.evaluate_one_epoch(eval_set) 204 | 205 | def get_time(self): 206 | if torch.cuda.is_available(): 207 | torch.cuda.synchronize() 208 | return time.time() 209 | 210 | def prepare_data(self, data): 211 | if isinstance(data, list) or isinstance(data, tuple): 212 | for i, v in enumerate(data): 213 | if isinstance(v, np.ndarray): 214 | data[i] = torch.from_numpy(v).to(self.device) 215 | if torch.is_tensor(v): 216 | data[i] = v.to(self.device) 217 | elif isinstance(data, dict): 218 | for k, v in data.items(): 219 | if isinstance(v, np.ndarray): 220 | data[k] = torch.from_numpy(v).to(self.device) 221 | if torch.is_tensor(v): 222 | data[k] = v.to(self.device) 223 | elif isinstance(data, np.ndarray): 224 | data = torch.from_numpy(data).to(self.device) 225 | else: # is_tensor 226 | data = data.to(self.device) 227 | 228 | return data 229 | 230 | def profile(self, steps=1): 231 | """ 232 | ```bash 233 | python -m torch.utils.bottleneck train.py 234 | ``` 235 | """ 236 | self.log.log(f"==> Start Profiling for {steps} steps.") 237 | 238 | self.model.train() 239 | for metric in self.metrics: 240 | metric.clear() 241 | 242 | start_time = self.get_time() 243 | 244 | data_time = 0 245 | forward_time = 0 246 | backward_time = 0 247 | metric_time = 0 248 | 249 | for i in range(steps): 250 | 251 | data_start_time = self.get_time() 252 | data = next(iter(self.dataloaders["train"])) 253 | data = self.prepare_data(data) 254 | data_time += self.get_time() - data_start_time 255 | 256 | forward_start_time = self.get_time() 257 | data = self.train_step(data) 258 | truth, pred, loss = data["truth"], data["pred"], data["loss"] 259 | forward_time += self.get_time() - forward_start_time 260 | 261 | backward_start_time = self.get_time() 262 | loss.backward() 263 | self.optimizer.step() 264 | self.optimizer.zero_grad() 265 | backward_time += self.get_time() - backward_start_time 266 | 267 | metric_start_time = self.get_time() 268 | for metric in self.metrics: 269 | metric.update(pred, truth) 270 | metric_time += self.get_time() - metric_start_time 271 | 272 | self.lr_scheduler.step() 273 | 274 | end_time = self.get_time() 275 | 276 | self.log.log(f"==> Finished Profiling for {steps} steps, time={end_time-start_time:.4f}({data_time:.4f} + {forward_time:.4f} + {backward_time:.4f} + {metric_time:.4f})") 277 | 278 | 279 | def train_one_epoch(self): 280 | self.log.log(f"==> Start Training Epoch {self.epoch}, lr={self.optimizer.param_groups[0]['lr']} ...") 281 | 282 | for metric in self.metrics: 283 | metric.clear() 284 | total_loss = [] 285 | self.model.train() 286 | 287 | pbar = self.dataloaders["train"] 288 | if self.use_tqdm: 289 | pbar = tqdm.tqdm(pbar) 290 | 291 | self.local_step = 0 292 | epoch_start_time = self.get_time() 293 | 294 | for data in pbar: 295 | start_time = self.get_time() 296 | self.local_step += 1 297 | self.global_step += 1 298 | 299 | data = self.prepare_data(data) 300 | 301 | data = self.train_step(data) 302 | truth, pred, loss = data["truth"], data["pred"], data["loss"] 303 | 304 | loss.backward() 305 | self.optimizer.step() 306 | self.optimizer.zero_grad() 307 | 308 | for metric in self.metrics: 309 | metric.update(pred, truth) 310 | if self.use_tensorboardX: 311 | metric.write(self.writer, self.global_step, prefix="train") 312 | 313 | if self.use_tensorboardX: 314 | self.writer.add_scalar("train/loss", loss.item(), self.global_step) 315 | 316 | total_loss.append(loss.item()) 317 | total_time = self.get_time() - start_time 318 | 319 | if self.report_step_interval > 0 and self.local_step % self.report_step_interval == 0: 320 | self.log.log1(f"step={self.epoch}/{self.local_step}, loss={loss.item():.4f}, time={total_time:.2f}") 321 | for metric in self.metrics: 322 | self.log.log1(metric.report()) 323 | metric.clear() 324 | 325 | if self.report_step_interval < 0: 326 | for metric in self.metrics: 327 | self.log.log1(metric.report()) 328 | metric.clear() 329 | 330 | self.lr_scheduler.step() 331 | epoch_end_time = self.get_time() 332 | average_loss = np.mean(total_loss) 333 | self.stats["StepLoss"].extend(total_loss) 334 | self.stats["EpochLoss"].append(average_loss) 335 | 336 | self.log.log(f"==> Finished Epoch {self.epoch}, average_loss={average_loss:.4f}, time={epoch_end_time-epoch_start_time:.4f}") 337 | 338 | 339 | def evaluate_one_epoch(self, eval_set): 340 | self.log.log(f"++> Evaluate at epoch {self.epoch} ...") 341 | 342 | for metric in self.metrics: 343 | metric.clear() 344 | self.model.eval() 345 | 346 | pbar = self.dataloaders[eval_set] 347 | if self.use_tqdm: 348 | pbar = tqdm.tqdm(pbar) 349 | 350 | epoch_start_time = self.get_time() 351 | 352 | with torch.no_grad(): 353 | self.local_step = 0 354 | start_time = self.get_time() 355 | 356 | for data in pbar: 357 | self.local_step += 1 358 | 359 | if self.max_eval_step is not None and self.local_step > self.max_eval_step: 360 | break 361 | 362 | data = self.prepare_data(data) 363 | data = self.eval_step(data) 364 | pred, truth = data["pred"], data["truth"] 365 | 366 | for metric in self.metrics: 367 | metric.update(pred, truth) 368 | 369 | total_time = self.get_time() - start_time 370 | self.log.log1(f"total_time={total_time:.2f}") 371 | 372 | self.stats["EvalResults"].append(self.metrics[0].measure()) 373 | 374 | for metric in self.metrics: 375 | self.log.log1(metric.report()) 376 | if self.use_tensorboardX: 377 | metric.write(self.writer, self.epoch, prefix="evaluate") 378 | metric.clear() 379 | 380 | epoch_end_time = self.get_time() 381 | self.log.log(f"++> Evaluate Finished. time={epoch_end_time-epoch_start_time:.4f}") 382 | 383 | 384 | def save_checkpoint(self): 385 | """Saves a checkpoint of the network and other variables.""" 386 | with DelayedKeyboardInterrupt(): 387 | model_name = type(self.model).__name__ 388 | ckpt_path = os.path.join(self.workspace_path, 'checkpoints') 389 | file_path = f"{ckpt_path}/{model_name}_ep{self.epoch:04d}.pth.tar" 390 | best_path = f"{ckpt_path}/{model_name}_best.pth.tar" 391 | os.makedirs(ckpt_path, exist_ok=True) 392 | 393 | self.stats["Checkpoints"].append(file_path) 394 | 395 | if len(self.stats["Checkpoints"]) > self.max_keep_ckpt: 396 | old_ckpt = self.stats["Checkpoints"].pop(0) 397 | if os.path.exists(old_ckpt): 398 | os.remove(old_ckpt) 399 | self.log.info(f"Removed old checkpoint {old_ckpt}") 400 | 401 | state = { 402 | 'epoch': self.epoch, 403 | 'global_step': self.global_step, 404 | 'model_name': model_name, 405 | 'model': self.model.state_dict(), 406 | 'optimizer' : self.optimizer.state_dict(), 407 | 'lr_scheduler': self.lr_scheduler.state_dict(), 408 | 'stats' : self.stats, 409 | } 410 | 411 | torch.save(state, file_path) 412 | self.log.info(f"Saved checkpoint {self.epoch} successfully.") 413 | 414 | if self.stats["EvalResults"] is not None: 415 | if self.stats["BestResult"] is None or self.metrics[0].better(self.stats["EvalResults"][-1], self.stats["BestResult"]): 416 | self.stats["BestResult"] = self.stats["EvalResults"][-1] 417 | torch.save(state, best_path) 418 | self.log.info(f"Saved Best checkpoint.") 419 | 420 | 421 | def load_checkpoint(self, checkpoint=None): 422 | """Loads a network checkpoint file. 423 | 424 | Can be called in three different ways: 425 | load_checkpoint(): 426 | Loads the latest epoch from the workspace. Use this to continue training. 427 | load_checkpoint(epoch_num): 428 | Loads the model at the given epoch number (int). 429 | load_checkpoint(path_to_checkpoint): 430 | Loads the file from the given absolute path (str). 431 | """ 432 | model_name = type(self.model).__name__ 433 | 434 | ckpt_path = os.path.join(self.workspace_path, 'checkpoints') 435 | 436 | if checkpoint is None: 437 | # Load most recent checkpoint 438 | checkpoint_list = sorted(glob.glob(f'{ckpt_path}/{model_name}_ep*.pth.tar')) 439 | if checkpoint_list: 440 | checkpoint_path = checkpoint_list[-1] 441 | else: 442 | self.log.info("No checkpoint found, model randomly initialized.") 443 | return False 444 | elif isinstance(checkpoint, int): 445 | # Checkpoint is the epoch number 446 | checkpoint_path = f'{ckpt_path}/{model_name}_ep{checkpoint:04d}.pth.tar' 447 | elif isinstance(checkpoint, str): 448 | # checkpoint is the path 449 | checkpoint_path = os.path.expanduser(checkpoint) 450 | else: 451 | self.log.error("load_checkpoint: Invalid argument") 452 | raise TypeError 453 | 454 | checkpoint_dict = torch.load(checkpoint_path) 455 | 456 | #assert model_name == checkpoint_dict['model_name'], 'network is not of correct type.' 457 | 458 | self.model.load_state_dict(checkpoint_dict['model']) 459 | if not self.restart: 460 | self.log.info("Loading epoch and other status...") 461 | self.epoch = checkpoint_dict['epoch'] + 1 462 | self.global_step = checkpoint_dict['global_step'] 463 | self.optimizer.load_state_dict(checkpoint_dict['optimizer']) 464 | self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler']) 465 | self.lr_scheduler.last_epoch = checkpoint_dict['epoch'] 466 | self.stats = checkpoint_dict['stats'] 467 | else: 468 | self.log.info("Only loading model parameters.") 469 | 470 | self.log.info("Checkpoint Loaded Successfully.") 471 | return True 472 | 473 | 474 | --------------------------------------------------------------------------------