├── .gitignore ├── LICENSE ├── README.md ├── configs ├── __init__.py ├── bisenet_customer.py ├── bisenetv1_ade20k.py ├── bisenetv1_city.py ├── bisenetv1_coco.py ├── bisenetv2_ade20k.py ├── bisenetv2_city.py └── bisenetv2_coco.py ├── datasets ├── ade20k │ ├── annotations │ └── images ├── cityscapes │ ├── gtFine │ ├── leftImg8bit │ ├── train.txt │ └── val.txt └── coco │ ├── images │ ├── train2017 │ └── val2017 │ └── labels │ ├── train2017 │ └── val2017 ├── dist_train.sh ├── example.png ├── lib ├── __init__.py ├── data │ ├── __init__.py │ ├── ade20k.py │ ├── base_dataset.py │ ├── cityscapes_cv2.py │ ├── coco.py │ ├── customer_dataset.py │ ├── get_dataloader.py │ ├── sampler.py │ └── transform_cv2.py ├── logger.py ├── lr_scheduler.py ├── meters.py ├── models │ ├── __init__.py │ ├── bisenetv1.py │ ├── bisenetv2.py │ └── resnet.py └── ohem_ce_loss.py ├── ncnn ├── CMakeLists.txt ├── README.md └── segment.cpp ├── old ├── README.md ├── bisenetv2 │ ├── __init__.py │ ├── bisenetv2.py │ ├── cityscapes_cv2.py │ ├── evaluatev2.py │ ├── logger.py │ ├── lr_scheduler.py │ ├── meters.py │ ├── ohem_ce_loss.py │ ├── sampler.py │ ├── train.py │ └── transform_cv2.py ├── cityscapes.py ├── cityscapes_info.json ├── demo.py ├── diss │ ├── __init__.py │ ├── evaluate.py │ ├── model.py │ └── train.py ├── evaluate.py ├── fp16 │ ├── __init__.py │ ├── evaluate.py │ ├── model.py │ ├── resnet.py │ └── train.py ├── logger.py ├── loss.py ├── model.py ├── modules │ ├── __init__.py │ ├── bn.py │ ├── deeplab.py │ ├── dense.py │ ├── functions.py │ ├── misc.py │ ├── residual.py │ └── src │ │ ├── checks.h │ │ ├── inplace_abn.cpp │ │ ├── inplace_abn.h │ │ ├── inplace_abn_cpu.cpp │ │ ├── inplace_abn_cuda.cu │ │ ├── inplace_abn_cuda_half.cu │ │ └── utils │ │ ├── checks.h │ │ ├── common.h │ │ └── cuda.cuh ├── optimizer.py ├── pic.jpg ├── resnet.py ├── train.py └── transform.py ├── openvino ├── CMakeLists.txt ├── README.md └── main.cpp ├── tensorrt ├── CMakeLists.txt ├── README.md ├── batch_stream.hpp ├── entropy_calibrator.hpp ├── plugins │ ├── CMakeLists.txt │ ├── argmax_plugin.cu │ ├── argmax_plugin.h │ └── kernels.hpp ├── read_img.cpp ├── read_img.hpp ├── segment.cu ├── segment.py ├── trt_dep.cu └── trt_dep.hpp ├── tis ├── README.md ├── client_backend.py ├── client_grpc.py ├── client_http.py ├── cpp_client │ ├── CMakeLists.txt │ └── main.cpp ├── models │ ├── bisenetv1 │ │ └── config.pbtxt │ ├── bisenetv1_model │ │ └── config.pbtxt │ ├── bisenetv2 │ │ └── config.pbtxt │ ├── bisenetv2_model │ │ └── config.pbtxt │ ├── preprocess_cpp │ │ ├── 1 │ │ │ └── .gitkeep │ │ └── config.pbtxt │ └── preprocess_py │ │ ├── 1 │ │ └── model.py │ │ └── config.pbtxt └── self_backend │ ├── CMakeLists.txt │ ├── cmake │ └── TutorialRecommendedBackendConfig.cmake.in │ └── src │ ├── libtriton_recommended.ldscript │ └── recommended.cc ├── tools ├── __init__.py ├── check_dataset_info.py ├── conver_to_trt.py ├── demo.py ├── demo_video.py ├── evaluate.py ├── export_libtorch.py ├── export_onnx.py ├── gen_dataset_annos.py └── train_amp.py └── video.mp4 /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | MANIFEST 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | .hypothesis/ 47 | .pytest_cache/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | db.sqlite3 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # Environments 84 | .env 85 | .venv 86 | env/ 87 | venv/ 88 | ENV/ 89 | env.bak/ 90 | venv.bak/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | # mypy 103 | .mypy_cache/ 104 | 105 | 106 | ## Coin: 107 | play.py 108 | preprocess_data.py 109 | res/ 110 | adj.md 111 | tensorrt/build/* 112 | datasets/coco/train.txt 113 | datasets/coco/val.txt 114 | datasets/ade20k/train.txt 115 | datasets/ade20k/val.txt 116 | pretrained/* 117 | run.sh 118 | openvino/build/* 119 | openvino/output* 120 | ncnn/models/* 121 | *.onnx 122 | *.pth 123 | tis/cpp_client/build/* 124 | log*txt 125 | 000000320425.jpg 126 | 127 | tvm/ 128 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 CoinCheung 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import importlib 4 | 5 | 6 | class cfg_dict(object): 7 | 8 | def __init__(self, d): 9 | self.__dict__ = d 10 | self.get = d.get 11 | 12 | 13 | def set_cfg_from_file(cfg_path): 14 | spec = importlib.util.spec_from_file_location('cfg_file', cfg_path) 15 | cfg_file = importlib.util.module_from_spec(spec) 16 | spec_loader = spec.loader.exec_module(cfg_file) 17 | cfg = cfg_file.cfg 18 | return cfg_dict(cfg) 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /configs/bisenet_customer.py: -------------------------------------------------------------------------------- 1 | 2 | cfg = dict( 3 | model_type='bisenetv1', 4 | n_cats=20, 5 | num_aux_heads=2, 6 | lr_start=1e-2, 7 | weight_decay=5e-4, 8 | warmup_iters=1000, 9 | max_iter=80000, 10 | dataset='CustomerDataset', 11 | im_root='./datasets/cityscapes', 12 | train_im_anns='./datasets/cityscapes/train.txt', 13 | val_im_anns='./datasets/cityscapes/val.txt', 14 | scales=[0.75, 2.], 15 | cropsize=[512, 512], 16 | eval_crop=[512, 512], 17 | eval_scales=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 18 | ims_per_gpu=8, 19 | eval_ims_per_gpu=2, 20 | use_fp16=True, 21 | use_sync_bn=False, 22 | respth='./res', 23 | ) 24 | -------------------------------------------------------------------------------- /configs/bisenetv1_ade20k.py: -------------------------------------------------------------------------------- 1 | 2 | cfg = dict( 3 | model_type='bisenetv1', 4 | n_cats=150, 5 | num_aux_heads=2, 6 | lr_start=4e-2, 7 | weight_decay=1e-4, 8 | warmup_iters=1000, 9 | max_iter=40000, 10 | dataset='ADE20k', 11 | im_root='./datasets/ade20k', 12 | train_im_anns='./datasets/ade20k/train.txt', 13 | val_im_anns='./datasets/ade20k/val.txt', 14 | scales=[0.5, 2.], 15 | cropsize=[512, 512], 16 | eval_crop=[512, 512], 17 | eval_scales=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 18 | eval_start_shortside=512, 19 | ims_per_gpu=8, 20 | eval_ims_per_gpu=1, 21 | use_fp16=True, 22 | use_sync_bn=True, 23 | respth='./res', 24 | ) 25 | -------------------------------------------------------------------------------- /configs/bisenetv1_city.py: -------------------------------------------------------------------------------- 1 | 2 | cfg = dict( 3 | model_type='bisenetv1', 4 | n_cats=19, 5 | num_aux_heads=2, 6 | lr_start=1e-2, 7 | weight_decay=5e-4, 8 | warmup_iters=1000, 9 | max_iter=80000, 10 | dataset='CityScapes', 11 | im_root='./datasets/cityscapes', 12 | train_im_anns='./datasets/cityscapes/train.txt', 13 | val_im_anns='./datasets/cityscapes/val.txt', 14 | scales=[0.75, 2.], 15 | cropsize=[1024, 1024], 16 | eval_crop=[1024, 1024], 17 | eval_scales=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 18 | ims_per_gpu=8, 19 | eval_ims_per_gpu=2, 20 | use_fp16=True, 21 | use_sync_bn=False, 22 | respth='./res', 23 | ) 24 | -------------------------------------------------------------------------------- /configs/bisenetv1_coco.py: -------------------------------------------------------------------------------- 1 | 2 | cfg = dict( 3 | model_type='bisenetv1', 4 | n_cats=171, 5 | num_aux_heads=2, 6 | lr_start=1e-2, 7 | weight_decay=1e-4, 8 | warmup_iters=1000, 9 | max_iter=90000, 10 | dataset='CocoStuff', 11 | im_root='./datasets/coco', 12 | train_im_anns='./datasets/coco/train.txt', 13 | val_im_anns='./datasets/coco/val.txt', 14 | scales=[0.5, 2.], 15 | cropsize=[512, 512], 16 | eval_crop=[512, 512], 17 | eval_scales=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 18 | ims_per_gpu=4, 19 | eval_ims_per_gpu=1, 20 | use_fp16=True, 21 | use_sync_bn=True, 22 | respth='./res', 23 | ) 24 | -------------------------------------------------------------------------------- /configs/bisenetv2_ade20k.py: -------------------------------------------------------------------------------- 1 | 2 | ## bisenetv2 3 | cfg = dict( 4 | model_type='bisenetv2', 5 | n_cats=150, 6 | num_aux_heads=4, 7 | lr_start=5e-3, 8 | weight_decay=1e-4, 9 | warmup_iters=1000, 10 | max_iter=160000, 11 | dataset='ADE20k', 12 | im_root='./datasets/ade20k', 13 | train_im_anns='./datasets/ade20k/train.txt', 14 | val_im_anns='./datasets/ade20k/val.txt', 15 | scales=[0.5, 2.], 16 | cropsize=[640, 640], 17 | eval_crop=[640, 640], 18 | eval_start_shortside=640, 19 | eval_scales=[0.5, 0.75, 1, 1.25, 1.5, 1.75], 20 | ims_per_gpu=2, 21 | eval_ims_per_gpu=1, 22 | use_fp16=True, 23 | use_sync_bn=True, 24 | respth='./res', 25 | ) 26 | -------------------------------------------------------------------------------- /configs/bisenetv2_city.py: -------------------------------------------------------------------------------- 1 | 2 | ## bisenetv2 3 | cfg = dict( 4 | model_type='bisenetv2', 5 | n_cats=19, 6 | num_aux_heads=4, 7 | lr_start=5e-3, 8 | weight_decay=5e-4, 9 | warmup_iters=1000, 10 | max_iter=150000, 11 | dataset='CityScapes', 12 | im_root='./datasets/cityscapes', 13 | train_im_anns='./datasets/cityscapes/train.txt', 14 | val_im_anns='./datasets/cityscapes/val.txt', 15 | scales=[0.25, 2.], 16 | cropsize=[512, 1024], 17 | eval_crop=[1024, 1024], 18 | eval_scales=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 19 | ims_per_gpu=8, 20 | eval_ims_per_gpu=2, 21 | use_fp16=True, 22 | use_sync_bn=True, 23 | respth='./res', 24 | ) 25 | -------------------------------------------------------------------------------- /configs/bisenetv2_coco.py: -------------------------------------------------------------------------------- 1 | 2 | ## bisenetv2 3 | cfg = dict( 4 | model_type='bisenetv2', 5 | n_cats=171, 6 | num_aux_heads=4, 7 | lr_start=5e-3, 8 | weight_decay=1e-4, 9 | warmup_iters=1000, 10 | max_iter=180000, 11 | dataset='CocoStuff', 12 | im_root='./datasets/coco', 13 | train_im_anns='./datasets/coco/train.txt', 14 | val_im_anns='./datasets/coco/val.txt', 15 | scales=[0.75, 2.], 16 | cropsize=[640, 640], 17 | eval_crop=[640, 640], 18 | eval_scales=[0.5, 0.75, 1, 1.25, 1.5, 1.75], 19 | ims_per_gpu=2, 20 | eval_ims_per_gpu=1, 21 | use_fp16=True, 22 | use_sync_bn=True, 23 | respth='./res', 24 | ) 25 | -------------------------------------------------------------------------------- /datasets/ade20k/annotations: -------------------------------------------------------------------------------- 1 | /data/zzy/.datasets/ADEChallengeData2016/annotations/ -------------------------------------------------------------------------------- /datasets/ade20k/images: -------------------------------------------------------------------------------- 1 | /data/zzy/.datasets/ADEChallengeData2016/images/ -------------------------------------------------------------------------------- /datasets/cityscapes/gtFine: -------------------------------------------------------------------------------- 1 | /data/zzy/zzy/cityscapes/cityscapes/gtFine/ -------------------------------------------------------------------------------- /datasets/cityscapes/leftImg8bit: -------------------------------------------------------------------------------- 1 | /data/zzy/zzy/cityscapes/cityscapes/leftImg8bit/ -------------------------------------------------------------------------------- /datasets/coco/images/train2017: -------------------------------------------------------------------------------- 1 | /data/zzy/zzy/coco/images/train2017/ -------------------------------------------------------------------------------- /datasets/coco/images/val2017: -------------------------------------------------------------------------------- 1 | /data/zzy/zzy/coco/images/val2017/ -------------------------------------------------------------------------------- /datasets/coco/labels/train2017: -------------------------------------------------------------------------------- 1 | /data/zzy/zzy/coco/labels/train2017/ -------------------------------------------------------------------------------- /datasets/coco/labels/val2017: -------------------------------------------------------------------------------- 1 | /data/zzy/zzy/coco/labels/val2017/ -------------------------------------------------------------------------------- /dist_train.sh: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | NOTE: replace torchrun with torch.distributed.launch if you use older version of pytorch. I suggest you use the same version as I do since I have not tested compatibility with older version after updating. 4 | ''' 5 | 6 | 7 | ## bisenetv1 cityscapes 8 | export CUDA_VISIBLE_DEVICES=0,1 9 | cfg_file=configs/bisenetv1_city.py 10 | NGPUS=2 11 | torchrun --nproc_per_node=$NGPUS tools/train_amp.py --config $cfg_file 12 | 13 | 14 | ## bisenetv2 cityscapes 15 | export CUDA_VISIBLE_DEVICES=0,1 16 | cfg_file=configs/bisenetv2_city.py 17 | NGPUS=2 18 | torchrun --nproc_per_node=$NGPUS tools/train_amp.py --config $cfg_file 19 | 20 | 21 | ## bisenetv1 cocostuff 22 | export CUDA_VISIBLE_DEVICES=0,1,2,3 23 | cfg_file=configs/bisenetv1_coco.py 24 | NGPUS=4 25 | torchrun --nproc_per_node=$NGPUS tools/train_amp.py --config $cfg_file 26 | 27 | 28 | ## bisenetv2 cocostuff 29 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 30 | cfg_file=configs/bisenetv2_coco.py 31 | NGPUS=8 32 | torchrun --nproc_per_node=$NGPUS tools/train_amp.py --config $cfg_file 33 | 34 | 35 | ## bisenetv1 ade20k 36 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 37 | cfg_file=configs/bisenetv1_ade20k.py 38 | NGPUS=8 39 | torchrun --nproc_per_node=$NGPUS tools/train_amp.py --config $cfg_file 40 | 41 | 42 | ## bisenetv2 ade20k 43 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 44 | cfg_file=configs/bisenetv2_ade20k.py 45 | NGPUS=8 46 | torchrun --nproc_per_node=$NGPUS tools/train_amp.py --config $cfg_file 47 | -------------------------------------------------------------------------------- /example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CoinCheung/BiSeNet/3d9b2cf592bcb1185cb52d710b476d8a1bde8120/example.png -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CoinCheung/BiSeNet/3d9b2cf592bcb1185cb52d710b476d8a1bde8120/lib/__init__.py -------------------------------------------------------------------------------- /lib/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .get_dataloader import get_data_loader 3 | -------------------------------------------------------------------------------- /lib/data/ade20k.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import os 5 | import os.path as osp 6 | import json 7 | 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader 10 | import torch.distributed as dist 11 | import cv2 12 | import numpy as np 13 | 14 | import lib.data.transform_cv2 as T 15 | from lib.data.base_dataset import BaseDataset 16 | 17 | ''' 18 | proportion of each class label pixels: 19 | [0.1692778570779725, 0.11564757275917185, 0.0952101638485813, 0.06663867349694136, 0.05213595836428788, 0.04856869977177328, 0.04285300460652723, 0.024667459730413076, 0.021459432596108052, 0.01951911788079975, 0.019458422169334556, 0.017972951662770457, 0.017102797922112795, 0.016127154995430226, 0.012743318904507446, 0.011871312183986243, 0.01169223174996906, 0.010873715499098895, 0.01119535711707017, 0.01106824347921356, 0.010700814956159628, 0.00792769980935508, 0.007320940186670243, 0.007101978087028939, 0.006652130884336369, 0.0065129268341813954, 0.005905601374046595, 0.005655465856321791, 0.00485152244584825, 0.004812313401121428, 0.004808430157907591, 0.004852065319115992, 0.0035166264746248105, 0.0034049293812196796, 0.0031501695661207163, 0.003200865983720736, 0.0027563053654176255, 0.0026019635559833536, 0.002535207367187799, 0.0024709898687369503, 0.002511264681160722, 0.002349575022340693, 0.0022952289072600395, 0.0021756144527500325, 0.0020667410351909894, 20 | 0.002019785482875027, 0.001971430263652598, 0.0019830032929254865, 0.0019170129596070547, 0.0019400873699042965, 0.0019177214046286212, 0.001992758707175458, 0.0019064211898405371, 0.001794991169874655, 0.0017086228805355563, 0.001816450049952539, 0.0018115561530790863, 0.0017526224833158293, 0.0016693853602227783, 0.001690968246884664, 0.001672815290479542, 0.0016435338913693607, 0.0015994805524026869, 0.001415586825791652, 0.0015309535955159497, 0.0015066783881302896, 0.0015584265652761034, 0.0014294452504793305, 0.0014381224963739522, 0.0013854752714941247, 0.001299217899155161, 0.0012526667460881378, 0.0013178209535318454, 0.0012941402888239277, 0.0010893388225083507, 0.0011300189527483507, 0.0010488809855522653, 0.0009206912461167046, 0.0009957668988478528, 0.0009413381127111981, 0.0009365154048026355, 0.0009059601825045681, 0.0008541199189880419, 0.0008971791385063005, 0.0008428502465623139, 0.0008056902958152122, 0.0008098830962054097, 0.0007822564960661871, 0.0007982742428082544, 0.0007502832355158758, 0.0007779780392762995, 0.0007712568824233966, 0.0007453305503359334, 0.0006837047894907241, 0.0007144561259049724, 0.0006892632697976981, 21 | 0.0006652429648347085, 0.0006708271650257716, 0.0006737982709217282, 0.0006266153732017621, 0.0006591083131957701, 0.0006729084088606035, 0.0006615025588342957, 0.0005978453864296776, 0.0005662905332794616, 0.0005832571600309656, 0.000558171776296493, 0.0005270943484946844, 0.0005918616094679417, 0.0005653340750898915, 0.0005626451989934503, 0.0005906185582842337, 0.0005217418569022469, 0.0005282586325333688, 0.0005198277923139954, 0.0004861910064034809, 0.0005218504774841597, 0.0005172358250665335, 0.0005247616468645153, 0.0005357304885031275, 0.0004276964118043196, 0.0004607179872730913, 0.00041193838996318965, 0.00042133234798497776, 0.000374820234027733, 0.00041071531761801536, 0.0003664373889492048, 0.00043033958917813777, 0.00037797413481418125, 0.0004129435322190717, 0.00037504252731164754, 0.0003633328611545351, 0.00039741354470741193, 0.0003815260048785467, 0.00037395769934345317, 0.00037914990094397704, 0.000360210650939554, 0.0003641708241638368, 0.0003354311501122861, 0.0003386525655944687, 0.0003593692433029189, 0.00034422115014162057, 0.00032131529694189243, 0.00031263024322531515, 0.0003252564098949305, 0.00034751306566322646, 0.0002711341955909471, 0.00022987904222809388, 0.000242549759411221, 0.0002045743505533957] 22 | ''' 23 | 24 | 25 | 26 | class ADE20k(BaseDataset): 27 | 28 | def __init__(self, dataroot, annpath, trans_func=None, mode='train'): 29 | super(ADE20k, self).__init__( 30 | dataroot, annpath, trans_func, mode) 31 | self.n_cats = 150 32 | self.lb_ignore = 255 33 | self.lb_map = np.arange(200) - 1 # label range from 1 to 149, 0 is ignored 34 | self.lb_map[0] = 255 35 | 36 | self.to_tensor = T.ToTensor( 37 | mean=(0.49343230, 0.46819794, 0.43106043), # ade20k, rgb 38 | std=(0.25680755, 0.25506608, 0.27422913), 39 | ) 40 | 41 | -------------------------------------------------------------------------------- /lib/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import os 5 | import os.path as osp 6 | import json 7 | 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader 10 | import torch.distributed as dist 11 | import cv2 12 | import numpy as np 13 | 14 | 15 | 16 | 17 | class BaseDataset(Dataset): 18 | ''' 19 | ''' 20 | def __init__(self, dataroot, annpath, trans_func=None, mode='train'): 21 | super(BaseDataset, self).__init__() 22 | assert mode in ('train', 'val', 'test') 23 | self.mode = mode 24 | self.trans_func = trans_func 25 | 26 | self.lb_ignore = -100 27 | self.lb_map = None 28 | 29 | with open(annpath, 'r') as fr: 30 | pairs = fr.read().splitlines() 31 | self.img_paths, self.lb_paths = [], [] 32 | for pair in pairs: 33 | imgpth, lbpth = pair.split(',') 34 | self.img_paths.append(osp.join(dataroot, imgpth)) 35 | self.lb_paths.append(osp.join(dataroot, lbpth)) 36 | 37 | assert len(self.img_paths) == len(self.lb_paths) 38 | self.len = len(self.img_paths) 39 | 40 | def __getitem__(self, idx): 41 | impth, lbpth = self.img_paths[idx], self.lb_paths[idx] 42 | img, label = self.get_image(impth, lbpth) 43 | if not self.lb_map is None: 44 | label = self.lb_map[label] 45 | im_lb = dict(im=img, lb=label) 46 | if not self.trans_func is None: 47 | im_lb = self.trans_func(im_lb) 48 | im_lb = self.to_tensor(im_lb) 49 | img, label = im_lb['im'], im_lb['lb'] 50 | return img.detach(), label.unsqueeze(0).detach() 51 | 52 | def get_image(self, impth, lbpth): 53 | img = cv2.imread(impth)[:, :, ::-1].copy() 54 | label = cv2.imread(lbpth, 0) 55 | return img, label 56 | 57 | def __len__(self): 58 | return self.len 59 | 60 | 61 | if __name__ == "__main__": 62 | from tqdm import tqdm 63 | from torch.utils.data import DataLoader 64 | ds = CityScapes('./data/', mode='val') 65 | dl = DataLoader(ds, 66 | batch_size = 4, 67 | shuffle = True, 68 | num_workers = 4, 69 | drop_last = True) 70 | for imgs, label in dl: 71 | print(len(imgs)) 72 | for el in imgs: 73 | print(el.size()) 74 | break 75 | -------------------------------------------------------------------------------- /lib/data/coco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import os 5 | import os.path as osp 6 | import json 7 | 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader 10 | import torch.distributed as dist 11 | import cv2 12 | import numpy as np 13 | 14 | import lib.data.transform_cv2 as T 15 | from lib.data.base_dataset import BaseDataset 16 | 17 | ''' 18 | 91(thing) + 91(stuff) = 182 classes, label proportions are: 19 | [0.0901445377, 0.00157896236, 0.00611962763, 0.00494526505, 0.00335260064, 0.00765355955, 0.00772972804, 0.00631509744, 20 | 0.00270457286, 0.000697793344, 0.00114085574, 0.0, 0.00114084131, 0.000705729068, 0.00359758029, 0.00162208938, 0.00598373796, 21 | 0.00440213609, 0.00362085441, 0.00193052224, 0.00271001196, 0.00492864603, 0.00186985393, 0.00332902228, 0.00334420294, 0.0, 22 | 0.000922751106, 0.00298028204, 0.0, 0.0, 0.0010437561, 0.000285608411, 0.00318569535, 0.000314216755, 0.000313060076, 0.000364755975, 23 | 0.000135920434, 0.000678980469, 0.000145436185, 0.000187677684, 0.000640885889, 0.00121345742, 0.000586313048, 0.00160106929, 0.0, 24 | 0.000887093272, 0.00252332669, 0.000283407598, 0.000423017189, 0.000247005886, 0.00607086751, 0.002264644, 0.00108296684, 0.00299262899, 25 | 0.0013542901, 0.0018255991, 0.000719220519, 0.00127748254, 0.00743539745, 0.0018222117, 0.00368625641, 0.00644224839, 0.00576837542, 26 | 0.00234158491, 0.0102560197, 0.0, 0.0310601945, 0.0, 0.0, 0.00321417022, 0.0, 0.00343909654, 0.00366968441, 0.000223077284, 27 | 0.000549851977, 0.00142833996, 0.000976368198, 0.000932849475, 0.00367802183, 6.33631941e-05, 0.00179415878, 0.00384408865, 0.0, 28 | 0.00178728429, 0.00131955324, 0.00172710316, 0.000355333114, 0.00323052075, 3.45024606e-05, 0.000159319051, 0.0, 0.00233498927, 29 | 0.00115535012, 0.00216354199, 0.00122636929, 0.0297802789, 0.00599919161, 0.00792527951, 0.00446247753, 0.00229155615, 30 | 0.00481623284, 0.00928416394, 0.000292110971, 0.00100709844, 0.0036950065, 0.0238653594, 0.00318962423, 0.000957967243, 0.00491549702, 31 | 0.00305316147, 0.0142686986, 0.00667806178, 0.00940045853, 0.000994700392, 0.00697502858, 0.00163056828, 0.00655119369, 0.00599044442, 32 | 0.00200317424, 0.00546109479, 0.00496814246, 0.00128356119, 0.00893122042, 0.0423373213, 0.00275267517, 0.00730936505, 0.00231434982, 33 | 0.00435102045, 0.00276966794, 0.00141028174, 0.000251683147, 0.00878006131, 0.00357672108, 0.000183633027, 0.00514584856, 34 | 0.000848967739, 0.000662099529, 0.00186883821, 0.00417270686, 0.0224302911, 0.000551947753, 0.00799009014, 0.00379765772, 35 | 0.00226731642, 0.0181341982, 0.000835227067, 0.00287355753, 0.00546769461, 0.0242787139, 0.00318951861, 0.00147349686, 36 | 0.00167046288, 0.000520877717, 0.0101631583, 0.0234788756, 0.00283978366, 0.0624405778, 0.00258472693, 0.0204314774, 0.000550128266, 37 | 0.00112924659, 0.001457768, 0.00190406757, 0.00173232644, 0.0116980759, 0.000850599027, 0.00565381261, 0.000787379463, 0.0577763754, 38 | 0.00214883711, 0.00553984356, 0.0443605019, 0.0218570174, 0.0027310644, 0.00225446528, 0.00903008323, 0.00644298871, 0.00442167269, 39 | 0.000129279566, 0.00176047379, 0.0101637834, 0.00255549522] 40 | 41 | 11 thing classes has no annos, proportions are 0: 42 | [11, 25, 28, 29, 44, 65, 67, 68, 70, 82, 90] 43 | ''' 44 | 45 | 46 | 47 | class CocoStuff(BaseDataset): 48 | 49 | def __init__(self, dataroot, annpath, trans_func=None, mode='train'): 50 | super(CocoStuff, self).__init__( 51 | dataroot, annpath, trans_func, mode) 52 | self.n_cats = 171 # 91 stuff, 91 thing, 11 of thing have no annos 53 | self.lb_ignore = 255 54 | 55 | ## label mapping, remove non-existing labels 56 | missing = [11, 25, 28, 29, 44, 65, 67, 68, 70, 82, 90] 57 | remain = [ind for ind in range(182) if not ind in missing] 58 | self.lb_map = np.arange(256) 59 | for ind in remain: 60 | self.lb_map[ind] = remain.index(ind) 61 | 62 | self.to_tensor = T.ToTensor( 63 | mean=(0.46962251, 0.4464104, 0.40718787), # coco, rgb 64 | std=(0.27469736, 0.27012361, 0.28515933), 65 | ) 66 | 67 | 68 | -------------------------------------------------------------------------------- /lib/data/customer_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import lib.data.transform_cv2 as T 6 | from lib.data.base_dataset import BaseDataset 7 | 8 | 9 | class CustomerDataset(BaseDataset): 10 | 11 | def __init__(self, dataroot, annpath, trans_func=None, mode='train'): 12 | super(CustomerDataset, self).__init__( 13 | dataroot, annpath, trans_func, mode) 14 | self.lb_ignore = 255 15 | 16 | self.to_tensor = T.ToTensor( 17 | mean=(0.4, 0.4, 0.4), # rgb 18 | std=(0.2, 0.2, 0.2), 19 | ) 20 | 21 | 22 | -------------------------------------------------------------------------------- /lib/data/get_dataloader.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | import torch.distributed as dist 5 | 6 | import lib.data.transform_cv2 as T 7 | from lib.data.sampler import RepeatedDistSampler 8 | 9 | from lib.data.cityscapes_cv2 import CityScapes 10 | from lib.data.coco import CocoStuff 11 | from lib.data.ade20k import ADE20k 12 | from lib.data.customer_dataset import CustomerDataset 13 | 14 | 15 | 16 | 17 | 18 | def get_data_loader(cfg, mode='train'): 19 | if mode == 'train': 20 | trans_func = T.TransformationTrain(cfg.scales, cfg.cropsize) 21 | batchsize = cfg.ims_per_gpu 22 | annpath = cfg.train_im_anns 23 | shuffle = True 24 | drop_last = True 25 | elif mode == 'val': 26 | trans_func = T.TransformationVal() 27 | batchsize = cfg.eval_ims_per_gpu 28 | annpath = cfg.val_im_anns 29 | shuffle = False 30 | drop_last = False 31 | 32 | ds = eval(cfg.dataset)(cfg.im_root, annpath, trans_func=trans_func, mode=mode) 33 | 34 | if dist.is_initialized(): 35 | assert dist.is_available(), "dist should be initialzed" 36 | if mode == 'train': 37 | assert not cfg.max_iter is None 38 | n_train_imgs = cfg.ims_per_gpu * dist.get_world_size() * cfg.max_iter 39 | sampler = RepeatedDistSampler(ds, n_train_imgs, shuffle=shuffle) 40 | else: 41 | sampler = torch.utils.data.distributed.DistributedSampler( 42 | ds, shuffle=shuffle) 43 | batchsampler = torch.utils.data.sampler.BatchSampler( 44 | sampler, batchsize, drop_last=drop_last 45 | ) 46 | dl = DataLoader( 47 | ds, 48 | batch_sampler=batchsampler, 49 | num_workers=4, 50 | pin_memory=True, 51 | ) 52 | else: 53 | dl = DataLoader( 54 | ds, 55 | batch_size=batchsize, 56 | shuffle=shuffle, 57 | drop_last=drop_last, 58 | num_workers=4, 59 | pin_memory=True, 60 | ) 61 | return dl 62 | -------------------------------------------------------------------------------- /lib/data/sampler.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import torch 4 | from torch.utils.data.sampler import Sampler 5 | import torch.distributed as dist 6 | 7 | 8 | class RepeatedDistSampler(Sampler): 9 | """Sampler that restricts data loading to a subset of the dataset. 10 | 11 | It is especially useful in conjunction with 12 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 13 | process can pass a DistributedSampler instance as a DataLoader sampler, 14 | and load a subset of the original dataset that is exclusive to it. 15 | 16 | .. note:: 17 | Dataset is assumed to be of constant size. 18 | 19 | Arguments: 20 | dataset: Dataset used for sampling. 21 | num_replicas (optional): Number of processes participating in 22 | distributed training. 23 | rank (optional): Rank of the current process within num_replicas. 24 | shuffle (optional): If true (default), sampler will shuffle the indices 25 | """ 26 | 27 | def __init__(self, dataset, num_imgs, num_replicas=None, rank=None, shuffle=True, ba=False): 28 | if num_replicas is None: 29 | if not dist.is_available(): 30 | raise RuntimeError("Requires distributed package to be available") 31 | num_replicas = dist.get_world_size() 32 | if rank is None: 33 | if not dist.is_available(): 34 | raise RuntimeError("Requires distributed package to be available") 35 | rank = dist.get_rank() 36 | self.dataset = dataset 37 | self.num_replicas = num_replicas 38 | self.rank = rank 39 | self.num_imgs_rank = int(math.ceil(num_imgs * 1.0 / self.num_replicas)) 40 | self.total_size = self.num_imgs_rank * self.num_replicas 41 | self.num_imgs = num_imgs 42 | self.shuffle = shuffle 43 | self.ba = ba 44 | 45 | 46 | def __iter__(self): 47 | # deterministically shuffle based on epoch 48 | g = torch.Generator() 49 | n_repeats = self.num_imgs // len(self.dataset) + 1 50 | indices = [] 51 | for n in range(n_repeats): 52 | if self.shuffle: 53 | g.manual_seed(n) 54 | indices += torch.randperm(len(self.dataset), generator=g).tolist() 55 | else: 56 | indices += [i for i in range(len(self.dataset))] 57 | 58 | # add extra samples to make it evenly divisible 59 | indices = indices[:self.total_size] 60 | assert len(indices) == self.total_size 61 | 62 | if self.ba: 63 | n_rep = max(4, self.num_replicas) 64 | len_ind = len(indices) // n_rep + 1 65 | indices = indices[:len_ind] 66 | indices = [ind for ind in indices for _ in range(n_rep)] 67 | 68 | # subsample 69 | indices = indices[self.rank:self.total_size:self.num_replicas] 70 | assert len(indices) == self.num_imgs_rank 71 | 72 | return iter(indices) 73 | 74 | def __len__(self): 75 | return self.num_imgs_rank 76 | 77 | -------------------------------------------------------------------------------- /lib/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import os.path as osp 6 | import time 7 | import logging 8 | 9 | import torch.distributed as dist 10 | 11 | 12 | def setup_logger(name, logpth): 13 | logfile = '{}-{}.log'.format(name, time.strftime('%Y-%m-%d-%H-%M-%S')) 14 | logfile = osp.join(logpth, logfile) 15 | FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s' 16 | log_level = logging.INFO 17 | if dist.is_initialized() and dist.get_rank() != 0: 18 | log_level = logging.WARNING 19 | try: 20 | logging.basicConfig(level=log_level, format=FORMAT, filename=logfile, force=True) 21 | except Exception: 22 | for hl in logging.root.handlers: logging.root.removeHandler(hl) 23 | logging.basicConfig(level=log_level, format=FORMAT, filename=logfile) 24 | logging.root.addHandler(logging.StreamHandler()) 25 | 26 | 27 | def log_msg(it, max_iter, lr, time_meter, loss_meter, loss_pre_meter, 28 | loss_aux_meters): 29 | t_intv, eta = time_meter.get() 30 | loss_avg, _ = loss_meter.get() 31 | loss_pre_avg, _ = loss_pre_meter.get() 32 | loss_aux_avg = ', '.join(['{}: {:.4f}'.format(el.name, el.get()[0]) for el in loss_aux_meters]) 33 | msg = ', '.join([ 34 | f'iter: {it+1}/{max_iter}', 35 | f'lr: {lr:4f}', 36 | f'eta: {eta}', 37 | f'time: {t_intv:.2f}', 38 | f'loss: {loss_avg:.4f}', 39 | f'loss_pre: {loss_pre_avg:.4f}', 40 | ]) 41 | msg += ', ' + loss_aux_avg 42 | 43 | return msg 44 | -------------------------------------------------------------------------------- /lib/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import math 5 | from bisect import bisect_right 6 | import torch 7 | 8 | 9 | class WarmupLrScheduler(torch.optim.lr_scheduler._LRScheduler): 10 | 11 | def __init__( 12 | self, 13 | optimizer, 14 | warmup_iter=500, 15 | warmup_ratio=5e-4, 16 | warmup='exp', 17 | last_epoch=-1, 18 | ): 19 | self.warmup_iter = warmup_iter 20 | self.warmup_ratio = warmup_ratio 21 | self.warmup = warmup 22 | super(WarmupLrScheduler, self).__init__(optimizer, last_epoch) 23 | 24 | def get_lr(self): 25 | ratio = self.get_lr_ratio() 26 | lrs = [ratio * lr for lr in self.base_lrs] 27 | return lrs 28 | 29 | def get_lr_ratio(self): 30 | if self.last_epoch < self.warmup_iter: 31 | ratio = self.get_warmup_ratio() 32 | else: 33 | ratio = self.get_main_ratio() 34 | return ratio 35 | 36 | def get_main_ratio(self): 37 | raise NotImplementedError 38 | 39 | def get_warmup_ratio(self): 40 | assert self.warmup in ('linear', 'exp') 41 | alpha = self.last_epoch / self.warmup_iter 42 | if self.warmup == 'linear': 43 | ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha 44 | elif self.warmup == 'exp': 45 | ratio = self.warmup_ratio ** (1. - alpha) 46 | return ratio 47 | 48 | 49 | class WarmupPolyLrScheduler(WarmupLrScheduler): 50 | 51 | def __init__( 52 | self, 53 | optimizer, 54 | power, 55 | max_iter, 56 | warmup_iter=500, 57 | warmup_ratio=5e-4, 58 | warmup='exp', 59 | last_epoch=-1, 60 | ): 61 | self.power = power 62 | self.max_iter = max_iter 63 | super(WarmupPolyLrScheduler, self).__init__( 64 | optimizer, warmup_iter, warmup_ratio, warmup, last_epoch) 65 | 66 | def get_main_ratio(self): 67 | real_iter = self.last_epoch - self.warmup_iter 68 | real_max_iter = self.max_iter - self.warmup_iter 69 | alpha = real_iter / real_max_iter 70 | ratio = (1 - alpha) ** self.power 71 | return ratio 72 | 73 | 74 | class WarmupExpLrScheduler(WarmupLrScheduler): 75 | 76 | def __init__( 77 | self, 78 | optimizer, 79 | gamma, 80 | interval=1, 81 | warmup_iter=500, 82 | warmup_ratio=5e-4, 83 | warmup='exp', 84 | last_epoch=-1, 85 | ): 86 | self.gamma = gamma 87 | self.interval = interval 88 | super(WarmupExpLrScheduler, self).__init__( 89 | optimizer, warmup_iter, warmup_ratio, warmup, last_epoch) 90 | 91 | def get_main_ratio(self): 92 | real_iter = self.last_epoch - self.warmup_iter 93 | ratio = self.gamma ** (real_iter // self.interval) 94 | return ratio 95 | 96 | 97 | class WarmupCosineLrScheduler(WarmupLrScheduler): 98 | 99 | def __init__( 100 | self, 101 | optimizer, 102 | max_iter, 103 | eta_ratio=0, 104 | warmup_iter=500, 105 | warmup_ratio=5e-4, 106 | warmup='exp', 107 | last_epoch=-1, 108 | ): 109 | self.eta_ratio = eta_ratio 110 | self.max_iter = max_iter 111 | super(WarmupCosineLrScheduler, self).__init__( 112 | optimizer, warmup_iter, warmup_ratio, warmup, last_epoch) 113 | 114 | def get_main_ratio(self): 115 | real_iter = self.last_epoch - self.warmup_iter 116 | real_max_iter = self.max_iter - self.warmup_iter 117 | return self.eta_ratio + (1 - self.eta_ratio) * ( 118 | 1 + math.cos(math.pi * self.last_epoch / real_max_iter)) / 2 119 | 120 | 121 | class WarmupStepLrScheduler(WarmupLrScheduler): 122 | 123 | def __init__( 124 | self, 125 | optimizer, 126 | milestones: list, 127 | gamma=0.1, 128 | warmup_iter=500, 129 | warmup_ratio=5e-4, 130 | warmup='exp', 131 | last_epoch=-1, 132 | ): 133 | self.milestones = milestones 134 | self.gamma = gamma 135 | super(WarmupStepLrScheduler, self).__init__( 136 | optimizer, warmup_iter, warmup_ratio, warmup, last_epoch) 137 | 138 | def get_main_ratio(self): 139 | real_iter = self.last_epoch - self.warmup_iter 140 | ratio = self.gamma ** bisect_right(self.milestones, real_iter) 141 | return ratio 142 | 143 | 144 | if __name__ == "__main__": 145 | model = torch.nn.Conv2d(3, 16, 3, 1, 1) 146 | optim = torch.optim.SGD(model.parameters(), lr=1e-3) 147 | 148 | max_iter = 20000 149 | lr_scheduler = WarmupPolyLrScheduler(optim, 0.9, max_iter, 200, 0.1, 'linear', -1) 150 | lrs = [] 151 | for _ in range(max_iter): 152 | lr = lr_scheduler.get_lr()[0] 153 | print(lr) 154 | lrs.append(lr) 155 | lr_scheduler.step() 156 | import matplotlib 157 | import matplotlib.pyplot as plt 158 | import numpy as np 159 | lrs = np.array(lrs) 160 | n_lrs = len(lrs) 161 | plt.plot(np.arange(n_lrs), lrs) 162 | plt.grid() 163 | plt.show() 164 | 165 | 166 | -------------------------------------------------------------------------------- /lib/meters.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | import datetime 4 | 5 | class TimeMeter(object): 6 | 7 | def __init__(self, max_iter): 8 | self.iter = 0 9 | self.max_iter = max_iter 10 | self.st = time.time() 11 | self.global_st = self.st 12 | self.curr = self.st 13 | 14 | def update(self): 15 | self.iter += 1 16 | 17 | def get(self): 18 | self.curr = time.time() 19 | interv = self.curr - self.st 20 | global_interv = self.curr - self.global_st 21 | eta = int((self.max_iter-self.iter) * (global_interv / (self.iter+1))) 22 | eta = str(datetime.timedelta(seconds=eta)) 23 | self.st = self.curr 24 | return interv, eta 25 | 26 | 27 | class AvgMeter(object): 28 | 29 | def __init__(self, name): 30 | self.name = name 31 | self.seq = [] 32 | self.global_seq = [] 33 | 34 | def update(self, val): 35 | self.seq.append(val) 36 | self.global_seq.append(val) 37 | 38 | def get(self): 39 | avg = sum(self.seq) / len(self.seq) 40 | global_avg = sum(self.global_seq) / len(self.global_seq) 41 | self.seq = [] 42 | return avg, global_avg 43 | 44 | -------------------------------------------------------------------------------- /lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from .bisenetv1 import BiSeNetV1 4 | from .bisenetv2 import BiSeNetV2 5 | 6 | 7 | model_factory = { 8 | 'bisenetv1': BiSeNetV1, 9 | 'bisenetv2': BiSeNetV2, 10 | } 11 | -------------------------------------------------------------------------------- /lib/models/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.model_zoo as modelzoo 8 | 9 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 10 | 11 | 12 | from torch.nn import BatchNorm2d 13 | 14 | 15 | def conv3x3(in_planes, out_planes, stride=1): 16 | """3x3 convolution with padding""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=1, bias=False) 19 | 20 | 21 | class BasicBlock(nn.Module): 22 | def __init__(self, in_chan, out_chan, stride=1): 23 | super(BasicBlock, self).__init__() 24 | self.conv1 = conv3x3(in_chan, out_chan, stride) 25 | self.bn1 = BatchNorm2d(out_chan) 26 | self.conv2 = conv3x3(out_chan, out_chan) 27 | self.bn2 = BatchNorm2d(out_chan) 28 | self.relu = nn.ReLU(inplace=True) 29 | self.downsample = None 30 | if in_chan != out_chan or stride != 1: 31 | self.downsample = nn.Sequential( 32 | nn.Conv2d(in_chan, out_chan, 33 | kernel_size=1, stride=stride, bias=False), 34 | BatchNorm2d(out_chan), 35 | ) 36 | 37 | def forward(self, x): 38 | residual = self.conv1(x) 39 | residual = self.bn1(residual) 40 | residual = self.relu(residual) 41 | residual = self.conv2(residual) 42 | residual = self.bn2(residual) 43 | 44 | shortcut = x 45 | if self.downsample is not None: 46 | shortcut = self.downsample(x) 47 | 48 | out = shortcut + residual 49 | out = self.relu(out) 50 | return out 51 | 52 | 53 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 54 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 55 | for i in range(bnum-1): 56 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 57 | return nn.Sequential(*layers) 58 | 59 | 60 | class Resnet18(nn.Module): 61 | def __init__(self): 62 | super(Resnet18, self).__init__() 63 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 64 | bias=False) 65 | self.bn1 = BatchNorm2d(64) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 68 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 69 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 70 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 71 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 72 | self.init_weight() 73 | 74 | def forward(self, x): 75 | x = self.conv1(x) 76 | x = self.bn1(x) 77 | x = self.relu(x) 78 | x = self.maxpool(x) 79 | 80 | x = self.layer1(x) 81 | feat8 = self.layer2(x) # 1/8 82 | feat16 = self.layer3(feat8) # 1/16 83 | feat32 = self.layer4(feat16) # 1/32 84 | return feat8, feat16, feat32 85 | 86 | def init_weight(self): 87 | state_dict = modelzoo.load_url(resnet18_url) 88 | self_state_dict = self.state_dict() 89 | for k, v in state_dict.items(): 90 | if 'fc' in k: continue 91 | self_state_dict.update({k: v}) 92 | self.load_state_dict(self_state_dict) 93 | 94 | def get_params(self): 95 | wd_params, nowd_params = [], [] 96 | for name, module in self.named_modules(): 97 | if isinstance(module, (nn.Linear, nn.Conv2d)): 98 | wd_params.append(module.weight) 99 | if not module.bias is None: 100 | nowd_params.append(module.bias) 101 | elif isinstance(module, nn.modules.batchnorm._BatchNorm): 102 | nowd_params += list(module.parameters()) 103 | return wd_params, nowd_params 104 | 105 | 106 | if __name__ == "__main__": 107 | net = Resnet18() 108 | x = torch.randn(16, 3, 224, 224) 109 | out = net(x) 110 | print(out[0].size()) 111 | print(out[1].size()) 112 | print(out[2].size()) 113 | net.get_params() 114 | -------------------------------------------------------------------------------- /lib/ohem_ce_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | # import ohem_cpp 11 | # class OhemCELoss(nn.Module): 12 | # 13 | # def __init__(self, thresh, lb_ignore=255): 14 | # super(OhemCELoss, self).__init__() 15 | # self.score_thresh = thresh 16 | # self.lb_ignore = lb_ignore 17 | # self.criteria = nn.CrossEntropyLoss(ignore_index=lb_ignore, reduction='mean') 18 | # 19 | # def forward(self, logits, labels): 20 | # n_min = labels[labels != self.lb_ignore].numel() // 16 21 | # labels = ohem_cpp.score_ohem_label( 22 | # logits, labels, self.lb_ignore, self.score_thresh, n_min).detach() 23 | # loss = self.criteria(logits, labels) 24 | # return loss 25 | 26 | 27 | class OhemCELoss(nn.Module): 28 | 29 | def __init__(self, thresh, lb_ignore=255): 30 | super(OhemCELoss, self).__init__() 31 | self.thresh = -torch.log(torch.tensor(thresh, requires_grad=False, dtype=torch.float)).cuda() 32 | self.lb_ignore = lb_ignore 33 | self.criteria = nn.CrossEntropyLoss(ignore_index=lb_ignore, reduction='none') 34 | 35 | def forward(self, logits, labels): 36 | n_min = labels[labels != self.lb_ignore].numel() // 16 37 | loss = self.criteria(logits, labels).view(-1) 38 | loss_hard = loss[loss > self.thresh] 39 | if loss_hard.numel() < n_min: 40 | loss_hard, _ = loss.topk(n_min) 41 | return torch.mean(loss_hard) 42 | 43 | 44 | if __name__ == '__main__': 45 | pass 46 | 47 | -------------------------------------------------------------------------------- /ncnn/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | CMAKE_MINIMUM_REQUIRED(VERSION 3.15) 2 | 3 | PROJECT(segment) 4 | 5 | set(CMAKE_CXX_FLAGS "-std=c++14 -O2") 6 | 7 | 8 | set (ncnn_DIR ${NCNN_ROOT}/lib/cmake/ncnn) 9 | find_package(OpenCV REQUIRED) 10 | find_package(OpenMP REQUIRED) 11 | find_package(ncnn REQUIRED) 12 | 13 | 14 | add_executable(segment segment.cpp) 15 | target_include_directories(segment PUBLIC ${OpenCV_INCLUDE_DIRS}) 16 | target_link_libraries(segment ${OpenCV_LIBRARIES} ncnn OpenMP::OpenMP_CXX) 17 | -------------------------------------------------------------------------------- /ncnn/README.md: -------------------------------------------------------------------------------- 1 | 2 | ### My platform 3 | 4 | * raspberry pi 3b 5 | * 2022-04-04-raspios-bullseye-armhf-lite.img 6 | * cpu: 4 core armv8, memory: 1G 7 | 8 | 9 | 10 | ### Install ncnn 11 | 12 | Just follow the ncnn official tutoral of [build-for-linux](https://github.com/Tencent/ncnn/wiki/how-to-build#build-for-linux) to install ncnn. Following steps are all carried out on my raspberry pi: 13 | 14 | **step 1:** install dependencies 15 | ``` 16 | $ sudo apt install build-essential git cmake libprotobuf-dev protobuf-compiler libopencv-dev 17 | ``` 18 | 19 | **step 2:** (optional) install vulkan 20 | 21 | **step 3:** build 22 | I am using commit `6869c81ed3e7170dc0`, and I have not tested over other commits. 23 | ``` 24 | $ git clone https://github.com/Tencent/ncnn.git 25 | $ cd ncnn 26 | $ git reset --hard 6869c81ed3e7170dc0 27 | $ git submodule update --init 28 | $ mkdir -p build 29 | $ cmake -DCMAKE_BUILD_TYPE=Release -DNCNN_VULKAN=OFF -DNCNN_BUILD_TOOLS=ON -DCMAKE_TOOLCHAIN_FILE=../toolchains/pi3.toolchain.cmake .. 30 | $ make -j2 31 | $ make install 32 | ``` 33 | 34 | ### Convert pytorch model to ncnn model 35 | 36 | #### 1. dependencies 37 | ``` 38 | $ python -m pip install onnx-simplifier 39 | ``` 40 | 41 | #### 2. convert pytorch model to ncnn model via onnx 42 | On your training platform: 43 | ``` 44 | $ cd BiSeNet/ 45 | $ python tools/export_onnx.py --aux-mode eval --config configs/bisenetv2_city.py --weight-path /path/to/your/model.pth --outpath ./model_v2.onnx 46 | $ python -m onnxsim model_v2.onnx model_v2_sim.onnx 47 | ``` 48 | 49 | Then copy your `model_v2_sim.onnx` from training platform to raspberry device. 50 | 51 | On raspberry device: 52 | ``` 53 | $ /path/to/ncnn/build/tools/onnx/onnx2ncnn model_v2_sim.onnx model_v2_sim.param model_v2_sim.bin 54 | ``` 55 | 56 | You can optimize the ncnn model by fusing the layers and save the weights with fp16 datatype. 57 | On raspberry device: 58 | ``` 59 | $ /path/to/ncnn/build/tools/ncnnoptimize model_v2_sim.param model_v2_sim.bin model_v2_sim_opt.param model_v2_sim_opt.bin 65536 60 | $ mv model_v2_sim_opt.param model_v2_sim.param 61 | $ mv model_v2_sim_opt.bin model_v2_sim.bin 62 | ``` 63 | 64 | You can also quantize the model for int8 inference, following this [tutorial](https://github.com/Tencent/ncnn/wiki/quantized-int8-inference). Make sure your device support int8 inference. 65 | 66 | 67 | ### build and run the demo 68 | #### 1. compile demo code 69 | On raspberry device: 70 | ``` 71 | $ mkdir -p BiSeNet/ncnn/build 72 | $ cd BiSeNet/ncnn/build 73 | $ cmake .. -DNCNN_ROOT=/path/to/ncnn/build/install 74 | $ make 75 | ``` 76 | 77 | #### 2. run demo 78 | ``` 79 | ./segment 80 | ``` 81 | -------------------------------------------------------------------------------- /ncnn/segment.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "net.h" 3 | #include "mat.h" 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | 18 | using std::string; 19 | using std::vector; 20 | using cv::Mat; 21 | 22 | 23 | vector> get_color_map(); 24 | void inference(); 25 | 26 | 27 | int main(int argc, char** argv) { 28 | inference(); 29 | return 0; 30 | } 31 | 32 | 33 | void inference() { 34 | int nthreads = 4; 35 | string mod_param = "../models/model_v2_sim.param"; 36 | string mod_model = "../models/model_v2_sim.bin"; 37 | int oH{512}, oW{1024}, n_classes{19}; 38 | float mean[3] = {0.3257f, 0.3690f, 0.3223f}; 39 | float var[3] = {0.2112f, 0.2148f, 0.2115f}; 40 | string impth = "../../example.png"; 41 | string savepth = "out.png"; 42 | 43 | // load model 44 | ncnn::Net mod; 45 | #if NCNN_VULKAN 46 | int gpu_count = ncnn::get_gpu_count(); 47 | if (gpu_count <= 0) { 48 | fprintf(stderr, "we do not have gpu device\n"); 49 | return; 50 | } 51 | mod.opt.use_vulkan_compute = 1; 52 | mod.set_vulkan_device(1); 53 | #endif 54 | //// switch off fp16 55 | // bool use_fp16 = false; 56 | // mod.opt.use_fp16_packed = use_fp16; 57 | // mod.opt.use_fp16_storage = use_fp16; 58 | // mod.opt.use_fp16_arithmetic = use_fp16; 59 | //// switch on bf16 60 | // mod.opt.use_packing_layout = true; 61 | // mod.opt.use_ff16_storage = true; 62 | //// reduce cpu usage 63 | // net.opt.openmp_blocktime = 0; 64 | mod.opt.use_winograd_convolution = true; 65 | 66 | // we should set opt before load model 67 | mod.load_param(mod_param.c_str()); 68 | mod.load_model(mod_model.c_str()); 69 | 70 | // load image, and copy to ncnn mat 71 | cv::Mat im = cv::imread(impth); 72 | if (im.empty()) { 73 | fprintf(stderr, "cv::imread failed\n"); 74 | return; 75 | } 76 | 77 | ncnn::Mat inp = ncnn::Mat::from_pixels_resize( 78 | im.data, ncnn::Mat::PIXEL_BGR, im.cols, im.rows, oW, oH); 79 | for (float &el : mean) el *= 255.; 80 | for (float &el : var) el = 1. / (255. * el); 81 | inp.substract_mean_normalize(mean, var); 82 | 83 | // set input, run, get output 84 | ncnn::Extractor ex = mod.create_extractor(); 85 | ex.set_light_mode(true); 86 | ex.set_num_threads(nthreads); 87 | #if NCNN_VULKAN 88 | ex.set_vulkan_compute(true); 89 | #endif 90 | 91 | ex.input("input_image", inp); 92 | ncnn::Mat out; 93 | ex.extract("preds", out); // output is nchw, as onnx, where here n=1 94 | 95 | // generate colorful output, and dump 96 | vector> color_map = get_color_map(); 97 | Mat pred(cv::Size(oW, oH), CV_8UC3); 98 | int offset = oH * oW; 99 | omp_set_num_threads(omp_get_max_threads()); 100 | #pragma omp parallel for 101 | for (int i=0; i < oH; ++i) { 102 | uint8_t *ptr = pred.ptr(i); 103 | for (int j{0}; j < oW; ++j) { 104 | // compute argmax 105 | int idx, argmax{0}; 106 | float max; 107 | idx = i * oW + j; 108 | max = out[idx]; 109 | for (int k{1}; k < n_classes; ++k) { 110 | idx += offset; 111 | if (max < out[idx]) { 112 | max = out[idx]; 113 | argmax = k; 114 | } 115 | } 116 | // color the result 117 | ptr[0] = color_map[argmax][0]; 118 | ptr[1] = color_map[argmax][1]; 119 | ptr[2] = color_map[argmax][2]; 120 | ptr += 3; 121 | } 122 | } 123 | cv::imwrite(savepth, pred); 124 | 125 | ex.clear(); // must have this, or error 126 | mod.clear(); 127 | 128 | } 129 | 130 | 131 | vector> get_color_map() { 132 | vector> color_map(256, vector(3)); 133 | std::minstd_rand rand_eng(123); 134 | std::uniform_int_distribution u(0, 255); 135 | for (int i{0}; i < 256; ++i) { 136 | for (int j{0}; j < 3; ++j) { 137 | color_map[i][j] = u(rand_eng); 138 | } 139 | } 140 | return color_map; 141 | } 142 | -------------------------------------------------------------------------------- /old/bisenetv2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CoinCheung/BiSeNet/3d9b2cf592bcb1185cb52d710b476d8a1bde8120/old/bisenetv2/__init__.py -------------------------------------------------------------------------------- /old/bisenetv2/evaluatev2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import sys 5 | sys.path.insert(0, '.') 6 | import os 7 | import os.path as osp 8 | import logging 9 | import argparse 10 | import math 11 | 12 | from tqdm import tqdm 13 | import numpy as np 14 | import cv2 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | import torch.distributed as dist 20 | 21 | from bisenetv2.bisenetv2 import BiSeNetV2 22 | from bisenetv2.logger import setup_logger 23 | from bisenetv2.cityscapes_cv2 import get_data_loader 24 | 25 | 26 | 27 | 28 | class MscEvalV0(object): 29 | 30 | def __init__(self, ignore_label=255): 31 | self.ignore_label = ignore_label 32 | 33 | def __call__(self, net, dl, n_classes): 34 | ## evaluate 35 | hist = torch.zeros(n_classes, n_classes).cuda().detach() 36 | if dist.is_initialized() and dist.get_rank() != 0: 37 | diter = enumerate(dl) 38 | else: 39 | diter = enumerate(tqdm(dl)) 40 | for i, (imgs, label) in diter: 41 | N, _, H, W = label.shape 42 | label = label.squeeze(1).cuda() 43 | size = label.size()[-2:] 44 | imgs = imgs.cuda() 45 | logits = net(imgs)[0] 46 | logits = F.interpolate(logits, size=size, 47 | mode='bilinear', align_corners=True) 48 | probs = torch.softmax(logits, dim=1) 49 | preds = torch.argmax(probs, dim=1) 50 | keep = label != self.ignore_label 51 | hist += torch.bincount( 52 | label[keep] * n_classes + preds[keep], 53 | minlength=n_classes ** 2 54 | ).view(n_classes, n_classes) 55 | if dist.is_initialized(): 56 | dist.all_reduce(hist, dist.ReduceOp.SUM) 57 | ious = hist.diag() / (hist.sum(dim=0) + hist.sum(dim=1) - hist.diag()) 58 | miou = ious.mean() 59 | return miou.item() 60 | 61 | 62 | 63 | def eval_model(net, ims_per_gpu): 64 | is_dist = dist.is_initialized() 65 | dl = get_data_loader('./data', ims_per_gpu, mode='val', distributed=is_dist) 66 | net.eval() 67 | 68 | with torch.no_grad(): 69 | single_scale = MscEvalV0() 70 | mIOU = single_scale(net, dl, 19) 71 | logger = logging.getLogger() 72 | logger.info('mIOU is: %s\n', mIOU) 73 | 74 | 75 | def evaluate(weight_pth): 76 | logger = logging.getLogger() 77 | 78 | ## model 79 | logger.info('setup and restore model') 80 | net = BiSeNetV2(19) 81 | net.load_state_dict(torch.load(weight_pth)) 82 | net.cuda() 83 | 84 | is_dist = dist.is_initialized() 85 | if is_dist: 86 | local_rank = dist.get_rank() 87 | net = nn.parallel.DistributedDataParallel( 88 | net, 89 | device_ids=[local_rank, ], 90 | output_device=local_rank 91 | ) 92 | 93 | ## evaluator 94 | eval_model(net, 2) 95 | 96 | 97 | def parse_args(): 98 | parse = argparse.ArgumentParser() 99 | parse.add_argument('--local_rank', dest='local_rank', 100 | type=int, default=-1,) 101 | parse.add_argument('--weight-path', dest='weight_pth', type=str, 102 | default='model_final.pth',) 103 | parse.add_argument('--port', dest='port', type=int, default=44553,) 104 | parse.add_argument('--respth', dest='respth', type=str, default='./res',) 105 | return parse.parse_args() 106 | 107 | 108 | def main(): 109 | args = parse_args() 110 | if not args.local_rank == -1: 111 | torch.cuda.set_device(args.local_rank) 112 | dist.init_process_group(backend='nccl', 113 | init_method='tcp://127.0.0.1:{}'.format(args.port), 114 | world_size=torch.cuda.device_count(), 115 | rank=args.local_rank 116 | ) 117 | if not osp.exists(args.respth): os.makedirs(args.respth) 118 | setup_logger('BiSeNetV2-eval', args.respth) 119 | evaluate(args.weight_pth) 120 | 121 | 122 | 123 | if __name__ == "__main__": 124 | main() 125 | -------------------------------------------------------------------------------- /old/bisenetv2/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import os.path as osp 6 | import time 7 | import logging 8 | 9 | import torch.distributed as dist 10 | 11 | 12 | def setup_logger(name, logpth): 13 | logfile = '{}-{}.log'.format(name, time.strftime('%Y-%m-%d-%H-%M-%S')) 14 | logfile = osp.join(logpth, logfile) 15 | FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s' 16 | log_level = logging.INFO 17 | if dist.is_initialized() and dist.get_rank() != 0: 18 | log_level = logging.WARNING 19 | logging.basicConfig(level=log_level, format=FORMAT, filename=logfile) 20 | logging.root.addHandler(logging.StreamHandler()) 21 | 22 | 23 | def print_log_msg(it, max_iter, lr, time_meter, loss_meter, loss_pre_meter, 24 | loss_aux_meters): 25 | t_intv, eta = time_meter.get() 26 | loss_avg, _ = loss_meter.get() 27 | loss_pre_avg, _ = loss_pre_meter.get() 28 | loss_aux_avg = ', '.join(['{}: {:.4f}'.format(el.name, el.get()[0]) for el in loss_aux_meters]) 29 | msg = ', '.join([ 30 | 'iter: {it}/{max_it}', 31 | 'lr: {lr:4f}', 32 | 'eta: {eta}', 33 | 'time: {time:.2f}', 34 | 'loss: {loss:.4f}', 35 | 'loss_pre: {loss_pre:.4f}', 36 | ]).format( 37 | it=it+1, 38 | max_it=max_iter, 39 | lr=lr, 40 | time=t_intv, 41 | eta=eta, 42 | loss=loss_avg, 43 | loss_pre=loss_pre_avg, 44 | ) 45 | msg += ', ' + loss_aux_avg 46 | logger = logging.getLogger() 47 | logger.info(msg) 48 | -------------------------------------------------------------------------------- /old/bisenetv2/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import math 5 | from bisect import bisect_right 6 | import torch 7 | 8 | 9 | class WarmupLrScheduler(torch.optim.lr_scheduler._LRScheduler): 10 | 11 | def __init__( 12 | self, 13 | optimizer, 14 | warmup_iter=500, 15 | warmup_ratio=5e-4, 16 | warmup='exp', 17 | last_epoch=-1, 18 | ): 19 | self.warmup_iter = warmup_iter 20 | self.warmup_ratio = warmup_ratio 21 | self.warmup = warmup 22 | super(WarmupLrScheduler, self).__init__(optimizer, last_epoch) 23 | 24 | def get_lr(self): 25 | ratio = self.get_lr_ratio() 26 | lrs = [ratio * lr for lr in self.base_lrs] 27 | return lrs 28 | 29 | def get_lr_ratio(self): 30 | if self.last_epoch < self.warmup_iter: 31 | ratio = self.get_warmup_ratio() 32 | else: 33 | ratio = self.get_main_ratio() 34 | return ratio 35 | 36 | def get_main_ratio(self): 37 | raise NotImplementedError 38 | 39 | def get_warmup_ratio(self): 40 | assert self.warmup in ('linear', 'exp') 41 | alpha = self.last_epoch / self.warmup_iter 42 | if self.warmup == 'linear': 43 | ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha 44 | elif self.warmup == 'exp': 45 | ratio = self.warmup_ratio ** (1. - alpha) 46 | return ratio 47 | 48 | 49 | class WarmupPolyLrScheduler(WarmupLrScheduler): 50 | 51 | def __init__( 52 | self, 53 | optimizer, 54 | power, 55 | max_iter, 56 | warmup_iter=500, 57 | warmup_ratio=5e-4, 58 | warmup='exp', 59 | last_epoch=-1, 60 | ): 61 | self.power = power 62 | self.max_iter = max_iter 63 | super(WarmupPolyLrScheduler, self).__init__( 64 | optimizer, warmup_iter, warmup_ratio, warmup, last_epoch) 65 | 66 | def get_main_ratio(self): 67 | real_iter = self.last_epoch - self.warmup_iter 68 | real_max_iter = self.max_iter - self.warmup_iter 69 | alpha = real_iter / real_max_iter 70 | ratio = (1 - alpha) ** self.power 71 | return ratio 72 | 73 | 74 | class WarmupExpLrScheduler(WarmupLrScheduler): 75 | 76 | def __init__( 77 | self, 78 | optimizer, 79 | gamma, 80 | interval=1, 81 | warmup_iter=500, 82 | warmup_ratio=5e-4, 83 | warmup='exp', 84 | last_epoch=-1, 85 | ): 86 | self.gamma = gamma 87 | self.interval = interval 88 | super(WarmupExpLrScheduler, self).__init__( 89 | optimizer, warmup_iter, warmup_ratio, warmup, last_epoch) 90 | 91 | def get_main_ratio(self): 92 | real_iter = self.last_epoch - self.warmup_iter 93 | ratio = self.gamma ** (real_iter // self.interval) 94 | return ratio 95 | 96 | 97 | class WarmupCosineLrScheduler(WarmupLrScheduler): 98 | 99 | def __init__( 100 | self, 101 | optimizer, 102 | max_iter, 103 | eta_ratio=0, 104 | warmup_iter=500, 105 | warmup_ratio=5e-4, 106 | warmup='exp', 107 | last_epoch=-1, 108 | ): 109 | self.eta_ratio = eta_ratio 110 | self.max_iter = max_iter 111 | super(WarmupCosineLrScheduler, self).__init__( 112 | optimizer, warmup_iter, warmup_ratio, warmup, last_epoch) 113 | 114 | def get_main_ratio(self): 115 | real_iter = self.last_epoch - self.warmup_iter 116 | real_max_iter = self.max_iter - self.warmup_iter 117 | return self.eta_ratio + (1 - self.eta_ratio) * ( 118 | 1 + math.cos(math.pi * self.last_epoch / real_max_iter)) / 2 119 | 120 | 121 | class WarmupStepLrScheduler(WarmupLrScheduler): 122 | 123 | def __init__( 124 | self, 125 | optimizer, 126 | milestones: list, 127 | gamma=0.1, 128 | warmup_iter=500, 129 | warmup_ratio=5e-4, 130 | warmup='exp', 131 | last_epoch=-1, 132 | ): 133 | self.milestones = milestones 134 | self.gamma = gamma 135 | super(WarmupStepLrScheduler, self).__init__( 136 | optimizer, warmup_iter, warmup_ratio, warmup, last_epoch) 137 | 138 | def get_main_ratio(self): 139 | real_iter = self.last_epoch - self.warmup_iter 140 | ratio = self.gamma ** bisect_right(self.milestones, real_iter) 141 | return ratio 142 | 143 | 144 | if __name__ == "__main__": 145 | model = torch.nn.Conv2d(3, 16, 3, 1, 1) 146 | optim = torch.optim.SGD(model.parameters(), lr=1e-3) 147 | 148 | max_iter = 20000 149 | lr_scheduler = WarmupPolyLrScheduler(optim, 0.9, max_iter, 200, 0.1, 'linear', -1) 150 | lrs = [] 151 | for _ in range(max_iter): 152 | lr = lr_scheduler.get_lr()[0] 153 | print(lr) 154 | lrs.append(lr) 155 | lr_scheduler.step() 156 | import matplotlib 157 | import matplotlib.pyplot as plt 158 | import numpy as np 159 | lrs = np.array(lrs) 160 | n_lrs = len(lrs) 161 | plt.plot(np.arange(n_lrs), lrs) 162 | plt.grid() 163 | plt.show() 164 | 165 | 166 | -------------------------------------------------------------------------------- /old/bisenetv2/meters.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | import datetime 4 | 5 | class TimeMeter(object): 6 | 7 | def __init__(self, max_iter): 8 | self.iter = 0 9 | self.max_iter = max_iter 10 | self.st = time.time() 11 | self.global_st = self.st 12 | self.curr = self.st 13 | 14 | def update(self): 15 | self.iter += 1 16 | 17 | def get(self): 18 | self.curr = time.time() 19 | interv = self.curr - self.st 20 | global_interv = self.curr - self.global_st 21 | eta = int((self.max_iter-self.iter) * (global_interv / (self.iter+1))) 22 | eta = str(datetime.timedelta(seconds=eta)) 23 | self.st = self.curr 24 | return interv, eta 25 | 26 | 27 | class AvgMeter(object): 28 | 29 | def __init__(self, name): 30 | self.name = name 31 | self.seq = [] 32 | self.global_seq = [] 33 | 34 | def update(self, val): 35 | self.seq.append(val) 36 | self.global_seq.append(val) 37 | 38 | def get(self): 39 | avg = sum(self.seq) / len(self.seq) 40 | global_avg = sum(self.global_seq) / len(self.global_seq) 41 | self.seq = [] 42 | return avg, global_avg 43 | 44 | -------------------------------------------------------------------------------- /old/bisenetv2/ohem_ce_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | 11 | class OhemCELoss(nn.Module): 12 | 13 | def __init__(self, thresh, ignore_lb=255): 14 | super(OhemCELoss, self).__init__() 15 | self.thresh = -torch.log(torch.tensor(thresh, requires_grad=False, dtype=torch.float)).cuda() 16 | self.ignore_lb = ignore_lb 17 | self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none') 18 | 19 | def forward(self, logits, labels): 20 | n_min = labels[labels != self.ignore_lb].numel() // 16 21 | loss = self.criteria(logits, labels).view(-1) 22 | loss_hard = loss[loss > self.thresh] 23 | if loss_hard.numel() < n_min: 24 | loss_hard, _ = loss.topk(n_min) 25 | return torch.mean(loss_hard) 26 | 27 | 28 | if __name__ == '__main__': 29 | pass 30 | # criteria1 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda() 31 | # criteria2 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda() 32 | 33 | -------------------------------------------------------------------------------- /old/bisenetv2/sampler.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import torch 4 | from torch.utils.data.sampler import Sampler 5 | import torch.distributed as dist 6 | 7 | 8 | class RepeatedDistSampler(Sampler): 9 | """Sampler that restricts data loading to a subset of the dataset. 10 | 11 | It is especially useful in conjunction with 12 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 13 | process can pass a DistributedSampler instance as a DataLoader sampler, 14 | and load a subset of the original dataset that is exclusive to it. 15 | 16 | .. note:: 17 | Dataset is assumed to be of constant size. 18 | 19 | Arguments: 20 | dataset: Dataset used for sampling. 21 | num_replicas (optional): Number of processes participating in 22 | distributed training. 23 | rank (optional): Rank of the current process within num_replicas. 24 | shuffle (optional): If true (default), sampler will shuffle the indices 25 | """ 26 | 27 | def __init__(self, dataset, num_imgs, num_replicas=None, rank=None, shuffle=True): 28 | if num_replicas is None: 29 | if not dist.is_available(): 30 | raise RuntimeError("Requires distributed package to be available") 31 | num_replicas = dist.get_world_size() 32 | if rank is None: 33 | if not dist.is_available(): 34 | raise RuntimeError("Requires distributed package to be available") 35 | rank = dist.get_rank() 36 | self.dataset = dataset 37 | self.num_replicas = num_replicas 38 | self.rank = rank 39 | self.num_imgs_rank = int(math.ceil(num_imgs * 1.0 / self.num_replicas)) 40 | self.total_size = self.num_imgs_rank * self.num_replicas 41 | self.num_imgs = num_imgs 42 | self.shuffle = shuffle 43 | 44 | 45 | def __iter__(self): 46 | # deterministically shuffle based on epoch 47 | g = torch.Generator() 48 | n_repeats = self.num_imgs // len(self.dataset) + 1 49 | indices = [] 50 | for n in range(n_repeats): 51 | if self.shuffle: 52 | g.manual_seed(n) 53 | indices += torch.randperm(len(self.dataset), generator=g).tolist() 54 | else: 55 | indices += [i for i in range(len(self.dataset))] 56 | 57 | # add extra samples to make it evenly divisible 58 | indices = indices[:self.total_size] 59 | assert len(indices) == self.total_size 60 | 61 | # subsample 62 | indices = indices[self.rank:self.total_size:self.num_replicas] 63 | assert len(indices) == self.num_imgs_rank 64 | 65 | return iter(indices) 66 | 67 | def __len__(self): 68 | return self.num_imgs_rank 69 | 70 | -------------------------------------------------------------------------------- /old/cityscapes.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms as transforms 8 | 9 | import os.path as osp 10 | import os 11 | from PIL import Image 12 | import numpy as np 13 | import json 14 | 15 | from transform import * 16 | 17 | 18 | 19 | class CityScapes(Dataset): 20 | def __init__(self, rootpth, cropsize=(640, 480), mode='train', *args, **kwargs): 21 | super(CityScapes, self).__init__(*args, **kwargs) 22 | assert mode in ('train', 'val', 'test') 23 | self.mode = mode 24 | self.ignore_lb = 255 25 | 26 | with open('./cityscapes_info.json', 'r') as fr: 27 | labels_info = json.load(fr) 28 | self.lb_map = {el['id']: el['trainId'] for el in labels_info} 29 | 30 | ## parse img directory 31 | self.imgs = {} 32 | imgnames = [] 33 | impth = osp.join(rootpth, 'leftImg8bit', mode) 34 | folders = os.listdir(impth) 35 | for fd in folders: 36 | fdpth = osp.join(impth, fd) 37 | im_names = os.listdir(fdpth) 38 | names = [el.replace('_leftImg8bit.png', '') for el in im_names] 39 | impths = [osp.join(fdpth, el) for el in im_names] 40 | imgnames.extend(names) 41 | self.imgs.update(dict(zip(names, impths))) 42 | 43 | ## parse gt directory 44 | self.labels = {} 45 | gtnames = [] 46 | gtpth = osp.join(rootpth, 'gtFine', mode) 47 | folders = os.listdir(gtpth) 48 | for fd in folders: 49 | fdpth = osp.join(gtpth, fd) 50 | lbnames = os.listdir(fdpth) 51 | lbnames = [el for el in lbnames if 'labelIds' in el] 52 | names = [el.replace('_gtFine_labelIds.png', '') for el in lbnames] 53 | lbpths = [osp.join(fdpth, el) for el in lbnames] 54 | gtnames.extend(names) 55 | self.labels.update(dict(zip(names, lbpths))) 56 | 57 | self.imnames = imgnames 58 | self.len = len(self.imnames) 59 | assert set(imgnames) == set(gtnames) 60 | assert set(self.imnames) == set(self.imgs.keys()) 61 | assert set(self.imnames) == set(self.labels.keys()) 62 | 63 | ## pre-processing 64 | self.to_tensor = transforms.Compose([ 65 | transforms.ToTensor(), 66 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 67 | ]) 68 | self.trans_train = Compose([ 69 | ColorJitter( 70 | brightness = 0.5, 71 | contrast = 0.5, 72 | saturation = 0.5), 73 | HorizontalFlip(), 74 | RandomScale((0.75, 1.0, 1.25, 1.5, 1.75, 2.0)), 75 | RandomCrop(cropsize) 76 | ]) 77 | 78 | 79 | def __getitem__(self, idx): 80 | fn = self.imnames[idx] 81 | impth = self.imgs[fn] 82 | lbpth = self.labels[fn] 83 | img = Image.open(impth).convert('RGB') 84 | label = Image.open(lbpth) 85 | if self.mode == 'train': 86 | im_lb = dict(im = img, lb = label) 87 | im_lb = self.trans_train(im_lb) 88 | img, label = im_lb['im'], im_lb['lb'] 89 | img = self.to_tensor(img) 90 | label = np.array(label).astype(np.int64)[np.newaxis, :] 91 | label = self.convert_labels(label) 92 | return img, label 93 | 94 | 95 | def __len__(self): 96 | return self.len 97 | 98 | 99 | def convert_labels(self, label): 100 | for k, v in self.lb_map.items(): 101 | label[label == k] = v 102 | return label 103 | 104 | 105 | 106 | if __name__ == "__main__": 107 | from tqdm import tqdm 108 | ds = CityScapes('./data/', n_classes=19, mode='val') 109 | uni = [] 110 | for im, lb in tqdm(ds): 111 | lb_uni = np.unique(lb).tolist() 112 | uni.extend(lb_uni) 113 | print(uni) 114 | print(set(uni)) 115 | 116 | -------------------------------------------------------------------------------- /old/demo.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | import cv2 8 | 9 | from fp16.model import BiSeNet 10 | 11 | 12 | # args 13 | parse = argparse.ArgumentParser() 14 | parse.add_argument( 15 | '--ckpt', 16 | dest='ckpt', 17 | type=str, 18 | default='./res/model_final.pth',) 19 | parse.add_argument( 20 | '--img_path', 21 | dest='img_path', 22 | type=str, 23 | default='./pic.jpg',) 24 | args = parse.parse_args() 25 | 26 | 27 | # define model 28 | net = BiSeNet(n_classes=19) 29 | net.load_state_dict(torch.load(args.ckpt, map_location='cpu')) 30 | net.eval() 31 | net.cuda() 32 | 33 | # prepare data 34 | to_tensor = transforms.Compose([ 35 | transforms.ToTensor(), 36 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 37 | ]) 38 | im = to_tensor(Image.open(args.img_path).convert('RGB')).unsqueeze(0).cuda() 39 | 40 | # inference 41 | out = net(im)[0].argmax(dim=1).squeeze().detach().cpu().numpy() 42 | cv2.imwrite('./res.jpg', out) 43 | -------------------------------------------------------------------------------- /old/diss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CoinCheung/BiSeNet/3d9b2cf592bcb1185cb52d710b476d8a1bde8120/old/diss/__init__.py -------------------------------------------------------------------------------- /old/diss/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import sys 6 | sys.path.insert(0, '.') 7 | from logger import setup_logger 8 | from diss.model import BiSeNet 9 | from cityscapes import CityScapes 10 | from loss import OhemCELoss 11 | from diss.evaluate import evaluate 12 | from optimizer import Optimizer 13 | 14 | import torch 15 | import torch.nn as nn 16 | from torch.utils.data import DataLoader 17 | import torch.nn.functional as F 18 | import torch.distributed as dist 19 | 20 | import os 21 | import os.path as osp 22 | import logging 23 | import time 24 | import datetime 25 | import argparse 26 | 27 | 28 | respth = './res' 29 | if not osp.exists(respth): os.makedirs(respth) 30 | logger = logging.getLogger() 31 | 32 | 33 | def parse_args(): 34 | parse = argparse.ArgumentParser() 35 | parse.add_argument( 36 | '--local_rank', 37 | dest = 'local_rank', 38 | type = int, 39 | default = -1, 40 | ) 41 | return parse.parse_args() 42 | 43 | 44 | def train(): 45 | args = parse_args() 46 | torch.cuda.set_device(args.local_rank) 47 | dist.init_process_group( 48 | backend = 'nccl', 49 | init_method = 'tcp://127.0.0.1:33241', 50 | world_size = torch.cuda.device_count(), 51 | rank=args.local_rank 52 | ) 53 | setup_logger(respth) 54 | 55 | ## dataset 56 | n_classes = 19 57 | n_img_per_gpu = 8 58 | n_workers = 4 59 | cropsize = [1024, 1024] 60 | ds = CityScapes('./data', cropsize=cropsize, mode='train') 61 | sampler = torch.utils.data.distributed.DistributedSampler(ds) 62 | dl = DataLoader(ds, 63 | batch_size = n_img_per_gpu, 64 | shuffle = False, 65 | sampler = sampler, 66 | num_workers = n_workers, 67 | pin_memory = True, 68 | drop_last = True) 69 | 70 | ## model 71 | ignore_idx = 255 72 | net = BiSeNet(n_classes=n_classes) 73 | net.cuda() 74 | net.train() 75 | net = nn.parallel.DistributedDataParallel(net, 76 | device_ids = [args.local_rank, ], 77 | output_device = args.local_rank 78 | ) 79 | score_thres = 0.7 80 | n_min = n_img_per_gpu*cropsize[0]*cropsize[1]//16 81 | LossP = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) 82 | Loss2 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) 83 | Loss3 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) 84 | 85 | ## optimizer 86 | momentum = 0.9 87 | weight_decay = 5e-4 88 | lr_start = 1e-2 89 | max_iter = 80000 90 | power = 0.9 91 | warmup_steps = 1000 92 | warmup_start_lr = 1e-5 93 | optim = Optimizer( 94 | model = net.module, 95 | lr0 = lr_start, 96 | momentum = momentum, 97 | wd = weight_decay, 98 | warmup_steps = warmup_steps, 99 | warmup_start_lr = warmup_start_lr, 100 | max_iter = max_iter, 101 | power = power) 102 | 103 | ## train loop 104 | msg_iter = 50 105 | loss_avg = [] 106 | st = glob_st = time.time() 107 | diter = iter(dl) 108 | epoch = 0 109 | for it in range(max_iter): 110 | try: 111 | im, lb = next(diter) 112 | if not im.size()[0]==n_img_per_gpu: raise StopIteration 113 | except StopIteration: 114 | epoch += 1 115 | sampler.set_epoch(epoch) 116 | diter = iter(dl) 117 | im, lb = next(diter) 118 | im = im.cuda() 119 | lb = lb.cuda() 120 | H, W = im.size()[2:] 121 | lb = torch.squeeze(lb, 1) 122 | 123 | optim.zero_grad() 124 | out, out16, out32 = net(im) 125 | lossp = LossP(out, lb) 126 | loss2 = Loss2(out16, lb) 127 | loss3 = Loss3(out32, lb) 128 | loss = lossp + loss2 + loss3 129 | loss.backward() 130 | optim.step() 131 | 132 | loss_avg.append(loss.item()) 133 | ## print training log message 134 | if (it+1)%msg_iter==0: 135 | loss_avg = sum(loss_avg) / len(loss_avg) 136 | lr = optim.lr 137 | ed = time.time() 138 | t_intv, glob_t_intv = ed - st, ed - glob_st 139 | eta = int((max_iter - it) * (glob_t_intv / it)) 140 | eta = str(datetime.timedelta(seconds=eta)) 141 | msg = ', '.join([ 142 | 'it: {it}/{max_it}', 143 | 'lr: {lr:4f}', 144 | 'loss: {loss:.4f}', 145 | 'eta: {eta}', 146 | 'time: {time:.4f}', 147 | ]).format( 148 | it = it+1, 149 | max_it = max_iter, 150 | lr = lr, 151 | loss = loss_avg, 152 | time = t_intv, 153 | eta = eta 154 | ) 155 | logger.info(msg) 156 | loss_avg = [] 157 | st = ed 158 | 159 | ## dump the final model 160 | save_pth = osp.join(respth, 'model_final_diss.pth') 161 | net.cpu() 162 | state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict() 163 | if dist.get_rank()==0: torch.save(state, save_pth) 164 | logger.info('training done, model saved to: {}'.format(save_pth)) 165 | 166 | 167 | if __name__ == "__main__": 168 | train() 169 | evaluate() 170 | -------------------------------------------------------------------------------- /old/fp16/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CoinCheung/BiSeNet/3d9b2cf592bcb1185cb52d710b476d8a1bde8120/old/fp16/__init__.py -------------------------------------------------------------------------------- /old/fp16/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.model_zoo as modelzoo 8 | 9 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 10 | 11 | # from torch.nn import BatchNorm2d 12 | def BatchNorm2d(out_chan): 13 | return nn.SyncBatchNorm.convert_sync_batchnorm(nn.BatchNorm2d(out_chan)) 14 | 15 | 16 | def conv3x3(in_planes, out_planes, stride=1): 17 | """3x3 convolution with padding""" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=1, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | def __init__(self, in_chan, out_chan, stride=1): 24 | super(BasicBlock, self).__init__() 25 | self.conv1 = conv3x3(in_chan, out_chan, stride) 26 | self.bn1 = BatchNorm2d(out_chan) 27 | self.conv2 = conv3x3(out_chan, out_chan) 28 | self.bn2 = BatchNorm2d(out_chan) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.downsample = None 31 | if in_chan != out_chan or stride != 1: 32 | self.downsample = nn.Sequential( 33 | nn.Conv2d(in_chan, out_chan, 34 | kernel_size=1, stride=stride, bias=False), 35 | BatchNorm2d(out_chan), 36 | ) 37 | 38 | def forward(self, x): 39 | residual = self.conv1(x) 40 | residual = self.bn1(residual) 41 | residual = self.relu(residual) 42 | residual = self.conv2(residual) 43 | residual = self.bn2(residual) 44 | 45 | shortcut = x 46 | if self.downsample is not None: 47 | shortcut = self.downsample(x) 48 | 49 | out = shortcut + residual 50 | out = self.relu(out) 51 | return out 52 | 53 | 54 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 55 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 56 | for i in range(bnum-1): 57 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 58 | return nn.Sequential(*layers) 59 | 60 | 61 | class Resnet18(nn.Module): 62 | def __init__(self): 63 | super(Resnet18, self).__init__() 64 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 65 | bias=False) 66 | self.bn1 = BatchNorm2d(64) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 69 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 70 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 71 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 72 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 73 | self.init_weight() 74 | 75 | def forward(self, x): 76 | x = self.conv1(x) 77 | x = self.bn1(x) 78 | x = self.relu(x) 79 | x = self.maxpool(x) 80 | 81 | x = self.layer1(x) 82 | feat8 = self.layer2(x) # 1/8 83 | feat16 = self.layer3(feat8) # 1/16 84 | feat32 = self.layer4(feat16) # 1/32 85 | return feat8, feat16, feat32 86 | 87 | def init_weight(self): 88 | state_dict = modelzoo.load_url(resnet18_url) 89 | self_state_dict = self.state_dict() 90 | for k, v in state_dict.items(): 91 | if 'fc' in k: continue 92 | self_state_dict.update({k: v}) 93 | self.load_state_dict(self_state_dict) 94 | 95 | def get_params(self): 96 | wd_params, nowd_params = [], [] 97 | for name, module in self.named_modules(): 98 | if isinstance(module, (nn.Linear, nn.Conv2d)): 99 | wd_params.append(module.weight) 100 | if not module.bias is None: 101 | nowd_params.append(module.bias) 102 | elif isinstance(module, nn.modules.batchnorm._BatchNorm): 103 | nowd_params += list(module.parameters()) 104 | return wd_params, nowd_params 105 | 106 | 107 | if __name__ == "__main__": 108 | net = Resnet18() 109 | x = torch.randn(16, 3, 224, 224) 110 | out = net(x) 111 | print(out[0].size()) 112 | print(out[1].size()) 113 | print(out[2].size()) 114 | net.get_params() 115 | -------------------------------------------------------------------------------- /old/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import os.path as osp 6 | import time 7 | import sys 8 | import logging 9 | 10 | import torch.distributed as dist 11 | 12 | 13 | def setup_logger(logpth): 14 | logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S')) 15 | logfile = osp.join(logpth, logfile) 16 | FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s' 17 | log_level = logging.INFO 18 | if dist.is_initialized() and not dist.get_rank()==0: 19 | log_level = logging.ERROR 20 | logging.basicConfig(level=log_level, format=FORMAT, filename=logfile) 21 | logging.root.addHandler(logging.StreamHandler()) 22 | 23 | 24 | -------------------------------------------------------------------------------- /old/loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import numpy as np 10 | 11 | 12 | class OhemCELoss(nn.Module): 13 | def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs): 14 | super(OhemCELoss, self).__init__() 15 | self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda() 16 | self.n_min = n_min 17 | self.ignore_lb = ignore_lb 18 | self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none') 19 | 20 | def forward(self, logits, labels): 21 | N, C, H, W = logits.size() 22 | loss = self.criteria(logits, labels).view(-1) 23 | loss, _ = torch.sort(loss, descending=True) 24 | if loss[self.n_min] > self.thresh: 25 | loss = loss[loss>self.thresh] 26 | else: 27 | loss = loss[:self.n_min] 28 | return torch.mean(loss) 29 | 30 | 31 | class SoftmaxFocalLoss(nn.Module): 32 | def __init__(self, gamma, ignore_lb=255, *args, **kwargs): 33 | super(FocalLoss, self).__init__() 34 | self.gamma = gamma 35 | self.nll = nn.NLLLoss(ignore_index=ignore_lb) 36 | 37 | def forward(self, logits, labels): 38 | scores = F.softmax(logits, dim=1) 39 | factor = torch.pow(1.-scores, self.gamma) 40 | log_score = F.log_softmax(logits, dim=1) 41 | log_score = factor * log_score 42 | loss = self.nll(log_score, labels) 43 | return loss 44 | 45 | 46 | if __name__ == '__main__': 47 | torch.manual_seed(15) 48 | criteria1 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda() 49 | criteria2 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda() 50 | net1 = nn.Sequential( 51 | nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1), 52 | ) 53 | net1.cuda() 54 | net1.train() 55 | net2 = nn.Sequential( 56 | nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1), 57 | ) 58 | net2.cuda() 59 | net2.train() 60 | 61 | with torch.no_grad(): 62 | inten = torch.randn(16, 3, 20, 20).cuda() 63 | lbs = torch.randint(0, 19, [16, 20, 20]).cuda() 64 | lbs[1, :, :] = 255 65 | 66 | logits1 = net1(inten) 67 | logits1 = F.interpolate(logits1, inten.size()[2:], mode='bilinear') 68 | logits2 = net2(inten) 69 | logits2 = F.interpolate(logits2, inten.size()[2:], mode='bilinear') 70 | 71 | loss1 = criteria1(logits1, lbs) 72 | loss2 = criteria2(logits2, lbs) 73 | loss = loss1 + loss2 74 | print(loss.detach().cpu()) 75 | loss.backward() 76 | -------------------------------------------------------------------------------- /old/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .bn import ABN, InPlaceABN, InPlaceABNSync 2 | from .functions import ACT_RELU, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE 3 | from .misc import GlobalAvgPool2d, SingleGPU 4 | from .residual import IdentityResidualBlock 5 | from .dense import DenseModule 6 | -------------------------------------------------------------------------------- /old/modules/bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as functional 4 | 5 | try: 6 | from queue import Queue 7 | except ImportError: 8 | from Queue import Queue 9 | 10 | from .functions import * 11 | 12 | 13 | class ABN(nn.Module): 14 | """Activated Batch Normalization 15 | 16 | This gathers a `BatchNorm2d` and an activation function in a single module 17 | """ 18 | 19 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): 20 | """Creates an Activated Batch Normalization module 21 | 22 | Parameters 23 | ---------- 24 | num_features : int 25 | Number of feature channels in the input and output. 26 | eps : float 27 | Small constant to prevent numerical issues. 28 | momentum : float 29 | Momentum factor applied to compute running statistics as. 30 | affine : bool 31 | If `True` apply learned scale and shift transformation after normalization. 32 | activation : str 33 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. 34 | slope : float 35 | Negative slope for the `leaky_relu` activation. 36 | """ 37 | super(ABN, self).__init__() 38 | self.num_features = num_features 39 | self.affine = affine 40 | self.eps = eps 41 | self.momentum = momentum 42 | self.activation = activation 43 | self.slope = slope 44 | if self.affine: 45 | self.weight = nn.Parameter(torch.ones(num_features)) 46 | self.bias = nn.Parameter(torch.zeros(num_features)) 47 | else: 48 | self.register_parameter('weight', None) 49 | self.register_parameter('bias', None) 50 | self.register_buffer('running_mean', torch.zeros(num_features)) 51 | self.register_buffer('running_var', torch.ones(num_features)) 52 | self.reset_parameters() 53 | 54 | def reset_parameters(self): 55 | nn.init.constant_(self.running_mean, 0) 56 | nn.init.constant_(self.running_var, 1) 57 | if self.affine: 58 | nn.init.constant_(self.weight, 1) 59 | nn.init.constant_(self.bias, 0) 60 | 61 | def forward(self, x): 62 | x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, 63 | self.training, self.momentum, self.eps) 64 | 65 | if self.activation == ACT_RELU: 66 | return functional.relu(x, inplace=True) 67 | elif self.activation == ACT_LEAKY_RELU: 68 | return functional.leaky_relu(x, negative_slope=self.slope, inplace=True) 69 | elif self.activation == ACT_ELU: 70 | return functional.elu(x, inplace=True) 71 | else: 72 | return x 73 | 74 | def __repr__(self): 75 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ 76 | ' affine={affine}, activation={activation}' 77 | if self.activation == "leaky_relu": 78 | rep += ', slope={slope})' 79 | else: 80 | rep += ')' 81 | return rep.format(name=self.__class__.__name__, **self.__dict__) 82 | 83 | 84 | class InPlaceABN(ABN): 85 | """InPlace Activated Batch Normalization""" 86 | 87 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): 88 | """Creates an InPlace Activated Batch Normalization module 89 | 90 | Parameters 91 | ---------- 92 | num_features : int 93 | Number of feature channels in the input and output. 94 | eps : float 95 | Small constant to prevent numerical issues. 96 | momentum : float 97 | Momentum factor applied to compute running statistics as. 98 | affine : bool 99 | If `True` apply learned scale and shift transformation after normalization. 100 | activation : str 101 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. 102 | slope : float 103 | Negative slope for the `leaky_relu` activation. 104 | """ 105 | super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope) 106 | 107 | def forward(self, x): 108 | return inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var, 109 | self.training, self.momentum, self.eps, self.activation, self.slope) 110 | 111 | 112 | class InPlaceABNSync(ABN): 113 | """InPlace Activated Batch Normalization with cross-GPU synchronization 114 | This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DistributedDataParallel`. 115 | """ 116 | 117 | def forward(self, x): 118 | return inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var, 119 | self.training, self.momentum, self.eps, self.activation, self.slope) 120 | 121 | def __repr__(self): 122 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ 123 | ' affine={affine}, activation={activation}' 124 | if self.activation == "leaky_relu": 125 | rep += ', slope={slope})' 126 | else: 127 | rep += ')' 128 | return rep.format(name=self.__class__.__name__, **self.__dict__) 129 | 130 | 131 | -------------------------------------------------------------------------------- /old/modules/deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as functional 4 | 5 | from models._util import try_index 6 | from .bn import ABN 7 | 8 | 9 | class DeeplabV3(nn.Module): 10 | def __init__(self, 11 | in_channels, 12 | out_channels, 13 | hidden_channels=256, 14 | dilations=(12, 24, 36), 15 | norm_act=ABN, 16 | pooling_size=None): 17 | super(DeeplabV3, self).__init__() 18 | self.pooling_size = pooling_size 19 | 20 | self.map_convs = nn.ModuleList([ 21 | nn.Conv2d(in_channels, hidden_channels, 1, bias=False), 22 | nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[0], padding=dilations[0]), 23 | nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[1], padding=dilations[1]), 24 | nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[2], padding=dilations[2]) 25 | ]) 26 | self.map_bn = norm_act(hidden_channels * 4) 27 | 28 | self.global_pooling_conv = nn.Conv2d(in_channels, hidden_channels, 1, bias=False) 29 | self.global_pooling_bn = norm_act(hidden_channels) 30 | 31 | self.red_conv = nn.Conv2d(hidden_channels * 4, out_channels, 1, bias=False) 32 | self.pool_red_conv = nn.Conv2d(hidden_channels, out_channels, 1, bias=False) 33 | self.red_bn = norm_act(out_channels) 34 | 35 | self.reset_parameters(self.map_bn.activation, self.map_bn.slope) 36 | 37 | def reset_parameters(self, activation, slope): 38 | gain = nn.init.calculate_gain(activation, slope) 39 | for m in self.modules(): 40 | if isinstance(m, nn.Conv2d): 41 | nn.init.xavier_normal_(m.weight.data, gain) 42 | if hasattr(m, "bias") and m.bias is not None: 43 | nn.init.constant_(m.bias, 0) 44 | elif isinstance(m, ABN): 45 | if hasattr(m, "weight") and m.weight is not None: 46 | nn.init.constant_(m.weight, 1) 47 | if hasattr(m, "bias") and m.bias is not None: 48 | nn.init.constant_(m.bias, 0) 49 | 50 | def forward(self, x): 51 | # Map convolutions 52 | out = torch.cat([m(x) for m in self.map_convs], dim=1) 53 | out = self.map_bn(out) 54 | out = self.red_conv(out) 55 | 56 | # Global pooling 57 | pool = self._global_pooling(x) 58 | pool = self.global_pooling_conv(pool) 59 | pool = self.global_pooling_bn(pool) 60 | pool = self.pool_red_conv(pool) 61 | if self.training or self.pooling_size is None: 62 | pool = pool.repeat(1, 1, x.size(2), x.size(3)) 63 | 64 | out += pool 65 | out = self.red_bn(out) 66 | return out 67 | 68 | def _global_pooling(self, x): 69 | if self.training or self.pooling_size is None: 70 | pool = x.view(x.size(0), x.size(1), -1).mean(dim=-1) 71 | pool = pool.view(x.size(0), x.size(1), 1, 1) 72 | else: 73 | pooling_size = (min(try_index(self.pooling_size, 0), x.shape[2]), 74 | min(try_index(self.pooling_size, 1), x.shape[3])) 75 | padding = ( 76 | (pooling_size[1] - 1) // 2, 77 | (pooling_size[1] - 1) // 2 if pooling_size[1] % 2 == 1 else (pooling_size[1] - 1) // 2 + 1, 78 | (pooling_size[0] - 1) // 2, 79 | (pooling_size[0] - 1) // 2 if pooling_size[0] % 2 == 1 else (pooling_size[0] - 1) // 2 + 1 80 | ) 81 | 82 | pool = functional.avg_pool2d(x, pooling_size, stride=1) 83 | pool = functional.pad(pool, pad=padding, mode="replicate") 84 | return pool 85 | -------------------------------------------------------------------------------- /old/modules/dense.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .bn import ABN 7 | 8 | 9 | class DenseModule(nn.Module): 10 | def __init__(self, in_channels, growth, layers, bottleneck_factor=4, norm_act=ABN, dilation=1): 11 | super(DenseModule, self).__init__() 12 | self.in_channels = in_channels 13 | self.growth = growth 14 | self.layers = layers 15 | 16 | self.convs1 = nn.ModuleList() 17 | self.convs3 = nn.ModuleList() 18 | for i in range(self.layers): 19 | self.convs1.append(nn.Sequential(OrderedDict([ 20 | ("bn", norm_act(in_channels)), 21 | ("conv", nn.Conv2d(in_channels, self.growth * bottleneck_factor, 1, bias=False)) 22 | ]))) 23 | self.convs3.append(nn.Sequential(OrderedDict([ 24 | ("bn", norm_act(self.growth * bottleneck_factor)), 25 | ("conv", nn.Conv2d(self.growth * bottleneck_factor, self.growth, 3, padding=dilation, bias=False, 26 | dilation=dilation)) 27 | ]))) 28 | in_channels += self.growth 29 | 30 | @property 31 | def out_channels(self): 32 | return self.in_channels + self.growth * self.layers 33 | 34 | def forward(self, x): 35 | inputs = [x] 36 | for i in range(self.layers): 37 | x = torch.cat(inputs, dim=1) 38 | x = self.convs1[i](x) 39 | x = self.convs3[i](x) 40 | inputs += [x] 41 | 42 | return torch.cat(inputs, dim=1) 43 | -------------------------------------------------------------------------------- /old/modules/misc.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.distributed as dist 4 | 5 | class GlobalAvgPool2d(nn.Module): 6 | def __init__(self): 7 | """Global average pooling over the input's spatial dimensions""" 8 | super(GlobalAvgPool2d, self).__init__() 9 | 10 | def forward(self, inputs): 11 | in_size = inputs.size() 12 | return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2) 13 | 14 | class SingleGPU(nn.Module): 15 | def __init__(self, module): 16 | super(SingleGPU, self).__init__() 17 | self.module=module 18 | 19 | def forward(self, input): 20 | return self.module(input.cuda(non_blocking=True)) 21 | 22 | -------------------------------------------------------------------------------- /old/modules/residual.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch.nn as nn 4 | 5 | from .bn import ABN 6 | 7 | 8 | class IdentityResidualBlock(nn.Module): 9 | def __init__(self, 10 | in_channels, 11 | channels, 12 | stride=1, 13 | dilation=1, 14 | groups=1, 15 | norm_act=ABN, 16 | dropout=None): 17 | """Configurable identity-mapping residual block 18 | 19 | Parameters 20 | ---------- 21 | in_channels : int 22 | Number of input channels. 23 | channels : list of int 24 | Number of channels in the internal feature maps. Can either have two or three elements: if three construct 25 | a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then 26 | `3 x 3` then `1 x 1` convolutions. 27 | stride : int 28 | Stride of the first `3 x 3` convolution 29 | dilation : int 30 | Dilation to apply to the `3 x 3` convolutions. 31 | groups : int 32 | Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with 33 | bottleneck blocks. 34 | norm_act : callable 35 | Function to create normalization / activation Module. 36 | dropout: callable 37 | Function to create Dropout Module. 38 | """ 39 | super(IdentityResidualBlock, self).__init__() 40 | 41 | # Check parameters for inconsistencies 42 | if len(channels) != 2 and len(channels) != 3: 43 | raise ValueError("channels must contain either two or three values") 44 | if len(channels) == 2 and groups != 1: 45 | raise ValueError("groups > 1 are only valid if len(channels) == 3") 46 | 47 | is_bottleneck = len(channels) == 3 48 | need_proj_conv = stride != 1 or in_channels != channels[-1] 49 | 50 | self.bn1 = norm_act(in_channels) 51 | if not is_bottleneck: 52 | layers = [ 53 | ("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False, 54 | dilation=dilation)), 55 | ("bn2", norm_act(channels[0])), 56 | ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False, 57 | dilation=dilation)) 58 | ] 59 | if dropout is not None: 60 | layers = layers[0:2] + [("dropout", dropout())] + layers[2:] 61 | else: 62 | layers = [ 63 | ("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=stride, padding=0, bias=False)), 64 | ("bn2", norm_act(channels[0])), 65 | ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False, 66 | groups=groups, dilation=dilation)), 67 | ("bn3", norm_act(channels[1])), 68 | ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False)) 69 | ] 70 | if dropout is not None: 71 | layers = layers[0:4] + [("dropout", dropout())] + layers[4:] 72 | self.convs = nn.Sequential(OrderedDict(layers)) 73 | 74 | if need_proj_conv: 75 | self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False) 76 | 77 | def forward(self, x): 78 | if hasattr(self, "proj_conv"): 79 | bn1 = self.bn1(x) 80 | shortcut = self.proj_conv(bn1) 81 | else: 82 | shortcut = x.clone() 83 | bn1 = self.bn1(x) 84 | 85 | out = self.convs(bn1) 86 | out.add_(shortcut) 87 | 88 | return out 89 | -------------------------------------------------------------------------------- /old/modules/src/checks.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | // Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT 6 | #ifndef AT_CHECK 7 | #define AT_CHECK AT_ASSERT 8 | #endif 9 | 10 | #define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor") 12 | #define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous") 13 | 14 | #define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 15 | #define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x) -------------------------------------------------------------------------------- /old/modules/src/inplace_abn.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "inplace_abn.h" 6 | 7 | std::vector mean_var(at::Tensor x) { 8 | if (x.is_cuda()) { 9 | if (x.type().scalarType() == at::ScalarType::Half) { 10 | return mean_var_cuda_h(x); 11 | } else { 12 | return mean_var_cuda(x); 13 | } 14 | } else { 15 | return mean_var_cpu(x); 16 | } 17 | } 18 | 19 | at::Tensor forward(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 20 | bool affine, float eps) { 21 | if (x.is_cuda()) { 22 | if (x.type().scalarType() == at::ScalarType::Half) { 23 | return forward_cuda_h(x, mean, var, weight, bias, affine, eps); 24 | } else { 25 | return forward_cuda(x, mean, var, weight, bias, affine, eps); 26 | } 27 | } else { 28 | return forward_cpu(x, mean, var, weight, bias, affine, eps); 29 | } 30 | } 31 | 32 | std::vector edz_eydz(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 33 | bool affine, float eps) { 34 | if (z.is_cuda()) { 35 | if (z.type().scalarType() == at::ScalarType::Half) { 36 | return edz_eydz_cuda_h(z, dz, weight, bias, affine, eps); 37 | } else { 38 | return edz_eydz_cuda(z, dz, weight, bias, affine, eps); 39 | } 40 | } else { 41 | return edz_eydz_cpu(z, dz, weight, bias, affine, eps); 42 | } 43 | } 44 | 45 | at::Tensor backward(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 46 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) { 47 | if (z.is_cuda()) { 48 | if (z.type().scalarType() == at::ScalarType::Half) { 49 | return backward_cuda_h(z, dz, var, weight, bias, edz, eydz, affine, eps); 50 | } else { 51 | return backward_cuda(z, dz, var, weight, bias, edz, eydz, affine, eps); 52 | } 53 | } else { 54 | return backward_cpu(z, dz, var, weight, bias, edz, eydz, affine, eps); 55 | } 56 | } 57 | 58 | void leaky_relu_forward(at::Tensor z, float slope) { 59 | at::leaky_relu_(z, slope); 60 | } 61 | 62 | void leaky_relu_backward(at::Tensor z, at::Tensor dz, float slope) { 63 | if (z.is_cuda()) { 64 | if (z.type().scalarType() == at::ScalarType::Half) { 65 | return leaky_relu_backward_cuda_h(z, dz, slope); 66 | } else { 67 | return leaky_relu_backward_cuda(z, dz, slope); 68 | } 69 | } else { 70 | return leaky_relu_backward_cpu(z, dz, slope); 71 | } 72 | } 73 | 74 | void elu_forward(at::Tensor z) { 75 | at::elu_(z); 76 | } 77 | 78 | void elu_backward(at::Tensor z, at::Tensor dz) { 79 | if (z.is_cuda()) { 80 | return elu_backward_cuda(z, dz); 81 | } else { 82 | return elu_backward_cpu(z, dz); 83 | } 84 | } 85 | 86 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 87 | m.def("mean_var", &mean_var, "Mean and variance computation"); 88 | m.def("forward", &forward, "In-place forward computation"); 89 | m.def("edz_eydz", &edz_eydz, "First part of backward computation"); 90 | m.def("backward", &backward, "Second part of backward computation"); 91 | m.def("leaky_relu_forward", &leaky_relu_forward, "Leaky relu forward computation"); 92 | m.def("leaky_relu_backward", &leaky_relu_backward, "Leaky relu backward computation and inversion"); 93 | m.def("elu_forward", &elu_forward, "Elu forward computation"); 94 | m.def("elu_backward", &elu_backward, "Elu backward computation and inversion"); 95 | } 96 | -------------------------------------------------------------------------------- /old/modules/src/inplace_abn.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | 7 | std::vector mean_var_cpu(at::Tensor x); 8 | std::vector mean_var_cuda(at::Tensor x); 9 | std::vector mean_var_cuda_h(at::Tensor x); 10 | 11 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 12 | bool affine, float eps); 13 | at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 14 | bool affine, float eps); 15 | at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 16 | bool affine, float eps); 17 | 18 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 19 | bool affine, float eps); 20 | std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 21 | bool affine, float eps); 22 | std::vector edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 23 | bool affine, float eps); 24 | 25 | at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 26 | at::Tensor edz, at::Tensor eydz, bool affine, float eps); 27 | at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 28 | at::Tensor edz, at::Tensor eydz, bool affine, float eps); 29 | at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 30 | at::Tensor edz, at::Tensor eydz, bool affine, float eps); 31 | 32 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope); 33 | void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope); 34 | void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope); 35 | 36 | void elu_backward_cpu(at::Tensor z, at::Tensor dz); 37 | void elu_backward_cuda(at::Tensor z, at::Tensor dz); 38 | 39 | static void get_dims(at::Tensor x, int64_t& num, int64_t& chn, int64_t& sp) { 40 | num = x.size(0); 41 | chn = x.size(1); 42 | sp = 1; 43 | for (int64_t i = 2; i < x.ndimension(); ++i) 44 | sp *= x.size(i); 45 | } 46 | 47 | /* 48 | * Specialized CUDA reduction functions for BN 49 | */ 50 | #ifdef __CUDACC__ 51 | 52 | #include "utils/cuda.cuh" 53 | 54 | template 55 | __device__ T reduce(Op op, int plane, int N, int S) { 56 | T sum = (T)0; 57 | for (int batch = 0; batch < N; ++batch) { 58 | for (int x = threadIdx.x; x < S; x += blockDim.x) { 59 | sum += op(batch, plane, x); 60 | } 61 | } 62 | 63 | // sum over NumThreads within a warp 64 | sum = warpSum(sum); 65 | 66 | // 'transpose', and reduce within warp again 67 | __shared__ T shared[32]; 68 | __syncthreads(); 69 | if (threadIdx.x % WARP_SIZE == 0) { 70 | shared[threadIdx.x / WARP_SIZE] = sum; 71 | } 72 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { 73 | // zero out the other entries in shared 74 | shared[threadIdx.x] = (T)0; 75 | } 76 | __syncthreads(); 77 | if (threadIdx.x / WARP_SIZE == 0) { 78 | sum = warpSum(shared[threadIdx.x]); 79 | if (threadIdx.x == 0) { 80 | shared[0] = sum; 81 | } 82 | } 83 | __syncthreads(); 84 | 85 | // Everyone picks it up, should be broadcast into the whole gradInput 86 | return shared[0]; 87 | } 88 | #endif 89 | -------------------------------------------------------------------------------- /old/modules/src/inplace_abn_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "utils/checks.h" 6 | #include "inplace_abn.h" 7 | 8 | at::Tensor reduce_sum(at::Tensor x) { 9 | if (x.ndimension() == 2) { 10 | return x.sum(0); 11 | } else { 12 | auto x_view = x.view({x.size(0), x.size(1), -1}); 13 | return x_view.sum(-1).sum(0); 14 | } 15 | } 16 | 17 | at::Tensor broadcast_to(at::Tensor v, at::Tensor x) { 18 | if (x.ndimension() == 2) { 19 | return v; 20 | } else { 21 | std::vector broadcast_size = {1, -1}; 22 | for (int64_t i = 2; i < x.ndimension(); ++i) 23 | broadcast_size.push_back(1); 24 | 25 | return v.view(broadcast_size); 26 | } 27 | } 28 | 29 | int64_t count(at::Tensor x) { 30 | int64_t count = x.size(0); 31 | for (int64_t i = 2; i < x.ndimension(); ++i) 32 | count *= x.size(i); 33 | 34 | return count; 35 | } 36 | 37 | at::Tensor invert_affine(at::Tensor z, at::Tensor weight, at::Tensor bias, bool affine, float eps) { 38 | if (affine) { 39 | return (z - broadcast_to(bias, z)) / broadcast_to(at::abs(weight) + eps, z); 40 | } else { 41 | return z; 42 | } 43 | } 44 | 45 | std::vector mean_var_cpu(at::Tensor x) { 46 | auto num = count(x); 47 | auto mean = reduce_sum(x) / num; 48 | auto diff = x - broadcast_to(mean, x); 49 | auto var = reduce_sum(diff.pow(2)) / num; 50 | 51 | return {mean, var}; 52 | } 53 | 54 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 55 | bool affine, float eps) { 56 | auto gamma = affine ? at::abs(weight) + eps : at::ones_like(var); 57 | auto mul = at::rsqrt(var + eps) * gamma; 58 | 59 | x.sub_(broadcast_to(mean, x)); 60 | x.mul_(broadcast_to(mul, x)); 61 | if (affine) x.add_(broadcast_to(bias, x)); 62 | 63 | return x; 64 | } 65 | 66 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 67 | bool affine, float eps) { 68 | auto edz = reduce_sum(dz); 69 | auto y = invert_affine(z, weight, bias, affine, eps); 70 | auto eydz = reduce_sum(y * dz); 71 | 72 | return {edz, eydz}; 73 | } 74 | 75 | at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 76 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) { 77 | auto y = invert_affine(z, weight, bias, affine, eps); 78 | auto mul = affine ? at::rsqrt(var + eps) * (at::abs(weight) + eps) : at::rsqrt(var + eps); 79 | 80 | auto num = count(z); 81 | auto dx = (dz - broadcast_to(edz / num, dz) - y * broadcast_to(eydz / num, dz)) * broadcast_to(mul, dz); 82 | return dx; 83 | } 84 | 85 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope) { 86 | CHECK_CPU_INPUT(z); 87 | CHECK_CPU_INPUT(dz); 88 | 89 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cpu", ([&] { 90 | int64_t count = z.numel(); 91 | auto *_z = z.data(); 92 | auto *_dz = dz.data(); 93 | 94 | for (int64_t i = 0; i < count; ++i) { 95 | if (_z[i] < 0) { 96 | _z[i] *= 1 / slope; 97 | _dz[i] *= slope; 98 | } 99 | } 100 | })); 101 | } 102 | 103 | void elu_backward_cpu(at::Tensor z, at::Tensor dz) { 104 | CHECK_CPU_INPUT(z); 105 | CHECK_CPU_INPUT(dz); 106 | 107 | AT_DISPATCH_FLOATING_TYPES(z.type(), "elu_backward_cpu", ([&] { 108 | int64_t count = z.numel(); 109 | auto *_z = z.data(); 110 | auto *_dz = dz.data(); 111 | 112 | for (int64_t i = 0; i < count; ++i) { 113 | if (_z[i] < 0) { 114 | _z[i] = log1p(_z[i]); 115 | _dz[i] *= (_z[i] + 1.f); 116 | } 117 | } 118 | })); 119 | } 120 | -------------------------------------------------------------------------------- /old/modules/src/utils/checks.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | // Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT 6 | #ifndef AT_CHECK 7 | #define AT_CHECK AT_ASSERT 8 | #endif 9 | 10 | #define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor") 12 | #define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous") 13 | 14 | #define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 15 | #define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x) -------------------------------------------------------------------------------- /old/modules/src/utils/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | /* 6 | * Functions to share code between CPU and GPU 7 | */ 8 | 9 | #ifdef __CUDACC__ 10 | // CUDA versions 11 | 12 | #define HOST_DEVICE __host__ __device__ 13 | #define INLINE_HOST_DEVICE __host__ __device__ inline 14 | #define FLOOR(x) floor(x) 15 | 16 | #if __CUDA_ARCH__ >= 600 17 | // Recent compute capabilities have block-level atomicAdd for all data types, so we use that 18 | #define ACCUM(x,y) atomicAdd_block(&(x),(y)) 19 | #else 20 | // Older architectures don't have block-level atomicAdd, nor atomicAdd for doubles, so we defer to atomicAdd for float 21 | // and use the known atomicCAS-based implementation for double 22 | template 23 | __device__ inline data_t atomic_add(data_t *address, data_t val) { 24 | return atomicAdd(address, val); 25 | } 26 | 27 | template<> 28 | __device__ inline double atomic_add(double *address, double val) { 29 | unsigned long long int* address_as_ull = (unsigned long long int*)address; 30 | unsigned long long int old = *address_as_ull, assumed; 31 | do { 32 | assumed = old; 33 | old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed))); 34 | } while (assumed != old); 35 | return __longlong_as_double(old); 36 | } 37 | 38 | #define ACCUM(x,y) atomic_add(&(x),(y)) 39 | #endif // #if __CUDA_ARCH__ >= 600 40 | 41 | #else 42 | // CPU versions 43 | 44 | #define HOST_DEVICE 45 | #define INLINE_HOST_DEVICE inline 46 | #define FLOOR(x) std::floor(x) 47 | #define ACCUM(x,y) (x) += (y) 48 | 49 | #endif // #ifdef __CUDACC__ -------------------------------------------------------------------------------- /old/modules/src/utils/cuda.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /* 4 | * General settings and functions 5 | */ 6 | const int WARP_SIZE = 32; 7 | const int MAX_BLOCK_SIZE = 1024; 8 | 9 | static int getNumThreads(int nElem) { 10 | int threadSizes[6] = {32, 64, 128, 256, 512, MAX_BLOCK_SIZE}; 11 | for (int i = 0; i < 6; ++i) { 12 | if (nElem <= threadSizes[i]) { 13 | return threadSizes[i]; 14 | } 15 | } 16 | return MAX_BLOCK_SIZE; 17 | } 18 | 19 | /* 20 | * Reduction utilities 21 | */ 22 | template 23 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, 24 | unsigned int mask = 0xffffffff) { 25 | #if CUDART_VERSION >= 9000 26 | return __shfl_xor_sync(mask, value, laneMask, width); 27 | #else 28 | return __shfl_xor(value, laneMask, width); 29 | #endif 30 | } 31 | 32 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); } 33 | 34 | template 35 | struct Pair { 36 | T v1, v2; 37 | __device__ Pair() {} 38 | __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {} 39 | __device__ Pair(T v) : v1(v), v2(v) {} 40 | __device__ Pair(int v) : v1(v), v2(v) {} 41 | __device__ Pair &operator+=(const Pair &a) { 42 | v1 += a.v1; 43 | v2 += a.v2; 44 | return *this; 45 | } 46 | }; 47 | 48 | template 49 | static __device__ __forceinline__ T warpSum(T val) { 50 | #if __CUDA_ARCH__ >= 300 51 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) { 52 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); 53 | } 54 | #else 55 | __shared__ T values[MAX_BLOCK_SIZE]; 56 | values[threadIdx.x] = val; 57 | __threadfence_block(); 58 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE; 59 | for (int i = 1; i < WARP_SIZE; i++) { 60 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)]; 61 | } 62 | #endif 63 | return val; 64 | } 65 | 66 | template 67 | static __device__ __forceinline__ Pair warpSum(Pair value) { 68 | value.v1 = warpSum(value.v1); 69 | value.v2 = warpSum(value.v2); 70 | return value; 71 | } -------------------------------------------------------------------------------- /old/optimizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch 6 | import logging 7 | 8 | logger = logging.getLogger() 9 | 10 | class Optimizer(object): 11 | def __init__(self, 12 | model, 13 | lr0, 14 | momentum, 15 | wd, 16 | warmup_steps, 17 | warmup_start_lr, 18 | max_iter, 19 | power, 20 | *args, **kwargs): 21 | self.warmup_steps = warmup_steps 22 | self.warmup_start_lr = warmup_start_lr 23 | self.lr0 = lr0 24 | self.lr = self.lr0 25 | self.max_iter = float(max_iter) 26 | self.power = power 27 | self.it = 0 28 | wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = model.get_params() 29 | param_list = [ 30 | {'params': wd_params}, 31 | {'params': nowd_params, 'weight_decay': 0}, 32 | {'params': lr_mul_wd_params, 'lr_mul': True}, 33 | {'params': lr_mul_nowd_params, 'weight_decay': 0, 'lr_mul': True}] 34 | self.optim = torch.optim.SGD( 35 | param_list, 36 | lr = lr0, 37 | momentum = momentum, 38 | weight_decay = wd) 39 | self.warmup_factor = (self.lr0/self.warmup_start_lr)**(1./self.warmup_steps) 40 | 41 | 42 | def get_lr(self): 43 | if self.it <= self.warmup_steps: 44 | lr = self.warmup_start_lr*(self.warmup_factor**self.it) 45 | else: 46 | factor = (1-(self.it-self.warmup_steps)/(self.max_iter-self.warmup_steps))**self.power 47 | lr = self.lr0 * factor 48 | return lr 49 | 50 | 51 | def step(self): 52 | self.lr = self.get_lr() 53 | for pg in self.optim.param_groups: 54 | if pg.get('lr_mul', False): 55 | pg['lr'] = self.lr * 10 56 | else: 57 | pg['lr'] = self.lr 58 | if self.optim.defaults.get('lr_mul', False): 59 | self.optim.defaults['lr'] = self.lr * 10 60 | else: 61 | self.optim.defaults['lr'] = self.lr 62 | self.it += 1 63 | self.optim.step() 64 | if self.it == self.warmup_steps+2: 65 | logger.info('==> warmup done, start to implement poly lr strategy') 66 | 67 | def zero_grad(self): 68 | self.optim.zero_grad() 69 | 70 | -------------------------------------------------------------------------------- /old/pic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CoinCheung/BiSeNet/3d9b2cf592bcb1185cb52d710b476d8a1bde8120/old/pic.jpg -------------------------------------------------------------------------------- /old/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.model_zoo as modelzoo 8 | 9 | from modules.bn import InPlaceABNSync as BatchNorm2d 10 | 11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | def __init__(self, in_chan, out_chan, stride=1): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(in_chan, out_chan, stride) 24 | self.bn1 = BatchNorm2d(out_chan) 25 | self.conv2 = conv3x3(out_chan, out_chan) 26 | self.bn2 = BatchNorm2d(out_chan, activation='none') 27 | self.relu = nn.ReLU(inplace=True) 28 | self.downsample = None 29 | if in_chan != out_chan or stride != 1: 30 | self.downsample = nn.Sequential( 31 | nn.Conv2d(in_chan, out_chan, 32 | kernel_size=1, stride=stride, bias=False), 33 | BatchNorm2d(out_chan, activation='none'), 34 | ) 35 | 36 | def forward(self, x): 37 | residual = self.conv1(x) 38 | residual = self.bn1(residual) 39 | residual = self.conv2(residual) 40 | residual = self.bn2(residual) 41 | 42 | shortcut = x 43 | if self.downsample is not None: 44 | shortcut = self.downsample(x) 45 | 46 | out = shortcut + residual 47 | out = self.relu(out) 48 | return out 49 | 50 | 51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 53 | for i in range(bnum-1): 54 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 55 | return nn.Sequential(*layers) 56 | 57 | 58 | class Resnet18(nn.Module): 59 | def __init__(self): 60 | super(Resnet18, self).__init__() 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = BatchNorm2d(64) 64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 69 | self.init_weight() 70 | 71 | def forward(self, x): 72 | x = self.conv1(x) 73 | x = self.bn1(x) 74 | x = self.maxpool(x) 75 | 76 | x = self.layer1(x) 77 | feat8 = self.layer2(x) # 1/8 78 | feat16 = self.layer3(feat8) # 1/16 79 | feat32 = self.layer4(feat16) # 1/32 80 | return feat8, feat16, feat32 81 | 82 | def init_weight(self): 83 | state_dict = modelzoo.load_url(resnet18_url) 84 | self_state_dict = self.state_dict() 85 | for k, v in state_dict.items(): 86 | if 'fc' in k: continue 87 | self_state_dict.update({k: v}) 88 | self.load_state_dict(self_state_dict) 89 | 90 | def get_params(self): 91 | wd_params, nowd_params = [], [] 92 | for name, module in self.named_modules(): 93 | if isinstance(module, (nn.Linear, nn.Conv2d)): 94 | wd_params.append(module.weight) 95 | if not module.bias is None: 96 | nowd_params.append(module.bias) 97 | elif isinstance(module, (BatchNorm2d, nn.BatchNorm2d)): 98 | nowd_params += list(module.parameters()) 99 | return wd_params, nowd_params 100 | 101 | 102 | if __name__ == "__main__": 103 | net = Resnet18() 104 | x = torch.randn(16, 3, 224, 224) 105 | out = net(x) 106 | print(out[0].size()) 107 | print(out[1].size()) 108 | print(out[2].size()) 109 | net.get_params() 110 | -------------------------------------------------------------------------------- /old/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | from logger import setup_logger 6 | from model import BiSeNet 7 | from cityscapes import CityScapes 8 | from loss import OhemCELoss 9 | from evaluate import evaluate 10 | from optimizer import Optimizer 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch.utils.data import DataLoader 15 | import torch.nn.functional as F 16 | import torch.distributed as dist 17 | 18 | import os 19 | import os.path as osp 20 | import logging 21 | import time 22 | import datetime 23 | import argparse 24 | 25 | 26 | respth = './res' 27 | if not osp.exists(respth): os.makedirs(respth) 28 | logger = logging.getLogger() 29 | 30 | 31 | def parse_args(): 32 | parse = argparse.ArgumentParser() 33 | parse.add_argument( 34 | '--local_rank', 35 | dest = 'local_rank', 36 | type = int, 37 | default = -1, 38 | ) 39 | parse.add_argument( 40 | '--ckpt', 41 | dest = 'ckpt', 42 | type = str, 43 | default = None, 44 | ) 45 | return parse.parse_args() 46 | 47 | 48 | def train(): 49 | args = parse_args() 50 | torch.cuda.set_device(args.local_rank) 51 | dist.init_process_group( 52 | backend = 'nccl', 53 | init_method = 'tcp://127.0.0.1:33271', 54 | world_size = torch.cuda.device_count(), 55 | rank=args.local_rank 56 | ) 57 | setup_logger(respth) 58 | 59 | ## dataset 60 | n_classes = 19 61 | n_img_per_gpu = 8 62 | n_workers = 4 63 | cropsize = [1024, 1024] 64 | # cropsize = [1024, 512] 65 | ds = CityScapes('./data', cropsize=cropsize, mode='train') 66 | sampler = torch.utils.data.distributed.DistributedSampler(ds) 67 | dl = DataLoader(ds, 68 | batch_size = n_img_per_gpu, 69 | shuffle = False, 70 | sampler = sampler, 71 | num_workers = n_workers, 72 | pin_memory = True, 73 | drop_last = True) 74 | 75 | ## model 76 | ignore_idx = 255 77 | net = BiSeNet(n_classes=n_classes) 78 | if not args.ckpt is None: 79 | net.load_state_dict(torch.load(args.ckpt, map_location='cpu')) 80 | net.cuda() 81 | net.train() 82 | net = nn.parallel.DistributedDataParallel(net, 83 | device_ids = [args.local_rank, ], 84 | output_device = args.local_rank 85 | ) 86 | score_thres = 0.7 87 | n_min = n_img_per_gpu*cropsize[0]*cropsize[1]//16 88 | criteria_p = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) 89 | criteria_16 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) 90 | criteria_32 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) 91 | 92 | ## optimizer 93 | momentum = 0.9 94 | weight_decay = 5e-4 95 | lr_start = 1e-2 96 | max_iter = 80000 97 | power = 0.9 98 | warmup_steps = 1000 99 | warmup_start_lr = 1e-5 100 | optim = Optimizer( 101 | model = net.module, 102 | lr0 = lr_start, 103 | momentum = momentum, 104 | wd = weight_decay, 105 | warmup_steps = warmup_steps, 106 | warmup_start_lr = warmup_start_lr, 107 | max_iter = max_iter, 108 | power = power) 109 | 110 | ## train loop 111 | msg_iter = 50 112 | loss_avg = [] 113 | st = glob_st = time.time() 114 | diter = iter(dl) 115 | epoch = 0 116 | for it in range(max_iter): 117 | try: 118 | im, lb = next(diter) 119 | if not im.size()[0]==n_img_per_gpu: raise StopIteration 120 | except StopIteration: 121 | epoch += 1 122 | sampler.set_epoch(epoch) 123 | diter = iter(dl) 124 | im, lb = next(diter) 125 | im = im.cuda() 126 | lb = lb.cuda() 127 | H, W = im.size()[2:] 128 | lb = torch.squeeze(lb, 1) 129 | 130 | optim.zero_grad() 131 | out, out16, out32 = net(im) 132 | lossp = criteria_p(out, lb) 133 | loss2 = criteria_16(out16, lb) 134 | loss3 = criteria_32(out32, lb) 135 | loss = lossp + loss2 + loss3 136 | loss.backward() 137 | optim.step() 138 | 139 | loss_avg.append(loss.item()) 140 | ## print training log message 141 | if (it+1)%msg_iter==0: 142 | loss_avg = sum(loss_avg) / len(loss_avg) 143 | lr = optim.lr 144 | ed = time.time() 145 | t_intv, glob_t_intv = ed - st, ed - glob_st 146 | eta = int((max_iter - it) * (glob_t_intv / it)) 147 | eta = str(datetime.timedelta(seconds=eta)) 148 | msg = ', '.join([ 149 | 'it: {it}/{max_it}', 150 | 'lr: {lr:4f}', 151 | 'loss: {loss:.4f}', 152 | 'eta: {eta}', 153 | 'time: {time:.4f}', 154 | ]).format( 155 | it = it+1, 156 | max_it = max_iter, 157 | lr = lr, 158 | loss = loss_avg, 159 | time = t_intv, 160 | eta = eta 161 | ) 162 | logger.info(msg) 163 | loss_avg = [] 164 | st = ed 165 | 166 | ## dump the final model 167 | save_pth = osp.join(respth, 'model_final.pth') 168 | net.cpu() 169 | state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict() 170 | if dist.get_rank()==0: torch.save(state, save_pth) 171 | logger.info('training done, model saved to: {}'.format(save_pth)) 172 | 173 | 174 | if __name__ == "__main__": 175 | train() 176 | evaluate() 177 | -------------------------------------------------------------------------------- /old/transform.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | from PIL import Image 6 | import PIL.ImageEnhance as ImageEnhance 7 | import random 8 | 9 | 10 | class RandomCrop(object): 11 | def __init__(self, size, *args, **kwargs): 12 | self.size = size 13 | 14 | def __call__(self, im_lb): 15 | im = im_lb['im'] 16 | lb = im_lb['lb'] 17 | assert im.size == lb.size 18 | W, H = self.size 19 | w, h = im.size 20 | 21 | if (W, H) == (w, h): return dict(im=im, lb=lb) 22 | if w < W or h < H: 23 | scale = float(W) / w if w < h else float(H) / h 24 | w, h = int(scale * w + 1), int(scale * h + 1) 25 | im = im.resize((w, h), Image.BILINEAR) 26 | lb = lb.resize((w, h), Image.NEAREST) 27 | sw, sh = random.random() * (w - W), random.random() * (h - H) 28 | crop = int(sw), int(sh), int(sw) + W, int(sh) + H 29 | return dict( 30 | im = im.crop(crop), 31 | lb = lb.crop(crop) 32 | ) 33 | 34 | 35 | class HorizontalFlip(object): 36 | def __init__(self, p=0.5, *args, **kwargs): 37 | self.p = p 38 | 39 | def __call__(self, im_lb): 40 | if random.random() > self.p: 41 | return im_lb 42 | else: 43 | im = im_lb['im'] 44 | lb = im_lb['lb'] 45 | return dict(im = im.transpose(Image.FLIP_LEFT_RIGHT), 46 | lb = lb.transpose(Image.FLIP_LEFT_RIGHT), 47 | ) 48 | 49 | 50 | class RandomScale(object): 51 | def __init__(self, scales=(1, ), *args, **kwargs): 52 | self.scales = scales 53 | 54 | def __call__(self, im_lb): 55 | im = im_lb['im'] 56 | lb = im_lb['lb'] 57 | W, H = im.size 58 | scale = random.choice(self.scales) 59 | w, h = int(W * scale), int(H * scale) 60 | return dict(im = im.resize((w, h), Image.BILINEAR), 61 | lb = lb.resize((w, h), Image.NEAREST), 62 | ) 63 | 64 | 65 | class ColorJitter(object): 66 | def __init__(self, brightness=None, contrast=None, saturation=None, *args, **kwargs): 67 | if not brightness is None and brightness>0: 68 | self.brightness = [max(1-brightness, 0), 1+brightness] 69 | if not contrast is None and contrast>0: 70 | self.contrast = [max(1-contrast, 0), 1+contrast] 71 | if not saturation is None and saturation>0: 72 | self.saturation = [max(1-saturation, 0), 1+saturation] 73 | 74 | def __call__(self, im_lb): 75 | im = im_lb['im'] 76 | lb = im_lb['lb'] 77 | r_brightness = random.uniform(self.brightness[0], self.brightness[1]) 78 | r_contrast = random.uniform(self.contrast[0], self.contrast[1]) 79 | r_saturation = random.uniform(self.saturation[0], self.saturation[1]) 80 | im = ImageEnhance.Brightness(im).enhance(r_brightness) 81 | im = ImageEnhance.Contrast(im).enhance(r_contrast) 82 | im = ImageEnhance.Color(im).enhance(r_saturation) 83 | return dict(im = im, 84 | lb = lb, 85 | ) 86 | 87 | 88 | class MultiScale(object): 89 | def __init__(self, scales): 90 | self.scales = scales 91 | 92 | def __call__(self, img): 93 | W, H = img.size 94 | sizes = [(int(W*ratio), int(H*ratio)) for ratio in self.scales] 95 | imgs = [] 96 | [imgs.append(img.resize(size, Image.BILINEAR)) for size in sizes] 97 | return imgs 98 | 99 | 100 | class Compose(object): 101 | def __init__(self, do_list): 102 | self.do_list = do_list 103 | 104 | def __call__(self, im_lb): 105 | for comp in self.do_list: 106 | im_lb = comp(im_lb) 107 | return im_lb 108 | 109 | 110 | 111 | 112 | if __name__ == '__main__': 113 | flip = HorizontalFlip(p = 1) 114 | crop = RandomCrop((321, 321)) 115 | rscales = RandomScale((0.75, 1.0, 1.5, 1.75, 2.0)) 116 | img = Image.open('data/img.jpg') 117 | lb = Image.open('data/label.png') 118 | -------------------------------------------------------------------------------- /openvino/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | cmake_minimum_required (VERSION 3.10) 3 | 4 | cmake_policy(SET CMP0025 NEW) 5 | 6 | project(Samples) 7 | 8 | 9 | set (CMAKE_CXX_STANDARD 14) 10 | set(CMAKE_BUILD_TYPE "Release") 11 | set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wuninitialized -Winit-self") 12 | set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall") 13 | 14 | 15 | find_package(OpenCV REQUIRED) 16 | find_package(InferenceEngine REQUIRED) 17 | find_package(ngraph REQUIRED) 18 | 19 | 20 | include_directories( 21 | ${CMAKE_CURRENT_SOURCE_DIR} 22 | ${CMAKE_CURRENT_BINARY_DIR} 23 | ${OpenCV_INCLUDE_DIRS} 24 | ) 25 | 26 | add_executable(segment main.cpp) 27 | target_link_libraries( 28 | segment 29 | ${InferenceEngine_LIBRARIES} 30 | ${NGRAPH_LIBRARIES} 31 | ${OpenCV_LIBS} 32 | ) 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /openvino/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## A demo of using openvino to deploy 4 | 5 | Openvino is used to deploy model on intel cpus or "gpu inside cpu". 6 | 7 | My platform: 8 | * Ubuntu 18.04 9 | * Intel(R) Xeon(R) Gold 6240 CPU @ 2.60GHz 10 | * openvino_2021.4.689 11 | 12 | 13 | ### preparation 14 | 15 | 1.Train the model and export it to onnx 16 | ``` 17 | $ cd BiSeNet/ 18 | $ python tools/export_onnx.py --config configs/bisenetv2_city.py --weight-path /path/to/your/model.pth --outpath ./model_v2.onnx 19 | ``` 20 | (Optional) 2.Install 'onnx-simplifier' to simplify the generated onnx model: 21 | ``` 22 | $ python -m pip install onnx-simplifier 23 | $ python -m onnxsim model_v2.onnx model_v2_sim.onnx 24 | ``` 25 | 26 | 27 | ### Install and configure openvino 28 | 29 | 1.pull docker image 30 | ``` 31 | $ docker pull openvino/ubuntu18_dev 32 | ``` 33 | 34 | 2.start a docker container and mount code into it 35 | ``` 36 | $ docker run -itu root -v /path/to/BiSeNet:/BiSeNet openvino/ubuntu18_dev --device /dev/dri:/dev/dri bash 37 | 38 | ``` 39 | If your cpu does not have intel "gpu inside of cpu" or you do not want to use it, you can remove the option of `--device /dev/dri:/dev/dri`. 40 | 41 | After running the above command, you will be in the container. 42 | 43 | (optional) 3.install gpu dependencies 44 | If you want to use gpu, you also need to install some dependencies inside the container: 45 | ``` 46 | # mkdir -p /tmp/opencl && cd /tmp/opencl 47 | # useradd -ms /bin/bash -G video,users openvino 48 | # chown openvino -R /home/openvino 49 | # apt update 50 | # apt install -y --no-install-recommends ocl-icd-libopencl1 51 | # curl -L "https://github.com/intel/compute-runtime/releases/download/19.41.14441/intel-gmmlib_19.3.2_amd64.deb" --output "intel-gmmlib_19.3.2_amd64.deb" 52 | # curl -L "https://github.com/intel/compute-runtime/releases/download/19.41.14441/intel-igc-core_1.0.2597_amd64.deb" --output "intel-igc-core_1.0.2597_amd64.deb" 53 | # curl -L "https://github.com/intel/compute-runtime/releases/download/19.41.14441/intel-igc-opencl_1.0.2597_amd64.deb" --output "intel-igc-opencl_1.0.2597_amd64.deb" 54 | # curl -L "https://github.com/intel/compute-runtime/releases/download/19.41.14441/intel-opencl_19.41.14441_amd64.deb" --output "intel-opencl_19.41.14441_amd64.deb" 55 | # curl -L "https://github.com/intel/compute-runtime/releases/download/19.41.14441/intel-ocloc_19.41.14441_amd64.deb" --output "intel-ocloc_19.04.12237_amd64.deb" 56 | # dpkg -i /tmp/opencl/*.deb 57 | # apt --fix-broken install 58 | # ldconfig 59 | ``` 60 | 61 | I got the above commands from the official docs but I did not test it since my cpu does not have integrated gpu. 62 | 63 | You can check if your platform has intel gpu with this command: 64 | ``` 65 | $ sudo lspci | grep -i vga 66 | ``` 67 | 68 | 4.configure environment 69 | just run this script, and the environment would be ready: 70 | ``` 71 | # source /opt/intel/openvino_2021.4.689/bin/setupvars.sh 72 | ``` 73 | 74 | 75 | ### convert model and run demo 76 | 77 | 1.convert onnx to openvino IR 78 | In the docker container: 79 | ``` 80 | # cd /opt/intel/openvino_2021.4.689/deployment_tools/model_optimizer 81 | # python3 mo.py --input_model /BiSeNet/model_v2.onnx --output_dir /BiSeNet/openvino/output_v2 82 | ``` 83 | 84 | 2.compile and run the demo 85 | ``` 86 | # cd /BiSeNet/openvino 87 | # mkdir -p build && cd build 88 | # cmake .. && make 89 | # ./segment 90 | ``` 91 | After this, you will see a segmentation result image named `res.jpg` generated. 92 | 93 | 94 | 95 | ### Tipes 96 | 97 | 1. GPU support: openvino supports intel cpu and intel "gpu inside cpu". Until now(2021.11), other popular isolated gpus are not supported, such as nvidia/amd gpus. Also, other integrated gpus are not supported, such as aspeed graphics family. 98 | 99 | 2. About low-precision: precision is optimized automatically, and the model will be run in one or several precision mode. We can also manually enforce to use bf16, as long as our cpu have `avx512_bf16` supports. If cpu does not support bf16, it will use simulation which would slow down the inference. If neither native bf16 nor simulation is supported, an error would occur. 100 | -------------------------------------------------------------------------------- /tensorrt/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | CMAKE_MINIMUM_REQUIRED(VERSION 3.22) 2 | 3 | PROJECT(segment LANGUAGES CUDA CXX) 4 | 5 | # set(CMAKE_CXX_FLAGS "-std=c++17 -O2") 6 | 7 | message (${CMAKE_CUDA_ARCHITECTURES}) 8 | 9 | set (CMAKE_BUILD_TYPE Release) 10 | set (CMAKE_CUDA_FLAGS_RELEASE "-O2 -DNDEBUG") 11 | if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) 12 | set (CMAKE_CUDA_ARCHITECTURES 80) 13 | endif () 14 | 15 | # link_directories(${PROJECT_SOURCE_DIR}/build ${PROJECT_SOURCE_DIR}/build/plugins) 16 | # include_directories(/root/build/TensorRT-8.2.5.1/include) 17 | # link_directories(/root/build/TensorRT-8.2.5.1/lib) 18 | 19 | 20 | enable_language (CUDA) 21 | find_package (OpenCV REQUIRED) 22 | 23 | include (CheckLanguage) 24 | check_language (CUDA) 25 | check_language (OpenCV) 26 | 27 | add_subdirectory(./plugins/) # custom_plugin 28 | 29 | 30 | add_executable(segment segment.cu read_img.cpp trt_dep.cu) 31 | target_compile_features(segment PRIVATE cxx_std_17 cuda_std_14) 32 | target_include_directories(segment PUBLIC ${OpenCV_INCLUDE_DIRS}) 33 | target_link_libraries(segment ${OpenCV_LIBRARIES} 34 | nvinfer nvinfer_plugin nvonnxparser custom_plugin 35 | ) 36 | 37 | -------------------------------------------------------------------------------- /tensorrt/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Deploy with Tensorrt 4 | 5 | Firstly, We should export our trained model to onnx model: 6 | ``` 7 | $ cd BiSeNet/ 8 | $ python tools/export_onnx.py --config configs/bisenetv2_city.py --weight-path /path/to/your/model.pth --outpath ./model.onnx --aux-mode eval 9 | ``` 10 | 11 | **NOTE:** I use cropsize of `1024x2048` here in my example, you should change it according to your specific application. The inference cropsize is fixed from this step on, so you should decide the inference cropsize when you export the model here. 12 | 13 | Then we can use either c++ or python to compile the model and run inference. 14 | 15 | 16 | ### Using C++ 17 | 18 | #### 1. My platform 19 | 20 | * ubuntu 22.04 21 | * nvidia A40 gpu, driver newer than 555.42.06 22 | * cuda 12.1, cudnn 8 23 | * cmake 3.22.1 24 | * opencv built from source 25 | * tensorrt 10.3.0.26 26 | 27 | 28 | 29 | #### 2. Build with source code 30 | Just use the standard cmake build method: 31 | ``` 32 | mkdir -p tensorrt/build 33 | cd tensorrt/build 34 | cmake .. 35 | make 36 | ``` 37 | This would generate a `./segment` in the `tensorrt/build` directory. 38 | 39 | 40 | #### 3. Convert onnx to tensorrt model 41 | If you can successfully compile the source code, you can parse the onnx model to tensorrt model with one of the following commands. 42 | For fp32/fp16/bf16, command is: 43 | ``` 44 | $ ./segment compile /path/to/onnx.model /path/to/saved_model.trt --fp32 45 | $ ./segment compile /path/to/onnx.model /path/to/saved_model.trt --fp16 46 | $ ./segment compile /path/to/onnx.model /path/to/saved_model.trt --bf16 47 | ``` 48 | Make sure that your gpu support acceleration with fp16/bf16 inferenece when you set these options.
49 | 50 | Building an int8 engine is also supported. Firstly, you should make sure your gpu support int8 inference, or you model will not be faster than fp16/fp32. Then you should prepare certain amount of images for int8 calibration. In this example, I use train set of cityscapes for calibration. The command is like this: 51 | ``` 52 | $ rm calibrate_int8 # delete this if exists 53 | $ ./segment compile /path/to/onnx.model /path/to/saved_model.trt --int8 /path/to/BiSeNet/datasets/cityscapes /path/to/BiSeNet/datasets/cityscapes/train.txt 54 | ``` 55 | With the above commands, we will have an tensorrt engine named `saved_model.trt` generated. 56 | 57 | Note that I use the simplest method to parse the command line args, so please do **Not** change the order of the args in above command. 58 | 59 | 60 | #### 4. Infer with one single image 61 | Run inference like this: 62 | ``` 63 | $ ./segment run /path/to/saved_model.trt /path/to/input/image.jpg /path/to/saved_img.jpg 64 | ``` 65 | 66 | 67 | #### 5. Test speed 68 | The speed depends on the specific gpu platform you are working on, you can test the fps on your gpu like this: 69 | ``` 70 | $ ./segment test /path/to/saved_model.trt 71 | ``` 72 | 73 | 74 | #### 6. Tips: 75 | 76 | The speed(fps) is tested on a single nvidia A40 gpu with `batchsize=1` and `cropsize=(1024,2048)`, which might be different from your platform and settings. You should evaluate the speed considering your own platform and cropsize. Also note that the performance would be affected if your gpu is concurrently working on other tasks. Please make sure no other program is running on your gpu when you test the speed. 77 | 78 | 79 | 80 | ### Using python (this is not updated to tensorrt 10.3) 81 | 82 | You can also use python script to compile and run inference of your model.
83 | 84 | Following is still the usage method of tensorrt 8.2.
85 | 86 | 87 | #### 1. Compile model to onnx 88 | 89 | 90 | With this command: 91 | ``` 92 | $ cd BiSeNet/tensorrt 93 | $ python segment.py compile --onnx /path/to/model.onnx --savepth ./model.trt --quant fp16/fp32 94 | ``` 95 | 96 | This will compile onnx model into tensorrt serialized engine, save save to `./model.trt`. 97 | 98 | 99 | #### 2. Inference with Tensorrt 100 | 101 | Run Inference like this: 102 | ``` 103 | $ python segment.py run --mdpth ./model.trt --impth ../example.png --outpth ./res.png 104 | ``` 105 | 106 | This will use the tensorrt model compiled above, and run inference with the example image. 107 | 108 | -------------------------------------------------------------------------------- /tensorrt/batch_stream.hpp: -------------------------------------------------------------------------------- 1 | 2 | #ifndef BATCH_STREAM_HPP 3 | #define BATCH_STREAM_HPP 4 | 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include "NvInfer.h" 16 | #include "read_img.hpp" 17 | 18 | using nvinfer1::Dims; 19 | using nvinfer1::Dims3; 20 | using nvinfer1::Dims4; 21 | 22 | 23 | class IBatchStream 24 | { 25 | public: 26 | virtual void reset(int firstBatch) = 0; 27 | virtual bool next() = 0; 28 | virtual void skip(int skipCount) = 0; 29 | virtual float* getBatch() = 0; 30 | virtual int getBatchesRead() const = 0; 31 | virtual int getBatchSize() const = 0; 32 | virtual nvinfer1::Dims4 getDims() const = 0; 33 | }; 34 | 35 | 36 | class BatchStream : public IBatchStream 37 | { 38 | public: 39 | BatchStream(int batchSize, int maxBatches, Dims indim, 40 | const std::string& dataRoot, 41 | const std::string& dataFile) 42 | : mBatchSize{batchSize} 43 | , mMaxBatches{maxBatches} 44 | { 45 | mDims = Dims3(indim.d[1], indim.d[2], indim.d[3]); 46 | 47 | readDataFile(dataFile, dataRoot); 48 | mSampleSize = std::accumulate( 49 | mDims.d, mDims.d + mDims.nbDims, 1, std::multiplies()) * sizeof(float); 50 | mData.resize(mSampleSize * mBatchSize); 51 | } 52 | 53 | void reset(int firstBatch) override 54 | { 55 | mBatchCount = firstBatch; 56 | } 57 | 58 | bool next() override 59 | { 60 | if (mBatchCount >= mMaxBatches) 61 | { 62 | return false; 63 | } 64 | ++mBatchCount; 65 | return true; 66 | } 67 | 68 | void skip(int skipCount) override 69 | { 70 | mBatchCount += skipCount; 71 | } 72 | 73 | float* getBatch() override 74 | { 75 | int offset = mBatchCount * mBatchSize; 76 | for (int i{0}; i < mBatchSize; ++i) { 77 | int ind = offset + i; 78 | read_data(mPaths[ind], &mData[i * mSampleSize], mDims.d[1], mDims.d[2]); 79 | } 80 | return mData.data(); 81 | } 82 | 83 | int getBatchesRead() const override 84 | { 85 | return mBatchCount; 86 | } 87 | 88 | int getBatchSize() const override 89 | { 90 | return mBatchSize; 91 | } 92 | 93 | nvinfer1::Dims4 getDims() const override 94 | { 95 | return Dims4{mBatchSize, mDims.d[0], mDims.d[1], mDims.d[2]}; 96 | } 97 | 98 | private: 99 | void readDataFile(const std::string& dataFilePath, const std::string& dataRootPath) 100 | { 101 | std::ifstream file(dataFilePath, std::ios::in); 102 | if (!file.is_open()) { 103 | cout << "file open failed: " << dataFilePath << endl; 104 | std::abort(); 105 | } 106 | std::stringstream ss; 107 | file >> ss.rdbuf(); 108 | file.close(); 109 | 110 | std::string impth; 111 | int n_imgs = 0; 112 | while (std::getline(ss, impth)) ++n_imgs; 113 | ss.clear(); ss.seekg(0, std::ios::beg); 114 | if (n_imgs <= 0) { 115 | cout << "ann file is empty, cannot read image paths for int8 calibration: " 116 | << dataFilePath << endl; 117 | std::abort(); 118 | } 119 | 120 | mPaths.resize(n_imgs); 121 | for (int i{0}; i < n_imgs; ++i) { 122 | std::getline(ss, impth, ','); 123 | mPaths[i] = dataRootPath + "/" + impth; 124 | std::getline(ss, impth); 125 | } 126 | if (mMaxBatches < 0) { 127 | mMaxBatches = n_imgs / mBatchSize - 1; 128 | } 129 | if (mMaxBatches <= 0) { 130 | cout << "must have at least 1 batch for calibration\n"; 131 | std::abort(); 132 | } 133 | cout << "mMaxBatches = " << mMaxBatches << endl; 134 | } 135 | 136 | 137 | int mBatchSize{0}; 138 | int mBatchCount{0}; 139 | int mMaxBatches{0}; 140 | Dims3 mDims{}; 141 | std::vector mPaths; 142 | std::vector mData; 143 | int mSampleSize{0}; 144 | }; 145 | 146 | 147 | #endif 148 | -------------------------------------------------------------------------------- /tensorrt/entropy_calibrator.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef ENTROPY_CALIBRATOR_HPP 18 | #define ENTROPY_CALIBRATOR_HPP 19 | 20 | #include 21 | #include 22 | #include 23 | #include "NvInfer.h" 24 | 25 | //! \class EntropyCalibratorImpl 26 | //! 27 | //! \brief Implements common functionality for Entropy calibrators. 28 | //! 29 | template 30 | class EntropyCalibratorImpl 31 | { 32 | public: 33 | EntropyCalibratorImpl( 34 | TBatchStream stream, int firstBatch, std::string cal_table_name, const char* inputBlobName, bool readCache = true) 35 | : mStream{stream} 36 | , mCalibrationTableName(cal_table_name) 37 | , mInputBlobName(inputBlobName) 38 | , mReadCache(readCache) 39 | { 40 | nvinfer1::Dims4 dims = mStream.getDims(); 41 | mInputCount = std::accumulate( 42 | dims.d, dims.d + dims.nbDims, 1, std::multiplies()); 43 | cout << "dims.nbDims: " << dims.nbDims << endl; 44 | for (int i{0}; i < dims.nbDims; ++i) { 45 | cout << dims.d[i] << ", "; 46 | } 47 | cout << endl; 48 | 49 | cudaError_t state; 50 | state = cudaMalloc(&mDeviceInput, mInputCount * sizeof(float)); 51 | if (state) { 52 | cout << "allocate memory failed\n"; 53 | std::abort(); 54 | } 55 | cout << "mInputCount: " << mInputCount << endl; 56 | mStream.reset(firstBatch); 57 | } 58 | 59 | virtual ~EntropyCalibratorImpl() 60 | { 61 | cudaError_t state; 62 | state = cudaFree(mDeviceInput); 63 | if (state) { 64 | cout << "free memory failed\n"; 65 | std::abort(); 66 | } 67 | } 68 | 69 | int getBatchSize() const 70 | { 71 | return mStream.getBatchSize(); 72 | } 73 | 74 | bool getBatch(void* bindings[], const char* names[], int nbBindings) 75 | { 76 | if (!mStream.next()) 77 | { 78 | return false; 79 | } 80 | cudaError_t state; 81 | state = cudaMemcpy(mDeviceInput, mStream.getBatch(), mInputCount * sizeof(float), cudaMemcpyHostToDevice); 82 | if (state) { 83 | cout << "memory copy to device failed\n"; 84 | std::abort(); 85 | } 86 | assert(!strcmp(names[0], mInputBlobName)); 87 | bindings[0] = mDeviceInput; 88 | return true; 89 | } 90 | 91 | const void* readCalibrationCache(size_t& length) 92 | { 93 | mCalibrationCache.clear(); 94 | std::ifstream input(mCalibrationTableName, std::ios::binary); 95 | input >> std::noskipws; 96 | if (mReadCache && input.good()) 97 | { 98 | std::copy(std::istream_iterator(input), std::istream_iterator(), 99 | std::back_inserter(mCalibrationCache)); 100 | } 101 | length = mCalibrationCache.size(); 102 | return length ? mCalibrationCache.data() : nullptr; 103 | } 104 | 105 | void writeCalibrationCache(const void* cache, size_t length) 106 | { 107 | std::ofstream output(mCalibrationTableName, std::ios::binary); 108 | output.write(reinterpret_cast(cache), length); 109 | } 110 | 111 | private: 112 | TBatchStream mStream; 113 | size_t mInputCount; 114 | std::string mCalibrationTableName; 115 | const char* mInputBlobName; 116 | bool mReadCache{true}; 117 | void* mDeviceInput{nullptr}; 118 | std::vector mCalibrationCache; 119 | }; 120 | 121 | //! \class Int8EntropyCalibrator2 122 | //! 123 | //! \brief Implements Entropy calibrator 2. 124 | //! CalibrationAlgoType is kENTROPY_CALIBRATION_2. 125 | //! 126 | template 127 | class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 128 | { 129 | public: 130 | Int8EntropyCalibrator2( 131 | TBatchStream stream, int firstBatch, const char* networkName, const char* inputBlobName, bool readCache = true) 132 | : mImpl(stream, firstBatch, networkName, inputBlobName, readCache) 133 | { 134 | } 135 | 136 | int getBatchSize() const noexcept override 137 | { 138 | return mImpl.getBatchSize(); 139 | } 140 | 141 | bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override 142 | { 143 | return mImpl.getBatch(bindings, names, nbBindings); 144 | } 145 | 146 | const void* readCalibrationCache(size_t& length) noexcept override 147 | { 148 | return mImpl.readCalibrationCache(length); 149 | } 150 | 151 | void writeCalibrationCache(const void* cache, size_t length) noexcept override 152 | { 153 | mImpl.writeCalibrationCache(cache, length); 154 | } 155 | 156 | private: 157 | EntropyCalibratorImpl mImpl; 158 | }; 159 | 160 | #endif // ENTROPY_CALIBRATOR_H 161 | -------------------------------------------------------------------------------- /tensorrt/plugins/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | add_library (custom_plugin SHARED argmax_plugin.cu) 5 | target_compile_features (custom_plugin PRIVATE cuda_std_14) 6 | target_include_directories (custom_plugin PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) 7 | set_property (TARGET custom_plugin PROPERTY CUDA_ARCHITECTURES 80) # until a100 8 | 9 | -------------------------------------------------------------------------------- /tensorrt/plugins/argmax_plugin.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef TENSORRT_ARGMAX_PLUGIN_H 3 | #define TENSORRT_ARGMAX_PLUGIN_H 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | #include "NvInferPlugin.h" 10 | #include "NvInfer.h" 11 | 12 | 13 | using namespace nvinfer1; 14 | 15 | 16 | namespace nvinfer1 { 17 | 18 | 19 | class ArgMaxPlugin : public IPluginV3, public IPluginV3OneCore, public IPluginV3OneBuildV2, public IPluginV3OneRuntime 20 | { 21 | public: 22 | ArgMaxPlugin(ArgMaxPlugin const& p) = default; 23 | 24 | ArgMaxPlugin(int64_t axis); 25 | 26 | 27 | // IPluginV3 methods 28 | 29 | IPluginCapability* getCapabilityInterface(PluginCapabilityType type) noexcept override; 30 | 31 | IPluginV3* clone() noexcept override; 32 | 33 | // IPluginV3OneCore methods 34 | char const* getPluginName() const noexcept override; 35 | 36 | char const* getPluginVersion() const noexcept override; 37 | 38 | char const* getPluginNamespace() const noexcept override; 39 | 40 | void setPluginNamespace(char const* pluginNamespace) noexcept; 41 | 42 | // IPluginV3OneBuild methods 43 | int32_t getNbOutputs() const noexcept override; 44 | 45 | int32_t configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, 46 | int32_t nbOutputs) noexcept override; 47 | 48 | bool supportsFormatCombination( 49 | int32_t pos, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept override; 50 | 51 | int32_t getOutputDataTypes( 52 | DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, int32_t nbInputs) const noexcept override; 53 | 54 | int32_t getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs, 55 | int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, IExprBuilder& exprBuilder) noexcept override; 56 | 57 | size_t getWorkspaceSize(DynamicPluginTensorDesc const* inputs, int32_t nbInputs, 58 | DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override; 59 | 60 | // IPluginV3OneRuntime methods 61 | int32_t enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, void const* const* inputs, 62 | void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; 63 | 64 | int32_t onShapeChange( 65 | PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out, int32_t nbOutputs) noexcept override; 66 | 67 | IPluginV3* attachToContext(IPluginResourceContext* context) noexcept override; 68 | 69 | PluginFieldCollection const* getFieldsToSerialize() noexcept override; 70 | 71 | 72 | private: 73 | int64_t mAxis; 74 | std::vector mDataToSerialize; 75 | nvinfer1::PluginFieldCollection mFCToSerialize; 76 | std::string mNamespace; 77 | }; 78 | 79 | 80 | 81 | 82 | class ArgMaxPluginCreator : public nvinfer1::IPluginCreatorV3One 83 | { 84 | public: 85 | ArgMaxPluginCreator(); 86 | 87 | ~ArgMaxPluginCreator() override = default; 88 | 89 | char const* getPluginName() const noexcept override; 90 | 91 | char const* getPluginVersion() const noexcept override; 92 | 93 | PluginFieldCollection const* getFieldNames() noexcept override; 94 | 95 | IPluginV3* createPlugin(char const* name, PluginFieldCollection const* fc, TensorRTPhase phase) noexcept override; 96 | 97 | char const* getPluginNamespace() const noexcept override; 98 | 99 | void setPluginNamespace(char const* libNamespace) noexcept; 100 | 101 | 102 | private: 103 | nvinfer1::PluginFieldCollection mFC; 104 | std::vector mPluginAttributes; 105 | std::string mNamespace; 106 | }; 107 | 108 | } // namespace nvinfer1 109 | 110 | #endif 111 | 112 | -------------------------------------------------------------------------------- /tensorrt/read_img.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | using std::cout; 11 | using std::endl; 12 | using std::vector; 13 | using std::string; 14 | using cv::Mat; 15 | 16 | 17 | void read_data(std::string impth, float *data, int iH, int iW, 18 | int& orgH, int& orgW) { 19 | vector mean{0.485f, 0.456f, 0.406f}; // rgb order 20 | vector variance{0.229f, 0.224f, 0.225f}; 21 | 22 | Mat im = cv::imread(impth); 23 | if (im.empty()) { 24 | cout << "cannot read image \n"; 25 | std::abort(); 26 | } 27 | 28 | orgH = im.rows; orgW = im.cols; 29 | if ((orgH != iH) || orgW != iW) { 30 | cout << "resize orignal image of (" << orgH << "," << orgW 31 | << ") to (" << iH << ", " << iW << ") according to model require\n"; 32 | cv::resize(im, im, cv::Size(iW, iH), cv::INTER_CUBIC); 33 | } 34 | 35 | // normalize and convert to rgb 36 | float scale = 1.f / 255.f; 37 | for (int i{0}; i < variance.size(); ++ i) { 38 | variance[i] = 1.f / variance[i]; 39 | } 40 | for (int h{0}; h < iH; ++h) { 41 | cv::Vec3b *p = im.ptr(h); 42 | for (int w{0}; w < iW; ++w) { 43 | for (int c{0}; c < 3; ++c) { 44 | int idx = c * iH * iW + h * iW + w; 45 | data[idx] = (p[w][2 - c] * scale - mean[c]) * variance[c]; 46 | } 47 | } 48 | } 49 | } 50 | 51 | 52 | void read_data(std::string impth, float *data, int iH, int iW) { 53 | int tmp1, tmp2; 54 | read_data(impth, data, iH, iW, tmp1, tmp2); 55 | } 56 | 57 | -------------------------------------------------------------------------------- /tensorrt/read_img.hpp: -------------------------------------------------------------------------------- 1 | 2 | #ifndef _READ_IMAGE_HPP_ 3 | #define _READ_IMAGE_HPP_ 4 | 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | 14 | using std::cout; 15 | using std::endl; 16 | using std::vector; 17 | using std::string; 18 | using cv::Mat; 19 | 20 | 21 | void read_data(std::string impth, float *data, 22 | int iH, int iW, int& orgH, int& orgW); 23 | void read_data(std::string impth, float *data, int iH, int iW); 24 | 25 | 26 | #endif 27 | -------------------------------------------------------------------------------- /tensorrt/segment.cu: -------------------------------------------------------------------------------- 1 | // #include "NvInfer.h" 2 | // #include "NvOnnxParser.h" 3 | // #include "NvInferPlugin.h" 4 | // #include "NvInferRuntimeCommon.h" 5 | 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include "trt_dep.hpp" 19 | #include "read_img.hpp" 20 | 21 | 22 | // using nvinfer1::IHostMemory; 23 | // using nvinfer1::IBuilder; 24 | // using nvinfer1::INetworkDefinition; 25 | // using nvinfer1::ICudaEngine; 26 | // using nvinfer1::IInt8Calibrator; 27 | // using nvinfer1::IBuilderConfig; 28 | // using nvinfer1::IRuntime; 29 | // using nvinfer1::IExecutionContext; 30 | // using nvinfer1::ILogger; 31 | // using nvinfer1::Dims; 32 | // using Severity = nvinfer1::ILogger::Severity; 33 | 34 | using std::string; 35 | using std::ios; 36 | using std::ofstream; 37 | using std::ifstream; 38 | using std::vector; 39 | using std::cout; 40 | using std::endl; 41 | using std::array; 42 | using std::stringstream; 43 | 44 | using cv::Mat; 45 | 46 | 47 | 48 | 49 | vector> get_color_map(); 50 | 51 | void compile_onnx(vector args); 52 | void run_with_trt(vector args); 53 | void test_speed(vector args); 54 | 55 | 56 | int main(int argc, char* argv[]) { 57 | CHECK (argc >= 3, "usage is ./segment compile/run/test"); 58 | 59 | vector args; 60 | for (int i{1}; i < argc; ++i) args.emplace_back(argv[i]); 61 | 62 | if (args[0] == "compile") { 63 | stringstream ss; 64 | ss << "usage is: ./segment compile input.onnx output.trt [--fp16|--fp32|--bf16|--fp8]\n" 65 | << "or ./segment compile input.onnx output.trt --int8 /path/to/data_root /path/to/ann_file\n"; 66 | CHECK (argc >= 5, ss.str()); 67 | compile_onnx(args); 68 | } else if (args[0] == "run") { 69 | CHECK (argc >= 5, "usage is ./segment run ./xxx.trt input.jpg result.jpg"); 70 | run_with_trt(args); 71 | } else if (args[0] == "test") { 72 | CHECK (argc >= 3, "usage is ./segment test ./xxx.trt"); 73 | test_speed(args); 74 | } else { 75 | CHECK (false, "usage is ./segment compile/run/test"); 76 | } 77 | 78 | return 0; 79 | } 80 | 81 | 82 | void compile_onnx(vector args) { 83 | 84 | string quant("fp32"); 85 | string data_root("none"); 86 | string data_file("none"); 87 | int opt_bsize = 1; 88 | 89 | std::unordered_map quant_map{ 90 | {"--fp32", "fp32"}, 91 | {"--fp16", "fp16"}, 92 | {"--bf16", "bf16"}, 93 | {"--fp8", "fp8"}, 94 | {"--int8", "int8"}, 95 | }; 96 | CHECK (quant_map.find(args[3]) != quant_map.end(), 97 | "invalid args of quantization: " + args[3]); 98 | quant = quant_map[args[3]]; 99 | if (quant == "int8") { 100 | data_root = args[4]; 101 | data_file = args[5]; 102 | } 103 | 104 | if (args[3] == "--int8") { 105 | if (args.size() > 6) opt_bsize = std::stoi(args[6]); 106 | } else { 107 | if (args.size() > 4) opt_bsize = std::stoi(args[4]); 108 | } 109 | 110 | SemanticSegmentTrt ss_trt; 111 | ss_trt.set_opt_batch_size(opt_bsize); 112 | ss_trt.parse_to_engine(args[1], quant, data_root, data_file); 113 | ss_trt.serialize(args[2]); 114 | } 115 | 116 | 117 | void run_with_trt(vector args) { 118 | 119 | SemanticSegmentTrt ss_trt; 120 | ss_trt.deserialize(args[1]); 121 | 122 | vector i_dims = ss_trt.get_input_shape(); 123 | vector o_dims = ss_trt.get_output_shape(); 124 | 125 | const int iH{i_dims[2]}, iW{i_dims[3]}; 126 | const int oH{o_dims[1]}, oW{o_dims[2]}; 127 | 128 | // prepare image and resize 129 | vector data; data.resize(iH * iW * 3); 130 | int orgH, orgW; 131 | read_data(args[2], &data[0], iH, iW, orgH, orgW); 132 | 133 | // call engine 134 | vector res = ss_trt.inference(data); 135 | 136 | // generate colored out 137 | vector> color_map = get_color_map(); 138 | Mat pred(cv::Size(oW, oH), CV_8UC3); 139 | 140 | int idx{0}; 141 | for (int i{0}; i < oH; ++i) { 142 | uint8_t *ptr = pred.ptr(i); 143 | for (int j{0}; j < oW; ++j) { 144 | 145 | ptr[0] = color_map[res[idx]][0]; 146 | ptr[1] = color_map[res[idx]][1]; 147 | ptr[2] = color_map[res[idx]][2]; 148 | ptr += 3; 149 | ++idx; 150 | } 151 | } 152 | 153 | // resize back and save 154 | if ((orgH != oH) || (orgW != oW)) { 155 | cv::resize(pred, pred, cv::Size(orgW, orgH), cv::INTER_CUBIC); 156 | } 157 | cv::imwrite(args[3], pred); 158 | } 159 | 160 | 161 | vector> get_color_map() { 162 | vector> color_map(256, vector(3)); 163 | std::minstd_rand rand_eng(123); 164 | std::uniform_int_distribution u(0, 255); 165 | for (int i{0}; i < 256; ++i) { 166 | for (int j{0}; j < 3; ++j) { 167 | color_map[i][j] = u(rand_eng); 168 | } 169 | } 170 | return color_map; 171 | } 172 | 173 | 174 | void test_speed(vector args) { 175 | int opt_bsize = 1; 176 | if (args.size() > 2) opt_bsize = std::stoi(args[2]); 177 | 178 | SemanticSegmentTrt ss_trt; 179 | ss_trt.set_opt_batch_size(opt_bsize); 180 | ss_trt.deserialize(args[1]); 181 | ss_trt.test_speed_fps(); 182 | } 183 | -------------------------------------------------------------------------------- /tensorrt/segment.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import os.path as osp 4 | import cv2 5 | import numpy as np 6 | import logging 7 | import argparse 8 | 9 | import tensorrt as trt 10 | import pycuda.driver as cuda 11 | import pycuda.autoinit 12 | 13 | 14 | parser = argparse.ArgumentParser() 15 | subparsers = parser.add_subparsers(dest="command") 16 | compile_parser = subparsers.add_parser('compile') 17 | compile_parser.add_argument('--onnx') 18 | compile_parser.add_argument('--quant', default='fp32') 19 | compile_parser.add_argument('--savepth', default='./model.trt') 20 | run_parser = subparsers.add_parser('run') 21 | run_parser.add_argument('--mdpth') 22 | run_parser.add_argument('--impth') 23 | run_parser.add_argument('--outpth', default='./res.png') 24 | args = parser.parse_args() 25 | 26 | 27 | np.random.seed(123) 28 | in_datatype = trt.nptype(trt.float32) 29 | out_datatype = trt.nptype(trt.int32) 30 | palette = np.random.randint(0, 256, (256, 3)).astype(np.uint8) 31 | 32 | ctx = pycuda.autoinit.context 33 | trt.init_libnvinfer_plugins(None, "") 34 | TRT_LOGGER = trt.Logger() 35 | 36 | 37 | 38 | def get_image(impth, size): 39 | mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)[:, None, None] 40 | var = np.array([0.229, 0.224, 0.225], dtype=np.float32)[:, None, None] 41 | iH, iW = size[0], size[1] 42 | img = cv2.imread(impth)[:, :, ::-1] 43 | orgH, orgW, _ = img.shape 44 | img = cv2.resize(img, (iW, iH)).astype(np.float32) 45 | img = img.transpose(2, 0, 1) / 255. 46 | img = (img - mean) / var 47 | return img, (orgH, orgW) 48 | 49 | 50 | 51 | def allocate_buffers(engine): 52 | h_input = cuda.pagelocked_empty( 53 | trt.volume(engine.get_binding_shape(0)), dtype=in_datatype) 54 | print(engine.get_binding_shape(0)) 55 | d_input = cuda.mem_alloc(h_input.nbytes) 56 | h_outputs, d_outputs = [], [] 57 | n_outs = 1 58 | for i in range(n_outs): 59 | h_output = cuda.pagelocked_empty( 60 | trt.volume(engine.get_binding_shape(i+1)), 61 | dtype=out_datatype) 62 | d_output = cuda.mem_alloc(h_output.nbytes) 63 | h_outputs.append(h_output) 64 | d_outputs.append(d_output) 65 | stream = cuda.Stream() 66 | return ( 67 | stream, 68 | h_input, 69 | d_input, 70 | h_outputs, 71 | d_outputs, 72 | ) 73 | 74 | 75 | def build_engine_from_onnx(onnx_file_path): 76 | engine = None ## add this to avoid return deleted engine 77 | EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) 78 | with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, builder.create_builder_config() as config, trt.OnnxParser(network, TRT_LOGGER) as parser, trt.Runtime(TRT_LOGGER) as runtime: 79 | 80 | # Parse model file 81 | print(f'Loading ONNX file from path {onnx_file_path}...') 82 | assert os.path.exists(onnx_file_path), f'cannot find {onnx_file_path}' 83 | with open(onnx_file_path, 'rb') as fr: 84 | if not parser.parse(fr.read()): 85 | print ('ERROR: Failed to parse the ONNX file.') 86 | for error in range(parser.num_errors): 87 | print (parser.get_error(error)) 88 | assert False 89 | 90 | # build settings 91 | builder.max_batch_size = 128 92 | config.max_workspace_size = 1 << 30 # 1G 93 | if args.quant == 'fp16': 94 | config.set_flag(trt.BuilderFlag.FP16) 95 | 96 | print("Start to build Engine") 97 | plan = builder.build_serialized_network(network, config) 98 | engine = runtime.deserialize_cuda_engine(plan) 99 | return engine 100 | 101 | 102 | def serialize_engine_to_file(engine, savepth): 103 | plan = engine.serialize() 104 | with open(savepth, "wb") as fw: 105 | fw.write(plan) 106 | 107 | 108 | def deserialize_engine_from_file(savepth): 109 | with open(savepth, 'rb') as fr, trt.Runtime(TRT_LOGGER) as runtime: 110 | engine = runtime.deserialize_cuda_engine(fr.read()) 111 | return engine 112 | 113 | 114 | def main(): 115 | if args.command == 'compile': 116 | engine = build_engine_from_onnx(args.onnx) 117 | serialize_engine_to_file(engine, args.savepth) 118 | 119 | elif args.command == 'run': 120 | engine = deserialize_engine_from_file(args.mdpth) 121 | 122 | ishape = engine.get_binding_shape(0) 123 | img, (orgH, orgW) = get_image(args.impth, ishape[2:]) 124 | 125 | ## create engine and allocate bffers 126 | ( 127 | stream, 128 | h_input, 129 | d_input, 130 | h_outputs, 131 | d_outputs, 132 | ) = allocate_buffers(engine) 133 | ctx.push() 134 | context = engine.create_execution_context() 135 | ctx.pop() 136 | bds = [int(d_input), ] + [int(el) for el in d_outputs] 137 | 138 | h_input = np.ascontiguousarray(img) 139 | cuda.memcpy_htod_async(d_input, h_input, stream) 140 | context.execute_async( 141 | bindings=bds, stream_handle=stream.handle) 142 | for h_output, d_output in zip(h_outputs, d_outputs): 143 | cuda.memcpy_dtoh_async(h_output, d_output, stream) 144 | stream.synchronize() 145 | 146 | oshape = engine.get_binding_shape(1) 147 | pred = np.argmax(h_outputs[0].reshape(oshape), axis=1) 148 | out = palette[pred] 149 | out = out.reshape(*oshape[2:], 3) 150 | out = cv2.resize(out, (orgW, orgH)) 151 | cv2.imwrite(args.outpth, out) 152 | 153 | 154 | 155 | if __name__ == '__main__': 156 | main() 157 | 158 | -------------------------------------------------------------------------------- /tensorrt/trt_dep.hpp: -------------------------------------------------------------------------------- 1 | #ifndef _TRT_DEP_HPP_ 2 | #define _TRT_DEP_HPP_ 3 | 4 | #include "NvInfer.h" 5 | #include "NvOnnxParser.h" 6 | #include "NvInferPlugin.h" 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include "argmax_plugin.h" 14 | 15 | 16 | using std::string; 17 | using std::vector; 18 | using std::cout; 19 | using std::endl; 20 | 21 | using nvinfer1::ICudaEngine; 22 | using nvinfer1::ILogger; 23 | using nvinfer1::IRuntime; 24 | using Severity = nvinfer1::ILogger::Severity; 25 | 26 | 27 | void CHECK(bool success, string msg); 28 | 29 | 30 | class Logger: public ILogger { 31 | public: 32 | void log(Severity severity, const char* msg) noexcept override { 33 | if (severity != Severity::kINFO) { 34 | std::cout << msg << std::endl; 35 | } 36 | } 37 | }; 38 | 39 | struct TrtDeleter { 40 | template 41 | void operator()(T* obj) const { 42 | delete obj; 43 | } 44 | }; 45 | 46 | struct CudaStreamDeleter { 47 | void operator()(cudaStream_t* stream) const { 48 | cudaStreamDestroy(*stream); 49 | } 50 | }; 51 | 52 | template 53 | using TrtUnqPtr = std::unique_ptr; 54 | using CudaStreamUnqPtr = std::unique_ptr; 55 | using TrtSharedEnginePtr = std::shared_ptr; 56 | 57 | 58 | extern Logger gLogger; 59 | 60 | 61 | struct SemanticSegmentTrt { 62 | public: 63 | TrtSharedEnginePtr engine; 64 | CudaStreamUnqPtr stream; 65 | TrtUnqPtr runtime; 66 | std::unique_ptr plugin_creator; 67 | 68 | string input_name; 69 | string output_name; 70 | int opt_bsize{1}; 71 | 72 | SemanticSegmentTrt(): 73 | engine(nullptr), runtime(nullptr), stream(nullptr) { 74 | 75 | stream.reset(new cudaStream_t); 76 | auto fail = cudaStreamCreate(stream.get()); 77 | CHECK(!fail, "create stream failed"); 78 | 79 | register_plugins(); 80 | } 81 | 82 | ~SemanticSegmentTrt() { 83 | engine.reset(); 84 | runtime.reset(); 85 | stream.reset(); 86 | } 87 | 88 | void register_plugins(); 89 | 90 | void set_opt_batch_size(int bs); 91 | 92 | void serialize(string save_path); 93 | 94 | void deserialize(string serpth); 95 | 96 | void parse_to_engine(string onnx_path, string quant, 97 | string data_root, string data_file); 98 | 99 | vector inference(vector& data); 100 | 101 | void test_speed_fps(); 102 | 103 | vector get_input_shape(); 104 | vector get_output_shape(); 105 | }; 106 | 107 | 108 | #endif 109 | -------------------------------------------------------------------------------- /tis/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## A simple demo of using trition-inference-serving 4 | 5 | ### Platform 6 | 7 | * ubuntu 18.04 8 | * cmake-3.22.0 9 | * 8 Tesla T4 gpu 10 | 11 | 12 | ### Serving Model 13 | 14 | #### 1. prepare model repository 15 | 16 | We need to export our model to onnx and copy it to model repository: 17 | ``` 18 | $ cd BiSeNet 19 | $ python tools/export_onnx.py --config configs/bisenetv1_city.py --weight-path /path/to/your/model.pth --outpath ./model.onnx 20 | $ cp -riv ./model.onnx tis/models/bisenetv1/1 21 | 22 | $ python tools/export_onnx.py --config configs/bisenetv2_city.py --weight-path /path/to/your/model.pth --outpath ./model.onnx 23 | $ cp -riv ./model.onnx tis/models/bisenetv2/1 24 | ``` 25 | 26 | #### 2. prepare the preprocessing backend 27 | We can use either python backend or cpp backend for preprocessing in the server side. 28 | Firstly, we pull the docker image, and start a serving container: 29 | ``` 30 | $ docker pull nvcr.io/nvidia/tritonserver:22.07-py3 31 | $ docker run -it --gpus all --rm -p8000:8000 -p8001:8001 -p8002:8002 -v /path/to/BiSeNet/tis/models:/models -v /path/to/BiSeNet/:/BiSeNet nvcr.io/nvidia/tritonserver:21.10-py3 bash 32 | ``` 33 | From here on, we are in the container environment. Let's prepare the backends in the container: 34 | ``` 35 | # ln -s /usr/local/bin/pip3.8 /usr/bin/pip3.8 36 | # /usr/bin/python3 -m pip install pillow 37 | # apt update && apt install rapidjson-dev libopencv-dev 38 | ``` 39 | Then we download cmake 3.22 and unzip in the container, we use this cmake 3.22 in the following operations. 40 | We compile c++ backends: 41 | ``` 42 | # cp -riv /BiSeNet/tis/self_backend /opt/tritonserver/backends 43 | # chmod 777 /opt/tritonserver/backends/self_backend 44 | # cd /opt/tritonserver/backends/self_backend 45 | # mkdir -p build && cd build 46 | # cmake .. && make -j4 47 | # mv -iuv libtriton_self_backend.so .. 48 | ``` 49 | Utils now, we should have backends prepared. 50 | 51 | 52 | 53 | #### 3. start service 54 | We start the server in the docker container, following the above steps: 55 | ``` 56 | # tritonserver --model-repository=/models 57 | ``` 58 | In general, the service would start now. You can check whether service has started by: 59 | ``` 60 | $ curl -v localhost:8000/v2/health/ready 61 | ``` 62 | 63 | By default, we use gpu 0 and gpu 1, you can change configurations in the `config.pbtxt` file. 64 | 65 | 66 | ### Request with client 67 | 68 | We call the model service with both python and c++ method. 69 | 70 | From here on, we are at the client machine, rather than the server docker container. 71 | 72 | 73 | #### 1. python method 74 | 75 | Firstly, we need to install dependency package: 76 | ``` 77 | $ python -m pip install tritonclient[all]==2.15.0 78 | ``` 79 | 80 | Then we can run the script for both http request and grpc request: 81 | ``` 82 | $ cd BiSeNet/tis 83 | $ python client_http.py # if you want to use http client 84 | $ python client_grpc.py # if you want to use grpc client 85 | ``` 86 | 87 | This would generate a result file named `res.jpg` in `BiSeNet/tis` directory. 88 | 89 | 90 | #### 2. c++ method 91 | 92 | We need to compile c++ client library from source: 93 | ``` 94 | $ apt install rapidjson-dev 95 | $ mkdir -p /data/ $$ cd /data/ 96 | $ git clone https://github.com/triton-inference-server/client.git 97 | $ cd client && git reset --hard da04158bc094925a56b 98 | $ mkdir -p build && cd build 99 | $ cmake -DCMAKE_INSTALL_PREFIX=/opt/triton_client -DTRITON_ENABLE_CC_HTTP=ON -DTRITON_ENABLE_CC_GRPC=ON -DTRITON_ENABLE_PERF_ANALYZER=OFF -DTRITON_ENABLE_PYTHON_HTTP=OFF -DTRITON_ENABLE_PYTHON_GRPC=OFF -DTRITON_ENABLE_JAVA_HTTP=OFF -DTRITON_ENABLE_GPU=ON -DTRITON_ENABLE_EXAMPLES=OFF -DTRITON_ENABLE_TESTS=ON .. 100 | $ make cc-clients 101 | ``` 102 | The above commands are exactly what I used to compile the library. I learned these commands from the official document. 103 | 104 | Also, We need to install `cmake` with version `3.22`. 105 | 106 | Optionally, I compiled opencv from source and install it to `/opt/opencv`. You can first skip this and see whether you meet problems. If you have problems about opencv in the following steps, you can compile opencv as what I do. 107 | 108 | After installing the dependencies, we can compile our c++ client: 109 | ``` 110 | $ cd BiSeNet/tis/cpp_client 111 | $ mkdir -p build && cd build 112 | $ cmake .. && make 113 | ``` 114 | 115 | Finally, we run the client and see a result file named `res.jpg` generated: 116 | ``` 117 | ./client 118 | ``` 119 | 120 | 121 | ### In the end 122 | 123 | This is a simple demo with only basic function. There are many other features that is useful, such as shared memory and dynamic batching. If you have interests on this, you can learn more in the official document. 124 | -------------------------------------------------------------------------------- /tis/client_backend.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import argparse 4 | import sys 5 | import numpy as np 6 | import cv2 7 | import gevent.ssl 8 | 9 | import tritonclient.http as httpclient 10 | from tritonclient.utils import InferenceServerException 11 | 12 | 13 | np.random.seed(123) 14 | palette = np.random.randint(0, 256, (100, 3)) 15 | 16 | 17 | url = '10.128.61.8:8000' 18 | # url = '127.0.0.1:8000' 19 | model_name = 'preprocess_cpp' 20 | model_version = '1' 21 | inp_name = 'raw_img_bytes' 22 | outp_name = 'processed_img' 23 | inp_dtype = 'UINT8' 24 | impth = '../example.png' 25 | mean = [0.3257, 0.3690, 0.3223] # city, rgb 26 | std = [0.2112, 0.2148, 0.2115] 27 | 28 | 29 | ## prepare image and mean/std 30 | inp_data = np.fromfile(impth, dtype=np.uint8)[None, ...] 31 | mean = np.array(mean, dtype=np.float32)[None, ...] 32 | std = np.array(std, dtype=np.float32)[None, ...] 33 | inputs = [] 34 | inputs.append(httpclient.InferInput(inp_name, inp_data.shape, inp_dtype)) 35 | inputs.append(httpclient.InferInput('channel_mean', mean.shape, 'FP32')) 36 | inputs.append(httpclient.InferInput('channel_std', std.shape, 'FP32')) 37 | inputs[0].set_data_from_numpy(inp_data, binary_data=True) 38 | inputs[1].set_data_from_numpy(mean, binary_data=True) 39 | inputs[2].set_data_from_numpy(std, binary_data=True) 40 | 41 | ## client 42 | triton_client = httpclient.InferenceServerClient( 43 | url=url, verbose=False, concurrency=32) 44 | 45 | ## infer 46 | # sync 47 | # results = triton_client.infer(model_name, inputs) 48 | 49 | 50 | # async 51 | # results = triton_client.async_infer( 52 | # model_name, 53 | # inputs, 54 | # outputs=None, 55 | # query_params=None, 56 | # headers=None, 57 | # request_compression_algorithm=None, 58 | # response_compression_algorithm=None) 59 | # results = results.get_result() # async infer only 60 | 61 | 62 | ## dynamic batching, this is not allowed, since different pictures has different raw size 63 | results = [] 64 | for i in range(10): 65 | r = triton_client.async_infer( 66 | model_name, 67 | inputs, 68 | outputs=None, 69 | query_params=None, 70 | headers=None, 71 | request_compression_algorithm=None, 72 | response_compression_algorithm=None) 73 | results.append(r) 74 | for i in range(10): 75 | results[i].get_result() 76 | results = results[i] 77 | 78 | 79 | # get output 80 | outp = results.as_numpy(outp_name).squeeze() 81 | print(outp.shape) 82 | -------------------------------------------------------------------------------- /tis/client_grpc.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import cv2 4 | 5 | import grpc 6 | 7 | from tritonclient.grpc import service_pb2, service_pb2_grpc 8 | import tritonclient.grpc.model_config_pb2 as mc 9 | 10 | 11 | np.random.seed(123) 12 | palette = np.random.randint(0, 256, (100, 3)) 13 | 14 | 15 | 16 | url = '10.128.61.8:8001' 17 | # url = '127.0.0.1:8001' 18 | model_name = 'bisenetv1' 19 | model_version = '1' 20 | inp_name = 'raw_img_bytes' 21 | outp_name = 'preds' 22 | inp_dtype = 'UINT8' 23 | outp_dtype = np.int64 24 | impth = '../example.png' 25 | mean = [0.3257, 0.3690, 0.3223] # city, rgb 26 | std = [0.2112, 0.2148, 0.2115] 27 | 28 | 29 | ## input data and mean/std 30 | inp_data = np.fromfile(impth, dtype=np.uint8)[None, ...] 31 | mean = np.array(mean, dtype=np.float32)[None, ...] 32 | std = np.array(std, dtype=np.float32)[None, ...] 33 | inputs = [service_pb2.ModelInferRequest().InferInputTensor() for _ in range(3)] 34 | inputs[0].name = inp_name 35 | inputs[0].datatype = inp_dtype 36 | inputs[0].shape.extend(inp_data.shape) 37 | inputs[1].name = 'channel_mean' 38 | inputs[1].datatype = 'FP32' 39 | inputs[1].shape.extend(mean.shape) 40 | inputs[2].name = 'channel_std' 41 | inputs[2].datatype = 'FP32' 42 | inputs[2].shape.extend(std.shape) 43 | inp_bytes = [inp_data.tobytes(), mean.tobytes(), std.tobytes()] 44 | 45 | 46 | option = [ 47 | ('grpc.max_receive_message_length', 1073741824), 48 | ('grpc.max_send_message_length', 1073741824), 49 | ] 50 | channel = grpc.insecure_channel(url, options=option) 51 | grpc_stub = service_pb2_grpc.GRPCInferenceServiceStub(channel) 52 | 53 | 54 | metadata_request = service_pb2.ModelMetadataRequest( 55 | name=model_name, version=model_version) 56 | metadata_response = grpc_stub.ModelMetadata(metadata_request) 57 | print(metadata_response) 58 | 59 | config_request = service_pb2.ModelConfigRequest( 60 | name=model_name, 61 | version=model_version) 62 | config_response = grpc_stub.ModelConfig(config_request) 63 | print(config_response) 64 | 65 | 66 | request = service_pb2.ModelInferRequest() 67 | request.model_name = model_name 68 | request.model_version = model_version 69 | 70 | request.ClearField("inputs") 71 | request.ClearField("raw_input_contents") 72 | request.inputs.extend(inputs) 73 | request.raw_input_contents.extend(inp_bytes) 74 | 75 | 76 | # sync 77 | # resp = grpc_stub.ModelInfer(request) 78 | # async 79 | resp = grpc_stub.ModelInfer.future(request) 80 | resp = resp.result() 81 | 82 | outp_bytes = resp.raw_output_contents[0] 83 | outp_shape = resp.outputs[0].shape 84 | 85 | out = np.frombuffer(outp_bytes, dtype=outp_dtype).reshape(*outp_shape).squeeze() 86 | 87 | out = palette[out] 88 | cv2.imwrite('res.png', out) 89 | -------------------------------------------------------------------------------- /tis/client_http.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import argparse 4 | import sys 5 | import numpy as np 6 | import cv2 7 | import gevent.ssl 8 | 9 | import tritonclient.http as httpclient 10 | from tritonclient.utils import InferenceServerException 11 | 12 | 13 | np.random.seed(123) 14 | palette = np.random.randint(0, 256, (100, 3)) 15 | 16 | 17 | url = '10.128.61.8:8000' 18 | # url = '127.0.0.1:8000' 19 | model_name = 'bisenetv2' 20 | model_version = '1' 21 | inp_name = 'raw_img_bytes' 22 | outp_name = 'preds' 23 | inp_dtype = 'UINT8' 24 | impth = '../example.png' 25 | mean = [0.3257, 0.3690, 0.3223] # city, rgb 26 | std = [0.2112, 0.2148, 0.2115] 27 | 28 | 29 | ## prepare image and mean/std 30 | inp_data = np.fromfile(impth, dtype=np.uint8)[None, ...] 31 | mean = np.array(mean, dtype=np.float32)[None, ...] 32 | std = np.array(std, dtype=np.float32)[None, ...] 33 | inputs = [] 34 | inputs.append(httpclient.InferInput(inp_name, inp_data.shape, inp_dtype)) 35 | inputs.append(httpclient.InferInput('channel_mean', mean.shape, 'FP32')) 36 | inputs.append(httpclient.InferInput('channel_std', std.shape, 'FP32')) 37 | inputs[0].set_data_from_numpy(inp_data, binary_data=True) 38 | inputs[1].set_data_from_numpy(mean, binary_data=True) 39 | inputs[2].set_data_from_numpy(std, binary_data=True) 40 | 41 | 42 | ## client 43 | triton_client = httpclient.InferenceServerClient( 44 | url=url, verbose=False, concurrency=32) 45 | 46 | ## infer 47 | # sync 48 | # results = triton_client.infer(model_name, inputs) 49 | 50 | # async 51 | results = triton_client.async_infer( 52 | model_name, 53 | inputs, 54 | outputs=None, 55 | query_params=None, 56 | headers=None, 57 | request_compression_algorithm=None, 58 | response_compression_algorithm=None) 59 | results = results.get_result() # async infer only 60 | 61 | # get output 62 | outp = results.as_numpy(outp_name).squeeze() 63 | out = palette[outp] 64 | cv2.imwrite('res.png', out) 65 | -------------------------------------------------------------------------------- /tis/cpp_client/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required (VERSION 3.18) 2 | 3 | project(Samples) 4 | 5 | set(CMAKE_CXX_FLAGS "-std=c++14 -O2") 6 | set(CMAKE_BUILD_TYPE Release) 7 | 8 | set(CMAKE_PREFIX_PATH 9 | /opt/triton_client/ 10 | /opt/opencv/lib/cmake/opencv4) 11 | find_package(OpenCV REQUIRED) 12 | 13 | include_directories( 14 | ${CMAKE_CURRENT_SOURCE_DIR} 15 | ${CMAKE_CURRENT_BINARY_DIR} 16 | ${OpenCV_INCLUDE_DIRS} 17 | /opt/triton_client/include 18 | ) 19 | link_directories( 20 | /opt/triton_client/lib 21 | ) 22 | 23 | 24 | add_executable(client main.cpp) 25 | target_link_libraries(client PRIVATE 26 | grpcclient 27 | ${OpenCV_LIBS} 28 | -lpthread 29 | ) 30 | -------------------------------------------------------------------------------- /tis/models/bisenetv1/config.pbtxt: -------------------------------------------------------------------------------- 1 | name: "bisenetv1" 2 | platform: "ensemble" 3 | max_batch_size: 256 4 | input [ 5 | { 6 | name: "raw_img_bytes" 7 | data_type: TYPE_UINT8 8 | dims: [ -1 ] 9 | }, 10 | { 11 | name: "channel_mean" 12 | data_type: TYPE_FP32 13 | dims: [ 3 ] 14 | }, 15 | { 16 | name: "channel_std" 17 | data_type: TYPE_FP32 18 | dims: [ 3 ] 19 | } 20 | ] 21 | output [ 22 | { 23 | name: "preds" 24 | data_type: TYPE_INT64 25 | dims: [1, 1024, 2048 ] 26 | } 27 | ] 28 | 29 | ensemble_scheduling { 30 | step [ 31 | { 32 | model_name: "preprocess_py" 33 | model_version: 1 34 | input_map { 35 | key: "raw_img_bytes" 36 | value: "raw_img_bytes" 37 | } 38 | input_map { 39 | key: "channel_mean" 40 | value: "channel_mean" 41 | } 42 | input_map { 43 | key: "channel_std" 44 | value: "channel_std" 45 | } 46 | output_map { 47 | key: "processed_img" 48 | value: "processed_img" 49 | } 50 | }, 51 | { 52 | model_name: "bisenetv1_model" 53 | model_version: 1 54 | input_map { 55 | key: "input_image" 56 | value: "processed_img" 57 | } 58 | output_map { 59 | key: "preds" 60 | value: "preds" 61 | } 62 | } 63 | ] 64 | } 65 | -------------------------------------------------------------------------------- /tis/models/bisenetv1_model/config.pbtxt: -------------------------------------------------------------------------------- 1 | name: "bisenetv1_model" 2 | platform: "onnxruntime_onnx" 3 | max_batch_size: 0 4 | input [ 5 | { 6 | name: "input_image" 7 | data_type: TYPE_FP32 8 | dims: [ 1, 3, 1024, 2048 ] 9 | } 10 | ] 11 | output [ 12 | { 13 | name: "preds" 14 | data_type: TYPE_INT64 15 | dims: [ 1, 1024, 2048 ] 16 | } 17 | ] 18 | optimization { execution_accelerators { # we use tensorrt backend, pure onnxruntime seems to have memory leackage problem 19 | gpu_execution_accelerator : [ { 20 | name : "tensorrt" 21 | parameters { key: "precision_mode" value: "FP16" } 22 | parameters { key: "max_workspace_size_bytes" value: "4294967296" } 23 | }] 24 | }} 25 | instance_group [ 26 | { 27 | count: 2 28 | kind: KIND_GPU 29 | gpus: [ 0, 1 ] 30 | } 31 | ] 32 | -------------------------------------------------------------------------------- /tis/models/bisenetv2/config.pbtxt: -------------------------------------------------------------------------------- 1 | name: "bisenetv2" 2 | platform: "ensemble" 3 | max_batch_size: 256 4 | input [ 5 | { 6 | name: "raw_img_bytes" 7 | data_type: TYPE_UINT8 8 | dims: [ -1 ] 9 | }, 10 | { 11 | name: "channel_mean" 12 | data_type: TYPE_FP32 13 | dims: [ 3 ] 14 | }, 15 | { 16 | name: "channel_std" 17 | data_type: TYPE_FP32 18 | dims: [ 3 ] 19 | } 20 | ] 21 | output [ 22 | { 23 | name: "preds" 24 | data_type: TYPE_INT64 25 | dims: [1, 1024, 2048 ] 26 | } 27 | ] 28 | 29 | ensemble_scheduling { 30 | step [ 31 | { 32 | model_name: "preprocess_cpp" 33 | model_version: 1 34 | input_map { 35 | key: "raw_img_bytes" 36 | value: "raw_img_bytes" 37 | } 38 | input_map { 39 | key: "channel_mean" 40 | value: "channel_mean" 41 | } 42 | input_map { 43 | key: "channel_std" 44 | value: "channel_std" 45 | } 46 | output_map { 47 | key: "processed_img" 48 | value: "processed_img" 49 | } 50 | }, 51 | { 52 | model_name: "bisenetv2_model" 53 | model_version: 1 54 | input_map { 55 | key: "input_image" 56 | value: "processed_img" 57 | } 58 | output_map { 59 | key: "preds" 60 | value: "preds" 61 | } 62 | } 63 | ] 64 | } 65 | -------------------------------------------------------------------------------- /tis/models/bisenetv2_model/config.pbtxt: -------------------------------------------------------------------------------- 1 | name: "bisenetv2_model" 2 | platform: "onnxruntime_onnx" 3 | max_batch_size: 0 4 | input [ 5 | { 6 | name: "input_image" 7 | data_type: TYPE_FP32 8 | dims: [1, 3, 1024, 2048 ] 9 | } 10 | ] 11 | output [ 12 | { 13 | name: "preds" 14 | data_type: TYPE_INT64 15 | dims: [1, 1024, 2048 ] 16 | } 17 | ] 18 | optimization { execution_accelerators { # we use tensorrt backend, pure onnxruntime seems to have memory leackage problem 19 | gpu_execution_accelerator : [ { 20 | name : "tensorrt" 21 | parameters { key: "precision_mode" value: "FP16" } 22 | parameters { key: "max_workspace_size_bytes" value: "4294967296" } 23 | }] 24 | }} 25 | instance_group [ 26 | { 27 | count: 2 28 | kind: KIND_GPU 29 | gpus: [ 0, 1 ] 30 | } 31 | ] 32 | -------------------------------------------------------------------------------- /tis/models/preprocess_cpp/1/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CoinCheung/BiSeNet/3d9b2cf592bcb1185cb52d710b476d8a1bde8120/tis/models/preprocess_cpp/1/.gitkeep -------------------------------------------------------------------------------- /tis/models/preprocess_cpp/config.pbtxt: -------------------------------------------------------------------------------- 1 | name: "preprocess_cpp" 2 | backend: "self_backend" 3 | max_batch_size: 256 4 | # dynamic_batching { ## this is not allowed, since we cannot know raw bytes size of each inputs from the server, as they just concat the bytes together 5 | # max_queue_delay_microseconds: 5000000 6 | # } 7 | input [ 8 | { 9 | name: "raw_img_bytes" 10 | data_type: TYPE_UINT8 11 | dims: [ -1 ] 12 | }, 13 | { 14 | name: "channel_mean" 15 | data_type: TYPE_FP32 16 | dims: [ 3 ] 17 | }, 18 | { 19 | name: "channel_std" 20 | data_type: TYPE_FP32 21 | dims: [ 3 ] 22 | } 23 | ] 24 | output [ 25 | { 26 | name: "processed_img" 27 | data_type: TYPE_FP32 28 | dims: [ 1, 3, 1024, 2048 ] 29 | } 30 | ] 31 | instance_group [ 32 | { 33 | kind: KIND_CPU 34 | } 35 | ] 36 | -------------------------------------------------------------------------------- /tis/models/preprocess_py/1/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | import json 4 | import io 5 | 6 | # triton_python_backend_utils is available in every Triton Python model. You 7 | # need to use this module to create inference requests and responses. It also 8 | # contains some utility functions for extracting information from model_config 9 | # and converting Triton input/output types to numpy types. 10 | import triton_python_backend_utils as pb_utils 11 | 12 | from PIL import Image 13 | import os 14 | 15 | 16 | class TritonPythonModel: 17 | """Your Python model must use the same class name. Every Python model 18 | that is created must have "TritonPythonModel" as the class name. 19 | """ 20 | 21 | def initialize(self, args): 22 | """`initialize` is called only once when the model is being loaded. 23 | Implementing `initialize` function is optional. This function allows 24 | the model to intialize any state associated with this model. 25 | 26 | Parameters 27 | ---------- 28 | args : dict 29 | Both keys and values are strings. The dictionary keys and values are: 30 | * model_config: A JSON string containing the model configuration 31 | * model_instance_kind: A string containing model instance kind 32 | * model_instance_device_id: A string containing model instance device ID 33 | * model_repository: Model repository path 34 | * model_version: Model version 35 | * model_name: Model name 36 | """ 37 | 38 | # You must parse model_config. JSON string is not parsed here 39 | self.model_config = model_config = json.loads(args['model_config']) 40 | 41 | # Get OUTPUT0 configuration 42 | output0_config = pb_utils.get_output_config_by_name( 43 | model_config, "processed_img") 44 | 45 | # Convert Triton types to numpy types 46 | self.output0_dtype = pb_utils.triton_string_to_numpy( 47 | output0_config['data_type']) 48 | 49 | self.output0_shape = output0_config['dims'] 50 | 51 | def execute(self, requests): 52 | """`execute` MUST be implemented in every Python model. `execute` 53 | function receives a list of pb_utils.InferenceRequest as the only 54 | argument. This function is called when an inference request is made 55 | for this model. Depending on the batching configuration (e.g. Dynamic 56 | Batching) used, `requests` may contain multiple requests. Every 57 | Python model, must create one pb_utils.InferenceResponse for every 58 | pb_utils.InferenceRequest in `requests`. If there is an error, you can 59 | set the error argument when creating a pb_utils.InferenceResponse 60 | 61 | Parameters 62 | ---------- 63 | requests : list 64 | A list of pb_utils.InferenceRequest 65 | 66 | Returns 67 | ------- 68 | list 69 | A list of pb_utils.InferenceResponse. The length of this list must 70 | be the same as `requests` 71 | """ 72 | 73 | output0_dtype = self.output0_dtype 74 | N, C, H, W = self.output0_shape 75 | 76 | responses = [] 77 | 78 | # Every Python backend must iterate over everyone of the requests 79 | # and create a pb_utils.InferenceResponse for each of them. 80 | for request in requests: 81 | # Get INPUT0 82 | im_bytes = pb_utils.get_input_tensor_by_name(request, "raw_img_bytes") 83 | im_bytes = im_bytes.as_numpy().tobytes() 84 | im = Image.open(io.BytesIO(im_bytes)) 85 | im = im.resize((W, H), Image.ANTIALIAS) 86 | im = np.array(im) 87 | 88 | # Get mean/std 89 | mean = pb_utils.get_input_tensor_by_name(request, "channel_mean") 90 | std = pb_utils.get_input_tensor_by_name(request, "channel_std") 91 | mean = mean.as_numpy().reshape(1, 1, 3) 92 | std = std.as_numpy().reshape(1, 1, 3) 93 | 94 | # preprocess 95 | im = ((im / 255.) - mean) / std 96 | im = im[None, ...].transpose(0, 3, 1, 2).astype(np.float32) 97 | 98 | 99 | out_tensor_0 = pb_utils.Tensor("processed_img", im) 100 | 101 | # Create InferenceResponse. You can set an error here in case 102 | # there was a problem with handling this inference request. 103 | # Below is an example of how you can set errors in inference 104 | # response: 105 | # 106 | # pb_utils.InferenceResponse( 107 | # output_tensors=..., TritonError("An error occured")) 108 | inference_response = pb_utils.InferenceResponse( 109 | output_tensors=[out_tensor_0]) 110 | responses.append(inference_response) 111 | 112 | # You should return a list of pb_utils.InferenceResponse. Length 113 | # of this list must match the length of `requests` list. 114 | return responses 115 | 116 | def finalize(self): 117 | """`finalize` is called only once when the model is being unloaded. 118 | Implementing `finalize` function is OPTIONAL. This function allows 119 | the model to perform any necessary clean ups before exit. 120 | """ 121 | print('Cleaning up...') 122 | 123 | -------------------------------------------------------------------------------- /tis/models/preprocess_py/config.pbtxt: -------------------------------------------------------------------------------- 1 | name: "preprocess_py" 2 | backend: "python" 3 | max_batch_size: 256 4 | input [ 5 | { 6 | name: "raw_img_bytes" 7 | data_type: TYPE_UINT8 8 | dims: [ -1 ] 9 | }, 10 | { 11 | name: "channel_mean" 12 | data_type: TYPE_FP32 13 | dims: [ 3 ] 14 | }, 15 | { 16 | name: "channel_std" 17 | data_type: TYPE_FP32 18 | dims: [ 3 ] 19 | } 20 | ] 21 | 22 | output [ 23 | { 24 | name: "processed_img" 25 | data_type: TYPE_FP32 26 | dims: [1, 3, 1024, 2048 ] 27 | } 28 | ] 29 | 30 | instance_group [{ kind: KIND_CPU }] 31 | -------------------------------------------------------------------------------- /tis/self_backend/cmake/TutorialRecommendedBackendConfig.cmake.in: -------------------------------------------------------------------------------- 1 | # Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | include(CMakeFindDependencyMacro) 28 | 29 | get_filename_component( 30 | TUTORIALRECOMMENDEDBACKEND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH 31 | ) 32 | 33 | list(APPEND CMAKE_MODULE_PATH ${TUTORIALRECOMMENDEDBACKEND_CMAKE_DIR}) 34 | 35 | if(NOT TARGET TutorialRecommendedBackend::triton-recommended-backend) 36 | include("${TUTORIALRECOMMENDEDBACKEND_CMAKE_DIR}/TutorialRecommendedBackendTargets.cmake") 37 | endif() 38 | 39 | set(TUTORIALRECOMMENDEDBACKEND_LIBRARIES TutorialRecommendedBackend::triton-recommended-backend) 40 | -------------------------------------------------------------------------------- /tis/self_backend/src/libtriton_recommended.ldscript: -------------------------------------------------------------------------------- 1 | # Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | { 27 | global: 28 | TRITONBACKEND_*; 29 | local: *; 30 | }; 31 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CoinCheung/BiSeNet/3d9b2cf592bcb1185cb52d710b476d8a1bde8120/tools/__init__.py -------------------------------------------------------------------------------- /tools/check_dataset_info.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import os.path as osp 4 | import argparse 5 | from tqdm import tqdm 6 | 7 | import cv2 8 | import numpy as np 9 | 10 | 11 | parse = argparse.ArgumentParser() 12 | parse.add_argument('--im_root', dest='im_root', type=str, default='./datasets/cityscapes',) 13 | parse.add_argument('--im_anns', dest='im_anns', type=str, default='./datasets/cityscapes/train.txt',) 14 | parse.add_argument('--lb_ignore', dest='lb_ignore', type=int, default=255) 15 | args = parse.parse_args() 16 | 17 | lb_ignore = args.lb_ignore 18 | 19 | 20 | with open(args.im_anns, 'r') as fr: 21 | lines = fr.read().splitlines() 22 | 23 | n_pairs = len(lines) 24 | impaths, lbpaths = [], [] 25 | for l in lines: 26 | impth, lbpth = l.split(',') 27 | impth = osp.join(args.im_root, impth) 28 | lbpth = osp.join(args.im_root, lbpth) 29 | impaths.append(impth) 30 | lbpaths.append(lbpth) 31 | 32 | 33 | ## shapes 34 | max_shape_area, min_shape_area = [0, 0], [100000, 100000] 35 | max_shape_height, min_shape_height = [0, 0], [100000, 100000] 36 | max_shape_width, min_shape_width = [0, 0], [100000, 100000] 37 | max_lb_val, min_lb_val = -1, 10000000 38 | for impth, lbpth in tqdm(zip(impaths, lbpaths), total=n_pairs): 39 | im = cv2.imread(impth)[:, :, ::-1] 40 | lb = cv2.imread(lbpth, 0) 41 | assert im.shape[:2] == lb.shape 42 | 43 | shape = lb.shape 44 | area = shape[0] * shape[1] 45 | if area > max_shape_area[0] * max_shape_area[1]: 46 | max_shape_area = shape 47 | if area < min_shape_area[0] * min_shape_area[1]: 48 | min_shape_area = shape 49 | 50 | if shape[0] > max_shape_height[0]: 51 | max_shape_height = shape 52 | if shape[0] < min_shape_height[0]: 53 | min_shape_height = shape 54 | 55 | if shape[1] > max_shape_width[1]: 56 | max_shape_width = shape 57 | if shape[1] < min_shape_width[1]: 58 | min_shape_width = shape 59 | 60 | lb = lb[lb != lb_ignore] 61 | if lb.size > 0: 62 | max_lb_val = max(max_lb_val, np.max(lb)) 63 | min_lb_val = min(min_lb_val, np.min(lb)) 64 | 65 | ## label info 66 | lb_minlength = max_lb_val+1-min_lb_val 67 | lb_hist = np.zeros(lb_minlength) 68 | for lbpth in tqdm(lbpaths): 69 | lb = cv2.imread(lbpth, 0) 70 | lb = lb[lb != lb_ignore] - min_lb_val 71 | lb_hist += np.bincount(lb, minlength=lb_minlength) 72 | 73 | lb_missing_vals = [ind + min_lb_val 74 | for ind, el in enumerate(lb_hist.tolist()) if el == 0] 75 | lb_ratios = (lb_hist / lb_hist.sum()).tolist() 76 | 77 | 78 | ## pixel mean/std 79 | rgb_mean = np.zeros(3).astype(np.float32) 80 | n_pixels = 0 81 | for impth in tqdm(impaths): 82 | im = cv2.imread(impth)[:, :, ::-1].astype(np.float32) 83 | im = im.reshape(-1, 3) / 255. 84 | n_pixels += im.shape[0] 85 | rgb_mean += im.sum(axis=0) 86 | rgb_mean = (rgb_mean / n_pixels) 87 | 88 | rgb_std = np.zeros(3).astype(np.float32) 89 | for impth in tqdm(impaths): 90 | im = cv2.imread(impth)[:, :, ::-1].astype(np.float32) 91 | im = im.reshape(-1, 3) / 255. 92 | 93 | a = (im - rgb_mean.reshape(1, 3)) ** 2 94 | rgb_std += a.sum(axis=0) 95 | rgb_std = (rgb_std / n_pixels) ** 0.5 96 | 97 | rgb_mean = rgb_mean.tolist() 98 | rgb_std = rgb_std.tolist() 99 | 100 | 101 | print('\n') 102 | print(f'there are {n_pairs} lines in {args.im_anns}, which means {n_pairs} image/label image pairs') 103 | print('\n') 104 | 105 | print(f'max and min image shapes by area are: {max_shape_area}, {min_shape_area}') 106 | print(f'max and min image shapes by height are: {max_shape_height}, {min_shape_height}') 107 | print(f'max and min image shapes by width are: {max_shape_width}, {min_shape_width}') 108 | print('\n') 109 | 110 | print(f'we ignore label value of {args.lb_ignore} in label images') 111 | print(f'label values are within range of [{min_lb_val}, {max_lb_val}]') 112 | print(f'label values that are missing: {lb_missing_vals}') 113 | print('ratios of each label value(from small to big, without ignored): ') 114 | print('\t', lb_ratios) 115 | print('\n') 116 | 117 | print('pixel mean rgb: ', rgb_mean) 118 | print('pixel std rgb: ', rgb_std) 119 | -------------------------------------------------------------------------------- /tools/conver_to_trt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | import sys 4 | sys.path.insert(0, '.') 5 | 6 | import torch 7 | from torch2trt import torch2trt 8 | 9 | from lib.models import model_factory 10 | from configs import set_cfg_from_file 11 | 12 | torch.set_grad_enabled(False) 13 | 14 | 15 | parse = argparse.ArgumentParser() 16 | parse.add_argument('--config', dest='config', type=str, default='configs/bisenetv2.py',) 17 | parse.add_argument('--weight-path', type=str, default='./res/model_final.pth',) 18 | parse.add_argument('--fp16', action='store_true') 19 | parse.add_argument('--outpath', dest='out_pth', type=str, 20 | default='model.trt') 21 | args = parse.parse_args() 22 | 23 | 24 | cfg = set_cfg_from_file(args.config) 25 | if cfg.use_sync_bn: cfg.use_sync_bn = False 26 | 27 | net = model_factory[cfg.model_type](cfg.n_cats, aux_mode='pred') 28 | net.load_state_dict(torch.load(args.weight_path), strict=False) 29 | net.cuda() 30 | net.eval() 31 | 32 | 33 | # dummy_input = torch.randn(1, 3, *cfg.crop_size) 34 | dummy_input = torch.randn(1, 3, 1024, 2048).cuda() 35 | 36 | trt_model = torch2trt(net, [dummy_input, ], fp16_mode=args.fp16, max_workspace=1 << 30) 37 | 38 | with open(args.out_pth, 'wb') as fw: 39 | fw.write(trt_model.engine.serialize()) 40 | -------------------------------------------------------------------------------- /tools/demo.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | sys.path.insert(0, '.') 4 | import argparse 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from PIL import Image 10 | import numpy as np 11 | import cv2 12 | 13 | import lib.data.transform_cv2 as T 14 | from lib.models import model_factory 15 | from configs import set_cfg_from_file 16 | 17 | 18 | # uncomment the following line if you want to reduce cpu usage, see issue #231 19 | # torch.set_num_threads(4) 20 | 21 | torch.set_grad_enabled(False) 22 | np.random.seed(123) 23 | 24 | 25 | # args 26 | parse = argparse.ArgumentParser() 27 | parse.add_argument('--config', dest='config', type=str, default='configs/bisenetv2.py',) 28 | parse.add_argument('--weight-path', type=str, default='./res/model_final.pth',) 29 | parse.add_argument('--img-path', dest='img_path', type=str, default='./example.png',) 30 | args = parse.parse_args() 31 | cfg = set_cfg_from_file(args.config) 32 | 33 | 34 | palette = np.random.randint(0, 256, (256, 3), dtype=np.uint8) 35 | 36 | # define model 37 | net = model_factory[cfg.model_type](cfg.n_cats, aux_mode='eval') 38 | net.load_state_dict(torch.load(args.weight_path, map_location='cpu'), strict=False) 39 | net.eval() 40 | net.cuda() 41 | 42 | # prepare data 43 | to_tensor = T.ToTensor( 44 | mean=(0.3257, 0.3690, 0.3223), # city, rgb 45 | std=(0.2112, 0.2148, 0.2115), 46 | ) 47 | im = cv2.imread(args.img_path)[:, :, ::-1] 48 | im = to_tensor(dict(im=im, lb=None))['im'].unsqueeze(0).cuda() 49 | 50 | # shape divisor 51 | org_size = im.size()[2:] 52 | new_size = [math.ceil(el / 32) * 32 for el in im.size()[2:]] 53 | 54 | # inference 55 | im = F.interpolate(im, size=new_size, align_corners=False, mode='bilinear') 56 | out = net(im)[0] 57 | out = F.interpolate(out, size=org_size, align_corners=False, mode='bilinear') 58 | out = out.argmax(dim=1) 59 | 60 | # visualize 61 | out = out.squeeze().detach().cpu().numpy() 62 | pred = palette[out] 63 | cv2.imwrite('./res.jpg', pred) 64 | -------------------------------------------------------------------------------- /tools/demo_video.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | sys.path.insert(0, '.') 4 | import argparse 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.multiprocessing as mp 9 | import time 10 | from PIL import Image 11 | import numpy as np 12 | import cv2 13 | 14 | import lib.data.transform_cv2 as T 15 | from lib.models import model_factory 16 | from configs import set_cfg_from_file 17 | 18 | 19 | torch.set_grad_enabled(False) 20 | 21 | 22 | # args 23 | parse = argparse.ArgumentParser() 24 | parse.add_argument('--config', dest='config', type=str, default='configs/bisenetv2.py',) 25 | parse.add_argument('--weight-path', type=str, default='./res/model_final.pth',) 26 | parse.add_argument('--input', dest='input', type=str, default='./example.mp4',) 27 | parse.add_argument('--output', dest='output', type=str, default='./res.mp4',) 28 | args = parse.parse_args() 29 | cfg = set_cfg_from_file(args.config) 30 | 31 | 32 | 33 | # define model 34 | def get_model(): 35 | net = model_factory[cfg.model_type](cfg.n_cats, aux_mode='eval') 36 | net.load_state_dict(torch.load(args.weight_path, map_location='cpu'), strict=False) 37 | net.eval() 38 | net.cuda() 39 | return net 40 | 41 | 42 | # fetch frames 43 | def get_func(inpth, in_q, done): 44 | cap = cv2.VideoCapture(args.input) 45 | width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # type is float 46 | height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # type is float 47 | fps = cap.get(cv2.CAP_PROP_FPS) 48 | 49 | to_tensor = T.ToTensor( 50 | mean=(0.3257, 0.3690, 0.3223), # city, rgb 51 | std=(0.2112, 0.2148, 0.2115), 52 | ) 53 | 54 | while cap.isOpened(): 55 | ret, frame = cap.read() 56 | if not ret: break 57 | frame = frame[:, :, ::-1] 58 | frame = to_tensor(dict(im=frame, lb=None))['im'].unsqueeze(0) 59 | in_q.put(frame) 60 | 61 | in_q.put('quit') 62 | done.wait() 63 | 64 | cap.release() 65 | time.sleep(1) 66 | print('input queue done') 67 | 68 | 69 | # save to video 70 | def save_func(inpth, outpth, out_q): 71 | np.random.seed(123) 72 | palette = np.random.randint(0, 256, (256, 3), dtype=np.uint8) 73 | 74 | cap = cv2.VideoCapture(args.input) 75 | width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # type is float 76 | height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # type is float 77 | fps = cap.get(cv2.CAP_PROP_FPS) 78 | cap.release() 79 | 80 | video_writer = cv2.VideoWriter(outpth, 81 | cv2.VideoWriter_fourcc(*"mp4v"), 82 | fps, (int(width), int(height))) 83 | 84 | while True: 85 | out = out_q.get() 86 | if out == 'quit': break 87 | out = out.numpy() 88 | preds = palette[out] 89 | for pred in preds: 90 | video_writer.write(pred) 91 | video_writer.release() 92 | print('output queue done') 93 | 94 | 95 | # inference a list of frames 96 | def infer_batch(frames): 97 | frames = torch.cat(frames, dim=0).cuda() 98 | H, W = frames.size()[2:] 99 | frames = F.interpolate(frames, size=(768, 768), mode='bilinear', 100 | align_corners=False) # must be divisible by 32 101 | out = net(frames)[0] 102 | out = F.interpolate(out, size=(H, W), mode='bilinear', 103 | align_corners=False).argmax(dim=1).detach().cpu() 104 | out_q.put(out) 105 | 106 | 107 | 108 | if __name__ == '__main__': 109 | mp.set_start_method('spawn') 110 | 111 | in_q = mp.Queue(1024) 112 | out_q = mp.Queue(1024) 113 | done = mp.Event() 114 | 115 | in_worker = mp.Process(target=get_func, 116 | args=(args.input, in_q, done)) 117 | out_worker = mp.Process(target=save_func, 118 | args=(args.input, args.output, out_q)) 119 | 120 | in_worker.start() 121 | out_worker.start() 122 | 123 | net = get_model() 124 | 125 | frames = [] 126 | while True: 127 | frame = in_q.get() 128 | if frame == 'quit': break 129 | 130 | frames.append(frame) 131 | if len(frames) == 8: 132 | infer_batch(frames) 133 | frames = [] 134 | if len(frames) > 0: 135 | infer_batch(frames) 136 | 137 | out_q.put('quit') 138 | done.set() 139 | 140 | out_worker.join() 141 | in_worker.join() 142 | -------------------------------------------------------------------------------- /tools/export_libtorch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | import sys 4 | sys.path.insert(0, '.') 5 | 6 | import torch 7 | 8 | from lib.models import model_factory 9 | from configs import set_cfg_from_file 10 | 11 | torch.set_grad_enabled(False) 12 | 13 | 14 | parse = argparse.ArgumentParser() 15 | parse.add_argument('--config', dest='config', type=str, 16 | default='configs/bisenetv2.py',) 17 | parse.add_argument('--weight-path', dest='weight_pth', type=str, 18 | default='model_final.pth') 19 | parse.add_argument('--outpath', dest='out_pth', type=str, 20 | default='model.pt') 21 | args = parse.parse_args() 22 | 23 | 24 | cfg = set_cfg_from_file(args.config) 25 | if cfg.use_sync_bn: cfg.use_sync_bn = False 26 | 27 | net = model_factory[cfg.model_type](cfg.n_cats, aux_mode='pred') 28 | net.load_state_dict(torch.load(args.weight_pth, map_location='cpu'), strict=False) 29 | net.eval() 30 | 31 | 32 | # dummy_input = torch.randn(1, 3, *cfg.crop_size) 33 | dummy_input = torch.randn(1, 3, 1024, 2048) 34 | script_module = torch.jit.trace(net, dummy_input) 35 | # script_module.save(args.out_pth, _use_new_zipfile_serialization=False) 36 | script_module.save(args.out_pth) 37 | 38 | -------------------------------------------------------------------------------- /tools/export_onnx.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | import sys 4 | sys.path.insert(0, '.') 5 | 6 | import torch 7 | from torch.onnx import OperatorExportTypes 8 | 9 | from lib.models import model_factory 10 | from configs import set_cfg_from_file 11 | 12 | torch.set_grad_enabled(False) 13 | 14 | 15 | parse = argparse.ArgumentParser() 16 | parse.add_argument('--config', dest='config', type=str, 17 | default='configs/bisenetv2.py',) 18 | parse.add_argument('--weight-path', dest='weight_pth', type=str, 19 | default='model_final.pth') 20 | parse.add_argument('--outpath', dest='out_pth', type=str, 21 | default='model.onnx') 22 | parse.add_argument('--aux-mode', dest='aux_mode', type=str, 23 | default='pred') 24 | args = parse.parse_args() 25 | 26 | 27 | cfg = set_cfg_from_file(args.config) 28 | if cfg.use_sync_bn: cfg.use_sync_bn = False 29 | 30 | net = model_factory[cfg.model_type](cfg.n_cats, aux_mode=args.aux_mode) 31 | net.load_state_dict(torch.load(args.weight_pth, map_location='cpu', 32 | weights_only=True), strict=False) 33 | net.eval() 34 | 35 | 36 | dummy_input = torch.randn(1, 3, *cfg.cropsize) 37 | # dummy_input = torch.randn(1, 3, 1024, 2048) 38 | input_names = ['input_image'] 39 | output_names = ['preds',] 40 | dynamic_axes = {'input_image': {0: 'batch'}, 'preds': {0: 'batch'}} 41 | 42 | torch.onnx.export(net, dummy_input, args.out_pth, 43 | input_names=input_names, output_names=output_names, 44 | verbose=False, opset_version=18, 45 | operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH, 46 | dynamic_axes=dynamic_axes) 47 | 48 | -------------------------------------------------------------------------------- /tools/gen_dataset_annos.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import os.path as osp 4 | import argparse 5 | 6 | 7 | def gen_coco(): 8 | ''' 9 | root_path: 10 | |- images 11 | |- train2017 12 | |- val2017 13 | |- labels 14 | |- train2017 15 | |- val2017 16 | ''' 17 | root_path = './datasets/coco' 18 | save_path = './datasets/coco/' 19 | for mode in ('train', 'val'): 20 | im_root = osp.join(root_path, f'images/{mode}2017') 21 | lb_root = osp.join(root_path, f'labels/{mode}2017') 22 | 23 | ims = os.listdir(im_root) 24 | lbs = os.listdir(lb_root) 25 | 26 | print(len(ims)) 27 | print(len(lbs)) 28 | 29 | im_names = [el.replace('.jpg', '') for el in ims] 30 | lb_names = [el.replace('.png', '') for el in lbs] 31 | common_names = list(set(im_names) & set(lb_names)) 32 | 33 | lines = [ 34 | f'images/{mode}2017/{name}.jpg,labels/{mode}2017/{name}.png' 35 | for name in common_names 36 | ] 37 | 38 | with open(f'{save_path}/{mode}.txt', 'w') as fw: 39 | fw.write('\n'.join(lines)) 40 | 41 | 42 | def gen_ade20k(): 43 | ''' 44 | root_path: 45 | |- images 46 | |- training 47 | |- validation 48 | |- annotations 49 | |- training 50 | |- validation 51 | ''' 52 | root_path = './datasets/ade20k/' 53 | save_path = './datasets/ade20k/' 54 | folder_map = {'train': 'training', 'val': 'validation'} 55 | for mode in ('train', 'val'): 56 | folder = folder_map[mode] 57 | im_root = osp.join(root_path, f'images/{folder}') 58 | lb_root = osp.join(root_path, f'annotations/{folder}') 59 | 60 | ims = os.listdir(im_root) 61 | lbs = os.listdir(lb_root) 62 | 63 | print(len(ims)) 64 | print(len(lbs)) 65 | 66 | im_names = [el.replace('.jpg', '') for el in ims] 67 | lb_names = [el.replace('.png', '') for el in lbs] 68 | common_names = list(set(im_names) & set(lb_names)) 69 | 70 | lines = [ 71 | f'images/{folder}/{name}.jpg,annotations/{folder}/{name}.png' 72 | for name in common_names 73 | ] 74 | 75 | with open(f'{save_path}/{mode}.txt', 'w') as fw: 76 | fw.write('\n'.join(lines)) 77 | 78 | 79 | 80 | if __name__ == '__main__': 81 | parse = argparse.ArgumentParser() 82 | parse.add_argument('--dataset', dest='dataset', type=str, default='coco') 83 | args = parse.parse_args() 84 | 85 | if args.dataset == 'coco': 86 | gen_coco() 87 | elif args.dataset == 'ade20k': 88 | gen_ade20k() 89 | -------------------------------------------------------------------------------- /video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CoinCheung/BiSeNet/3d9b2cf592bcb1185cb52d710b476d8a1bde8120/video.mp4 --------------------------------------------------------------------------------