├── utils ├── __init__.py ├── aws │ ├── __init__.py │ ├── mime.sh │ ├── resume.py │ └── userdata.sh ├── wandb_logging │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── wandb_utils.cpython-37.pyc │ │ └── wandb_utils.cpython-38.pyc │ ├── log_dataset.py │ └── wandb_utils.py ├── google_app_engine │ ├── additional_requirements.txt │ ├── app.yaml │ └── Dockerfile ├── activations.py ├── google_utils.py ├── add_nms.py ├── autoanchor.py ├── metrics.py ├── torch_utils.py └── plots.py ├── models ├── __init__.py ├── experimental.py └── SwinTransformer.py ├── heatmaps ├── BloodImage_00261 │ ├── RBCs.png │ ├── WBCs.png │ ├── Platelets.png │ └── BloodImage_00261.jpg └── BloodImage_00340 │ ├── RBCs.png │ ├── WBCs.png │ └── BloodImage_00340.jpg ├── data ├── bccd.yaml ├── bcd.yaml ├── cbc.yaml ├── hyp.scratch.p5.yaml └── hyp.scratch.custom.yaml ├── requirements.txt ├── hubconf.py ├── cfg ├── ablation │ ├── w-mp.yaml │ ├── wo-welan.yaml │ ├── wo-cst.yaml │ └── wo-mcs.yaml ├── training │ └── cst-yolo.yaml └── baseline │ └── yolov7.yaml ├── export.py ├── detect.py ├── README.md └── test.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # init -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # init -------------------------------------------------------------------------------- /utils/aws/__init__.py: -------------------------------------------------------------------------------- 1 | #init -------------------------------------------------------------------------------- /utils/wandb_logging/__init__.py: -------------------------------------------------------------------------------- 1 | # init -------------------------------------------------------------------------------- /heatmaps/BloodImage_00261/RBCs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkang315/CST-YOLO/HEAD/heatmaps/BloodImage_00261/RBCs.png -------------------------------------------------------------------------------- /heatmaps/BloodImage_00261/WBCs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkang315/CST-YOLO/HEAD/heatmaps/BloodImage_00261/WBCs.png -------------------------------------------------------------------------------- /heatmaps/BloodImage_00340/RBCs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkang315/CST-YOLO/HEAD/heatmaps/BloodImage_00340/RBCs.png -------------------------------------------------------------------------------- /heatmaps/BloodImage_00340/WBCs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkang315/CST-YOLO/HEAD/heatmaps/BloodImage_00340/WBCs.png -------------------------------------------------------------------------------- /heatmaps/BloodImage_00261/Platelets.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkang315/CST-YOLO/HEAD/heatmaps/BloodImage_00261/Platelets.png -------------------------------------------------------------------------------- /data/bccd.yaml: -------------------------------------------------------------------------------- 1 | train: ./datasets/bccd/traindata 2 | val: ./datasets/bccd/validata 3 | 4 | nc: 3 5 | names: ['WBC', 'RBC', 'Platelets'] 6 | -------------------------------------------------------------------------------- /data/bcd.yaml: -------------------------------------------------------------------------------- 1 | train: ./datasets/bcd/train/images 2 | val: ./datasets/bcd/valid/images 3 | 4 | nc: 3 5 | names: ['WBC', 'RBC', 'Platelets'] 6 | -------------------------------------------------------------------------------- /heatmaps/BloodImage_00261/BloodImage_00261.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkang315/CST-YOLO/HEAD/heatmaps/BloodImage_00261/BloodImage_00261.jpg -------------------------------------------------------------------------------- /heatmaps/BloodImage_00340/BloodImage_00340.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkang315/CST-YOLO/HEAD/heatmaps/BloodImage_00340/BloodImage_00340.jpg -------------------------------------------------------------------------------- /data/cbc.yaml: -------------------------------------------------------------------------------- 1 | train: ./datasets/cbc/Training/Images 2 | val: ./datasets/cbc/Validation/Images 3 | 4 | nc: 3 5 | names: ['WBC', 'RBC', 'Platelets'] 6 | -------------------------------------------------------------------------------- /utils/wandb_logging/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkang315/CST-YOLO/HEAD/utils/wandb_logging/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/wandb_logging/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkang315/CST-YOLO/HEAD/utils/wandb_logging/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/google_app_engine/additional_requirements.txt: -------------------------------------------------------------------------------- 1 | # add these requirements in your app on top of the existing ones 2 | pip==18.1 3 | Flask==1.0.2 4 | gunicorn==19.9.0 5 | -------------------------------------------------------------------------------- /utils/wandb_logging/__pycache__/wandb_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkang315/CST-YOLO/HEAD/utils/wandb_logging/__pycache__/wandb_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/wandb_logging/__pycache__/wandb_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkang315/CST-YOLO/HEAD/utils/wandb_logging/__pycache__/wandb_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/google_app_engine/app.yaml: -------------------------------------------------------------------------------- 1 | runtime: custom 2 | env: flex 3 | 4 | service: yolorapp 5 | 6 | liveness_check: 7 | initial_delay_sec: 600 8 | 9 | manual_scaling: 10 | instances: 1 11 | resources: 12 | cpu: 1 13 | memory_gb: 4 14 | disk_size_gb: 20 -------------------------------------------------------------------------------- /utils/aws/mime.sh: -------------------------------------------------------------------------------- 1 | # AWS EC2 instance startup 'MIME' script https://aws.amazon.com/premiumsupport/knowledge-center/execute-user-data-ec2/ 2 | # This script will run on every instance restart, not only on first start 3 | # --- DO NOT COPY ABOVE COMMENTS WHEN PASTING INTO USERDATA --- 4 | 5 | Content-Type: multipart/mixed; boundary="//" 6 | MIME-Version: 1.0 7 | 8 | --// 9 | Content-Type: text/cloud-config; charset="us-ascii" 10 | MIME-Version: 1.0 11 | Content-Transfer-Encoding: 7bit 12 | Content-Disposition: attachment; filename="cloud-config.txt" 13 | 14 | #cloud-config 15 | cloud_final_modules: 16 | - [scripts-user, always] 17 | 18 | --// 19 | Content-Type: text/x-shellscript; charset="us-ascii" 20 | MIME-Version: 1.0 21 | Content-Transfer-Encoding: 7bit 22 | Content-Disposition: attachment; filename="userdata.txt" 23 | 24 | #!/bin/bash 25 | # --- paste contents of userdata.sh here --- 26 | --// 27 | -------------------------------------------------------------------------------- /utils/wandb_logging/log_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import yaml 4 | 5 | from wandb_utils import WandbLogger 6 | 7 | WANDB_ARTIFACT_PREFIX = 'wandb-artifact://' 8 | 9 | 10 | def create_dataset_artifact(opt): 11 | with open(opt.data) as f: 12 | data = yaml.load(f, Loader=yaml.SafeLoader) # data dict 13 | logger = WandbLogger(opt, '', None, data, job_type='Dataset Creation') 14 | 15 | 16 | if __name__ == '__main__': 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--data', type=str, default='data/coco.yaml', help='data.yaml path') 19 | parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') 20 | parser.add_argument('--project', type=str, default='YOLOR', help='name of W&B Project') 21 | opt = parser.parse_args() 22 | opt.resume = False # Explicitly disallow resume check for dataset upload job 23 | 24 | create_dataset_artifact(opt) 25 | -------------------------------------------------------------------------------- /utils/google_app_engine/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM gcr.io/google-appengine/python 2 | 3 | # Create a virtualenv for dependencies. This isolates these packages from 4 | # system-level packages. 5 | # Use -p python3 or -p python3.7 to select python version. Default is version 2. 6 | RUN virtualenv /env -p python3 7 | 8 | # Setting these environment variables are the same as running 9 | # source /env/bin/activate. 10 | ENV VIRTUAL_ENV /env 11 | ENV PATH /env/bin:$PATH 12 | 13 | RUN apt-get update && apt-get install -y python-opencv 14 | 15 | # Copy the application's requirements.txt and run pip to install all 16 | # dependencies into the virtualenv. 17 | ADD requirements.txt /app/requirements.txt 18 | RUN pip install -r /app/requirements.txt 19 | 20 | # Add the application source code. 21 | ADD . /app 22 | 23 | # Run a WSGI server to serve the application. gunicorn must be declared as 24 | # a dependency in requirements.txt. 25 | CMD gunicorn -b :$PORT main:app 26 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Recommend: Python <= 3.8, Torch <= 1.7.1, CUDA <= 11.1 2 | # Usage: pip install -r requirements.txt 3 | 4 | # Base ---------------------------------------- 5 | matplotlib>=3.2.2 6 | numpy>=1.18.5,<1.24.0 7 | opencv-python>=4.1.1 8 | Pillow>=7.1.2 9 | PyYAML>=5.3.1 10 | requests>=2.23.0 11 | scipy>=1.4.1 12 | torch>=1.7.0,!=1.12.0 13 | torchvision>=0.8.1,!=0.13.0 14 | tqdm>=4.41.0 15 | protobuf<4.21.3 16 | 17 | # Logging ------------------------------------- 18 | tensorboard>=2.4.1 19 | # wandb 20 | 21 | # Plotting ------------------------------------ 22 | pandas>=1.1.4 23 | seaborn>=0.11.0 24 | 25 | # Export -------------------------------------- 26 | # coremltools>=4.1 # CoreML export 27 | # onnx>=1.9.0 # ONNX export 28 | # onnx-simplifier>=0.3.6 # ONNX simplifier 29 | # scikit-learn==0.19.2 # CoreML quantization 30 | # tensorflow>=2.4.1 # TFLite export 31 | # tensorflowjs>=3.9.0 # TF.js export 32 | # openvino-dev # OpenVINO export 33 | 34 | # Extras -------------------------------------- 35 | ipython # interactive notebook 36 | psutil # system utilization 37 | thop # FLOPs computation 38 | # albumentations>=1.0.3 39 | # pycocotools>=2.0 # COCO mAP 40 | # roboflow 41 | -------------------------------------------------------------------------------- /utils/aws/resume.py: -------------------------------------------------------------------------------- 1 | # Resume all interrupted trainings in yolor/ dir including DDP trainings 2 | # Usage: $ python utils/aws/resume.py 3 | 4 | import os 5 | import sys 6 | from pathlib import Path 7 | 8 | import torch 9 | import yaml 10 | 11 | sys.path.append('./') # to run '$ python *.py' files in subdirectories 12 | 13 | port = 0 # --master_port 14 | path = Path('').resolve() 15 | for last in path.rglob('*/**/last.pt'): 16 | ckpt = torch.load(last) 17 | if ckpt['optimizer'] is None: 18 | continue 19 | 20 | # Load opt.yaml 21 | with open(last.parent.parent / 'opt.yaml') as f: 22 | opt = yaml.load(f, Loader=yaml.SafeLoader) 23 | 24 | # Get device count 25 | d = opt['device'].split(',') # devices 26 | nd = len(d) # number of devices 27 | ddp = nd > 1 or (nd == 0 and torch.cuda.device_count() > 1) # distributed data parallel 28 | 29 | if ddp: # multi-GPU 30 | port += 1 31 | cmd = f'python -m torch.distributed.launch --nproc_per_node {nd} --master_port {port} train.py --resume {last}' 32 | else: # single-GPU 33 | cmd = f'python train.py --resume {last}' 34 | 35 | cmd += ' > /dev/null 2>&1 &' # redirect output to dev/null and run in daemon thread 36 | print(cmd) 37 | os.system(cmd) 38 | -------------------------------------------------------------------------------- /utils/aws/userdata.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # AWS EC2 instance startup script https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/user-data.html 3 | # This script will run only once on first instance start (for a re-start script see mime.sh) 4 | # /home/ubuntu (ubuntu) or /home/ec2-user (amazon-linux) is working dir 5 | # Use >300 GB SSD 6 | 7 | cd home/ubuntu 8 | if [ ! -d yolor ]; then 9 | echo "Running first-time script." # install dependencies, download COCO, pull Docker 10 | git clone -b main https://github.com/WongKinYiu/yolov7 && sudo chmod -R 777 yolov7 11 | cd yolov7 12 | bash data/scripts/get_coco.sh && echo "Data done." & 13 | sudo docker pull nvcr.io/nvidia/pytorch:21.08-py3 && echo "Docker done." & 14 | python -m pip install --upgrade pip && pip install -r requirements.txt && python detect.py && echo "Requirements done." & 15 | wait && echo "All tasks done." # finish background tasks 16 | else 17 | echo "Running re-start script." # resume interrupted runs 18 | i=0 19 | list=$(sudo docker ps -qa) # container list i.e. $'one\ntwo\nthree\nfour' 20 | while IFS= read -r id; do 21 | ((i++)) 22 | echo "restarting container $i: $id" 23 | sudo docker start $id 24 | # sudo docker exec -it $id python train.py --resume # single-GPU 25 | sudo docker exec -d $id python utils/aws/resume.py # multi-scenario 26 | done <<<"$list" 27 | fi 28 | -------------------------------------------------------------------------------- /data/hyp.scratch.p5.yaml: -------------------------------------------------------------------------------- 1 | lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3) 2 | lrf: 0.1 # final OneCycleLR learning rate (lr0 * lrf) 3 | momentum: 0.937 # SGD momentum/Adam beta1 4 | weight_decay: 0.0005 # optimizer weight decay 5e-4 5 | warmup_epochs: 3.0 # warmup epochs (fractions ok) 6 | warmup_momentum: 0.8 # warmup initial momentum 7 | warmup_bias_lr: 0.1 # warmup initial bias lr 8 | box: 0.05 # box loss gain 9 | cls: 0.3 # cls loss gain 10 | cls_pw: 1.0 # cls BCELoss positive_weight 11 | obj: 0.7 # obj loss gain (scale with pixels) 12 | obj_pw: 1.0 # obj BCELoss positive_weight 13 | iou_t: 0.20 # IoU training threshold 14 | anchor_t: 4.0 # anchor-multiple threshold 15 | # anchors: 3 # anchors per output layer (0 to ignore) 16 | fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5) 17 | hsv_h: 0.015 # image HSV-Hue augmentation (fraction) 18 | hsv_s: 0.7 # image HSV-Saturation augmentation (fraction) 19 | hsv_v: 0.4 # image HSV-Value augmentation (fraction) 20 | degrees: 0.0 # image rotation (+/- deg) 21 | translate: 0.2 # image translation (+/- fraction) 22 | scale: 0.9 # image scale (+/- gain) 23 | shear: 0.0 # image shear (+/- deg) 24 | perspective: 0.0 # image perspective (+/- fraction), range 0-0.001 25 | flipud: 0.0 # image flip up-down (probability) 26 | fliplr: 0.5 # image flip left-right (probability) 27 | mosaic: 1.0 # image mosaic (probability) 28 | mixup: 0.15 # image mixup (probability) 29 | copy_paste: 0.0 # image copy paste (probability) 30 | paste_in: 0.15 # image copy paste (probability), use 0 for faster training 31 | loss_ota: 1 # use ComputeLossOTA, use 0 for faster training -------------------------------------------------------------------------------- /data/hyp.scratch.custom.yaml: -------------------------------------------------------------------------------- 1 | lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3) 2 | lrf: 0.1 # final OneCycleLR learning rate (lr0 * lrf) 3 | momentum: 0.937 # SGD momentum/Adam beta1 4 | weight_decay: 0.0005 # optimizer weight decay 5e-4 5 | warmup_epochs: 3.0 # warmup epochs (fractions ok) 6 | warmup_momentum: 0.8 # warmup initial momentum 7 | warmup_bias_lr: 0.1 # warmup initial bias lr 8 | box: 0.05 # box loss gain 9 | cls: 0.3 # cls loss gain 10 | cls_pw: 1.0 # cls BCELoss positive_weight 11 | obj: 0.7 # obj loss gain (scale with pixels) 12 | obj_pw: 1.0 # obj BCELoss positive_weight 13 | iou_t: 0.20 # IoU training threshold 14 | anchor_t: 4.0 # anchor-multiple threshold 15 | # anchors: 3 # anchors per output layer (0 to ignore) 16 | fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5) 17 | hsv_h: 0.015 # image HSV-Hue augmentation (fraction) 18 | hsv_s: 0.7 # image HSV-Saturation augmentation (fraction) 19 | hsv_v: 0.4 # image HSV-Value augmentation (fraction) 20 | degrees: 0.0 # image rotation (+/- deg) 21 | translate: 0.2 # image translation (+/- fraction) 22 | scale: 0.5 # image scale (+/- gain) 23 | shear: 0.0 # image shear (+/- deg) 24 | perspective: 0.0 # image perspective (+/- fraction), range 0-0.001 25 | flipud: 0.0 # image flip up-down (probability) 26 | fliplr: 0.5 # image flip left-right (probability) 27 | mosaic: 1.0 # image mosaic (probability) 28 | mixup: 0.0 # image mixup (probability) 29 | copy_paste: 0.0 # image copy paste (probability) 30 | paste_in: 0.0 # image copy paste (probability), use 0 for faster training 31 | loss_ota: 1 # use ComputeLossOTA, use 0 for faster training -------------------------------------------------------------------------------- /utils/activations.py: -------------------------------------------------------------------------------- 1 | # Activation functions 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | # SiLU https://arxiv.org/pdf/1606.08415.pdf ---------------------------------------------------------------------------- 9 | class SiLU(nn.Module): # export-friendly version of nn.SiLU() 10 | @staticmethod 11 | def forward(x): 12 | return x * torch.sigmoid(x) 13 | 14 | 15 | class Hardswish(nn.Module): # export-friendly version of nn.Hardswish() 16 | @staticmethod 17 | def forward(x): 18 | # return x * F.hardsigmoid(x) # for torchscript and CoreML 19 | return x * F.hardtanh(x + 3, 0., 6.) / 6. # for torchscript, CoreML and ONNX 20 | 21 | 22 | class MemoryEfficientSwish(nn.Module): 23 | class F(torch.autograd.Function): 24 | @staticmethod 25 | def forward(ctx, x): 26 | ctx.save_for_backward(x) 27 | return x * torch.sigmoid(x) 28 | 29 | @staticmethod 30 | def backward(ctx, grad_output): 31 | x = ctx.saved_tensors[0] 32 | sx = torch.sigmoid(x) 33 | return grad_output * (sx * (1 + x * (1 - sx))) 34 | 35 | def forward(self, x): 36 | return self.F.apply(x) 37 | 38 | 39 | # Mish https://github.com/digantamisra98/Mish -------------------------------------------------------------------------- 40 | class Mish(nn.Module): 41 | @staticmethod 42 | def forward(x): 43 | return x * F.softplus(x).tanh() 44 | 45 | 46 | class MemoryEfficientMish(nn.Module): 47 | class F(torch.autograd.Function): 48 | @staticmethod 49 | def forward(ctx, x): 50 | ctx.save_for_backward(x) 51 | return x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x))) 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): 55 | x = ctx.saved_tensors[0] 56 | sx = torch.sigmoid(x) 57 | fx = F.softplus(x).tanh() 58 | return grad_output * (fx + x * sx * (1 - fx * fx)) 59 | 60 | def forward(self, x): 61 | return self.F.apply(x) 62 | 63 | 64 | # FReLU https://arxiv.org/abs/2007.11824 ------------------------------------------------------------------------------- 65 | class FReLU(nn.Module): 66 | def __init__(self, c1, k=3): # ch_in, kernel 67 | super().__init__() 68 | self.conv = nn.Conv2d(c1, c1, k, 1, 1, groups=c1, bias=False) 69 | self.bn = nn.BatchNorm2d(c1) 70 | 71 | def forward(self, x): 72 | return torch.max(x, self.bn(self.conv(x))) 73 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | """PyTorch Hub models 2 | 3 | Usage: 4 | import torch 5 | model = torch.hub.load('repo', 'model') 6 | """ 7 | 8 | from pathlib import Path 9 | 10 | import torch 11 | 12 | from models.yolo import Model 13 | from utils.general import check_requirements, set_logging 14 | from utils.google_utils import attempt_download 15 | from utils.torch_utils import select_device 16 | 17 | dependencies = ['torch', 'yaml'] 18 | check_requirements(Path(__file__).parent / 'requirements.txt', exclude=('pycocotools', 'thop')) 19 | set_logging() 20 | 21 | 22 | def create(name, pretrained, channels, classes, autoshape): 23 | """Creates a specified model 24 | 25 | Arguments: 26 | name (str): name of model, i.e. 'yolov7' 27 | pretrained (bool): load pretrained weights into the model 28 | channels (int): number of input channels 29 | classes (int): number of model classes 30 | 31 | Returns: 32 | pytorch model 33 | """ 34 | try: 35 | cfg = list((Path(__file__).parent / 'cfg').rglob(f'{name}.yaml'))[0] # model.yaml path 36 | model = Model(cfg, channels, classes) 37 | if pretrained: 38 | fname = f'{name}.pt' # checkpoint filename 39 | attempt_download(fname) # download if not found locally 40 | ckpt = torch.load(fname, map_location=torch.device('cpu')) # load 41 | msd = model.state_dict() # model state_dict 42 | csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 43 | csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter 44 | model.load_state_dict(csd, strict=False) # load 45 | if len(ckpt['model'].names) == classes: 46 | model.names = ckpt['model'].names # set class names attribute 47 | if autoshape: 48 | model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS 49 | device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available 50 | return model.to(device) 51 | 52 | except Exception as e: 53 | s = 'Cache maybe be out of date, try force_reload=True.' 54 | raise Exception(s) from e 55 | 56 | 57 | def custom(path_or_model='path/to/model.pt', autoshape=True): 58 | """custom mode 59 | 60 | Arguments (3 options): 61 | path_or_model (str): 'path/to/model.pt' 62 | path_or_model (dict): torch.load('path/to/model.pt') 63 | path_or_model (nn.Module): torch.load('path/to/model.pt')['model'] 64 | 65 | Returns: 66 | pytorch model 67 | """ 68 | model = torch.load(path_or_model, map_location=torch.device('cpu')) if isinstance(path_or_model, str) else path_or_model # load checkpoint 69 | if isinstance(model, dict): 70 | model = model['ema' if model.get('ema') else 'model'] # load model 71 | 72 | hub_model = Model(model.yaml).to(next(model.parameters()).device) # create 73 | hub_model.load_state_dict(model.float().state_dict()) # load state_dict 74 | hub_model.names = model.names # class names 75 | if autoshape: 76 | hub_model = hub_model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS 77 | device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available 78 | return hub_model.to(device) 79 | 80 | 81 | def yolov7(pretrained=True, channels=3, classes=80, autoshape=True): 82 | return create('yolov7', pretrained, channels, classes, autoshape) 83 | 84 | 85 | if __name__ == '__main__': 86 | model = custom(path_or_model='yolov7.pt') # custom example 87 | # model = create(name='yolov7', pretrained=True, channels=3, classes=80, autoshape=True) # pretrained example 88 | 89 | # Verify inference 90 | import numpy as np 91 | from PIL import Image 92 | 93 | imgs = [np.zeros((640, 480, 3))] 94 | 95 | results = model(imgs) # batched inference 96 | results.print() 97 | results.save() 98 | -------------------------------------------------------------------------------- /cfg/ablation/w-mp.yaml: -------------------------------------------------------------------------------- 1 | # parameters 2 | nc: 3 # number of classes 3 | depth_multiple: 1.0 # model depth multiple 4 | width_multiple: 1.0 # layer channel multiple 5 | 6 | # anchors 7 | anchors: 8 | - [12,16, 19,36, 40,28] # P3/8 9 | - [36,75, 76,55, 72,146] # P4/16 10 | - [142,110, 192,243, 459,401] # P5/32 11 | 12 | # cst-yolo backbone 13 | backbone: 14 | # [from, number, module, args] 15 | [[-1, 1, Conv, [32, 3, 1]], # 0 16 | 17 | [-1, 1, Conv, [64, 3, 2]], # 1-P1/2 18 | [-1, 1, Conv, [64, 3, 1]], # CBS 19 | 20 | [-1, 1, Conv, [128, 3, 2]], # 3-P2/4 21 | 22 | [-1, 1, Conv, [64, 1, 1]], # W-ELAN 1 23 | [-2, 1, Conv, [64, 1, 1]], 24 | [-1, 1, Conv, [64, 3, 1]], 25 | [-1, 1, Conv, [64, 3, 1]], 26 | [-1, 1, Conv, [64, 3, 1]], 27 | [-1, 1, Conv, [64, 3, 1]], 28 | [[-1, -3, -5, -6], 1, MyConcat4, [1]], 29 | [-1, 1, Conv, [256, 1, 1]], # 11 30 | 31 | [-1, 1, MP, []], 32 | [-1, 1, Conv, [128, 1, 1]], 33 | [-3, 1, Conv, [128, 1, 1]], 34 | [-1, 1, Conv, [128, 3, 2]], 35 | [[-1, -3], 1, Concat, [1]], # 16-P3/8 36 | 37 | 38 | [-1, 1, Conv, [128, 1, 1]], # W-ELAN 1 39 | [-2, 1, Conv, [128, 1, 1]], 40 | [-1, 1, Conv, [128, 3, 1]], 41 | [-1, 1, Conv, [128, 3, 1]], 42 | [-1, 1, Conv, [128, 3, 1]], 43 | [-1, 1, Conv, [128, 3, 1]], 44 | [[-1, -3, -5, -6], 1, MyConcat4, [1]], 45 | [-1, 1, Conv, [512, 1, 1]], 46 | 47 | [-1, 1, MP, []], 48 | [-1, 1, Conv, [256, 1, 1]], 49 | [-3, 1, Conv, [256, 1, 1]], 50 | [-1, 1, Conv, [256, 3, 2]], 51 | [[-1, -3], 1, Concat, [1]], # 30-P4/16 52 | [-1, 1, CST, [512 ]], # CST 53 | 54 | 55 | [-1, 1, Conv, [256, 1, 1]], # W-ELAN 1 56 | [-2, 1, Conv, [256, 1, 1]], 57 | [-1, 1, Conv, [256, 3, 1]], 58 | [-1, 1, Conv, [256, 3, 1]], 59 | [-1, 1, Conv, [256, 3, 1]], 60 | [-1, 1, Conv, [256, 3, 1]], 61 | [[-1, -3, -5, -6], 1, MyConcat4, [1]], 62 | [-1, 1, Conv, [1024, 1, 1]], # 39 63 | 64 | [-1, 1, MP, []], 65 | [-1, 1, Conv, [512, 1, 1]], 66 | [-3, 1, Conv, [512, 1, 1]], 67 | [-1, 1, Conv, [512, 3, 2]], 68 | [[-1, -3], 1, Concat, [1]], # 44-P5/32 69 | 70 | [-1, 1, Conv, [256, 1, 1]], # W-ELAN 1 71 | [-2, 1, Conv, [256, 1, 1]], 72 | [-1, 1, Conv, [256, 3, 1]], 73 | [-1, 1, Conv, [256, 3, 1]], 74 | [-1, 1, Conv, [256, 3, 1]], 75 | [-1, 1, Conv, [256, 3, 1]], 76 | [[-1, -3, -5, -6], 1, MyConcat4, [1]], 77 | [-1, 1, Conv, [1024, 1, 1]], # 52 78 | [-1, 1, MCS, [1024 ]], # MCS 79 | ] 80 | 81 | # cst-yolo neck & head 82 | head: 83 | [[-1, 1, SPPCSPC, [512]], 84 | 85 | [-1, 1, Conv, [256, 1, 1]], # CBS 86 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], # Upsample 87 | [38, 1, Conv, [256, 1, 1]], # CBS 88 | [[-1, -2], 1, Concat, [1]], 89 | 90 | [-1, 1, Conv, [256, 1, 1]], # W-ELAN 2 91 | [-2, 1, Conv, [256, 1, 1]], 92 | [-1, 1, Conv, [128, 3, 1]], 93 | [-1, 1, Conv, [128, 3, 1]], 94 | [-1, 1, Conv, [128, 3, 1]], 95 | [-1, 1, Conv, [128, 3, 1]], 96 | [[-1, -2, -3, -4, -5, -6], 1, MyConcat6, [1]], 97 | [-1, 1, Conv, [256, 1, 1]], 98 | 99 | [-1, 1, Conv, [128, 1, 1]], # CBS 100 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], # Upsample 101 | [24, 1, Conv, [128, 1, 1]], # CBS 102 | [[-1, -2], 1, Concat, [1]], 103 | 104 | [-1, 1, Conv, [128, 1, 1]], # W-ELAN 2 105 | [-2, 1, Conv, [128, 1, 1]], 106 | [-1, 1, Conv, [64, 3, 1]], 107 | [-1, 1, Conv, [64, 3, 1]], 108 | [-1, 1, Conv, [64, 3, 1]], 109 | [-1, 1, Conv, [64, 3, 1]], 110 | [[-1, -2, -3, -4, -5, -6], 1, MyConcat6, [1]], 111 | [-1, 1, Conv, [128, 1, 1]], # 75 112 | 113 | [-1, 1, MP, []], 114 | [-1, 1, Conv, [128, 1, 1]], 115 | [-3, 1, Conv, [128, 1, 1]], 116 | [-1, 1, Conv, [128, 3, 2]], 117 | [[-1, -3, 63], 1, Concat, [1]], 118 | 119 | [-1, 1, Conv, [256, 1, 1]], # W-ELAN 2 120 | [-2, 1, Conv, [256, 1, 1]], 121 | [-1, 1, Conv, [128, 3, 1]], 122 | [-1, 1, Conv, [128, 3, 1]], 123 | [-1, 1, Conv, [128, 3, 1]], 124 | [-1, 1, Conv, [128, 3, 1]], 125 | [[-1, -2, -3, -4, -5, -6], 1, MyConcat6, [1]], 126 | [-1, 1, Conv, [256, 1, 1]], # 89 127 | 128 | [-1, 1, MP, []], 129 | [-1, 1, Conv, [256, 1, 1]], 130 | [-3, 1, Conv, [256, 1, 1]], 131 | [-1, 1, Conv, [256, 3, 2]], 132 | [[-1,-2, -4, 54], 1, MyConcat4, [1]], 133 | 134 | [-1, 1, Conv, [512, 1, 1]], # W-ELAN 2 135 | [-2, 1, Conv, [512, 1, 1]], 136 | [-1, 1, Conv, [256, 3, 1]], 137 | [-1, 1, Conv, [256, 3, 1]], 138 | [-1, 1, Conv, [256, 3, 1]], 139 | [-1, 1, Conv, [256, 3, 1]], 140 | [[-1, -2, -3, -4, -5, -6], 1, MyConcat6, [1]], 141 | [-1, 1, Conv, [512, 1, 1]], # 103 142 | 143 | [77, 1, RepConv, [256, 3, 1]], 144 | [91, 1, RepConv, [512, 3, 1]], 145 | [105, 1, RepConv, [1024, 3, 1]], 146 | 147 | [[106,107,108], 1, IDetect, [nc, anchors]], # Detect(P3, P4, P5) 148 | ] 149 | -------------------------------------------------------------------------------- /cfg/ablation/wo-welan.yaml: -------------------------------------------------------------------------------- 1 | # parameters 2 | nc: 3 # number of classes 3 | depth_multiple: 1.0 # model depth multiple 4 | width_multiple: 1.0 # layer channel multiple 5 | 6 | # anchors 7 | anchors: 8 | - [12,16, 19,36, 40,28] # P3/8 9 | - [36,75, 76,55, 72,146] # P4/16 10 | - [142,110, 192,243, 459,401] # P5/32 11 | 12 | # cst-yolo backbone 13 | backbone: 14 | # [from, number, module, args] 15 | [[-1, 1, Conv, [32, 3, 1]], # 0 16 | 17 | [-1, 1, Conv, [64, 3, 2]], # 1-P1/2 18 | [-1, 1, Conv, [64, 3, 1]], 19 | 20 | [-1, 1, Conv, [128, 3, 2]], # 3-P2/4 21 | 22 | [-1, 1, Conv, [64, 1, 1]], # E-ELAN 23 | [-2, 1, Conv, [64, 1, 1]], 24 | [-1, 1, Conv, [64, 3, 1]], 25 | [-1, 1, Conv, [64, 3, 1]], 26 | [-1, 1, Conv, [64, 3, 1]], 27 | [-1, 1, Conv, [64, 3, 1]], 28 | [[-1, -3, -5, -6], 1, Concat, [1]], 29 | [-1, 1, Conv, [256, 1, 1]], # 11 30 | 31 | [-1, 1, Conv, [256, 3, 2]], # CBSConcat (MP->Conv) 32 | [-1, 1, Conv, [128, 1, 1]], 33 | [-3, 1, Conv, [128, 1, 1]], 34 | [-1, 1, Conv, [128, 3, 2]], 35 | [[-1, -3], 1, Concat, [1]], # 16-P3/8 36 | 37 | 38 | [-1, 1, Conv, [128, 1, 1]], # E-ELAN 39 | [-2, 1, Conv, [128, 1, 1]], 40 | [-1, 1, Conv, [128, 3, 1]], 41 | [-1, 1, Conv, [128, 3, 1]], 42 | [-1, 1, Conv, [128, 3, 1]], 43 | [-1, 1, Conv, [128, 3, 1]], 44 | [[-1, -3, -5, -6], 1, Concat, [1]], 45 | [-1, 1, Conv, [512, 1, 1]], # 25 46 | 47 | [-1, 1, Conv, [256, 3, 2]], # CBSConcat (MP->Conv) 48 | [-1, 1, Conv, [256, 1, 1]], 49 | [-3, 1, Conv, [256, 1, 1]], 50 | [-1, 1, Conv, [256, 3, 2]], 51 | [[-1, -3], 1, Concat, [1]], # 30-P4/16 52 | [-1, 1, CST, [512 ]], # CST 53 | 54 | 55 | [-1, 1, Conv, [256, 1, 1]], # E-ELAN 56 | [-2, 1, Conv, [256, 1, 1]], 57 | [-1, 1, Conv, [256, 3, 1]], 58 | [-1, 1, Conv, [256, 3, 1]], 59 | [-1, 1, Conv, [256, 3, 1]], 60 | [-1, 1, Conv, [256, 3, 1]], 61 | [[-1, -3, -5, -6], 1, Concat, [1]], 62 | [-1, 1, Conv, [1024, 1, 1]], # 39 63 | 64 | [-1, 1, Conv, [512, 3, 2]], # CBSConcat (MP->Conv) 65 | [-1, 1, Conv, [512, 1, 1]], 66 | [-3, 1, Conv, [512, 1, 1]], 67 | [-1, 1, Conv, [512, 3, 2]], 68 | [[-1, -3], 1, Concat, [1]], # 44-P5/32 69 | 70 | [-1, 1, Conv, [256, 1, 1]], # E-ELAN 71 | [-2, 1, Conv, [256, 1, 1]], 72 | [-1, 1, Conv, [256, 3, 1]], 73 | [-1, 1, Conv, [256, 3, 1]], 74 | [-1, 1, Conv, [256, 3, 1]], 75 | [-1, 1, Conv, [256, 3, 1]], 76 | [[-1, -3, -5, -6], 1, Concat, [1]], 77 | [-1, 1, Conv, [1024, 1, 1]], 78 | [-1, 1, MCS, [1024 ]], # MCS 79 | ] 80 | 81 | # cst-yolo neck & head 82 | head: 83 | [[-1, 1, SPPCSPC, [512]], 84 | 85 | [-1, 1, Conv, [256, 1, 1]], # CBS 86 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], # Upsample 87 | [38, 1, Conv, [256, 1, 1]], # CBS 88 | [[-1, -2], 1, Concat, [1]], 89 | 90 | [-1, 1, Conv, [256, 1, 1]], # E-ELAN 91 | [-2, 1, Conv, [256, 1, 1]], 92 | [-1, 1, Conv, [128, 3, 1]], 93 | [-1, 1, Conv, [128, 3, 1]], 94 | [-1, 1, Conv, [128, 3, 1]], 95 | [-1, 1, Conv, [128, 3, 1]], 96 | [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]], 97 | [-1, 1, Conv, [256, 1, 1]], 98 | 99 | [-1, 1, Conv, [128, 1, 1]], # CBS 100 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], # Upsample 101 | [24, 1, Conv, [128, 1, 1]], # CBS 102 | [[-1, -2], 1, Concat, [1]], 103 | 104 | [-1, 1, Conv, [128, 1, 1]], # E-ELAN 105 | [-2, 1, Conv, [128, 1, 1]], 106 | [-1, 1, Conv, [64, 3, 1]], 107 | [-1, 1, Conv, [64, 3, 1]], 108 | [-1, 1, Conv, [64, 3, 1]], 109 | [-1, 1, Conv, [64, 3, 1]], 110 | [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]], 111 | [-1, 1, Conv, [128, 1, 1]], 112 | 113 | [-1, 1, Conv, [128, 3, 2]], # MP->Conv 114 | [-1, 1, Conv, [128, 1, 1]], 115 | [-3, 1, Conv, [128, 1, 1]], 116 | [-1, 1, Conv, [128, 3, 2]], 117 | [69, 1, Conv, [128, 3, 2]], 118 | [[-1,-2, -4, 66], 1, Concat, [1]], 119 | 120 | [-1, 1, Conv, [256, 1, 1]], E-ELAN 121 | [-2, 1, Conv, [256, 1, 1]], 122 | [-1, 1, Conv, [128, 3, 1]], 123 | [-1, 1, Conv, [128, 3, 1]], 124 | [-1, 1, Conv, [128, 3, 1]], 125 | [-1, 1, Conv, [128, 3, 1]], 126 | [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]], 127 | [-1, 1, Conv, [256, 1, 1]], 128 | 129 | [-1, 1, Conv, [256, 3, 2]], # MP->Conv 130 | [-1, 1, Conv, [256, 1, 1]], 131 | [-3, 1, Conv, [256, 1, 1]], 132 | [-1, 1, Conv, [256, 3, 2]], 133 | [57, 1, Conv, [256, 3, 2]], 134 | [[-1,-2, -4, 54], 1, Concat, [1]], 135 | 136 | [-1, 1, Conv, [512, 1, 1]], # E-ELAN 137 | [-2, 1, Conv, [512, 1, 1]], 138 | [-1, 1, Conv, [256, 3, 1]], 139 | [-1, 1, Conv, [256, 3, 1]], 140 | [-1, 1, Conv, [256, 3, 1]], 141 | [-1, 1, Conv, [256, 3, 1]], 142 | [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]], 143 | [-1, 1, Conv, [512, 1, 1]], # 103 144 | 145 | [77, 1, RepConv, [256, 3, 1]], 146 | [91, 1, RepConv, [512, 3, 1]], 147 | [105, 1, RepConv, [1024, 3, 1]], 148 | 149 | [[106,107,108], 1, IDetect, [nc, anchors]], # Detect(P3, P4, P5) 150 | ] 151 | -------------------------------------------------------------------------------- /cfg/ablation/wo-cst.yaml: -------------------------------------------------------------------------------- 1 | # parameters 2 | nc: 3 # number of classes 3 | depth_multiple: 1.0 # model depth multiple 4 | width_multiple: 1.0 # layer channel multiple 5 | 6 | # anchors 7 | anchors: 8 | - [12,16, 19,36, 40,28] # P3/8 9 | - [36,75, 76,55, 72,146] # P4/16 10 | - [142,110, 192,243, 459,401] # P5/32 11 | 12 | # cst-yolo backbone 13 | backbone: 14 | # [from, number, module, args] 15 | [[-1, 1, Conv, [32, 3, 1]], # 0 16 | 17 | [-1, 1, Conv, [64, 3, 2]], # 1-P1/2 18 | [-1, 1, Conv, [64, 3, 1]], 19 | 20 | [-1, 1, Conv, [128, 3, 2]], # 3-P2/4 21 | 22 | [-1, 1, Conv, [64, 1, 1]], # W-ELAN 1 23 | [-2, 1, Conv, [64, 1, 1]], 24 | [-1, 1, Conv, [64, 3, 1]], 25 | [-1, 1, Conv, [64, 3, 1]], 26 | [-1, 1, Conv, [64, 3, 1]], 27 | [-1, 1, Conv, [64, 3, 1]], 28 | [[-1, -3, -5, -6], 1, MyConcat4, [1]], 29 | [-1, 1, Conv, [256, 1, 1]], # 11 30 | 31 | [-1, 1, Conv, [256, 3, 2]], # CBSConcat (MP->Conv) 32 | [-1, 1, Conv, [128, 1, 1]], 33 | [-3, 1, Conv, [128, 1, 1]], 34 | [-1, 1, Conv, [128, 3, 2]], 35 | [[-1, -3], 1, Concat, [1]], # 16-P3/8 36 | 37 | 38 | [-1, 1, Conv, [128, 1, 1]], # W-ELAN 1 39 | [-2, 1, Conv, [128, 1, 1]], 40 | [-1, 1, Conv, [128, 3, 1]], 41 | [-1, 1, Conv, [128, 3, 1]], 42 | [-1, 1, Conv, [128, 3, 1]], 43 | [-1, 1, Conv, [128, 3, 1]], 44 | [[-1, -3, -5, -6], 1, MyConcat4, [1]], 45 | [-1, 1, Conv, [512, 1, 1]], # 25 46 | 47 | [-1, 1, Conv, [256, 3, 2]], # CBSConcat (MP->Conv) 48 | [-1, 1, Conv, [256, 1, 1]], 49 | [-3, 1, Conv, [256, 1, 1]], 50 | [-1, 1, Conv, [256, 3, 2]], 51 | [[-1, -3], 1, Concat, [1]], # 30-P4/16 52 | # [-1, 1, CST, [512 ]], # CST 53 | 54 | 55 | [-1, 1, Conv, [256, 1, 1]], # W-ELAN 1 56 | [-2, 1, Conv, [256, 1, 1]], 57 | [-1, 1, Conv, [256, 3, 1]], 58 | [-1, 1, Conv, [256, 3, 1]], 59 | [-1, 1, Conv, [256, 3, 1]], 60 | [-1, 1, Conv, [256, 3, 1]], 61 | [[-1, -3, -5, -6], 1, MyConcat4, [1]], 62 | [-1, 1, Conv, [1024, 1, 1]], # 38 63 | 64 | [-1, 1, Conv, [512, 3, 2]], # CBSConcat (MP->Conv) 65 | [-1, 1, Conv, [512, 1, 1]], 66 | [-3, 1, Conv, [512, 1, 1]], 67 | [-1, 1, Conv, [512, 3, 2]], 68 | [[-1, -3], 1, Concat, [1]], # 44-P5/32 69 | 70 | [-1, 1, Conv, [256, 1, 1]], # W-ELAN 1 71 | [-2, 1, Conv, [256, 1, 1]], 72 | [-1, 1, Conv, [256, 3, 1]], 73 | [-1, 1, Conv, [256, 3, 1]], 74 | [-1, 1, Conv, [256, 3, 1]], 75 | [-1, 1, Conv, [256, 3, 1]], 76 | [[-1, -3, -5, -6], 1, MyConcat4, [1]], 77 | [-1, 1, Conv, [1024, 1, 1]], # 51 78 | [-1, 1, MCS, [1024 ]], # MCS 79 | ] 80 | 81 | # cst-yolo neck & head 82 | head: 83 | [[-1, 1, SPPCSPC, [512]], 84 | 85 | [-1, 1, Conv, [256, 1, 1]], # CBS 86 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], # Upsample 87 | [38, 1, Conv, [256, 1, 1]], # CBS 88 | [[-1, -2], 1, Concat, [1]], 89 | 90 | [-1, 1, Conv, [256, 1, 1]], # W-ELAN 2 91 | [-2, 1, Conv, [256, 1, 1]], 92 | [-1, 1, Conv, [128, 3, 1]], 93 | [-1, 1, Conv, [128, 3, 1]], 94 | [-1, 1, Conv, [128, 3, 1]], 95 | [-1, 1, Conv, [128, 3, 1]], 96 | [[-1, -2, -3, -4, -5, -6], 1, MyConcat6, [1]], 97 | [-1, 1, Conv, [256, 1, 1]], 98 | 99 | [-1, 1, Conv, [128, 1, 1]], # CBS 100 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], # Upsample 101 | [24, 1, Conv, [128, 1, 1]], # CBS 102 | [[-1, -2], 1, Concat, [1]], 103 | 104 | [-1, 1, Conv, [128, 1, 1]], # W-ELAN 2 105 | [-2, 1, Conv, [128, 1, 1]], 106 | [-1, 1, Conv, [64, 3, 1]], 107 | [-1, 1, Conv, [64, 3, 1]], 108 | [-1, 1, Conv, [64, 3, 1]], 109 | [-1, 1, Conv, [64, 3, 1]], 110 | [[-1, -2, -3, -4, -5, -6], 1, MyConcat6, [1]], 111 | [-1, 1, Conv, [128, 1, 1]], 112 | 113 | [-1, 1, Conv, [128, 3, 2]], # CatConv (MP->Conv) 114 | [-1, 1, Conv, [128, 1, 1]], 115 | [-3, 1, Conv, [128, 1, 1]], 116 | [-1, 1, Conv, [128, 3, 2]], 117 | [69, 1, Conv, [128, 3, 2]], 118 | [[-1,-2, -4, 66], 1, MyConcat4, [1]], 119 | 120 | [-1, 1, Conv, [256, 1, 1]], # W-ELAN 2 121 | [-2, 1, Conv, [256, 1, 1]], 122 | [-1, 1, Conv, [128, 3, 1]], 123 | [-1, 1, Conv, [128, 3, 1]], 124 | [-1, 1, Conv, [128, 3, 1]], 125 | [-1, 1, Conv, [128, 3, 1]], 126 | [[-1, -2, -3, -4, -5, -6], 1, MyConcat6, [1]], 127 | [-1, 1, Conv, [256, 1, 1]], 128 | 129 | [-1, 1, Conv, [256, 3, 2]], # CatConv (MP->Conv) 130 | [-1, 1, Conv, [256, 1, 1]], 131 | [-3, 1, Conv, [256, 1, 1]], 132 | [-1, 1, Conv, [256, 3, 2]], 133 | [57, 1, Conv, [256, 3, 2]], 134 | [[-1,-2, -4, 54], 1, MyConcat4, [1]], 135 | 136 | [-1, 1, Conv, [512, 1, 1]], # W-ELAN 2 137 | [-2, 1, Conv, [512, 1, 1]], 138 | [-1, 1, Conv, [256, 3, 1]], 139 | [-1, 1, Conv, [256, 3, 1]], 140 | [-1, 1, Conv, [256, 3, 1]], 141 | [-1, 1, Conv, [256, 3, 1]], 142 | [[-1, -2, -3, -4, -5, -6], 1, MyConcat6, [1]], 143 | [-1, 1, Conv, [512, 1, 1]], 144 | 145 | [76, 1, RepConv, [256, 3, 1]], 146 | [90, 1, RepConv, [512, 3, 1]], 147 | [104, 1, RepConv, [1024, 3, 1]], 148 | 149 | [[105,106,107], 1, IDetect, [nc, anchors]], # Detect(P3, P4, P5) 150 | ] 151 | -------------------------------------------------------------------------------- /cfg/ablation/wo-mcs.yaml: -------------------------------------------------------------------------------- 1 | # parameters 2 | nc: 3 # number of classes 3 | depth_multiple: 1.0 # model depth multiple 4 | width_multiple: 1.0 # layer channel multiple 5 | 6 | # anchors 7 | anchors: 8 | - [12,16, 19,36, 40,28] # P3/8 9 | - [36,75, 76,55, 72,146] # P4/16 10 | - [142,110, 192,243, 459,401] # P5/32 11 | 12 | # cst-yolo backbone 13 | backbone: 14 | # [from, number, module, args] 15 | [[-1, 1, Conv, [32, 3, 1]], # 0 16 | 17 | [-1, 1, Conv, [64, 3, 2]], # 1-P1/2 18 | [-1, 1, Conv, [64, 3, 1]], 19 | 20 | [-1, 1, Conv, [128, 3, 2]], # 3-P2/4 21 | 22 | [-1, 1, Conv, [64, 1, 1]], # W-ELAN 1 23 | [-2, 1, Conv, [64, 1, 1]], 24 | [-1, 1, Conv, [64, 3, 1]], 25 | [-1, 1, Conv, [64, 3, 1]], 26 | [-1, 1, Conv, [64, 3, 1]], 27 | [-1, 1, Conv, [64, 3, 1]], 28 | [[-1, -3, -5, -6], 1, MyConcat4, [1]], 29 | [-1, 1, Conv, [256, 1, 1]], # 11 30 | 31 | [-1, 1, Conv, [256, 3, 2]], # MPConv (MP->Conv) 32 | [-1, 1, Conv, [128, 1, 1]], 33 | [-3, 1, Conv, [128, 1, 1]], 34 | [-1, 1, Conv, [128, 3, 2]], 35 | [[-1, -3], 1, Concat, [1]], # 16-P3/8 36 | 37 | 38 | [-1, 1, Conv, [128, 1, 1]], # W-ELAN 1 39 | [-2, 1, Conv, [128, 1, 1]], 40 | [-1, 1, Conv, [128, 3, 1]], 41 | [-1, 1, Conv, [128, 3, 1]], 42 | [-1, 1, Conv, [128, 3, 1]], 43 | [-1, 1, Conv, [128, 3, 1]], 44 | [[-1, -3, -5, -6], 1, MyConcat4, [1]], 45 | [-1, 1, Conv, [512, 1, 1]], # 25 46 | 47 | [-1, 1, Conv, [256, 3, 2]], # MPConv (MP->Conv) 48 | [-1, 1, Conv, [256, 1, 1]], 49 | [-3, 1, Conv, [256, 1, 1]], 50 | [-1, 1, Conv, [256, 3, 2]], 51 | [[-1, -3], 1, Concat, [1]], # 30-P4/16 52 | [-1, 1, CST, [512 ]], # CST 53 | 54 | 55 | [-1, 1, Conv, [256, 1, 1]], # W-ELAN 1 56 | [-2, 1, Conv, [256, 1, 1]], 57 | [-1, 1, Conv, [256, 3, 1]], 58 | [-1, 1, Conv, [256, 3, 1]], 59 | [-1, 1, Conv, [256, 3, 1]], 60 | [-1, 1, Conv, [256, 3, 1]], 61 | [[-1, -3, -5, -6], 1, MyConcat4, [1]], 62 | [-1, 1, Conv, [1024, 1, 1]], # 39 63 | 64 | [-1, 1, Conv, [512, 3, 2]], # MPConv (MP->Conv) 65 | [-1, 1, Conv, [512, 1, 1]], 66 | [-3, 1, Conv, [512, 1, 1]], 67 | [-1, 1, Conv, [512, 3, 2]], 68 | [[-1, -3], 1, Concat, [1]], # 44-P5/32 69 | 70 | [-1, 1, Conv, [256, 1, 1]], # W-ELAN 1 71 | [-2, 1, Conv, [256, 1, 1]], 72 | [-1, 1, Conv, [256, 3, 1]], 73 | [-1, 1, Conv, [256, 3, 1]], 74 | [-1, 1, Conv, [256, 3, 1]], 75 | [-1, 1, Conv, [256, 3, 1]], 76 | [[-1, -3, -5, -6], 1, MyConcat4, [1]], 77 | [-1, 1, Conv, [1024, 1, 1]], # 52 78 | # [-1, 1, MCS, [1024 ]], # MCS 79 | ] 80 | 81 | # cst-yolo neck & head 82 | head: 83 | [[-1, 1, SPPCSPC, [512]], 84 | 85 | [-1, 1, Conv, [256, 1, 1]], # CBS 86 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], # Upsample 87 | [38, 1, Conv, [256, 1, 1]], # CBS 88 | [[-1, -2], 1, Concat, [1]], 89 | 90 | [-1, 1, Conv, [256, 1, 1]], # W-ELAN 2 91 | [-2, 1, Conv, [256, 1, 1]], 92 | [-1, 1, Conv, [128, 3, 1]], 93 | [-1, 1, Conv, [128, 3, 1]], 94 | [-1, 1, Conv, [128, 3, 1]], 95 | [-1, 1, Conv, [128, 3, 1]], 96 | [[-1, -2, -3, -4, -5, -6], 1, MyConcat6, [1]], 97 | [-1, 1, Conv, [256, 1, 1]], 98 | 99 | [-1, 1, Conv, [128, 1, 1]], # CBS 100 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], # Upsample 101 | [24, 1, Conv, [128, 1, 1]], # CBS 102 | [[-1, -2], 1, Concat, [1]], 103 | 104 | [-1, 1, Conv, [128, 1, 1]], # W-ELAN 2 105 | [-2, 1, Conv, [128, 1, 1]], 106 | [-1, 1, Conv, [64, 3, 1]], 107 | [-1, 1, Conv, [64, 3, 1]], 108 | [-1, 1, Conv, [64, 3, 1]], 109 | [-1, 1, Conv, [64, 3, 1]], 110 | [[-1, -2, -3, -4, -5, -6], 1, MyConcat6, [1]], 111 | [-1, 1, Conv, [128, 1, 1]], 112 | 113 | [-1, 1, Conv, [128, 3, 2]], # CatConv (MP->Conv) 114 | [-1, 1, Conv, [128, 1, 1]], 115 | [-3, 1, Conv, [128, 1, 1]], 116 | [-1, 1, Conv, [128, 3, 2]], 117 | [69, 1, Conv, [128, 3, 2]], 118 | [[-1,-2, -4, 66], 1, MyConcat4, [1]], 119 | 120 | [-1, 1, Conv, [256, 1, 1]], # W-ELAN 2 121 | [-2, 1, Conv, [256, 1, 1]], 122 | [-1, 1, Conv, [128, 3, 1]], 123 | [-1, 1, Conv, [128, 3, 1]], 124 | [-1, 1, Conv, [128, 3, 1]], 125 | [-1, 1, Conv, [128, 3, 1]], 126 | [[-1, -2, -3, -4, -5, -6], 1, MyConcat6, [1]], 127 | [-1, 1, Conv, [256, 1, 1]], 128 | 129 | [-1, 1, Conv, [256, 3, 2]], # CatConv (MP->Conv) 130 | [-1, 1, Conv, [256, 1, 1]], 131 | [-3, 1, Conv, [256, 1, 1]], 132 | [-1, 1, Conv, [256, 3, 2]], 133 | [57, 1, Conv, [256, 3, 2]], 134 | [[-1,-2, -4, 54], 1, MyConcat4, [1]], 135 | 136 | [-1, 1, Conv, [512, 1, 1]], # W-ELAN 2 137 | [-2, 1, Conv, [512, 1, 1]], 138 | [-1, 1, Conv, [256, 3, 1]], 139 | [-1, 1, Conv, [256, 3, 1]], 140 | [-1, 1, Conv, [256, 3, 1]], 141 | [-1, 1, Conv, [256, 3, 1]], 142 | [[-1, -2, -3, -4, -5, -6], 1, MyConcat6, [1]], 143 | [-1, 1, Conv, [512, 1, 1]], 144 | 145 | [76, 1, RepConv, [256, 3, 1]], 146 | [90, 1, RepConv, [512, 3, 1]], 147 | [104, 1, RepConv, [1024, 3, 1]], 148 | 149 | [[105,106,107], 1, IDetect, [nc, anchors]], # Detect(P3, P4, P5) 150 | ] 151 | -------------------------------------------------------------------------------- /cfg/training/cst-yolo.yaml: -------------------------------------------------------------------------------- 1 | # parameters 2 | nc: 3 # number of classes 3 | depth_multiple: 1.0 # model depth multiple 4 | width_multiple: 1.0 # layer channel multiple 5 | 6 | # anchors 7 | anchors: 8 | - [12,16, 19,36, 40,28] # P3/8 9 | - [36,75, 76,55, 72,146] # P4/16 10 | - [142,110, 192,243, 459,401] # P5/32 11 | 12 | # cst-yolo backbone 13 | backbone: 14 | # [from, number, module, args] 15 | [[-1, 1, Conv, [32, 3, 1]], # 0 16 | 17 | [-1, 1, Conv, [64, 3, 2]], # 1-P1/2 18 | [-1, 1, Conv, [64, 3, 1]], 19 | 20 | [-1, 1, Conv, [128, 3, 2]], # 3-P2/4 21 | 22 | [-1, 1, Conv, [64, 1, 1]], # W-ELAN 1 23 | [-2, 1, Conv, [64, 1, 1]], 24 | [-1, 1, Conv, [64, 3, 1]], 25 | [-1, 1, Conv, [64, 3, 1]], 26 | [-1, 1, Conv, [64, 3, 1]], 27 | [-1, 1, Conv, [64, 3, 1]], 28 | [[-1, -3, -5, -6], 1, MyConcat4, [1]], 29 | [-1, 1, Conv, [256, 1, 1]], # 11 30 | 31 | [-1, 1, Conv, [256, 3, 2]], # CBSConcat (MP->Conv) 32 | [-1, 1, Conv, [128, 1, 1]], 33 | [-3, 1, Conv, [128, 1, 1]], 34 | [-1, 1, Conv, [128, 3, 2]], 35 | [[-1, -3], 1, Concat, [1]], # 16-P3/8 36 | 37 | 38 | [-1, 1, Conv, [128, 1, 1]], # W-ELAN 1 39 | [-2, 1, Conv, [128, 1, 1]], 40 | [-1, 1, Conv, [128, 3, 1]], 41 | [-1, 1, Conv, [128, 3, 1]], 42 | [-1, 1, Conv, [128, 3, 1]], 43 | [-1, 1, Conv, [128, 3, 1]], 44 | [[-1, -3, -5, -6], 1, MyConcat4, [1]], 45 | [-1, 1, Conv, [512, 1, 1]], # 25 46 | 47 | [-1, 1, Conv, [256, 3, 2]], # CBSConcat (MP->Conv) 48 | [-1, 1, Conv, [256, 1, 1]], 49 | [-3, 1, Conv, [256, 1, 1]], 50 | [-1, 1, Conv, [256, 3, 2]], 51 | [[-1, -3], 1, Concat, [1]], # 30-P4/16 52 | [-1, 1, CST, [512 ]], # CST 53 | 54 | 55 | [-1, 1, Conv, [256, 1, 1]], # W-ELAN 1 56 | [-2, 1, Conv, [256, 1, 1]], 57 | [-1, 1, Conv, [256, 3, 1]], 58 | [-1, 1, Conv, [256, 3, 1]], 59 | [-1, 1, Conv, [256, 3, 1]], 60 | [-1, 1, Conv, [256, 3, 1]], 61 | [[-1, -3, -5, -6], 1, MyConcat4, [1]], 62 | [-1, 1, Conv, [1024, 1, 1]], # 39 63 | 64 | [-1, 1, Conv, [512, 3, 2]], # CBSConcat (MP->Conv) 65 | [-1, 1, Conv, [512, 1, 1]], 66 | [-3, 1, Conv, [512, 1, 1]], 67 | [-1, 1, Conv, [512, 3, 2]], 68 | [[-1, -3], 1, Concat, [1]], # 44-P5/32 69 | 70 | [-1, 1, Conv, [256, 1, 1]], # W-ELAN 1 71 | [-2, 1, Conv, [256, 1, 1]], 72 | [-1, 1, Conv, [256, 3, 1]], 73 | [-1, 1, Conv, [256, 3, 1]], 74 | [-1, 1, Conv, [256, 3, 1]], 75 | [-1, 1, Conv, [256, 3, 1]], 76 | [[-1, -3, -5, -6], 1, MyConcat4, [1]], 77 | [-1, 1, Conv, [1024, 1, 1]], # 52 78 | [-1, 1, MCS, [1024 ]], # MCS 79 | ] 80 | 81 | # cst-yolo neck & head 82 | head: 83 | [[-1, 1, SPPCSPC, [512]], # 54 84 | 85 | [-1, 1, Conv, [256, 1, 1]], # CBS 86 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], # Upsample 87 | [38, 1, Conv, [256, 1, 1]], # CBS 88 | [[-1, -2], 1, Concat, [1]], 89 | 90 | [-1, 1, Conv, [256, 1, 1]], # W-ELAN 2 91 | [-2, 1, Conv, [256, 1, 1]], 92 | [-1, 1, Conv, [128, 3, 1]], 93 | [-1, 1, Conv, [128, 3, 1]], 94 | [-1, 1, Conv, [128, 3, 1]], 95 | [-1, 1, Conv, [128, 3, 1]], 96 | [[-1, -2, -3, -4, -5, -6], 1, MyConcat6, [1]], 97 | [-1, 1, Conv, [256, 1, 1]], # 65 98 | 99 | [-1, 1, Conv, [128, 1, 1]], # CBS 100 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], # Upsample 101 | [24, 1, Conv, [128, 1, 1]], # CBS 102 | [[-1, -2], 1, Concat, [1]], # 69 103 | 104 | [-1, 1, Conv, [128, 1, 1]], # W-ELAN 2 105 | [-2, 1, Conv, [128, 1, 1]], 106 | [-1, 1, Conv, [64, 3, 1]], 107 | [-1, 1, Conv, [64, 3, 1]], 108 | [-1, 1, Conv, [64, 3, 1]], 109 | [-1, 1, Conv, [64, 3, 1]], 110 | [[-1, -2, -3, -4, -5, -6], 1, MyConcat6, [1]], 111 | [-1, 1, Conv, [128, 1, 1]], # 75 112 | 113 | [-1, 1, Conv, [128, 3, 2]], # CatConv (MP->Conv) 114 | [-1, 1, Conv, [128, 1, 1]], 115 | [-3, 1, Conv, [128, 1, 1]], 116 | [-1, 1, Conv, [128, 3, 2]], 117 | [69, 1, Conv, [128, 3, 2]], # add 118 | [[-1,-2, -4, 66], 1, MyConcat4, [1]], 119 | 120 | [-1, 1, Conv, [256, 1, 1]], # W-ELAN 2 121 | [-2, 1, Conv, [256, 1, 1]], 122 | [-1, 1, Conv, [128, 3, 1]], 123 | [-1, 1, Conv, [128, 3, 1]], 124 | [-1, 1, Conv, [128, 3, 1]], 125 | [-1, 1, Conv, [128, 3, 1]], 126 | [[-1, -2, -3, -4, -5, -6], 1, MyConcat6, [1]], 127 | [-1, 1, Conv, [256, 1, 1]], # 89 128 | 129 | [-1, 1, Conv, [256, 3, 2]], # CatConv (MP->Conv) 130 | [-1, 1, Conv, [256, 1, 1]], 131 | [-3, 1, Conv, [256, 1, 1]], 132 | [-1, 1, Conv, [256, 3, 2]], 133 | [57, 1, Conv, [256, 3, 2]], # add 134 | [[-1,-2, -4, 54], 1, MyConcat4, [1]], 135 | 136 | [-1, 1, Conv, [512, 1, 1]], # W-ELAN 2 137 | [-2, 1, Conv, [512, 1, 1]], 138 | [-1, 1, Conv, [256, 3, 1]], 139 | [-1, 1, Conv, [256, 3, 1]], 140 | [-1, 1, Conv, [256, 3, 1]], 141 | [-1, 1, Conv, [256, 3, 1]], 142 | [[-1, -2, -3, -4, -5, -6], 1, MyConcat6, [1]], 143 | [-1, 1, Conv, [512, 1, 1]], # 103 144 | 145 | [77, 1, RepConv, [256, 3, 1]], 146 | [91, 1, RepConv, [512, 3, 1]], 147 | [105, 1, RepConv, [1024, 3, 1]], 148 | 149 | [[106,107,108], 1, IDetect, [nc, anchors]], # Detect(P3, P4, P5) 150 | ] 151 | -------------------------------------------------------------------------------- /cfg/baseline/yolov7.yaml: -------------------------------------------------------------------------------- 1 | # References: 2 | # [1] C.-Y. Wang, A. Bochkovskiy, and H.-Y. M. Liao, "Yolov7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors," arXiv:2207.02696v1 [cs.CV], Jul. 2022. 3 | # [2] C.-Y. Wang, "Official yolov7," GitHub, 2022, https://github.com/WongKinYiu/yolov7/blob/main/cfg/training/yolov7.yaml. 4 | 5 | # parameters 6 | nc: 3 # number of classes 7 | depth_multiple: 1.0 # model depth multiple 8 | width_multiple: 1.0 # layer channel multiple 9 | 10 | # anchors 11 | anchors: 12 | - [12,16, 19,36, 40,28] # P3/8 13 | - [36,75, 76,55, 72,146] # P4/16 14 | - [142,110, 192,243, 459,401] # P5/32 15 | 16 | # yolov7 backbone 17 | backbone: 18 | # [from, number, module, args] 19 | [[-1, 1, Conv, [32, 3, 1]], # 0 20 | 21 | [-1, 1, Conv, [64, 3, 2]], # 1-P1/2 22 | [-1, 1, Conv, [64, 3, 1]], 23 | 24 | [-1, 1, Conv, [128, 3, 2]], # 3-P2/4 25 | [-1, 1, Conv, [64, 1, 1]], 26 | [-2, 1, Conv, [64, 1, 1]], 27 | [-1, 1, Conv, [64, 3, 1]], 28 | [-1, 1, Conv, [64, 3, 1]], 29 | [-1, 1, Conv, [64, 3, 1]], 30 | [-1, 1, Conv, [64, 3, 1]], 31 | [[-1, -3, -5, -6], 1, Concat, [1]], 32 | [-1, 1, Conv, [256, 1, 1]], # 11 33 | 34 | [-1, 1, MP, []], 35 | [-1, 1, Conv, [128, 1, 1]], 36 | [-3, 1, Conv, [128, 1, 1]], 37 | [-1, 1, Conv, [128, 3, 2]], 38 | [[-1, -3], 1, Concat, [1]], # 16-P3/8 39 | [-1, 1, Conv, [128, 1, 1]], 40 | [-2, 1, Conv, [128, 1, 1]], 41 | [-1, 1, Conv, [128, 3, 1]], 42 | [-1, 1, Conv, [128, 3, 1]], 43 | [-1, 1, Conv, [128, 3, 1]], 44 | [-1, 1, Conv, [128, 3, 1]], 45 | [[-1, -3, -5, -6], 1, Concat, [1]], 46 | [-1, 1, Conv, [512, 1, 1]], # 24 47 | 48 | [-1, 1, MP, []], 49 | [-1, 1, Conv, [256, 1, 1]], 50 | [-3, 1, Conv, [256, 1, 1]], 51 | [-1, 1, Conv, [256, 3, 2]], 52 | [[-1, -3], 1, Concat, [1]], # 29-P4/16 53 | [-1, 1, Conv, [256, 1, 1]], 54 | [-2, 1, Conv, [256, 1, 1]], 55 | [-1, 1, Conv, [256, 3, 1]], 56 | [-1, 1, Conv, [256, 3, 1]], 57 | [-1, 1, Conv, [256, 3, 1]], 58 | [-1, 1, Conv, [256, 3, 1]], 59 | [[-1, -3, -5, -6], 1, Concat, [1]], 60 | [-1, 1, Conv, [1024, 1, 1]], # 37 61 | 62 | [-1, 1, MP, []], 63 | [-1, 1, Conv, [512, 1, 1]], 64 | [-3, 1, Conv, [512, 1, 1]], 65 | [-1, 1, Conv, [512, 3, 2]], 66 | [[-1, -3], 1, Concat, [1]], # 42-P5/32 67 | [-1, 1, Conv, [256, 1, 1]], 68 | [-2, 1, Conv, [256, 1, 1]], 69 | [-1, 1, Conv, [256, 3, 1]], 70 | [-1, 1, Conv, [256, 3, 1]], 71 | [-1, 1, Conv, [256, 3, 1]], 72 | [-1, 1, Conv, [256, 3, 1]], 73 | [[-1, -3, -5, -6], 1, Concat, [1]], 74 | [-1, 1, Conv, [1024, 1, 1]], # 50 75 | ] 76 | 77 | # yolov7 head 78 | head: 79 | [[-1, 1, SPPCSPC, [512]], # 51 80 | 81 | [-1, 1, Conv, [256, 1, 1]], 82 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 83 | [37, 1, Conv, [256, 1, 1]], # route backbone P4 84 | [[-1, -2], 1, Concat, [1]], 85 | 86 | [-1, 1, Conv, [256, 1, 1]], 87 | [-2, 1, Conv, [256, 1, 1]], 88 | [-1, 1, Conv, [128, 3, 1]], 89 | [-1, 1, Conv, [128, 3, 1]], 90 | [-1, 1, Conv, [128, 3, 1]], 91 | [-1, 1, Conv, [128, 3, 1]], 92 | [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]], 93 | [-1, 1, Conv, [256, 1, 1]], # 63 94 | 95 | [-1, 1, Conv, [128, 1, 1]], 96 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 97 | [24, 1, Conv, [128, 1, 1]], # route backbone P3 98 | [[-1, -2], 1, Concat, [1]], 99 | 100 | [-1, 1, Conv, [128, 1, 1]], 101 | [-2, 1, Conv, [128, 1, 1]], 102 | [-1, 1, Conv, [64, 3, 1]], 103 | [-1, 1, Conv, [64, 3, 1]], 104 | [-1, 1, Conv, [64, 3, 1]], 105 | [-1, 1, Conv, [64, 3, 1]], 106 | [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]], 107 | [-1, 1, Conv, [128, 1, 1]], # 75 108 | 109 | [-1, 1, MP, []], 110 | [-1, 1, Conv, [128, 1, 1]], 111 | [-3, 1, Conv, [128, 1, 1]], 112 | [-1, 1, Conv, [128, 3, 2]], 113 | [[-1, -3, 63], 1, Concat, [1]], 114 | 115 | [-1, 1, Conv, [256, 1, 1]], 116 | [-2, 1, Conv, [256, 1, 1]], 117 | [-1, 1, Conv, [128, 3, 1]], 118 | [-1, 1, Conv, [128, 3, 1]], 119 | [-1, 1, Conv, [128, 3, 1]], 120 | [-1, 1, Conv, [128, 3, 1]], 121 | [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]], 122 | [-1, 1, Conv, [256, 1, 1]], # 88 123 | 124 | [-1, 1, MP, []], 125 | [-1, 1, Conv, [256, 1, 1]], 126 | [-3, 1, Conv, [256, 1, 1]], 127 | [-1, 1, Conv, [256, 3, 2]], 128 | [[-1, -3, 51], 1, Concat, [1]], 129 | 130 | [-1, 1, Conv, [512, 1, 1]], 131 | [-2, 1, Conv, [512, 1, 1]], 132 | [-1, 1, Conv, [256, 3, 1]], 133 | [-1, 1, Conv, [256, 3, 1]], 134 | [-1, 1, Conv, [256, 3, 1]], 135 | [-1, 1, Conv, [256, 3, 1]], 136 | [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]], 137 | [-1, 1, Conv, [512, 1, 1]], # 101 138 | 139 | [75, 1, RepConv, [256, 3, 1]], 140 | [88, 1, RepConv, [512, 3, 1]], 141 | [101, 1, RepConv, [1024, 3, 1]], 142 | 143 | [[102,103,104], 1, IDetect, [nc, anchors]], # Detect(P3, P4, P5) 144 | ] 145 | -------------------------------------------------------------------------------- /utils/google_utils.py: -------------------------------------------------------------------------------- 1 | # Google utils: https://cloud.google.com/storage/docs/reference/libraries 2 | 3 | import os 4 | import platform 5 | import subprocess 6 | import time 7 | from pathlib import Path 8 | 9 | import requests 10 | import torch 11 | 12 | 13 | def gsutil_getsize(url=''): 14 | # gs://bucket/file size https://cloud.google.com/storage/docs/gsutil/commands/du 15 | s = subprocess.check_output(f'gsutil du {url}', shell=True).decode('utf-8') 16 | return eval(s.split(' ')[0]) if len(s) else 0 # bytes 17 | 18 | 19 | def attempt_download(file, repo='WongKinYiu/yolov7'): 20 | # Attempt file download if does not exist 21 | file = Path(str(file).strip().replace("'", '').lower()) 22 | 23 | if not file.exists(): 24 | try: 25 | response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api 26 | assets = [x['name'] for x in response['assets']] # release assets 27 | tag = response['tag_name'] # i.e. 'v1.0' 28 | except: # fallback plan 29 | assets = ['yolov7.pt', 'yolov7-tiny.pt', 'yolov7x.pt', 'yolov7-d6.pt', 'yolov7-e6.pt', 30 | 'yolov7-e6e.pt', 'yolov7-w6.pt'] 31 | tag = subprocess.check_output('git tag', shell=True).decode().split()[-1] 32 | 33 | name = file.name 34 | if name in assets: 35 | msg = f'{file} missing, try downloading from https://github.com/{repo}/releases/' 36 | redundant = False # second download option 37 | try: # GitHub 38 | url = f'https://github.com/{repo}/releases/download/{tag}/{name}' 39 | print(f'Downloading {url} to {file}...') 40 | torch.hub.download_url_to_file(url, file) 41 | assert file.exists() and file.stat().st_size > 1E6 # check 42 | except Exception as e: # GCP 43 | print(f'Download error: {e}') 44 | assert redundant, 'No secondary mirror' 45 | url = f'https://storage.googleapis.com/{repo}/ckpt/{name}' 46 | print(f'Downloading {url} to {file}...') 47 | os.system(f'curl -L {url} -o {file}') # torch.hub.download_url_to_file(url, weights) 48 | finally: 49 | if not file.exists() or file.stat().st_size < 1E6: # check 50 | file.unlink(missing_ok=True) # remove partial downloads 51 | print(f'ERROR: Download failure: {msg}') 52 | print('') 53 | return 54 | 55 | 56 | def gdrive_download(id='', file='tmp.zip'): 57 | # Downloads a file from Google Drive. from yolov7.utils.google_utils import *; gdrive_download() 58 | t = time.time() 59 | file = Path(file) 60 | cookie = Path('cookie') # gdrive cookie 61 | print(f'Downloading https://drive.google.com/uc?export=download&id={id} as {file}... ', end='') 62 | file.unlink(missing_ok=True) # remove existing file 63 | cookie.unlink(missing_ok=True) # remove existing cookie 64 | 65 | # Attempt file download 66 | out = "NUL" if platform.system() == "Windows" else "/dev/null" 67 | os.system(f'curl -c ./cookie -s -L "drive.google.com/uc?export=download&id={id}" > {out}') 68 | if os.path.exists('cookie'): # large file 69 | s = f'curl -Lb ./cookie "drive.google.com/uc?export=download&confirm={get_token()}&id={id}" -o {file}' 70 | else: # small file 71 | s = f'curl -s -L -o {file} "drive.google.com/uc?export=download&id={id}"' 72 | r = os.system(s) # execute, capture return 73 | cookie.unlink(missing_ok=True) # remove existing cookie 74 | 75 | # Error check 76 | if r != 0: 77 | file.unlink(missing_ok=True) # remove partial 78 | print('Download error ') # raise Exception('Download error') 79 | return r 80 | 81 | # Unzip if archive 82 | if file.suffix == '.zip': 83 | print('unzipping... ', end='') 84 | os.system(f'unzip -q {file}') # unzip 85 | file.unlink() # remove zip to free space 86 | 87 | print(f'Done ({time.time() - t:.1f}s)') 88 | return r 89 | 90 | 91 | def get_token(cookie="./cookie"): 92 | with open(cookie) as f: 93 | for line in f: 94 | if "download" in line: 95 | return line.split()[-1] 96 | return "" 97 | 98 | # def upload_blob(bucket_name, source_file_name, destination_blob_name): 99 | # # Uploads a file to a bucket 100 | # # https://cloud.google.com/storage/docs/uploading-objects#storage-upload-object-python 101 | # 102 | # storage_client = storage.Client() 103 | # bucket = storage_client.get_bucket(bucket_name) 104 | # blob = bucket.blob(destination_blob_name) 105 | # 106 | # blob.upload_from_filename(source_file_name) 107 | # 108 | # print('File {} uploaded to {}.'.format( 109 | # source_file_name, 110 | # destination_blob_name)) 111 | # 112 | # 113 | # def download_blob(bucket_name, source_blob_name, destination_file_name): 114 | # # Uploads a blob from a bucket 115 | # storage_client = storage.Client() 116 | # bucket = storage_client.get_bucket(bucket_name) 117 | # blob = bucket.blob(source_blob_name) 118 | # 119 | # blob.download_to_filename(destination_file_name) 120 | # 121 | # print('Blob {} downloaded to {}.'.format( 122 | # source_blob_name, 123 | # destination_file_name)) 124 | -------------------------------------------------------------------------------- /utils/add_nms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | from onnx import shape_inference 4 | try: 5 | import onnx_graphsurgeon as gs 6 | except Exception as e: 7 | print('Import onnx_graphsurgeon failure: %s' % e) 8 | 9 | import logging 10 | 11 | LOGGER = logging.getLogger(__name__) 12 | 13 | class RegisterNMS(object): 14 | def __init__( 15 | self, 16 | onnx_model_path: str, 17 | precision: str = "fp32", 18 | ): 19 | 20 | self.graph = gs.import_onnx(onnx.load(onnx_model_path)) 21 | assert self.graph 22 | LOGGER.info("ONNX graph created successfully") 23 | # Fold constants via ONNX-GS that PyTorch2ONNX may have missed 24 | self.graph.fold_constants() 25 | self.precision = precision 26 | self.batch_size = 1 27 | def infer(self): 28 | """ 29 | Sanitize the graph by cleaning any unconnected nodes, do a topological resort, 30 | and fold constant inputs values. When possible, run shape inference on the 31 | ONNX graph to determine tensor shapes. 32 | """ 33 | for _ in range(3): 34 | count_before = len(self.graph.nodes) 35 | 36 | self.graph.cleanup().toposort() 37 | try: 38 | for node in self.graph.nodes: 39 | for o in node.outputs: 40 | o.shape = None 41 | model = gs.export_onnx(self.graph) 42 | model = shape_inference.infer_shapes(model) 43 | self.graph = gs.import_onnx(model) 44 | except Exception as e: 45 | LOGGER.info(f"Shape inference could not be performed at this time:\n{e}") 46 | try: 47 | self.graph.fold_constants(fold_shapes=True) 48 | except TypeError as e: 49 | LOGGER.error( 50 | "This version of ONNX GraphSurgeon does not support folding shapes, " 51 | f"please upgrade your onnx_graphsurgeon module. Error:\n{e}" 52 | ) 53 | raise 54 | 55 | count_after = len(self.graph.nodes) 56 | if count_before == count_after: 57 | # No new folding occurred in this iteration, so we can stop for now. 58 | break 59 | 60 | def save(self, output_path): 61 | """ 62 | Save the ONNX model to the given location. 63 | Args: 64 | output_path: Path pointing to the location where to write 65 | out the updated ONNX model. 66 | """ 67 | self.graph.cleanup().toposort() 68 | model = gs.export_onnx(self.graph) 69 | onnx.save(model, output_path) 70 | LOGGER.info(f"Saved ONNX model to {output_path}") 71 | 72 | def register_nms( 73 | self, 74 | *, 75 | score_thresh: float = 0.25, 76 | nms_thresh: float = 0.45, 77 | detections_per_img: int = 100, 78 | ): 79 | """ 80 | Register the ``EfficientNMS_TRT`` plugin node. 81 | NMS expects these shapes for its input tensors: 82 | - box_net: [batch_size, number_boxes, 4] 83 | - class_net: [batch_size, number_boxes, number_labels] 84 | Args: 85 | score_thresh (float): The scalar threshold for score (low scoring boxes are removed). 86 | nms_thresh (float): The scalar threshold for IOU (new boxes that have high IOU 87 | overlap with previously selected boxes are removed). 88 | detections_per_img (int): Number of best detections to keep after NMS. 89 | """ 90 | 91 | self.infer() 92 | # Find the concat node at the end of the network 93 | op_inputs = self.graph.outputs 94 | op = "EfficientNMS_TRT" 95 | attrs = { 96 | "plugin_version": "1", 97 | "background_class": -1, # no background class 98 | "max_output_boxes": detections_per_img, 99 | "score_threshold": score_thresh, 100 | "iou_threshold": nms_thresh, 101 | "score_activation": False, 102 | "box_coding": 0, 103 | } 104 | 105 | if self.precision == "fp32": 106 | dtype_output = np.float32 107 | elif self.precision == "fp16": 108 | dtype_output = np.float16 109 | else: 110 | raise NotImplementedError(f"Currently not supports precision: {self.precision}") 111 | 112 | # NMS Outputs 113 | output_num_detections = gs.Variable( 114 | name="num_dets", 115 | dtype=np.int32, 116 | shape=[self.batch_size, 1], 117 | ) # A scalar indicating the number of valid detections per batch image. 118 | output_boxes = gs.Variable( 119 | name="det_boxes", 120 | dtype=dtype_output, 121 | shape=[self.batch_size, detections_per_img, 4], 122 | ) 123 | output_scores = gs.Variable( 124 | name="det_scores", 125 | dtype=dtype_output, 126 | shape=[self.batch_size, detections_per_img], 127 | ) 128 | output_labels = gs.Variable( 129 | name="det_classes", 130 | dtype=np.int32, 131 | shape=[self.batch_size, detections_per_img], 132 | ) 133 | 134 | op_outputs = [output_num_detections, output_boxes, output_scores, output_labels] 135 | 136 | # Create the NMS Plugin node with the selected inputs. The outputs of the node will also 137 | # become the final outputs of the graph. 138 | self.graph.layer(op=op, name="batched_nms", inputs=op_inputs, outputs=op_outputs, attrs=attrs) 139 | LOGGER.info(f"Created NMS plugin '{op}' with attributes: {attrs}") 140 | 141 | self.graph.outputs = op_outputs 142 | 143 | self.infer() 144 | 145 | def save(self, output_path): 146 | """ 147 | Save the ONNX model to the given location. 148 | Args: 149 | output_path: Path pointing to the location where to write 150 | out the updated ONNX model. 151 | """ 152 | self.graph.cleanup().toposort() 153 | model = gs.export_onnx(self.graph) 154 | onnx.save(model, output_path) 155 | LOGGER.info(f"Saved ONNX model to {output_path}") 156 | -------------------------------------------------------------------------------- /utils/autoanchor.py: -------------------------------------------------------------------------------- 1 | # Auto-anchor utils 2 | 3 | import numpy as np 4 | import torch 5 | import yaml 6 | from scipy.cluster.vq import kmeans 7 | from tqdm import tqdm 8 | 9 | from utils.general import colorstr 10 | 11 | 12 | def check_anchor_order(m): 13 | # Check anchor order against stride order for YOLO Detect() module m, and correct if necessary 14 | a = m.anchor_grid.prod(-1).view(-1) # anchor area 15 | da = a[-1] - a[0] # delta a 16 | ds = m.stride[-1] - m.stride[0] # delta s 17 | if da.sign() != ds.sign(): # same order 18 | print('Reversing anchor order') 19 | m.anchors[:] = m.anchors.flip(0) 20 | m.anchor_grid[:] = m.anchor_grid.flip(0) 21 | 22 | 23 | def check_anchors(dataset, model, thr=4.0, imgsz=640): 24 | # Check anchor fit to data, recompute if necessary 25 | prefix = colorstr('autoanchor: ') 26 | print(f'\n{prefix}Analyzing anchors... ', end='') 27 | m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect() 28 | shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True) 29 | scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale 30 | wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh 31 | 32 | def metric(k): # compute metric 33 | r = wh[:, None] / k[None] 34 | x = torch.min(r, 1. / r).min(2)[0] # ratio metric 35 | best = x.max(1)[0] # best_x 36 | aat = (x > 1. / thr).float().sum(1).mean() # anchors above threshold 37 | bpr = (best > 1. / thr).float().mean() # best possible recall 38 | return bpr, aat 39 | 40 | anchors = m.anchor_grid.clone().cpu().view(-1, 2) # current anchors 41 | bpr, aat = metric(anchors) 42 | print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='') 43 | if bpr < 0.98: # threshold to recompute 44 | print('. Attempting to improve anchors, please wait...') 45 | na = m.anchor_grid.numel() // 2 # number of anchors 46 | try: 47 | anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False) 48 | except Exception as e: 49 | print(f'{prefix}ERROR: {e}') 50 | new_bpr = metric(anchors)[0] 51 | if new_bpr > bpr: # replace anchors 52 | anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors) 53 | m.anchor_grid[:] = anchors.clone().view_as(m.anchor_grid) # for inference 54 | check_anchor_order(m) 55 | m.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss 56 | print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.') 57 | else: 58 | print(f'{prefix}Original anchors better than new anchors. Proceeding with original anchors.') 59 | print('') # newline 60 | 61 | 62 | def kmean_anchors(path='./data/coco.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True): 63 | """ Creates kmeans-evolved anchors from training dataset 64 | 65 | Arguments: 66 | path: path to dataset *.yaml, or a loaded dataset 67 | n: number of anchors 68 | img_size: image size used for training 69 | thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0 70 | gen: generations to evolve anchors using genetic algorithm 71 | verbose: print all results 72 | 73 | Return: 74 | k: kmeans evolved anchors 75 | 76 | Usage: 77 | from utils.autoanchor import *; _ = kmean_anchors() 78 | """ 79 | thr = 1. / thr 80 | prefix = colorstr('autoanchor: ') 81 | 82 | def metric(k, wh): # compute metrics 83 | r = wh[:, None] / k[None] 84 | x = torch.min(r, 1. / r).min(2)[0] # ratio metric 85 | # x = wh_iou(wh, torch.tensor(k)) # iou metric 86 | return x, x.max(1)[0] # x, best_x 87 | 88 | def anchor_fitness(k): # mutation fitness 89 | _, best = metric(torch.tensor(k, dtype=torch.float32), wh) 90 | return (best * (best > thr).float()).mean() # fitness 91 | 92 | def print_results(k): 93 | k = k[np.argsort(k.prod(1))] # sort small to large 94 | x, best = metric(k, wh0) 95 | bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr 96 | print(f'{prefix}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr') 97 | print(f'{prefix}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, ' 98 | f'past_thr={x[x > thr].mean():.3f}-mean: ', end='') 99 | for i, x in enumerate(k): 100 | print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg 101 | return k 102 | 103 | if isinstance(path, str): # *.yaml file 104 | with open(path) as f: 105 | data_dict = yaml.load(f, Loader=yaml.SafeLoader) # model dict 106 | from utils.datasets import LoadImagesAndLabels 107 | dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True) 108 | else: 109 | dataset = path # dataset 110 | 111 | # Get label wh 112 | shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True) 113 | wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh 114 | 115 | # Filter 116 | i = (wh0 < 3.0).any(1).sum() 117 | if i: 118 | print(f'{prefix}WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.') 119 | wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels 120 | # wh = wh * (np.random.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1 121 | 122 | # Kmeans calculation 123 | print(f'{prefix}Running kmeans for {n} anchors on {len(wh)} points...') 124 | s = wh.std(0) # sigmas for whitening 125 | k, dist = kmeans(wh / s, n, iter=30) # points, mean distance 126 | assert len(k) == n, print(f'{prefix}ERROR: scipy.cluster.vq.kmeans requested {n} points but returned only {len(k)}') 127 | k *= s 128 | wh = torch.tensor(wh, dtype=torch.float32) # filtered 129 | wh0 = torch.tensor(wh0, dtype=torch.float32) # unfiltered 130 | k = print_results(k) 131 | 132 | # Plot 133 | # k, d = [None] * 20, [None] * 20 134 | # for i in tqdm(range(1, 21)): 135 | # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance 136 | # fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True) 137 | # ax = ax.ravel() 138 | # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.') 139 | # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh 140 | # ax[0].hist(wh[wh[:, 0]<100, 0],400) 141 | # ax[1].hist(wh[wh[:, 1]<100, 1],400) 142 | # fig.savefig('wh.png', dpi=200) 143 | 144 | # Evolve 145 | npr = np.random 146 | f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma 147 | pbar = tqdm(range(gen), desc=f'{prefix}Evolving anchors with Genetic Algorithm:') # progress bar 148 | for _ in pbar: 149 | v = np.ones(sh) 150 | while (v == 1).all(): # mutate until a change occurs (prevent duplicates) 151 | v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0) 152 | kg = (k.copy() * v).clip(min=2.0) 153 | fg = anchor_fitness(kg) 154 | if fg > f: 155 | f, k = fg, kg.copy() 156 | pbar.desc = f'{prefix}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}' 157 | if verbose: 158 | print_results(k) 159 | 160 | return print_results(k) 161 | -------------------------------------------------------------------------------- /export.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import time 4 | import warnings 5 | 6 | sys.path.append('./') # to run '$ python *.py' files in subdirectories 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.mobile_optimizer import optimize_for_mobile 11 | 12 | import models 13 | from models.experimental import attempt_load, End2End 14 | from utils.activations import Hardswish, SiLU 15 | from utils.general import set_logging, check_img_size 16 | from utils.torch_utils import select_device 17 | from utils.add_nms import RegisterNMS 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--weights', type=str, default='./yolor-csp-c.pt', help='weights path') 22 | parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width 23 | parser.add_argument('--batch-size', type=int, default=1, help='batch size') 24 | parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') 25 | parser.add_argument('--dynamic-batch', action='store_true', help='dynamic batch onnx for tensorrt and onnx-runtime') 26 | parser.add_argument('--grid', action='store_true', help='export Detect() layer grid') 27 | parser.add_argument('--end2end', action='store_true', help='export end2end onnx') 28 | parser.add_argument('--max-wh', type=int, default=None, help='None for tensorrt nms, int value for onnx-runtime nms') 29 | parser.add_argument('--topk-all', type=int, default=100, help='topk objects for every images') 30 | parser.add_argument('--iou-thres', type=float, default=0.45, help='iou threshold for NMS') 31 | parser.add_argument('--conf-thres', type=float, default=0.25, help='conf threshold for NMS') 32 | parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') 33 | parser.add_argument('--simplify', action='store_true', help='simplify onnx model') 34 | parser.add_argument('--include-nms', action='store_true', help='export end2end onnx') 35 | parser.add_argument('--fp16', action='store_true', help='CoreML FP16 half-precision export') 36 | parser.add_argument('--int8', action='store_true', help='CoreML INT8 quantization') 37 | opt = parser.parse_args() 38 | opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand 39 | opt.dynamic = opt.dynamic and not opt.end2end 40 | opt.dynamic = False if opt.dynamic_batch else opt.dynamic 41 | print(opt) 42 | set_logging() 43 | t = time.time() 44 | 45 | # Load PyTorch model 46 | device = select_device(opt.device) 47 | model = attempt_load(opt.weights, map_location=device) # load FP32 model 48 | labels = model.names 49 | 50 | # Checks 51 | gs = int(max(model.stride)) # grid size (max stride) 52 | opt.img_size = [check_img_size(x, gs) for x in opt.img_size] # verify img_size are gs-multiples 53 | 54 | # Input 55 | img = torch.zeros(opt.batch_size, 3, *opt.img_size).to(device) # image size(1,3,320,192) iDetection 56 | 57 | # Update model 58 | for k, m in model.named_modules(): 59 | m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility 60 | if isinstance(m, models.common.Conv): # assign export-friendly activations 61 | if isinstance(m.act, nn.Hardswish): 62 | m.act = Hardswish() 63 | elif isinstance(m.act, nn.SiLU): 64 | m.act = SiLU() 65 | # elif isinstance(m, models.yolo.Detect): 66 | # m.forward = m.forward_export # assign forward (optional) 67 | model.model[-1].export = not opt.grid # set Detect() layer grid export 68 | y = model(img) # dry run 69 | if opt.include_nms: 70 | model.model[-1].include_nms = True 71 | y = None 72 | 73 | # TorchScript export 74 | try: 75 | print('\nStarting TorchScript export with torch %s...' % torch.__version__) 76 | f = opt.weights.replace('.pt', '.torchscript.pt') # filename 77 | ts = torch.jit.trace(model, img, strict=False) 78 | ts.save(f) 79 | print('TorchScript export success, saved as %s' % f) 80 | except Exception as e: 81 | print('TorchScript export failure: %s' % e) 82 | 83 | # CoreML export 84 | try: 85 | import coremltools as ct 86 | 87 | print('\nStarting CoreML export with coremltools %s...' % ct.__version__) 88 | # convert model from torchscript and apply pixel scaling as per detect.py 89 | ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])]) 90 | bits, mode = (8, 'kmeans_lut') if opt.int8 else (16, 'linear') if opt.fp16 else (32, None) 91 | if bits < 32: 92 | if sys.platform.lower() == 'darwin': # quantization only supported on macOS 93 | with warnings.catch_warnings(): 94 | warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning 95 | ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode) 96 | else: 97 | print('quantization only supported on macOS, skipping...') 98 | 99 | f = opt.weights.replace('.pt', '.mlmodel') # filename 100 | ct_model.save(f) 101 | print('CoreML export success, saved as %s' % f) 102 | except Exception as e: 103 | print('CoreML export failure: %s' % e) 104 | 105 | # TorchScript-Lite export 106 | try: 107 | print('\nStarting TorchScript-Lite export with torch %s...' % torch.__version__) 108 | f = opt.weights.replace('.pt', '.torchscript.ptl') # filename 109 | tsl = torch.jit.trace(model, img, strict=False) 110 | tsl = optimize_for_mobile(tsl) 111 | tsl._save_for_lite_interpreter(f) 112 | print('TorchScript-Lite export success, saved as %s' % f) 113 | except Exception as e: 114 | print('TorchScript-Lite export failure: %s' % e) 115 | 116 | # ONNX export 117 | try: 118 | import onnx 119 | 120 | print('\nStarting ONNX export with onnx %s...' % onnx.__version__) 121 | f = opt.weights.replace('.pt', '.onnx') # filename 122 | model.eval() 123 | output_names = ['classes', 'boxes'] if y is None else ['output'] 124 | dynamic_axes = None 125 | if opt.dynamic: 126 | dynamic_axes = {'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640) 127 | 'output': {0: 'batch', 2: 'y', 3: 'x'}} 128 | if opt.dynamic_batch: 129 | opt.batch_size = 'batch' 130 | dynamic_axes = { 131 | 'images': { 132 | 0: 'batch', 133 | }, } 134 | if opt.end2end and opt.max_wh is None: 135 | output_axes = { 136 | 'num_dets': {0: 'batch'}, 137 | 'det_boxes': {0: 'batch'}, 138 | 'det_scores': {0: 'batch'}, 139 | 'det_classes': {0: 'batch'}, 140 | } 141 | else: 142 | output_axes = { 143 | 'output': {0: 'batch'}, 144 | } 145 | dynamic_axes.update(output_axes) 146 | if opt.grid: 147 | if opt.end2end: 148 | print('\nStarting export end2end onnx model for %s...' % 'TensorRT' if opt.max_wh is None else 'onnxruntime') 149 | model = End2End(model,opt.topk_all,opt.iou_thres,opt.conf_thres,opt.max_wh,device,len(labels)) 150 | if opt.end2end and opt.max_wh is None: 151 | output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes'] 152 | shapes = [opt.batch_size, 1, opt.batch_size, opt.topk_all, 4, 153 | opt.batch_size, opt.topk_all, opt.batch_size, opt.topk_all] 154 | else: 155 | output_names = ['output'] 156 | else: 157 | model.model[-1].concat = True 158 | 159 | torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'], 160 | output_names=output_names, 161 | dynamic_axes=dynamic_axes) 162 | 163 | # Checks 164 | onnx_model = onnx.load(f) # load onnx model 165 | onnx.checker.check_model(onnx_model) # check onnx model 166 | 167 | if opt.end2end and opt.max_wh is None: 168 | for i in onnx_model.graph.output: 169 | for j in i.type.tensor_type.shape.dim: 170 | j.dim_param = str(shapes.pop(0)) 171 | 172 | # print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model 173 | 174 | # # Metadata 175 | # d = {'stride': int(max(model.stride))} 176 | # for k, v in d.items(): 177 | # meta = onnx_model.metadata_props.add() 178 | # meta.key, meta.value = k, str(v) 179 | # onnx.save(onnx_model, f) 180 | 181 | if opt.simplify: 182 | try: 183 | import onnxsim 184 | 185 | print('\nStarting to simplify ONNX...') 186 | onnx_model, check = onnxsim.simplify(onnx_model) 187 | assert check, 'assert check failed' 188 | except Exception as e: 189 | print(f'Simplifier failure: {e}') 190 | 191 | # print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model 192 | onnx.save(onnx_model,f) 193 | print('ONNX export success, saved as %s' % f) 194 | 195 | if opt.include_nms: 196 | print('Registering NMS plugin for ONNX...') 197 | mo = RegisterNMS(f) 198 | mo.register_nms() 199 | mo.save(f) 200 | 201 | except Exception as e: 202 | print('ONNX export failure: %s' % e) 203 | 204 | # Finish 205 | print('\nExport complete (%.2fs). Visualize with https://github.com/lutzroeder/netron.' % (time.time() - t)) 206 | -------------------------------------------------------------------------------- /detect.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from pathlib import Path 4 | 5 | import cv2 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | from numpy import random 9 | 10 | from models.experimental import attempt_load 11 | from utils.datasets import LoadStreams, LoadImages 12 | from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \ 13 | scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path 14 | from utils.plots import plot_one_box 15 | from utils.torch_utils import select_device, load_classifier, time_synchronized, TracedModel 16 | 17 | 18 | def detect(save_img=False): 19 | source, weights, view_img, save_txt, imgsz, trace = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size, not opt.no_trace 20 | save_img = not opt.nosave and not source.endswith('.txt') # save inference images 21 | webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith( 22 | ('rtsp://', 'rtmp://', 'http://', 'https://')) 23 | 24 | # Directories 25 | save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run 26 | (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir 27 | 28 | # Initialize 29 | set_logging() 30 | device = select_device(opt.device) 31 | half = device.type != 'cpu' # half precision only supported on CUDA 32 | 33 | # Load model 34 | model = attempt_load(weights, map_location=device) # load FP32 model 35 | stride = int(model.stride.max()) # model stride 36 | imgsz = check_img_size(imgsz, s=stride) # check img_size 37 | 38 | if False: 39 | model = TracedModel(model, device, opt.img_size) 40 | 41 | if half: 42 | model.half() # to FP16 43 | 44 | # Second-stage classifier 45 | classify = False 46 | if classify: 47 | modelc = load_classifier(name='resnet101', n=2) # initialize 48 | modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']).to(device).eval() 49 | 50 | # Set Dataloader 51 | vid_path, vid_writer = None, None 52 | if webcam: 53 | view_img = check_imshow() 54 | cudnn.benchmark = True # set True to speed up constant image size inference 55 | dataset = LoadStreams(source, img_size=imgsz, stride=stride) 56 | else: 57 | dataset = LoadImages(source, img_size=imgsz, stride=stride) 58 | 59 | # Get names and colors 60 | names = model.module.names if hasattr(model, 'module') else model.names 61 | #colors = [[random.randint(0, 255) for _ in range(3)] for _ in names] 62 | colors = [(0,255,0),(0,0,255),(255,0,0)] 63 | # Run inference 64 | if device.type != 'cpu': 65 | model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once 66 | old_img_w = old_img_h = imgsz 67 | old_img_b = 1 68 | 69 | t0 = time.time() 70 | for path, img, im0s, vid_cap in dataset: 71 | img = torch.from_numpy(img).to(device) 72 | img = img.half() if half else img.float() # uint8 to fp16/32 73 | img /= 255.0 # 0 - 255 to 0.0 - 1.0 74 | if img.ndimension() == 3: 75 | img = img.unsqueeze(0) 76 | 77 | # Warmup 78 | if device.type != 'cpu' and (old_img_b != img.shape[0] or old_img_h != img.shape[2] or old_img_w != img.shape[3]): 79 | old_img_b = img.shape[0] 80 | old_img_h = img.shape[2] 81 | old_img_w = img.shape[3] 82 | for i in range(3): 83 | model(img, augment=opt.augment)[0] 84 | 85 | # Inference 86 | t1 = time_synchronized() 87 | with torch.no_grad(): # Calculating gradients would cause a GPU memory leak 88 | pred = model(img, augment=opt.augment)[0] 89 | t2 = time_synchronized() 90 | 91 | # Apply NMS 92 | pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms) 93 | t3 = time_synchronized() 94 | 95 | # Apply Classifier 96 | if classify: 97 | pred = apply_classifier(pred, modelc, img, im0s) 98 | 99 | # Process detections 100 | for i, det in enumerate(pred): # detections per image 101 | if webcam: # batch_size >= 1 102 | p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count 103 | else: 104 | p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0) 105 | 106 | p = Path(p) # to Path 107 | save_path = str(save_dir / p.name) # img.jpg 108 | txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt 109 | gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh 110 | if len(det): 111 | # Rescale boxes from img_size to im0 size 112 | det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() 113 | 114 | # Print results 115 | for c in det[:, -1].unique(): 116 | n = (det[:, -1] == c).sum() # detections per class 117 | s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string 118 | 119 | # Write results 120 | for *xyxy, conf, cls in reversed(det): 121 | if save_txt: # Write to file 122 | xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh 123 | line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh) # label format 124 | with open(txt_path + '.txt', 'a') as f: 125 | f.write(('%g ' * len(line)).rstrip() % line + '\n') 126 | 127 | if save_img or view_img: # Add bbox to image 128 | label = f'{names[int(cls)]} {conf:.2f}' 129 | plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=1) 130 | 131 | # Print time (inference + NMS) 132 | print(f'{s}Done. ({(1E3 * (t2 - t1)):.1f}ms) Inference, ({(1E3 * (t3 - t2)):.1f}ms) NMS') 133 | 134 | # Stream results 135 | if view_img: 136 | cv2.imshow(str(p), im0) 137 | cv2.waitKey(1) # 1 millisecond 138 | 139 | # Save results (image with detections) 140 | if save_img: 141 | if dataset.mode == 'image': 142 | cv2.imwrite(save_path, im0) 143 | print(f" The image with the result is saved in: {save_path}") 144 | else: # 'video' or 'stream' 145 | if vid_path != save_path: # new video 146 | vid_path = save_path 147 | if isinstance(vid_writer, cv2.VideoWriter): 148 | vid_writer.release() # release previous video writer 149 | if vid_cap: # video 150 | fps = vid_cap.get(cv2.CAP_PROP_FPS) 151 | w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 152 | h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 153 | else: # stream 154 | fps, w, h = 30, im0.shape[1], im0.shape[0] 155 | save_path += '.mp4' 156 | vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) 157 | vid_writer.write(im0) 158 | 159 | if save_txt or save_img: 160 | s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' 161 | #print(f"Results saved to {save_dir}{s}") 162 | 163 | print(f'Done. ({time.time() - t0:.3f}s)') 164 | 165 | 166 | if __name__ == '__main__': 167 | parser = argparse.ArgumentParser() 168 | parser.add_argument('--weights', nargs='+', type=str, default='runs/train/exp/weights/best.pt', help='model.pt path(s)') 169 | parser.add_argument('--source', type=str, default=r'data/cbc/valdata', help='source') # file/folder, 0 for webcam 170 | parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') 171 | parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold') 172 | parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS') 173 | parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') 174 | parser.add_argument('--view-img', action='store_true', help='display results') 175 | parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') 176 | parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') 177 | parser.add_argument('--nosave', action='store_true', help='do not save images/videos') 178 | parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3') 179 | parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') 180 | parser.add_argument('--augment', action='store_true', help='augmented inference') 181 | parser.add_argument('--update', action='store_true', help='update all models') 182 | parser.add_argument('--project', default='runs/detect', help='save results to project/name') 183 | parser.add_argument('--name', default='exp', help='save results to project/name') 184 | parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') 185 | parser.add_argument('--no-trace', action='store_true', help='don`t trace model') 186 | opt = parser.parse_args() 187 | print(opt) 188 | #check_requirements(exclude=('pycocotools', 'thop')) 189 | 190 | with torch.no_grad(): 191 | if opt.update: # update all models (to fix SourceChangeWarning) 192 | for opt.weights in ['yolov7.pt']: 193 | detect() 194 | strip_optimizer(opt.weights) 195 | else: 196 | detect() 197 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Model validation metrics 2 | 3 | from pathlib import Path 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | 9 | from . import general 10 | 11 | 12 | def fitness(x): 13 | # Model fitness as a weighted combination of metrics 14 | w = [0.0, 0.0, 1.0, 0.0] # weights for [P, R, mAP@0.5, mAP@0.5:0.95] 15 | return (x[:, :4] * w).sum(1) 16 | 17 | 18 | def ap_per_class(tp, conf, pred_cls, target_cls, v5_metric=False, plot=False, save_dir='.', names=()): 19 | """ Compute the average precision, given the recall and precision curves. 20 | Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. 21 | # Arguments 22 | tp: True positives (nparray, nx1 or nx10). 23 | conf: Objectness value from 0-1 (nparray). 24 | pred_cls: Predicted object classes (nparray). 25 | target_cls: True object classes (nparray). 26 | plot: Plot precision-recall curve at mAP@0.5 27 | save_dir: Plot save directory 28 | # Returns 29 | The average precision as computed in py-faster-rcnn. 30 | """ 31 | 32 | # Sort by objectness 33 | i = np.argsort(-conf) 34 | tp, conf, pred_cls = tp[i], conf[i], pred_cls[i] 35 | 36 | # Find unique classes 37 | unique_classes = np.unique(target_cls) 38 | nc = unique_classes.shape[0] # number of classes, number of detections 39 | 40 | # Create Precision-Recall curve and compute AP for each class 41 | px, py = np.linspace(0, 1, 1000), [] # for plotting 42 | ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000)) 43 | for ci, c in enumerate(unique_classes): 44 | i = pred_cls == c 45 | n_l = (target_cls == c).sum() # number of labels 46 | n_p = i.sum() # number of predictions 47 | 48 | if n_p == 0 or n_l == 0: 49 | continue 50 | else: 51 | # Accumulate FPs and TPs 52 | fpc = (1 - tp[i]).cumsum(0) 53 | tpc = tp[i].cumsum(0) 54 | 55 | # Recall 56 | recall = tpc / (n_l + 1e-16) # recall curve 57 | r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases 58 | 59 | # Precision 60 | precision = tpc / (tpc + fpc) # precision curve 61 | p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score 62 | 63 | # AP from recall-precision curve 64 | for j in range(tp.shape[1]): 65 | ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j], v5_metric=v5_metric) 66 | if plot and j == 0: 67 | py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5 68 | 69 | # Compute F1 (harmonic mean of precision and recall) 70 | f1 = 2 * p * r / (p + r + 1e-16) 71 | if plot: 72 | plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names) 73 | plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1') 74 | plot_mc_curve(px, p, Path(save_dir) / 'P_curve.png', names, ylabel='Precision') 75 | plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall') 76 | 77 | i = f1.mean(0).argmax() # max F1 index 78 | return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32') 79 | 80 | 81 | def compute_ap(recall, precision, v5_metric=False): 82 | """ Compute the average precision, given the recall and precision curves 83 | # Arguments 84 | recall: The recall curve (list) 85 | precision: The precision curve (list) 86 | v5_metric: Assume maximum recall to be 1.0, as in YOLOv5, MMDetetion etc. 87 | # Returns 88 | Average precision, precision curve, recall curve 89 | """ 90 | 91 | # Append sentinel values to beginning and end 92 | if v5_metric: # New YOLOv5 metric, same as MMDetection and Detectron2 repositories 93 | mrec = np.concatenate(([0.], recall, [1.0])) 94 | else: # Old YOLOv5 metric, i.e. default YOLOv7 metric 95 | mrec = np.concatenate(([0.], recall, [recall[-1] + 0.01])) 96 | mpre = np.concatenate(([1.], precision, [0.])) 97 | 98 | # Compute the precision envelope 99 | mpre = np.flip(np.maximum.accumulate(np.flip(mpre))) 100 | 101 | # Integrate area under curve 102 | method = 'interp' # methods: 'continuous', 'interp' 103 | if method == 'interp': 104 | x = np.linspace(0, 1, 101) # 101-point interp (COCO) 105 | ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate 106 | else: # 'continuous' 107 | i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes 108 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve 109 | 110 | return ap, mpre, mrec 111 | 112 | 113 | class ConfusionMatrix: 114 | # Updated version of https://github.com/kaanakan/object_detection_confusion_matrix 115 | def __init__(self, nc, conf=0.25, iou_thres=0.45): 116 | self.matrix = np.zeros((nc + 1, nc + 1)) 117 | self.nc = nc # number of classes 118 | self.conf = conf 119 | self.iou_thres = iou_thres 120 | 121 | def process_batch(self, detections, labels): 122 | """ 123 | Return intersection-over-union (Jaccard index) of boxes. 124 | Both sets of boxes are expected to be in (x1, y1, x2, y2) format. 125 | Arguments: 126 | detections (Array[N, 6]), x1, y1, x2, y2, conf, class 127 | labels (Array[M, 5]), class, x1, y1, x2, y2 128 | Returns: 129 | None, updates confusion matrix accordingly 130 | """ 131 | detections = detections[detections[:, 4] > self.conf] 132 | gt_classes = labels[:, 0].int() 133 | detection_classes = detections[:, 5].int() 134 | iou = general.box_iou(labels[:, 1:], detections[:, :4]) 135 | 136 | x = torch.where(iou > self.iou_thres) 137 | if x[0].shape[0]: 138 | matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() 139 | if x[0].shape[0] > 1: 140 | matches = matches[matches[:, 2].argsort()[::-1]] 141 | matches = matches[np.unique(matches[:, 1], return_index=True)[1]] 142 | matches = matches[matches[:, 2].argsort()[::-1]] 143 | matches = matches[np.unique(matches[:, 0], return_index=True)[1]] 144 | else: 145 | matches = np.zeros((0, 3)) 146 | 147 | n = matches.shape[0] > 0 148 | m0, m1, _ = matches.transpose().astype(np.int16) 149 | for i, gc in enumerate(gt_classes): 150 | j = m0 == i 151 | if n and sum(j) == 1: 152 | self.matrix[gc, detection_classes[m1[j]]] += 1 # correct 153 | else: 154 | self.matrix[self.nc, gc] += 1 # background FP 155 | 156 | if n: 157 | for i, dc in enumerate(detection_classes): 158 | if not any(m1 == i): 159 | self.matrix[dc, self.nc] += 1 # background FN 160 | 161 | def matrix(self): 162 | return self.matrix 163 | 164 | def plot(self, save_dir='', names=()): 165 | try: 166 | import seaborn as sn 167 | 168 | array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize 169 | array[array < 0.005] = np.nan # don't annotate (would appear as 0.00) 170 | 171 | fig = plt.figure(figsize=(12, 9), tight_layout=True) 172 | sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size 173 | labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels 174 | sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True, 175 | xticklabels=names + ['background FP'] if labels else "auto", 176 | yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1)) 177 | fig.axes[0].set_xlabel('True') 178 | fig.axes[0].set_ylabel('Predicted') 179 | fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250) 180 | except Exception as e: 181 | pass 182 | 183 | def print(self): 184 | for i in range(self.nc + 1): 185 | print(' '.join(map(str, self.matrix[i]))) 186 | 187 | 188 | # Plots ---------------------------------------------------------------------------------------------------------------- 189 | 190 | def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()): 191 | # Precision-recall curve 192 | fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) 193 | py = np.stack(py, axis=1) 194 | 195 | if 0 < len(names) < 21: # display per-class legend if < 21 classes 196 | for i, y in enumerate(py.T): 197 | ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision) 198 | else: 199 | ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision) 200 | 201 | ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean()) 202 | ax.set_xlabel('Recall') 203 | ax.set_ylabel('Precision') 204 | ax.set_xlim(0, 1) 205 | ax.set_ylim(0, 1) 206 | plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") 207 | fig.savefig(Path(save_dir), dpi=250) 208 | 209 | 210 | def plot_mc_curve(px, py, save_dir='mc_curve.png', names=(), xlabel='Confidence', ylabel='Metric'): 211 | # Metric-confidence curve 212 | fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) 213 | 214 | if 0 < len(names) < 21: # display per-class legend if < 21 classes 215 | for i, y in enumerate(py): 216 | ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric) 217 | else: 218 | ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric) 219 | 220 | y = py.mean(0) 221 | ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}') 222 | ax.set_xlabel(xlabel) 223 | ax.set_ylabel(ylabel) 224 | ax.set_xlim(0, 1) 225 | ax.set_ylim(0, 1) 226 | plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") 227 | fig.savefig(Path(save_dir), dpi=250) 228 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Official CST-YOLO 2 |
3 | 4 | Build 5 | 6 | Build 7 |
8 | 9 | ## Description 10 | This is the source code for the paper titled "CST-YOLO: A Novel Method for Blood Cell Detection Based on Improved YOLOv7 and CNN-Swin Transformer" accepted by and [presented orally](https://cmsworkshops.com/ICIP2024/view_paper.php?PaperNum=2542&bare=1) at the 2024 IEEE International Conference on Image Processing ([ICIP 2024](https://2024.ieeeicip.org)), of which I am the first author. This paper is available to download from [IEEE Xplore](https://ieeexplore.ieee.org/document/10647618) or [arXiv](https://arxiv.org/abs/2306.14590). 11 | 12 | ## Model 13 | The CNN-Swin Transformer You Only Look Once (CST-YOLO) model configuration (i.e., network construction) file is cst-yolo.yaml in the directory [./cfg/training/](https://github.com/mkang315/CST-YOLO/tree/main/cfg/training). 14 | 15 | #### Installation 16 | Install requirements.txt with recommended dependencies Python >= 3.8 environment including Torch <= 1.7.1 and CUDA <= 11.1: 17 | ``` 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | #### Training 22 | 23 | The hyperparameter setting file is hyp.scratch.p5.yaml in the directory [./data/](https://github.com/mkang315/CST-YOLO/tree/main/data). 24 | 25 | ###### Single GPU training 26 | ``` 27 | python train.py --workers 8 --device 0 --batch-size 32 --data data/cbc.yaml --img 640 640 --cfg cfg/training/cst-yolo.yaml --weights '' --name cst-yolo --hyp data/hyp.scratch.p5.yaml 28 | ``` 29 | 30 | ###### Multiple GPU training 31 | ``` 32 | python -m torch.distributed.launch --nproc_per_node 4 --master_port 9527 train.py --workers 8 --device 0,1,2,3 --sync-bn --batch-size 128 --data data/cbc.yaml --img 640 640 --cfg cfg/training/cst-yolo.yaml --weights '' --name rcs-yolo --hyp data/hyp.scratch.p5.yaml 33 | ``` 34 | 35 | #### Testing 36 | 37 | ``` 38 | python test.py --data data/cbc.yaml --img 640 --batch 32 --conf 0.001 --iou 0.65 --device 0 --weights runs/train/exp/weights/best.pt --name val 39 | ``` 40 | 41 | ## Evaluation 42 | We trained and evaluated CST-YOLO on three blood cell detection datasets [Blood Cell Count and Detection (BCCD)](https://github.com/Shenggan/BCCD_Dataset), [Complete Blood Count (CBC)](https://github.com/MahmudulAlam/Complete-Blood-Cell-Count-Dataset), and [Blood Cell Detection (BCD)](https://www.kaggle.com/datasets/adhoppin/blood-cell-detection-datatset). The 60 samples of the validation set duplicate those from the training set in the CBC dataset. Each image includes three types of blood cells: Red Blood Cells (RBCs), White Blood Cells (WBCs), and platelets. 43 | 44 | ## Referencing Guide 45 | Please cite the paper if using this repository. Here is a guide to referencing this work in various styles for formatting your references:
46 | 47 | > Plain Text
48 | - **IEEE Reference Style**
49 | M. Kang, C.-M. Ting, F. F. Ting, and R. C.-W. Phan, "Cst-yolo: A novel method for blood cell detection based on improved yolov7 and cnn-swin transformer," in *Proc. IEEE Int. Conf. Image Process. (ICIP)*, Abu Dhabi, UAE, Oct. 27–30, 2024, pp. 3024–3029.
50 | **NOTE:** City of Conf., Abbrev. State, Country, Month & day(s) are optional. 51 | 52 | - **IEEE Full Name Reference Style**
53 | Ming Kang, Chee-Ming Ting, Fung Fung Ting, and Raphaël C.-W. Phan. Cst-yolo: A novel method for blood cell detection based on improved yolov7 and cnn-swin transformer. In *ICIP*, pages 3024–3029, 2024.
54 | **NOTE:** This is a modification to the standard IEEE Reference Style and used by most IEEE/CVF conferences, including **CVPR**, **ICCV**, and **WACV**, to render first names in the bibliography as "Firstname Lastname" rather than "F. Lastname" or "Lastname, F.".
55 |  - **IJCAI Full Name-Year Variation**
56 | \[Kang *et al.*, 2024\] Ming Kang, Chee-Ming Ting, Fung Fung Ting, and Raphaël C.-W. Phan. Cst-yolo: A novel method for blood cell detection based on improved yolov7 and cnn-swin transformer. In *Proceedings of the 2024 IEEE International Conference on Image Processing*, pages 3024–3029, Piscataway, NJ, October 2024. IEEE.
57 |  - **ACL Full Name-Year Variation**
58 | Ming Kang, Chee-Ming Ting, Fung Fung Ting, and Raphaël C.-W. Phan. 2024. Cst-yolo: A novel method for blood cell detection based on improved yolov7 and cnn-swin transformer. In *Proceedings of the 2024 IEEE International Conference on Image Processing*, pages 3024–3029, Piscataway, NJ. IEEE.
59 | 60 | - **Nature Referencing Style**
61 | Kang, M., Ting, C.-M., Ting, F. F. & Phan, R. C.-W. CST-YOLO: a novel method for blood cell detection based on improved YOLOv7 and CNN-Swin Transformer. In *2024 IEEE International Conference on Image Processing (ICIP)* 3024–3029 (IEEE, 2024).
62 | 63 | - **Springer Reference Style**
64 | Kang, M., Ting, C.-M., Ting, F.F., Phan, R.C.-W.: CST-YOLO: a novel method for blood cell detection based on improved YOLOv7 and CNN-Swin Transformer. In: 2024 IEEE International Conference on Image Processing (ICIP), pp. 3024–3029. IEEE, Piscataway (2024)
65 | **NOTE:** *ECCV* and *MICCAI* conference proceedings are part of the book series LNCS in which Springer's format for bibliographical references is strictly enforced. LNCS stands for Lecture Notes in Computer Science. 66 | 67 | - **Elsevier Numbered Style**
68 | M. Kang, C.-M. Ting, F.F. Ting, R.C.-W. Phan, CST-YOLO: a novel method for blood cell detection based on improved YOLOv7 and CNN-Swin Transformer, in: Proceedings of the IEEE International Conference on Image Processing (ICIP), 2024, pp. 3024–3029.
69 | **NOTE:** Day(s) Month Year, City, Abbrev. State, Country of Conference, Publiser, and Place of Publication are optional and omitted. 70 | 71 | - **Elsevier Name–Date (Harvard) Style**
72 | Kang, M., Ting, C.-M., Ting, F.F., Phan, R.C.-W., 2024. CST-YOLO: a novel method for blood cell detection based on improved YOLOv7 and CNN-Swin Transformer. In: Proceedings of the IEEE International Conference on Image Processing (ICIP), 27–30 October 2024, Abu Dhabi, UAE. IEEE, Piscataway, New York, USA, pp. 3024–3029.
73 | **NOTE:** Day(s) Month Year, City, Abbrev. State, Country of Conference, Publiser, and Place of Publication are optional. 74 | 75 | - **Elsevier Vancouver Style**
76 | Kang M, Ting C-M, Ting FF, Phan RC-W. CST-YOLO: a novel method for blood cell detection based on improved YOLOv7 and CNN-Swin Transformer. In: Proceedings of the IEEE International Conference on Image Processing (ICIP); 2024 Oct 27–30; Abu Dhabi, UAE. Piscataway: IEEE; 2024. p. 3024–9.
77 | 78 | - **Elsevier Embellished Vancouver Style**
79 | Kang M, Ting C-M, Ting FF, Phan RC-W. CST-YOLO: a novel method for blood cell detection based on improved YOLOv7 and CNN-Swin Transformer. In: *Proceedings of the IEEE International Conference on Image Processing (ICIP)*; 2024 Oct 27–30; Abu Dhabi, UAE. Piscataway: IEEE; 2024. p. 3024–9.
80 | 81 | - **APA7 (Author–Date) Style**
82 | Kang, M., Ting, C.-M., Ting, F. F., & Phan, R. C.-W. (2024). CST-YOLO: A novel method for blood cell detection based on improved YOLOv7 and CNN-Swin Transformer. In *Proceedings of the 2024 IEEE International Conference on Image Processing (ICIP)* (pp. 3024–3029). IEEE. https://doi.org/10.1109/ICIP51287.2024.10647618
83 |  - **AAAI (Author–Year) Variation**
84 | Kang, M.; Ting, C.-M.; Ting, F. F.; and Phan, R. C.-W. 2024. CST-YOLO: A Novel Method for Blood Cell Detection Based on Improved YOLOv7 and CNN-Swin Transformer. In *Proceedings of the 2024 IEEE International Conference on Image Processing (ICIP)*, 3024–3029. Piscataway, NJ: IEEE.
85 |  - **ICML (Author–Year) Variation**
86 | Kang, M., Ting, C.-M., Ting, F. F., and Phan, R. C.-W. CST-YOLO: A novel method for blood cell detection based on improved YOLOv7 and CNN-swin Transformer. In *Proceedings of the 2024 IEEE International Conference on Image Processing (ICIP)*, pp. 3024–3029, Piscataway, NJ, 2024. IEEE.
87 | **NOTE:** For **NeurIPS** and **ICLR**, any reference/citation style is acceptable as long as it is used consistently. The sample of references in Formatting Instructions For NeurIPS almost follows APA7 (author–date) style and that in Formatting Instructions For ICLR Conference Submissions is similar to IJCAI full name-year variation. 88 | 89 | > BibTeX Format
90 | ``` 91 | \begin{thebibliography}{1} 92 | \bibitem{Kang24Cstyolo} M. Kang, C.-M. Ting, F. F. Ting, and R. C.-W. Phan, "Cst-yolo: A novel method for blood cell detection based on improved yolov7 and cnn-swin transformer," in {\emph Proc. IEEE Int. Conf. Image Process. (ICIP)}, Abu Dhabi, UAE, Oct. 27--30, 2024, pp. 3024--3029. 93 | \end{thebibliography} 94 | ``` 95 | ``` 96 | @inproceedings{Kang24Cstyolo, 97 | author = "Ming Kang and Chee-Ming Ting and Fung Fung Ting and Rapha{\"e}l C.-W. Phan", 98 | title = "Cst-yolo: A novel method for blood cell detection based on improved yolov7 and cnn-swin transformer", 99 | booktitle = "Proc. IEEE Int. Conf. Image Process. (ICIP)", 100 | % booktitle = ICIP, %% IEEE Full Name Reference Style 101 | address = "Abu Dhabi, UAE, Oct. 27--30", 102 | pages = "3024--3029", 103 | year = "2024" 104 | } 105 | ``` 106 | ``` 107 | @inproceedings{Kang24Cstyolo, 108 | author = "Kang, Ming and Ting, Chee-Ming and Ting, Fung Fung and Phan, Rapha{\"e}l C.-W.", 109 | title = "{CST-YOLO}: a novel method for blood cell detection based on improved {YOLO}v7 and {CNN}-{S}win {T}ransformer", 110 | editor = "", 111 | booktitle = "2024 IEEE International Conference on Image Processing (ICIP)", 112 | series = "", 113 | volume = "", 114 | pages = "3024--3029", 115 | publisher = "IEEE", 116 | address = "Piscataway", 117 | year = "2024", 118 | doi= "10.1109/ICIP51287.2024.10647618", 119 | url = "https://doi.org/10.1109/ICIP51287.2024.10647618" 120 | } 121 | ``` 122 | **NOTE:** Please remove some optional *BibTeX* fields/tags such as `series`, `volume`, `address`, `url`, and so on if the *LaTeX* compiler produces an error. Author names may be manually modified if not automatically abbreviated by the compiler under the control of the bibliography/reference style (i.e., .bst) file. The *BibTex* citation key may be `bib1`, `b1`, or `ref1` when references appear in numbered style in which they are cited. The quotation mark pair `""` in the field could be replaced by the brace `{}`, whereas the brace `{}` in the *BibTeX* field/tag `title` plays a role of keeping letters/characters/text original lower/uppercases or sentence/capitalized cases unchanged while using Springer Nature bibliography style files, for example, sn-nature.bst. 123 | 124 | ## License 125 | CST-YOLO is released under the GNU General Public License v3.0. Please see the [LICENSE](https://github.com/mkang315/CST-YOLO/blob/main/LICENSE) file for more information. 126 | 127 | ## Copyright Notice 128 | Many utility codes of our project base on the codes of [YOLOv7](https://github.com/WongKinYiu/yolov7) and [Swin Transformer](https://github.com/microsoft/Swin-Transformer) repositories. 129 | -------------------------------------------------------------------------------- /models/experimental.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | import torch.nn as nn 5 | 6 | from models.common import Conv, DWConv 7 | from utils.google_utils import attempt_download 8 | 9 | 10 | class CrossConv(nn.Module): 11 | # Cross Convolution Downsample 12 | def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False): 13 | # ch_in, ch_out, kernel, stride, groups, expansion, shortcut 14 | super(CrossConv, self).__init__() 15 | c_ = int(c2 * e) # hidden channels 16 | self.cv1 = Conv(c1, c_, (1, k), (1, s)) 17 | self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g) 18 | self.add = shortcut and c1 == c2 19 | 20 | def forward(self, x): 21 | return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) 22 | 23 | 24 | class Sum(nn.Module): 25 | # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 26 | def __init__(self, n, weight=False): # n: number of inputs 27 | super(Sum, self).__init__() 28 | self.weight = weight # apply weights boolean 29 | self.iter = range(n - 1) # iter object 30 | if weight: 31 | self.w = nn.Parameter(-torch.arange(1., n) / 2, requires_grad=True) # layer weights 32 | 33 | def forward(self, x): 34 | y = x[0] # no weight 35 | if self.weight: 36 | w = torch.sigmoid(self.w) * 2 37 | for i in self.iter: 38 | y = y + x[i + 1] * w[i] 39 | else: 40 | for i in self.iter: 41 | y = y + x[i + 1] 42 | return y 43 | 44 | 45 | class MixConv2d(nn.Module): 46 | # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595 47 | def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): 48 | super(MixConv2d, self).__init__() 49 | groups = len(k) 50 | if equal_ch: # equal c_ per group 51 | i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices 52 | c_ = [(i == g).sum() for g in range(groups)] # intermediate channels 53 | else: # equal weight.numel() per group 54 | b = [c2] + [0] * groups 55 | a = np.eye(groups + 1, groups, k=-1) 56 | a -= np.roll(a, 1, axis=1) 57 | a *= np.array(k) ** 2 58 | a[0] = 1 59 | c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b 60 | 61 | self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)]) 62 | self.bn = nn.BatchNorm2d(c2) 63 | self.act = nn.LeakyReLU(0.1, inplace=True) 64 | 65 | def forward(self, x): 66 | return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) 67 | 68 | 69 | class Ensemble(nn.ModuleList): 70 | # Ensemble of models 71 | def __init__(self): 72 | super(Ensemble, self).__init__() 73 | 74 | def forward(self, x, augment=False): 75 | y = [] 76 | for module in self: 77 | y.append(module(x, augment)[0]) 78 | # y = torch.stack(y).max(0)[0] # max ensemble 79 | # y = torch.stack(y).mean(0) # mean ensemble 80 | y = torch.cat(y, 1) # nms ensemble 81 | return y, None # inference, train output 82 | 83 | 84 | 85 | 86 | 87 | class ORT_NMS(torch.autograd.Function): 88 | '''ONNX-Runtime NMS operation''' 89 | @staticmethod 90 | def forward(ctx, 91 | boxes, 92 | scores, 93 | max_output_boxes_per_class=torch.tensor([100]), 94 | iou_threshold=torch.tensor([0.45]), 95 | score_threshold=torch.tensor([0.25])): 96 | device = boxes.device 97 | batch = scores.shape[0] 98 | num_det = random.randint(0, 100) 99 | batches = torch.randint(0, batch, (num_det,)).sort()[0].to(device) 100 | idxs = torch.arange(100, 100 + num_det).to(device) 101 | zeros = torch.zeros((num_det,), dtype=torch.int64).to(device) 102 | selected_indices = torch.cat([batches[None], zeros[None], idxs[None]], 0).T.contiguous() 103 | selected_indices = selected_indices.to(torch.int64) 104 | return selected_indices 105 | 106 | @staticmethod 107 | def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold): 108 | return g.op("NonMaxSuppression", boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold) 109 | 110 | 111 | class TRT_NMS(torch.autograd.Function): 112 | '''TensorRT NMS operation''' 113 | @staticmethod 114 | def forward( 115 | ctx, 116 | boxes, 117 | scores, 118 | background_class=-1, 119 | box_coding=1, 120 | iou_threshold=0.45, 121 | max_output_boxes=100, 122 | plugin_version="1", 123 | score_activation=0, 124 | score_threshold=0.25, 125 | ): 126 | batch_size, num_boxes, num_classes = scores.shape 127 | num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32) 128 | det_boxes = torch.randn(batch_size, max_output_boxes, 4) 129 | det_scores = torch.randn(batch_size, max_output_boxes) 130 | det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32) 131 | return num_det, det_boxes, det_scores, det_classes 132 | 133 | @staticmethod 134 | def symbolic(g, 135 | boxes, 136 | scores, 137 | background_class=-1, 138 | box_coding=1, 139 | iou_threshold=0.45, 140 | max_output_boxes=100, 141 | plugin_version="1", 142 | score_activation=0, 143 | score_threshold=0.25): 144 | out = g.op("TRT::EfficientNMS_TRT", 145 | boxes, 146 | scores, 147 | background_class_i=background_class, 148 | box_coding_i=box_coding, 149 | iou_threshold_f=iou_threshold, 150 | max_output_boxes_i=max_output_boxes, 151 | plugin_version_s=plugin_version, 152 | score_activation_i=score_activation, 153 | score_threshold_f=score_threshold, 154 | outputs=4) 155 | nums, boxes, scores, classes = out 156 | return nums, boxes, scores, classes 157 | 158 | 159 | class ONNX_ORT(nn.Module): 160 | '''onnx module with ONNX-Runtime NMS operation.''' 161 | def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, device=None, n_classes=80): 162 | super().__init__() 163 | self.device = device if device else torch.device("cpu") 164 | self.max_obj = torch.tensor([max_obj]).to(device) 165 | self.iou_threshold = torch.tensor([iou_thres]).to(device) 166 | self.score_threshold = torch.tensor([score_thres]).to(device) 167 | self.max_wh = max_wh # if max_wh != 0 : non-agnostic else : agnostic 168 | self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]], 169 | dtype=torch.float32, 170 | device=self.device) 171 | self.n_classes=n_classes 172 | 173 | def forward(self, x): 174 | boxes = x[:, :, :4] 175 | conf = x[:, :, 4:5] 176 | scores = x[:, :, 5:] 177 | if self.n_classes == 1: 178 | scores = conf # for models with one class, cls_loss is 0 and cls_conf is always 0.5, 179 | # so there is no need to multiplicate. 180 | else: 181 | scores *= conf # conf = obj_conf * cls_conf 182 | boxes @= self.convert_matrix 183 | max_score, category_id = scores.max(2, keepdim=True) 184 | dis = category_id.float() * self.max_wh 185 | nmsbox = boxes + dis 186 | max_score_tp = max_score.transpose(1, 2).contiguous() 187 | selected_indices = ORT_NMS.apply(nmsbox, max_score_tp, self.max_obj, self.iou_threshold, self.score_threshold) 188 | X, Y = selected_indices[:, 0], selected_indices[:, 2] 189 | selected_boxes = boxes[X, Y, :] 190 | selected_categories = category_id[X, Y, :].float() 191 | selected_scores = max_score[X, Y, :] 192 | X = X.unsqueeze(1).float() 193 | return torch.cat([X, selected_boxes, selected_categories, selected_scores], 1) 194 | 195 | class ONNX_TRT(nn.Module): 196 | '''onnx module with TensorRT NMS operation.''' 197 | def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None, n_classes=80): 198 | super().__init__() 199 | assert max_wh is None 200 | self.device = device if device else torch.device('cpu') 201 | self.background_class = -1, 202 | self.box_coding = 1, 203 | self.iou_threshold = iou_thres 204 | self.max_obj = max_obj 205 | self.plugin_version = '1' 206 | self.score_activation = 0 207 | self.score_threshold = score_thres 208 | self.n_classes=n_classes 209 | 210 | def forward(self, x): 211 | boxes = x[:, :, :4] 212 | conf = x[:, :, 4:5] 213 | scores = x[:, :, 5:] 214 | if self.n_classes == 1: 215 | scores = conf # for models with one class, cls_loss is 0 and cls_conf is always 0.5, 216 | # so there is no need to multiplicate. 217 | else: 218 | scores *= conf # conf = obj_conf * cls_conf 219 | num_det, det_boxes, det_scores, det_classes = TRT_NMS.apply(boxes, scores, self.background_class, self.box_coding, 220 | self.iou_threshold, self.max_obj, 221 | self.plugin_version, self.score_activation, 222 | self.score_threshold) 223 | return num_det, det_boxes, det_scores, det_classes 224 | 225 | 226 | class End2End(nn.Module): 227 | '''export onnx or tensorrt model with NMS operation.''' 228 | def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None, device=None, n_classes=80): 229 | super().__init__() 230 | device = device if device else torch.device('cpu') 231 | assert isinstance(max_wh,(int)) or max_wh is None 232 | self.model = model.to(device) 233 | self.model.model[-1].end2end = True 234 | self.patch_model = ONNX_TRT if max_wh is None else ONNX_ORT 235 | self.end2end = self.patch_model(max_obj, iou_thres, score_thres, max_wh, device, n_classes) 236 | self.end2end.eval() 237 | 238 | def forward(self, x): 239 | x = self.model(x) 240 | x = self.end2end(x) 241 | return x 242 | 243 | 244 | 245 | 246 | 247 | def attempt_load(weights, map_location=None): 248 | # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a 249 | model = Ensemble() 250 | for w in weights if isinstance(weights, list) else [weights]: 251 | attempt_download(w) 252 | ckpt = torch.load(w, map_location=map_location) # load 253 | model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model 254 | 255 | # Compatibility updates 256 | for m in model.modules(): 257 | if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: 258 | m.inplace = True # pytorch 1.7.0 compatibility 259 | elif type(m) is nn.Upsample: 260 | m.recompute_scale_factor = None # torch 1.11.0 compatibility 261 | elif type(m) is Conv: 262 | m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility 263 | 264 | if len(model) == 1: 265 | return model[-1] # return model 266 | else: 267 | print('Ensemble created with %s\n' % weights) 268 | for k in ['names', 'stride']: 269 | setattr(model, k, getattr(model[-1], k)) 270 | return model # return ensemble 271 | 272 | 273 | -------------------------------------------------------------------------------- /models/SwinTransformer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import warnings 4 | from copy import copy 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import requests 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from PIL import Image 14 | from typing import Optional 15 | from torch.cuda import amp 16 | def drop_path_f(x, drop_prob: float = 0., training: bool = False): 17 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 18 | 19 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 20 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 21 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 22 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 23 | 'survival rate' as the argument. 24 | 25 | """ 26 | if drop_prob == 0. or not training: 27 | return x 28 | keep_prob = 1 - drop_prob 29 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 30 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 31 | random_tensor.floor_() # binarize 32 | output = x.div(keep_prob) * random_tensor 33 | return output 34 | 35 | 36 | class DropPath(nn.Module): 37 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 38 | """ 39 | def __init__(self, drop_prob=None): 40 | super(DropPath, self).__init__() 41 | self.drop_prob = drop_prob 42 | 43 | def forward(self, x): 44 | return drop_path_f(x, self.drop_prob, self.training) 45 | 46 | def window_partition(x, window_size: int): 47 | """ 48 | # Feature map is divided into one without overlapping windows in term of window_size. 49 | Args: 50 | x: (B, H, W, C) 51 | window_size (int): window size(M) 52 | 53 | Returns: 54 | windows: (num_windows*B, window_size, window_size, C) 55 | """ 56 | B, H, W, C = x.shape 57 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 58 | # permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C] 59 | # view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C] 60 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 61 | return windows 62 | 63 | def window_reverse(windows, window_size: int, H: int, W: int): 64 | """ 65 | # Every window reverted to one feature map. 66 | Args: 67 | windows: (num_windows*B, window_size, window_size, C) 68 | window_size (int): Window size(M) 69 | H (int): Height of image 70 | W (int): Width of image 71 | 72 | Returns: 73 | x: (B, H, W, C) 74 | """ 75 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 76 | # view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C] 77 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 78 | # permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C] 79 | # view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C] 80 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 81 | return x 82 | 83 | class Mlp(nn.Module): 84 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 85 | """ 86 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 87 | super().__init__() 88 | out_features = out_features or in_features 89 | hidden_features = hidden_features or in_features 90 | 91 | self.fc1 = nn.Linear(in_features, hidden_features) 92 | self.act = act_layer() 93 | self.drop1 = nn.Dropout(drop) 94 | self.fc2 = nn.Linear(hidden_features, out_features) 95 | self.drop2 = nn.Dropout(drop) 96 | 97 | def forward(self, x): 98 | x = self.fc1(x) 99 | x = self.act(x) 100 | x = self.drop1(x) 101 | x = self.fc2(x) 102 | x = self.drop2(x) 103 | return x 104 | 105 | class WindowAttention(nn.Module): 106 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 107 | It supports both of shifted and non-shifted window. 108 | 109 | Args: 110 | dim (int): Number of input channels. 111 | window_size (tuple[int]): The height and width of the window. 112 | num_heads (int): Number of attention heads. 113 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 114 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 115 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 116 | """ 117 | 118 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.): 119 | 120 | super().__init__() 121 | self.dim = dim 122 | self.window_size = window_size # [Mh, Mw] 123 | self.num_heads = num_heads 124 | head_dim = dim // num_heads 125 | self.scale = head_dim ** -0.5 126 | 127 | # define a parameter table of relative position bias 128 | self.relative_position_bias_table = nn.Parameter( 129 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # [2*Mh-1 * 2*Mw-1, nH] 130 | 131 | # get pair-wise relative position index for each token inside the window 132 | coords_h = torch.arange(self.window_size[0]) 133 | coords_w = torch.arange(self.window_size[1]) 134 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # [2, Mh, Mw] 135 | coords_flatten = torch.flatten(coords, 1) # [2, Mh*Mw] 136 | # [2, Mh*Mw, 1] - [2, 1, Mh*Mw] 137 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2, Mh*Mw, Mh*Mw] 138 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # [Mh*Mw, Mh*Mw, 2] 139 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 140 | relative_coords[:, :, 1] += self.window_size[1] - 1 141 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 142 | relative_position_index = relative_coords.sum(-1) # [Mh*Mw, Mh*Mw] 143 | self.register_buffer("relative_position_index", relative_position_index) 144 | 145 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 146 | self.attn_drop = nn.Dropout(attn_drop) 147 | self.proj = nn.Linear(dim, dim) 148 | self.proj_drop = nn.Dropout(proj_drop) 149 | 150 | nn.init.trunc_normal_(self.relative_position_bias_table, std=.02) 151 | self.softmax = nn.Softmax(dim=-1) 152 | 153 | def forward(self, x, mask: Optional[torch.Tensor] = None): 154 | """ 155 | Args: 156 | x: input features with shape of (num_windows*B, Mh*Mw, C) 157 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 158 | """ 159 | # [batch_size*num_windows, Mh*Mw, total_embed_dim] 160 | B_, N, C = x.shape 161 | # qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim] 162 | # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head] 163 | # permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head] 164 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 165 | # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head] 166 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 167 | 168 | # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw] 169 | # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw] 170 | q = q * self.scale 171 | attn = (q @ k.transpose(-2, -1)) 172 | 173 | # relative_position_bias_table.view: [Mh*Mw*Mh*Mw,nH] -> [Mh*Mw,Mh*Mw,nH] 174 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 175 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) 176 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # [nH, Mh*Mw, Mh*Mw] 177 | attn = attn + relative_position_bias.unsqueeze(0) 178 | 179 | if mask is not None: 180 | # mask: [nW, Mh*Mw, Mh*Mw] 181 | nW = mask.shape[0] # num_windows 182 | # attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw] 183 | # mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw] 184 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 185 | attn = attn.view(-1, self.num_heads, N, N) 186 | attn = self.softmax(attn) 187 | else: 188 | attn = self.softmax(attn) 189 | 190 | attn = self.attn_drop(attn) 191 | 192 | # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head] 193 | # transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head] 194 | # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim] 195 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 196 | x = self.proj(x) 197 | x = self.proj_drop(x) 198 | return x 199 | 200 | class SwinTransformerLayer(nn.Module): 201 | # Vision Transformer (ViT) 202 | # https://arxiv.org/abs/2010.11929 203 | # https://github.com/google-research/vision_transformer 204 | def __init__(self, c, num_heads, window_size=7, shift_size=0, 205 | mlp_ratio = 4, qkv_bias=False, drop=0., attn_drop=0., drop_path=0., 206 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 207 | super().__init__() 208 | if num_heads > 10: 209 | drop_path = 0.1 210 | self.window_size = window_size 211 | self.shift_size = shift_size 212 | self.mlp_ratio = mlp_ratio 213 | 214 | self.norm1 = norm_layer(c) 215 | self.attn = WindowAttention( 216 | c, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, 217 | attn_drop=attn_drop, proj_drop=drop) 218 | 219 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 220 | self.norm2 = norm_layer(c) 221 | mlp_hidden_dim = int(c * mlp_ratio) 222 | self.mlp = Mlp(in_features=c, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 223 | 224 | def create_mask(self, x, H, W): 225 | # calculate attention mask for SW-MSA 226 | # Hp and Wp are integer multiples of window_size 227 | Hp = int(np.ceil(H / self.window_size)) * self.window_size 228 | Wp = int(np.ceil(W / self.window_size)) * self.window_size 229 | # The channels are arranged in the same order as the feature map to facilitate subsequent window_partitions 230 | img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # [1, Hp, Wp, 1] 231 | h_slices = ( (0, -self.window_size), 232 | slice(-self.window_size, -self.shift_size), 233 | slice(-self.shift_size, None)) 234 | w_slices = (slice(0, -self.window_size), 235 | slice(-self.window_size, -self.shift_size), 236 | slice(-self.shift_size, None)) 237 | cnt = 0 238 | for h in h_slices: 239 | for w in w_slices: 240 | img_mask[:, h, w, :] = cnt 241 | cnt += 1 242 | 243 | mask_windows = window_partition(img_mask, self.window_size) # [nW, Mh, Mw, 1] 244 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # [nW, Mh*Mw] 245 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1] 246 | # [nW, Mh*Mw, Mh*Mw] 247 | attn_mask = attn_mask.masked_fill(attn_mask != 0, torch.tensor(-100.0)).masked_fill(attn_mask == 0, torch.tensor(0.0)) 248 | return attn_mask 249 | 250 | def forward(self, x): 251 | b, c, w, h = x.shape 252 | x = x.permute(0, 3, 2, 1).contiguous() # [b,h,w,c] 253 | 254 | attn_mask = self.create_mask(x, h, w) # [nW, Mh*Mw, Mh*Mw] 255 | shortcut = x 256 | x = self.norm1(x) 257 | 258 | pad_l = pad_t = 0 259 | pad_r = (self.window_size - w % self.window_size) % self.window_size 260 | pad_b = (self.window_size - h % self.window_size) % self.window_size 261 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) 262 | _, hp, wp, _ = x.shape 263 | 264 | if self.shift_size > 0: 265 | # print(f"shift size: {self.shift_size}") 266 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 267 | else: 268 | shifted_x = x 269 | attn_mask = None 270 | 271 | x_windows = window_partition(shifted_x, self.window_size) # [nW*B, Mh, Mw, C] 272 | x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # [nW*B, Mh*Mw, C] 273 | 274 | attn_windows = self.attn(x_windows, mask=attn_mask) # [nW*B, Mh*Mw, C] 275 | 276 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c) # [nW*B, Mh, Mw, C] 277 | shifted_x = window_reverse(attn_windows, self.window_size, hp, wp) # [B, H', W', C] 278 | 279 | if self.shift_size > 0: 280 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 281 | else: 282 | x = shifted_x 283 | 284 | if pad_r > 0 or pad_b > 0: 285 | # Remove the data from pad 286 | x = x[:, :h, :w, :].contiguous() 287 | 288 | x = shortcut + self.drop_path(x) 289 | x = x + self.drop_path(self.mlp(self.norm2(x))) 290 | 291 | x = x.permute(0, 3, 2, 1).contiguous() 292 | return x # (b, self.c2, w, h) 293 | 294 | -------------------------------------------------------------------------------- /utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | # YOLOR PyTorch utils 2 | 3 | import datetime 4 | import logging 5 | import math 6 | import os 7 | import platform 8 | import subprocess 9 | import time 10 | from contextlib import contextmanager 11 | from copy import deepcopy 12 | from pathlib import Path 13 | 14 | import torch 15 | import torch.backends.cudnn as cudnn 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torchvision 19 | 20 | try: 21 | import thop # for FLOPS computation 22 | except ImportError: 23 | thop = None 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | @contextmanager 28 | def torch_distributed_zero_first(local_rank: int): 29 | """ 30 | Decorator to make all processes in distributed training wait for each local_master to do something. 31 | """ 32 | if local_rank not in [-1, 0]: 33 | torch.distributed.barrier() 34 | yield 35 | if local_rank == 0: 36 | torch.distributed.barrier() 37 | 38 | 39 | def init_torch_seeds(seed=0): 40 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html 41 | torch.manual_seed(seed) 42 | if seed == 0: # slower, more reproducible 43 | cudnn.benchmark, cudnn.deterministic = False, True 44 | else: # faster, less reproducible 45 | cudnn.benchmark, cudnn.deterministic = True, False 46 | 47 | 48 | def date_modified(path=__file__): 49 | # return human-readable file modification date, i.e. '2021-3-26' 50 | t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime) 51 | return f'{t.year}-{t.month}-{t.day}' 52 | 53 | 54 | def git_describe(path=Path(__file__).parent): # path must be a directory 55 | # return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe 56 | s = f'git -C {path} describe --tags --long --always' 57 | try: 58 | return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1] 59 | except subprocess.CalledProcessError as e: 60 | return '' # not a git repository 61 | 62 | 63 | def select_device(device='', batch_size=None): 64 | # device = 'cpu' or '0' or '0,1,2,3' 65 | s = f'YOLOR 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string 66 | cpu = device.lower() == 'cpu' 67 | if cpu: 68 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False 69 | elif device: # non-cpu device requested 70 | os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable 71 | assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability 72 | 73 | cuda = not cpu and torch.cuda.is_available() 74 | if cuda: 75 | n = torch.cuda.device_count() 76 | if n > 1 and batch_size: # check that batch_size is compatible with device_count 77 | assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}' 78 | space = ' ' * len(s) 79 | for i, d in enumerate(device.split(',') if device else range(n)): 80 | p = torch.cuda.get_device_properties(i) 81 | s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB 82 | else: 83 | s += 'CPU\n' 84 | 85 | logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe 86 | return torch.device('cuda:0' if cuda else 'cpu') 87 | 88 | 89 | def time_synchronized(): 90 | # pytorch-accurate time 91 | if torch.cuda.is_available(): 92 | torch.cuda.synchronize() 93 | return time.time() / 1.2 94 | 95 | 96 | def profile(x, ops, n=100, device=None): 97 | # profile a pytorch module or list of modules. Example usage: 98 | # x = torch.randn(16, 3, 640, 640) # input 99 | # m1 = lambda x: x * torch.sigmoid(x) 100 | # m2 = nn.SiLU() 101 | # profile(x, [m1, m2], n=100) # profile speed over 100 iterations 102 | 103 | device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 104 | x = x.to(device) 105 | x.requires_grad = True 106 | print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '') 107 | print(f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}") 108 | for m in ops if isinstance(ops, list) else [ops]: 109 | m = m.to(device) if hasattr(m, 'to') else m # device 110 | m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type 111 | dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward 112 | try: 113 | flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPS 114 | except: 115 | flops = 0 116 | 117 | for _ in range(n): 118 | t[0] = time_synchronized() 119 | y = m(x) 120 | t[1] = time_synchronized() 121 | try: 122 | _ = y.sum().backward() 123 | t[2] = time_synchronized() 124 | except: # no backward method 125 | t[2] = float('nan') 126 | dtf += (t[1] - t[0]) * 1000 / n # ms per op forward 127 | dtb += (t[2] - t[1]) * 1000 / n # ms per op backward 128 | 129 | s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' 130 | s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list' 131 | p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters 132 | print(f'{p:12}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}') 133 | 134 | 135 | def is_parallel(model): 136 | return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) 137 | 138 | 139 | def intersect_dicts(da, db, exclude=()): 140 | # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values 141 | return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape} 142 | 143 | 144 | def initialize_weights(model): 145 | for m in model.modules(): 146 | t = type(m) 147 | if t is nn.Conv2d: 148 | pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 149 | elif t is nn.BatchNorm2d: 150 | m.eps = 1e-3 151 | m.momentum = 0.03 152 | elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]: 153 | m.inplace = True 154 | 155 | 156 | def find_modules(model, mclass=nn.Conv2d): 157 | # Finds layer indices matching module class 'mclass' 158 | return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)] 159 | 160 | 161 | def sparsity(model): 162 | # Return global model sparsity 163 | a, b = 0., 0. 164 | for p in model.parameters(): 165 | a += p.numel() 166 | b += (p == 0).sum() 167 | return b / a 168 | 169 | 170 | def prune(model, amount=0.3): 171 | # Prune model to requested global sparsity 172 | import torch.nn.utils.prune as prune 173 | print('Pruning model... ', end='') 174 | for name, m in model.named_modules(): 175 | if isinstance(m, nn.Conv2d): 176 | prune.l1_unstructured(m, name='weight', amount=amount) # prune 177 | prune.remove(m, 'weight') # make permanent 178 | print(' %.3g global sparsity' % sparsity(model)) 179 | 180 | 181 | def fuse_conv_and_bn(conv, bn): 182 | # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ 183 | fusedconv = nn.Conv2d(conv.in_channels, 184 | conv.out_channels, 185 | kernel_size=conv.kernel_size, 186 | stride=conv.stride, 187 | padding=conv.padding, 188 | groups=conv.groups, 189 | bias=True).requires_grad_(False).to(conv.weight.device) 190 | 191 | # prepare filters 192 | w_conv = conv.weight.clone().view(conv.out_channels, -1) 193 | w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) 194 | fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) 195 | 196 | # prepare spatial bias 197 | b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias 198 | b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) 199 | fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) 200 | 201 | return fusedconv 202 | 203 | 204 | def model_info(model, verbose=False, img_size=640): 205 | # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320] 206 | n_p = sum(x.numel() for x in model.parameters()) # number parameters 207 | n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients 208 | if verbose: 209 | print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma')) 210 | for i, (name, p) in enumerate(model.named_parameters()): 211 | name = name.replace('module_list.', '') 212 | print('%5g %40s %9s %12g %20s %10.3g %10.3g' % 213 | (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std())) 214 | 215 | try: # FLOPS 216 | from thop import profile 217 | stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 218 | img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input 219 | flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPS 220 | img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float 221 | fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPS 222 | except (ImportError, Exception): 223 | fs = '' 224 | 225 | logger.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}") 226 | 227 | 228 | def load_classifier(name='resnet101', n=2): 229 | # Loads a pretrained model reshaped to n-class output 230 | model = torchvision.models.__dict__[name](pretrained=True) 231 | 232 | # ResNet model properties 233 | # input_size = [3, 224, 224] 234 | # input_space = 'RGB' 235 | # input_range = [0, 1] 236 | # mean = [0.485, 0.456, 0.406] 237 | # std = [0.229, 0.224, 0.225] 238 | 239 | # Reshape output to n classes 240 | filters = model.fc.weight.shape[1] 241 | model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True) 242 | model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True) 243 | model.fc.out_features = n 244 | return model 245 | 246 | 247 | def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416) 248 | # scales img(bs,3,y,x) by ratio constrained to gs-multiple 249 | if ratio == 1.0: 250 | return img 251 | else: 252 | h, w = img.shape[2:] 253 | s = (int(h * ratio), int(w * ratio)) # new size 254 | img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize 255 | if not same_shape: # pad/crop img 256 | h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)] 257 | return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean 258 | 259 | 260 | def copy_attr(a, b, include=(), exclude=()): 261 | # Copy attributes from b to a, options to only include [...] and to exclude [...] 262 | for k, v in b.__dict__.items(): 263 | if (len(include) and k not in include) or k.startswith('_') or k in exclude: 264 | continue 265 | else: 266 | setattr(a, k, v) 267 | 268 | 269 | class ModelEMA: 270 | """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models 271 | Keep a moving average of everything in the model state_dict (parameters and buffers). 272 | This is intended to allow functionality like 273 | https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage 274 | A smoothed version of the weights is necessary for some training schemes to perform well. 275 | This class is sensitive where it is initialized in the sequence of model init, 276 | GPU assignment and distributed training wrappers. 277 | """ 278 | 279 | def __init__(self, model, decay=0.9999, updates=0): 280 | # Create EMA 281 | self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA 282 | # if next(model.parameters()).device.type != 'cpu': 283 | # self.ema.half() # FP16 EMA 284 | self.updates = updates # number of EMA updates 285 | self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs) 286 | for p in self.ema.parameters(): 287 | p.requires_grad_(False) 288 | 289 | def update(self, model): 290 | # Update EMA parameters 291 | with torch.no_grad(): 292 | self.updates += 1 293 | d = self.decay(self.updates) 294 | 295 | msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict 296 | for k, v in self.ema.state_dict().items(): 297 | if v.dtype.is_floating_point: 298 | v *= d 299 | v += (1. - d) * msd[k].detach() 300 | 301 | def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): 302 | # Update EMA attributes 303 | copy_attr(self.ema, model, include, exclude) 304 | 305 | 306 | class BatchNormXd(torch.nn.modules.batchnorm._BatchNorm): 307 | def _check_input_dim(self, input): 308 | # The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc 309 | # is this method that is overwritten by the sub-class 310 | # This original goal of this method was for tensor sanity checks 311 | # If you're ok bypassing those sanity checks (eg. if you trust your inference 312 | # to provide the right dimensional inputs), then you can just use this method 313 | # for easy conversion from SyncBatchNorm 314 | # (unfortunately, SyncBatchNorm does not store the original class - if it did 315 | # we could return the one that was originally created) 316 | return 317 | 318 | def revert_sync_batchnorm(module): 319 | # this is very similar to the function that it is trying to revert: 320 | # https://github.com/pytorch/pytorch/blob/c8b3686a3e4ba63dc59e5dcfe5db3430df256833/torch/nn/modules/batchnorm.py#L679 321 | module_output = module 322 | if isinstance(module, torch.nn.modules.batchnorm.SyncBatchNorm): 323 | new_cls = BatchNormXd 324 | module_output = BatchNormXd(module.num_features, 325 | module.eps, module.momentum, 326 | module.affine, 327 | module.track_running_stats) 328 | if module.affine: 329 | with torch.no_grad(): 330 | module_output.weight = module.weight 331 | module_output.bias = module.bias 332 | module_output.running_mean = module.running_mean 333 | module_output.running_var = module.running_var 334 | module_output.num_batches_tracked = module.num_batches_tracked 335 | if hasattr(module, "qconfig"): 336 | module_output.qconfig = module.qconfig 337 | for name, child in module.named_children(): 338 | module_output.add_module(name, revert_sync_batchnorm(child)) 339 | del module 340 | return module_output 341 | 342 | 343 | class TracedModel(nn.Module): 344 | 345 | def __init__(self, model=None, device=None, img_size=(640,640)): 346 | super(TracedModel, self).__init__() 347 | 348 | print(" Convert model to Traced-model... ") 349 | self.stride = model.stride 350 | self.names = model.names 351 | self.model = model 352 | 353 | self.model = revert_sync_batchnorm(self.model) 354 | self.model.to('cpu') 355 | self.model.eval() 356 | 357 | self.detect_layer = self.model.model[-1] 358 | self.model.traced = True 359 | 360 | rand_example = torch.rand(1, 3, img_size, img_size) 361 | 362 | traced_script_module = torch.jit.trace(self.model, rand_example, strict=False) 363 | #traced_script_module = torch.jit.script(self.model) 364 | traced_script_module.save("traced_model.pt") 365 | print(" traced_script_module saved! ") 366 | self.model = traced_script_module 367 | self.model.to(device) 368 | self.detect_layer.to(device) 369 | print(" model is traced! \n") 370 | 371 | def forward(self, x, augment=False, profile=False): 372 | out = self.model(x) 373 | out = self.detect_layer(out) 374 | return out -------------------------------------------------------------------------------- /utils/wandb_logging/wandb_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | from pathlib import Path 4 | 5 | import torch 6 | import yaml 7 | from tqdm import tqdm 8 | 9 | sys.path.append(str(Path(__file__).parent.parent.parent)) # add utils/ to path 10 | from utils.datasets import LoadImagesAndLabels 11 | from utils.datasets import img2label_paths 12 | from utils.general import colorstr, xywh2xyxy, check_dataset 13 | 14 | try: 15 | import wandb 16 | from wandb import init, finish 17 | except ImportError: 18 | wandb = None 19 | 20 | WANDB_ARTIFACT_PREFIX = 'wandb-artifact://' 21 | 22 | 23 | def remove_prefix(from_string, prefix=WANDB_ARTIFACT_PREFIX): 24 | return from_string[len(prefix):] 25 | 26 | 27 | def check_wandb_config_file(data_config_file): 28 | wandb_config = '_wandb.'.join(data_config_file.rsplit('.', 1)) # updated data.yaml path 29 | if Path(wandb_config).is_file(): 30 | return wandb_config 31 | return data_config_file 32 | 33 | 34 | def get_run_info(run_path): 35 | run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX)) 36 | run_id = run_path.stem 37 | project = run_path.parent.stem 38 | model_artifact_name = 'run_' + run_id + '_model' 39 | return run_id, project, model_artifact_name 40 | 41 | 42 | def check_wandb_resume(opt): 43 | process_wandb_config_ddp_mode(opt) if opt.global_rank not in [-1, 0] else None 44 | if isinstance(opt.resume, str): 45 | if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): 46 | if opt.global_rank not in [-1, 0]: # For resuming DDP runs 47 | run_id, project, model_artifact_name = get_run_info(opt.resume) 48 | api = wandb.Api() 49 | artifact = api.artifact(project + '/' + model_artifact_name + ':latest') 50 | modeldir = artifact.download() 51 | opt.weights = str(Path(modeldir) / "last.pt") 52 | return True 53 | return None 54 | 55 | 56 | def process_wandb_config_ddp_mode(opt): 57 | with open(opt.data) as f: 58 | data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict 59 | train_dir, val_dir = None, None 60 | if isinstance(data_dict['train'], str) and data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX): 61 | api = wandb.Api() 62 | train_artifact = api.artifact(remove_prefix(data_dict['train']) + ':' + opt.artifact_alias) 63 | train_dir = train_artifact.download() 64 | train_path = Path(train_dir) / 'data/images/' 65 | data_dict['train'] = str(train_path) 66 | 67 | if isinstance(data_dict['val'], str) and data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX): 68 | api = wandb.Api() 69 | val_artifact = api.artifact(remove_prefix(data_dict['val']) + ':' + opt.artifact_alias) 70 | val_dir = val_artifact.download() 71 | val_path = Path(val_dir) / 'data/images/' 72 | data_dict['val'] = str(val_path) 73 | if train_dir or val_dir: 74 | ddp_data_path = str(Path(val_dir) / 'wandb_local_data.yaml') 75 | with open(ddp_data_path, 'w') as f: 76 | yaml.dump(data_dict, f) 77 | opt.data = ddp_data_path 78 | 79 | 80 | class WandbLogger(): 81 | def __init__(self, opt, name, run_id, data_dict, job_type='Training'): 82 | # Pre-training routine -- 83 | self.job_type = job_type 84 | self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict 85 | # It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call 86 | if isinstance(opt.resume, str): # checks resume from artifact 87 | if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): 88 | run_id, project, model_artifact_name = get_run_info(opt.resume) 89 | model_artifact_name = WANDB_ARTIFACT_PREFIX + model_artifact_name 90 | assert wandb, 'install wandb to resume wandb runs' 91 | # Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config 92 | self.wandb_run = wandb.init(id=run_id, project=project, resume='allow') 93 | opt.resume = model_artifact_name 94 | elif self.wandb: 95 | self.wandb_run = wandb.init(config=opt, 96 | resume="allow", 97 | project='YOLOR' if opt.project == 'runs/train' else Path(opt.project).stem, 98 | name=name, 99 | job_type=job_type, 100 | id=run_id) if not wandb.run else wandb.run 101 | if self.wandb_run: 102 | if self.job_type == 'Training': 103 | if not opt.resume: 104 | wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict 105 | # Info useful for resuming from artifacts 106 | self.wandb_run.config.opt = vars(opt) 107 | self.wandb_run.config.data_dict = wandb_data_dict 108 | self.data_dict = self.setup_training(opt, data_dict) 109 | if self.job_type == 'Dataset Creation': 110 | self.data_dict = self.check_and_upload_dataset(opt) 111 | else: 112 | prefix = colorstr('wandb: ') 113 | print(f"{prefix}Install Weights & Biases for YOLOR logging with 'pip install wandb' (recommended)") 114 | 115 | def check_and_upload_dataset(self, opt): 116 | assert wandb, 'Install wandb to upload dataset' 117 | check_dataset(self.data_dict) 118 | config_path = self.log_dataset_artifact(opt.data, 119 | opt.single_cls, 120 | 'YOLOR' if opt.project == 'runs/train' else Path(opt.project).stem) 121 | print("Created dataset config file ", config_path) 122 | with open(config_path) as f: 123 | wandb_data_dict = yaml.load(f, Loader=yaml.SafeLoader) 124 | return wandb_data_dict 125 | 126 | def setup_training(self, opt, data_dict): 127 | self.log_dict, self.current_epoch, self.log_imgs = {}, 0, 16 # Logging Constants 128 | self.bbox_interval = opt.bbox_interval 129 | if isinstance(opt.resume, str): 130 | modeldir, _ = self.download_model_artifact(opt) 131 | if modeldir: 132 | self.weights = Path(modeldir) / "last.pt" 133 | config = self.wandb_run.config 134 | opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp = str( 135 | self.weights), config.save_period, config.total_batch_size, config.bbox_interval, config.epochs, \ 136 | config.opt['hyp'] 137 | data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume 138 | if 'val_artifact' not in self.__dict__: # If --upload_dataset is set, use the existing artifact, don't download 139 | self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'), 140 | opt.artifact_alias) 141 | self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'), 142 | opt.artifact_alias) 143 | self.result_artifact, self.result_table, self.val_table, self.weights = None, None, None, None 144 | if self.train_artifact_path is not None: 145 | train_path = Path(self.train_artifact_path) / 'data/images/' 146 | data_dict['train'] = str(train_path) 147 | if self.val_artifact_path is not None: 148 | val_path = Path(self.val_artifact_path) / 'data/images/' 149 | data_dict['val'] = str(val_path) 150 | self.val_table = self.val_artifact.get("val") 151 | self.map_val_table_path() 152 | if self.val_artifact is not None: 153 | self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation") 154 | self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"]) 155 | if opt.bbox_interval == -1: 156 | self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1 157 | return data_dict 158 | 159 | def download_dataset_artifact(self, path, alias): 160 | if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX): 161 | dataset_artifact = wandb.use_artifact(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias) 162 | assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'" 163 | datadir = dataset_artifact.download() 164 | return datadir, dataset_artifact 165 | return None, None 166 | 167 | def download_model_artifact(self, opt): 168 | if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): 169 | model_artifact = wandb.use_artifact(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX) + ":latest") 170 | assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist' 171 | modeldir = model_artifact.download() 172 | epochs_trained = model_artifact.metadata.get('epochs_trained') 173 | total_epochs = model_artifact.metadata.get('total_epochs') 174 | assert epochs_trained < total_epochs, 'training to %g epochs is finished, nothing to resume.' % ( 175 | total_epochs) 176 | return modeldir, model_artifact 177 | return None, None 178 | 179 | def log_model(self, path, opt, epoch, fitness_score, best_model=False): 180 | model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={ 181 | 'original_url': str(path), 182 | 'epochs_trained': epoch + 1, 183 | 'save period': opt.save_period, 184 | 'project': opt.project, 185 | 'total_epochs': opt.epochs, 186 | 'fitness_score': fitness_score 187 | }) 188 | model_artifact.add_file(str(path / 'last.pt'), name='last.pt') 189 | wandb.log_artifact(model_artifact, 190 | aliases=['latest', 'epoch ' + str(self.current_epoch), 'best' if best_model else '']) 191 | print("Saving model artifact on epoch ", epoch + 1) 192 | 193 | def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False): 194 | with open(data_file) as f: 195 | data = yaml.load(f, Loader=yaml.SafeLoader) # data dict 196 | nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names']) 197 | names = {k: v for k, v in enumerate(names)} # to index dictionary 198 | self.train_artifact = self.create_dataset_table(LoadImagesAndLabels( 199 | data['train']), names, name='train') if data.get('train') else None 200 | self.val_artifact = self.create_dataset_table(LoadImagesAndLabels( 201 | data['val']), names, name='val') if data.get('val') else None 202 | if data.get('train'): 203 | data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'train') 204 | if data.get('val'): 205 | data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val') 206 | path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path 207 | data.pop('download', None) 208 | with open(path, 'w') as f: 209 | yaml.dump(data, f) 210 | 211 | if self.job_type == 'Training': # builds correct artifact pipeline graph 212 | self.wandb_run.use_artifact(self.val_artifact) 213 | self.wandb_run.use_artifact(self.train_artifact) 214 | self.val_artifact.wait() 215 | self.val_table = self.val_artifact.get('val') 216 | self.map_val_table_path() 217 | else: 218 | self.wandb_run.log_artifact(self.train_artifact) 219 | self.wandb_run.log_artifact(self.val_artifact) 220 | return path 221 | 222 | def map_val_table_path(self): 223 | self.val_table_map = {} 224 | print("Mapping dataset") 225 | for i, data in enumerate(tqdm(self.val_table.data)): 226 | self.val_table_map[data[3]] = data[0] 227 | 228 | def create_dataset_table(self, dataset, class_to_id, name='dataset'): 229 | # TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging 230 | artifact = wandb.Artifact(name=name, type="dataset") 231 | img_files = tqdm([dataset.path]) if isinstance(dataset.path, str) and Path(dataset.path).is_dir() else None 232 | img_files = tqdm(dataset.img_files) if not img_files else img_files 233 | for img_file in img_files: 234 | if Path(img_file).is_dir(): 235 | artifact.add_dir(img_file, name='data/images') 236 | labels_path = 'labels'.join(dataset.path.rsplit('images', 1)) 237 | artifact.add_dir(labels_path, name='data/labels') 238 | else: 239 | artifact.add_file(img_file, name='data/images/' + Path(img_file).name) 240 | label_file = Path(img2label_paths([img_file])[0]) 241 | artifact.add_file(str(label_file), 242 | name='data/labels/' + label_file.name) if label_file.exists() else None 243 | table = wandb.Table(columns=["id", "train_image", "Classes", "name"]) 244 | class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()]) 245 | for si, (img, labels, paths, shapes) in enumerate(tqdm(dataset)): 246 | height, width = shapes[0] 247 | labels[:, 2:] = (xywh2xyxy(labels[:, 2:].view(-1, 4))) * torch.Tensor([width, height, width, height]) 248 | box_data, img_classes = [], {} 249 | for cls, *xyxy in labels[:, 1:].tolist(): 250 | cls = int(cls) 251 | box_data.append({"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]}, 252 | "class_id": cls, 253 | "box_caption": "%s" % (class_to_id[cls]), 254 | "scores": {"acc": 1}, 255 | "domain": "pixel"}) 256 | img_classes[cls] = class_to_id[cls] 257 | boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}} # inference-space 258 | table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), json.dumps(img_classes), 259 | Path(paths).name) 260 | artifact.add(table, name) 261 | return artifact 262 | 263 | def log_training_progress(self, predn, path, names): 264 | if self.val_table and self.result_table: 265 | class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()]) 266 | box_data = [] 267 | total_conf = 0 268 | for *xyxy, conf, cls in predn.tolist(): 269 | if conf >= 0.25: 270 | box_data.append( 271 | {"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]}, 272 | "class_id": int(cls), 273 | "box_caption": "%s %.3f" % (names[cls], conf), 274 | "scores": {"class_score": conf}, 275 | "domain": "pixel"}) 276 | total_conf = total_conf + conf 277 | boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space 278 | id = self.val_table_map[Path(path).name] 279 | self.result_table.add_data(self.current_epoch, 280 | id, 281 | wandb.Image(self.val_table.data[id][1], boxes=boxes, classes=class_set), 282 | total_conf / max(1, len(box_data)) 283 | ) 284 | 285 | def log(self, log_dict): 286 | if self.wandb_run: 287 | for key, value in log_dict.items(): 288 | self.log_dict[key] = value 289 | 290 | def end_epoch(self, best_result=False): 291 | if self.wandb_run: 292 | wandb.log(self.log_dict) 293 | self.log_dict = {} 294 | if self.result_artifact: 295 | train_results = wandb.JoinedTable(self.val_table, self.result_table, "id") 296 | self.result_artifact.add(train_results, 'result') 297 | wandb.log_artifact(self.result_artifact, aliases=['latest', 'epoch ' + str(self.current_epoch), 298 | ('best' if best_result else '')]) 299 | self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"]) 300 | self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation") 301 | 302 | def finish_run(self): 303 | if self.wandb_run: 304 | if self.log_dict: 305 | wandb.log(self.log_dict) 306 | wandb.run.finish() 307 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from pathlib import Path 5 | from threading import Thread 6 | 7 | import numpy as np 8 | import torch 9 | import yaml 10 | from tqdm import tqdm 11 | 12 | from models.experimental import attempt_load 13 | from utils.datasets import create_dataloader 14 | from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, check_requirements, \ 15 | box_iou, non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path, colorstr 16 | from utils.metrics import ap_per_class, ConfusionMatrix 17 | from utils.plots import plot_images, output_to_target, plot_study_txt 18 | from utils.torch_utils import select_device, time_synchronized, TracedModel 19 | 20 | 21 | def test(data, 22 | weights=None, 23 | batch_size=32, 24 | imgsz=640, 25 | conf_thres=0.001, 26 | iou_thres=0.6, # for NMS 27 | save_json=False, 28 | single_cls=False, 29 | augment=False, 30 | verbose=False, 31 | model=None, 32 | dataloader=None, 33 | save_dir=Path(''), # for saving images 34 | save_txt=False, # for auto-labelling 35 | save_hybrid=False, # for hybrid auto-labelling 36 | save_conf=False, # save auto-label confidences 37 | plots=True, 38 | wandb_logger=None, 39 | compute_loss=None, 40 | half_precision=True, 41 | trace=False, 42 | is_coco=False, 43 | v5_metric=False): 44 | # Initialize/load model and set device 45 | training = model is not None 46 | if training: # called by train.py 47 | device = next(model.parameters()).device # get model device 48 | 49 | else: # called directly 50 | set_logging() 51 | device = select_device(opt.device, batch_size=batch_size) 52 | 53 | # Directories 54 | save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run 55 | (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir 56 | 57 | # Load model 58 | model = attempt_load(weights, map_location=device) # load FP32 model 59 | gs = max(int(model.stride.max()), 32) # grid size (max stride) 60 | imgsz = check_img_size(imgsz, s=gs) # check img_size 61 | 62 | if False: 63 | model1 = TracedModel(model, device, imgsz) 64 | 65 | # Half 66 | half = device.type != 'cpu' and half_precision # half precision only supported on CUDA 67 | if half: 68 | model.half() 69 | 70 | # Configure 71 | model.eval() 72 | if isinstance(data, str): 73 | is_coco = data.endswith('coco.yaml') 74 | with open(data) as f: 75 | data = yaml.load(f, Loader=yaml.SafeLoader) 76 | check_dataset(data) # check 77 | nc = 1 if single_cls else int(data['nc']) # number of classes 78 | iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95 79 | niou = iouv.numel() 80 | 81 | # Logging 82 | log_imgs = 0 83 | if wandb_logger and wandb_logger.wandb: 84 | log_imgs = min(wandb_logger.log_imgs, 100) 85 | # Dataloader 86 | if not training: 87 | if device.type != 'cpu': 88 | model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once 89 | task = opt.task if opt.task in ('train', 'val', 'test') else 'val' # path to train/val/test images 90 | 91 | dataloader = create_dataloader(data[task], imgsz, batch_size, gs, opt, pad=0.5, rect=True, 92 | prefix=colorstr(f'{task}: '))[0] 93 | 94 | if v5_metric: 95 | print("Testing with YOLOv5 AP metric...") 96 | 97 | seen = 0 98 | confusion_matrix = ConfusionMatrix(nc=nc) 99 | names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)} 100 | coco91class = coco80_to_coco91_class() 101 | s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', 'mAP@.5', 'mAP@.5:.95') 102 | p, r, f1, mp, mr, map50, map, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0. 103 | loss = torch.zeros(3, device=device) 104 | jdict, stats, ap, ap_class, wandb_images = [], [], [], [], [] 105 | for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)): 106 | img = img.to(device, non_blocking=True) 107 | img = img.half() if half else img.float() # uint8 to fp16/32 108 | img /= 255.0 # 0 - 255 to 0.0 - 1.0 109 | targets = targets.to(device) 110 | nb, _, height, width = img.shape # batch size, channels, height, width 111 | 112 | with torch.no_grad(): 113 | # Run model 114 | t = time_synchronized() 115 | out, train_out = model(img, augment=False) # inference and training outputs 116 | t0 += time_synchronized() - t 117 | 118 | # Compute loss 119 | if compute_loss: 120 | loss += compute_loss([x.float() for x in train_out], targets)[1][:3] # box, obj, cls 121 | 122 | # Run NMS 123 | targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels 124 | lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling 125 | t = time_synchronized() 126 | out = non_max_suppression(out, conf_thres=conf_thres, iou_thres=iou_thres, labels=lb, multi_label=True) 127 | t1 += time_synchronized() - t 128 | 129 | # Statistics per image 130 | for si, pred in enumerate(out): 131 | labels = targets[targets[:, 0] == si, 1:] 132 | nl = len(labels) 133 | tcls = labels[:, 0].tolist() if nl else [] # target class 134 | path = Path(paths[si]) 135 | seen += 1 136 | 137 | if len(pred) == 0: 138 | if nl: 139 | stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls)) 140 | continue 141 | 142 | # Predictions 143 | predn = pred.clone() 144 | scale_coords(img[si].shape[1:], predn[:, :4], shapes[si][0], shapes[si][1]) # native-space pred 145 | 146 | # Append to text file 147 | if save_txt: 148 | gn = torch.tensor(shapes[si][0])[[1, 0, 1, 0]] # normalization gain whwh 149 | for *xyxy, conf, cls in predn.tolist(): 150 | xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh 151 | line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format 152 | with open(save_dir / 'labels' / (path.stem + '.txt'), 'a') as f: 153 | f.write(('%g ' * len(line)).rstrip() % line + '\n') 154 | 155 | # W&B logging - Media Panel Plots 156 | if len(wandb_images) < log_imgs and wandb_logger.current_epoch > 0: # Check for test operation 157 | if wandb_logger.current_epoch % wandb_logger.bbox_interval == 0: 158 | box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]}, 159 | "class_id": int(cls), 160 | "box_caption": "%s %.3f" % (names[cls], conf), 161 | "scores": {"class_score": conf}, 162 | "domain": "pixel"} for *xyxy, conf, cls in pred.tolist()] 163 | boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space 164 | wandb_images.append(wandb_logger.wandb.Image(img[si], boxes=boxes, caption=path.name)) 165 | wandb_logger.log_training_progress(predn, path, names) if wandb_logger and wandb_logger.wandb_run else None 166 | 167 | # Append to pycocotools JSON dictionary 168 | if save_json: 169 | # [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ... 170 | image_id = int(path.stem) if path.stem.isnumeric() else path.stem 171 | box = xyxy2xywh(predn[:, :4]) # xywh 172 | box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner 173 | for p, b in zip(pred.tolist(), box.tolist()): 174 | jdict.append({'image_id': image_id, 175 | 'category_id': coco91class[int(p[5])] if is_coco else int(p[5]), 176 | 'bbox': [round(x, 3) for x in b], 177 | 'score': round(p[4], 5)}) 178 | 179 | # Assign all predictions as incorrect 180 | correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool, device=device) 181 | if nl: 182 | detected = [] # target indices 183 | tcls_tensor = labels[:, 0] 184 | 185 | # target boxes 186 | tbox = xywh2xyxy(labels[:, 1:5]) 187 | scale_coords(img[si].shape[1:], tbox, shapes[si][0], shapes[si][1]) # native-space labels 188 | if plots: 189 | confusion_matrix.process_batch(predn, torch.cat((labels[:, 0:1], tbox), 1)) 190 | 191 | # Per target class 192 | for cls in torch.unique(tcls_tensor): 193 | ti = (cls == tcls_tensor).nonzero(as_tuple=False).view(-1) # prediction indices 194 | pi = (cls == pred[:, 5]).nonzero(as_tuple=False).view(-1) # target indices 195 | 196 | # Search for detections 197 | if pi.shape[0]: 198 | # Prediction to target ious 199 | ious, i = box_iou(predn[pi, :4], tbox[ti]).max(1) # best ious, indices 200 | 201 | # Append detections 202 | detected_set = set() 203 | for j in (ious > iouv[0]).nonzero(as_tuple=False): 204 | d = ti[i[j]] # detected target 205 | if d.item() not in detected_set: 206 | detected_set.add(d.item()) 207 | detected.append(d) 208 | correct[pi[j]] = ious[j] > iouv # iou_thres is 1xn 209 | if len(detected) == nl: # all targets already located in image 210 | break 211 | 212 | # Append statistics (correct, conf, pcls, tcls) 213 | stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) 214 | 215 | # Plot images 216 | if plots and batch_i < 3: 217 | f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels 218 | Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start() 219 | f = save_dir / f'test_batch{batch_i}_pred.jpg' # predictions 220 | Thread(target=plot_images, args=(img, output_to_target(out), paths, f, names), daemon=True).start() 221 | 222 | # Compute statistics 223 | stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy 224 | if len(stats) and stats[0].any(): 225 | p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, v5_metric=v5_metric, save_dir=save_dir, names=names) 226 | ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95 227 | mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean() 228 | nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class 229 | else: 230 | nt = torch.zeros(1) 231 | 232 | # Print results 233 | pf = '%20s' + '%12i' * 2 + '%12.3g' * 4 # print format 234 | print(pf % ('all', seen, nt.sum(), mp, mr, map50, map)) 235 | 236 | # Print results per class 237 | if (verbose or (nc < 50 and not training)) and nc > 1 and len(stats): 238 | for i, c in enumerate(ap_class): 239 | print(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i])) 240 | 241 | # Print speeds 242 | t = tuple(x / seen * 1E3 for x in (t0, t1, t0 + t1)) + (imgsz, imgsz, batch_size) # tuple 243 | if not training: 244 | print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t) 245 | 246 | # Plots 247 | if plots: 248 | confusion_matrix.plot(save_dir=save_dir, names=list(names.values())) 249 | if wandb_logger and wandb_logger.wandb: 250 | val_batches = [wandb_logger.wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))] 251 | wandb_logger.log({"Validation": val_batches}) 252 | if wandb_images: 253 | wandb_logger.log({"Bounding Box Debugger/Images": wandb_images}) 254 | 255 | # Save JSON 256 | if save_json and len(jdict): 257 | w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights 258 | anno_json = './coco/annotations/instances_val2017.json' # annotations json 259 | pred_json = str(save_dir / f"{w}_predictions.json") # predictions json 260 | print('\nEvaluating pycocotools mAP... saving %s...' % pred_json) 261 | with open(pred_json, 'w') as f: 262 | json.dump(jdict, f) 263 | 264 | try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb 265 | from pycocotools.coco import COCO 266 | from pycocotools.cocoeval import COCOeval 267 | 268 | anno = COCO(anno_json) # init annotations api 269 | pred = anno.loadRes(pred_json) # init predictions api 270 | eval = COCOeval(anno, pred, 'bbox') 271 | if is_coco: 272 | eval.params.imgIds = [int(Path(x).stem) for x in dataloader.dataset.img_files] # image IDs to evaluate 273 | eval.evaluate() 274 | eval.accumulate() 275 | eval.summarize() 276 | map, map50 = eval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5) 277 | except Exception as e: 278 | print(f'pycocotools unable to run: {e}') 279 | 280 | # Return results 281 | model.float() # for training 282 | if not training: 283 | s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' 284 | print(f"Results saved to {save_dir}{s}") 285 | maps = np.zeros(nc) + map 286 | for i, c in enumerate(ap_class): 287 | maps[c] = ap[i] 288 | return (mp, mr, map50, map, *(loss.cpu() / len(dataloader)).tolist()), maps, t 289 | 290 | 291 | if __name__ == '__main__': 292 | parser = argparse.ArgumentParser(prog='test.py') 293 | parser.add_argument('--weights', nargs='+', type=str, default='runs/train/exp/weights/best.pt', help='model.pt path(s)') 294 | parser.add_argument('--data', type=str, default='data/cbc.yaml', help='*.data path') 295 | parser.add_argument('--batch-size', type=int, default=20, help='size of each image batch') 296 | parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') 297 | parser.add_argument('--conf-thres', type=float, default=0.001, help='object confidence threshold') 298 | parser.add_argument('--iou-thres', type=float, default=0.6, help='IOU threshold for NMS') 299 | parser.add_argument('--task', default='val', help='train, val, test, speed or study') 300 | parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') 301 | parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset') 302 | parser.add_argument('--augment', action='store_true', help='augmented inference') 303 | parser.add_argument('--verbose', action='store_true', help='report mAP by class') 304 | parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') 305 | parser.add_argument('--save-hybrid', action='store_true', help='save label+prediction hybrid results to *.txt') 306 | parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') 307 | parser.add_argument('--save-json', action='store_true', help='save a cocoapi-compatible JSON results file') 308 | parser.add_argument('--project', default='runs/test', help='save to project/name') 309 | parser.add_argument('--name', default='exp', help='save to project/name') 310 | parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') 311 | parser.add_argument('--no-trace', action='store_true', help='don`t trace model') 312 | parser.add_argument('--v5-metric', action='store_true', help='assume maximum recall as 1.0 in AP calculation') 313 | opt = parser.parse_args() 314 | opt.save_json |= opt.data.endswith('coco.yaml') 315 | opt.data = check_file(opt.data) # check file 316 | print(opt) 317 | #check_requirements() 318 | 319 | if opt.task in ('train', 'val', 'test'): # run normally 320 | test(opt.data, 321 | opt.weights, 322 | opt.batch_size, 323 | opt.img_size, 324 | opt.conf_thres, 325 | opt.iou_thres, 326 | opt.save_json, 327 | opt.single_cls, 328 | opt.augment, 329 | opt.verbose, 330 | save_txt=opt.save_txt | opt.save_hybrid, 331 | save_hybrid=opt.save_hybrid, 332 | save_conf=opt.save_conf, 333 | trace=not opt.no_trace, 334 | v5_metric=opt.v5_metric 335 | ) 336 | 337 | elif opt.task == 'speed': # speed benchmarks 338 | for w in opt.weights: 339 | test(opt.data, w, opt.batch_size, opt.img_size, 0.25, 0.45, save_json=False, plots=False, v5_metric=opt.v5_metric) 340 | 341 | elif opt.task == 'study': # run over a range of settings and save/plot 342 | # python test.py --task study --data coco.yaml --iou 0.65 --weights yolov7.pt 343 | x = list(range(256, 1536 + 128, 128)) # x axis (image sizes) 344 | for w in opt.weights: 345 | f = f'study_{Path(opt.data).stem}_{Path(w).stem}.txt' # filename to save to 346 | y = [] # y axis 347 | for i in x: # img-size 348 | print(f'\nRunning {f} point {i}...') 349 | r, _, t = test(opt.data, w, opt.batch_size, i, opt.conf_thres, opt.iou_thres, opt.save_json, 350 | plots=False, v5_metric=opt.v5_metric) 351 | y.append(r + t) # results and times 352 | np.savetxt(f, y, fmt='%10.4g') # save 353 | os.system('zip -r study.zip study_*.txt') 354 | plot_study_txt(x=x) # plot 355 | -------------------------------------------------------------------------------- /utils/plots.py: -------------------------------------------------------------------------------- 1 | # Plotting utils 2 | 3 | import glob 4 | import math 5 | import os 6 | import random 7 | from copy import copy 8 | from pathlib import Path 9 | 10 | import cv2 11 | import matplotlib 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | import pandas as pd 15 | import seaborn as sns 16 | import torch 17 | import yaml 18 | from PIL import Image, ImageDraw, ImageFont 19 | from scipy.signal import butter, filtfilt 20 | 21 | from utils.general import xywh2xyxy, xyxy2xywh 22 | from utils.metrics import fitness 23 | 24 | # Settings 25 | matplotlib.rc('font', **{'size': 11}) 26 | matplotlib.use('Agg') # for writing to files only 27 | 28 | 29 | def color_list(): 30 | # Return first 10 plt colors as (r,g,b) https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb 31 | def hex2rgb(h): 32 | return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4)) 33 | 34 | return [hex2rgb(h) for h in matplotlib.colors.TABLEAU_COLORS.values()] # or BASE_ (8), CSS4_ (148), XKCD_ (949) 35 | 36 | 37 | def hist2d(x, y, n=100): 38 | # 2d histogram used in labels.png and evolve.png 39 | xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n) 40 | hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges)) 41 | xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1) 42 | yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1) 43 | return np.log(hist[xidx, yidx]) 44 | 45 | 46 | def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5): 47 | # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy 48 | def butter_lowpass(cutoff, fs, order): 49 | nyq = 0.5 * fs 50 | normal_cutoff = cutoff / nyq 51 | return butter(order, normal_cutoff, btype='low', analog=False) 52 | 53 | b, a = butter_lowpass(cutoff, fs, order=order) 54 | return filtfilt(b, a, data) # forward-backward filter 55 | 56 | 57 | def plot_one_box(x, img, color=None, label=None, line_thickness=3): 58 | # Plots one bounding box on image img 59 | tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness 60 | color = color or [random.randint(0, 255) for _ in range(3)] 61 | c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) 62 | cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA) 63 | if label: 64 | tf = max(tl - 1, 1) # font thickness 65 | t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] 66 | c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 67 | cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled 68 | cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA) 69 | return img 70 | 71 | def plot_one_box_PIL(box, img, color=None, label=None, line_thickness=None): 72 | img = Image.fromarray(img) 73 | draw = ImageDraw.Draw(img) 74 | line_thickness = line_thickness or max(int(min(img.size) / 200), 2) 75 | draw.rectangle(box, width=line_thickness, outline=tuple(color)) # plot 76 | if label: 77 | fontsize = max(round(max(img.size) / 40), 12) 78 | font = ImageFont.truetype("Arial.ttf", fontsize) 79 | txt_width, txt_height = font.getsize(label) 80 | draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=tuple(color)) 81 | draw.text((box[0], box[1] - txt_height + 1), label, fill=(255, 255, 255), font=font) 82 | return np.asarray(img) 83 | 84 | 85 | def plot_wh_methods(): # from utils.plots import *; plot_wh_methods() 86 | # Compares the two methods for width-height anchor multiplication 87 | # https://github.com/ultralytics/yolov3/issues/168 88 | x = np.arange(-4.0, 4.0, .1) 89 | ya = np.exp(x) 90 | yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2 91 | 92 | fig = plt.figure(figsize=(6, 3), tight_layout=True) 93 | plt.plot(x, ya, '.-', label='YOLOv3') 94 | plt.plot(x, yb ** 2, '.-', label='YOLOR ^2') 95 | plt.plot(x, yb ** 1.6, '.-', label='YOLOR ^1.6') 96 | plt.xlim(left=-4, right=4) 97 | plt.ylim(bottom=0, top=6) 98 | plt.xlabel('input') 99 | plt.ylabel('output') 100 | plt.grid() 101 | plt.legend() 102 | fig.savefig('comparison.png', dpi=200) 103 | 104 | 105 | def output_to_target(output): 106 | # Convert model output to target format [batch_id, class_id, x, y, w, h, conf] 107 | targets = [] 108 | for i, o in enumerate(output): 109 | for *box, conf, cls in o.cpu().numpy(): 110 | targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf]) 111 | return np.array(targets) 112 | 113 | 114 | def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16): 115 | # Plot image grid with labels 116 | 117 | if isinstance(images, torch.Tensor): 118 | images = images.cpu().float().numpy() 119 | if isinstance(targets, torch.Tensor): 120 | targets = targets.cpu().numpy() 121 | 122 | # un-normalise 123 | if np.max(images[0]) <= 1: 124 | images *= 255 125 | 126 | tl = 3 # line thickness 127 | tf = max(tl - 1, 1) # font thickness 128 | bs, _, h, w = images.shape # batch size, _, height, width 129 | bs = min(bs, max_subplots) # limit plot images 130 | ns = np.ceil(bs ** 0.5) # number of subplots (square) 131 | 132 | # Check if we should resize 133 | scale_factor = max_size / max(h, w) 134 | if scale_factor < 1: 135 | h = math.ceil(scale_factor * h) 136 | w = math.ceil(scale_factor * w) 137 | 138 | colors = color_list() # list of colors 139 | mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init 140 | for i, img in enumerate(images): 141 | if i == max_subplots: # if last batch has fewer images than we expect 142 | break 143 | 144 | block_x = int(w * (i // ns)) 145 | block_y = int(h * (i % ns)) 146 | 147 | img = img.transpose(1, 2, 0) 148 | if scale_factor < 1: 149 | img = cv2.resize(img, (w, h)) 150 | 151 | mosaic[block_y:block_y + h, block_x:block_x + w, :] = img 152 | if len(targets) > 0: 153 | image_targets = targets[targets[:, 0] == i] 154 | boxes = xywh2xyxy(image_targets[:, 2:6]).T 155 | classes = image_targets[:, 1].astype('int') 156 | labels = image_targets.shape[1] == 6 # labels if no conf column 157 | conf = None if labels else image_targets[:, 6] # check for confidence presence (label vs pred) 158 | 159 | if boxes.shape[1]: 160 | if boxes.max() <= 1.01: # if normalized with tolerance 0.01 161 | boxes[[0, 2]] *= w # scale to pixels 162 | boxes[[1, 3]] *= h 163 | elif scale_factor < 1: # absolute coords need scale if image scales 164 | boxes *= scale_factor 165 | boxes[[0, 2]] += block_x 166 | boxes[[1, 3]] += block_y 167 | for j, box in enumerate(boxes.T): 168 | cls = int(classes[j]) 169 | color = colors[cls % len(colors)] 170 | cls = names[cls] if names else cls 171 | if labels or conf[j] > 0.25: # 0.25 conf thresh 172 | label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j]) 173 | plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl) 174 | 175 | # Draw image filename labels 176 | if paths: 177 | label = Path(paths[i]).name[:40] # trim to 40 char 178 | t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] 179 | cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf, 180 | lineType=cv2.LINE_AA) 181 | 182 | # Image border 183 | cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3) 184 | 185 | if fname: 186 | r = min(1280. / max(h, w) / ns, 1.0) # ratio to limit image size 187 | mosaic = cv2.resize(mosaic, (int(ns * w * r), int(ns * h * r)), interpolation=cv2.INTER_AREA) 188 | # cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2 save 189 | Image.fromarray(mosaic).save(fname) # PIL save 190 | return mosaic 191 | 192 | 193 | def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''): 194 | # Plot LR simulating training for full epochs 195 | optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals 196 | y = [] 197 | for _ in range(epochs): 198 | scheduler.step() 199 | y.append(optimizer.param_groups[0]['lr']) 200 | plt.plot(y, '.-', label='LR') 201 | plt.xlabel('epoch') 202 | plt.ylabel('LR') 203 | plt.grid() 204 | plt.xlim(0, epochs) 205 | plt.ylim(0) 206 | plt.savefig(Path(save_dir) / 'LR.png', dpi=200) 207 | plt.close() 208 | 209 | 210 | def plot_test_txt(): # from utils.plots import *; plot_test() 211 | # Plot test.txt histograms 212 | x = np.loadtxt('test.txt', dtype=np.float32) 213 | box = xyxy2xywh(x[:, :4]) 214 | cx, cy = box[:, 0], box[:, 1] 215 | 216 | fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True) 217 | ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0) 218 | ax.set_aspect('equal') 219 | plt.savefig('hist2d.png', dpi=300) 220 | 221 | fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True) 222 | ax[0].hist(cx, bins=600) 223 | ax[1].hist(cy, bins=600) 224 | plt.savefig('hist1d.png', dpi=200) 225 | 226 | 227 | def plot_targets_txt(): # from utils.plots import *; plot_targets_txt() 228 | # Plot targets.txt histograms 229 | x = np.loadtxt('targets.txt', dtype=np.float32).T 230 | s = ['x targets', 'y targets', 'width targets', 'height targets'] 231 | fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True) 232 | ax = ax.ravel() 233 | for i in range(4): 234 | ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std())) 235 | ax[i].legend() 236 | ax[i].set_title(s[i]) 237 | plt.savefig('targets.jpg', dpi=200) 238 | 239 | 240 | def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_txt() 241 | # Plot study.txt generated by test.py 242 | fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True) 243 | # ax = ax.ravel() 244 | 245 | fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True) 246 | # for f in [Path(path) / f'study_coco_{x}.txt' for x in ['yolor-p6', 'yolor-w6', 'yolor-e6', 'yolor-d6']]: 247 | for f in sorted(Path(path).glob('study*.txt')): 248 | y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T 249 | x = np.arange(y.shape[1]) if x is None else np.array(x) 250 | s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)'] 251 | # for i in range(7): 252 | # ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8) 253 | # ax[i].set_title(s[i]) 254 | 255 | j = y[3].argmax() + 1 256 | ax2.plot(y[6, 1:j], y[3, 1:j] * 1E2, '.-', linewidth=2, markersize=8, 257 | label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO')) 258 | 259 | ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5], 260 | 'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet') 261 | 262 | ax2.grid(alpha=0.2) 263 | ax2.set_yticks(np.arange(20, 60, 5)) 264 | ax2.set_xlim(0, 57) 265 | ax2.set_ylim(30, 55) 266 | ax2.set_xlabel('GPU Speed (ms/img)') 267 | ax2.set_ylabel('COCO AP val') 268 | ax2.legend(loc='lower right') 269 | plt.savefig(str(Path(path).name) + '.png', dpi=300) 270 | 271 | 272 | def plot_labels(labels, names=(), save_dir=Path(''), loggers=None): 273 | # plot dataset labels 274 | print('Plotting labels... ') 275 | c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes 276 | nc = int(c.max() + 1) # number of classes 277 | colors = color_list() 278 | x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height']) 279 | 280 | # seaborn correlogram 281 | sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9)) 282 | plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200) 283 | plt.close() 284 | 285 | # matplotlib labels 286 | matplotlib.use('svg') # faster 287 | ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() 288 | ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) 289 | ax[0].set_ylabel('instances') 290 | if 0 < len(names) < 30: 291 | ax[0].set_xticks(range(len(names))) 292 | ax[0].set_xticklabels(names, rotation=90, fontsize=10) 293 | else: 294 | ax[0].set_xlabel('classes') 295 | sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9) 296 | sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9) 297 | 298 | # rectangles 299 | labels[:, 1:3] = 0.5 # center 300 | labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000 301 | img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255) 302 | for cls, *box in labels[:1000]: 303 | ImageDraw.Draw(img).rectangle(box, width=1, outline=colors[int(cls) % 10]) # plot 304 | ax[1].imshow(img) 305 | ax[1].axis('off') 306 | 307 | for a in [0, 1, 2, 3]: 308 | for s in ['top', 'right', 'left', 'bottom']: 309 | ax[a].spines[s].set_visible(False) 310 | 311 | plt.savefig(save_dir / 'labels.jpg', dpi=200) 312 | matplotlib.use('Agg') 313 | plt.close() 314 | 315 | # loggers 316 | for k, v in loggers.items() or {}: 317 | if k == 'wandb' and v: 318 | v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.jpg')]}, commit=False) 319 | 320 | 321 | def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution() 322 | # Plot hyperparameter evolution results in evolve.txt 323 | with open(yaml_file) as f: 324 | hyp = yaml.load(f, Loader=yaml.SafeLoader) 325 | x = np.loadtxt('evolve.txt', ndmin=2) 326 | f = fitness(x) 327 | # weights = (f - f.min()) ** 2 # for weighted results 328 | plt.figure(figsize=(10, 12), tight_layout=True) 329 | matplotlib.rc('font', **{'size': 8}) 330 | for i, (k, v) in enumerate(hyp.items()): 331 | y = x[:, i + 7] 332 | # mu = (y * weights).sum() / weights.sum() # best weighted result 333 | mu = y[f.argmax()] # best single result 334 | plt.subplot(6, 5, i + 1) 335 | plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none') 336 | plt.plot(mu, f.max(), 'k+', markersize=15) 337 | plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters 338 | if i % 5 != 0: 339 | plt.yticks([]) 340 | print('%15s: %.3g' % (k, mu)) 341 | plt.savefig('evolve.png', dpi=200) 342 | print('\nPlot saved as evolve.png') 343 | 344 | 345 | def profile_idetection(start=0, stop=0, labels=(), save_dir=''): 346 | # Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection() 347 | ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel() 348 | s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS'] 349 | files = list(Path(save_dir).glob('frames*.txt')) 350 | for fi, f in enumerate(files): 351 | try: 352 | results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows 353 | n = results.shape[1] # number of rows 354 | x = np.arange(start, min(stop, n) if stop else n) 355 | results = results[:, x] 356 | t = (results[0] - results[0].min()) # set t0=0s 357 | results[0] = x 358 | for i, a in enumerate(ax): 359 | if i < len(results): 360 | label = labels[fi] if len(labels) else f.stem.replace('frames_', '') 361 | a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5) 362 | a.set_title(s[i]) 363 | a.set_xlabel('time (s)') 364 | # if fi == len(files) - 1: 365 | # a.set_ylim(bottom=0) 366 | for side in ['top', 'right']: 367 | a.spines[side].set_visible(False) 368 | else: 369 | a.remove() 370 | except Exception as e: 371 | print('Warning: Plotting error for %s; %s' % (f, e)) 372 | 373 | ax[1].legend() 374 | plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200) 375 | 376 | 377 | def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_results_overlay() 378 | # Plot training 'results*.txt', overlaying train and val losses 379 | s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends 380 | t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles 381 | for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')): 382 | results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T 383 | n = results.shape[1] # number of rows 384 | x = range(start, min(stop, n) if stop else n) 385 | fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True) 386 | ax = ax.ravel() 387 | for i in range(5): 388 | for j in [i, i + 5]: 389 | y = results[j, x] 390 | ax[i].plot(x, y, marker='.', label=s[j]) 391 | # y_smooth = butter_lowpass_filtfilt(y) 392 | # ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j]) 393 | 394 | ax[i].set_title(t[i]) 395 | ax[i].legend() 396 | ax[i].set_ylabel(f) if i == 0 else None # add filename 397 | fig.savefig(f.replace('.txt', '.png'), dpi=200) 398 | 399 | 400 | def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): 401 | # Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp') 402 | fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True) 403 | ax = ax.ravel() 404 | s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall', 405 | 'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95'] 406 | if bucket: 407 | # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id] 408 | files = ['results%g.txt' % x for x in id] 409 | c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id) 410 | os.system(c) 411 | else: 412 | files = list(Path(save_dir).glob('results*.txt')) 413 | assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir) 414 | for fi, f in enumerate(files): 415 | try: 416 | results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T 417 | n = results.shape[1] # number of rows 418 | x = range(start, min(stop, n) if stop else n) 419 | for i in range(10): 420 | y = results[i, x] 421 | if i in [0, 1, 2, 5, 6, 7]: 422 | y[y == 0] = np.nan # don't show zero loss values 423 | # y /= y[0] # normalize 424 | label = labels[fi] if len(labels) else f.stem 425 | ax[i].plot(x, y, marker='.', label=label, linewidth=2, markersize=8) 426 | ax[i].set_title(s[i]) 427 | # if i in [5, 6, 7]: # share train and val loss y axes 428 | # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5]) 429 | except Exception as e: 430 | print('Warning: Plotting error for %s; %s' % (f, e)) 431 | 432 | ax[1].legend() 433 | fig.savefig(Path(save_dir) / 'results.png', dpi=200) 434 | 435 | 436 | def output_to_keypoint(output): 437 | # Convert model output to target format [batch_id, class_id, x, y, w, h, conf] 438 | targets = [] 439 | for i, o in enumerate(output): 440 | kpts = o[:,6:] 441 | o = o[:,:6] 442 | for index, (*box, conf, cls) in enumerate(o.detach().cpu().numpy()): 443 | targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf, *list(kpts.detach().cpu().numpy()[index])]) 444 | return np.array(targets) 445 | 446 | 447 | def plot_skeleton_kpts(im, kpts, steps, orig_shape=None): 448 | #Plot the skeleton and keypointsfor coco datatset 449 | palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], 450 | [230, 230, 0], [255, 153, 255], [153, 204, 255], 451 | [255, 102, 255], [255, 51, 255], [102, 178, 255], 452 | [51, 153, 255], [255, 153, 153], [255, 102, 102], 453 | [255, 51, 51], [153, 255, 153], [102, 255, 102], 454 | [51, 255, 51], [0, 255, 0], [0, 0, 255], [255, 0, 0], 455 | [255, 255, 255]]) 456 | 457 | skeleton = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], 458 | [7, 13], [6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3], 459 | [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]] 460 | 461 | pose_limb_color = palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]] 462 | pose_kpt_color = palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]] 463 | radius = 5 464 | num_kpts = len(kpts) // steps 465 | 466 | for kid in range(num_kpts): 467 | r, g, b = pose_kpt_color[kid] 468 | x_coord, y_coord = kpts[steps * kid], kpts[steps * kid + 1] 469 | if not (x_coord % 640 == 0 or y_coord % 640 == 0): 470 | if steps == 3: 471 | conf = kpts[steps * kid + 2] 472 | if conf < 0.5: 473 | continue 474 | cv2.circle(im, (int(x_coord), int(y_coord)), radius, (int(r), int(g), int(b)), -1) 475 | 476 | for sk_id, sk in enumerate(skeleton): 477 | r, g, b = pose_limb_color[sk_id] 478 | pos1 = (int(kpts[(sk[0]-1)*steps]), int(kpts[(sk[0]-1)*steps+1])) 479 | pos2 = (int(kpts[(sk[1]-1)*steps]), int(kpts[(sk[1]-1)*steps+1])) 480 | if steps == 3: 481 | conf1 = kpts[(sk[0]-1)*steps+2] 482 | conf2 = kpts[(sk[1]-1)*steps+2] 483 | if conf1<0.5 or conf2<0.5: 484 | continue 485 | if pos1[0]%640 == 0 or pos1[1]%640==0 or pos1[0]<0 or pos1[1]<0: 486 | continue 487 | if pos2[0] % 640 == 0 or pos2[1] % 640 == 0 or pos2[0]<0 or pos2[1]<0: 488 | continue 489 | cv2.line(im, pos1, pos2, (int(r), int(g), int(b)), thickness=2) 490 | --------------------------------------------------------------------------------