├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── dataset.py ├── eval_existingOnes.py ├── evaluation └── metrics.py ├── gen_best_ep.py ├── image_proc.py ├── inference.py ├── loss.py ├── make_a_copy.sh ├── models ├── backbones │ ├── build_backbone.py │ ├── pvt_v2.py │ └── swin_v1.py ├── birefnet.py ├── modules │ ├── aspp.py │ ├── decoder_blocks.py │ ├── deform_conv.py │ ├── lateral_blocks.py │ └── utils.py └── refinement │ ├── refiner.py │ └── stem_layer.py ├── requirements.txt ├── rm_cache.sh ├── sub.sh ├── test.sh ├── train.py ├── train.sh ├── train_test.sh ├── tutorials ├── BiRefNet_inference.ipynb ├── BiRefNet_inference_video.ipynb └── BiRefNet_pth2onnx.ipynb └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | comparisons/ 3 | preds-*/ 4 | *.onnx 5 | deform_conv2d_onnx_exporter* 6 | frames-video_*/ 7 | images_todo/ 8 | predictions/ 9 | *.mp4 10 | *.avi 11 | e_* 12 | .vscode 13 | ckpt 14 | preds 15 | evaluation/eval-* 16 | nohup.* 17 | tmp* 18 | *.pth 19 | core-*-python-* 20 | .DS_Store 21 | __MACOSX/ 22 | 23 | # Byte-compiled / optimized / DLL files 24 | __pycache__/ 25 | *.py[cod] 26 | *$py.class 27 | 28 | # C extensions 29 | *.so 30 | 31 | # Distribution / packaging 32 | .Python 33 | build/ 34 | develop-eggs/ 35 | dist/ 36 | downloads/ 37 | eggs/ 38 | .eggs/ 39 | lib/ 40 | lib64/ 41 | parts/ 42 | sdist/ 43 | var/ 44 | wheels/ 45 | pip-wheel-metadata/ 46 | share/python-wheels/ 47 | *.egg-info/ 48 | .installed.cfg 49 | *.egg 50 | MANIFEST 51 | 52 | # PyInstaller 53 | # Usually these files are written by a python script from a template 54 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 55 | *.manifest 56 | *.spec 57 | 58 | # Installer logs 59 | pip-log.txt 60 | pip-delete-this-directory.txt 61 | 62 | # Unit test / coverage reports 63 | htmlcov/ 64 | .tox/ 65 | .nox/ 66 | .coverage 67 | .coverage.* 68 | .cache 69 | nosetests.xml 70 | coverage.xml 71 | *.cover 72 | *.py,cover 73 | .hypothesis/ 74 | .pytest_cache/ 75 | 76 | # Translations 77 | *.mo 78 | *.pot 79 | 80 | # Django stuff: 81 | *.log 82 | local_settings.py 83 | db.sqlite3 84 | db.sqlite3-journal 85 | 86 | # Flask stuff: 87 | instance/ 88 | .webassets-cache 89 | 90 | # Scrapy stuff: 91 | .scrapy 92 | 93 | # Sphinx documentation 94 | docs/_build/ 95 | 96 | # PyBuilder 97 | target/ 98 | 99 | # Jupyter Notebook 100 | .ipynb_checkpoints 101 | 102 | # IPython 103 | profile_default/ 104 | ipython_config.py 105 | 106 | # pyenv 107 | .python-version 108 | 109 | # pipenv 110 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 111 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 112 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 113 | # install all needed dependencies. 114 | #Pipfile.lock 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 ZhengPeng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | 5 | class Config(): 6 | def __init__(self) -> None: 7 | # PATH settings 8 | # Make up your file system as: SYS_HOME_DIR/codes/dis/BiRefNet, SYS_HOME_DIR/datasets/dis/xx, SYS_HOME_DIR/weights/xx 9 | self.sys_home_dir = [os.path.expanduser('~'), '/workspace'][1] # Default, custom 10 | self.data_root_dir = os.path.join(self.sys_home_dir, 'datasets/dis') 11 | 12 | # TASK settings 13 | self.task = ['DIS5K', 'COD', 'HRSOD', 'General', 'General-2K', 'Matting'][0] 14 | self.testsets = { 15 | # Benchmarks 16 | 'DIS5K': ','.join(['DIS-VD', 'DIS-TE1', 'DIS-TE2', 'DIS-TE3', 'DIS-TE4'][:1]), 17 | 'COD': ','.join(['CHAMELEON', 'NC4K', 'TE-CAMO', 'TE-COD10K']), 18 | 'HRSOD': ','.join(['DAVIS-S', 'TE-HRSOD', 'TE-UHRSD', 'DUT-OMRON', 'TE-DUTS']), 19 | # Practical use 20 | 'General': ','.join(['DIS-VD', 'TE-P3M-500-NP']), 21 | 'General-2K': ','.join(['DIS-VD', 'TE-P3M-500-NP']), 22 | 'Matting': ','.join(['TE-P3M-500-NP', 'TE-AM-2k']), 23 | }[self.task] 24 | datasets_all = '+'.join([ds for ds in (os.listdir(os.path.join(self.data_root_dir, self.task)) if os.path.isdir(os.path.join(self.data_root_dir, self.task)) else []) if ds not in self.testsets.split(',')]) 25 | self.training_set = { 26 | 'DIS5K': ['DIS-TR', 'DIS-TR+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'][0], 27 | 'COD': 'TR-COD10K+TR-CAMO', 28 | 'HRSOD': ['TR-DUTS', 'TR-HRSOD', 'TR-UHRSD', 'TR-DUTS+TR-HRSOD', 'TR-DUTS+TR-UHRSD', 'TR-HRSOD+TR-UHRSD', 'TR-DUTS+TR-HRSOD+TR-UHRSD'][5], 29 | 'General': datasets_all, 30 | 'General-2K': datasets_all, 31 | 'Matting': datasets_all, 32 | }[self.task] 33 | 34 | # Data settings 35 | self.size = (1024, 1024) if self.task not in ['General-2K'] else (2560, 1440) # wid, hei. Can be overwritten by dynamic_size in training. 36 | self.dynamic_size = [None, ((512-256, 2048+256), (512-256, 2048+256))][0] # wid, hei. It might cause errors in using compile. 37 | self.background_color_synthesis = False # whether to use pure bg color to replace the original backgrounds. 38 | 39 | # Faster-Training settings 40 | self.mixed_precision = ['no', 'fp16', 'bf16', 'fp8'][1] 41 | self.load_all = False and self.dynamic_size is None # Turn it on/off by your case. It may consume a lot of CPU memory. And for multi-GPU (N), it would cost N times the CPU memory to load the data. 42 | self.compile = True # 1. Trigger CPU memory leak in some extend, which is an inherent problem of PyTorch. 43 | # Machines with > 70GB CPU memory can run the whole training on DIS5K with default setting. 44 | # 2. Higher PyTorch version may fix it: https://github.com/pytorch/pytorch/issues/119607. 45 | # 3. But compile in 2.0.1 < Pytorch < 2.5.0 seems to bring no acceleration for training. 46 | self.precisionHigh = True 47 | 48 | # MODEL settings 49 | self.ms_supervision = True 50 | self.out_ref = self.ms_supervision and True 51 | self.dec_ipt = True 52 | self.dec_ipt_split = True 53 | self.cxt_num = [0, 3][1] # multi-scale skip connections from encoder 54 | self.mul_scl_ipt = ['', 'add', 'cat'][2] 55 | self.dec_att = ['', 'ASPP', 'ASPPDeformable'][2] 56 | self.squeeze_block = ['', 'BasicDecBlk_x1', 'ResBlk_x4', 'ASPP_x3', 'ASPPDeformable_x3'][1] 57 | self.dec_blk = ['BasicDecBlk', 'ResBlk'][0] 58 | 59 | # TRAINING settings 60 | self.batch_size = 4 61 | self.finetune_last_epochs = [ 62 | 0, 63 | { 64 | 'DIS5K': -40, 65 | 'COD': -20, 66 | 'HRSOD': -20, 67 | 'General': -20, 68 | 'General-2K': -20, 69 | 'Matting': -10, 70 | }[self.task] 71 | ][1] # choose 0 to skip 72 | self.lr = (1e-4 if 'DIS5K' in self.task else 1e-5) * math.sqrt(self.batch_size / 4) # DIS needs high lr to converge faster. Adapt the lr linearly 73 | self.num_workers = max(4, self.batch_size) # will be decrease to min(it, batch_size) at the initialization of the data_loader 74 | 75 | # Backbone settings 76 | self.bb = [ 77 | 'vgg16', 'vgg16bn', 'resnet50', # 0, 1, 2 78 | 'swin_v1_t', 'swin_v1_s', # 3, 4 79 | 'swin_v1_b', 'swin_v1_l', # 5-bs9, 6-bs4 80 | 'pvt_v2_b0', 'pvt_v2_b1', # 7, 8 81 | 'pvt_v2_b2', 'pvt_v2_b5', # 9-bs10, 10-bs5 82 | ][6] 83 | self.lateral_channels_in_collection = { 84 | 'vgg16': [512, 512, 256, 128], 'vgg16bn': [512, 512, 256, 128], 'resnet50': [2048, 1024, 512, 256], 85 | 'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64], 86 | 'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192], 87 | 'swin_v1_t': [768, 384, 192, 96], 'swin_v1_s': [768, 384, 192, 96], 88 | 'pvt_v2_b0': [256, 160, 64, 32], 'pvt_v2_b1': [512, 320, 128, 64], 89 | }[self.bb] 90 | if self.mul_scl_ipt == 'cat': 91 | self.lateral_channels_in_collection = [channel * 2 for channel in self.lateral_channels_in_collection] 92 | self.cxt = self.lateral_channels_in_collection[1:][::-1][-self.cxt_num:] if self.cxt_num else [] 93 | 94 | # MODEL settings - inactive 95 | self.lat_blk = ['BasicLatBlk'][0] 96 | self.dec_channels_inter = ['fixed', 'adap'][0] 97 | self.refine = ['', 'itself', 'RefUNet', 'Refiner', 'RefinerPVTInChannels4'][0] 98 | self.progressive_ref = self.refine and True 99 | self.ender = self.progressive_ref and False 100 | self.scale = self.progressive_ref and 2 101 | self.auxiliary_classification = False # Only for DIS5K, where class labels are saved in `dataset.py`. 102 | self.refine_iteration = 1 103 | self.freeze_bb = False 104 | self.model = [ 105 | 'BiRefNet', 106 | 'BiRefNetC2F', 107 | ][0] 108 | 109 | # TRAINING settings - inactive 110 | self.preproc_methods = ['flip', 'enhance', 'rotate', 'pepper', 'crop'][:4 if not self.background_color_synthesis else 1] 111 | self.optimizer = ['Adam', 'AdamW'][1] 112 | self.lr_decay_epochs = [1e5] # Set to negative N to decay the lr in the last N-th epoch. 113 | self.lr_decay_rate = 0.5 114 | # Loss 115 | if self.task in ['Matting']: 116 | self.lambdas_pix_last = { 117 | 'bce': 30 * 1, 118 | 'iou': 0.5 * 0, 119 | 'iou_patch': 0.5 * 0, 120 | 'mae': 100 * 1, 121 | 'mse': 30 * 0, 122 | 'triplet': 3 * 0, 123 | 'reg': 100 * 0, 124 | 'ssim': 10 * 1, 125 | 'cnt': 5 * 0, 126 | 'structure': 5 * 0, 127 | } 128 | elif self.task in ['General', 'General-2K']: 129 | self.lambdas_pix_last = { 130 | 'bce': 30 * 1, 131 | 'iou': 0.5 * 1, 132 | 'iou_patch': 0.5 * 0, 133 | 'mae': 100 * 1, 134 | 'mse': 30 * 0, 135 | 'triplet': 3 * 0, 136 | 'reg': 100 * 0, 137 | 'ssim': 10 * 1, 138 | 'cnt': 5 * 0, 139 | 'structure': 5 * 0, 140 | } 141 | else: 142 | self.lambdas_pix_last = { 143 | # not 0 means opening this loss 144 | # original rate -- 1 : 30 : 1.5 : 0.2, bce x 30 145 | 'bce': 30 * 1, # high performance 146 | 'iou': 0.5 * 1, # 0 / 255 147 | 'iou_patch': 0.5 * 0, # 0 / 255, win_size = (64, 64) 148 | 'mae': 30 * 0, 149 | 'mse': 30 * 0, # can smooth the saliency map 150 | 'triplet': 3 * 0, 151 | 'reg': 100 * 0, 152 | 'ssim': 10 * 1, # help contours, 153 | 'cnt': 5 * 0, # help contours 154 | 'structure': 5 * 0, # structure loss from codes of MVANet. A little improvement on DIS-TE[1,2,3], a bit more decrease on DIS-TE4. 155 | } 156 | self.lambdas_cls = { 157 | 'ce': 5.0 158 | } 159 | 160 | # PATH settings - inactive 161 | self.weights_root_dir = os.path.join(self.sys_home_dir, 'weights/cv') 162 | self.weights = { 163 | 'pvt_v2_b2': os.path.join(self.weights_root_dir, 'pvt_v2_b2.pth'), 164 | 'pvt_v2_b5': os.path.join(self.weights_root_dir, ['pvt_v2_b5.pth', 'pvt_v2_b5_22k.pth'][0]), 165 | 'swin_v1_b': os.path.join(self.weights_root_dir, ['swin_base_patch4_window12_384_22kto1k.pth', 'swin_base_patch4_window12_384_22k.pth'][0]), 166 | 'swin_v1_l': os.path.join(self.weights_root_dir, ['swin_large_patch4_window12_384_22kto1k.pth', 'swin_large_patch4_window12_384_22k.pth'][0]), 167 | 'swin_v1_t': os.path.join(self.weights_root_dir, ['swin_tiny_patch4_window7_224_22kto1k_finetune.pth'][0]), 168 | 'swin_v1_s': os.path.join(self.weights_root_dir, ['swin_small_patch4_window7_224_22kto1k_finetune.pth'][0]), 169 | 'pvt_v2_b0': os.path.join(self.weights_root_dir, ['pvt_v2_b0.pth'][0]), 170 | 'pvt_v2_b1': os.path.join(self.weights_root_dir, ['pvt_v2_b1.pth'][0]), 171 | } 172 | 173 | # Callbacks - inactive 174 | self.verbose_eval = True 175 | self.only_S_MAE = False 176 | self.SDPA_enabled = False # Bugs. Slower and errors occur in multi-GPUs 177 | 178 | # others 179 | self.device = [0, 'cpu'][0] # .to(0) == .to('cuda:0') 180 | 181 | self.batch_size_valid = 1 182 | self.rand_seed = 7 183 | run_sh_file = [f for f in os.listdir('.') if 'train.sh' == f] + [os.path.join('..', f) for f in os.listdir('..') if 'train.sh' == f] 184 | if run_sh_file: 185 | with open(run_sh_file[0], 'r') as f: 186 | lines = f.readlines() 187 | self.save_last = int([l.strip() for l in lines if "'{}')".format(self.task) in l and 'val_last=' in l][0].split('val_last=')[-1].split()[0]) 188 | self.save_step = int([l.strip() for l in lines if "'{}')".format(self.task) in l and 'step=' in l][0].split('step=')[-1].split()[0]) 189 | 190 | 191 | # Return task for choosing settings in shell scripts. 192 | if __name__ == '__main__': 193 | import argparse 194 | 195 | 196 | parser = argparse.ArgumentParser(description='Only choose one argument to activate.') 197 | parser.add_argument('--print_task', action='store_true', help='print task name') 198 | parser.add_argument('--print_testsets', action='store_true', help='print validation set') 199 | args = parser.parse_args() 200 | 201 | config = Config() 202 | for arg_name, arg_value in args._get_kwargs(): 203 | if arg_value: 204 | print(config.__getattribute__(arg_name[len('print_'):])) 205 | 206 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import cv2 5 | from tqdm import tqdm 6 | from PIL import Image 7 | from torch.utils import data 8 | from torchvision import transforms 9 | 10 | from image_proc import preproc 11 | from config import Config 12 | from utils import path_to_image 13 | 14 | 15 | Image.MAX_IMAGE_PIXELS = None # remove DecompressionBombWarning 16 | config = Config() 17 | _class_labels_TR_sorted = ( 18 | 'Airplane, Ant, Antenna, Archery, Axe, BabyCarriage, Bag, BalanceBeam, Balcony, Balloon, Basket, BasketballHoop, Beatle, Bed, Bee, Bench, Bicycle, ' 19 | 'BicycleFrame, BicycleStand, Boat, Bonsai, BoomLift, Bridge, BunkBed, Butterfly, Button, Cable, CableLift, Cage, Camcorder, Cannon, Canoe, Car, ' 20 | 'CarParkDropArm, Carriage, Cart, Caterpillar, CeilingLamp, Centipede, Chair, Clip, Clock, Clothes, CoatHanger, Comb, ConcretePumpTruck, Crack, Crane, ' 21 | 'Cup, DentalChair, Desk, DeskChair, Diagram, DishRack, DoorHandle, Dragonfish, Dragonfly, Drum, Earphone, Easel, ElectricIron, Excavator, Eyeglasses, ' 22 | 'Fan, Fence, Fencing, FerrisWheel, FireExtinguisher, Fishing, Flag, FloorLamp, Forklift, GasStation, Gate, Gear, Goal, Golf, GymEquipment, Hammock, ' 23 | 'Handcart, Handcraft, Handrail, HangGlider, Harp, Harvester, Headset, Helicopter, Helmet, Hook, HorizontalBar, Hydrovalve, IroningTable, Jewelry, Key, ' 24 | 'KidsPlayground, Kitchenware, Kite, Knife, Ladder, LaundryRack, Lightning, Lobster, Locust, Machine, MachineGun, MagazineRack, Mantis, Medal, MemorialArchway, ' 25 | 'Microphone, Missile, MobileHolder, Monitor, Mosquito, Motorcycle, MovingTrolley, Mower, MusicPlayer, MusicStand, ObservationTower, Octopus, OilWell, ' 26 | 'OlympicLogo, OperatingTable, OutdoorFitnessEquipment, Parachute, Pavilion, Piano, Pipe, PlowHarrow, PoleVault, Punchbag, Rack, Racket, Rifle, Ring, Robot, ' 27 | 'RockClimbing, Rope, Sailboat, Satellite, Scaffold, Scale, Scissor, Scooter, Sculpture, Seadragon, Seahorse, Seal, SewingMachine, Ship, Shoe, ShoppingCart, ' 28 | 'ShoppingTrolley, Shower, Shrimp, Signboard, Skateboarding, Skeleton, Skiing, Spade, SpeedBoat, Spider, Spoon, Stair, Stand, Stationary, SteeringWheel, ' 29 | 'Stethoscope, Stool, Stove, StreetLamp, SweetStand, Swing, Sword, TV, Table, TableChair, TableLamp, TableTennis, Tank, Tapeline, Teapot, Telescope, Tent, ' 30 | 'TobaccoPipe, Toy, Tractor, TrafficLight, TrafficSign, Trampoline, TransmissionTower, Tree, Tricycle, TrimmerCover, Tripod, Trombone, Truck, Trumpet, Tuba, ' 31 | 'UAV, Umbrella, UnevenBars, UtilityPole, VacuumCleaner, Violin, Wakesurfing, Watch, WaterTower, WateringPot, Well, WellLid, Wheel, Wheelchair, WindTurbine, Windmill, WineGlass, WireWhisk, Yacht' 32 | ) 33 | class_labels_TR_sorted = _class_labels_TR_sorted.split(', ') 34 | 35 | 36 | class MyData(data.Dataset): 37 | def __init__(self, datasets, data_size, is_train=True): 38 | # data_size is None when using dynamic_size or data_size is manually set to None (for inference in the original size). 39 | self.is_train = is_train 40 | self.data_size = data_size 41 | self.load_all = config.load_all 42 | self.device = config.device 43 | valid_extensions = ['.png', '.jpg', '.PNG', '.JPG', '.JPEG'] 44 | 45 | if self.is_train and config.auxiliary_classification: 46 | self.cls_name2id = {_name: _id for _id, _name in enumerate(class_labels_TR_sorted)} 47 | self.transform_image = transforms.Compose([ 48 | transforms.ToTensor(), 49 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 50 | ]) 51 | self.transform_label = transforms.Compose([ 52 | transforms.ToTensor(), 53 | ]) 54 | dataset_root = os.path.join(config.data_root_dir, config.task) 55 | # datasets can be a list of different datasets for training on combined sets. 56 | self.image_paths = [] 57 | for dataset in datasets.split('+'): 58 | image_root = os.path.join(dataset_root, dataset, 'im') 59 | self.image_paths += [os.path.join(image_root, p) for p in os.listdir(image_root) if any(p.endswith(ext) for ext in valid_extensions)] 60 | self.label_paths = [] 61 | for p in self.image_paths: 62 | for ext in valid_extensions: 63 | ## 'im' and 'gt' may need modifying 64 | p_gt = p.replace('/im/', '/gt/')[:-(len(p.split('.')[-1])+1)] + ext 65 | file_exists = False 66 | if os.path.exists(p_gt): 67 | self.label_paths.append(p_gt) 68 | file_exists = True 69 | break 70 | if not file_exists: 71 | print('Not exists:', p_gt) 72 | 73 | if len(self.label_paths) != len(self.image_paths): 74 | set_image_paths = set([os.path.splitext(p.split(os.sep)[-1])[0] for p in self.image_paths]) 75 | set_label_paths = set([os.path.splitext(p.split(os.sep)[-1])[0] for p in self.label_paths]) 76 | print('Path diff:', set_image_paths - set_label_paths) 77 | raise ValueError(f"There are different numbers of images ({len(self.label_paths)}) and labels ({len(self.image_paths)})") 78 | 79 | if self.load_all: 80 | self.images_loaded, self.labels_loaded = [], [] 81 | self.class_labels_loaded = [] 82 | # for image_path, label_path in zip(self.image_paths, self.label_paths): 83 | for image_path, label_path in tqdm(zip(self.image_paths, self.label_paths), total=len(self.image_paths)): 84 | _image = path_to_image(image_path, size=self.data_size, color_type='rgb') 85 | _label = path_to_image(label_path, size=self.data_size, color_type='gray') 86 | self.images_loaded.append(_image) 87 | self.labels_loaded.append(_label) 88 | self.class_labels_loaded.append( 89 | self.cls_name2id[label_path.split('/')[-1].split('#')[3]] if self.is_train and config.auxiliary_classification else -1 90 | ) 91 | 92 | def __getitem__(self, index): 93 | if self.load_all: 94 | image = self.images_loaded[index] 95 | label = self.labels_loaded[index] 96 | class_label = self.class_labels_loaded[index] if self.is_train and config.auxiliary_classification else -1 97 | else: 98 | image = path_to_image(self.image_paths[index], size=self.data_size, color_type='rgb') 99 | label = path_to_image(self.label_paths[index], size=self.data_size, color_type='gray') 100 | class_label = self.cls_name2id[self.label_paths[index].split('/')[-1].split('#')[3]] if self.is_train and config.auxiliary_classification else -1 101 | 102 | # loading image and label 103 | if self.is_train: 104 | if config.background_color_synthesis: 105 | image.putalpha(label) 106 | array_image = np.array(image) 107 | array_foreground = array_image[:, :, :3].astype(np.float32) 108 | array_mask = (array_image[:, :, 3:] / 255).astype(np.float32) 109 | array_background = np.zeros_like(array_foreground) 110 | choice = random.random() 111 | if choice < 0.4: 112 | # Black/Gray/White backgrounds 113 | array_background[:, :, :] = random.randint(0, 255) 114 | elif choice < 0.8: 115 | # Background color that similar to the foreground object. Hard negative samples. 116 | foreground_pixel_number = np.sum(array_mask > 0) 117 | color_foreground_mean = np.mean(array_foreground * array_mask, axis=(0, 1)) * (np.prod(array_foreground.shape[:2]) / foreground_pixel_number) 118 | color_up_or_down = random.choice((-1, 1)) 119 | # Up or down for 20% range from 255 or 0, respectively. 120 | color_foreground_mean += (255 - color_foreground_mean if color_up_or_down == 1 else color_foreground_mean) * (random.random() * 0.2) * color_up_or_down 121 | array_background[:, :, :] = color_foreground_mean 122 | else: 123 | # Any color 124 | for idx_channel in range(3): 125 | array_background[:, :, idx_channel] = random.randint(0, 255) 126 | array_foreground_background = array_foreground * array_mask + array_background * (1 - array_mask) 127 | image = Image.fromarray(array_foreground_background.astype(np.uint8)) 128 | image, label = preproc(image, label, preproc_methods=config.preproc_methods) 129 | # else: 130 | # if _label.shape[0] > 2048 or _label.shape[1] > 2048: 131 | # _image = cv2.resize(_image, (2048, 2048), interpolation=cv2.INTER_LINEAR) 132 | # _label = cv2.resize(_label, (2048, 2048), interpolation=cv2.INTER_LINEAR) 133 | 134 | # At present, we use fixed sizes in inference, instead of consistent dynamic size with training. 135 | if self.is_train: 136 | if config.dynamic_size is None: 137 | image, label = self.transform_image(image), self.transform_label(label) 138 | else: 139 | size_div_32 = (int(image.size[0] // 32 * 32), int(image.size[1] // 32 * 32)) 140 | if image.size != size_div_32: 141 | image = image.resize(size_div_32) 142 | label = label.resize(size_div_32) 143 | image, label = self.transform_image(image), self.transform_label(label) 144 | 145 | if self.is_train: 146 | return image, label, class_label 147 | else: 148 | return image, label, self.label_paths[index] 149 | 150 | def __len__(self): 151 | return len(self.image_paths) 152 | 153 | 154 | def custom_collate_fn(batch): 155 | if config.dynamic_size: 156 | dynamic_size = tuple(sorted(config.dynamic_size)) 157 | dynamic_size_batch = (random.randint(dynamic_size[0][0], dynamic_size[0][1]) // 32 * 32, random.randint(dynamic_size[1][0], dynamic_size[1][1]) // 32 * 32) # select a value randomly in the range of [dynamic_size[0/1][0], dynamic_size[0/1][1]]. 158 | data_size = dynamic_size_batch 159 | else: 160 | data_size = config.size 161 | new_batch = [] 162 | transform_image = transforms.Compose([ 163 | transforms.Resize(data_size[::-1]), 164 | transforms.ToTensor(), 165 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 166 | ]) 167 | transform_label = transforms.Compose([ 168 | transforms.Resize(data_size[::-1]), 169 | transforms.ToTensor(), 170 | ]) 171 | for image, label, class_label in batch: 172 | new_batch.append((transform_image(image), transform_label(label), class_label)) 173 | return data._utils.collate.default_collate(new_batch) 174 | -------------------------------------------------------------------------------- /eval_existingOnes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from glob import glob 4 | import prettytable as pt 5 | 6 | from evaluation.metrics import evaluator 7 | from config import Config 8 | 9 | 10 | config = Config() 11 | 12 | 13 | def do_eval(args): 14 | # evaluation for whole dataset 15 | # dataset first in evaluation 16 | for _data_name in args.data_lst.split('+'): 17 | pred_data_dir = sorted(glob(os.path.join(args.pred_root, args.model_lst[0], _data_name))) 18 | if not pred_data_dir: 19 | print('Skip dataset {}.'.format(_data_name)) 20 | continue 21 | gt_src = os.path.join(args.gt_root, _data_name) 22 | gt_paths = sorted(glob(os.path.join(gt_src, 'gt', '*'))) 23 | print('#' * 20, _data_name, '#' * 20) 24 | filename = os.path.join(args.save_dir, '{}_eval.txt'.format(_data_name)) 25 | tb = pt.PrettyTable() 26 | tb.vertical_char = '&' 27 | if config.task == 'DIS5K': 28 | tb.field_names = ["Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "HCE", "maxEm", "meanFm", "adpEm", "adpFm", 'mBA', 'maxBIoU', 'meanBIoU'] 29 | elif config.task == 'COD': 30 | tb.field_names = ["Dataset", "Method", "Smeasure", "wFmeasure", "meanFm", "meanEm", "maxEm", 'MAE', "maxFm", "adpEm", "adpFm", "HCE", 'mBA', 'maxBIoU', 'meanBIoU'] 31 | elif config.task == 'HRSOD': 32 | tb.field_names = ["Dataset", "Method", "Smeasure", "maxFm", "meanEm", 'MAE', "maxEm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE", 'mBA', 'maxBIoU', 'meanBIoU'] 33 | elif config.task == 'General': 34 | tb.field_names = ["Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "HCE", "maxEm", "meanFm", "adpEm", "adpFm", 'mBA', 'maxBIoU', 'meanBIoU'] 35 | elif config.task == 'General-2K': 36 | tb.field_names = ["Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "HCE", "maxEm", "meanFm", "adpEm", "adpFm", 'mBA', 'maxBIoU', 'meanBIoU'] 37 | elif config.task == 'Matting': 38 | tb.field_names = ["Dataset", "Method", "Smeasure", "maxFm", "meanEm", 'MSE', "maxEm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE", 'mBA', 'maxBIoU', 'meanBIoU'] 39 | else: 40 | tb.field_names = ["Dataset", "Method", "Smeasure", 'MAE', "maxEm", "meanEm", "maxFm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE", 'mBA', 'maxBIoU', 'meanBIoU'] 41 | for _model_name in args.model_lst[:]: 42 | print('\t', 'Evaluating model: {}...'.format(_model_name)) 43 | pred_paths = [p.replace(args.gt_root, os.path.join(args.pred_root, _model_name)).replace('/gt/', '/') for p in gt_paths] 44 | # print(pred_paths[:1], gt_paths[:1]) 45 | em, sm, fm, mae, mse, wfm, hce, mba, biou = evaluator( 46 | gt_paths=gt_paths, 47 | pred_paths=pred_paths, 48 | metrics=args.metrics.split('+'), 49 | verbose=config.verbose_eval 50 | ) 51 | if config.task == 'DIS5K': 52 | scores = [ 53 | fm['curve'].max().round(3), wfm.round(3), mae.round(3), sm.round(3), em['curve'].mean().round(3), int(hce.round()), 54 | em['curve'].max().round(3), fm['curve'].mean().round(3), em['adp'].round(3), fm['adp'].round(3), 55 | mba.round(3), biou['curve'].max().round(3), biou['curve'].mean().round(3), 56 | ] 57 | elif config.task == 'COD': 58 | scores = [ 59 | sm.round(3), wfm.round(3), fm['curve'].mean().round(3), em['curve'].mean().round(3), em['curve'].max().round(3), mae.round(3), 60 | fm['curve'].max().round(3), em['adp'].round(3), fm['adp'].round(3), int(hce.round()), 61 | mba.round(3), biou['curve'].max().round(3), biou['curve'].mean().round(3), 62 | ] 63 | elif config.task == 'HRSOD': 64 | scores = [ 65 | sm.round(3), fm['curve'].max().round(3), em['curve'].mean().round(3), mae.round(3), 66 | em['curve'].max().round(3), fm['curve'].mean().round(3), wfm.round(3), em['adp'].round(3), fm['adp'].round(3), int(hce.round()), 67 | mba.round(3), biou['curve'].max().round(3), biou['curve'].mean().round(3), 68 | ] 69 | elif config.task == 'General': 70 | scores = [ 71 | fm['curve'].max().round(3), wfm.round(3), mae.round(3), sm.round(3), em['curve'].mean().round(3), int(hce.round()), 72 | em['curve'].max().round(3), fm['curve'].mean().round(3), em['adp'].round(3), fm['adp'].round(3), 73 | mba.round(3), biou['curve'].max().round(3), biou['curve'].mean().round(3), 74 | ] 75 | elif config.task == 'General-2K': 76 | scores = [ 77 | fm['curve'].max().round(3), wfm.round(3), mae.round(3), sm.round(3), em['curve'].mean().round(3), int(hce.round()), 78 | em['curve'].max().round(3), fm['curve'].mean().round(3), em['adp'].round(3), fm['adp'].round(3), 79 | mba.round(3), biou['curve'].max().round(3), biou['curve'].mean().round(3), 80 | ] 81 | elif config.task == 'Matting': 82 | scores = [ 83 | sm.round(3), fm['curve'].max().round(3), em['curve'].mean().round(3), mse.round(5), 84 | em['curve'].max().round(3), fm['curve'].mean().round(3), wfm.round(3), em['adp'].round(3), fm['adp'].round(3), int(hce.round()), 85 | mba.round(3), biou['curve'].max().round(3), biou['curve'].mean().round(3), 86 | ] 87 | else: 88 | scores = [ 89 | sm.round(3), mae.round(3), em['curve'].max().round(3), em['curve'].mean().round(3), 90 | fm['curve'].max().round(3), fm['curve'].mean().round(3), wfm.round(3), 91 | em['adp'].round(3), fm['adp'].round(3), int(hce.round()), 92 | mba.round(3), biou['curve'].max().round(3), biou['curve'].mean().round(3), 93 | ] 94 | 95 | for idx_score, score in enumerate(scores): 96 | scores[idx_score] = '.' + format(score, '.3f').split('.')[-1] if score <= 1 else format(score, '<4') 97 | records = [_data_name, _model_name] + scores 98 | tb.add_row(records) 99 | # Write results after every check. 100 | with open(filename, 'w+') as file_to_write: 101 | file_to_write.write(str(tb)+'\n') 102 | print(tb) 103 | 104 | 105 | if __name__ == '__main__': 106 | # set parameters 107 | parser = argparse.ArgumentParser() 108 | parser.add_argument( 109 | '--gt_root', type=str, help='ground-truth root', 110 | default=os.path.join(config.data_root_dir, config.task)) 111 | parser.add_argument( 112 | '--pred_root', type=str, help='prediction root', 113 | default='./e_preds') 114 | parser.add_argument( 115 | '--data_lst', type=str, help='test dataset', 116 | default=config.testsets.replace(',', '+')) 117 | parser.add_argument( 118 | '--save_dir', type=str, help='candidate competitors', 119 | default='e_results') 120 | parser.add_argument( 121 | '--check_integrity', type=bool, help='whether to check the file integrity', 122 | default=False) 123 | parser.add_argument( 124 | '--metrics', type=str, help='candidate competitors', 125 | default='+'.join(['S', 'MAE', 'E', 'F', 'WF', 'MBA', 'BIoU', 'MSE', 'HCE'][:100 if 'DIS5K' in config.task else -1])) 126 | args = parser.parse_args() 127 | args.metrics = '+'.join(['S', 'MAE', 'E', 'F', 'WF', 'MBA', 'BIoU', 'MSE', 'HCE'][:100 if sum(['DIS-' in _data for _data in args.data_lst.split('+')]) else -1]) 128 | 129 | os.makedirs(args.save_dir, exist_ok=True) 130 | try: 131 | args.model_lst = [m for m in sorted(os.listdir(args.pred_root), key=lambda x: int(x.split('epoch_')[-1].split('-')[0]), reverse=True) if int(m.split('epoch_')[-1].split('-')[0]) % 1 == 0] 132 | except Exception as e: 133 | print(f"Exception: {type(e).__name__} at line {e.__traceback__.tb_lineno} of {__file__}: {e}") 134 | args.model_lst = [m for m in sorted(os.listdir(args.pred_root))] 135 | 136 | # check the integrity of each candidates 137 | if args.check_integrity: 138 | for _data_name in args.data_lst.split('+'): 139 | for _model_name in args.model_lst: 140 | gt_pth = os.path.join(args.gt_root, _data_name) 141 | pred_pth = os.path.join(args.pred_root, _model_name, _data_name) 142 | if not sorted(os.listdir(gt_pth)) == sorted(os.listdir(pred_pth)): 143 | print(len(sorted(os.listdir(gt_pth))), len(sorted(os.listdir(pred_pth)))) 144 | print('The {} Dataset of {} Model is not matching to the ground-truth'.format(_data_name, _model_name)) 145 | else: 146 | print('>>> skip check the integrity of each candidates') 147 | 148 | # start engine 149 | do_eval(args) 150 | -------------------------------------------------------------------------------- /evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import cv2 4 | from PIL import Image 5 | import numpy as np 6 | from scipy.ndimage import convolve, distance_transform_edt as bwdist 7 | from skimage.morphology import skeletonize 8 | from skimage.morphology import disk 9 | from skimage.measure import label 10 | 11 | 12 | _EPS = np.spacing(1) 13 | _TYPE = np.float64 14 | 15 | 16 | def evaluator(gt_paths, pred_paths, metrics=['S', 'MAE', 'E', 'F', 'WF', 'MBA', 'BIoU', 'MSE', 'HCE'], verbose=False): 17 | # define measures 18 | if 'E' in metrics: 19 | EM = EMeasure() 20 | if 'S' in metrics: 21 | SM = SMeasure() 22 | if 'F' in metrics: 23 | FM = FMeasure() 24 | if 'MAE' in metrics: 25 | MAE = MAEMeasure() 26 | if 'MSE' in metrics: 27 | MSE = MSEMeasure() 28 | if 'WF' in metrics: 29 | WFM = WeightedFMeasure() 30 | if 'HCE' in metrics: 31 | HCE = HCEMeasure() 32 | if 'MBA' in metrics: 33 | MBA = MBAMeasure() 34 | if 'BIoU' in metrics: 35 | BIoU = BIoUMeasure() 36 | 37 | if isinstance(gt_paths, list) and isinstance(pred_paths, list): 38 | # print(len(gt_paths), len(pred_paths)) 39 | assert len(gt_paths) == len(pred_paths) 40 | 41 | for idx_sample in tqdm(range(len(gt_paths)), total=len(gt_paths)) if verbose else range(len(gt_paths)): 42 | gt = gt_paths[idx_sample] 43 | pred = pred_paths[idx_sample] 44 | 45 | pred = pred[:-4] + '.png' 46 | valid_extensions = ['.png', '.jpg', '.PNG', '.JPG', '.JPEG'] 47 | file_exists = False 48 | for ext in valid_extensions: 49 | if os.path.exists(pred[:-4] + ext): 50 | pred = pred[:-4] + ext 51 | file_exists = True 52 | break 53 | if file_exists: 54 | pred_ary = cv2.imread(pred, cv2.IMREAD_GRAYSCALE) 55 | else: 56 | print('Not exists:', pred) 57 | 58 | gt_ary = cv2.imread(gt, cv2.IMREAD_GRAYSCALE) 59 | pred_ary = cv2.resize(pred_ary, (gt_ary.shape[1], gt_ary.shape[0])) 60 | 61 | if 'E' in metrics: 62 | EM.step(pred=pred_ary, gt=gt_ary) 63 | if 'S' in metrics: 64 | SM.step(pred=pred_ary, gt=gt_ary) 65 | if 'F' in metrics: 66 | FM.step(pred=pred_ary, gt=gt_ary) 67 | if 'MAE' in metrics: 68 | MAE.step(pred=pred_ary, gt=gt_ary) 69 | if 'MSE' in metrics: 70 | MSE.step(pred=pred_ary, gt=gt_ary) 71 | if 'WF' in metrics: 72 | WFM.step(pred=pred_ary, gt=gt_ary) 73 | if 'HCE' in metrics: 74 | ske_path = gt.replace('/gt/', '/ske/') 75 | if os.path.exists(ske_path): 76 | ske_ary = cv2.imread(ske_path, cv2.IMREAD_GRAYSCALE) 77 | ske_ary = ske_ary > 128 78 | else: 79 | ske_ary = skeletonize(gt_ary > 128) 80 | ske_save_dir = os.path.join(*ske_path.split(os.sep)[:-1]) 81 | if ske_path[0] == os.sep: 82 | ske_save_dir = os.sep + ske_save_dir 83 | os.makedirs(ske_save_dir, exist_ok=True) 84 | cv2.imwrite(ske_path, ske_ary.astype(np.uint8) * 255) 85 | HCE.step(pred=pred_ary, gt=gt_ary, gt_ske=ske_ary) 86 | if 'MBA' in metrics: 87 | MBA.step(pred=pred_ary, gt=gt_ary) 88 | if 'BIoU' in metrics: 89 | BIoU.step(pred=pred_ary, gt=gt_ary) 90 | 91 | if 'E' in metrics: 92 | em = EM.get_results()['em'] 93 | else: 94 | em = {'curve': np.array([np.float64(-1)]), 'adp': np.float64(-1)} 95 | if 'S' in metrics: 96 | sm = SM.get_results()['sm'] 97 | else: 98 | sm = np.float64(-1) 99 | if 'F' in metrics: 100 | fm = FM.get_results()['fm'] 101 | else: 102 | fm = {'curve': np.array([np.float64(-1)]), 'adp': np.float64(-1)} 103 | if 'MAE' in metrics: 104 | mae = MAE.get_results()['mae'] 105 | else: 106 | mae = np.float64(-1) 107 | if 'MSE' in metrics: 108 | mse = MSE.get_results()['mse'] 109 | else: 110 | mse = np.float64(-1) 111 | if 'WF' in metrics: 112 | wfm = WFM.get_results()['wfm'] 113 | else: 114 | wfm = np.float64(-1) 115 | if 'HCE' in metrics: 116 | hce = HCE.get_results()['hce'] 117 | else: 118 | hce = np.float64(-1) 119 | if 'MBA' in metrics: 120 | mba = MBA.get_results()['mba'] 121 | else: 122 | mba = np.float64(-1) 123 | if 'BIoU' in metrics: 124 | biou = BIoU.get_results()['biou'] 125 | else: 126 | biou = {'curve': np.array([np.float64(-1)])} 127 | 128 | return em, sm, fm, mae, mse, wfm, hce, mba, biou 129 | 130 | 131 | def _prepare_data(pred: np.ndarray, gt: np.ndarray) -> tuple: 132 | gt = gt > 128 133 | pred = pred / 255 134 | if pred.max() != pred.min(): 135 | pred = (pred - pred.min()) / (pred.max() - pred.min()) 136 | return pred, gt 137 | 138 | 139 | def _get_adaptive_threshold(matrix: np.ndarray, max_value: float = 1) -> float: 140 | return min(2 * matrix.mean(), max_value) 141 | 142 | 143 | class FMeasure(object): 144 | def __init__(self, beta: float = 0.3): 145 | self.beta = beta 146 | self.precisions = [] 147 | self.recalls = [] 148 | self.adaptive_fms = [] 149 | self.changeable_fms = [] 150 | 151 | def step(self, pred: np.ndarray, gt: np.ndarray): 152 | pred, gt = _prepare_data(pred, gt) 153 | 154 | adaptive_fm = self.cal_adaptive_fm(pred=pred, gt=gt) 155 | self.adaptive_fms.append(adaptive_fm) 156 | 157 | precisions, recalls, changeable_fms = self.cal_pr(pred=pred, gt=gt) 158 | self.precisions.append(precisions) 159 | self.recalls.append(recalls) 160 | self.changeable_fms.append(changeable_fms) 161 | 162 | def cal_adaptive_fm(self, pred: np.ndarray, gt: np.ndarray) -> float: 163 | adaptive_threshold = _get_adaptive_threshold(pred, max_value=1) 164 | binary_predcition = pred >= adaptive_threshold 165 | area_intersection = binary_predcition[gt].sum() 166 | if area_intersection == 0: 167 | adaptive_fm = 0 168 | else: 169 | pre = area_intersection / np.count_nonzero(binary_predcition) 170 | rec = area_intersection / np.count_nonzero(gt) 171 | adaptive_fm = (1 + self.beta) * pre * rec / (self.beta * pre + rec) 172 | return adaptive_fm 173 | 174 | def cal_pr(self, pred: np.ndarray, gt: np.ndarray) -> tuple: 175 | pred = (pred * 255).astype(np.uint8) 176 | bins = np.linspace(0, 256, 257) 177 | fg_hist, _ = np.histogram(pred[gt], bins=bins) 178 | bg_hist, _ = np.histogram(pred[~gt], bins=bins) 179 | fg_w_thrs = np.cumsum(np.flip(fg_hist), axis=0) 180 | bg_w_thrs = np.cumsum(np.flip(bg_hist), axis=0) 181 | TPs = fg_w_thrs 182 | Ps = fg_w_thrs + bg_w_thrs 183 | Ps[Ps == 0] = 1 184 | T = max(np.count_nonzero(gt), 1) 185 | precisions = TPs / Ps 186 | recalls = TPs / T 187 | numerator = (1 + self.beta) * precisions * recalls 188 | denominator = np.where(numerator == 0, 1, self.beta * precisions + recalls) 189 | changeable_fms = numerator / denominator 190 | return precisions, recalls, changeable_fms 191 | 192 | def get_results(self) -> dict: 193 | adaptive_fm = np.mean(np.array(self.adaptive_fms, _TYPE)) 194 | changeable_fm = np.mean(np.array(self.changeable_fms, dtype=_TYPE), axis=0) 195 | precision = np.mean(np.array(self.precisions, dtype=_TYPE), axis=0) # N, 256 196 | recall = np.mean(np.array(self.recalls, dtype=_TYPE), axis=0) # N, 256 197 | return dict(fm=dict(adp=adaptive_fm, curve=changeable_fm), 198 | pr=dict(p=precision, r=recall)) 199 | 200 | 201 | class MAEMeasure(object): 202 | def __init__(self): 203 | self.maes = [] 204 | 205 | def step(self, pred: np.ndarray, gt: np.ndarray): 206 | pred, gt = _prepare_data(pred, gt) 207 | 208 | mae = self.cal_mae(pred, gt) 209 | self.maes.append(mae) 210 | 211 | def cal_mae(self, pred: np.ndarray, gt: np.ndarray) -> float: 212 | mae = np.mean(np.abs(pred - gt)) 213 | return mae 214 | 215 | def get_results(self) -> dict: 216 | mae = np.mean(np.array(self.maes, _TYPE)) 217 | return dict(mae=mae) 218 | 219 | 220 | class MSEMeasure(object): 221 | def __init__(self): 222 | self.mses = [] 223 | 224 | def step(self, pred: np.ndarray, gt: np.ndarray): 225 | pred, gt = _prepare_data(pred, gt) 226 | 227 | mse = self.cal_mse(pred, gt) 228 | self.mses.append(mse) 229 | 230 | def cal_mse(self, pred: np.ndarray, gt: np.ndarray) -> float: 231 | mse = np.mean((pred - gt) ** 2) 232 | return mse 233 | 234 | def get_results(self) -> dict: 235 | mse = np.mean(np.array(self.mses, _TYPE)) 236 | return dict(mse=mse) 237 | 238 | 239 | class SMeasure(object): 240 | def __init__(self, alpha: float = 0.5): 241 | self.sms = [] 242 | self.alpha = alpha 243 | 244 | def step(self, pred: np.ndarray, gt: np.ndarray): 245 | pred, gt = _prepare_data(pred=pred, gt=gt) 246 | 247 | sm = self.cal_sm(pred, gt) 248 | self.sms.append(sm) 249 | 250 | def cal_sm(self, pred: np.ndarray, gt: np.ndarray) -> float: 251 | y = np.mean(gt) 252 | if y == 0: 253 | sm = 1 - np.mean(pred) 254 | elif y == 1: 255 | sm = np.mean(pred) 256 | else: 257 | sm = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt) 258 | sm = max(0, sm) 259 | return sm 260 | 261 | def object(self, pred: np.ndarray, gt: np.ndarray) -> float: 262 | fg = pred * gt 263 | bg = (1 - pred) * (1 - gt) 264 | u = np.mean(gt) 265 | object_score = u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, 1 - gt) 266 | return object_score 267 | 268 | def s_object(self, pred: np.ndarray, gt: np.ndarray) -> float: 269 | x = np.mean(pred[gt == 1]) 270 | sigma_x = np.std(pred[gt == 1], ddof=1) 271 | score = 2 * x / (np.power(x, 2) + 1 + sigma_x + _EPS) 272 | return score 273 | 274 | def region(self, pred: np.ndarray, gt: np.ndarray) -> float: 275 | x, y = self.centroid(gt) 276 | part_info = self.divide_with_xy(pred, gt, x, y) 277 | w1, w2, w3, w4 = part_info['weight'] 278 | pred1, pred2, pred3, pred4 = part_info['pred'] 279 | gt1, gt2, gt3, gt4 = part_info['gt'] 280 | score1 = self.ssim(pred1, gt1) 281 | score2 = self.ssim(pred2, gt2) 282 | score3 = self.ssim(pred3, gt3) 283 | score4 = self.ssim(pred4, gt4) 284 | 285 | return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4 286 | 287 | def centroid(self, matrix: np.ndarray) -> tuple: 288 | h, w = matrix.shape 289 | area_object = np.count_nonzero(matrix) 290 | if area_object == 0: 291 | x = np.round(w / 2) 292 | y = np.round(h / 2) 293 | else: 294 | # More details can be found at: https://www.yuque.com/lart/blog/gpbigm 295 | y, x = np.argwhere(matrix).mean(axis=0).round() 296 | return int(x) + 1, int(y) + 1 297 | 298 | def divide_with_xy(self, pred: np.ndarray, gt: np.ndarray, x, y) -> dict: 299 | h, w = gt.shape 300 | area = h * w 301 | 302 | gt_LT = gt[0:y, 0:x] 303 | gt_RT = gt[0:y, x:w] 304 | gt_LB = gt[y:h, 0:x] 305 | gt_RB = gt[y:h, x:w] 306 | 307 | pred_LT = pred[0:y, 0:x] 308 | pred_RT = pred[0:y, x:w] 309 | pred_LB = pred[y:h, 0:x] 310 | pred_RB = pred[y:h, x:w] 311 | 312 | w1 = x * y / area 313 | w2 = y * (w - x) / area 314 | w3 = (h - y) * x / area 315 | w4 = 1 - w1 - w2 - w3 316 | 317 | return dict(gt=(gt_LT, gt_RT, gt_LB, gt_RB), 318 | pred=(pred_LT, pred_RT, pred_LB, pred_RB), 319 | weight=(w1, w2, w3, w4)) 320 | 321 | def ssim(self, pred: np.ndarray, gt: np.ndarray) -> float: 322 | h, w = pred.shape 323 | N = h * w 324 | 325 | x = np.mean(pred) 326 | y = np.mean(gt) 327 | 328 | sigma_x = np.sum((pred - x) ** 2) / (N - 1) 329 | sigma_y = np.sum((gt - y) ** 2) / (N - 1) 330 | sigma_xy = np.sum((pred - x) * (gt - y)) / (N - 1) 331 | 332 | alpha = 4 * x * y * sigma_xy 333 | beta = (x ** 2 + y ** 2) * (sigma_x + sigma_y) 334 | 335 | if alpha != 0: 336 | score = alpha / (beta + _EPS) 337 | elif alpha == 0 and beta == 0: 338 | score = 1 339 | else: 340 | score = 0 341 | return score 342 | 343 | def get_results(self) -> dict: 344 | sm = np.mean(np.array(self.sms, dtype=_TYPE)) 345 | return dict(sm=sm) 346 | 347 | 348 | class EMeasure(object): 349 | def __init__(self): 350 | self.adaptive_ems = [] 351 | self.changeable_ems = [] 352 | 353 | def step(self, pred: np.ndarray, gt: np.ndarray): 354 | pred, gt = _prepare_data(pred=pred, gt=gt) 355 | self.gt_fg_numel = np.count_nonzero(gt) 356 | self.gt_size = gt.shape[0] * gt.shape[1] 357 | 358 | changeable_ems = self.cal_changeable_em(pred, gt) 359 | self.changeable_ems.append(changeable_ems) 360 | adaptive_em = self.cal_adaptive_em(pred, gt) 361 | self.adaptive_ems.append(adaptive_em) 362 | 363 | def cal_adaptive_em(self, pred: np.ndarray, gt: np.ndarray) -> float: 364 | adaptive_threshold = _get_adaptive_threshold(pred, max_value=1) 365 | adaptive_em = self.cal_em_with_threshold(pred, gt, threshold=adaptive_threshold) 366 | return adaptive_em 367 | 368 | def cal_changeable_em(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray: 369 | changeable_ems = self.cal_em_with_cumsumhistogram(pred, gt) 370 | return changeable_ems 371 | 372 | def cal_em_with_threshold(self, pred: np.ndarray, gt: np.ndarray, threshold: float) -> float: 373 | binarized_pred = pred >= threshold 374 | fg_fg_numel = np.count_nonzero(binarized_pred & gt) 375 | fg_bg_numel = np.count_nonzero(binarized_pred & ~gt) 376 | 377 | fg___numel = fg_fg_numel + fg_bg_numel 378 | bg___numel = self.gt_size - fg___numel 379 | 380 | if self.gt_fg_numel == 0: 381 | enhanced_matrix_sum = bg___numel 382 | elif self.gt_fg_numel == self.gt_size: 383 | enhanced_matrix_sum = fg___numel 384 | else: 385 | parts_numel, combinations = self.generate_parts_numel_combinations( 386 | fg_fg_numel=fg_fg_numel, fg_bg_numel=fg_bg_numel, 387 | pred_fg_numel=fg___numel, pred_bg_numel=bg___numel, 388 | ) 389 | 390 | results_parts = [] 391 | for i, (part_numel, combination) in enumerate(zip(parts_numel, combinations)): 392 | align_matrix_value = 2 * (combination[0] * combination[1]) / \ 393 | (combination[0] ** 2 + combination[1] ** 2 + _EPS) 394 | enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4 395 | results_parts.append(enhanced_matrix_value * part_numel) 396 | enhanced_matrix_sum = sum(results_parts) 397 | 398 | em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS) 399 | return em 400 | 401 | def cal_em_with_cumsumhistogram(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray: 402 | pred = (pred * 255).astype(np.uint8) 403 | bins = np.linspace(0, 256, 257) 404 | fg_fg_hist, _ = np.histogram(pred[gt], bins=bins) 405 | fg_bg_hist, _ = np.histogram(pred[~gt], bins=bins) 406 | fg_fg_numel_w_thrs = np.cumsum(np.flip(fg_fg_hist), axis=0) 407 | fg_bg_numel_w_thrs = np.cumsum(np.flip(fg_bg_hist), axis=0) 408 | 409 | fg___numel_w_thrs = fg_fg_numel_w_thrs + fg_bg_numel_w_thrs 410 | bg___numel_w_thrs = self.gt_size - fg___numel_w_thrs 411 | 412 | if self.gt_fg_numel == 0: 413 | enhanced_matrix_sum = bg___numel_w_thrs 414 | elif self.gt_fg_numel == self.gt_size: 415 | enhanced_matrix_sum = fg___numel_w_thrs 416 | else: 417 | parts_numel_w_thrs, combinations = self.generate_parts_numel_combinations( 418 | fg_fg_numel=fg_fg_numel_w_thrs, fg_bg_numel=fg_bg_numel_w_thrs, 419 | pred_fg_numel=fg___numel_w_thrs, pred_bg_numel=bg___numel_w_thrs, 420 | ) 421 | 422 | results_parts = np.empty(shape=(4, 256), dtype=np.float64) 423 | for i, (part_numel, combination) in enumerate(zip(parts_numel_w_thrs, combinations)): 424 | align_matrix_value = 2 * (combination[0] * combination[1]) / \ 425 | (combination[0] ** 2 + combination[1] ** 2 + _EPS) 426 | enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4 427 | results_parts[i] = enhanced_matrix_value * part_numel 428 | enhanced_matrix_sum = results_parts.sum(axis=0) 429 | 430 | em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS) 431 | return em 432 | 433 | def generate_parts_numel_combinations(self, fg_fg_numel, fg_bg_numel, pred_fg_numel, pred_bg_numel): 434 | bg_fg_numel = self.gt_fg_numel - fg_fg_numel 435 | bg_bg_numel = pred_bg_numel - bg_fg_numel 436 | 437 | parts_numel = [fg_fg_numel, fg_bg_numel, bg_fg_numel, bg_bg_numel] 438 | 439 | mean_pred_value = pred_fg_numel / self.gt_size 440 | mean_gt_value = self.gt_fg_numel / self.gt_size 441 | 442 | demeaned_pred_fg_value = 1 - mean_pred_value 443 | demeaned_pred_bg_value = 0 - mean_pred_value 444 | demeaned_gt_fg_value = 1 - mean_gt_value 445 | demeaned_gt_bg_value = 0 - mean_gt_value 446 | 447 | combinations = [ 448 | (demeaned_pred_fg_value, demeaned_gt_fg_value), 449 | (demeaned_pred_fg_value, demeaned_gt_bg_value), 450 | (demeaned_pred_bg_value, demeaned_gt_fg_value), 451 | (demeaned_pred_bg_value, demeaned_gt_bg_value) 452 | ] 453 | return parts_numel, combinations 454 | 455 | def get_results(self) -> dict: 456 | adaptive_em = np.mean(np.array(self.adaptive_ems, dtype=_TYPE)) 457 | changeable_em = np.mean(np.array(self.changeable_ems, dtype=_TYPE), axis=0) 458 | return dict(em=dict(adp=adaptive_em, curve=changeable_em)) 459 | 460 | 461 | class WeightedFMeasure(object): 462 | def __init__(self, beta: float = 1): 463 | self.beta = beta 464 | self.weighted_fms = [] 465 | 466 | def step(self, pred: np.ndarray, gt: np.ndarray): 467 | pred, gt = _prepare_data(pred=pred, gt=gt) 468 | 469 | if np.all(~gt): 470 | wfm = 0 471 | else: 472 | wfm = self.cal_wfm(pred, gt) 473 | self.weighted_fms.append(wfm) 474 | 475 | def cal_wfm(self, pred: np.ndarray, gt: np.ndarray) -> float: 476 | # [Dst,IDXT] = bwdist(dGT); 477 | Dst, Idxt = bwdist(gt == 0, return_indices=True) 478 | 479 | # %Pixel dependency 480 | # E = abs(FG-dGT); 481 | E = np.abs(pred - gt) 482 | Et = np.copy(E) 483 | Et[gt == 0] = Et[Idxt[0][gt == 0], Idxt[1][gt == 0]] 484 | 485 | # K = fspecial('gaussian',7,5); 486 | # EA = imfilter(Et,K); 487 | K = self.matlab_style_gauss2D((7, 7), sigma=5) 488 | EA = convolve(Et, weights=K, mode="constant", cval=0) 489 | # MIN_E_EA = E; 490 | # MIN_E_EA(GT & EA np.ndarray: 510 | """ 511 | 2D gaussian mask - should give the same result as MATLAB's 512 | fspecial('gaussian',[shape],[sigma]) 513 | """ 514 | m, n = [(ss - 1) / 2 for ss in shape] 515 | y, x = np.ogrid[-m: m + 1, -n: n + 1] 516 | h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) 517 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 518 | sumh = h.sum() 519 | if sumh != 0: 520 | h /= sumh 521 | return h 522 | 523 | def get_results(self) -> dict: 524 | weighted_fm = np.mean(np.array(self.weighted_fms, dtype=_TYPE)) 525 | return dict(wfm=weighted_fm) 526 | 527 | 528 | class HCEMeasure(object): 529 | def __init__(self): 530 | self.hces = [] 531 | 532 | def step(self, pred: np.ndarray, gt: np.ndarray, gt_ske): 533 | # pred, gt = _prepare_data(pred, gt) 534 | 535 | hce = self.cal_hce(pred, gt, gt_ske) 536 | self.hces.append(hce) 537 | 538 | def get_results(self) -> dict: 539 | hce = np.mean(np.array(self.hces, _TYPE)) 540 | return dict(hce=hce) 541 | 542 | 543 | def cal_hce(self, pred: np.ndarray, gt: np.ndarray, gt_ske: np.ndarray, relax=5, epsilon=2.0) -> float: 544 | # Binarize gt 545 | if(len(gt.shape)>2): 546 | gt = gt[:, :, 0] 547 | 548 | epsilon_gt = 128#(np.amin(gt)+np.amax(gt))/2.0 549 | gt = (gt>epsilon_gt).astype(np.uint8) 550 | 551 | # Binarize pred 552 | if(len(pred.shape)>2): 553 | pred = pred[:, :, 0] 554 | epsilon_pred = 128#(np.amin(pred)+np.amax(pred))/2.0 555 | pred = (pred>epsilon_pred).astype(np.uint8) 556 | 557 | Union = np.logical_or(gt, pred) 558 | TP = np.logical_and(gt, pred) 559 | FP = pred - TP 560 | FN = gt - TP 561 | 562 | # relax the Union of gt and pred 563 | Union_erode = Union.copy() 564 | Union_erode = cv2.erode(Union_erode.astype(np.uint8), disk(1), iterations=relax) 565 | 566 | # --- get the relaxed False Positive regions for computing the human efforts in correcting them --- 567 | FP_ = np.logical_and(FP, Union_erode) # get the relaxed FP 568 | for i in range(0, relax): 569 | FP_ = cv2.dilate(FP_.astype(np.uint8), disk(1)) 570 | FP_ = np.logical_and(FP_, 1-np.logical_or(TP, FN)) 571 | FP_ = np.logical_and(FP, FP_) 572 | 573 | # --- get the relaxed False Negative regions for computing the human efforts in correcting them --- 574 | FN_ = np.logical_and(FN, Union_erode) # preserve the structural components of FN 575 | ## recover the FN, where pixels are not close to the TP borders 576 | for i in range(0, relax): 577 | FN_ = cv2.dilate(FN_.astype(np.uint8), disk(1)) 578 | FN_ = np.logical_and(FN_, 1-np.logical_or(TP, FP)) 579 | FN_ = np.logical_and(FN, FN_) 580 | FN_ = np.logical_or(FN_, np.logical_xor(gt_ske, np.logical_and(TP, gt_ske))) # preserve the structural components of FN 581 | 582 | ## 2. =============Find exact polygon control points and independent regions============== 583 | ## find contours from FP_ 584 | ctrs_FP, hier_FP = cv2.findContours(FP_.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) 585 | ## find control points and independent regions for human correction 586 | bdies_FP, indep_cnt_FP = self.filter_bdy_cond(ctrs_FP, FP_, np.logical_or(TP,FN_)) 587 | ## find contours from FN_ 588 | ctrs_FN, hier_FN = cv2.findContours(FN_.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) 589 | ## find control points and independent regions for human correction 590 | bdies_FN, indep_cnt_FN = self.filter_bdy_cond(ctrs_FN, FN_, 1-np.logical_or(np.logical_or(TP, FP_), FN_)) 591 | 592 | poly_FP, poly_FP_len, poly_FP_point_cnt = self.approximate_RDP(bdies_FP, epsilon=epsilon) 593 | poly_FN, poly_FN_len, poly_FN_point_cnt = self.approximate_RDP(bdies_FN, epsilon=epsilon) 594 | 595 | # FP_points+FP_indep+FN_points+FN_indep 596 | return poly_FP_point_cnt+indep_cnt_FP+poly_FN_point_cnt+indep_cnt_FN 597 | 598 | def filter_bdy_cond(self, bdy_, mask, cond): 599 | 600 | cond = cv2.dilate(cond.astype(np.uint8), disk(1)) 601 | labels = label(mask) # find the connected regions 602 | lbls = np.unique(labels) # the indices of the connected regions 603 | indep = np.ones(lbls.shape[0]) # the label of each connected regions 604 | indep[0] = 0 # 0 indicate the background region 605 | 606 | boundaries = [] 607 | h,w = cond.shape[0:2] 608 | ind_map = np.zeros((h, w)) 609 | indep_cnt = 0 610 | 611 | for i in range(0, len(bdy_)): 612 | tmp_bdies = [] 613 | tmp_bdy = [] 614 | for j in range(0, bdy_[i].shape[0]): 615 | r, c = bdy_[i][j,0,1],bdy_[i][j,0,0] 616 | 617 | if(np.sum(cond[r, c])==0 or ind_map[r, c]!=0): 618 | if(len(tmp_bdy)>0): 619 | tmp_bdies.append(tmp_bdy) 620 | tmp_bdy = [] 621 | continue 622 | tmp_bdy.append([c, r]) 623 | ind_map[r, c] = ind_map[r, c] + 1 624 | indep[labels[r, c]] = 0 # indicates part of the boundary of this region needs human correction 625 | if(len(tmp_bdy)>0): 626 | tmp_bdies.append(tmp_bdy) 627 | 628 | # check if the first and the last boundaries are connected 629 | # if yes, invert the first boundary and attach it after the last boundary 630 | if(len(tmp_bdies)>1): 631 | first_x, first_y = tmp_bdies[0][0] 632 | last_x, last_y = tmp_bdies[-1][-1] 633 | if((abs(first_x-last_x)==1 and first_y==last_y) or 634 | (first_x==last_x and abs(first_y-last_y)==1) or 635 | (abs(first_x-last_x)==1 and abs(first_y-last_y)==1) 636 | ): 637 | tmp_bdies[-1].extend(tmp_bdies[0][::-1]) 638 | del tmp_bdies[0] 639 | 640 | for k in range(0, len(tmp_bdies)): 641 | tmp_bdies[k] = np.array(tmp_bdies[k])[:, np.newaxis, :] 642 | if(len(tmp_bdies)>0): 643 | boundaries.extend(tmp_bdies) 644 | 645 | return boundaries, np.sum(indep) 646 | 647 | # this function approximate each boundary by DP algorithm 648 | # https://en.wikipedia.org/wiki/Ramer%E2%80%93Douglas%E2%80%93Peucker_algorithm 649 | def approximate_RDP(self, boundaries, epsilon=1.0): 650 | 651 | boundaries_ = [] 652 | boundaries_len_ = [] 653 | pixel_cnt_ = 0 654 | 655 | # polygon approximate of each boundary 656 | for i in range(0, len(boundaries)): 657 | boundaries_.append(cv2.approxPolyDP(boundaries[i], epsilon, False)) 658 | 659 | # count the control points number of each boundary and the total control points number of all the boundaries 660 | for i in range(0, len(boundaries_)): 661 | boundaries_len_.append(len(boundaries_[i])) 662 | pixel_cnt_ = pixel_cnt_ + len(boundaries_[i]) 663 | 664 | return boundaries_, boundaries_len_, pixel_cnt_ 665 | 666 | 667 | class MBAMeasure(object): 668 | def __init__(self): 669 | self.bas = [] 670 | self.all_h = 0 671 | self.all_w = 0 672 | self.all_max = 0 673 | 674 | def step(self, pred: np.ndarray, gt: np.ndarray): 675 | # pred, gt = _prepare_data(pred, gt) 676 | 677 | refined = gt.copy() 678 | 679 | rmin = cmin = 0 680 | rmax, cmax = gt.shape 681 | 682 | self.all_h += rmax 683 | self.all_w += cmax 684 | self.all_max += max(rmax, cmax) 685 | 686 | refined_h, refined_w = refined.shape 687 | if refined_h != cmax: 688 | refined = np.array(Image.fromarray(pred).resize((cmax, rmax), Image.BILINEAR)) 689 | 690 | if not(gt.sum() < 32*32): 691 | if not((cmax==cmin) or (rmax==rmin)): 692 | class_refined_prob = np.array(Image.fromarray(pred).resize((cmax-cmin, rmax-rmin), Image.BILINEAR)) 693 | refined[rmin:rmax, cmin:cmax] = class_refined_prob 694 | 695 | pred = pred > 128 696 | gt = gt > 128 697 | 698 | ba = self.cal_ba(pred, gt) 699 | self.bas.append(ba) 700 | 701 | def get_disk_kernel(self, radius): 702 | return cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (radius*2+1, radius*2+1)) 703 | 704 | def cal_ba(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray: 705 | """ 706 | Calculate the mean absolute error. 707 | 708 | :return: ba 709 | """ 710 | 711 | gt = gt.astype(np.uint8) 712 | pred = pred.astype(np.uint8) 713 | 714 | h, w = gt.shape 715 | 716 | min_radius = 1 717 | max_radius = (w+h)/300 718 | num_steps = 5 719 | 720 | pred_acc = [None] * num_steps 721 | 722 | for i in range(num_steps): 723 | curr_radius = min_radius + int((max_radius-min_radius)/num_steps*i) 724 | 725 | kernel = self.get_disk_kernel(curr_radius) 726 | boundary_region = cv2.morphologyEx(gt, cv2.MORPH_GRADIENT, kernel) > 0 727 | 728 | gt_in_bound = gt[boundary_region] 729 | pred_in_bound = pred[boundary_region] 730 | 731 | num_edge_pixels = (boundary_region).sum() 732 | num_pred_gd_pix = ((gt_in_bound) * (pred_in_bound) + (1-gt_in_bound) * (1-pred_in_bound)).sum() 733 | 734 | pred_acc[i] = num_pred_gd_pix / num_edge_pixels 735 | 736 | ba = sum(pred_acc)/num_steps 737 | return ba 738 | 739 | def get_results(self) -> dict: 740 | mba = np.mean(np.array(self.bas, _TYPE)) 741 | return dict(mba=mba) 742 | 743 | 744 | class BIoUMeasure(object): 745 | def __init__(self, dilation_ratio=0.02): 746 | self.bious = [] 747 | self.dilation_ratio = dilation_ratio 748 | 749 | def mask_to_boundary(self, mask): 750 | h, w = mask.shape 751 | img_diag = np.sqrt(h ** 2 + w ** 2) 752 | dilation = int(round(self.dilation_ratio * img_diag)) 753 | if dilation < 1: 754 | dilation = 1 755 | # Pad image so mask truncated by the image border is also considered as boundary. 756 | new_mask = cv2.copyMakeBorder(mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=0) 757 | kernel = np.ones((3, 3), dtype=np.uint8) 758 | new_mask_erode = cv2.erode(new_mask, kernel, iterations=dilation) 759 | mask_erode = new_mask_erode[1 : h + 1, 1 : w + 1] 760 | # G_d intersects G in the paper. 761 | return mask - mask_erode 762 | 763 | def step(self, pred: np.ndarray, gt: np.ndarray): 764 | pred, gt = _prepare_data(pred, gt) 765 | 766 | bious = self.cal_biou(pred=pred, gt=gt) 767 | self.bious.append(bious) 768 | 769 | def cal_biou(self, pred, gt): 770 | pred = (pred * 255).astype(np.uint8) 771 | pred = self.mask_to_boundary(pred) 772 | gt = (gt * 255).astype(np.uint8) 773 | gt = self.mask_to_boundary(gt) 774 | gt = gt > 128 775 | 776 | bins = np.linspace(0, 256, 257) 777 | fg_hist, _ = np.histogram(pred[gt], bins=bins) # ture positive 778 | bg_hist, _ = np.histogram(pred[~gt], bins=bins) # false positive 779 | fg_w_thrs = np.cumsum(np.flip(fg_hist), axis=0) 780 | bg_w_thrs = np.cumsum(np.flip(bg_hist), axis=0) 781 | TPs = fg_w_thrs 782 | Ps = fg_w_thrs + bg_w_thrs # positives 783 | Ps[Ps == 0] = 1 784 | T = max(np.count_nonzero(gt), 1) 785 | 786 | ious = TPs / (T + bg_w_thrs) 787 | return ious 788 | 789 | def get_results(self) -> dict: 790 | biou = np.mean(np.array(self.bious, dtype=_TYPE), axis=0) 791 | return dict(biou=dict(curve=biou)) 792 | -------------------------------------------------------------------------------- /gen_best_ep.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import numpy as np 4 | from config import Config 5 | 6 | 7 | config = Config() 8 | 9 | eval_txts = sorted(glob('e_results/*_eval.txt')) 10 | print('eval_txts:', [_.split(os.sep)[-1] for _ in eval_txts]) 11 | score_panel = {} 12 | sep = '&' 13 | metrics = ['sm', 'wfm', 'hce'] # we used HCE for DIS and wFm for others. 14 | if 'DIS5K' not in config.task: 15 | metrics.remove('hce') 16 | 17 | for metric in metrics: 18 | print('Metric:', metric) 19 | current_line_nums = [] 20 | for idx_et, eval_txt in enumerate(eval_txts): 21 | with open(eval_txt, 'r') as f: 22 | lines = [l for l in f.readlines()[3:] if '.' in l] 23 | current_line_nums.append(len(lines)) 24 | for idx_et, eval_txt in enumerate(eval_txts): 25 | with open(eval_txt, 'r') as f: 26 | lines = [l for l in f.readlines()[3:] if '.' in l] 27 | for idx_line, line in enumerate(lines[:min(current_line_nums)]): # Consist line numbers by the minimal result file. 28 | properties = line.strip().strip(sep).split(sep) 29 | dataset = properties[0].strip() 30 | ckpt = properties[1].strip() 31 | if int(ckpt.split('--epoch_')[-1].strip()) < 0: 32 | continue 33 | targe_idx = { 34 | 'sm': [5, 2, 2, 5, 5, 2], 35 | 'wfm': [3, 3, 8, 3, 3, 8], 36 | 'hce': [7, -1, -1, 7, 7, -1] 37 | }[metric][['DIS5K', 'COD', 'HRSOD', 'General', 'General-2K', 'Matting'].index(config.task)] 38 | if metric != 'hce': 39 | score_sm = float(properties[targe_idx].strip()) 40 | else: 41 | score_sm = int(properties[targe_idx].strip().strip('.')) 42 | if idx_et == 0: 43 | score_panel[ckpt] = [] 44 | score_panel[ckpt].append(score_sm) 45 | 46 | metrics_min = ['hce', 'mae'] 47 | max_or_min = min if metric in metrics_min else max 48 | score_max = max_or_min(score_panel.values(), key=lambda x: np.sum(x)) 49 | 50 | good_models = [] 51 | for k, v in score_panel.items(): 52 | if (np.sum(v) <= np.sum(score_max)) if metric in metrics_min else (np.sum(v) >= np.sum(score_max)): 53 | print(k, v) 54 | good_models.append(k) 55 | 56 | # Write 57 | with open(eval_txt, 'r') as f: 58 | lines = f.readlines() 59 | info4good_models = lines[:3] 60 | metric_names = [m.strip() for m in lines[1].strip().strip('&').split('&')[2:]] 61 | testset_mean_values = {metric_name: [] for metric_name in metric_names} 62 | for good_model in good_models: 63 | for idx_et, eval_txt in enumerate(eval_txts): 64 | with open(eval_txt, 'r') as f: 65 | lines = f.readlines() 66 | for line in lines: 67 | if set([good_model]) & set([_.strip() for _ in line.split(sep)]): 68 | info4good_models.append(line) 69 | metric_scores = [float(m.strip()) for m in line.strip().strip('&').split('&')[2:]] 70 | for idx_score, metric_score in enumerate(metric_scores): 71 | testset_mean_values[metric_names[idx_score]].append(metric_score) 72 | 73 | if 'DIS5K' in config.task: 74 | testset_mean_values_lst = ['{:<4}'.format(int(np.mean(v_lst[:-1]).round())) if name == 'HCE' else '{:.3f}'.format(np.mean(v_lst[:-1])).lstrip('0') for name, v_lst in testset_mean_values.items()] # [:-1] to remove DIS-VD 75 | sample_line_for_placing_mean_values = info4good_models[-2] 76 | numbers_placed_well = sample_line_for_placing_mean_values.replace(sample_line_for_placing_mean_values.split('&')[1].strip(), 'DIS-TEs').strip().split('&')[3:] 77 | for idx_number, (number_placed_well, testset_mean_value) in enumerate(zip(numbers_placed_well, testset_mean_values_lst)): 78 | numbers_placed_well[idx_number] = number_placed_well.replace(number_placed_well.strip(), testset_mean_value) 79 | testset_mean_line = '&'.join(sample_line_for_placing_mean_values.replace(sample_line_for_placing_mean_values.split('&')[1].strip(), 'DIS-TEs').split('&')[:3] + numbers_placed_well) + '\n' 80 | info4good_models.append(testset_mean_line) 81 | info4good_models.append(lines[-1]) 82 | info = ''.join(info4good_models) 83 | print(info) 84 | with open(os.path.join('e_results', 'eval-{}_best_on_{}.txt'.format(config.task, metric)), 'w') as f: 85 | f.write(info + '\n') 86 | -------------------------------------------------------------------------------- /image_proc.py: -------------------------------------------------------------------------------- 1 | import random 2 | from PIL import Image, ImageEnhance 3 | import numpy as np 4 | import cv2 5 | 6 | 7 | def refine_foreground(image, mask): 8 | if mask.size != image.size: 9 | mask = mask.resize(image.size) 10 | image = np.array(image, dtype=np.float32) / 255.0 11 | mask = np.array(mask, dtype=np.float32) / 255.0 12 | estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=int(sum(image.shape[:2]) / 2 * 0.1)) 13 | image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8)) 14 | return image_masked 15 | 16 | 17 | def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90): 18 | # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation 19 | alpha = alpha[:, :, None] 20 | F, blur_B = FB_blur_fusion_foreground_estimator( 21 | image, image, image, alpha, r) 22 | return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0] 23 | 24 | 25 | def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90): 26 | if isinstance(image, Image.Image): 27 | image = np.array(image) / 255.0 28 | blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None] 29 | 30 | blurred_FA = cv2.blur(F * alpha, (r, r)) 31 | blurred_F = blurred_FA / (blurred_alpha + 1e-5) 32 | 33 | blurred_B1A = cv2.blur(B * (1 - alpha), (r, r)) 34 | blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5) 35 | F = blurred_F + alpha * \ 36 | (image - alpha * blurred_F - (1 - alpha) * blurred_B) 37 | F = np.clip(F, 0, 1) 38 | return F, blurred_B 39 | 40 | 41 | def preproc(image, label, preproc_methods=['flip']): 42 | if 'flip' in preproc_methods: 43 | image, label = cv_random_flip(image, label) 44 | if 'crop' in preproc_methods: 45 | image, label = random_crop(image, label) 46 | if 'rotate' in preproc_methods: 47 | image, label = random_rotate(image, label) 48 | if 'enhance' in preproc_methods: 49 | image = color_enhance(image) 50 | if 'pepper' in preproc_methods: 51 | image = random_pepper(image) 52 | return image, label 53 | 54 | 55 | def cv_random_flip(img, label): 56 | if random.random() > 0.5: 57 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 58 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 59 | return img, label 60 | 61 | 62 | def random_crop(image, label): 63 | border = 30 64 | image_width = image.size[0] 65 | image_height = image.size[1] 66 | border = int(min(image_width, image_height) * 0.1) 67 | crop_win_width = np.random.randint(image_width - border, image_width) 68 | crop_win_height = np.random.randint(image_height - border, image_height) 69 | random_region = ( 70 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1, 71 | (image_height + crop_win_height) >> 1) 72 | return image.crop(random_region), label.crop(random_region) 73 | 74 | 75 | def random_rotate(image, label, angle=15): 76 | mode = Image.BICUBIC 77 | if random.random() > 0.8: 78 | random_angle = np.random.randint(-angle, angle) 79 | image = image.rotate(random_angle, mode) 80 | label = label.rotate(random_angle, mode) 81 | return image, label 82 | 83 | 84 | def color_enhance(image): 85 | bright_intensity = random.randint(5, 15) / 10.0 86 | image = ImageEnhance.Brightness(image).enhance(bright_intensity) 87 | contrast_intensity = random.randint(5, 15) / 10.0 88 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity) 89 | color_intensity = random.randint(0, 20) / 10.0 90 | image = ImageEnhance.Color(image).enhance(color_intensity) 91 | sharp_intensity = random.randint(0, 30) / 10.0 92 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity) 93 | return image 94 | 95 | 96 | def random_gaussian(image, mean=0.1, sigma=0.35): 97 | def gaussianNoisy(im, mean=mean, sigma=sigma): 98 | for _i in range(len(im)): 99 | im[_i] += random.gauss(mean, sigma) 100 | return im 101 | 102 | img = np.asarray(image) 103 | width, height = img.shape 104 | img = gaussianNoisy(img[:].flatten(), mean, sigma) 105 | img = img.reshape([width, height]) 106 | return Image.fromarray(np.uint8(img)) 107 | 108 | 109 | def random_pepper(img, N=0.0015): 110 | img = np.array(img) 111 | noiseNum = int(N * img.shape[0] * img.shape[1]) 112 | for i in range(noiseNum): 113 | randX = random.randint(0, img.shape[0] - 1) 114 | randY = random.randint(0, img.shape[1] - 1) 115 | img[randX, randY] = random.randint(0, 1) * 255 116 | return Image.fromarray(img) 117 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from glob import glob 4 | from tqdm import tqdm 5 | import cv2 6 | import torch 7 | from contextlib import nullcontext 8 | 9 | from dataset import MyData 10 | from models.birefnet import BiRefNet, BiRefNetC2F 11 | from utils import save_tensor_img, check_state_dict 12 | from config import Config 13 | 14 | 15 | config = Config() 16 | 17 | mixed_precision = config.mixed_precision 18 | if mixed_precision == 'fp16': 19 | mixed_dtype = torch.float16 20 | elif mixed_precision == 'bf16': 21 | mixed_dtype = torch.bfloat16 22 | else: 23 | mixed_dtype = None 24 | 25 | autocast_ctx = torch.amp.autocast(device_type='cuda', dtype=mixed_dtype) if mixed_dtype else nullcontext() 26 | 27 | 28 | def inference(model, data_loader_test, pred_root, method, testset, device=0): 29 | model_training = model.training 30 | if model_training: 31 | model.eval() 32 | for batch in tqdm(data_loader_test, total=len(data_loader_test)) if config.verbose_eval else data_loader_test: 33 | inputs = batch[0].to(device) 34 | label_paths = batch[-1] 35 | with autocast_ctx, torch.no_grad(): 36 | scaled_preds = model(inputs)[-1].sigmoid().to(torch.float32) 37 | 38 | os.makedirs(os.path.join(pred_root, method, testset), exist_ok=True) 39 | 40 | for idx_sample in range(scaled_preds.shape[0]): 41 | res = torch.nn.functional.interpolate( 42 | scaled_preds[idx_sample].unsqueeze(0), 43 | size=cv2.imread(label_paths[idx_sample], cv2.IMREAD_GRAYSCALE).shape[:2], 44 | mode='bilinear', 45 | align_corners=True 46 | ) 47 | save_tensor_img(res, os.path.join(os.path.join(pred_root, method, testset), label_paths[idx_sample].replace('\\', '/').split('/')[-1])) # test set dir + file name 48 | if model_training: 49 | model.train() 50 | return None 51 | 52 | 53 | def main(args): 54 | device = config.device 55 | if args.ckpt_folder: 56 | print('Testing with models in {}'.format(args.ckpt_folder)) 57 | else: 58 | print('Testing with model {}'.format(args.ckpt)) 59 | 60 | if config.model == 'BiRefNet': 61 | model = BiRefNet(bb_pretrained=False) 62 | elif config.model == 'BiRefNetC2F': 63 | model = BiRefNetC2F(bb_pretrained=False) 64 | weights_lst = sorted( 65 | glob(os.path.join(args.ckpt_folder, '*.pth')) if args.ckpt_folder else [args.ckpt], 66 | key=lambda x: int(x.split('epoch_')[-1].split('.pth')[0]), 67 | reverse=True 68 | ) 69 | try: 70 | if args.resolution in [None, 'None', 0, '']: 71 | # Use original resolution for inference. 72 | data_size = None 73 | elif args.resolution in ['config.size']: 74 | data_size = config.size 75 | else: 76 | data_size = [int(l) for l in args.resolution.split('x')] 77 | except Exception as e: 78 | print(f"Exception: {type(e).__name__} at line {e.__traceback__.tb_lineno} of {__file__}: {e}") 79 | # default as the config.size. 80 | data_size = config.size 81 | 82 | for testset in args.testsets.split('+'): 83 | print('>>>> Testset: {}...'.format(testset)) 84 | data_loader_test = torch.utils.data.DataLoader( 85 | dataset=MyData(testset, data_size=data_size, is_train=False), 86 | batch_size=config.batch_size_valid, shuffle=False, num_workers=config.num_workers, pin_memory=True 87 | ) 88 | for weights in weights_lst: 89 | if int(weights.strip('.pth').split('epoch_')[-1]) % 1 != 0: 90 | continue 91 | print('\tInferencing {}...'.format(weights)) 92 | state_dict = torch.load(weights, map_location='cpu', weights_only=True) 93 | state_dict = check_state_dict(state_dict) 94 | model.load_state_dict(state_dict) 95 | model = model.to(device) 96 | inference( 97 | model, data_loader_test=data_loader_test, pred_root=args.pred_root, 98 | method='--'.join([w.rstrip('.pth') for w in weights.split(os.sep)[-2:]]) + '-reso_{}'.format('x'.join([str(s) for s in data_size])), 99 | testset=testset, device=config.device 100 | ) 101 | 102 | 103 | if __name__ == '__main__': 104 | # Parameter from command line 105 | parser = argparse.ArgumentParser(description='') 106 | parser.add_argument('--ckpt', type=str, help='model folder') 107 | parser.add_argument('--ckpt_folder', default=sorted(glob(os.path.join('ckpt', '*')))[-1], type=str, help='model folder') 108 | parser.add_argument('--pred_root', default='e_preds', type=str, help='Output folder') 109 | parser.add_argument('--resolution', default='default', type=str, help='WeixHei') 110 | parser.add_argument('--testsets', 111 | default=config.testsets.replace(',', '+'), 112 | type=str, 113 | help="Test all sets: DIS5K -> 'DIS-VD+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'") 114 | 115 | args = parser.parse_args() 116 | 117 | if config.precisionHigh: 118 | torch.set_float32_matmul_precision('high') 119 | main(args) 120 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from math import exp 6 | from config import Config 7 | 8 | 9 | class ContourLoss(torch.nn.Module): 10 | def __init__(self): 11 | super(ContourLoss, self).__init__() 12 | 13 | def forward(self, pred, target, weight=10): 14 | ''' 15 | target, pred: tensor of shape (B, C, H, W), where target[:,:,region_in_contour] == 1, 16 | target[:,:,region_out_contour] == 0. 17 | weight: scalar, length term weight. 18 | ''' 19 | # length term 20 | delta_r = pred[:,:,1:,:] - pred[:,:,:-1,:] # horizontal gradient (B, C, H-1, W) 21 | delta_c = pred[:,:,:,1:] - pred[:,:,:,:-1] # vertical gradient (B, C, H, W-1) 22 | 23 | delta_r = delta_r[:,:,1:,:-2]**2 # (B, C, H-2, W-2) 24 | delta_c = delta_c[:,:,:-2,1:]**2 # (B, C, H-2, W-2) 25 | delta_pred = torch.abs(delta_r + delta_c) 26 | 27 | epsilon = 1e-8 # where is a parameter to avoid square root is zero in practice. 28 | length = torch.mean(torch.sqrt(delta_pred + epsilon)) # eq.(11) in the paper, mean is used instead of sum. 29 | 30 | c_in = torch.ones_like(pred) 31 | c_out = torch.zeros_like(pred) 32 | 33 | region_in = torch.mean( pred * (target - c_in )**2 ) # equ.(12) in the paper, mean is used instead of sum. 34 | region_out = torch.mean( (1-pred) * (target - c_out)**2 ) 35 | region = region_in + region_out 36 | 37 | loss = weight * length + region 38 | 39 | return loss 40 | 41 | 42 | class IoULoss(torch.nn.Module): 43 | def __init__(self): 44 | super(IoULoss, self).__init__() 45 | 46 | def forward(self, pred, target): 47 | b = pred.shape[0] 48 | IoU = 0.0 49 | for i in range(0, b): 50 | # compute the IoU of the foreground 51 | Iand1 = torch.sum(target[i, :, :, :] * pred[i, :, :, :]) 52 | Ior1 = torch.sum(target[i, :, :, :]) + torch.sum(pred[i, :, :, :]) - Iand1 53 | IoU1 = Iand1 / Ior1 54 | # IoU loss is (1-IoU1) 55 | IoU = IoU + (1-IoU1) 56 | # return IoU/b 57 | return IoU 58 | 59 | 60 | class StructureLoss(torch.nn.Module): 61 | def __init__(self): 62 | super(StructureLoss, self).__init__() 63 | 64 | def forward(self, pred, target): 65 | weit = 1+5*torch.abs(F.avg_pool2d(target, kernel_size=31, stride=1, padding=15)-target) 66 | wbce = F.binary_cross_entropy_with_logits(pred, target, reduction='none') 67 | wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3)) 68 | 69 | pred = torch.sigmoid(pred) 70 | inter = ((pred * target) * weit).sum(dim=(2, 3)) 71 | union = ((pred + target) * weit).sum(dim=(2, 3)) 72 | wiou = 1-(inter+1)/(union-inter+1) 73 | 74 | return (wbce+wiou).mean() 75 | 76 | 77 | class PatchIoULoss(torch.nn.Module): 78 | def __init__(self): 79 | super(PatchIoULoss, self).__init__() 80 | self.iou_loss = IoULoss() 81 | 82 | def forward(self, pred, target): 83 | win_y, win_x = 64, 64 84 | iou_loss = 0. 85 | for anchor_y in range(0, target.shape[0], win_y): 86 | for anchor_x in range(0, target.shape[1], win_y): 87 | patch_pred = pred[:, :, anchor_y:anchor_y+win_y, anchor_x:anchor_x+win_x] 88 | patch_target = target[:, :, anchor_y:anchor_y+win_y, anchor_x:anchor_x+win_x] 89 | patch_iou_loss = self.iou_loss(patch_pred, patch_target) 90 | iou_loss += patch_iou_loss 91 | return iou_loss 92 | 93 | 94 | class ThrReg_loss(torch.nn.Module): 95 | def __init__(self): 96 | super(ThrReg_loss, self).__init__() 97 | 98 | def forward(self, pred, gt=None): 99 | return torch.mean(1 - ((pred - 0) ** 2 + (pred - 1) ** 2)) 100 | 101 | 102 | class ClsLoss(nn.Module): 103 | """ 104 | Auxiliary classification loss for each refined class output. 105 | """ 106 | def __init__(self): 107 | super(ClsLoss, self).__init__() 108 | self.config = Config() 109 | self.lambdas_cls = self.config.lambdas_cls 110 | 111 | self.criterions_last = { 112 | 'ce': nn.CrossEntropyLoss() 113 | } 114 | 115 | def forward(self, preds, gt): 116 | loss = 0. 117 | for _, pred_lvl in enumerate(preds): 118 | if pred_lvl is None: 119 | continue 120 | for criterion_name, criterion in self.criterions_last.items(): 121 | loss += criterion(pred_lvl, gt) * self.lambdas_cls[criterion_name] 122 | return loss 123 | 124 | 125 | class PixLoss(nn.Module): 126 | """ 127 | Pixel loss for each refined map output. 128 | """ 129 | def __init__(self): 130 | super(PixLoss, self).__init__() 131 | self.config = Config() 132 | self.lambdas_pix_last = self.config.lambdas_pix_last 133 | 134 | self.criterions_last = {} 135 | if 'bce' in self.lambdas_pix_last and self.lambdas_pix_last['bce']: 136 | self.criterions_last['bce'] = nn.BCELoss() 137 | if 'iou' in self.lambdas_pix_last and self.lambdas_pix_last['iou']: 138 | self.criterions_last['iou'] = IoULoss() 139 | if 'iou_patch' in self.lambdas_pix_last and self.lambdas_pix_last['iou_patch']: 140 | self.criterions_last['iou_patch'] = PatchIoULoss() 141 | if 'ssim' in self.lambdas_pix_last and self.lambdas_pix_last['ssim']: 142 | self.criterions_last['ssim'] = SSIMLoss() 143 | if 'mae' in self.lambdas_pix_last and self.lambdas_pix_last['mae']: 144 | self.criterions_last['mae'] = nn.L1Loss() 145 | if 'mse' in self.lambdas_pix_last and self.lambdas_pix_last['mse']: 146 | self.criterions_last['mse'] = nn.MSELoss() 147 | if 'reg' in self.lambdas_pix_last and self.lambdas_pix_last['reg']: 148 | self.criterions_last['reg'] = ThrReg_loss() 149 | if 'cnt' in self.lambdas_pix_last and self.lambdas_pix_last['cnt']: 150 | self.criterions_last['cnt'] = ContourLoss() 151 | if 'structure' in self.lambdas_pix_last and self.lambdas_pix_last['structure']: 152 | self.criterions_last['structure'] = StructureLoss() 153 | 154 | def forward(self, scaled_preds, gt, pix_loss_lambda=1.0): 155 | loss = 0. 156 | loss_dict = {} 157 | for _, pred_lvl in enumerate(scaled_preds): 158 | if pred_lvl.shape != gt.shape: 159 | pred_lvl = nn.functional.interpolate(pred_lvl, size=gt.shape[2:], mode='bilinear', align_corners=True) 160 | for criterion_name, criterion in self.criterions_last.items(): 161 | _loss = criterion(pred_lvl.sigmoid(), gt) * self.lambdas_pix_last[criterion_name] * pix_loss_lambda 162 | loss += _loss 163 | loss_dict[criterion_name] = loss_dict.get(criterion_name, 0.) + _loss.item() / len(scaled_preds) 164 | # print(criterion_name, _loss.item()) 165 | return loss, loss_dict 166 | 167 | 168 | class SSIMLoss(torch.nn.Module): 169 | def __init__(self, window_size=11, size_average=True): 170 | super(SSIMLoss, self).__init__() 171 | self.window_size = window_size 172 | self.size_average = size_average 173 | self.channel = 1 174 | self.window = create_window(window_size, self.channel) 175 | 176 | def forward(self, img1, img2): 177 | (_, channel, _, _) = img1.size() 178 | if channel == self.channel and self.window.data.type() == img1.data.type(): 179 | window = self.window 180 | else: 181 | window = create_window(self.window_size, channel) 182 | if img1.is_cuda: 183 | window = window.cuda(img1.get_device()) 184 | window = window.type_as(img1) 185 | self.window = window 186 | self.channel = channel 187 | return 1 - (1 + _ssim(img1, img2, window, self.window_size, channel, self.size_average)) / 2 188 | 189 | 190 | def gaussian(window_size, sigma): 191 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 192 | return gauss/gauss.sum() 193 | 194 | 195 | def create_window(window_size, channel): 196 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 197 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 198 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 199 | return window 200 | 201 | 202 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 203 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups=channel) 204 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups=channel) 205 | 206 | mu1_sq = mu1.pow(2) 207 | mu2_sq = mu2.pow(2) 208 | mu1_mu2 = mu1*mu2 209 | 210 | sigma1_sq = F.conv2d(img1*img1, window, padding=window_size//2, groups=channel) - mu1_sq 211 | sigma2_sq = F.conv2d(img2*img2, window, padding=window_size//2, groups=channel) - mu2_sq 212 | sigma12 = F.conv2d(img1*img2, window, padding=window_size//2, groups=channel) - mu1_mu2 213 | 214 | C1 = 0.01**2 215 | C2 = 0.03**2 216 | 217 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 218 | 219 | if size_average: 220 | return ssim_map.mean() 221 | else: 222 | return ssim_map.mean(1).mean(1).mean(1) 223 | 224 | 225 | def SSIM(x, y): 226 | C1 = 0.01 ** 2 227 | C2 = 0.03 ** 2 228 | 229 | mu_x = nn.AvgPool2d(3, 1, 1)(x) 230 | mu_y = nn.AvgPool2d(3, 1, 1)(y) 231 | mu_x_mu_y = mu_x * mu_y 232 | mu_x_sq = mu_x.pow(2) 233 | mu_y_sq = mu_y.pow(2) 234 | 235 | sigma_x = nn.AvgPool2d(3, 1, 1)(x * x) - mu_x_sq 236 | sigma_y = nn.AvgPool2d(3, 1, 1)(y * y) - mu_y_sq 237 | sigma_xy = nn.AvgPool2d(3, 1, 1)(x * y) - mu_x_mu_y 238 | 239 | SSIM_n = (2 * mu_x_mu_y + C1) * (2 * sigma_xy + C2) 240 | SSIM_d = (mu_x_sq + mu_y_sq + C1) * (sigma_x + sigma_y + C2) 241 | SSIM = SSIM_n / SSIM_d 242 | 243 | return torch.clamp((1 - SSIM) / 2, 0, 1) 244 | 245 | 246 | def saliency_structure_consistency(x, y): 247 | ssim = torch.mean(SSIM(x,y)) 248 | return ssim 249 | -------------------------------------------------------------------------------- /make_a_copy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Set dst repo here. 3 | repo=$1 4 | mkdir ../${repo} 5 | mkdir ../${repo}/evaluation 6 | mkdir ../${repo}/models 7 | mkdir ../${repo}/models/backbones 8 | mkdir ../${repo}/models/modules 9 | mkdir ../${repo}/models/refinement 10 | 11 | cp ./*.sh ../${repo} 12 | cp ./*.py ../${repo} 13 | cp ./evaluation/*.py ../${repo}/evaluation 14 | cp ./models/*.py ../${repo}/models 15 | cp ./models/backbones/*.py ../${repo}/models/backbones 16 | cp ./models/modules/*.py ../${repo}/models/modules 17 | cp ./models/refinement/*.py ../${repo}/models/refinement 18 | cp -r ./.git* ../${repo} 19 | -------------------------------------------------------------------------------- /models/backbones/build_backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | from torchvision.models import vgg16, vgg16_bn, VGG16_Weights, VGG16_BN_Weights, resnet50, ResNet50_Weights 5 | from models.backbones.pvt_v2 import pvt_v2_b0, pvt_v2_b1, pvt_v2_b2, pvt_v2_b5 6 | from models.backbones.swin_v1 import swin_v1_t, swin_v1_s, swin_v1_b, swin_v1_l 7 | from config import Config 8 | 9 | 10 | config = Config() 11 | 12 | def build_backbone(bb_name, pretrained=True, params_settings=''): 13 | if bb_name == 'vgg16': 14 | bb_net = list(vgg16(weights=VGG16_Weights.DEFAULT if pretrained else None).children())[0] 15 | bb = nn.Sequential(OrderedDict({'conv1': bb_net[:10], 'conv2': bb_net[10:17], 'conv3': bb_net[17:24], 'conv4': bb_net[24:31]})) 16 | elif bb_name == 'vgg16bn': 17 | bb_net = list(vgg16_bn(weights=VGG16_BN_Weights.DEFAULT if pretrained else None).children())[0] 18 | bb = nn.Sequential(OrderedDict({'conv1': bb_net[:14], 'conv2': bb_net[14:24], 'conv3': bb_net[24:34], 'conv4': bb_net[34:44]})) 19 | elif bb_name == 'resnet50': 20 | bb_net = list(resnet50(weights=ResNet50_Weights.DEFAULT if pretrained else None).children()) 21 | bb = nn.Sequential(OrderedDict({'conv1': nn.Sequential(*bb_net[0:4], bb_net[4]), 'conv2': bb_net[5], 'conv3': bb_net[6], 'conv4': bb_net[7]})) 22 | else: 23 | bb = eval('{}({})'.format(bb_name, params_settings)) 24 | if pretrained: 25 | bb = load_weights(bb, bb_name) 26 | return bb 27 | 28 | def load_weights(model, model_name): 29 | save_model = torch.load(config.weights[model_name], map_location='cpu', weights_only=True) 30 | model_dict = model.state_dict() 31 | state_dict = {k: v if v.size() == model_dict[k].size() else model_dict[k] for k, v in save_model.items() if k in model_dict.keys()} 32 | # to ignore the weights with mismatched size when I modify the backbone itself. 33 | if not state_dict: 34 | save_model_keys = list(save_model.keys()) 35 | sub_item = save_model_keys[0] if len(save_model_keys) == 1 else None 36 | state_dict = {k: v if v.size() == model_dict[k].size() else model_dict[k] for k, v in save_model[sub_item].items() if k in model_dict.keys()} 37 | if not state_dict or not sub_item: 38 | print('Weights are not successfully loaded. Check the state dict of weights file.') 39 | return None 40 | else: 41 | print('Found correct weights in the "{}" item of loaded state_dict.'.format(sub_item)) 42 | model_dict.update(state_dict) 43 | model.load_state_dict(model_dict) 44 | return model 45 | -------------------------------------------------------------------------------- /models/backbones/pvt_v2.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | import torch 4 | import torch.nn as nn 5 | 6 | from timm.layers import DropPath, to_2tuple, trunc_normal_ 7 | 8 | from config import Config 9 | 10 | config = Config() 11 | 12 | 13 | class Mlp(nn.Module): 14 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 15 | super().__init__() 16 | out_features = out_features or in_features 17 | hidden_features = hidden_features or in_features 18 | self.fc1 = nn.Linear(in_features, hidden_features) 19 | self.dwconv = DWConv(hidden_features) 20 | self.act = act_layer() 21 | self.fc2 = nn.Linear(hidden_features, out_features) 22 | self.drop = nn.Dropout(drop) 23 | 24 | self.apply(self._init_weights) 25 | 26 | def _init_weights(self, m): 27 | if isinstance(m, nn.Linear): 28 | trunc_normal_(m.weight, std=.02) 29 | if isinstance(m, nn.Linear) and m.bias is not None: 30 | nn.init.constant_(m.bias, 0) 31 | elif isinstance(m, nn.LayerNorm): 32 | nn.init.constant_(m.bias, 0) 33 | nn.init.constant_(m.weight, 1.0) 34 | elif isinstance(m, nn.Conv2d): 35 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 36 | fan_out //= m.groups 37 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 38 | if m.bias is not None: 39 | m.bias.data.zero_() 40 | 41 | def forward(self, x, H, W): 42 | x = self.fc1(x) 43 | x = self.dwconv(x, H, W) 44 | x = self.act(x) 45 | x = self.drop(x) 46 | x = self.fc2(x) 47 | x = self.drop(x) 48 | return x 49 | 50 | 51 | class Attention(nn.Module): 52 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 53 | super().__init__() 54 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 55 | 56 | self.dim = dim 57 | self.num_heads = num_heads 58 | head_dim = dim // num_heads 59 | self.scale = qk_scale or head_dim ** -0.5 60 | 61 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 62 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 63 | self.attn_drop_prob = attn_drop 64 | self.attn_drop = nn.Dropout(attn_drop) 65 | self.proj = nn.Linear(dim, dim) 66 | self.proj_drop = nn.Dropout(proj_drop) 67 | 68 | self.sr_ratio = sr_ratio 69 | if sr_ratio > 1: 70 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 71 | self.norm = nn.LayerNorm(dim) 72 | 73 | self.apply(self._init_weights) 74 | 75 | def _init_weights(self, m): 76 | if isinstance(m, nn.Linear): 77 | trunc_normal_(m.weight, std=.02) 78 | if isinstance(m, nn.Linear) and m.bias is not None: 79 | nn.init.constant_(m.bias, 0) 80 | elif isinstance(m, nn.LayerNorm): 81 | nn.init.constant_(m.bias, 0) 82 | nn.init.constant_(m.weight, 1.0) 83 | elif isinstance(m, nn.Conv2d): 84 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 85 | fan_out //= m.groups 86 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 87 | if m.bias is not None: 88 | m.bias.data.zero_() 89 | 90 | def forward(self, x, H, W): 91 | B, N, C = x.shape 92 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 93 | 94 | if self.sr_ratio > 1: 95 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 96 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 97 | x_ = self.norm(x_) 98 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 99 | else: 100 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 101 | k, v = kv[0], kv[1] 102 | 103 | if config.SDPA_enabled: 104 | x = torch.nn.functional.scaled_dot_product_attention( 105 | q, k, v, 106 | attn_mask=None, dropout_p=self.attn_drop_prob, is_causal=False 107 | ).transpose(1, 2).reshape(B, N, C) 108 | else: 109 | attn = (q @ k.transpose(-2, -1)) * self.scale 110 | attn = attn.softmax(dim=-1) 111 | attn = self.attn_drop(attn) 112 | 113 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 114 | x = self.proj(x) 115 | x = self.proj_drop(x) 116 | 117 | return x 118 | 119 | 120 | class Block(nn.Module): 121 | 122 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 123 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 124 | super().__init__() 125 | self.norm1 = norm_layer(dim) 126 | self.attn = Attention( 127 | dim, 128 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 129 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 130 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 131 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 132 | self.norm2 = norm_layer(dim) 133 | mlp_hidden_dim = int(dim * mlp_ratio) 134 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 135 | 136 | self.apply(self._init_weights) 137 | 138 | def _init_weights(self, m): 139 | if isinstance(m, nn.Linear): 140 | trunc_normal_(m.weight, std=.02) 141 | if isinstance(m, nn.Linear) and m.bias is not None: 142 | nn.init.constant_(m.bias, 0) 143 | elif isinstance(m, nn.LayerNorm): 144 | nn.init.constant_(m.bias, 0) 145 | nn.init.constant_(m.weight, 1.0) 146 | elif isinstance(m, nn.Conv2d): 147 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 148 | fan_out //= m.groups 149 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 150 | if m.bias is not None: 151 | m.bias.data.zero_() 152 | 153 | def forward(self, x, H, W): 154 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 155 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 156 | 157 | return x 158 | 159 | 160 | class OverlapPatchEmbed(nn.Module): 161 | """ Image to Patch Embedding 162 | """ 163 | 164 | def __init__(self, img_size=224, patch_size=7, stride=4, in_channels=3, embed_dim=768): 165 | super().__init__() 166 | img_size = to_2tuple(img_size) 167 | patch_size = to_2tuple(patch_size) 168 | 169 | self.img_size = img_size 170 | self.patch_size = patch_size 171 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 172 | self.num_patches = self.H * self.W 173 | self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=stride, 174 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 175 | self.norm = nn.LayerNorm(embed_dim) 176 | 177 | self.apply(self._init_weights) 178 | 179 | def _init_weights(self, m): 180 | if isinstance(m, nn.Linear): 181 | trunc_normal_(m.weight, std=.02) 182 | if isinstance(m, nn.Linear) and m.bias is not None: 183 | nn.init.constant_(m.bias, 0) 184 | elif isinstance(m, nn.LayerNorm): 185 | nn.init.constant_(m.bias, 0) 186 | nn.init.constant_(m.weight, 1.0) 187 | elif isinstance(m, nn.Conv2d): 188 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 189 | fan_out //= m.groups 190 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 191 | if m.bias is not None: 192 | m.bias.data.zero_() 193 | 194 | def forward(self, x): 195 | x = self.proj(x) 196 | _, _, H, W = x.shape 197 | x = x.flatten(2).transpose(1, 2) 198 | x = self.norm(x) 199 | 200 | return x, H, W 201 | 202 | 203 | class PyramidVisionTransformerImpr(nn.Module): 204 | def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 205 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 206 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 207 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): 208 | super().__init__() 209 | self.num_classes = num_classes 210 | self.depths = depths 211 | 212 | # patch_embed 213 | self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_channels=in_channels, 214 | embed_dim=embed_dims[0]) 215 | self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_channels=embed_dims[0], 216 | embed_dim=embed_dims[1]) 217 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_channels=embed_dims[1], 218 | embed_dim=embed_dims[2]) 219 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_channels=embed_dims[2], 220 | embed_dim=embed_dims[3]) 221 | 222 | # transformer encoder 223 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 224 | cur = 0 225 | self.block1 = nn.ModuleList([Block( 226 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 227 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 228 | sr_ratio=sr_ratios[0]) 229 | for i in range(depths[0])]) 230 | self.norm1 = norm_layer(embed_dims[0]) 231 | 232 | cur += depths[0] 233 | self.block2 = nn.ModuleList([Block( 234 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 235 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 236 | sr_ratio=sr_ratios[1]) 237 | for i in range(depths[1])]) 238 | self.norm2 = norm_layer(embed_dims[1]) 239 | 240 | cur += depths[1] 241 | self.block3 = nn.ModuleList([Block( 242 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 243 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 244 | sr_ratio=sr_ratios[2]) 245 | for i in range(depths[2])]) 246 | self.norm3 = norm_layer(embed_dims[2]) 247 | 248 | cur += depths[2] 249 | self.block4 = nn.ModuleList([Block( 250 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, 251 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 252 | sr_ratio=sr_ratios[3]) 253 | for i in range(depths[3])]) 254 | self.norm4 = norm_layer(embed_dims[3]) 255 | 256 | # classification head 257 | # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 258 | 259 | self.apply(self._init_weights) 260 | 261 | def _init_weights(self, m): 262 | if isinstance(m, nn.Linear): 263 | trunc_normal_(m.weight, std=.02) 264 | if isinstance(m, nn.Linear) and m.bias is not None: 265 | nn.init.constant_(m.bias, 0) 266 | elif isinstance(m, nn.LayerNorm): 267 | nn.init.constant_(m.bias, 0) 268 | nn.init.constant_(m.weight, 1.0) 269 | elif isinstance(m, nn.Conv2d): 270 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 271 | fan_out //= m.groups 272 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 273 | if m.bias is not None: 274 | m.bias.data.zero_() 275 | 276 | def init_weights(self, pretrained=None): 277 | if isinstance(pretrained, str): 278 | logger = 1 279 | #load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) 280 | 281 | def reset_drop_path(self, drop_path_rate): 282 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] 283 | cur = 0 284 | for i in range(self.depths[0]): 285 | self.block1[i].drop_path.drop_prob = dpr[cur + i] 286 | 287 | cur += self.depths[0] 288 | for i in range(self.depths[1]): 289 | self.block2[i].drop_path.drop_prob = dpr[cur + i] 290 | 291 | cur += self.depths[1] 292 | for i in range(self.depths[2]): 293 | self.block3[i].drop_path.drop_prob = dpr[cur + i] 294 | 295 | cur += self.depths[2] 296 | for i in range(self.depths[3]): 297 | self.block4[i].drop_path.drop_prob = dpr[cur + i] 298 | 299 | def freeze_patch_emb(self): 300 | self.patch_embed1.requires_grad = False 301 | 302 | @torch.jit.ignore 303 | def no_weight_decay(self): 304 | return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better 305 | 306 | def get_classifier(self): 307 | return self.head 308 | 309 | def reset_classifier(self, num_classes, global_pool=''): 310 | self.num_classes = num_classes 311 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 312 | 313 | def forward_features(self, x): 314 | B = x.shape[0] 315 | outs = [] 316 | 317 | # stage 1 318 | x, H, W = self.patch_embed1(x) 319 | for i, blk in enumerate(self.block1): 320 | x = blk(x, H, W) 321 | x = self.norm1(x) 322 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 323 | outs.append(x) 324 | 325 | # stage 2 326 | x, H, W = self.patch_embed2(x) 327 | for i, blk in enumerate(self.block2): 328 | x = blk(x, H, W) 329 | x = self.norm2(x) 330 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 331 | outs.append(x) 332 | 333 | # stage 3 334 | x, H, W = self.patch_embed3(x) 335 | for i, blk in enumerate(self.block3): 336 | x = blk(x, H, W) 337 | x = self.norm3(x) 338 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 339 | outs.append(x) 340 | 341 | # stage 4 342 | x, H, W = self.patch_embed4(x) 343 | for i, blk in enumerate(self.block4): 344 | x = blk(x, H, W) 345 | x = self.norm4(x) 346 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 347 | outs.append(x) 348 | 349 | return outs 350 | 351 | # return x.mean(dim=1) 352 | 353 | def forward(self, x): 354 | x = self.forward_features(x) 355 | # x = self.head(x) 356 | 357 | return x 358 | 359 | 360 | class DWConv(nn.Module): 361 | def __init__(self, dim=768): 362 | super(DWConv, self).__init__() 363 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 364 | 365 | def forward(self, x, H, W): 366 | B, N, C = x.shape 367 | x = x.transpose(1, 2).view(B, C, H, W).contiguous() 368 | x = self.dwconv(x) 369 | x = x.flatten(2).transpose(1, 2) 370 | 371 | return x 372 | 373 | 374 | def _conv_filter(state_dict, patch_size=16): 375 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 376 | out_dict = {} 377 | for k, v in state_dict.items(): 378 | if 'patch_embed.proj.weight' in k: 379 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 380 | out_dict[k] = v 381 | 382 | return out_dict 383 | 384 | 385 | class pvt_v2_b0(PyramidVisionTransformerImpr): 386 | def __init__(self, **kwargs): 387 | super(pvt_v2_b0, self).__init__( 388 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 389 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 390 | drop_rate=0.0, drop_path_rate=0.1) 391 | 392 | 393 | class pvt_v2_b1(PyramidVisionTransformerImpr): 394 | def __init__(self, **kwargs): 395 | super(pvt_v2_b1, self).__init__( 396 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 397 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 398 | drop_rate=0.0, drop_path_rate=0.1) 399 | 400 | 401 | class pvt_v2_b2(PyramidVisionTransformerImpr): 402 | def __init__(self, in_channels=3, **kwargs): 403 | super(pvt_v2_b2, self).__init__( 404 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 405 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], 406 | drop_rate=0.0, drop_path_rate=0.1, in_channels=in_channels) 407 | 408 | 409 | class pvt_v2_b3(PyramidVisionTransformerImpr): 410 | def __init__(self, **kwargs): 411 | super(pvt_v2_b3, self).__init__( 412 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 413 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 414 | drop_rate=0.0, drop_path_rate=0.1) 415 | 416 | 417 | class pvt_v2_b4(PyramidVisionTransformerImpr): 418 | def __init__(self, **kwargs): 419 | super(pvt_v2_b4, self).__init__( 420 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 421 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 422 | drop_rate=0.0, drop_path_rate=0.1) 423 | 424 | 425 | class pvt_v2_b5(PyramidVisionTransformerImpr): 426 | def __init__(self, **kwargs): 427 | super(pvt_v2_b5, self).__init__( 428 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 429 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], 430 | drop_rate=0.0, drop_path_rate=0.1) 431 | -------------------------------------------------------------------------------- /models/backbones/swin_v1.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu, Yutong Lin, Yixuan Wei 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.utils.checkpoint as checkpoint 12 | import numpy as np 13 | from timm.layers import DropPath, to_2tuple, trunc_normal_ 14 | 15 | from config import Config 16 | 17 | 18 | config = Config() 19 | 20 | class Mlp(nn.Module): 21 | """ Multilayer perceptron.""" 22 | 23 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 24 | super().__init__() 25 | out_features = out_features or in_features 26 | hidden_features = hidden_features or in_features 27 | self.fc1 = nn.Linear(in_features, hidden_features) 28 | self.act = act_layer() 29 | self.fc2 = nn.Linear(hidden_features, out_features) 30 | self.drop = nn.Dropout(drop) 31 | 32 | def forward(self, x): 33 | x = self.fc1(x) 34 | x = self.act(x) 35 | x = self.drop(x) 36 | x = self.fc2(x) 37 | x = self.drop(x) 38 | return x 39 | 40 | 41 | def window_partition(x, window_size): 42 | """ 43 | Args: 44 | x: (B, H, W, C) 45 | window_size (int): window size 46 | 47 | Returns: 48 | windows: (num_windows*B, window_size, window_size, C) 49 | """ 50 | B, H, W, C = x.shape 51 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 52 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 53 | return windows 54 | 55 | 56 | def window_reverse(windows, window_size, H, W): 57 | """ 58 | Args: 59 | windows: (num_windows*B, window_size, window_size, C) 60 | window_size (int): Window size 61 | H (int): Height of image 62 | W (int): Width of image 63 | 64 | Returns: 65 | x: (B, H, W, C) 66 | """ 67 | C = int(windows.shape[-1]) 68 | x = windows.view(-1, H // window_size, W // window_size, window_size, window_size, C) 69 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) 70 | return x 71 | 72 | 73 | class WindowAttention(nn.Module): 74 | """ Window based multi-head self attention (W-MSA) module with relative position bias. 75 | It supports both of shifted and non-shifted window. 76 | 77 | Args: 78 | dim (int): Number of input channels. 79 | window_size (tuple[int]): The height and width of the window. 80 | num_heads (int): Number of attention heads. 81 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 82 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 83 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 84 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 85 | """ 86 | 87 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 88 | 89 | super().__init__() 90 | self.dim = dim 91 | self.window_size = window_size # Wh, Ww 92 | self.num_heads = num_heads 93 | head_dim = dim // num_heads 94 | self.scale = qk_scale or head_dim ** -0.5 95 | 96 | # define a parameter table of relative position bias 97 | self.relative_position_bias_table = nn.Parameter( 98 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 99 | 100 | # get pair-wise relative position index for each token inside the window 101 | coords_h = torch.arange(self.window_size[0]) 102 | coords_w = torch.arange(self.window_size[1]) 103 | coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww 104 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 105 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 106 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 107 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 108 | relative_coords[:, :, 1] += self.window_size[1] - 1 109 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 110 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 111 | self.register_buffer("relative_position_index", relative_position_index) 112 | 113 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 114 | self.attn_drop_prob = attn_drop 115 | self.attn_drop = nn.Dropout(attn_drop) 116 | self.proj = nn.Linear(dim, dim) 117 | self.proj_drop = nn.Dropout(proj_drop) 118 | 119 | trunc_normal_(self.relative_position_bias_table, std=.02) 120 | self.softmax = nn.Softmax(dim=-1) 121 | 122 | def forward(self, x, mask=None): 123 | """ Forward function. 124 | 125 | Args: 126 | x: input features with shape of (num_windows*B, N, C) 127 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 128 | """ 129 | B_, N, C = x.shape 130 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 131 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 132 | 133 | q = q * self.scale 134 | 135 | if config.SDPA_enabled: 136 | x = torch.nn.functional.scaled_dot_product_attention( 137 | q, k, v, 138 | attn_mask=None, dropout_p=self.attn_drop_prob, is_causal=False 139 | ).transpose(1, 2).reshape(B_, N, C) 140 | else: 141 | attn = (q @ k.transpose(-2, -1)) 142 | 143 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 144 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 145 | ) # Wh*Ww, Wh*Ww, nH 146 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 147 | attn = attn + relative_position_bias.unsqueeze(0) 148 | 149 | if mask is not None: 150 | nW = mask.shape[0] 151 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 152 | attn = attn.view(-1, self.num_heads, N, N) 153 | attn = self.softmax(attn) 154 | else: 155 | attn = self.softmax(attn) 156 | 157 | attn = self.attn_drop(attn) 158 | 159 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 160 | x = self.proj(x) 161 | x = self.proj_drop(x) 162 | return x 163 | 164 | 165 | class SwinTransformerBlock(nn.Module): 166 | """ Swin Transformer Block. 167 | 168 | Args: 169 | dim (int): Number of input channels. 170 | num_heads (int): Number of attention heads. 171 | window_size (int): Window size. 172 | shift_size (int): Shift size for SW-MSA. 173 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 174 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 175 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 176 | drop (float, optional): Dropout rate. Default: 0.0 177 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 178 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 179 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 180 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 181 | """ 182 | 183 | def __init__(self, dim, num_heads, window_size=7, shift_size=0, 184 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 185 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 186 | super().__init__() 187 | self.dim = dim 188 | self.num_heads = num_heads 189 | self.window_size = window_size 190 | self.shift_size = shift_size 191 | self.mlp_ratio = mlp_ratio 192 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 193 | 194 | self.norm1 = norm_layer(dim) 195 | self.attn = WindowAttention( 196 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 197 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 198 | 199 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 200 | self.norm2 = norm_layer(dim) 201 | mlp_hidden_dim = int(dim * mlp_ratio) 202 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 203 | 204 | self.H = None 205 | self.W = None 206 | 207 | def forward(self, x, mask_matrix): 208 | """ Forward function. 209 | 210 | Args: 211 | x: Input feature, tensor size (B, H*W, C). 212 | H, W: Spatial resolution of the input feature. 213 | mask_matrix: Attention mask for cyclic shift. 214 | """ 215 | B, L, C = x.shape 216 | H, W = self.H, self.W 217 | assert L == H * W, "input feature has wrong size" 218 | 219 | shortcut = x 220 | x = self.norm1(x) 221 | x = x.view(B, H, W, C) 222 | 223 | # pad feature maps to multiples of window size 224 | pad_l = pad_t = 0 225 | pad_r = (self.window_size - W % self.window_size) % self.window_size 226 | pad_b = (self.window_size - H % self.window_size) % self.window_size 227 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) 228 | _, Hp, Wp, _ = x.shape 229 | 230 | # cyclic shift 231 | if self.shift_size > 0: 232 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 233 | attn_mask = mask_matrix 234 | else: 235 | shifted_x = x 236 | attn_mask = None 237 | 238 | # partition windows 239 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 240 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 241 | 242 | # W-MSA/SW-MSA 243 | attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C 244 | 245 | # merge windows 246 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 247 | shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C 248 | 249 | # reverse cyclic shift 250 | if self.shift_size > 0: 251 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 252 | else: 253 | x = shifted_x 254 | 255 | if pad_r > 0 or pad_b > 0: 256 | x = x[:, :H, :W, :].contiguous() 257 | 258 | x = x.view(B, H * W, C) 259 | 260 | # FFN 261 | x = shortcut + self.drop_path(x) 262 | x = x + self.drop_path(self.mlp(self.norm2(x))) 263 | 264 | return x 265 | 266 | 267 | class PatchMerging(nn.Module): 268 | """ Patch Merging Layer 269 | 270 | Args: 271 | dim (int): Number of input channels. 272 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 273 | """ 274 | def __init__(self, dim, norm_layer=nn.LayerNorm): 275 | super().__init__() 276 | self.dim = dim 277 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 278 | self.norm = norm_layer(4 * dim) 279 | 280 | def forward(self, x, H, W): 281 | """ Forward function. 282 | 283 | Args: 284 | x: Input feature, tensor size (B, H*W, C). 285 | H, W: Spatial resolution of the input feature. 286 | """ 287 | B, L, C = x.shape 288 | assert L == H * W, "input feature has wrong size" 289 | 290 | x = x.view(B, H, W, C) 291 | 292 | # padding 293 | pad_input = (H % 2 == 1) or (W % 2 == 1) 294 | if pad_input: 295 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) 296 | 297 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 298 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 299 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 300 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 301 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 302 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 303 | 304 | x = self.norm(x) 305 | x = self.reduction(x) 306 | 307 | return x 308 | 309 | 310 | class BasicLayer(nn.Module): 311 | """ A basic Swin Transformer layer for one stage. 312 | 313 | Args: 314 | dim (int): Number of feature channels 315 | depth (int): Depths of this stage. 316 | num_heads (int): Number of attention head. 317 | window_size (int): Local window size. Default: 7. 318 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. 319 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 320 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 321 | drop (float, optional): Dropout rate. Default: 0.0 322 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 323 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 324 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 325 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 326 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 327 | """ 328 | 329 | def __init__(self, 330 | dim, 331 | depth, 332 | num_heads, 333 | window_size=7, 334 | mlp_ratio=4., 335 | qkv_bias=True, 336 | qk_scale=None, 337 | drop=0., 338 | attn_drop=0., 339 | drop_path=0., 340 | norm_layer=nn.LayerNorm, 341 | downsample=None, 342 | use_checkpoint=False): 343 | super().__init__() 344 | self.window_size = window_size 345 | self.shift_size = window_size // 2 346 | self.depth = depth 347 | self.use_checkpoint = use_checkpoint 348 | 349 | # build blocks 350 | self.blocks = nn.ModuleList([ 351 | SwinTransformerBlock( 352 | dim=dim, 353 | num_heads=num_heads, 354 | window_size=window_size, 355 | shift_size=0 if (i % 2 == 0) else window_size // 2, 356 | mlp_ratio=mlp_ratio, 357 | qkv_bias=qkv_bias, 358 | qk_scale=qk_scale, 359 | drop=drop, 360 | attn_drop=attn_drop, 361 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 362 | norm_layer=norm_layer) 363 | for i in range(depth)]) 364 | 365 | # patch merging layer 366 | if downsample is not None: 367 | self.downsample = downsample(dim=dim, norm_layer=norm_layer) 368 | else: 369 | self.downsample = None 370 | 371 | def forward(self, x, H, W): 372 | """ Forward function. 373 | 374 | Args: 375 | x: Input feature, tensor size (B, H*W, C). 376 | H, W: Spatial resolution of the input feature. 377 | """ 378 | 379 | # calculate attention mask for SW-MSA 380 | # Turn int to torch.tensor for the compatiability with torch.compile in PyTorch 2.5. 381 | Hp = torch.ceil(torch.tensor(H) / self.window_size).to(torch.int64) * self.window_size 382 | Wp = torch.ceil(torch.tensor(W) / self.window_size).to(torch.int64) * self.window_size 383 | img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 384 | h_slices = (slice(0, -self.window_size), 385 | slice(-self.window_size, -self.shift_size), 386 | slice(-self.shift_size, None)) 387 | w_slices = (slice(0, -self.window_size), 388 | slice(-self.window_size, -self.shift_size), 389 | slice(-self.shift_size, None)) 390 | cnt = 0 391 | for h in h_slices: 392 | for w in w_slices: 393 | img_mask[:, h, w, :] = cnt 394 | cnt += 1 395 | 396 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 397 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 398 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 399 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)).to(x.dtype) 400 | 401 | for blk in self.blocks: 402 | blk.H, blk.W = H, W 403 | if self.use_checkpoint: 404 | x = checkpoint.checkpoint(blk, x, attn_mask) 405 | else: 406 | x = blk(x, attn_mask) 407 | if self.downsample is not None: 408 | x_down = self.downsample(x, H, W) 409 | Wh, Ww = (H + 1) // 2, (W + 1) // 2 410 | return x, H, W, x_down, Wh, Ww 411 | else: 412 | return x, H, W, x, H, W 413 | 414 | 415 | class PatchEmbed(nn.Module): 416 | """ Image to Patch Embedding 417 | 418 | Args: 419 | patch_size (int): Patch token size. Default: 4. 420 | in_channels (int): Number of input image channels. Default: 3. 421 | embed_dim (int): Number of linear projection output channels. Default: 96. 422 | norm_layer (nn.Module, optional): Normalization layer. Default: None 423 | """ 424 | 425 | def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None): 426 | super().__init__() 427 | patch_size = to_2tuple(patch_size) 428 | self.patch_size = patch_size 429 | 430 | self.in_channels = in_channels 431 | self.embed_dim = embed_dim 432 | 433 | self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) 434 | if norm_layer is not None: 435 | self.norm = norm_layer(embed_dim) 436 | else: 437 | self.norm = None 438 | 439 | def forward(self, x): 440 | """Forward function.""" 441 | # padding 442 | _, _, H, W = x.size() 443 | if W % self.patch_size[1] != 0: 444 | x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) 445 | if H % self.patch_size[0] != 0: 446 | x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) 447 | 448 | x = self.proj(x) # B C Wh Ww 449 | if self.norm is not None: 450 | Wh, Ww = x.size(2), x.size(3) 451 | x = x.flatten(2).transpose(1, 2) 452 | x = self.norm(x) 453 | x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) 454 | 455 | return x 456 | 457 | 458 | class SwinTransformer(nn.Module): 459 | """ Swin Transformer backbone. 460 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 461 | https://arxiv.org/pdf/2103.14030 462 | 463 | Args: 464 | pretrain_img_size (int): Input image size for training the pretrained model, 465 | used in absolute postion embedding. Default 224. 466 | patch_size (int | tuple(int)): Patch size. Default: 4. 467 | in_channels (int): Number of input image channels. Default: 3. 468 | embed_dim (int): Number of linear projection output channels. Default: 96. 469 | depths (tuple[int]): Depths of each Swin Transformer stage. 470 | num_heads (tuple[int]): Number of attention head of each stage. 471 | window_size (int): Window size. Default: 7. 472 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. 473 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 474 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. 475 | drop_rate (float): Dropout rate. 476 | attn_drop_rate (float): Attention dropout rate. Default: 0. 477 | drop_path_rate (float): Stochastic depth rate. Default: 0.2. 478 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 479 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. 480 | patch_norm (bool): If True, add normalization after patch embedding. Default: True. 481 | out_indices (Sequence[int]): Output from which stages. 482 | frozen_stages (int): Stages to be frozen (stop grad and set eval mode). 483 | -1 means not freezing any parameters. 484 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 485 | """ 486 | 487 | def __init__(self, 488 | pretrain_img_size=224, 489 | patch_size=4, 490 | in_channels=3, 491 | embed_dim=96, 492 | depths=[2, 2, 6, 2], 493 | num_heads=[3, 6, 12, 24], 494 | window_size=7, 495 | mlp_ratio=4., 496 | qkv_bias=True, 497 | qk_scale=None, 498 | drop_rate=0., 499 | attn_drop_rate=0., 500 | drop_path_rate=0.2, 501 | norm_layer=nn.LayerNorm, 502 | ape=False, 503 | patch_norm=True, 504 | out_indices=(0, 1, 2, 3), 505 | frozen_stages=-1, 506 | use_checkpoint=False): 507 | super().__init__() 508 | 509 | self.pretrain_img_size = pretrain_img_size 510 | self.num_layers = len(depths) 511 | self.embed_dim = embed_dim 512 | self.ape = ape 513 | self.patch_norm = patch_norm 514 | self.out_indices = out_indices 515 | self.frozen_stages = frozen_stages 516 | 517 | # split image into non-overlapping patches 518 | self.patch_embed = PatchEmbed( 519 | patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim, 520 | norm_layer=norm_layer if self.patch_norm else None) 521 | 522 | # absolute position embedding 523 | if self.ape: 524 | pretrain_img_size = to_2tuple(pretrain_img_size) 525 | patch_size = to_2tuple(patch_size) 526 | patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]] 527 | 528 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])) 529 | trunc_normal_(self.absolute_pos_embed, std=.02) 530 | 531 | self.pos_drop = nn.Dropout(p=drop_rate) 532 | 533 | # stochastic depth 534 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 535 | 536 | # build layers 537 | self.layers = nn.ModuleList() 538 | for i_layer in range(self.num_layers): 539 | layer = BasicLayer( 540 | dim=int(embed_dim * 2 ** i_layer), 541 | depth=depths[i_layer], 542 | num_heads=num_heads[i_layer], 543 | window_size=window_size, 544 | mlp_ratio=mlp_ratio, 545 | qkv_bias=qkv_bias, 546 | qk_scale=qk_scale, 547 | drop=drop_rate, 548 | attn_drop=attn_drop_rate, 549 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 550 | norm_layer=norm_layer, 551 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 552 | use_checkpoint=use_checkpoint) 553 | self.layers.append(layer) 554 | 555 | num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] 556 | self.num_features = num_features 557 | 558 | # add a norm layer for each output 559 | for i_layer in out_indices: 560 | layer = norm_layer(num_features[i_layer]) 561 | layer_name = f'norm{i_layer}' 562 | self.add_module(layer_name, layer) 563 | 564 | self._freeze_stages() 565 | 566 | def _freeze_stages(self): 567 | if self.frozen_stages >= 0: 568 | self.patch_embed.eval() 569 | for param in self.patch_embed.parameters(): 570 | param.requires_grad = False 571 | 572 | if self.frozen_stages >= 1 and self.ape: 573 | self.absolute_pos_embed.requires_grad = False 574 | 575 | if self.frozen_stages >= 2: 576 | self.pos_drop.eval() 577 | for i in range(0, self.frozen_stages - 1): 578 | m = self.layers[i] 579 | m.eval() 580 | for param in m.parameters(): 581 | param.requires_grad = False 582 | 583 | 584 | def forward(self, x): 585 | """Forward function.""" 586 | x = self.patch_embed(x) 587 | 588 | Wh, Ww = x.size(2), x.size(3) 589 | if self.ape: 590 | # interpolate the position embedding to the corresponding size 591 | absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') 592 | x = (x + absolute_pos_embed) # B Wh*Ww C 593 | 594 | outs = []#x.contiguous()] 595 | x = x.flatten(2).transpose(1, 2) 596 | x = self.pos_drop(x) 597 | for i in range(self.num_layers): 598 | layer = self.layers[i] 599 | x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) 600 | 601 | if i in self.out_indices: 602 | norm_layer = getattr(self, f'norm{i}') 603 | x_out = norm_layer(x_out) 604 | 605 | out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() 606 | outs.append(out) 607 | 608 | return tuple(outs) 609 | 610 | def train(self, mode=True): 611 | """Convert the model into training mode while keep layers freezed.""" 612 | super(SwinTransformer, self).train(mode) 613 | self._freeze_stages() 614 | 615 | def swin_v1_t(): 616 | model = SwinTransformer(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7) 617 | return model 618 | 619 | def swin_v1_s(): 620 | model = SwinTransformer(embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], window_size=7) 621 | return model 622 | 623 | def swin_v1_b(): 624 | model = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12) 625 | return model 626 | 627 | def swin_v1_l(): 628 | model = SwinTransformer(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12) 629 | return model 630 | -------------------------------------------------------------------------------- /models/birefnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | from kornia.filters import laplacian 6 | from huggingface_hub import PyTorchModelHubMixin 7 | 8 | from config import Config 9 | from dataset import class_labels_TR_sorted 10 | from models.backbones.build_backbone import build_backbone 11 | from models.modules.decoder_blocks import BasicDecBlk, ResBlk 12 | from models.modules.lateral_blocks import BasicLatBlk 13 | from models.modules.aspp import ASPP, ASPPDeformable 14 | from models.refinement.refiner import Refiner, RefinerPVTInChannels4, RefUNet 15 | from models.refinement.stem_layer import StemLayer 16 | 17 | 18 | def image2patches(image, grid_h=2, grid_w=2, patch_ref=None, transformation='b c (hg h) (wg w) -> (b hg wg) c h w'): 19 | if patch_ref is not None: 20 | grid_h, grid_w = image.shape[-2] // patch_ref.shape[-2], image.shape[-1] // patch_ref.shape[-1] 21 | patches = rearrange(image, transformation, hg=grid_h, wg=grid_w) 22 | return patches 23 | 24 | def patches2image(patches, grid_h=2, grid_w=2, patch_ref=None, transformation='(b hg wg) c h w -> b c (hg h) (wg w)'): 25 | if patch_ref is not None: 26 | grid_h, grid_w = patch_ref.shape[-2] // patches[0].shape[-2], patch_ref.shape[-1] // patches[0].shape[-1] 27 | image = rearrange(patches, transformation, hg=grid_h, wg=grid_w) 28 | return image 29 | 30 | class BiRefNet( 31 | nn.Module, 32 | PyTorchModelHubMixin, 33 | library_name="birefnet", 34 | repo_url="https://github.com/ZhengPeng7/BiRefNet", 35 | tags=['Image Segmentation', 'Background Removal', 'Mask Generation', 'Dichotomous Image Segmentation', 'Camouflaged Object Detection', 'Salient Object Detection'] 36 | ): 37 | def __init__(self, bb_pretrained=True): 38 | super(BiRefNet, self).__init__() 39 | self.config = Config() 40 | self.epoch = 1 41 | self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained) 42 | 43 | channels = self.config.lateral_channels_in_collection 44 | 45 | if self.config.auxiliary_classification: 46 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 47 | self.cls_head = nn.Sequential( 48 | nn.Linear(channels[0], len(class_labels_TR_sorted)) 49 | ) 50 | 51 | if self.config.squeeze_block: 52 | self.squeeze_module = nn.Sequential(*[ 53 | eval(self.config.squeeze_block.split('_x')[0])(channels[0]+sum(self.config.cxt), channels[0]) 54 | for _ in range(eval(self.config.squeeze_block.split('_x')[1])) 55 | ]) 56 | 57 | self.decoder = Decoder(channels) 58 | 59 | if self.config.ender: 60 | self.dec_end = nn.Sequential( 61 | nn.Conv2d(1, 16, 3, 1, 1), 62 | nn.Conv2d(16, 1, 3, 1, 1), 63 | nn.ReLU(inplace=True), 64 | ) 65 | 66 | # refine patch-level segmentation 67 | if self.config.refine: 68 | if self.config.refine == 'itself': 69 | self.stem_layer = StemLayer(in_channels=3+1, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN') 70 | else: 71 | self.refiner = eval('{}({})'.format(self.config.refine, 'in_channels=3+1')) 72 | 73 | if self.config.freeze_bb: 74 | # Freeze the backbone... 75 | print(self.named_parameters()) 76 | for key, value in self.named_parameters(): 77 | if 'bb.' in key and 'refiner.' not in key: 78 | value.requires_grad = False 79 | 80 | def forward_enc(self, x): 81 | if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']: 82 | x1 = self.bb.conv1(x); x2 = self.bb.conv2(x1); x3 = self.bb.conv3(x2); x4 = self.bb.conv4(x3) 83 | else: 84 | x1, x2, x3, x4 = self.bb(x) 85 | if self.config.mul_scl_ipt: 86 | B, C, H, W = x.shape 87 | x_pyramid = F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True) 88 | if self.config.mul_scl_ipt == 'cat': 89 | if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']: 90 | x1_ = self.bb.conv1(x_pyramid); x2_ = self.bb.conv2(x1_); x3_ = self.bb.conv3(x2_); x4_ = self.bb.conv4(x3_) 91 | else: 92 | x1_, x2_, x3_, x4_ = self.bb(x_pyramid) 93 | x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1) 94 | x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1) 95 | x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1) 96 | x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1) 97 | elif self.config.mul_scl_ipt == 'add': 98 | x1_, x2_, x3_, x4_ = self.bb(x_pyramid) 99 | x1 = x1 + F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True) 100 | x2 = x2 + F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True) 101 | x3 = x3 + F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True) 102 | x4 = x4 + F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True) 103 | class_preds = self.cls_head(self.avgpool(x4).view(x4.shape[0], -1)) if self.training and self.config.auxiliary_classification else None 104 | if self.config.cxt: 105 | x4 = torch.cat( 106 | ( 107 | *[ 108 | F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True), 109 | F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True), 110 | F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True), 111 | ][-len(self.config.cxt):], 112 | x4 113 | ), 114 | dim=1 115 | ) 116 | return (x1, x2, x3, x4), class_preds 117 | 118 | def forward_ori(self, x): 119 | ########## Encoder ########## 120 | (x1, x2, x3, x4), class_preds = self.forward_enc(x) 121 | if self.config.squeeze_block: 122 | x4 = self.squeeze_module(x4) 123 | ########## Decoder ########## 124 | features = [x, x1, x2, x3, x4] 125 | if self.training and self.config.out_ref: 126 | features.append(laplacian(torch.mean(x, dim=1).unsqueeze(1), kernel_size=5)) 127 | scaled_preds = self.decoder(features) 128 | return scaled_preds, class_preds 129 | 130 | def forward(self, x): 131 | scaled_preds, class_preds = self.forward_ori(x) 132 | class_preds_lst = [class_preds] 133 | return [scaled_preds, class_preds_lst] if self.training else scaled_preds 134 | 135 | 136 | class Decoder(nn.Module): 137 | def __init__(self, channels): 138 | super(Decoder, self).__init__() 139 | self.config = Config() 140 | DecoderBlock = eval(self.config.dec_blk) 141 | LateralBlock = eval(self.config.lat_blk) 142 | 143 | if self.config.dec_ipt: 144 | self.split = self.config.dec_ipt_split 145 | N_dec_ipt = 64 146 | DBlock = SimpleConvs 147 | ic = 64 148 | ipt_cha_opt = 1 149 | self.ipt_blk5 = DBlock(2**10*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic) 150 | self.ipt_blk4 = DBlock(2**8*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic) 151 | self.ipt_blk3 = DBlock(2**6*3 if self.split else 3, [N_dec_ipt, channels[1]//8][ipt_cha_opt], inter_channels=ic) 152 | self.ipt_blk2 = DBlock(2**4*3 if self.split else 3, [N_dec_ipt, channels[2]//8][ipt_cha_opt], inter_channels=ic) 153 | self.ipt_blk1 = DBlock(2**0*3 if self.split else 3, [N_dec_ipt, channels[3]//8][ipt_cha_opt], inter_channels=ic) 154 | else: 155 | self.split = None 156 | 157 | self.decoder_block4 = DecoderBlock(channels[0]+([N_dec_ipt, channels[0]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[1]) 158 | self.decoder_block3 = DecoderBlock(channels[1]+([N_dec_ipt, channels[0]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[2]) 159 | self.decoder_block2 = DecoderBlock(channels[2]+([N_dec_ipt, channels[1]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[3]) 160 | self.decoder_block1 = DecoderBlock(channels[3]+([N_dec_ipt, channels[2]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[3]//2) 161 | self.conv_out1 = nn.Sequential(nn.Conv2d(channels[3]//2+([N_dec_ipt, channels[3]//8][ipt_cha_opt] if self.config.dec_ipt else 0), 1, 1, 1, 0)) 162 | 163 | self.lateral_block4 = LateralBlock(channels[1], channels[1]) 164 | self.lateral_block3 = LateralBlock(channels[2], channels[2]) 165 | self.lateral_block2 = LateralBlock(channels[3], channels[3]) 166 | 167 | if self.config.ms_supervision: 168 | self.conv_ms_spvn_4 = nn.Conv2d(channels[1], 1, 1, 1, 0) 169 | self.conv_ms_spvn_3 = nn.Conv2d(channels[2], 1, 1, 1, 0) 170 | self.conv_ms_spvn_2 = nn.Conv2d(channels[3], 1, 1, 1, 0) 171 | 172 | if self.config.out_ref: 173 | _N = 16 174 | self.gdt_convs_4 = nn.Sequential(nn.Conv2d(channels[1], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True)) 175 | self.gdt_convs_3 = nn.Sequential(nn.Conv2d(channels[2], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True)) 176 | self.gdt_convs_2 = nn.Sequential(nn.Conv2d(channels[3], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True)) 177 | 178 | self.gdt_convs_pred_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) 179 | self.gdt_convs_pred_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) 180 | self.gdt_convs_pred_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) 181 | 182 | self.gdt_convs_attn_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) 183 | self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) 184 | self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) 185 | 186 | def forward(self, features): 187 | if self.training and self.config.out_ref: 188 | outs_gdt_pred = [] 189 | outs_gdt_label = [] 190 | x, x1, x2, x3, x4, gdt_gt = features 191 | else: 192 | x, x1, x2, x3, x4 = features 193 | outs = [] 194 | 195 | if self.config.dec_ipt: 196 | patches_batch = image2patches(x, patch_ref=x4, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x 197 | x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1) 198 | p4 = self.decoder_block4(x4) 199 | m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision and self.training else None 200 | if self.config.out_ref: 201 | p4_gdt = self.gdt_convs_4(p4) 202 | if self.training: 203 | # >> GT: 204 | m4_dia = m4 205 | gdt_label_main_4 = gdt_gt * F.interpolate(m4_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True) 206 | outs_gdt_label.append(gdt_label_main_4) 207 | # >> Pred: 208 | gdt_pred_4 = self.gdt_convs_pred_4(p4_gdt) 209 | outs_gdt_pred.append(gdt_pred_4) 210 | gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid() 211 | # >> Finally: 212 | p4 = p4 * gdt_attn_4 213 | _p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True) 214 | _p3 = _p4 + self.lateral_block4(x3) 215 | 216 | if self.config.dec_ipt: 217 | patches_batch = image2patches(x, patch_ref=_p3, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x 218 | _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1) 219 | p3 = self.decoder_block3(_p3) 220 | m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision and self.training else None 221 | if self.config.out_ref: 222 | p3_gdt = self.gdt_convs_3(p3) 223 | if self.training: 224 | # >> GT: 225 | # m3 --dilation--> m3_dia 226 | # G_3^gt * m3_dia --> G_3^m, which is the label of gradient 227 | m3_dia = m3 228 | gdt_label_main_3 = gdt_gt * F.interpolate(m3_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True) 229 | outs_gdt_label.append(gdt_label_main_3) 230 | # >> Pred: 231 | # p3 --conv--BN--> F_3^G, where F_3^G predicts the \hat{G_3} with xx 232 | # F_3^G --sigmoid--> A_3^G 233 | gdt_pred_3 = self.gdt_convs_pred_3(p3_gdt) 234 | outs_gdt_pred.append(gdt_pred_3) 235 | gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid() 236 | # >> Finally: 237 | # p3 = p3 * A_3^G 238 | p3 = p3 * gdt_attn_3 239 | _p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True) 240 | _p2 = _p3 + self.lateral_block3(x2) 241 | 242 | if self.config.dec_ipt: 243 | patches_batch = image2patches(x, patch_ref=_p2, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x 244 | _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1) 245 | p2 = self.decoder_block2(_p2) 246 | m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision and self.training else None 247 | if self.config.out_ref: 248 | p2_gdt = self.gdt_convs_2(p2) 249 | if self.training: 250 | # >> GT: 251 | m2_dia = m2 252 | gdt_label_main_2 = gdt_gt * F.interpolate(m2_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True) 253 | outs_gdt_label.append(gdt_label_main_2) 254 | # >> Pred: 255 | gdt_pred_2 = self.gdt_convs_pred_2(p2_gdt) 256 | outs_gdt_pred.append(gdt_pred_2) 257 | gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid() 258 | # >> Finally: 259 | p2 = p2 * gdt_attn_2 260 | _p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True) 261 | _p1 = _p2 + self.lateral_block2(x1) 262 | 263 | if self.config.dec_ipt: 264 | patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x 265 | _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1) 266 | _p1 = self.decoder_block1(_p1) 267 | _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True) 268 | 269 | if self.config.dec_ipt: 270 | patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x 271 | _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1) 272 | p1_out = self.conv_out1(_p1) 273 | 274 | if self.config.ms_supervision and self.training: 275 | outs.append(m4) 276 | outs.append(m3) 277 | outs.append(m2) 278 | outs.append(p1_out) 279 | return outs if not (self.config.out_ref and self.training) else ([outs_gdt_pred, outs_gdt_label], outs) 280 | 281 | 282 | class SimpleConvs(nn.Module): 283 | def __init__( 284 | self, in_channels: int, out_channels: int, inter_channels=64 285 | ) -> None: 286 | super().__init__() 287 | self.conv1 = nn.Conv2d(in_channels, inter_channels, 3, 1, 1) 288 | self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, 1) 289 | 290 | def forward(self, x): 291 | return self.conv_out(self.conv1(x)) 292 | 293 | 294 | ########### 295 | 296 | 297 | class BiRefNetC2F( 298 | nn.Module, 299 | PyTorchModelHubMixin, 300 | library_name="birefnet_c2f", 301 | repo_url="https://github.com/ZhengPeng7/BiRefNet_C2F", 302 | tags=['Image Segmentation', 'Background Removal', 'Mask Generation', 'Dichotomous Image Segmentation', 'Camouflaged Object Detection', 'Salient Object Detection'] 303 | ): 304 | def __init__(self, bb_pretrained=True): 305 | super(BiRefNetC2F, self).__init__() 306 | self.config = Config() 307 | self.epoch = 1 308 | self.grid = 4 309 | self.model_coarse = BiRefNet(bb_pretrained=True) 310 | self.model_fine = BiRefNet(bb_pretrained=True) 311 | self.input_mixer = nn.Conv2d(4, 3, 1, 1, 0) 312 | self.output_mixer_merge_post = nn.Sequential(nn.Conv2d(1, 16, 3, 1, 1), nn.Conv2d(16, 1, 3, 1, 1)) 313 | 314 | def forward(self, x): 315 | x_ori = x.clone() 316 | ########## Coarse ########## 317 | x = F.interpolate(x, size=[s//self.grid for s in self.config.size[::-1]], mode='bilinear', align_corners=True) 318 | 319 | if self.training: 320 | scaled_preds, class_preds_lst = self.model_coarse(x) 321 | else: 322 | scaled_preds = self.model_coarse(x) 323 | ########## Fine ########## 324 | x_HR_patches = image2patches(x_ori, patch_ref=x, transformation='b c (hg h) (wg w) -> (b hg wg) c h w') 325 | pred = F.interpolate(scaled_preds[-1] if not (self.config.out_ref and self.training) else scaled_preds[1][-1], size=x_ori.shape[2:], mode='bilinear', align_corners=True) 326 | pred_patches = image2patches(pred, patch_ref=x, transformation='b c (hg h) (wg w) -> (b hg wg) c h w') 327 | t = torch.cat([x_HR_patches, pred_patches], dim=1) 328 | x_HR = self.input_mixer(t) 329 | 330 | pred_patches = image2patches(pred, patch_ref=x_HR, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') 331 | if self.training: 332 | scaled_preds_HR, class_preds_lst_HR = self.model_fine(x_HR) 333 | else: 334 | scaled_preds_HR = self.model_fine(x_HR) 335 | if self.training: 336 | if self.config.out_ref: 337 | [outs_gdt_pred, outs_gdt_label], outs = scaled_preds 338 | [outs_gdt_pred_HR, outs_gdt_label_HR], outs_HR = scaled_preds_HR 339 | for idx_out, out_HR in enumerate(outs_HR): 340 | outs_HR[idx_out] = self.output_mixer_merge_post(patches2image(out_HR, grid_h=self.grid, grid_w=self.grid, transformation='(b hg wg) c h w -> b c (hg h) (wg w)')) 341 | return [([outs_gdt_pred + outs_gdt_pred_HR, outs_gdt_label + outs_gdt_label_HR], outs + outs_HR), class_preds_lst] # handle gt here 342 | else: 343 | return [ 344 | scaled_preds + [self.output_mixer_merge_post(patches2image(scaled_pred_HR, grid_h=self.grid, grid_w=self.grid, transformation='(b hg wg) c h w -> b c (hg h) (wg w)')) for scaled_pred_HR in scaled_preds_HR], 345 | class_preds_lst 346 | ] 347 | else: 348 | return scaled_preds + [self.output_mixer_merge_post(patches2image(scaled_pred_HR, grid_h=self.grid, grid_w=self.grid, transformation='(b hg wg) c h w -> b c (hg h) (wg w)')) for scaled_pred_HR in scaled_preds_HR] 349 | -------------------------------------------------------------------------------- /models/modules/aspp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.modules.deform_conv import DeformableConv2d 5 | from config import Config 6 | 7 | 8 | config = Config() 9 | 10 | 11 | class _ASPPModule(nn.Module): 12 | def __init__(self, in_channels, planes, kernel_size, padding, dilation): 13 | super(_ASPPModule, self).__init__() 14 | self.atrous_conv = nn.Conv2d(in_channels, planes, kernel_size=kernel_size, 15 | stride=1, padding=padding, dilation=dilation, bias=False) 16 | self.bn = nn.BatchNorm2d(planes) if config.batch_size > 1 else nn.Identity() 17 | self.relu = nn.ReLU(inplace=True) 18 | 19 | def forward(self, x): 20 | x = self.atrous_conv(x) 21 | x = self.bn(x) 22 | 23 | return self.relu(x) 24 | 25 | 26 | class ASPP(nn.Module): 27 | def __init__(self, in_channels=64, out_channels=None, output_stride=16): 28 | super(ASPP, self).__init__() 29 | self.down_scale = 1 30 | if out_channels is None: 31 | out_channels = in_channels 32 | self.in_channelster = 256 // self.down_scale 33 | if output_stride == 16: 34 | dilations = [1, 6, 12, 18] 35 | elif output_stride == 8: 36 | dilations = [1, 12, 24, 36] 37 | else: 38 | raise NotImplementedError 39 | 40 | self.aspp1 = _ASPPModule(in_channels, self.in_channelster, 1, padding=0, dilation=dilations[0]) 41 | self.aspp2 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[1], dilation=dilations[1]) 42 | self.aspp3 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[2], dilation=dilations[2]) 43 | self.aspp4 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[3], dilation=dilations[3]) 44 | 45 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 46 | nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False), 47 | nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(), 48 | nn.ReLU(inplace=True)) 49 | self.conv1 = nn.Conv2d(self.in_channelster * 5, out_channels, 1, bias=False) 50 | self.bn1 = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity() 51 | self.relu = nn.ReLU(inplace=True) 52 | self.dropout = nn.Dropout(0.5) 53 | 54 | def forward(self, x): 55 | x1 = self.aspp1(x) 56 | x2 = self.aspp2(x) 57 | x3 = self.aspp3(x) 58 | x4 = self.aspp4(x) 59 | x5 = self.global_avg_pool(x) 60 | x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True) 61 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 62 | 63 | x = self.conv1(x) 64 | x = self.bn1(x) 65 | x = self.relu(x) 66 | 67 | return self.dropout(x) 68 | 69 | 70 | ##################### Deformable 71 | class _ASPPModuleDeformable(nn.Module): 72 | def __init__(self, in_channels, planes, kernel_size, padding): 73 | super(_ASPPModuleDeformable, self).__init__() 74 | self.atrous_conv = DeformableConv2d(in_channels, planes, kernel_size=kernel_size, 75 | stride=1, padding=padding, bias=False) 76 | self.bn = nn.BatchNorm2d(planes) if config.batch_size > 1 else nn.Identity() 77 | self.relu = nn.ReLU(inplace=True) 78 | 79 | def forward(self, x): 80 | x = self.atrous_conv(x) 81 | x = self.bn(x) 82 | 83 | return self.relu(x) 84 | 85 | 86 | class ASPPDeformable(nn.Module): 87 | def __init__(self, in_channels, out_channels=None, parallel_block_sizes=[1, 3, 7]): 88 | super(ASPPDeformable, self).__init__() 89 | self.down_scale = 1 90 | if out_channels is None: 91 | out_channels = in_channels 92 | self.in_channelster = 256 // self.down_scale 93 | 94 | self.aspp1 = _ASPPModuleDeformable(in_channels, self.in_channelster, 1, padding=0) 95 | self.aspp_deforms = nn.ModuleList([ 96 | _ASPPModuleDeformable(in_channels, self.in_channelster, conv_size, padding=int(conv_size//2)) for conv_size in parallel_block_sizes 97 | ]) 98 | 99 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 100 | nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False), 101 | nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(), 102 | nn.ReLU(inplace=True)) 103 | self.conv1 = nn.Conv2d(self.in_channelster * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False) 104 | self.bn1 = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity() 105 | self.relu = nn.ReLU(inplace=True) 106 | self.dropout = nn.Dropout(0.5) 107 | 108 | def forward(self, x): 109 | x1 = self.aspp1(x) 110 | x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms] 111 | x5 = self.global_avg_pool(x) 112 | x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True) 113 | x = torch.cat((x1, *x_aspp_deforms, x5), dim=1) 114 | 115 | x = self.conv1(x) 116 | x = self.bn1(x) 117 | x = self.relu(x) 118 | 119 | return self.dropout(x) 120 | -------------------------------------------------------------------------------- /models/modules/decoder_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.modules.aspp import ASPP, ASPPDeformable 4 | from config import Config 5 | 6 | 7 | config = Config() 8 | 9 | 10 | class BasicDecBlk(nn.Module): 11 | def __init__(self, in_channels=64, out_channels=64, inter_channels=64): 12 | super(BasicDecBlk, self).__init__() 13 | inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64 14 | self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1) 15 | self.relu_in = nn.ReLU(inplace=True) 16 | if config.dec_att == 'ASPP': 17 | self.dec_att = ASPP(in_channels=inter_channels) 18 | elif config.dec_att == 'ASPPDeformable': 19 | self.dec_att = ASPPDeformable(in_channels=inter_channels) 20 | self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1) 21 | self.bn_in = nn.BatchNorm2d(inter_channels) if config.batch_size > 1 else nn.Identity() 22 | self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity() 23 | 24 | def forward(self, x): 25 | x = self.conv_in(x) 26 | x = self.bn_in(x) 27 | x = self.relu_in(x) 28 | if hasattr(self, 'dec_att'): 29 | x = self.dec_att(x) 30 | x = self.conv_out(x) 31 | x = self.bn_out(x) 32 | return x 33 | 34 | 35 | class ResBlk(nn.Module): 36 | def __init__(self, in_channels=64, out_channels=None, inter_channels=64): 37 | super(ResBlk, self).__init__() 38 | if out_channels is None: 39 | out_channels = in_channels 40 | inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64 41 | 42 | self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1) 43 | self.bn_in = nn.BatchNorm2d(inter_channels) if config.batch_size > 1 else nn.Identity() 44 | self.relu_in = nn.ReLU(inplace=True) 45 | 46 | if config.dec_att == 'ASPP': 47 | self.dec_att = ASPP(in_channels=inter_channels) 48 | elif config.dec_att == 'ASPPDeformable': 49 | self.dec_att = ASPPDeformable(in_channels=inter_channels) 50 | 51 | self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1) 52 | self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity() 53 | 54 | self.conv_resi = nn.Conv2d(in_channels, out_channels, 1, 1, 0) 55 | 56 | def forward(self, x): 57 | _x = self.conv_resi(x) 58 | x = self.conv_in(x) 59 | x = self.bn_in(x) 60 | x = self.relu_in(x) 61 | if hasattr(self, 'dec_att'): 62 | x = self.dec_att(x) 63 | x = self.conv_out(x) 64 | x = self.bn_out(x) 65 | return x + _x 66 | -------------------------------------------------------------------------------- /models/modules/deform_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.ops import deform_conv2d 4 | 5 | 6 | class DeformableConv2d(nn.Module): 7 | def __init__(self, 8 | in_channels, 9 | out_channels, 10 | kernel_size=3, 11 | stride=1, 12 | padding=1, 13 | bias=False): 14 | 15 | super(DeformableConv2d, self).__init__() 16 | 17 | assert type(kernel_size) == tuple or type(kernel_size) == int 18 | 19 | kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size) 20 | self.stride = stride if type(stride) == tuple else (stride, stride) 21 | self.padding = padding 22 | 23 | self.offset_conv = nn.Conv2d(in_channels, 24 | 2 * kernel_size[0] * kernel_size[1], 25 | kernel_size=kernel_size, 26 | stride=stride, 27 | padding=self.padding, 28 | bias=True) 29 | 30 | nn.init.constant_(self.offset_conv.weight, 0.) 31 | nn.init.constant_(self.offset_conv.bias, 0.) 32 | 33 | self.modulator_conv = nn.Conv2d(in_channels, 34 | 1 * kernel_size[0] * kernel_size[1], 35 | kernel_size=kernel_size, 36 | stride=stride, 37 | padding=self.padding, 38 | bias=True) 39 | 40 | nn.init.constant_(self.modulator_conv.weight, 0.) 41 | nn.init.constant_(self.modulator_conv.bias, 0.) 42 | 43 | self.regular_conv = nn.Conv2d(in_channels, 44 | out_channels=out_channels, 45 | kernel_size=kernel_size, 46 | stride=stride, 47 | padding=self.padding, 48 | bias=bias) 49 | 50 | def forward(self, x): 51 | #h, w = x.shape[2:] 52 | #max_offset = max(h, w)/4. 53 | 54 | offset = self.offset_conv(x)#.clamp(-max_offset, max_offset) 55 | modulator = 2. * torch.sigmoid(self.modulator_conv(x)) 56 | 57 | x = deform_conv2d( 58 | input=x, 59 | offset=offset, 60 | weight=self.regular_conv.weight, 61 | bias=self.regular_conv.bias, 62 | padding=self.padding, 63 | mask=modulator, 64 | stride=self.stride, 65 | ) 66 | return x 67 | -------------------------------------------------------------------------------- /models/modules/lateral_blocks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from functools import partial 6 | 7 | from config import Config 8 | 9 | 10 | config = Config() 11 | 12 | 13 | class BasicLatBlk(nn.Module): 14 | def __init__(self, in_channels=64, out_channels=64, inter_channels=64): 15 | super(BasicLatBlk, self).__init__() 16 | inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64 17 | self.conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0) 18 | 19 | def forward(self, x): 20 | x = self.conv(x) 21 | return x 22 | -------------------------------------------------------------------------------- /models/modules/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def build_act_layer(act_layer): 5 | if act_layer == 'ReLU': 6 | return nn.ReLU(inplace=True) 7 | elif act_layer == 'SiLU': 8 | return nn.SiLU(inplace=True) 9 | elif act_layer == 'GELU': 10 | return nn.GELU() 11 | 12 | raise NotImplementedError(f'build_act_layer does not support {act_layer}') 13 | 14 | 15 | def build_norm_layer(dim, 16 | norm_layer, 17 | in_format='channels_last', 18 | out_format='channels_last', 19 | eps=1e-6): 20 | layers = [] 21 | if norm_layer == 'BN': 22 | if in_format == 'channels_last': 23 | layers.append(to_channels_first()) 24 | layers.append(nn.BatchNorm2d(dim)) 25 | if out_format == 'channels_last': 26 | layers.append(to_channels_last()) 27 | elif norm_layer == 'LN': 28 | if in_format == 'channels_first': 29 | layers.append(to_channels_last()) 30 | layers.append(nn.LayerNorm(dim, eps=eps)) 31 | if out_format == 'channels_first': 32 | layers.append(to_channels_first()) 33 | else: 34 | raise NotImplementedError( 35 | f'build_norm_layer does not support {norm_layer}') 36 | return nn.Sequential(*layers) 37 | 38 | 39 | class to_channels_first(nn.Module): 40 | 41 | def __init__(self): 42 | super().__init__() 43 | 44 | def forward(self, x): 45 | return x.permute(0, 3, 1, 2) 46 | 47 | 48 | class to_channels_last(nn.Module): 49 | 50 | def __init__(self): 51 | super().__init__() 52 | 53 | def forward(self, x): 54 | return x.permute(0, 2, 3, 1) 55 | -------------------------------------------------------------------------------- /models/refinement/refiner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torchvision.models import vgg16, vgg16_bn 8 | from torchvision.models import resnet50 9 | 10 | from config import Config 11 | from dataset import class_labels_TR_sorted 12 | from models.backbones.build_backbone import build_backbone 13 | from models.modules.decoder_blocks import BasicDecBlk 14 | from models.modules.lateral_blocks import BasicLatBlk 15 | from models.refinement.stem_layer import StemLayer 16 | 17 | 18 | class RefinerPVTInChannels4(nn.Module): 19 | def __init__(self, in_channels=3+1): 20 | super(RefinerPVTInChannels4, self).__init__() 21 | self.config = Config() 22 | self.epoch = 1 23 | self.bb = build_backbone(self.config.bb, params_settings='in_channels=4') 24 | 25 | lateral_channels_in_collection = { 26 | 'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64], 27 | 'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64], 28 | 'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192], 29 | } 30 | channels = lateral_channels_in_collection[self.config.bb] 31 | self.squeeze_module = BasicDecBlk(channels[0], channels[0]) 32 | 33 | self.decoder = Decoder(channels) 34 | 35 | if 0: 36 | for key, value in self.named_parameters(): 37 | if 'bb.' in key: 38 | value.requires_grad = False 39 | 40 | def forward(self, x): 41 | if isinstance(x, list): 42 | x = torch.cat(x, dim=1) 43 | ########## Encoder ########## 44 | if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']: 45 | x1 = self.bb.conv1(x) 46 | x2 = self.bb.conv2(x1) 47 | x3 = self.bb.conv3(x2) 48 | x4 = self.bb.conv4(x3) 49 | else: 50 | x1, x2, x3, x4 = self.bb(x) 51 | 52 | x4 = self.squeeze_module(x4) 53 | 54 | ########## Decoder ########## 55 | 56 | features = [x, x1, x2, x3, x4] 57 | scaled_preds = self.decoder(features) 58 | 59 | return scaled_preds 60 | 61 | 62 | class Refiner(nn.Module): 63 | def __init__(self, in_channels=3+1): 64 | super(Refiner, self).__init__() 65 | self.config = Config() 66 | self.epoch = 1 67 | self.stem_layer = StemLayer(in_channels=in_channels, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN') 68 | self.bb = build_backbone(self.config.bb) 69 | 70 | lateral_channels_in_collection = { 71 | 'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64], 72 | 'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64], 73 | 'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192], 74 | } 75 | channels = lateral_channels_in_collection[self.config.bb] 76 | self.squeeze_module = BasicDecBlk(channels[0], channels[0]) 77 | 78 | self.decoder = Decoder(channels) 79 | 80 | if 0: 81 | for key, value in self.named_parameters(): 82 | if 'bb.' in key: 83 | value.requires_grad = False 84 | 85 | def forward(self, x): 86 | if isinstance(x, list): 87 | x = torch.cat(x, dim=1) 88 | x = self.stem_layer(x) 89 | ########## Encoder ########## 90 | if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']: 91 | x1 = self.bb.conv1(x) 92 | x2 = self.bb.conv2(x1) 93 | x3 = self.bb.conv3(x2) 94 | x4 = self.bb.conv4(x3) 95 | else: 96 | x1, x2, x3, x4 = self.bb(x) 97 | 98 | x4 = self.squeeze_module(x4) 99 | 100 | ########## Decoder ########## 101 | 102 | features = [x, x1, x2, x3, x4] 103 | scaled_preds = self.decoder(features) 104 | 105 | return scaled_preds 106 | 107 | 108 | class Decoder(nn.Module): 109 | def __init__(self, channels): 110 | super(Decoder, self).__init__() 111 | self.config = Config() 112 | DecoderBlock = eval('BasicDecBlk') 113 | LateralBlock = eval('BasicLatBlk') 114 | 115 | self.decoder_block4 = DecoderBlock(channels[0], channels[1]) 116 | self.decoder_block3 = DecoderBlock(channels[1], channels[2]) 117 | self.decoder_block2 = DecoderBlock(channels[2], channels[3]) 118 | self.decoder_block1 = DecoderBlock(channels[3], channels[3]//2) 119 | 120 | self.lateral_block4 = LateralBlock(channels[1], channels[1]) 121 | self.lateral_block3 = LateralBlock(channels[2], channels[2]) 122 | self.lateral_block2 = LateralBlock(channels[3], channels[3]) 123 | 124 | if self.config.ms_supervision: 125 | self.conv_ms_spvn_4 = nn.Conv2d(channels[1], 1, 1, 1, 0) 126 | self.conv_ms_spvn_3 = nn.Conv2d(channels[2], 1, 1, 1, 0) 127 | self.conv_ms_spvn_2 = nn.Conv2d(channels[3], 1, 1, 1, 0) 128 | self.conv_out1 = nn.Sequential(nn.Conv2d(channels[3]//2, 1, 1, 1, 0)) 129 | 130 | def forward(self, features): 131 | x, x1, x2, x3, x4 = features 132 | outs = [] 133 | p4 = self.decoder_block4(x4) 134 | _p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True) 135 | _p3 = _p4 + self.lateral_block4(x3) 136 | 137 | p3 = self.decoder_block3(_p3) 138 | _p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True) 139 | _p2 = _p3 + self.lateral_block3(x2) 140 | 141 | p2 = self.decoder_block2(_p2) 142 | _p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True) 143 | _p1 = _p2 + self.lateral_block2(x1) 144 | 145 | _p1 = self.decoder_block1(_p1) 146 | _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True) 147 | p1_out = self.conv_out1(_p1) 148 | 149 | if self.config.ms_supervision: 150 | outs.append(self.conv_ms_spvn_4(p4)) 151 | outs.append(self.conv_ms_spvn_3(p3)) 152 | outs.append(self.conv_ms_spvn_2(p2)) 153 | outs.append(p1_out) 154 | return outs 155 | 156 | 157 | class RefUNet(nn.Module): 158 | # Refinement 159 | def __init__(self, in_channels=3+1): 160 | super(RefUNet, self).__init__() 161 | self.encoder_1 = nn.Sequential( 162 | nn.Conv2d(in_channels, 64, 3, 1, 1), 163 | nn.Conv2d(64, 64, 3, 1, 1), 164 | nn.BatchNorm2d(64), 165 | nn.ReLU(inplace=True) 166 | ) 167 | 168 | self.encoder_2 = nn.Sequential( 169 | nn.MaxPool2d(2, 2, ceil_mode=True), 170 | nn.Conv2d(64, 64, 3, 1, 1), 171 | nn.BatchNorm2d(64), 172 | nn.ReLU(inplace=True) 173 | ) 174 | 175 | self.encoder_3 = nn.Sequential( 176 | nn.MaxPool2d(2, 2, ceil_mode=True), 177 | nn.Conv2d(64, 64, 3, 1, 1), 178 | nn.BatchNorm2d(64), 179 | nn.ReLU(inplace=True) 180 | ) 181 | 182 | self.encoder_4 = nn.Sequential( 183 | nn.MaxPool2d(2, 2, ceil_mode=True), 184 | nn.Conv2d(64, 64, 3, 1, 1), 185 | nn.BatchNorm2d(64), 186 | nn.ReLU(inplace=True) 187 | ) 188 | 189 | self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True) 190 | ##### 191 | self.decoder_5 = nn.Sequential( 192 | nn.Conv2d(64, 64, 3, 1, 1), 193 | nn.BatchNorm2d(64), 194 | nn.ReLU(inplace=True) 195 | ) 196 | ##### 197 | self.decoder_4 = nn.Sequential( 198 | nn.Conv2d(128, 64, 3, 1, 1), 199 | nn.BatchNorm2d(64), 200 | nn.ReLU(inplace=True) 201 | ) 202 | 203 | self.decoder_3 = nn.Sequential( 204 | nn.Conv2d(128, 64, 3, 1, 1), 205 | nn.BatchNorm2d(64), 206 | nn.ReLU(inplace=True) 207 | ) 208 | 209 | self.decoder_2 = nn.Sequential( 210 | nn.Conv2d(128, 64, 3, 1, 1), 211 | nn.BatchNorm2d(64), 212 | nn.ReLU(inplace=True) 213 | ) 214 | 215 | self.decoder_1 = nn.Sequential( 216 | nn.Conv2d(128, 64, 3, 1, 1), 217 | nn.BatchNorm2d(64), 218 | nn.ReLU(inplace=True) 219 | ) 220 | 221 | self.conv_d0 = nn.Conv2d(64, 1, 3, 1, 1) 222 | 223 | self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 224 | 225 | def forward(self, x): 226 | outs = [] 227 | if isinstance(x, list): 228 | x = torch.cat(x, dim=1) 229 | hx = x 230 | 231 | hx1 = self.encoder_1(hx) 232 | hx2 = self.encoder_2(hx1) 233 | hx3 = self.encoder_3(hx2) 234 | hx4 = self.encoder_4(hx3) 235 | 236 | hx = self.decoder_5(self.pool4(hx4)) 237 | hx = torch.cat((self.upscore2(hx), hx4), 1) 238 | 239 | d4 = self.decoder_4(hx) 240 | hx = torch.cat((self.upscore2(d4), hx3), 1) 241 | 242 | d3 = self.decoder_3(hx) 243 | hx = torch.cat((self.upscore2(d3), hx2), 1) 244 | 245 | d2 = self.decoder_2(hx) 246 | hx = torch.cat((self.upscore2(d2), hx1), 1) 247 | 248 | d1 = self.decoder_1(hx) 249 | 250 | x = self.conv_d0(d1) 251 | outs.append(x) 252 | return outs 253 | -------------------------------------------------------------------------------- /models/refinement/stem_layer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from models.modules.utils import build_act_layer, build_norm_layer 3 | 4 | 5 | class StemLayer(nn.Module): 6 | r""" Stem layer of InternImage 7 | Args: 8 | in_channels (int): number of input channels 9 | out_channels (int): number of output channels 10 | act_layer (str): activation layer 11 | norm_layer (str): normalization layer 12 | """ 13 | 14 | def __init__(self, 15 | in_channels=3+1, 16 | inter_channels=48, 17 | out_channels=96, 18 | act_layer='GELU', 19 | norm_layer='BN'): 20 | super().__init__() 21 | self.conv1 = nn.Conv2d(in_channels, 22 | inter_channels, 23 | kernel_size=3, 24 | stride=1, 25 | padding=1) 26 | self.norm1 = build_norm_layer( 27 | inter_channels, norm_layer, 'channels_first', 'channels_first' 28 | ) 29 | self.act = build_act_layer(act_layer) 30 | self.conv2 = nn.Conv2d(inter_channels, 31 | out_channels, 32 | kernel_size=3, 33 | stride=1, 34 | padding=1) 35 | self.norm2 = build_norm_layer( 36 | out_channels, norm_layer, 'channels_first', 'channels_first' 37 | ) 38 | 39 | def forward(self, x): 40 | x = self.conv1(x) 41 | x = self.norm1(x) 42 | x = self.act(x) 43 | x = self.conv2(x) 44 | x = self.norm2(x) 45 | return x 46 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.5.0 2 | torchvision 3 | numpy<2 4 | opencv-python 5 | timm 6 | scipy 7 | scikit-image 8 | kornia 9 | einops 10 | 11 | tqdm 12 | prettytable 13 | tabulate 14 | 15 | huggingface-hub>0.25 16 | accelerate 17 | -------------------------------------------------------------------------------- /rm_cache.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | rm -rf __pycache__ */__pycache__ */*/__pycache__ 3 | 4 | # Val 5 | rm -r tmp* 6 | 7 | # Train 8 | rm slurm* 9 | rm -r ckpt 10 | rm nohup.out* 11 | rm nohup.log* 12 | 13 | # Eval 14 | rm -r evaluation/eval-* 15 | rm -r tmp* 16 | rm -r e_logs/ 17 | 18 | # System 19 | rm core-*-python-* 20 | 21 | # Inference cache 22 | rm -rf images_todo/ 23 | rm -rf predictions/ 24 | 25 | clear 26 | -------------------------------------------------------------------------------- /sub.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Example: ./sub.sh tmp_proj 0,1,2,3 3 --> Use 0,1,2,3 for training, release GPUs, use GPU:3 for inference. 3 | 4 | # module load gcc/11.2.0 cuda/11.8 cudnn/8.6.0_cu11x && cpu_core_num=6 5 | module load compilers/cuda/11.8 compilers/gcc/12.2.0 cudnn/8.4.0.27_cuda11.x && cpu_core_num=32 6 | 7 | export PYTHONUNBUFFERED=1 8 | 9 | method=${1:-"BSL"} 10 | devices=${2:-"0,1"} 11 | gpu_num=$(($(echo ${devices%%,} | grep -o "," | wc -l)+1)) 12 | 13 | sbatch --nodes=1 -p vip_gpu_ailab -A ai4bio \ 14 | --gres=gpu:${gpu_num} --ntasks-per-node=1 --cpus-per-task=$((gpu_num*cpu_core_num)) \ 15 | ./train_test.sh ${method} ${devices} 16 | 17 | hostname 18 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | devices=${1:-0} 2 | pred_root=${2:-e_preds} 3 | resolutions=${3:-"config.size"} 4 | 5 | # Inference 6 | # resolutions="1024x1024 None" 7 | for resolution in ${resolutions}; do 8 | CUDA_VISIBLE_DEVICES=${devices} python inference.py --pred_root ${pred_root} --resolution ${resolution} 9 | done 10 | 11 | echo Inference finished at $(date) 12 | 13 | # Evaluation 14 | log_dir=e_logs && mkdir ${log_dir} 15 | 16 | task=$(python3 config.py --print_task) 17 | testsets=$(python3 config.py --print_testsets) 18 | 19 | testsets=(`echo ${testsets} | tr ',' ' '`) && testsets=${testsets[@]} 20 | 21 | for testset in ${testsets}; do 22 | # python eval_existingOnes.py --pred_root ${pred_root} --data_lst ${testset} > ${log_dir}/eval_${testset}.out 23 | nohup python eval_existingOnes.py --pred_root ${pred_root} --data_lst ${testset} > ${log_dir}/eval_${testset}.out 2>&1 & 24 | done 25 | 26 | 27 | echo Evaluation started at $(date) 28 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | from contextlib import nullcontext 4 | import argparse 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | if tuple(map(int, torch.__version__.split('+')[0].split(".")[:3])) >= (2, 5, 0): 9 | os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' 10 | 11 | from config import Config 12 | from loss import PixLoss, ClsLoss 13 | from dataset import MyData 14 | from models.birefnet import BiRefNet, BiRefNetC2F 15 | from utils import Logger, AverageMeter, set_seed, check_state_dict 16 | 17 | from torch.utils.data.distributed import DistributedSampler 18 | from torch.nn.parallel import DistributedDataParallel as DDP 19 | from torch.distributed import init_process_group, destroy_process_group 20 | 21 | 22 | parser = argparse.ArgumentParser(description='') 23 | parser.add_argument('--resume', default=None, type=str, help='path to latest checkpoint') 24 | parser.add_argument('--epochs', default=120, type=int) 25 | parser.add_argument('--ckpt_dir', default='ckpt/tmp', help='Temporary folder') 26 | parser.add_argument('--dist', default=False, type=lambda x: x == 'True') 27 | parser.add_argument('--use_accelerate', action='store_true', help='`accelerate launch --multi_gpu train.py --use_accelerate`. Use accelerate for training, good for FP16/BF16/...') 28 | args = parser.parse_args() 29 | 30 | config = Config() 31 | 32 | if args.use_accelerate: 33 | from accelerate import Accelerator, utils 34 | mixed_precision = config.mixed_precision 35 | accelerator = Accelerator( 36 | mixed_precision=mixed_precision, 37 | gradient_accumulation_steps=1, 38 | kwargs_handlers=[ 39 | utils.InitProcessGroupKwargs(backend="nccl", timeout=datetime.timedelta(seconds=3600*10)), 40 | utils.DistributedDataParallelKwargs(find_unused_parameters=False), 41 | utils.GradScalerKwargs(backoff_factor=0.5)], 42 | ) 43 | args.dist = False 44 | 45 | # DDP 46 | to_be_distributed = args.dist 47 | if to_be_distributed: 48 | init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=3600*10)) 49 | device = int(os.environ["LOCAL_RANK"]) 50 | else: 51 | if args.use_accelerate: 52 | device = accelerator.local_process_index 53 | else: 54 | device = config.device 55 | 56 | if config.rand_seed: 57 | set_seed(config.rand_seed + device) 58 | 59 | epoch_st = 1 60 | # make dir for ckpt 61 | os.makedirs(args.ckpt_dir, exist_ok=True) 62 | 63 | # Init log file 64 | logger = Logger(os.path.join(args.ckpt_dir, "log.txt")) 65 | logger_loss_idx = 1 66 | 67 | # log model and optimizer params 68 | # logger.info("Model details:"); logger.info(model) 69 | # if args.use_accelerate and accelerator.mixed_precision != 'no': 70 | # config.compile = False 71 | logger.info("datasets: load_all={}, compile={}.".format(config.load_all, config.compile)) 72 | logger.info("Other hyperparameters:"); logger.info(args) 73 | print('batch size:', config.batch_size) 74 | 75 | from dataset import custom_collate_fn 76 | 77 | def prepare_dataloader(dataset: torch.utils.data.Dataset, batch_size: int, to_be_distributed=False, is_train=True): 78 | # Prepare dataloaders 79 | if to_be_distributed: 80 | return torch.utils.data.DataLoader( 81 | dataset=dataset, batch_size=batch_size, num_workers=min(config.num_workers, batch_size), pin_memory=True, 82 | shuffle=False, sampler=DistributedSampler(dataset), drop_last=True, collate_fn=custom_collate_fn if is_train and config.dynamic_size else None 83 | ) 84 | else: 85 | return torch.utils.data.DataLoader( 86 | dataset=dataset, batch_size=batch_size, num_workers=min(config.num_workers, batch_size), pin_memory=True, 87 | shuffle=is_train, sampler=None, drop_last=True, collate_fn=custom_collate_fn if is_train and config.dynamic_size else None 88 | ) 89 | 90 | 91 | def init_data_loaders(to_be_distributed): 92 | # Prepare datasets 93 | train_loader = prepare_dataloader( 94 | MyData(datasets=config.training_set, data_size=None if config.dynamic_size else config.size, is_train=True), 95 | config.batch_size, to_be_distributed=to_be_distributed, is_train=True 96 | ) 97 | print(len(train_loader), "batches of train dataloader {} have been created.".format(config.training_set)) 98 | return train_loader 99 | 100 | 101 | def init_models_optimizers(epochs, to_be_distributed): 102 | # Init models 103 | if config.model == 'BiRefNet': 104 | model = BiRefNet(bb_pretrained=True and not os.path.isfile(str(args.resume))) 105 | elif config.model == 'BiRefNetC2F': 106 | model = BiRefNetC2F(bb_pretrained=True and not os.path.isfile(str(args.resume))) 107 | if args.resume: 108 | if os.path.isfile(args.resume): 109 | logger.info("=> loading checkpoint '{}'".format(args.resume)) 110 | state_dict = torch.load(args.resume, map_location='cpu', weights_only=True) 111 | state_dict = check_state_dict(state_dict) 112 | model.load_state_dict(state_dict) 113 | global epoch_st 114 | epoch_st = int(args.resume.rstrip('.pth').split('epoch_')[-1]) + 1 115 | else: 116 | logger.info("=> no checkpoint found at '{}'".format(args.resume)) 117 | if not args.use_accelerate: 118 | if to_be_distributed: 119 | model = model.to(device) 120 | model = DDP(model, device_ids=[device]) 121 | else: 122 | model = model.to(device) 123 | if config.compile: 124 | model = torch.compile(model, mode=['default', 'reduce-overhead', 'max-autotune'][0]) 125 | if config.precisionHigh: 126 | torch.set_float32_matmul_precision('high') 127 | 128 | # Setting optimizer 129 | if config.optimizer == 'AdamW': 130 | optimizer = optim.AdamW(params=model.parameters(), lr=config.lr, weight_decay=1e-2) 131 | elif config.optimizer == 'Adam': 132 | optimizer = optim.Adam(params=model.parameters(), lr=config.lr, weight_decay=0) 133 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( 134 | optimizer, 135 | milestones=[lde if lde > 0 else epochs + lde + 1 for lde in config.lr_decay_epochs], 136 | gamma=config.lr_decay_rate 137 | ) 138 | # logger.info("Optimizer details:"); logger.info(optimizer) 139 | 140 | return model, optimizer, lr_scheduler 141 | 142 | 143 | class Trainer: 144 | def __init__( 145 | self, data_loaders, model_opt_lrsch, 146 | ): 147 | self.model, self.optimizer, self.lr_scheduler = model_opt_lrsch 148 | self.train_loader = data_loaders 149 | if args.use_accelerate: 150 | self.train_loader, self.model, self.optimizer = accelerator.prepare(self.train_loader, self.model, self.optimizer) 151 | if config.out_ref: 152 | self.criterion_gdt = nn.BCELoss() 153 | 154 | # Setting Losses 155 | self.pix_loss = PixLoss() 156 | self.cls_loss = ClsLoss() 157 | 158 | # Others 159 | self.loss_log = AverageMeter() 160 | 161 | def _train_batch(self, batch): 162 | if args.use_accelerate: 163 | inputs = batch[0]#.to(device) 164 | gts = batch[1]#.to(device) 165 | class_labels = batch[2]#.to(device) 166 | else: 167 | inputs = batch[0].to(device) 168 | gts = batch[1].to(device) 169 | class_labels = batch[2].to(device) 170 | self.optimizer.zero_grad() 171 | scaled_preds, class_preds_lst = self.model(inputs) 172 | if config.out_ref: 173 | (outs_gdt_pred, outs_gdt_label), scaled_preds = scaled_preds 174 | for _idx, (_gdt_pred, _gdt_label) in enumerate(zip(outs_gdt_pred, outs_gdt_label)): 175 | _gdt_pred = nn.functional.interpolate(_gdt_pred, size=_gdt_label.shape[2:], mode='bilinear', align_corners=True).sigmoid() 176 | _gdt_label = _gdt_label.sigmoid() 177 | loss_gdt = self.criterion_gdt(_gdt_pred, _gdt_label) if _idx == 0 else self.criterion_gdt(_gdt_pred, _gdt_label) + loss_gdt 178 | # self.loss_dict['loss_gdt'] = loss_gdt.item() 179 | if None in class_preds_lst: 180 | loss_cls = 0. 181 | else: 182 | loss_cls = self.cls_loss(class_preds_lst, class_labels) 183 | self.loss_dict['loss_cls'] = loss_cls.item() 184 | 185 | # Loss 186 | loss_pix, loss_dict_pix = self.pix_loss(scaled_preds, torch.clamp(gts, 0, 1), pix_loss_lambda=1.0) 187 | self.loss_dict.update(loss_dict_pix) 188 | self.loss_dict['loss_pix'] = loss_pix.item() 189 | # since there may be several losses for sal, the lambdas for them (lambdas_pix) are inside the loss.py 190 | loss = loss_pix + loss_cls 191 | if config.out_ref: 192 | loss = loss + loss_gdt * 1.0 193 | 194 | self.loss_log.update(loss.item(), inputs.size(0)) 195 | if args.use_accelerate: 196 | loss = loss / accelerator.gradient_accumulation_steps 197 | accelerator.backward(loss) 198 | else: 199 | loss.backward() 200 | self.optimizer.step() 201 | 202 | def train_epoch(self, epoch): 203 | global logger_loss_idx 204 | self.model.train() 205 | self.loss_dict = {} 206 | if epoch > args.epochs + config.finetune_last_epochs: 207 | if config.task == 'Matting': 208 | self.pix_loss.lambdas_pix_last['mae'] *= 1 209 | self.pix_loss.lambdas_pix_last['mse'] *= 0.9 210 | self.pix_loss.lambdas_pix_last['ssim'] *= 0.9 211 | else: 212 | self.pix_loss.lambdas_pix_last['bce'] *= 0 213 | self.pix_loss.lambdas_pix_last['ssim'] *= 1 214 | self.pix_loss.lambdas_pix_last['iou'] *= 0.5 215 | self.pix_loss.lambdas_pix_last['mae'] *= 0.9 216 | 217 | for batch_idx, batch in enumerate(self.train_loader): 218 | # with nullcontext if not args.use_accelerate or accelerator.gradient_accumulation_steps <= 1 else accelerator.accumulate(self.model): 219 | self._train_batch(batch) 220 | # Logger 221 | if (epoch < 2 and batch_idx < 100 and batch_idx % 20 == 0) or batch_idx % max(100, len(self.train_loader) / 100 // 100 * 100) == 0: 222 | info_progress = f'Epoch[{epoch}/{args.epochs}] Iter[{batch_idx}/{len(self.train_loader)}].' 223 | info_loss = 'Training Losses:' 224 | for loss_name, loss_value in self.loss_dict.items(): 225 | info_loss += f' {loss_name}: {loss_value:.5g} |' 226 | logger.info(' '.join((info_progress, info_loss))) 227 | info_loss = f'@==Final== Epoch[{epoch}/{args.epochs}] Training Loss: {self.loss_log.avg:.5g} ' 228 | logger.info(info_loss) 229 | 230 | self.lr_scheduler.step() 231 | return self.loss_log.avg 232 | 233 | 234 | def main(): 235 | 236 | trainer = Trainer( 237 | data_loaders=init_data_loaders(to_be_distributed), 238 | model_opt_lrsch=init_models_optimizers(args.epochs, to_be_distributed) 239 | ) 240 | 241 | for epoch in range(epoch_st, args.epochs+1): 242 | train_loss = trainer.train_epoch(epoch) 243 | # Save checkpoint 244 | if epoch >= args.epochs - config.save_last and epoch % config.save_step == 0: 245 | if args.use_accelerate: 246 | state_dict = trainer.model.state_dict() 247 | else: 248 | state_dict = trainer.model.module.state_dict() if to_be_distributed else trainer.model.state_dict() 249 | torch.save(state_dict, os.path.join(args.ckpt_dir, 'epoch_{}.pth'.format(epoch))) 250 | if to_be_distributed: 251 | destroy_process_group() 252 | 253 | 254 | if __name__ == '__main__': 255 | main() 256 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Run script 3 | # Settings of training & test for different tasks. 4 | method="$1" 5 | task=$(python3 config.py --print_task) 6 | case "${task}" in 7 | 'DIS5K') epochs=500 && val_last=50 && step=5 ;; 8 | 'COD') epochs=150 && val_last=50 && step=5 ;; 9 | 'HRSOD') epochs=150 && val_last=50 && step=5 ;; 10 | 'General') epochs=200 && val_last=50 && step=5 ;; 11 | 'General-2K') epochs=250 && val_last=30 && step=2 ;; 12 | 'Matting') epochs=150 && val_last=50 && step=5 ;; 13 | esac 14 | 15 | # Train 16 | devices=$2 17 | nproc_per_node=$(echo ${devices%%,} | grep -o "," | wc -l) 18 | 19 | to_be_distributed=`echo ${nproc_per_node} | awk '{if($e > 0) print "True"; else print "False";}'` 20 | 21 | echo Training started at $(date) 22 | resume_weights_path='path_to_a_pth' 23 | if [ ${to_be_distributed} == "True" ] 24 | then 25 | # Adapt the nproc_per_node by the number of GPUs. Give 8989 as the default value of master_port. 26 | echo "Multi-GPU mode received..." 27 | CUDA_VISIBLE_DEVICES=${devices} \ 28 | torchrun --standalone --nproc_per_node $((nproc_per_node+1)) \ 29 | train.py --ckpt_dir ckpt/${method} --epochs ${epochs} \ 30 | --dist ${to_be_distributed} \ 31 | --resume ${resume_weights_path} \ 32 | --use_accelerate 33 | else 34 | echo "Single-GPU mode received..." 35 | CUDA_VISIBLE_DEVICES=${devices} \ 36 | python train.py --ckpt_dir ckpt/${method} --epochs ${epochs} \ 37 | --dist ${to_be_distributed} \ 38 | --resume ${resume_weights_path} \ 39 | --use_accelerate 40 | fi 41 | 42 | echo Training finished at $(date) 43 | -------------------------------------------------------------------------------- /train_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Example: `setsid nohup ./train_test.sh BiRefNet 0,1,2,3,4,5,6,7 0 &>nohup.log &` 3 | 4 | method=${1:-"BSL"} 5 | devices=${2:-"0,1,2,3,4,5,6,7"} 6 | 7 | bash train.sh ${method} ${devices} 8 | 9 | devices_test=${3:-0} 10 | bash test.sh ${devices_test} 11 | 12 | hostname 13 | -------------------------------------------------------------------------------- /tutorials/BiRefNet_pth2onnx.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "LTj2A0RUQFNo" 7 | }, 8 | "source": [ 9 | "# Convert our BiRefNet weights to onnx format.\n", 10 | "\n", 11 | "> This colab file is modified from [Kazuhito00](https://github.com/Kazuhito00)'s nice work.\n", 12 | "\n", 13 | "> Repo: https://github.com/Kazuhito00/BiRefNet-ONNX-Sample \n", 14 | "> Original Colab: https://colab.research.google.com/github/Kazuhito00/BiRefNet-ONNX-Sample/blob/main/Convert2ONNX.ipynb\n", 15 | "\n", 16 | "+ Transforming a standard BiRefNet on GPU needs **19.7GB** GPU memory.\n", 17 | "+ Currently, Colab with 12.7GB RAM / 15GB GPU Mem cannot hold the transformation of BiRefNet in default setting. So, I take BiRefNet with swin_v1_tiny backbone as an example on Colab." 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "### Online Colab version: https://colab.research.google.com/drive/1z6OruR52LOvDDpnp516F-N4EyPGrp5om" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 1, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "name": "stdout", 34 | "output_type": "stream", 35 | "text": [ 36 | "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", 37 | "\u001b[0m" 38 | ] 39 | } 40 | ], 41 | "source": [ 42 | "!pip install -q onnx onnxscript onnxruntime-gpu==1.18.1" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 2, 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "name": "stdout", 52 | "output_type": "stream", 53 | "text": [ 54 | "/root/autodl-tmp/BiRefNet\n" 55 | ] 56 | } 57 | ], 58 | "source": [ 59 | "cd .." 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 3, 65 | "metadata": { 66 | "id": "781JHjLJmveh" 67 | }, 68 | "outputs": [], 69 | "source": [ 70 | "import torch\n", 71 | "\n", 72 | "\n", 73 | "weights_file = 'BiRefNet-matting-epoch_100.pth' # https://github.com/ZhengPeng7/BiRefNet/releases/download/v1/BiRefNet-general-bb_swin_v1_tiny-epoch_232.pth\n", 74 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 4, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "with open('config.py') as fp:\n", 84 | " file_lines = fp.read()\n", 85 | "if 'swin_v1_tiny' in weights_file:\n", 86 | " print('Set `swin_v1_tiny` as the backbone.')\n", 87 | " file_lines = file_lines.replace(\n", 88 | " '''\n", 89 | " 'pvt_v2_b2', 'pvt_v2_b5', # 9-bs10, 10-bs5\n", 90 | " ][6]\n", 91 | " ''',\n", 92 | " '''\n", 93 | " 'pvt_v2_b2', 'pvt_v2_b5', # 9-bs10, 10-bs5\n", 94 | " ][3]\n", 95 | " ''',\n", 96 | " )\n", 97 | " with open('config.py', mode=\"w\") as fp:\n", 98 | " fp.write(file_lines)\n", 99 | "else:\n", 100 | " file_lines = file_lines.replace(\n", 101 | " '''\n", 102 | " 'pvt_v2_b2', 'pvt_v2_b5', # 9-bs10, 10-bs5\n", 103 | " ][3]\n", 104 | " ''',\n", 105 | " '''\n", 106 | " 'pvt_v2_b2', 'pvt_v2_b5', # 9-bs10, 10-bs5\n", 107 | " ][6]\n", 108 | " ''',\n", 109 | " )\n", 110 | " with open('config.py', mode=\"w\") as fp:\n", 111 | " fp.write(file_lines)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 5, 117 | "metadata": { 118 | "id": "7lFgKfPS8Icy" 119 | }, 120 | "outputs": [ 121 | { 122 | "name": "stderr", 123 | "output_type": "stream", 124 | "text": [ 125 | "/root/miniconda3/envs/birefnet/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 126 | " from .autonotebook import tqdm as notebook_tqdm\n", 127 | "/root/miniconda3/envs/birefnet/lib/python3.9/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers\n", 128 | " warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.layers\", FutureWarning)\n", 129 | "/root/miniconda3/envs/birefnet/lib/python3.9/site-packages/timm/models/registry.py:4: FutureWarning: Importing from timm.models.registry is deprecated, please import via timm.models\n", 130 | " warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.models\", FutureWarning)\n" 131 | ] 132 | } 133 | ], 134 | "source": [ 135 | "from utils import check_state_dict\n", 136 | "from models.birefnet import BiRefNet\n", 137 | "\n", 138 | "\n", 139 | "birefnet = BiRefNet(bb_pretrained=False)\n", 140 | "state_dict = torch.load('./{}'.format(weights_file), map_location=device, weights_only=True)\n", 141 | "state_dict = check_state_dict(state_dict)\n", 142 | "birefnet.load_state_dict(state_dict)\n", 143 | "\n", 144 | "torch.set_float32_matmul_precision(['high', 'highest'][0])\n", 145 | "\n", 146 | "birefnet.to(device)\n", 147 | "_ = birefnet.eval()" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": { 153 | "id": "JVgJAdgxQVJW" 154 | }, 155 | "source": [ 156 | "# Process deform_conv2d in the conversion to ONNX" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 6, 162 | "metadata": {}, 163 | "outputs": [ 164 | { 165 | "name": "stdout", 166 | "output_type": "stream", 167 | "text": [ 168 | "Cloning into 'deform_conv2d_onnx_exporter'...\n", 169 | "remote: Enumerating objects: 205, done.\u001b[K\n", 170 | "remote: Counting objects: 100% (7/7), done.\u001b[K\n", 171 | "remote: Total 205 (delta 6), reused 6 (delta 6), pack-reused 198 (from 1)\u001b[K\n", 172 | "Receiving objects: 100% (205/205), 36.21 KiB | 170.00 KiB/s, done.\n", 173 | "Resolving deltas: 100% (102/102), done.\n" 174 | ] 175 | } 176 | ], 177 | "source": [ 178 | "!git clone https://github.com/masamitsu-murase/deform_conv2d_onnx_exporter\n", 179 | "%cp deform_conv2d_onnx_exporter/src/deform_conv2d_onnx_exporter.py .\n", 180 | "!rm -rf deform_conv2d_onnx_exporter" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 7, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "with open('deform_conv2d_onnx_exporter.py') as fp:\n", 190 | " file_lines = fp.read()\n", 191 | "\n", 192 | "file_lines = file_lines.replace(\n", 193 | " \"return sym_help._get_tensor_dim_size(tensor, dim)\",\n", 194 | " '''\n", 195 | " tensor_dim_size = sym_help._get_tensor_dim_size(tensor, dim)\n", 196 | " if tensor_dim_size == None and (dim == 2 or dim == 3):\n", 197 | " import typing\n", 198 | " from torch import _C\n", 199 | "\n", 200 | " x_type = typing.cast(_C.TensorType, tensor.type())\n", 201 | " x_strides = x_type.strides()\n", 202 | "\n", 203 | " tensor_dim_size = x_strides[2] if dim == 3 else x_strides[1] // x_strides[2]\n", 204 | " elif tensor_dim_size == None and (dim == 0):\n", 205 | " import typing\n", 206 | " from torch import _C\n", 207 | "\n", 208 | " x_type = typing.cast(_C.TensorType, tensor.type())\n", 209 | " x_strides = x_type.strides()\n", 210 | " tensor_dim_size = x_strides[3]\n", 211 | "\n", 212 | " return tensor_dim_size\n", 213 | " ''',\n", 214 | ")\n", 215 | "\n", 216 | "with open('deform_conv2d_onnx_exporter.py', mode=\"w\") as fp:\n", 217 | " fp.write(file_lines)" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 8, 223 | "metadata": { 224 | "id": "vJiZv0L75kTe" 225 | }, 226 | "outputs": [ 227 | { 228 | "name": "stderr", 229 | "output_type": "stream", 230 | "text": [ 231 | "/root/autodl-tmp/BiRefNet/models/backbones/swin_v1.py:441: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", 232 | " if W % self.patch_size[1] != 0:\n", 233 | "/root/autodl-tmp/BiRefNet/models/backbones/swin_v1.py:443: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", 234 | " if H % self.patch_size[0] != 0:\n", 235 | "/root/autodl-tmp/BiRefNet/models/backbones/swin_v1.py:379: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", 236 | " Hp = int(np.ceil(H / self.window_size)) * self.window_size\n", 237 | "/root/autodl-tmp/BiRefNet/models/backbones/swin_v1.py:380: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", 238 | " Wp = int(np.ceil(W / self.window_size)) * self.window_size\n", 239 | "/root/autodl-tmp/BiRefNet/models/backbones/swin_v1.py:216: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", 240 | " assert L == H * W, \"input feature has wrong size\"\n", 241 | "/root/autodl-tmp/BiRefNet/models/backbones/swin_v1.py:67: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", 242 | " B = int(windows.shape[0] / (H * W / window_size / window_size))\n", 243 | "/root/autodl-tmp/BiRefNet/models/backbones/swin_v1.py:254: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", 244 | " if pad_r > 0 or pad_b > 0:\n", 245 | "/root/autodl-tmp/BiRefNet/models/backbones/swin_v1.py:287: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", 246 | " assert L == H * W, \"input feature has wrong size\"\n", 247 | "/root/autodl-tmp/BiRefNet/models/backbones/swin_v1.py:292: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", 248 | " pad_input = (H % 2 == 1) or (W % 2 == 1)\n", 249 | "/root/autodl-tmp/BiRefNet/models/backbones/swin_v1.py:293: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", 250 | " if pad_input:\n", 251 | "/root/miniconda3/envs/birefnet/lib/python3.9/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n", 252 | " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n" 253 | ] 254 | }, 255 | { 256 | "name": "stdout", 257 | "output_type": "stream", 258 | "text": [ 259 | "============= Diagnostic Run torch.onnx.export version 2.0.1+cu118 =============\n", 260 | "verbose: False, log level: Level.ERROR\n", 261 | "======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================\n", 262 | "\n" 263 | ] 264 | } 265 | ], 266 | "source": [ 267 | "from torchvision.ops.deform_conv import DeformConv2d\n", 268 | "import deform_conv2d_onnx_exporter\n", 269 | "\n", 270 | "# register deform_conv2d operator\n", 271 | "deform_conv2d_onnx_exporter.register_deform_conv2d_onnx_op()\n", 272 | "\n", 273 | "def convert_to_onnx(net, file_name='output.onnx', input_shape=(1024, 1024), device=device):\n", 274 | " input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device)\n", 275 | "\n", 276 | " input_layer_names = ['input_image']\n", 277 | " output_layer_names = ['output_image']\n", 278 | "\n", 279 | " torch.onnx.export(\n", 280 | " net,\n", 281 | " input,\n", 282 | " file_name,\n", 283 | " verbose=False,\n", 284 | " opset_version=17,\n", 285 | " input_names=input_layer_names,\n", 286 | " output_names=output_layer_names,\n", 287 | " )\n", 288 | "convert_to_onnx(birefnet, weights_file.replace('.pth', '.onnx'), input_shape=(1024, 1024), device=device)" 289 | ] 290 | }, 291 | { 292 | "cell_type": "markdown", 293 | "metadata": { 294 | "id": "-eU-g40P1zS-" 295 | }, 296 | "source": [ 297 | "# Load ONNX weights and do the inference." 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [ 306 | "from PIL import Image\n", 307 | "from torchvision import transforms\n", 308 | "\n", 309 | "\n", 310 | "transform_image = transforms.Compose([\n", 311 | " transforms.Resize((1024, 1024)),\n", 312 | " transforms.ToTensor(),\n", 313 | " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", 314 | "])\n", 315 | "\n", 316 | "imagepath = './Helicopter-HR.jpg'\n", 317 | "image = Image.open(imagepath)\n", 318 | "image = image.convert(\"RGB\") if image.mode != \"RGB\" else image\n", 319 | "input_images = transform_image(image).unsqueeze(0).to(device)\n", 320 | "input_images_numpy = input_images.cpu().numpy()" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": null, 326 | "metadata": { 327 | "id": "rwzdKX1EfYkd" 328 | }, 329 | "outputs": [], 330 | "source": [ 331 | "import onnxruntime\n", 332 | "import matplotlib.pyplot as plt\n", 333 | "\n", 334 | "\n", 335 | "providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider']\n", 336 | "onnx_session = onnxruntime.InferenceSession(\n", 337 | " weights_file.replace('.pth', '.onnx'),\n", 338 | " providers=providers\n", 339 | ")\n", 340 | "input_name = onnx_session.get_inputs()[0].name\n", 341 | "print(onnxruntime.get_device(), onnx_session.get_providers())" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": null, 347 | "metadata": { 348 | "id": "DJVtxZUZum4-" 349 | }, 350 | "outputs": [], 351 | "source": [ 352 | "from time import time\n", 353 | "import matplotlib.pyplot as plt\n", 354 | "\n", 355 | "time_st = time()\n", 356 | "pred_onnx = torch.tensor(\n", 357 | " onnx_session.run(None, {input_name: input_images_numpy if device == 'cpu' else input_images_numpy})[-1]\n", 358 | ").squeeze(0).sigmoid().cpu()\n", 359 | "print(time() - time_st)\n", 360 | "\n", 361 | "plt.imshow(pred_onnx.squeeze(), cmap='gray'); plt.show()" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": null, 367 | "metadata": {}, 368 | "outputs": [], 369 | "source": [ 370 | "with torch.no_grad():\n", 371 | " preds = birefnet(input_images)[-1].sigmoid().to(torch.float32).cpu()\n", 372 | "plt.imshow(preds.squeeze(), cmap='gray'); plt.show()" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": null, 378 | "metadata": {}, 379 | "outputs": [], 380 | "source": [ 381 | "diff = abs(preds - pred_onnx)\n", 382 | "print('sum(diff):', diff.sum())\n", 383 | "plt.imshow((diff).squeeze(), cmap='gray'); plt.show()" 384 | ] 385 | }, 386 | { 387 | "cell_type": "markdown", 388 | "metadata": { 389 | "id": "qzYHflt92Bjd" 390 | }, 391 | "source": [ 392 | "# Efficiency Comparison between .pth and .onnx" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": null, 398 | "metadata": { 399 | "colab": { 400 | "base_uri": "https://localhost:8080/" 401 | }, 402 | "id": "A5IYfT-uzphA", 403 | "outputId": "2999e345-950e-41b3-ddd3-9f58a71a3f21" 404 | }, 405 | "outputs": [], 406 | "source": [ 407 | "%%timeit\n", 408 | "with torch.no_grad():\n", 409 | " preds = birefnet(input_images)[-1].sigmoid().to(torch.float32).cpu()" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": null, 415 | "metadata": { 416 | "id": "G0Ul4rfNg1za" 417 | }, 418 | "outputs": [], 419 | "source": [ 420 | "%%timeit\n", 421 | "pred_onnx = torch.tensor(\n", 422 | " onnx_session.run(None, {input_name: input_images_numpy})[-1]\n", 423 | ").squeeze(0).sigmoid().cpu()" 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": null, 429 | "metadata": {}, 430 | "outputs": [], 431 | "source": [] 432 | } 433 | ], 434 | "metadata": { 435 | "accelerator": "GPU", 436 | "colab": { 437 | "gpuType": "T4", 438 | "provenance": [] 439 | }, 440 | "kernelspec": { 441 | "display_name": "Python 3 (ipykernel)", 442 | "language": "python", 443 | "name": "python3" 444 | }, 445 | "language_info": { 446 | "codemirror_mode": { 447 | "name": "ipython", 448 | "version": 3 449 | }, 450 | "file_extension": ".py", 451 | "mimetype": "text/x-python", 452 | "name": "python", 453 | "nbconvert_exporter": "python", 454 | "pygments_lexer": "ipython3", 455 | "version": "3.10.16" 456 | } 457 | }, 458 | "nbformat": 4, 459 | "nbformat_minor": 4 460 | } 461 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch 4 | from torchvision import transforms 5 | import numpy as np 6 | import random 7 | import cv2 8 | from PIL import Image 9 | 10 | 11 | def path_to_image(path, size=(1024, 1024), color_type=['rgb', 'gray'][0]): 12 | if color_type.lower() == 'rgb': 13 | image = cv2.imread(path) 14 | elif color_type.lower() == 'gray': 15 | image = cv2.imread(path, cv2.IMREAD_GRAYSCALE) 16 | else: 17 | print('Select the color_type to return, either to RGB or gray image.') 18 | return 19 | if size: 20 | image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR) 21 | if color_type.lower() == 'rgb': 22 | image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).convert('RGB') 23 | else: 24 | image = Image.fromarray(image).convert('L') 25 | return image 26 | 27 | 28 | 29 | def check_state_dict(state_dict, unwanted_prefixes=['module.', '_orig_mod.']): 30 | for k, v in list(state_dict.items()): 31 | prefix_length = 0 32 | for unwanted_prefix in unwanted_prefixes: 33 | if k[prefix_length:].startswith(unwanted_prefix): 34 | prefix_length += len(unwanted_prefix) 35 | state_dict[k[prefix_length:]] = state_dict.pop(k) 36 | return state_dict 37 | 38 | 39 | def generate_smoothed_gt(gts): 40 | epsilon = 0.001 41 | new_gts = (1-epsilon)*gts+epsilon/2 42 | return new_gts 43 | 44 | 45 | class Logger(): 46 | def __init__(self, path="log.txt"): 47 | self.logger = logging.getLogger('BiRefNet') 48 | self.file_handler = logging.FileHandler(path, "w") 49 | self.stdout_handler = logging.StreamHandler() 50 | self.stdout_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s')) 51 | self.file_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s')) 52 | self.logger.addHandler(self.file_handler) 53 | self.logger.addHandler(self.stdout_handler) 54 | self.logger.setLevel(logging.INFO) 55 | self.logger.propagate = False 56 | 57 | def info(self, txt): 58 | self.logger.info(txt) 59 | 60 | def close(self): 61 | self.file_handler.close() 62 | self.stdout_handler.close() 63 | 64 | 65 | class AverageMeter(object): 66 | """Computes and stores the average and current value""" 67 | def __init__(self): 68 | self.reset() 69 | 70 | def reset(self): 71 | self.val = 0.0 72 | self.avg = 0.0 73 | self.sum = 0.0 74 | self.count = 0.0 75 | 76 | def update(self, val, n=1): 77 | self.val = val 78 | self.sum += val * n 79 | self.count += n 80 | self.avg = self.sum / self.count 81 | 82 | 83 | def save_checkpoint(state, path, filename="latest.pth"): 84 | torch.save(state, os.path.join(path, filename)) 85 | 86 | 87 | def save_tensor_img(tenor_im, path): 88 | im = tenor_im.cpu().clone() 89 | im = im.squeeze(0) 90 | tensor2pil = transforms.ToPILImage() 91 | im = tensor2pil(im) 92 | im.save(path) 93 | 94 | 95 | def set_seed(seed): 96 | torch.manual_seed(seed) 97 | torch.cuda.manual_seed_all(seed) 98 | np.random.seed(seed) 99 | random.seed(seed) 100 | torch.backends.cudnn.deterministic = True 101 | --------------------------------------------------------------------------------