├── README.md ├── pruning.py ├── train.py └── utils └── loss.py /README.md: -------------------------------------------------------------------------------- 1 | # yolov5-pruning 2 | 3 | 4 | 5 | Update 2021/10/4: adapt to the new version of yolov5 6 | 7 | 8 | 9 | --- 10 | 11 | Channel-wise pruning of yolov5 12 | 13 | 14 | 15 | Preparation: 16 | 17 | 1. Download **yolov5** 18 | 19 | ``` 20 | git clone https://github.com/ultralytics/yolov5 21 | cd yolov5 22 | git reset --hard 59aae85a7e40701bb872df673a6ef288e99a4ae3 23 | ``` 24 | 25 | 2. Download this compatible **Torch-Pruning** 26 | 27 | ``` 28 | git clone https://github.com/VainF/Torch-Pruning 29 | cd Torch-Pruning 30 | git reset --hard ec12e0590aad28e607e1df9feb2baf60c8cda689 31 | ``` 32 | 33 | 3. Copy `torch_pruning` to `yolov5` 34 | 35 | 4. Download this repo and copy to `yolov5` 36 | 37 | 38 | 39 | Usage: 40 | 41 | 1. Sparse learning: train new model with `--sl_factor`, L1 loss will be add to weights of all batchnorm layers 42 | 2. Pruning: `python prune.py --shape [batchsize channel height width] --prob 0.1 --weights [xxx.pt] --save_path [xxx_pruned.pt]`, channels with a batchnorm weight that is higher than a threshold will be removed 43 | 3. Fine-tuning: train the pruned model with `--ft_pruned` 44 | 45 | 46 | 47 | Reference: 48 | 49 | * https://github.com/Syencil/mobile-yolov5-pruning-distillation 50 | * https://github.com/VainF/Torch-Pruning -------------------------------------------------------------------------------- /pruning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch_pruning as tp 7 | import copy 8 | import matplotlib.pyplot as plt 9 | from models.yolo import Model 10 | from utils.torch_utils import intersect_dicts, is_parallel 11 | 12 | 13 | def load_model(weights): 14 | ckpt = torch.load(weights, map_location=device) # load checkpoint 15 | model = Model(ckpt['model'].yaml).to(device) # create 16 | state_dict = ckpt['model'].float().state_dict() # to FP32 17 | state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=[]) # intersect 18 | model.load_state_dict(state_dict, strict=False) # load 19 | print('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report 20 | assert len(state_dict) == len(model.state_dict()) 21 | 22 | model.float() 23 | model.model[-1].export = True 24 | return model 25 | 26 | def bn_analyze(prunable_modules, save_path=None): 27 | bn_val = [] 28 | max_val = [] 29 | for layer_to_prune in prunable_modules: 30 | # select a layer 31 | weight = layer_to_prune.weight.data.detach().cpu().numpy() 32 | max_val.append(max(weight)) 33 | bn_val.extend(weight) 34 | bn_val = np.abs(bn_val) 35 | max_val = np.abs(max_val) 36 | bn_val = sorted(bn_val) 37 | max_val = sorted(max_val) 38 | plt.hist(bn_val, bins=101, align="mid") 39 | if save_path is not None: 40 | if os.path.isfile(save_path): 41 | os.remove(save_path) 42 | plt.savefig(save_path) 43 | return bn_val, max_val 44 | 45 | def channel_prune(ori_model, example_inputs, output_transform, pruned_prob=0.3, thres=None): 46 | model = copy.deepcopy(ori_model) 47 | model.cpu().eval() 48 | 49 | prunable_module_type = (nn.BatchNorm2d) 50 | 51 | prunable_modules = [] 52 | for i, m in enumerate(model.modules()): 53 | if isinstance(m, prunable_module_type): 54 | prunable_modules.append(m) 55 | 56 | ori_size = tp.utils.count_params(model) 57 | DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs, output_transform=output_transform) 58 | 59 | bn_val, _ = bn_analyze(prunable_modules, os.path.splitext(opt.save_path)[0] + "_before_pruning.jpg") 60 | if thres is None: 61 | print('Recalculating thresh') 62 | thres_pos = int(pruned_prob * len(bn_val)) 63 | thres_pos = min(thres_pos, len(bn_val)-1) 64 | thres_pos = max(thres_pos, 0) 65 | thres = bn_val[thres_pos] 66 | print("Min val is %f, Max val is %f, Thres is %f" % (bn_val[0], bn_val[-1], thres)) 67 | 68 | for layer_to_prune in prunable_modules: 69 | # select a layer 70 | weight = layer_to_prune.weight.data.detach().cpu().numpy() 71 | prune_fn = tp.prune_batchnorm 72 | L1_norm = np.abs(weight) 73 | 74 | pos = np.array([i for i in range(len(L1_norm))]) 75 | pruned_idx_mask = L1_norm < thres 76 | prun_index = pos[pruned_idx_mask].tolist() 77 | if len(prun_index) == len(L1_norm): 78 | del prun_index[np.argmax(L1_norm)] 79 | 80 | plan = DG.get_pruning_plan(layer_to_prune, tp.prune_batchnorm, prun_index) 81 | plan.exec() 82 | 83 | bn_analyze(prunable_modules, os.path.splitext(opt.save_path)[0] + "_after_pruning.jpg") 84 | 85 | model.train() 86 | ori_model.train() 87 | with torch.no_grad(): 88 | out = model(example_inputs) 89 | out2 = ori_model(example_inputs) 90 | if output_transform: 91 | out = output_transform(out) 92 | out2 = output_transform(out2) 93 | print(" Params: %s => %s" % (ori_size, tp.utils.count_params(model))) 94 | if isinstance(out, (list, tuple)): 95 | for o, o2 in zip(out, out2): 96 | print(" Output: ", o.shape) 97 | assert o.shape == o2.shape, f'{o.shape} {o2.shape}' 98 | else: 99 | print(" Output: ", out.shape) 100 | assert out.shape == out2.shape, f'{o.shape} {o2.shape}' 101 | print("------------------------------------------------------\n") 102 | return model 103 | 104 | if __name__ == '__main__': 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument('--weights', type=str, help='') 107 | parser.add_argument('--save_path', default="", type=str, help='') 108 | parser.add_argument('-p', '--prob', default=0.5, type=float, help='pruning prob') 109 | parser.add_argument('-t', '--thres', default=0, type=float, help='pruning thres') 110 | parser.add_argument('--shape', nargs='+', type=int, default=[1, 3, 640, 640]) 111 | opt = parser.parse_args() 112 | 113 | weights = opt.weights 114 | if not opt.save_path.endswith('.pt'): 115 | save_dir = opt.save_path if os.path.isdir(opt.save_path) else os.path.dirname(os.path.abspath(weights)) 116 | save_name = os.path.splitext(os.path.basename(weights))[0] + '_pruned.pt' 117 | opt.save_path = os.path.join(save_dir, save_name) 118 | 119 | device = torch.device('cpu') 120 | model = load_model(weights) 121 | 122 | example_inputs = torch.zeros(opt.shape, dtype=torch.float32).to(device) 123 | output_transform = None 124 | # for prob in [0, 0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]: 125 | if opt.thres != 0: 126 | thres = opt.thres 127 | prob = None 128 | else: 129 | thres = None 130 | prob = opt.prob 131 | 132 | pruned_model = channel_prune(model, example_inputs=example_inputs, 133 | output_transform=output_transform, pruned_prob=prob, thres=thres) 134 | pruned_model.model[-1].export = False 135 | 136 | ckpt = { 137 | 'model': copy.deepcopy(pruned_model.module if is_parallel(pruned_model) else pruned_model).half(), 138 | 'optimizer': None, 139 | 'epoch': -1, 140 | } 141 | torch.save(ckpt, opt.save_path) 142 | print("Saved", opt.save_path) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, GPL-3.0 license 2 | """ 3 | Train a YOLOv5 model on a custom dataset 4 | 5 | Usage: 6 | $ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640 7 | """ 8 | 9 | import argparse 10 | import logging 11 | import math 12 | import os 13 | import random 14 | import sys 15 | import time 16 | from copy import deepcopy 17 | from pathlib import Path 18 | 19 | import numpy as np 20 | import torch 21 | import torch.distributed as dist 22 | import torch.nn as nn 23 | import yaml 24 | from torch.cuda import amp 25 | from torch.nn.parallel import DistributedDataParallel as DDP 26 | from torch.optim import Adam, SGD, lr_scheduler 27 | from tqdm import tqdm 28 | 29 | FILE = Path(__file__).resolve() 30 | ROOT = FILE.parents[0] # YOLOv5 root directory 31 | if str(ROOT) not in sys.path: 32 | sys.path.append(str(ROOT)) # add ROOT to PATH 33 | 34 | import val # for end-of-epoch mAP 35 | from models.experimental import attempt_load 36 | from models.yolo import Model 37 | from utils.autoanchor import check_anchors 38 | from utils.datasets import create_dataloader 39 | from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \ 40 | strip_optimizer, get_latest_run, check_dataset, check_git_status, check_img_size, check_requirements, \ 41 | check_file, check_yaml, check_suffix, print_args, print_mutation, set_logging, one_cycle, colorstr, methods 42 | from utils.downloads import attempt_download 43 | from utils.loss import ComputeLoss 44 | from utils.plots import plot_labels, plot_evolve 45 | from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, intersect_dicts, select_device, \ 46 | torch_distributed_zero_first 47 | from utils.loggers.wandb.wandb_utils import check_wandb_resume 48 | from utils.metrics import fitness 49 | from utils.loggers import Loggers 50 | from utils.callbacks import Callbacks 51 | 52 | LOGGER = logging.getLogger(__name__) 53 | LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html 54 | RANK = int(os.getenv('RANK', -1)) 55 | WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) 56 | 57 | 58 | def train(hyp, # path/to/hyp.yaml or hyp dictionary 59 | opt, 60 | device, 61 | callbacks 62 | ): 63 | save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \ 64 | Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \ 65 | opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze 66 | 67 | # Directories 68 | w = save_dir / 'weights' # weights dir 69 | (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir 70 | last, best = w / 'last.pt', w / 'best.pt' 71 | 72 | # Hyperparameters 73 | if isinstance(hyp, str): 74 | with open(hyp) as f: 75 | hyp = yaml.safe_load(f) # load hyps dict 76 | LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) 77 | 78 | # Save run settings 79 | with open(save_dir / 'hyp.yaml', 'w') as f: 80 | yaml.safe_dump(hyp, f, sort_keys=False) 81 | with open(save_dir / 'opt.yaml', 'w') as f: 82 | yaml.safe_dump(vars(opt), f, sort_keys=False) 83 | data_dict = None 84 | 85 | # Loggers 86 | if RANK in [-1, 0]: 87 | loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance 88 | if loggers.wandb: 89 | data_dict = loggers.wandb.data_dict 90 | if resume: 91 | weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp 92 | 93 | # Register actions 94 | for k in methods(loggers): 95 | callbacks.register_action(k, callback=getattr(loggers, k)) 96 | 97 | # Config 98 | plots = not evolve # create plots 99 | cuda = device.type != 'cpu' 100 | init_seeds(1 + RANK) 101 | with torch_distributed_zero_first(RANK): 102 | data_dict = data_dict or check_dataset(data) # check if None 103 | train_path, val_path = data_dict['train'], data_dict['val'] 104 | nc = 1 if single_cls else int(data_dict['nc']) # number of classes 105 | names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names 106 | assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check 107 | is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset 108 | 109 | # Model 110 | check_suffix(weights, '.pt') # check weights 111 | pretrained = weights.endswith('.pt') 112 | if pretrained and not opt.ft_pruned: 113 | with torch_distributed_zero_first(RANK): 114 | weights = attempt_download(weights) # download if not found locally 115 | ckpt = torch.load(weights, map_location=device) # load checkpoint 116 | model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create 117 | exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys 118 | csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 119 | csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect 120 | model.load_state_dict(csd, strict=False) # load 121 | LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report 122 | elif pretrained and opt.ft_pruned: # Fine-tuning pruned model 123 | print("[FT] Restoring pruned model", weights) 124 | ckpt = torch.load(weights, map_location=device) # load checkpoint 125 | model = ckpt["model"].float() 126 | model.info() 127 | csd = {} 128 | else: 129 | model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create 130 | 131 | # Freeze 132 | freeze = [f'model.{x}.' for x in range(freeze)] # layers to freeze 133 | for k, v in model.named_parameters(): 134 | v.requires_grad = True # train all layers 135 | if any(x in k for x in freeze): 136 | print(f'freezing {k}') 137 | v.requires_grad = False 138 | 139 | # Optimizer 140 | nbs = 64 # nominal batch size 141 | accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing 142 | hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay 143 | LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}") 144 | 145 | # Sparse learning 146 | if opt.sl_factor > 0: 147 | print("[SL] Using sparse learning") 148 | hyp['sl'] = opt.sl_factor * batch_size * accumulate / nbs 149 | 150 | # ignore_idx = [230, 260, 290] 151 | prunable_module_type = (nn.BatchNorm2d) 152 | prunable_modules = [] 153 | for i, m in enumerate(model.modules()): 154 | # if i in ignore_idx: 155 | # continue 156 | if isinstance(m, prunable_module_type): 157 | prunable_modules.append(m) 158 | 159 | g0, g1, g2 = [], [], [] # optimizer parameter groups 160 | for v in model.modules(): 161 | if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias 162 | g2.append(v.bias) 163 | if isinstance(v, nn.BatchNorm2d): # weight (no decay) 164 | g0.append(v.weight) 165 | elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay) 166 | g1.append(v.weight) 167 | 168 | if opt.adam: 169 | optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum 170 | else: 171 | optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True) 172 | 173 | optimizer.add_param_group({'params': g1, 'weight_decay': hyp['weight_decay']}) # add g1 with weight_decay 174 | optimizer.add_param_group({'params': g2}) # add g2 (biases) 175 | LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups " 176 | f"{len(g0)} weight, {len(g1)} weight (no decay), {len(g2)} bias") 177 | del g0, g1, g2 178 | 179 | # Scheduler 180 | if opt.linear_lr: 181 | lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear 182 | else: 183 | lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf'] 184 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs) 185 | 186 | # EMA 187 | ema = ModelEMA(model) if RANK in [-1, 0] else None 188 | 189 | # Resume 190 | start_epoch, best_fitness = 0, 0.0 191 | if pretrained: 192 | # Optimizer 193 | if ckpt['optimizer'] is not None: 194 | optimizer.load_state_dict(ckpt['optimizer']) 195 | best_fitness = ckpt['best_fitness'] 196 | 197 | # EMA 198 | if ema and ckpt.get('ema'): 199 | ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) 200 | ema.updates = ckpt['updates'] 201 | 202 | # Epochs 203 | start_epoch = ckpt['epoch'] + 1 204 | if resume: 205 | assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.' 206 | if epochs < start_epoch: 207 | LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.") 208 | epochs += ckpt['epoch'] # finetune additional epochs 209 | 210 | del ckpt, csd 211 | 212 | # Image sizes 213 | gs = max(int(model.stride.max()), 32) # grid size (max stride) 214 | nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj']) 215 | imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple 216 | 217 | # DP mode 218 | if cuda and RANK == -1 and torch.cuda.device_count() > 1: 219 | logging.warning('DP not recommended, instead use torch.distributed.run for best DDP Multi-GPU results.\n' 220 | 'See Multi-GPU Tutorial at https://github.com/ultralytics/yolov5/issues/475 to get started.') 221 | model = torch.nn.DataParallel(model) 222 | 223 | # SyncBatchNorm 224 | if opt.sync_bn and cuda and RANK != -1: 225 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) 226 | LOGGER.info('Using SyncBatchNorm()') 227 | 228 | # Trainloader 229 | train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls, 230 | hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=RANK, 231 | workers=workers, image_weights=opt.image_weights, quad=opt.quad, 232 | prefix=colorstr('train: ')) 233 | mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class 234 | nb = len(train_loader) # number of batches 235 | assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}' 236 | 237 | # Process 0 238 | if RANK in [-1, 0]: 239 | val_loader = create_dataloader(val_path, imgsz, batch_size // WORLD_SIZE * 2, gs, single_cls, 240 | hyp=hyp, cache=None if noval else opt.cache, rect=True, rank=-1, 241 | workers=workers, pad=0.5, 242 | prefix=colorstr('val: '))[0] 243 | 244 | if not resume: 245 | labels = np.concatenate(dataset.labels, 0) 246 | # c = torch.tensor(labels[:, 0]) # classes 247 | # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency 248 | # model._initialize_biases(cf.to(device)) 249 | if plots: 250 | plot_labels(labels, names, save_dir) 251 | 252 | # Anchors 253 | if not opt.noautoanchor: 254 | check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) 255 | model.half().float() # pre-reduce anchor precision 256 | 257 | callbacks.run('on_pretrain_routine_end') 258 | 259 | # DDP mode 260 | if cuda and RANK != -1: 261 | model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) 262 | 263 | # Model parameters 264 | hyp['box'] *= 3. / nl # scale to layers 265 | hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers 266 | hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers 267 | hyp['label_smoothing'] = opt.label_smoothing 268 | model.nc = nc # attach number of classes to model 269 | model.hyp = hyp # attach hyperparameters to model 270 | model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights 271 | model.names = names 272 | 273 | # Start training 274 | t0 = time.time() 275 | nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations) 276 | # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training 277 | last_opt_step = -1 278 | maps = np.zeros(nc) # mAP per class 279 | results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) 280 | scheduler.last_epoch = start_epoch - 1 # do not move 281 | scaler = amp.GradScaler(enabled=cuda) 282 | stopper = EarlyStopping(patience=opt.patience) 283 | compute_loss = ComputeLoss(model) # init loss class 284 | LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n' 285 | f'Using {train_loader.num_workers} dataloader workers\n' 286 | f"Logging results to {colorstr('bold', save_dir)}\n" 287 | f'Starting training for {epochs} epochs...') 288 | for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ 289 | model.train() 290 | 291 | # Update image weights (optional, single-GPU only) 292 | if opt.image_weights: 293 | cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights 294 | iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights 295 | dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx 296 | 297 | # Update mosaic border (optional) 298 | # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs) 299 | # dataset.mosaic_border = [b - imgsz, -b] # height, width borders 300 | 301 | mloss = torch.zeros(3, device=device) # mean losses 302 | if opt.sl_factor > 0: sl_mloss = torch.zeros(1, device=device) 303 | if RANK != -1: 304 | train_loader.sampler.set_epoch(epoch) 305 | pbar = enumerate(train_loader) 306 | LOGGER.info(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size')) 307 | if RANK in [-1, 0]: 308 | pbar = tqdm(pbar, total=nb) # progress bar 309 | optimizer.zero_grad() 310 | for i, (imgs, targets, paths, _) in pbar: # batch ------------------------------------------------------------- 311 | ni = i + nb * epoch # number integrated batches (since train start) 312 | imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0 313 | 314 | # Warmup 315 | if ni <= nw: 316 | xi = [0, nw] # x interp 317 | # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou) 318 | accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round()) 319 | for j, x in enumerate(optimizer.param_groups): 320 | # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 321 | x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)]) 322 | if 'momentum' in x: 323 | x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']]) 324 | 325 | # Multi-scale 326 | if opt.multi_scale: 327 | sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size 328 | sf = sz / max(imgs.shape[2:]) # scale factor 329 | if sf != 1: 330 | ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple) 331 | imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False) 332 | 333 | # Forward 334 | with amp.autocast(enabled=cuda): 335 | pred = model(imgs) # forward 336 | loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size 337 | if RANK != -1: 338 | loss *= WORLD_SIZE # gradient averaged between devices in DDP mode 339 | if opt.sl_factor > 0: 340 | sl_loss, sl_loss_items = compute_loss.sl_loss(pred, prunable_modules) 341 | loss += sl_loss 342 | if opt.quad: 343 | loss *= 4. 344 | 345 | # Backward 346 | scaler.scale(loss).backward() 347 | 348 | # Optimize 349 | if ni - last_opt_step >= accumulate: 350 | scaler.step(optimizer) # optimizer.step 351 | scaler.update() 352 | optimizer.zero_grad() 353 | if ema: 354 | ema.update(model) 355 | last_opt_step = ni 356 | 357 | # Log 358 | if RANK in [-1, 0]: 359 | mloss = (mloss * i + loss_items) / (i + 1) # update mean losses 360 | if opt.sl_factor > 0: sl_mloss = (sl_mloss * i + sl_loss_items) / (i + 1) 361 | mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB) 362 | pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % ( 363 | f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])) 364 | callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots, opt.sync_bn) 365 | # end batch ------------------------------------------------------------------------------------------------ 366 | 367 | # Scheduler 368 | lr = [x['lr'] for x in optimizer.param_groups] # for loggers 369 | scheduler.step() 370 | 371 | if RANK in [-1, 0]: 372 | # mAP 373 | callbacks.run('on_train_epoch_end', epoch=epoch) 374 | ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights']) 375 | final_epoch = (epoch + 1 == epochs) or stopper.possible_stop 376 | if not noval or final_epoch: # Calculate mAP 377 | results, maps, _ = val.run(data_dict, 378 | batch_size=batch_size // WORLD_SIZE * 2, 379 | imgsz=imgsz, 380 | model=ema.ema, 381 | single_cls=single_cls, 382 | dataloader=val_loader, 383 | save_dir=save_dir, 384 | save_json=is_coco and final_epoch, 385 | verbose=nc < 50 and final_epoch, 386 | plots=plots and final_epoch, 387 | callbacks=callbacks, 388 | compute_loss=compute_loss) 389 | if opt.sl_factor > 0: 390 | LOGGER.info(('%10s') % ('sl')) 391 | LOGGER.info(('%10.4g') % (*sl_mloss,)) 392 | 393 | # Update best mAP 394 | fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95] 395 | if fi > best_fitness: 396 | best_fitness = fi 397 | log_vals = list(mloss) + list(results) + lr 398 | callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi) 399 | 400 | # Save model 401 | if (not nosave) or (final_epoch and not evolve): # if save 402 | ckpt = {'epoch': epoch, 403 | 'best_fitness': best_fitness, 404 | 'model': deepcopy(de_parallel(model)).half(), 405 | 'ema': deepcopy(ema.ema).half(), 406 | 'updates': ema.updates, 407 | 'optimizer': optimizer.state_dict(), 408 | 'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None} 409 | 410 | # Save last, best and delete 411 | torch.save(ckpt, last) 412 | if best_fitness == fi: 413 | torch.save(ckpt, best) 414 | del ckpt 415 | callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi) 416 | 417 | # Stop Single-GPU 418 | if RANK == -1 and stopper(epoch=epoch, fitness=fi): 419 | break 420 | 421 | # Stop DDP TODO: known issues shttps://github.com/ultralytics/yolov5/pull/4576 422 | # stop = stopper(epoch=epoch, fitness=fi) 423 | # if RANK == 0: 424 | # dist.broadcast_object_list([stop], 0) # broadcast 'stop' to all ranks 425 | 426 | # Stop DPP 427 | # with torch_distributed_zero_first(RANK): 428 | # if stop: 429 | # break # must break all DDP ranks 430 | 431 | # end epoch ---------------------------------------------------------------------------------------------------- 432 | # end training ----------------------------------------------------------------------------------------------------- 433 | if RANK in [-1, 0]: 434 | LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.') 435 | if not evolve: 436 | if is_coco: # COCO dataset 437 | for m in [last, best] if best.exists() else [last]: # speed, mAP tests 438 | results, _, _ = val.run(data_dict, 439 | batch_size=batch_size // WORLD_SIZE * 2, 440 | imgsz=imgsz, 441 | model=attempt_load(m, device).half(), 442 | iou_thres=0.7, # NMS IoU threshold for best pycocotools results 443 | single_cls=single_cls, 444 | dataloader=val_loader, 445 | save_dir=save_dir, 446 | save_json=True, 447 | plots=False) 448 | # Strip optimizers 449 | for f in last, best: 450 | if f.exists(): 451 | strip_optimizer(f) # strip optimizers 452 | callbacks.run('on_train_end', last, best, plots, epoch) 453 | LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}") 454 | 455 | torch.cuda.empty_cache() 456 | return results 457 | 458 | 459 | def parse_opt(known=False): 460 | parser = argparse.ArgumentParser() 461 | parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path') 462 | parser.add_argument('--cfg', type=str, default='', help='model.yaml path') 463 | parser.add_argument('--data', type=str, default='data/coco128.yaml', help='dataset.yaml path') 464 | parser.add_argument('--hyp', type=str, default='data/hyps/hyp.scratch.yaml', help='hyperparameters path') 465 | parser.add_argument('--epochs', type=int, default=300) 466 | parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs') 467 | parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)') 468 | parser.add_argument('--rect', action='store_true', help='rectangular training') 469 | parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training') 470 | parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') 471 | parser.add_argument('--noval', action='store_true', help='only validate final epoch') 472 | parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check') 473 | parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations') 474 | parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') 475 | parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"') 476 | parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training') 477 | parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') 478 | parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') 479 | parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class') 480 | parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer') 481 | parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode') 482 | parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers') 483 | parser.add_argument('--project', default='runs/train', help='save to project/name') 484 | parser.add_argument('--entity', default=None, help='W&B entity') 485 | parser.add_argument('--name', default='exp', help='save to project/name') 486 | parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') 487 | parser.add_argument('--quad', action='store_true', help='quad dataloader') 488 | parser.add_argument('--linear-lr', action='store_true', help='linear LR') 489 | parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon') 490 | parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table') 491 | parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B') 492 | parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch') 493 | parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used') 494 | parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') 495 | parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24') 496 | parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)') 497 | 498 | parser.add_argument('--sl_factor', type=float, default=0, help='sparse learning factor') 499 | parser.add_argument('--ft_pruned', action='store_true', help='fine-tune pruned model') 500 | 501 | opt = parser.parse_known_args()[0] if known else parser.parse_args() 502 | return opt 503 | 504 | 505 | def main(opt, callbacks=Callbacks()): 506 | # Checks 507 | set_logging(RANK) 508 | if RANK in [-1, 0]: 509 | print_args(FILE.stem, opt) 510 | check_git_status() 511 | check_requirements(exclude=['thop']) 512 | 513 | # Resume 514 | if opt.resume and not check_wandb_resume(opt) and not opt.evolve: # resume an interrupted run 515 | ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path 516 | assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' 517 | with open(Path(ckpt).parent.parent / 'opt.yaml') as f: 518 | opt = argparse.Namespace(**yaml.safe_load(f)) # replace 519 | opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate 520 | LOGGER.info(f'Resuming training from {ckpt}') 521 | else: 522 | opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp) # check YAMLs 523 | assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified' 524 | if opt.evolve: 525 | opt.project = 'runs/evolve' 526 | opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume 527 | opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) 528 | 529 | # DDP mode 530 | device = select_device(opt.device, batch_size=opt.batch_size) 531 | if LOCAL_RANK != -1: 532 | from datetime import timedelta 533 | assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command' 534 | assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count' 535 | assert not opt.image_weights, '--image-weights argument is not compatible with DDP training' 536 | assert not opt.evolve, '--evolve argument is not compatible with DDP training' 537 | torch.cuda.set_device(LOCAL_RANK) 538 | device = torch.device('cuda', LOCAL_RANK) 539 | dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo") 540 | 541 | # Train 542 | if not opt.evolve: 543 | train(opt.hyp, opt, device, callbacks) 544 | if WORLD_SIZE > 1 and RANK == 0: 545 | LOGGER.info('Destroying process group... ') 546 | dist.destroy_process_group() 547 | 548 | # Evolve hyperparameters (optional) 549 | else: 550 | # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit) 551 | meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3) 552 | 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) 553 | 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1 554 | 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay 555 | 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok) 556 | 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum 557 | 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr 558 | 'box': (1, 0.02, 0.2), # box loss gain 559 | 'cls': (1, 0.2, 4.0), # cls loss gain 560 | 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight 561 | 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels) 562 | 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight 563 | 'iou_t': (0, 0.1, 0.7), # IoU training threshold 564 | 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold 565 | 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore) 566 | 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5) 567 | 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction) 568 | 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction) 569 | 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction) 570 | 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg) 571 | 'translate': (1, 0.0, 0.9), # image translation (+/- fraction) 572 | 'scale': (1, 0.0, 0.9), # image scale (+/- gain) 573 | 'shear': (1, 0.0, 10.0), # image shear (+/- deg) 574 | 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 575 | 'flipud': (1, 0.0, 1.0), # image flip up-down (probability) 576 | 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability) 577 | 'mosaic': (1, 0.0, 1.0), # image mixup (probability) 578 | 'mixup': (1, 0.0, 1.0), # image mixup (probability) 579 | 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability) 580 | 581 | with open(opt.hyp) as f: 582 | hyp = yaml.safe_load(f) # load hyps dict 583 | if 'anchors' not in hyp: # anchors commented in hyp.yaml 584 | hyp['anchors'] = 3 585 | opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch 586 | # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices 587 | evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv' 588 | if opt.bucket: 589 | os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {save_dir}') # download evolve.csv if exists 590 | 591 | for _ in range(opt.evolve): # generations to evolve 592 | if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate 593 | # Select parent(s) 594 | parent = 'single' # parent selection method: 'single' or 'weighted' 595 | x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1) 596 | n = min(5, len(x)) # number of previous results to consider 597 | x = x[np.argsort(-fitness(x))][:n] # top n mutations 598 | w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0) 599 | if parent == 'single' or len(x) == 1: 600 | # x = x[random.randint(0, n - 1)] # random selection 601 | x = x[random.choices(range(n), weights=w)[0]] # weighted selection 602 | elif parent == 'weighted': 603 | x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination 604 | 605 | # Mutate 606 | mp, s = 0.8, 0.2 # mutation probability, sigma 607 | npr = np.random 608 | npr.seed(int(time.time())) 609 | g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1 610 | ng = len(meta) 611 | v = np.ones(ng) 612 | while all(v == 1): # mutate until a change occurs (prevent duplicates) 613 | v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0) 614 | for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300) 615 | hyp[k] = float(x[i + 7] * v[i]) # mutate 616 | 617 | # Constrain to limits 618 | for k, v in meta.items(): 619 | hyp[k] = max(hyp[k], v[1]) # lower limit 620 | hyp[k] = min(hyp[k], v[2]) # upper limit 621 | hyp[k] = round(hyp[k], 5) # significant digits 622 | 623 | # Train mutation 624 | results = train(hyp.copy(), opt, device, callbacks) 625 | 626 | # Write mutation results 627 | print_mutation(results, hyp.copy(), save_dir, opt.bucket) 628 | 629 | # Plot results 630 | plot_evolve(evolve_csv) 631 | print(f'Hyperparameter evolution finished\n' 632 | f"Results saved to {colorstr('bold', save_dir)}\n" 633 | f'Use best hyperparameters example: $ python train.py --hyp {evolve_yaml}') 634 | 635 | 636 | def run(**kwargs): 637 | # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt') 638 | opt = parse_opt(True) 639 | for k, v in kwargs.items(): 640 | setattr(opt, k, v) 641 | main(opt) 642 | 643 | 644 | if __name__ == "__main__": 645 | opt = parse_opt() 646 | main(opt) 647 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, GPL-3.0 license 2 | """ 3 | Loss functions 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from utils.metrics import bbox_iou 10 | from utils.torch_utils import is_parallel 11 | 12 | 13 | def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441 14 | # return positive, negative label smoothing BCE targets 15 | return 1.0 - 0.5 * eps, 0.5 * eps 16 | 17 | 18 | class BCEBlurWithLogitsLoss(nn.Module): 19 | # BCEwithLogitLoss() with reduced missing label effects. 20 | def __init__(self, alpha=0.05): 21 | super(BCEBlurWithLogitsLoss, self).__init__() 22 | self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none') # must be nn.BCEWithLogitsLoss() 23 | self.alpha = alpha 24 | 25 | def forward(self, pred, true): 26 | loss = self.loss_fcn(pred, true) 27 | pred = torch.sigmoid(pred) # prob from logits 28 | dx = pred - true # reduce only missing label effects 29 | # dx = (pred - true).abs() # reduce missing label and false label effects 30 | alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4)) 31 | loss *= alpha_factor 32 | return loss.mean() 33 | 34 | 35 | class FocalLoss(nn.Module): 36 | # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5) 37 | def __init__(self, loss_fcn, gamma=1.5, alpha=0.25): 38 | super(FocalLoss, self).__init__() 39 | self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss() 40 | self.gamma = gamma 41 | self.alpha = alpha 42 | self.reduction = loss_fcn.reduction 43 | self.loss_fcn.reduction = 'none' # required to apply FL to each element 44 | 45 | def forward(self, pred, true): 46 | loss = self.loss_fcn(pred, true) 47 | # p_t = torch.exp(-loss) 48 | # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability 49 | 50 | # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py 51 | pred_prob = torch.sigmoid(pred) # prob from logits 52 | p_t = true * pred_prob + (1 - true) * (1 - pred_prob) 53 | alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha) 54 | modulating_factor = (1.0 - p_t) ** self.gamma 55 | loss *= alpha_factor * modulating_factor 56 | 57 | if self.reduction == 'mean': 58 | return loss.mean() 59 | elif self.reduction == 'sum': 60 | return loss.sum() 61 | else: # 'none' 62 | return loss 63 | 64 | 65 | class QFocalLoss(nn.Module): 66 | # Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5) 67 | def __init__(self, loss_fcn, gamma=1.5, alpha=0.25): 68 | super(QFocalLoss, self).__init__() 69 | self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss() 70 | self.gamma = gamma 71 | self.alpha = alpha 72 | self.reduction = loss_fcn.reduction 73 | self.loss_fcn.reduction = 'none' # required to apply FL to each element 74 | 75 | def forward(self, pred, true): 76 | loss = self.loss_fcn(pred, true) 77 | 78 | pred_prob = torch.sigmoid(pred) # prob from logits 79 | alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha) 80 | modulating_factor = torch.abs(true - pred_prob) ** self.gamma 81 | loss *= alpha_factor * modulating_factor 82 | 83 | if self.reduction == 'mean': 84 | return loss.mean() 85 | elif self.reduction == 'sum': 86 | return loss.sum() 87 | else: # 'none' 88 | return loss 89 | 90 | 91 | class ComputeLoss: 92 | # Compute losses 93 | def __init__(self, model, autobalance=False): 94 | self.sort_obj_iou = False 95 | device = next(model.parameters()).device # get model device 96 | h = model.hyp # hyperparameters 97 | 98 | # Define criteria 99 | BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device)) 100 | BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device)) 101 | 102 | # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 103 | self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0)) # positive, negative BCE targets 104 | 105 | # Focal loss 106 | g = h['fl_gamma'] # focal loss gamma 107 | if g > 0: 108 | BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) 109 | 110 | det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module 111 | self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02]) # P3-P7 112 | self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index 113 | self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance 114 | for k in 'na', 'nc', 'nl', 'anchors': 115 | setattr(self, k, getattr(det, k)) 116 | 117 | def __call__(self, p, targets): # predictions, targets, model 118 | device = targets.device 119 | lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device) 120 | tcls, tbox, indices, anchors = self.build_targets(p, targets) # targets 121 | 122 | # Losses 123 | for i, pi in enumerate(p): # layer index, layer predictions 124 | b, a, gj, gi = indices[i] # image, anchor, gridy, gridx 125 | tobj = torch.zeros_like(pi[..., 0], device=device) # target obj 126 | 127 | n = b.shape[0] # number of targets 128 | if n: 129 | ps = pi[b, a, gj, gi] # prediction subset corresponding to targets 130 | 131 | # Regression 132 | pxy = ps[:, :2].sigmoid() * 2. - 0.5 133 | pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i] 134 | pbox = torch.cat((pxy, pwh), 1) # predicted box 135 | iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target) 136 | lbox += (1.0 - iou).mean() # iou loss 137 | 138 | # Objectness 139 | score_iou = iou.detach().clamp(0).type(tobj.dtype) 140 | if self.sort_obj_iou: 141 | sort_id = torch.argsort(score_iou) 142 | b, a, gj, gi, score_iou = b[sort_id], a[sort_id], gj[sort_id], gi[sort_id], score_iou[sort_id] 143 | tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * score_iou # iou ratio 144 | 145 | # Classification 146 | if self.nc > 1: # cls loss (only if multiple classes) 147 | t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets 148 | t[range(n), tcls[i]] = self.cp 149 | lcls += self.BCEcls(ps[:, 5:], t) # BCE 150 | 151 | # Append targets to text file 152 | # with open('targets.txt', 'a') as file: 153 | # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)] 154 | 155 | obji = self.BCEobj(pi[..., 4], tobj) 156 | lobj += obji * self.balance[i] # obj loss 157 | if self.autobalance: 158 | self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item() 159 | 160 | if self.autobalance: 161 | self.balance = [x / self.balance[self.ssi] for x in self.balance] 162 | lbox *= self.hyp['box'] 163 | lobj *= self.hyp['obj'] 164 | lcls *= self.hyp['cls'] 165 | bs = tobj.shape[0] # batch size 166 | 167 | return (lbox + lobj + lcls) * bs, torch.cat((lbox, lobj, lcls)).detach() 168 | 169 | def sl_loss(self, p, prunable_modules): 170 | # Sparse Learning 171 | device = p[0].device 172 | ll1 = torch.zeros(1, device=device) 173 | 174 | if prunable_modules is not None: 175 | for m in prunable_modules: 176 | ll1 += m.weight.norm(1) 177 | ll1 /= len(prunable_modules) 178 | 179 | ll1 *= self.hyp['sl'] 180 | bs = p[0].shape[0] # batch size 181 | 182 | loss = ll1 183 | return loss * bs, ll1.detach() 184 | 185 | def build_targets(self, p, targets): 186 | # Build targets for compute_loss(), input targets(image,class,x,y,w,h) 187 | na, nt = self.na, targets.shape[0] # number of anchors, targets 188 | tcls, tbox, indices, anch = [], [], [], [] 189 | gain = torch.ones(7, device=targets.device) # normalized to gridspace gain 190 | ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt) 191 | targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices 192 | 193 | g = 0.5 # bias 194 | off = torch.tensor([[0, 0], 195 | [1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m 196 | # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm 197 | ], device=targets.device).float() * g # offsets 198 | 199 | for i in range(self.nl): 200 | anchors = self.anchors[i] 201 | gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain 202 | 203 | # Match targets to anchors 204 | t = targets * gain 205 | if nt: 206 | # Matches 207 | r = t[:, :, 4:6] / anchors[:, None] # wh ratio 208 | j = torch.max(r, 1. / r).max(2)[0] < self.hyp['anchor_t'] # compare 209 | # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2)) 210 | t = t[j] # filter 211 | 212 | # Offsets 213 | gxy = t[:, 2:4] # grid xy 214 | gxi = gain[[2, 3]] - gxy # inverse 215 | j, k = ((gxy % 1. < g) & (gxy > 1.)).T 216 | l, m = ((gxi % 1. < g) & (gxi > 1.)).T 217 | j = torch.stack((torch.ones_like(j), j, k, l, m)) 218 | t = t.repeat((5, 1, 1))[j] 219 | offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j] 220 | else: 221 | t = targets[0] 222 | offsets = 0 223 | 224 | # Define 225 | b, c = t[:, :2].long().T # image, class 226 | gxy = t[:, 2:4] # grid xy 227 | gwh = t[:, 4:6] # grid wh 228 | gij = (gxy - offsets).long() 229 | gi, gj = gij.T # grid xy indices 230 | 231 | # Append 232 | a = t[:, 6].long() # anchor indices 233 | indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1))) # image, anchor, grid indices 234 | tbox.append(torch.cat((gxy - gij, gwh), 1)) # box 235 | anch.append(anchors[a]) # anchors 236 | tcls.append(c) # class 237 | 238 | return tcls, tbox, indices, anch 239 | --------------------------------------------------------------------------------