├── README.md ├── assets └── ops │ └── dcn │ ├── __init__.py │ ├── functions │ ├── __init__.py │ ├── deform_conv.py │ └── deform_pool.py │ ├── modules │ ├── __init__.py │ ├── deform_conv.py │ └── deform_pool.py │ ├── setup.py │ └── src │ ├── deform_conv_cuda.cpp │ ├── deform_conv_cuda_kernel.cu │ ├── deform_pool_cuda.cpp │ └── deform_pool_cuda_kernel.cu ├── dataset.py ├── db_model ├── db_embedding.py ├── model_v3.py ├── representers │ ├── __init__.py │ └── seg_detector_representer.py └── resnet.py ├── demo_textboxPP.py ├── requirements.txt ├── scm ├── __init__.py ├── datasets │ ├── gen_json.py │ ├── par_crop.py │ └── scm_dataset.py ├── experiments │ └── siammask_sharp │ │ ├── config_icdar.json │ │ ├── custom.py │ │ └── resnet.py ├── models │ ├── __init__.py │ ├── features.py │ ├── mask.py │ ├── rpn.py │ └── siammask_sharp.py ├── tools │ └── track2mask.py └── utils │ ├── __init__.py │ ├── anchors.py │ ├── bbox_helper.py │ ├── load_helper.py │ ├── lr_helper.py │ └── tracker_config.py ├── track_textboxPP.py ├── tracker ├── __init__.py ├── basetrack.py ├── db_text_multitracker.py ├── kalman_filter.py └── matching.py ├── train_embedding.py ├── train_scm.py └── utils ├── __init__.py ├── log.py ├── meters.py ├── parse_config.py ├── timer.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Introduction 4 | This is a PyToch implementation of [Video Text Tracking With a Spatio-Temporal Complementary Model](https://arxiv.org/abs/2111.04987). 5 | 6 | Part of the code is inherited from [DB](https://github.com/MhLiao/DB) and [SiamMask](https://github.com/foolwood/SiamMask). 7 | ## ToDo List 8 | 9 | - [x] Release code 10 | - [x] Document for Installation 11 | - [x] Document for training and testing 12 | 13 | 14 | 15 | ## Installation 16 | 17 | ### Requirements: 18 | - Python 3.6 19 | - PyTorch >= 1.2 20 | - GCC 5.5 21 | - CUDA 9.2 22 | 23 | 24 | ```bash 25 | 26 | conda create --name scm python=3.6 27 | conda activate scm 28 | 29 | # install PyTorch with cuda-9.2 30 | conda install pytorch==1.5.0 torchvision==0.6.0 cudatoolkit=9.2 -c pytorch 31 | 32 | # python dependencies 33 | pip install -r requirement.txt 34 | 35 | # clone repo 36 | git clone https://github.com/lsabrinax/VideoTextSCM 37 | cd VideoTextSCM/ 38 | 39 | # build deformable convolution opertor 40 | cd assets/ops/dcn/ 41 | python setup.py build_ext --inplace 42 | ``` 43 | 44 | 45 | 46 | ## Datasets 47 | The root of the dataset directory can be ```VideoTextSCM/datasets/```. 48 | Download the converted ground-truth and data list [Baidu Drive](https://pan.baidu.com/s/1-r084b6l58Rhe__1SCBo6Q)(download code: 0e8b), [Google Drive](https://drive.google.com/drive/folders/13GkcaSLsXxTCbuFwUAHvBfbB6DB-5Fwq?usp=sharing). The images of each dataset can be obtained from official website. 49 | 50 | 51 | ## Testing 52 | run the below command to get the tracking results and submit the results to official website to get the performance 53 | 54 | ```CUDA_VISIBLE_DEVICES=0 python demo_textboxPP.py --input-root path-to-test-dataset --output-root path-to-save-result --sub-res --dataset icdar --weight-path path-to-embedding-model --scm-config path-to-scm-config --scm-weight-path path-to-scm-model``` 55 | 56 | 57 | ## Training 58 | ### SCM 59 | ```bash 60 | #download the pre-trained model 61 | cd VideoTextSCM/scm/experiments/siammask_sharp 62 | wget http://www.robots.ox.ac.uk/~qwang/SiamMask_VOT.pth 63 | 64 | #train the model 65 | cd VideoTextSCM 66 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train_scm.py --save-dir path-to-save-scm-model --pretrained \ 67 | ./scm/experiments/siammask_sharp/SiamMask_VOT.pth --config ./scm/experiments/siammask_sharp/config_icdar.json \ 68 | --batch 256 --epochs 20 69 | ``` 70 | 71 | ### Embedding 72 | Download totaltext_resnet50 [Baidu Drive](https://pan.baidu.com/s/1vxcdpOswTK6MxJyPIJlBkA) (download code: p6u3), [Google Drive](https://drive.google.com/open?id=1T9n0HTP3X3Y_nJ0D1ekMhCQRHntORLJG). 73 | ```bash 74 | cd db_model & mkdir weights # put totaltext_resnet50 in db_model/weights 75 | 76 | #train embedding 77 | cd VideoTextSCM 78 | CUDA_VISIBLE_DEVICES=0 python train_embedding.py --exp_name model-name --batch_size 3 --num_workers 8 --lr 0.0005 79 | ``` 80 | 81 | 82 | 83 | ## Citing the related works 84 | 85 | Please cite the related works in your publications if it helps your research: 86 | 87 | @article{gao2021video, 88 | title={Video Text Tracking With a Spatio-Temporal Complementary Model}, 89 | author={Gao, Yuzhe and Li, Xing and Zhang, Jiajian and Zhou, Yu and Jin, Dian and Wang, Jing and Zhu, Shenggao and Bai, Xiang}, 90 | journal={IEEE Transactions on Image Processing}, 91 | volume={30}, 92 | pages={9321--9331}, 93 | year={2021}, 94 | publisher={IEEE} 95 | } 96 | 97 | -------------------------------------------------------------------------------- /assets/ops/dcn/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions.deform_conv import deform_conv, modulated_deform_conv 2 | from .functions.deform_pool import deform_roi_pooling 3 | from .modules.deform_conv import (DeformConv, ModulatedDeformConv, 4 | DeformConvPack, ModulatedDeformConvPack) 5 | from .modules.deform_pool import (DeformRoIPooling, DeformRoIPoolingPack, 6 | ModulatedDeformRoIPoolingPack) 7 | 8 | __all__ = [ 9 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 10 | 'ModulatedDeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack', 11 | 'ModulatedDeformRoIPoolingPack', 'deform_conv', 'modulated_deform_conv', 12 | 'deform_roi_pooling' 13 | ] 14 | -------------------------------------------------------------------------------- /assets/ops/dcn/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lsabrinax/VideoTextSCM/d87ad1bbb6ada7573a02a82045ee1b9ead5861ad/assets/ops/dcn/functions/__init__.py -------------------------------------------------------------------------------- /assets/ops/dcn/functions/deform_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.nn.modules.utils import _pair 4 | 5 | from .. import deform_conv_cuda 6 | 7 | 8 | class DeformConvFunction(Function): 9 | 10 | @staticmethod 11 | def forward(ctx, 12 | input, 13 | offset, 14 | weight, 15 | stride=1, 16 | padding=0, 17 | dilation=1, 18 | groups=1, 19 | deformable_groups=1, 20 | im2col_step=64): 21 | if input is not None and input.dim() != 4: 22 | raise ValueError( 23 | "Expected 4D tensor as input, got {}D tensor instead.".format( 24 | input.dim())) 25 | ctx.stride = _pair(stride) 26 | ctx.padding = _pair(padding) 27 | ctx.dilation = _pair(dilation) 28 | ctx.groups = groups 29 | ctx.deformable_groups = deformable_groups 30 | ctx.im2col_step = im2col_step 31 | 32 | ctx.save_for_backward(input, offset, weight) 33 | 34 | output = input.new_empty( 35 | DeformConvFunction._output_size(input, weight, ctx.padding, 36 | ctx.dilation, ctx.stride)) 37 | 38 | ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones 39 | 40 | if not input.is_cuda: 41 | raise NotImplementedError 42 | else: 43 | cur_im2col_step = min(ctx.im2col_step, input.shape[0]) 44 | assert (input.shape[0] % 45 | cur_im2col_step) == 0, 'im2col step must divide batchsize' 46 | deform_conv_cuda.deform_conv_forward_cuda( 47 | input, weight, offset, output, ctx.bufs_[0], ctx.bufs_[1], 48 | weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], 49 | ctx.padding[1], ctx.padding[0], ctx.dilation[1], 50 | ctx.dilation[0], ctx.groups, ctx.deformable_groups, 51 | cur_im2col_step) 52 | return output 53 | 54 | @staticmethod 55 | def backward(ctx, grad_output): 56 | input, offset, weight = ctx.saved_tensors 57 | 58 | grad_input = grad_offset = grad_weight = None 59 | 60 | if not grad_output.is_cuda: 61 | raise NotImplementedError 62 | else: 63 | cur_im2col_step = min(ctx.im2col_step, input.shape[0]) 64 | assert (input.shape[0] % 65 | cur_im2col_step) == 0, 'im2col step must divide batchsize' 66 | 67 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 68 | grad_input = torch.zeros_like(input) 69 | grad_offset = torch.zeros_like(offset) 70 | deform_conv_cuda.deform_conv_backward_input_cuda( 71 | input, offset, grad_output, grad_input, 72 | grad_offset, weight, ctx.bufs_[0], weight.size(3), 73 | weight.size(2), ctx.stride[1], ctx.stride[0], 74 | ctx.padding[1], ctx.padding[0], ctx.dilation[1], 75 | ctx.dilation[0], ctx.groups, ctx.deformable_groups, 76 | cur_im2col_step) 77 | 78 | if ctx.needs_input_grad[2]: 79 | grad_weight = torch.zeros_like(weight) 80 | deform_conv_cuda.deform_conv_backward_parameters_cuda( 81 | input, offset, grad_output, 82 | grad_weight, ctx.bufs_[0], ctx.bufs_[1], weight.size(3), 83 | weight.size(2), ctx.stride[1], ctx.stride[0], 84 | ctx.padding[1], ctx.padding[0], ctx.dilation[1], 85 | ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1, 86 | cur_im2col_step) 87 | 88 | return (grad_input, grad_offset, grad_weight, None, None, None, None, 89 | None) 90 | 91 | @staticmethod 92 | def _output_size(input, weight, padding, dilation, stride): 93 | channels = weight.size(0) 94 | output_size = (input.size(0), channels) 95 | for d in range(input.dim() - 2): 96 | in_size = input.size(d + 2) 97 | pad = padding[d] 98 | kernel = dilation[d] * (weight.size(d + 2) - 1) + 1 99 | stride_ = stride[d] 100 | output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) 101 | if not all(map(lambda s: s > 0, output_size)): 102 | raise ValueError( 103 | "convolution input is too small (output would be {})".format( 104 | 'x'.join(map(str, output_size)))) 105 | return output_size 106 | 107 | 108 | class ModulatedDeformConvFunction(Function): 109 | 110 | @staticmethod 111 | def forward(ctx, 112 | input, 113 | offset, 114 | mask, 115 | weight, 116 | bias=None, 117 | stride=1, 118 | padding=0, 119 | dilation=1, 120 | groups=1, 121 | deformable_groups=1): 122 | ctx.stride = stride 123 | ctx.padding = padding 124 | ctx.dilation = dilation 125 | ctx.groups = groups 126 | ctx.deformable_groups = deformable_groups 127 | ctx.with_bias = bias is not None 128 | if not ctx.with_bias: 129 | bias = input.new_empty(1) # fake tensor 130 | if not input.is_cuda: 131 | raise NotImplementedError 132 | if weight.requires_grad or mask.requires_grad or offset.requires_grad \ 133 | or input.requires_grad: 134 | ctx.save_for_backward(input, offset, mask, weight, bias) 135 | output = input.new_empty( 136 | ModulatedDeformConvFunction._infer_shape(ctx, input, weight)) 137 | ctx._bufs = [input.new_empty(0), input.new_empty(0)] 138 | deform_conv_cuda.modulated_deform_conv_cuda_forward( 139 | input, weight, bias, ctx._bufs[0], offset, mask, output, 140 | ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride, 141 | ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, 142 | ctx.groups, ctx.deformable_groups, ctx.with_bias) 143 | return output 144 | 145 | @staticmethod 146 | def backward(ctx, grad_output): 147 | if not grad_output.is_cuda: 148 | raise NotImplementedError 149 | input, offset, mask, weight, bias = ctx.saved_tensors 150 | grad_input = torch.zeros_like(input) 151 | grad_offset = torch.zeros_like(offset) 152 | grad_mask = torch.zeros_like(mask) 153 | grad_weight = torch.zeros_like(weight) 154 | grad_bias = torch.zeros_like(bias) 155 | deform_conv_cuda.modulated_deform_conv_cuda_backward( 156 | input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], 157 | grad_input, grad_weight, grad_bias, grad_offset, grad_mask, 158 | grad_output, weight.shape[2], weight.shape[3], ctx.stride, 159 | ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, 160 | ctx.groups, ctx.deformable_groups, ctx.with_bias) 161 | if not ctx.with_bias: 162 | grad_bias = None 163 | 164 | return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, 165 | None, None, None, None, None) 166 | 167 | @staticmethod 168 | def _infer_shape(ctx, input, weight): 169 | n = input.size(0) 170 | channels_out = weight.size(0) 171 | height, width = input.shape[2:4] 172 | kernel_h, kernel_w = weight.shape[2:4] 173 | height_out = (height + 2 * ctx.padding - 174 | (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1 175 | width_out = (width + 2 * ctx.padding - 176 | (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1 177 | return n, channels_out, height_out, width_out 178 | 179 | 180 | deform_conv = DeformConvFunction.apply 181 | modulated_deform_conv = ModulatedDeformConvFunction.apply 182 | -------------------------------------------------------------------------------- /assets/ops/dcn/functions/deform_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | from .. import deform_pool_cuda 5 | 6 | 7 | class DeformRoIPoolingFunction(Function): 8 | 9 | @staticmethod 10 | def forward(ctx, 11 | data, 12 | rois, 13 | offset, 14 | spatial_scale, 15 | out_size, 16 | out_channels, 17 | no_trans, 18 | group_size=1, 19 | part_size=None, 20 | sample_per_part=4, 21 | trans_std=.0): 22 | ctx.spatial_scale = spatial_scale 23 | ctx.out_size = out_size 24 | ctx.out_channels = out_channels 25 | ctx.no_trans = no_trans 26 | ctx.group_size = group_size 27 | ctx.part_size = out_size if part_size is None else part_size 28 | ctx.sample_per_part = sample_per_part 29 | ctx.trans_std = trans_std 30 | 31 | assert 0.0 <= ctx.trans_std <= 1.0 32 | if not data.is_cuda: 33 | raise NotImplementedError 34 | 35 | n = rois.shape[0] 36 | output = data.new_empty(n, out_channels, out_size, out_size) 37 | output_count = data.new_empty(n, out_channels, out_size, out_size) 38 | deform_pool_cuda.deform_psroi_pooling_cuda_forward( 39 | data, rois, offset, output, output_count, ctx.no_trans, 40 | ctx.spatial_scale, ctx.out_channels, ctx.group_size, ctx.out_size, 41 | ctx.part_size, ctx.sample_per_part, ctx.trans_std) 42 | 43 | if data.requires_grad or rois.requires_grad or offset.requires_grad: 44 | ctx.save_for_backward(data, rois, offset) 45 | ctx.output_count = output_count 46 | 47 | return output 48 | 49 | @staticmethod 50 | def backward(ctx, grad_output): 51 | if not grad_output.is_cuda: 52 | raise NotImplementedError 53 | 54 | data, rois, offset = ctx.saved_tensors 55 | output_count = ctx.output_count 56 | grad_input = torch.zeros_like(data) 57 | grad_rois = None 58 | grad_offset = torch.zeros_like(offset) 59 | 60 | deform_pool_cuda.deform_psroi_pooling_cuda_backward( 61 | grad_output, data, rois, offset, output_count, grad_input, 62 | grad_offset, ctx.no_trans, ctx.spatial_scale, ctx.out_channels, 63 | ctx.group_size, ctx.out_size, ctx.part_size, ctx.sample_per_part, 64 | ctx.trans_std) 65 | return (grad_input, grad_rois, grad_offset, None, None, None, None, 66 | None, None, None, None) 67 | 68 | 69 | deform_roi_pooling = DeformRoIPoolingFunction.apply 70 | -------------------------------------------------------------------------------- /assets/ops/dcn/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lsabrinax/VideoTextSCM/d87ad1bbb6ada7573a02a82045ee1b9ead5861ad/assets/ops/dcn/modules/__init__.py -------------------------------------------------------------------------------- /assets/ops/dcn/modules/deform_conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.modules.utils import _pair 6 | 7 | from ..functions.deform_conv import deform_conv, modulated_deform_conv 8 | 9 | 10 | class DeformConv(nn.Module): 11 | 12 | def __init__(self, 13 | in_channels, 14 | out_channels, 15 | kernel_size, 16 | stride=1, 17 | padding=0, 18 | dilation=1, 19 | groups=1, 20 | deformable_groups=1, 21 | bias=False): 22 | super(DeformConv, self).__init__() 23 | 24 | assert not bias 25 | assert in_channels % groups == 0, \ 26 | 'in_channels {} cannot be divisible by groups {}'.format( 27 | in_channels, groups) 28 | assert out_channels % groups == 0, \ 29 | 'out_channels {} cannot be divisible by groups {}'.format( 30 | out_channels, groups) 31 | 32 | self.in_channels = in_channels 33 | self.out_channels = out_channels 34 | self.kernel_size = _pair(kernel_size) 35 | self.stride = _pair(stride) 36 | self.padding = _pair(padding) 37 | self.dilation = _pair(dilation) 38 | self.groups = groups 39 | self.deformable_groups = deformable_groups 40 | 41 | self.weight = nn.Parameter( 42 | torch.Tensor(out_channels, in_channels // self.groups, 43 | *self.kernel_size)) 44 | 45 | self.reset_parameters() 46 | 47 | def reset_parameters(self): 48 | n = self.in_channels 49 | for k in self.kernel_size: 50 | n *= k 51 | stdv = 1. / math.sqrt(n) 52 | self.weight.data.uniform_(-stdv, stdv) 53 | 54 | def forward(self, x, offset): 55 | return deform_conv(x, offset, self.weight, self.stride, self.padding, 56 | self.dilation, self.groups, self.deformable_groups) 57 | 58 | 59 | class DeformConvPack(DeformConv): 60 | 61 | def __init__(self, *args, **kwargs): 62 | super(DeformConvPack, self).__init__(*args, **kwargs) 63 | 64 | self.conv_offset = nn.Conv2d( 65 | self.in_channels, 66 | self.deformable_groups * 2 * self.kernel_size[0] * 67 | self.kernel_size[1], 68 | kernel_size=self.kernel_size, 69 | stride=_pair(self.stride), 70 | padding=_pair(self.padding), 71 | bias=True) 72 | self.init_offset() 73 | 74 | def init_offset(self): 75 | self.conv_offset.weight.data.zero_() 76 | self.conv_offset.bias.data.zero_() 77 | 78 | def forward(self, x): 79 | offset = self.conv_offset(x) 80 | return deform_conv(x, offset, self.weight, self.stride, self.padding, 81 | self.dilation, self.groups, self.deformable_groups) 82 | 83 | 84 | class ModulatedDeformConv(nn.Module): 85 | 86 | def __init__(self, 87 | in_channels, 88 | out_channels, 89 | kernel_size, 90 | stride=1, 91 | padding=0, 92 | dilation=1, 93 | groups=1, 94 | deformable_groups=1, 95 | bias=True): 96 | super(ModulatedDeformConv, self).__init__() 97 | self.in_channels = in_channels 98 | self.out_channels = out_channels 99 | self.kernel_size = _pair(kernel_size) 100 | self.stride = stride 101 | self.padding = padding 102 | self.dilation = dilation 103 | self.groups = groups 104 | self.deformable_groups = deformable_groups 105 | self.with_bias = bias 106 | 107 | self.weight = nn.Parameter( 108 | torch.Tensor(out_channels, in_channels // groups, 109 | *self.kernel_size)) 110 | if bias: 111 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 112 | else: 113 | self.register_parameter('bias', None) 114 | self.reset_parameters() 115 | 116 | def reset_parameters(self): 117 | n = self.in_channels 118 | for k in self.kernel_size: 119 | n *= k 120 | stdv = 1. / math.sqrt(n) 121 | self.weight.data.uniform_(-stdv, stdv) 122 | if self.bias is not None: 123 | self.bias.data.zero_() 124 | 125 | def forward(self, x, offset, mask): 126 | return modulated_deform_conv(x, offset, mask, self.weight, self.bias, 127 | self.stride, self.padding, self.dilation, 128 | self.groups, self.deformable_groups) 129 | 130 | 131 | class ModulatedDeformConvPack(ModulatedDeformConv): 132 | 133 | def __init__(self, *args, **kwargs): 134 | super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) 135 | 136 | self.conv_offset_mask = nn.Conv2d( 137 | self.in_channels, 138 | self.deformable_groups * 3 * self.kernel_size[0] * 139 | self.kernel_size[1], 140 | kernel_size=self.kernel_size, 141 | stride=_pair(self.stride), 142 | padding=_pair(self.padding), 143 | bias=True) 144 | self.init_offset() 145 | 146 | def init_offset(self): 147 | self.conv_offset_mask.weight.data.zero_() 148 | self.conv_offset_mask.bias.data.zero_() 149 | 150 | def forward(self, x): 151 | out = self.conv_offset_mask(x) 152 | o1, o2, mask = torch.chunk(out, 3, dim=1) 153 | offset = torch.cat((o1, o2), dim=1) 154 | mask = torch.sigmoid(mask) 155 | return modulated_deform_conv(x, offset, mask, self.weight, self.bias, 156 | self.stride, self.padding, self.dilation, 157 | self.groups, self.deformable_groups) 158 | -------------------------------------------------------------------------------- /assets/ops/dcn/modules/deform_pool.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from ..functions.deform_pool import deform_roi_pooling 4 | 5 | 6 | class DeformRoIPooling(nn.Module): 7 | 8 | def __init__(self, 9 | spatial_scale, 10 | out_size, 11 | out_channels, 12 | no_trans, 13 | group_size=1, 14 | part_size=None, 15 | sample_per_part=4, 16 | trans_std=.0): 17 | super(DeformRoIPooling, self).__init__() 18 | self.spatial_scale = spatial_scale 19 | self.out_size = out_size 20 | self.out_channels = out_channels 21 | self.no_trans = no_trans 22 | self.group_size = group_size 23 | self.part_size = out_size if part_size is None else part_size 24 | self.sample_per_part = sample_per_part 25 | self.trans_std = trans_std 26 | 27 | def forward(self, data, rois, offset): 28 | if self.no_trans: 29 | offset = data.new_empty(0) 30 | return deform_roi_pooling( 31 | data, rois, offset, self.spatial_scale, self.out_size, 32 | self.out_channels, self.no_trans, self.group_size, self.part_size, 33 | self.sample_per_part, self.trans_std) 34 | 35 | 36 | class DeformRoIPoolingPack(DeformRoIPooling): 37 | 38 | def __init__(self, 39 | spatial_scale, 40 | out_size, 41 | out_channels, 42 | no_trans, 43 | group_size=1, 44 | part_size=None, 45 | sample_per_part=4, 46 | trans_std=.0, 47 | num_offset_fcs=3, 48 | deform_fc_channels=1024): 49 | super(DeformRoIPoolingPack, 50 | self).__init__(spatial_scale, out_size, out_channels, no_trans, 51 | group_size, part_size, sample_per_part, trans_std) 52 | 53 | self.num_offset_fcs = num_offset_fcs 54 | self.deform_fc_channels = deform_fc_channels 55 | 56 | if not no_trans: 57 | seq = [] 58 | ic = self.out_size * self.out_size * self.out_channels 59 | for i in range(self.num_offset_fcs): 60 | if i < self.num_offset_fcs - 1: 61 | oc = self.deform_fc_channels 62 | else: 63 | oc = self.out_size * self.out_size * 2 64 | seq.append(nn.Linear(ic, oc)) 65 | ic = oc 66 | if i < self.num_offset_fcs - 1: 67 | seq.append(nn.ReLU(inplace=True)) 68 | self.offset_fc = nn.Sequential(*seq) 69 | self.offset_fc[-1].weight.data.zero_() 70 | self.offset_fc[-1].bias.data.zero_() 71 | 72 | def forward(self, data, rois): 73 | assert data.size(1) == self.out_channels 74 | if self.no_trans: 75 | offset = data.new_empty(0) 76 | return deform_roi_pooling( 77 | data, rois, offset, self.spatial_scale, self.out_size, 78 | self.out_channels, self.no_trans, self.group_size, 79 | self.part_size, self.sample_per_part, self.trans_std) 80 | else: 81 | n = rois.shape[0] 82 | offset = data.new_empty(0) 83 | x = deform_roi_pooling(data, rois, offset, self.spatial_scale, 84 | self.out_size, self.out_channels, True, 85 | self.group_size, self.part_size, 86 | self.sample_per_part, self.trans_std) 87 | offset = self.offset_fc(x.view(n, -1)) 88 | offset = offset.view(n, 2, self.out_size, self.out_size) 89 | return deform_roi_pooling( 90 | data, rois, offset, self.spatial_scale, self.out_size, 91 | self.out_channels, self.no_trans, self.group_size, 92 | self.part_size, self.sample_per_part, self.trans_std) 93 | 94 | 95 | class ModulatedDeformRoIPoolingPack(DeformRoIPooling): 96 | 97 | def __init__(self, 98 | spatial_scale, 99 | out_size, 100 | out_channels, 101 | no_trans, 102 | group_size=1, 103 | part_size=None, 104 | sample_per_part=4, 105 | trans_std=.0, 106 | num_offset_fcs=3, 107 | num_mask_fcs=2, 108 | deform_fc_channels=1024): 109 | super(ModulatedDeformRoIPoolingPack, self).__init__( 110 | spatial_scale, out_size, out_channels, no_trans, group_size, 111 | part_size, sample_per_part, trans_std) 112 | 113 | self.num_offset_fcs = num_offset_fcs 114 | self.num_mask_fcs = num_mask_fcs 115 | self.deform_fc_channels = deform_fc_channels 116 | 117 | if not no_trans: 118 | offset_fc_seq = [] 119 | ic = self.out_size * self.out_size * self.out_channels 120 | for i in range(self.num_offset_fcs): 121 | if i < self.num_offset_fcs - 1: 122 | oc = self.deform_fc_channels 123 | else: 124 | oc = self.out_size * self.out_size * 2 125 | offset_fc_seq.append(nn.Linear(ic, oc)) 126 | ic = oc 127 | if i < self.num_offset_fcs - 1: 128 | offset_fc_seq.append(nn.ReLU(inplace=True)) 129 | self.offset_fc = nn.Sequential(*offset_fc_seq) 130 | self.offset_fc[-1].weight.data.zero_() 131 | self.offset_fc[-1].bias.data.zero_() 132 | 133 | mask_fc_seq = [] 134 | ic = self.out_size * self.out_size * self.out_channels 135 | for i in range(self.num_mask_fcs): 136 | if i < self.num_mask_fcs - 1: 137 | oc = self.deform_fc_channels 138 | else: 139 | oc = self.out_size * self.out_size 140 | mask_fc_seq.append(nn.Linear(ic, oc)) 141 | ic = oc 142 | if i < self.num_mask_fcs - 1: 143 | mask_fc_seq.append(nn.ReLU(inplace=True)) 144 | else: 145 | mask_fc_seq.append(nn.Sigmoid()) 146 | self.mask_fc = nn.Sequential(*mask_fc_seq) 147 | self.mask_fc[-2].weight.data.zero_() 148 | self.mask_fc[-2].bias.data.zero_() 149 | 150 | def forward(self, data, rois): 151 | assert data.size(1) == self.out_channels 152 | if self.no_trans: 153 | offset = data.new_empty(0) 154 | return deform_roi_pooling( 155 | data, rois, offset, self.spatial_scale, self.out_size, 156 | self.out_channels, self.no_trans, self.group_size, 157 | self.part_size, self.sample_per_part, self.trans_std) 158 | else: 159 | n = rois.shape[0] 160 | offset = data.new_empty(0) 161 | x = deform_roi_pooling(data, rois, offset, self.spatial_scale, 162 | self.out_size, self.out_channels, True, 163 | self.group_size, self.part_size, 164 | self.sample_per_part, self.trans_std) 165 | offset = self.offset_fc(x.view(n, -1)) 166 | offset = offset.view(n, 2, self.out_size, self.out_size) 167 | mask = self.mask_fc(x.view(n, -1)) 168 | mask = mask.view(n, 1, self.out_size, self.out_size) 169 | return deform_roi_pooling( 170 | data, rois, offset, self.spatial_scale, self.out_size, 171 | self.out_channels, self.no_trans, self.group_size, 172 | self.part_size, self.sample_per_part, self.trans_std) * mask 173 | -------------------------------------------------------------------------------- /assets/ops/dcn/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='deform_conv', 6 | ext_modules=[ 7 | CUDAExtension('deform_conv_cuda', [ 8 | 'src/deform_conv_cuda.cpp', 9 | 'src/deform_conv_cuda_kernel.cu', 10 | ]), 11 | CUDAExtension('deform_pool_cuda', [ 12 | 'src/deform_pool_cuda.cpp', 'src/deform_pool_cuda_kernel.cu' 13 | ]), 14 | ], 15 | cmdclass={'build_ext': BuildExtension}) 16 | -------------------------------------------------------------------------------- /assets/ops/dcn/src/deform_pool_cuda.cpp: -------------------------------------------------------------------------------- 1 | // modify from 2 | // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c 3 | 4 | // based on 5 | // author: Charles Shang 6 | // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | void DeformablePSROIPoolForward( 14 | const at::Tensor data, const at::Tensor bbox, const at::Tensor trans, 15 | at::Tensor out, at::Tensor top_count, const int batch, const int channels, 16 | const int height, const int width, const int num_bbox, 17 | const int channels_trans, const int no_trans, const float spatial_scale, 18 | const int output_dim, const int group_size, const int pooled_size, 19 | const int part_size, const int sample_per_part, const float trans_std); 20 | 21 | void DeformablePSROIPoolBackwardAcc( 22 | const at::Tensor out_grad, const at::Tensor data, const at::Tensor bbox, 23 | const at::Tensor trans, const at::Tensor top_count, at::Tensor in_grad, 24 | at::Tensor trans_grad, const int batch, const int channels, 25 | const int height, const int width, const int num_bbox, 26 | const int channels_trans, const int no_trans, const float spatial_scale, 27 | const int output_dim, const int group_size, const int pooled_size, 28 | const int part_size, const int sample_per_part, const float trans_std); 29 | 30 | void deform_psroi_pooling_cuda_forward( 31 | at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out, 32 | at::Tensor top_count, const int no_trans, const float spatial_scale, 33 | const int output_dim, const int group_size, const int pooled_size, 34 | const int part_size, const int sample_per_part, const float trans_std) { 35 | TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); 36 | 37 | const int batch = input.size(0); 38 | const int channels = input.size(1); 39 | const int height = input.size(2); 40 | const int width = input.size(3); 41 | const int channels_trans = no_trans ? 2 : trans.size(1); 42 | 43 | const int num_bbox = bbox.size(0); 44 | if (num_bbox != out.size(0)) 45 | AT_ERROR("Output shape and bbox number wont match: (%d vs %d).", 46 | out.size(0), num_bbox); 47 | 48 | DeformablePSROIPoolForward( 49 | input, bbox, trans, out, top_count, batch, channels, height, width, 50 | num_bbox, channels_trans, no_trans, spatial_scale, output_dim, group_size, 51 | pooled_size, part_size, sample_per_part, trans_std); 52 | } 53 | 54 | void deform_psroi_pooling_cuda_backward( 55 | at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans, 56 | at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad, 57 | const int no_trans, const float spatial_scale, const int output_dim, 58 | const int group_size, const int pooled_size, const int part_size, 59 | const int sample_per_part, const float trans_std) { 60 | TORCH_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous"); 61 | TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); 62 | 63 | const int batch = input.size(0); 64 | const int channels = input.size(1); 65 | const int height = input.size(2); 66 | const int width = input.size(3); 67 | const int channels_trans = no_trans ? 2 : trans.size(1); 68 | 69 | const int num_bbox = bbox.size(0); 70 | if (num_bbox != out_grad.size(0)) 71 | AT_ERROR("Output shape and bbox number wont match: (%d vs %d).", 72 | out_grad.size(0), num_bbox); 73 | 74 | DeformablePSROIPoolBackwardAcc( 75 | out_grad, input, bbox, trans, top_count, input_grad, trans_grad, batch, 76 | channels, height, width, num_bbox, channels_trans, no_trans, 77 | spatial_scale, output_dim, group_size, pooled_size, part_size, 78 | sample_per_part, trans_std); 79 | } 80 | 81 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 82 | m.def("deform_psroi_pooling_cuda_forward", &deform_psroi_pooling_cuda_forward, 83 | "deform psroi pooling forward(CUDA)"); 84 | m.def("deform_psroi_pooling_cuda_backward", 85 | &deform_psroi_pooling_cuda_backward, 86 | "deform psroi pooling backward(CUDA)"); 87 | } -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import random 4 | import os 5 | from torch.utils.data import Dataset 6 | from utils import utils 7 | import torch 8 | 9 | class VideoDataset(Dataset): 10 | def __init__(self,img_dir, gt_dir, train_list_path, size=(1280, 1280)): 11 | ''' 12 | size = (resize_w,resize_h) 13 | ''' 14 | with open(train_list_path, 'r') as fopen: 15 | train_list = [line.strip() for line in fopen.readlines()] 16 | self.train_list = train_list 17 | self.img_dir = img_dir 18 | self.gt_dir = gt_dir 19 | self.size = size 20 | self.RGB_MEAN = np.array([122.67891434, 116.66876762, 104.00698793]) 21 | 22 | def __len__(self): 23 | return len(self.train_list) 24 | 25 | def load_img(self, img_name): 26 | image_path = os.path.join(self.img_dir, img_name) 27 | img = cv2.imread(image_path, cv2.IMREAD_COLOR).astype('float32') 28 | original_shape = img.shape[:2] #(h,w) 29 | img = cv2.resize(img, self.size) 30 | img -= self.RGB_MEAN 31 | img /= 255. 32 | img = torch.from_numpy(img).permute(2, 0, 1).float() 33 | return img, original_shape 34 | 35 | def perturb_box(self, box, min_iou=0.5, sigma_factor=0.1): 36 | """ Perturb the input box by adding gaussian noise to the co-ordinates 37 | args: 38 | box - input box top-left and w、h 39 | min_iou - minimum IoU overlap between input box and the perturbed box 40 | sigma_factor - amount of perturbation, relative to the box size. Can be either a single element, or a list of 41 | sigma_factors, in which case one of them will be uniformly sampled. Further, each of the 42 | sigma_factor element can be either a float, or a tensor 43 | of shape (4,) specifying the sigma_factor per co-ordinate 44 | returns: 45 | torch.Tensor - the perturbed box 46 | """ 47 | def rand_uniform(a, b, shape=1): 48 | """ sample numbers uniformly between a and b. 49 | args: 50 | a - lower bound 51 | b - upper bound 52 | shape - shape of the output tensor 53 | returns: 54 | torch.Tensor - tensor of shape=shape 55 | """ 56 | return (b - a) * torch.rand(shape) + a 57 | 58 | if isinstance(sigma_factor, list): 59 | # If list, sample one sigma_factor as current sigma factor 60 | c_sigma_factor = random.choice(sigma_factor) 61 | else: 62 | c_sigma_factor = sigma_factor 63 | if not isinstance(c_sigma_factor, torch.Tensor): 64 | c_sigma_factor = c_sigma_factor * torch.ones(4) 65 | perturb_factor = torch.sqrt(box[2]*box[3])*c_sigma_factor 66 | 67 | # multiple tries to ensure that the perturbed box has iou > min_iou with the input box 68 | for i_ in range(100): 69 | c_x = box[0] + 0.5*box[2] 70 | c_y = box[1] + 0.5 * box[3] 71 | c_x_per = random.gauss(c_x, perturb_factor[0]) 72 | c_y_per = random.gauss(c_y, perturb_factor[1]) 73 | w_per = random.gauss(box[2], perturb_factor[2]) 74 | h_per = random.gauss(box[3], perturb_factor[3]) 75 | 76 | if w_per <= 1: 77 | w_per = box[2]*rand_uniform(0.15, 0.5) 78 | if h_per <= 1: 79 | h_per = box[3]*rand_uniform(0.15, 0.5) 80 | box_per = torch.Tensor([c_x_per - 0.5*w_per, c_y_per - 0.5*h_per, w_per, h_per]).round() 81 | if box_per[2] <= 1: 82 | box_per[2] = box[2]*rand_uniform(0.15, 0.5) 83 | if box_per[3] <= 1: 84 | box_per[3] = box[3]*rand_uniform(0.15, 0.5) 85 | box_iou = utils.iou(box.view(1, 4), box_per.view(1, 4)) 86 | 87 | # if there is sufficient overlap, return 88 | if box_iou > min_iou: 89 | return box_per 90 | # else reduce the perturb factor 91 | perturb_factor *= 0.9 92 | 93 | return box 94 | 95 | def get_all_boxes(self, cur_txt, next_txt ,cur_original_shape, next_original_shape): 96 | cur_txt_path = os.path.join(self.gt_dir, cur_txt) 97 | next_txt_path = os.path.join(self.gt_dir, next_txt) 98 | 99 | cur_h, cur_w = cur_original_shape 100 | next_h, next_w = next_original_shape 101 | resize_w, resize_h = self.size 102 | 103 | cur_boxes = [] 104 | cur_objs = [] 105 | with open(cur_txt_path, 'rb') as fopen: 106 | lines = [line for line in fopen.readlines()] 107 | lines = [line.decode("utf-8").strip() for line in lines] 108 | for line in lines: 109 | line_split = line.split(',') 110 | box = list(map(int, map(float, line_split[:8]))) 111 | x1,y1,x2,y2,x3,y3,x4,y4 = box 112 | region_point = np.array([[x1,y1],[x2,y2],[x3,y3],[x4,y4]]) #shape=(k,2) 113 | x1,y1,w,h = cv2.boundingRect(region_point) 114 | x1,y1,w,h = self.perturb_box(torch.Tensor([x1,y1,w,h]), 0.8) 115 | x1,y1,x3,y3 = int(x1*resize_w/cur_w), int(y1*resize_h/cur_h), int((x1+w)*resize_w/cur_w), int((y1+h)*resize_h/cur_h) 116 | x1 = min(resize_w-1, max(0, x1)) 117 | x3 = min(resize_w-1, max(0, x3)) 118 | y1 = min(resize_h-1, max(0, y1)) 119 | y3 = min(resize_h-1, max(0, y3)) 120 | cur_boxes.append([x1,y1,x3,y3]) 121 | cur_objs.append(int(line_split[8])) 122 | 123 | next_boxes = [] 124 | next_objs = [] 125 | with open(next_txt_path, 'rb') as fopen: 126 | lines = [line for line in fopen.readlines()] 127 | lines = [line.decode("utf-8").strip() for line in lines] 128 | for line in lines: 129 | line_split = line.split(',') 130 | box = list(map(int, map(float, line_split[:8]))) 131 | x1,y1,x2,y2,x3,y3,x4,y4 = box 132 | region_point = np.array([[x1,y1],[x2,y2],[x3,y3],[x4,y4]]) #shape=(k,2) 133 | x1,y1,w,h = cv2.boundingRect(region_point) 134 | x1,y1,w,h = self.perturb_box(torch.Tensor([x1,y1,w,h]), 0.8) 135 | x1,y1,x3,y3 = int(x1*resize_w/next_w), int(y1*resize_h/next_h), int((x1+w)*resize_w/next_w), int((y1+h)*resize_h/next_h) 136 | x1 = min(resize_w-1, max(0, x1)) 137 | x3 = min(resize_w-1, max(0, x3)) 138 | y1 = min(resize_h-1, max(0, y1)) 139 | y3 = min(resize_h-1, max(0, y3)) 140 | next_boxes.append([x1,y1,x3,y3]) 141 | cur_objs.append(int(line_split[8])) 142 | cur_boxes = torch.Tensor(cur_boxes) 143 | next_boxes = torch.Tensor(next_boxes) 144 | return cur_boxes, next_boxes, cur_objs, next_objs 145 | 146 | def get_triple_list(self,cur_objs,next_objs): 147 | def get_pair(obj_index): 148 | pairs = [] 149 | for i in range(len(obj_index)): 150 | for j in range(i+1, len(obj_index)): 151 | pairs.append([obj_index[i], obj_index[j]]) 152 | return pairs 153 | 154 | objs = cur_objs+next_objs 155 | objs = np.array(objs) 156 | obj_id_unique = np.unique(objs) 157 | 158 | objs_index = [] 159 | for obj_id in obj_id_unique: 160 | objs_index.append(np.where(objs==obj_id)[0].tolist()) 161 | 162 | triple_list = [] 163 | for i in range(len(objs_index)): 164 | obj_index = objs_index[i] 165 | if len(obj_index) < 2: 166 | continue 167 | pairs = get_pair(obj_index) 168 | for j in range(len(objs_index)): 169 | if j == i: 170 | continue 171 | for neg in objs_index[j]: 172 | for pair in pairs: 173 | triple = pair+[neg] 174 | triple_list.append(triple.copy()) 175 | return triple_list 176 | 177 | def __getitem__(self, index): 178 | img_pair = self.train_list[index] 179 | imgcur_name, imgnext_name = img_pair.split(',') 180 | cur_txt, next_txt = imgcur_name+'.txt', imgnext_name+'.txt' 181 | imgcur, cur_original_shape = self.load_img(imgcur_name) 182 | imgnext , next_original_shape= self.load_img(imgnext_name) 183 | cur_boxes, next_boxes, cur_objs, next_objs = self.get_all_boxes(cur_txt, next_txt ,cur_original_shape, next_original_shape) 184 | triple_list = self.get_triple_list(cur_objs, next_objs) 185 | triple_list = torch.Tensor(triple_list).long() 186 | return imgcur, imgnext, cur_boxes, next_boxes, triple_list 187 | 188 | class LoadVideo: 189 | def __init__(self, path): 190 | self.cap = cv2.VideoCapture(path) 191 | self.frame_rate = int(round(self.cap.get(cv2.CAP_PROP_FPS))) 192 | self.vw = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 193 | self.vh = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 194 | self.vn = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) 195 | self.count = 0 196 | base_name = os.path.basename(path).split('.')[0] 197 | self.imgdir = os.path.join(os.path.dirname(path), base_name) 198 | utils.mkdir_if_missing(self.imgdir) 199 | print('Lenth of the video: {:d} frames'.format(self.vn)) 200 | 201 | def __iter__(self): 202 | self.count = -1 203 | return self 204 | 205 | def __next__(self): 206 | self.count += 1 207 | if self.count == len(self): 208 | raise StopIteration 209 | 210 | res, img0 = self.cap.read() # BGR 211 | assert img0 is not None, 'Failed to load frame {:d}'.format(self.count) 212 | 213 | img_path = os.path.join(self.imgdir, str(self.count)+'.jpg') #frame_id start from 0 214 | if not os.path.isfile(img_path): 215 | cv2.imwrite(img_path, img0) 216 | 217 | return img_path, img0 218 | 219 | def __len__(self): 220 | return self.vn # frames of video -------------------------------------------------------------------------------- /db_model/db_embedding.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import cv2 4 | import numpy as np 5 | import torch 6 | import math 7 | import logging 8 | from .model_v3 import DB_Embedding_Model 9 | from scm.experiments.siammask_sharp.custom import Custom 10 | from .representers.seg_detector_representer import SegDetectorRepresenter 11 | from utils.utils import make_seg_shrink 12 | from scm.tools.track2mask import track_all_objs2mask 13 | 14 | logger = logging.getLogger('root') 15 | class Demo: 16 | def __init__(self, weight_path, scm_weight_path, scm_config, img_min_size, conf_thresh): 17 | self.RGB_MEAN = np.array([122.67891434, 116.66876762, 104.00698793], dtype=np.float32) 18 | if torch.cuda.is_available(): 19 | torch.cuda.set_device(0) 20 | self.device = torch.device('cuda') 21 | else: 22 | self.device = torch.device('cpu') 23 | self.weight_path = weight_path 24 | self.scm_weight_path = scm_weight_path 25 | self.scm_config = scm_config 26 | self.min_size = img_min_size 27 | self.init_torch_tensor() 28 | self.model = self.init_model() 29 | self.model.eval() 30 | self.scm = self.init_scm() 31 | self.scm.eval() 32 | self.segdetector_representer = SegDetectorRepresenter() 33 | self.box_thresh = conf_thresh #0.65 34 | 35 | def init_torch_tensor(self): 36 | # Use gpu or not 37 | torch.set_default_tensor_type('torch.FloatTensor') 38 | if torch.cuda.is_available(): 39 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 40 | 41 | def init_model(self): 42 | logger.info("model init...") 43 | model = DB_Embedding_Model().to(self.device) 44 | logger.info("model init success!") 45 | 46 | logger.info("resume from {}...".format(self.weight_path)) 47 | if not os.path.exists(self.weight_path): 48 | logger.info("Checkpoint not found: " + self.weight_path) 49 | return 50 | states = torch.load(self.weight_path, map_location=self.device) 51 | model.load_state_dict(states, strict=True) 52 | return model 53 | 54 | def init_scm(self): 55 | with open(self.scm_config, 'r') as f: 56 | cfg = json.load(f) 57 | self.scm_cfg = cfg 58 | scm= Custom(anchors=cfg['anchors']).to(self.device) 59 | logger.info('load pretrained scm model from {}'.format(self.scm_weight_path)) 60 | if not os.path.exists(self.scm_weight_path): 61 | logger.info("scm Checkpoint not found: " + self.scm_weight_path) 62 | return 63 | states = torch.load(self.scm_weight_path, map_location=self.device) 64 | if 'state_dict' in states: 65 | states = states['state_dict'] 66 | 67 | #remove module 68 | new_sate = {} 69 | for key, value in states.items(): 70 | if key.startswith('module.'): 71 | key = key.split('module.', 1)[-1] 72 | new_sate[key] = value 73 | scm.load_state_dict(new_sate, strict=True) 74 | return scm 75 | 76 | def resize_image(self, img): 77 | height, width, _ = img.shape 78 | if height < width: 79 | new_height = self.min_size 80 | new_width = int(math.ceil(new_height / height * width / 32) * 32) 81 | else: 82 | new_width = self.min_size 83 | new_height = int(math.ceil(new_width / width * height / 32) * 32) 84 | resized_img = cv2.resize(img, (new_width, new_height)) 85 | self.size = (new_width, new_height) 86 | return resized_img 87 | 88 | def load_image(self, img0): 89 | img = img0.astype(np.float32) 90 | original_shape = img.shape[:2] #(h,w) 91 | img = self.resize_image(img) 92 | resize_img = img.copy() 93 | img -= self.RGB_MEAN 94 | img /= 255. 95 | img = torch.from_numpy(img).permute(2, 0, 1).float().unsqueeze(0) #shape=(1,3,h,w) 96 | return img, original_shape, resize_img 97 | 98 | def get_output(self, batch_boxes, batch_scores, feat_refine, original_shape, pp=False): 99 | ''' 100 | batch_boxes: list[np.array] 101 | batch_boxes[0]: np.array, shape=(n,4,2) 102 | batch_scores: list[np.array] 103 | batch_scores[0]: np.array, shape=(n) 104 | feat_refine: tensor, shape=(n, 256, h/4, w/4) 105 | original_shape = (h,w) 106 | ''' 107 | assert len(batch_scores) == 1 108 | assert len(batch_boxes) == 1 109 | batch_boxes = batch_boxes[0] 110 | batch_scores = batch_scores[0] 111 | 112 | #Firstly, remove boxes which score below box_thresh 113 | idx = batch_scores >= self.box_thresh 114 | boxes_8 = batch_boxes[idx].reshape(-1, 8) 115 | boxes_8 = torch.from_numpy(boxes_8).float() 116 | boxes_score = batch_scores[idx] 117 | boxes_score = torch.from_numpy(boxes_score) 118 | if boxes_score.shape[0] < 1: 119 | boxes_feat = torch.Tensor([]).cpu() 120 | boxes_4 = torch.Tensor([]).cpu() 121 | boxes8_score = torch.Tensor([]).cpu() 122 | return boxes_feat, boxes_4, boxes8_score, True 123 | 124 | #Secondly, get the min area horizontal rectangle 125 | x0s = torch.min(boxes_8[:,::2],dim=1,keepdim=True)[0] 126 | y0s = torch.min(boxes_8[:,1::2],dim=1,keepdim=True)[0] 127 | x2s = torch.max(boxes_8[:,::2],dim=1,keepdim=True)[0] 128 | y2s = torch.max(boxes_8[:,1::2],dim=1,keepdim=True)[0] 129 | boxes_4 = torch.cat((x0s,y0s,x2s,y2s),dim=1).float() 130 | if pp: 131 | boxes_area = (x2s-x0s) * (y2s -y0s) 132 | area_idx = boxes_area<5000 133 | boxes_4 = boxes_4[area_idx.expand_as(boxes_4)].reshape(-1,4) 134 | boxes_8 = boxes_8[area_idx.expand_as(boxes_8)].reshape(-1,8) 135 | boxes_score = boxes_score[area_idx.squeeze(-1)] 136 | assert boxes_8.shape[0] == boxes_score.shape[0] 137 | if boxes_score.shape[0] < 1: 138 | boxes_feat = torch.Tensor([]).cpu() 139 | boxes_4 = torch.Tensor([]).cpu() 140 | boxes8_score = torch.Tensor([]).cpu() 141 | return boxes_feat, boxes_4, boxes8_score, True 142 | 143 | #Thirdly, crop text instances feature 144 | boxes_4_resize = torch.zeros_like(boxes_4) 145 | boxes_4_resize[:,::2] = boxes_4[:,::2] / original_shape[1] * self.size[0] 146 | boxes_4_resize[:,1::2] = boxes_4[:,1::2] / original_shape[0] * self.size[1] 147 | zeros = torch.zeros((boxes_4.shape[0], 1)).cpu().float() 148 | boxes_4_resize = boxes_4_resize.float() 149 | boxes_4_resize = torch.cat((zeros, boxes_4_resize),dim=1).to(self.device) #shape=(n, 5), 5=[0,x1,y1,x2,y2] 150 | boxes_feat = self.model.roi_align(feat_refine, boxes_4_resize) #shape=(n,256,5,16) 151 | 152 | #Fourthly, pred sem and vis feature of text instances 153 | with torch.no_grad(): 154 | obj_num = boxes_feat.shape[0] 155 | # visual feat 156 | v_boxes_feat = boxes_feat.view(obj_num, -1) #shape=(n,256*5*16) 157 | v_boxes_feat = self.model.embed_layers(v_boxes_feat) #shape=(n,256) 158 | # Semantic feat 159 | s_boxes_feat = self.model.seq_conv(boxes_feat) 160 | s_boxes_feat = s_boxes_feat.squeeze(2).permute(2,0,1).contiguous() 161 | s_boxes_feat = self.model.rnn(s_boxes_feat).permute(1, 2, 0).contiguous().view(obj_num, -1) 162 | boxes_feat = torch.cat((v_boxes_feat, s_boxes_feat), -1) 163 | # boxes_feat = s_boxes_feat 164 | 165 | boxes_feat = boxes_feat.cpu() #shaep=(n,512) 166 | boxes8_score = torch.cat((boxes_8, boxes_score.view(-1,1)), dim=1) 167 | return boxes_feat, boxes_4, boxes8_score, False 168 | 169 | def inference(self, image_path, img0, add_vot_track=False, pre_img0=None, pre_boxes=None, add_mask=True): 170 | 171 | batch = dict() 172 | batch['filename'] = [image_path] 173 | pp = '7_4' in image_path 174 | img, original_shape, resize_img = self.load_image(img0) 175 | batch['shape'] = [original_shape] 176 | 177 | with torch.no_grad(): 178 | img = img.to(self.device) 179 | batch['image'] = img 180 | pred, feat_refine = self.model.forward(img) 181 | 182 | if add_vot_track and torch.is_tensor(pre_boxes) and pre_boxes.shape[0] > 0: 183 | pre_img, pre_original_shape,pre_resize_img = self.load_image(pre_img0) 184 | pre_boxes[:,:8:2] *= pre_resize_img.shape[1]/pre_original_shape[1] 185 | pre_boxes[:,1:8:2] *= pre_resize_img.shape[0]/pre_original_shape[0] 186 | track_mask, polygons = track_all_objs2mask(pre_resize_img, resize_img, pre_boxes, self.device, self.scm, self.scm_cfg) 187 | shrink_mask = make_seg_shrink(polygons, [0 for i in range(polygons.shape[0])], resize_img.shape[0], resize_img.shape[1]) 188 | track_mask *= shrink_mask 189 | track_mask = torch.Tensor(track_mask).to(self.device) 190 | track_mask = track_mask.unsqueeze(0).unsqueeze(0) 191 | track_mask[pred<0.6]=0 #0.6 192 | pred += track_mask 193 | output = self.segdetector_representer.represent(batch, pred, is_output_polygon=False) 194 | batch_boxes, batch_scores = output 195 | else: 196 | output = self.segdetector_representer.represent(batch, pred, is_output_polygon=False) 197 | batch_boxes, batch_scores = output 198 | 199 | pred_f, boxes_4, boxes8_score, no_objs = self.get_output(batch_boxes, batch_scores, feat_refine, original_shape, pp) 200 | return pred_f, boxes_4, boxes8_score, no_objs 201 | -------------------------------------------------------------------------------- /db_model/model_v3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.ops import RoIAlign 4 | from .resnet import deformable_resnet50 5 | 6 | class BidirectionalLSTM(nn.Module): 7 | 8 | def __init__(self, nIn, nHidden, nOut): 9 | super(BidirectionalLSTM, self).__init__() 10 | 11 | self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) 12 | self.embedding = nn.Linear(nHidden * 2, nOut) 13 | 14 | def forward(self, input): 15 | recurrent, _ = self.rnn(input) # T * B * (D*Hout) 16 | T, b, h = recurrent.size() 17 | t_rec = recurrent.view(T * b, h) 18 | 19 | output = self.embedding(t_rec) # [T * b, nOut] 20 | output = output.view(T, b, -1) 21 | 22 | return output 23 | 24 | class SegDetector(nn.Module): 25 | def __init__(self, 26 | in_channels=[64, 128, 256, 512], 27 | inner_channels=256, k=10, 28 | bias=False, adaptive=False, smooth=False, serial=False, 29 | *args, **kwargs): 30 | ''' 31 | bias: Whether conv layers have bias or not. 32 | adaptive: Whether to use adaptive threshold training or not. 33 | smooth: If true, use bilinear instead of deconv. 34 | serial: If true, thresh prediction will combine segmentation result as input. 35 | ''' 36 | super(SegDetector, self).__init__() 37 | self.k = k 38 | self.serial = serial 39 | self.up5 = nn.Upsample(scale_factor=2, mode='nearest') 40 | self.up4 = nn.Upsample(scale_factor=2, mode='nearest') 41 | self.up3 = nn.Upsample(scale_factor=2, mode='nearest') 42 | 43 | self.in5 = nn.Conv2d(in_channels[-1], inner_channels, 1, bias=bias) 44 | self.in4 = nn.Conv2d(in_channels[-2], inner_channels, 1, bias=bias) 45 | self.in3 = nn.Conv2d(in_channels[-3], inner_channels, 1, bias=bias) 46 | self.in2 = nn.Conv2d(in_channels[-4], inner_channels, 1, bias=bias) 47 | 48 | self.out5 = nn.Sequential( 49 | nn.Conv2d(inner_channels, inner_channels // 50 | 4, 3, padding=1, bias=bias), 51 | nn.Upsample(scale_factor=8, mode='nearest')) 52 | self.out4 = nn.Sequential( 53 | nn.Conv2d(inner_channels, inner_channels // 54 | 4, 3, padding=1, bias=bias), 55 | nn.Upsample(scale_factor=4, mode='nearest')) 56 | self.out3 = nn.Sequential( 57 | nn.Conv2d(inner_channels, inner_channels // 58 | 4, 3, padding=1, bias=bias), 59 | nn.Upsample(scale_factor=2, mode='nearest')) 60 | self.out2 = nn.Conv2d( 61 | inner_channels, inner_channels//4, 3, padding=1, bias=bias) 62 | 63 | self.binarize = nn.Sequential( 64 | nn.Conv2d(inner_channels, inner_channels // 65 | 4, 3, padding=1, bias=bias), 66 | nn.BatchNorm2d(inner_channels//4), 67 | nn.ReLU(inplace=True), 68 | nn.ConvTranspose2d(inner_channels//4, inner_channels//4, 2, 2), 69 | nn.BatchNorm2d(inner_channels//4), 70 | nn.ReLU(inplace=True), 71 | nn.ConvTranspose2d(inner_channels//4, 1, 2, 2), 72 | nn.Sigmoid()) 73 | self.binarize.apply(self.weights_init) 74 | 75 | self.adaptive = adaptive 76 | if adaptive: 77 | self.thresh = self._init_thresh( 78 | inner_channels, serial=serial, smooth=smooth, bias=bias) 79 | self.thresh.apply(self.weights_init) 80 | 81 | self.in5.apply(self.weights_init) 82 | self.in4.apply(self.weights_init) 83 | self.in3.apply(self.weights_init) 84 | self.in2.apply(self.weights_init) 85 | self.out5.apply(self.weights_init) 86 | self.out4.apply(self.weights_init) 87 | self.out3.apply(self.weights_init) 88 | self.out2.apply(self.weights_init) 89 | 90 | def weights_init(self, m): 91 | classname = m.__class__.__name__ 92 | if classname.find('Conv') != -1: 93 | nn.init.kaiming_normal_(m.weight.data) 94 | elif classname.find('BatchNorm') != -1: 95 | m.weight.data.fill_(1.) 96 | m.bias.data.fill_(1e-4) 97 | 98 | def _init_thresh(self, inner_channels, 99 | serial=False, smooth=False, bias=False): 100 | in_channels = inner_channels 101 | if serial: 102 | in_channels += 1 103 | self.thresh = nn.Sequential( 104 | nn.Conv2d(in_channels, inner_channels // 105 | 4, 3, padding=1, bias=bias), 106 | nn.BatchNorm2d(inner_channels//4), 107 | nn.ReLU(inplace=True), 108 | self._init_upsample(inner_channels // 4, inner_channels//4, smooth=smooth, bias=bias), 109 | nn.BatchNorm2d(inner_channels//4), 110 | nn.ReLU(inplace=True), 111 | self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias), 112 | nn.Sigmoid()) 113 | return self.thresh 114 | 115 | def _init_upsample(self, 116 | in_channels, out_channels, 117 | smooth=False, bias=False): 118 | if smooth: 119 | inter_out_channels = out_channels 120 | if out_channels == 1: 121 | inter_out_channels = in_channels 122 | module_list = [ 123 | nn.Upsample(scale_factor=2, mode='nearest'), 124 | nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)] 125 | if out_channels == 1: 126 | module_list.append( 127 | nn.Conv2d(in_channels, out_channels, 128 | kernel_size=1, stride=1, padding=1, bias=True)) 129 | 130 | return nn.Sequential(module_list) 131 | else: 132 | return nn.ConvTranspose2d(in_channels, out_channels, 2, 2) 133 | 134 | def forward(self, features, gt=None, masks=None): 135 | c2, c3, c4, c5 = features 136 | 137 | in5 = self.in5(c5) 138 | in4 = self.in4(c4) 139 | in3 = self.in3(c3) 140 | in2 = self.in2(c2) 141 | 142 | out4 = self.up5(in5) + in4 # 1/16 143 | out3 = self.up4(out4) + in3 # 1/8 144 | out2 = self.up3(out3) + in2 # 1/4 145 | 146 | p5 = self.out5(in5) 147 | p4 = self.out4(out4) 148 | p3 = self.out3(out3) 149 | p2 = self.out2(out2) 150 | fuse = torch.cat((p5, p4, p3, p2), 1) #shaep=(1,256,h/4,w/4) 151 | 152 | # this is the pred module, not binarization module; 153 | # We do not correct the name due to the trained model. 154 | binary = self.binarize(fuse) 155 | return binary, fuse 156 | 157 | def step_function(self, x, y): 158 | return torch.reciprocal(1 + torch.exp(-self.k * (x - y))) 159 | 160 | class BasicModel(nn.Module): 161 | def __init__(self): 162 | super(BasicModel, self).__init__() 163 | 164 | self.backbone = deformable_resnet50() 165 | self.decoder = SegDetector(adaptive=True, in_channels=[256, 512, 1024, 2048], k=50) 166 | 167 | def forward(self, data, *args, **kwargs): 168 | return self.decoder(self.backbone(data), *args, **kwargs) 169 | 170 | class SegDetectorModel(nn.Module): 171 | def __init__(self, distributed: bool = False, local_rank: int = 0): 172 | super(SegDetectorModel, self).__init__() 173 | self.model = BasicModel() 174 | 175 | def forward(self, data): 176 | data = data.float() 177 | pred, feat = self.model(data) 178 | return pred, feat 179 | 180 | class DB_Embedding_Model(nn.Module): 181 | def __init__(self, nh=256, nclass=16): 182 | super(DB_Embedding_Model, self).__init__() 183 | self.db = SegDetectorModel() 184 | self.roi_align = RoIAlign((5, 16), 0.25, 0) 185 | self.extra_conv_layers = nn.Sequential( 186 | nn.Conv2d(256, 256, 3, padding=1, bias=True), 187 | nn.BatchNorm2d(256), 188 | nn.ReLU(inplace=True), 189 | nn.Conv2d(256, 256, 3, padding=1, bias=True), 190 | nn.BatchNorm2d(256), 191 | nn.ReLU(inplace=True)) 192 | 193 | self.embed_layers = nn.Sequential(nn.Linear(256*5*16,1024), 194 | nn.ReLU(inplace=True), 195 | nn.Linear(1024,256) 196 | ) 197 | 198 | self.seq_conv = nn.Conv2d(256,256,kernel_size=(5,3), padding=(0,1)) 199 | self.rnn = nn.Sequential(BidirectionalLSTM(256, nh, nh),BidirectionalLSTM(nh, nh, nclass)) 200 | 201 | def forward(self, data, all_boxes=None): 202 | pred, feat = self.db(data) 203 | feat_refine = self.extra_conv_layers(feat.detach()) 204 | if all_boxes is None: 205 | return pred, feat_refine 206 | 207 | roi_feat = self.roi_align(feat_refine, all_boxes) 208 | ### Visual feat 209 | obj_num = roi_feat.shape[0] 210 | v_roi_feat = roi_feat.view(obj_num, -1) 211 | v_roi_feat = self.embed_layers(v_roi_feat) 212 | 213 | # Semantic feat 214 | s_roi_feat = self.seq_conv(roi_feat) 215 | s_roi_feat = s_roi_feat.squeeze(2).permute(2, 0, 1) 216 | s_roi_feat = self.rnn(s_roi_feat).permute(1, 2, 0).contiguous().view(obj_num, -1) 217 | 218 | roi_feat = torch.cat((v_roi_feat, s_roi_feat), -1) 219 | # roi_feat = s_roi_feat 220 | return pred, roi_feat 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | -------------------------------------------------------------------------------- /db_model/representers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /db_model/representers/seg_detector_representer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from shapely.geometry import Polygon 4 | import pyclipper 5 | 6 | class SegDetectorRepresenter(): 7 | def __init__(self): 8 | self.thresh = 0.3 #0.3 9 | self.box_thresh = 0.45 10 | self.max_candidates = 1000 11 | self.dest = 'binary' 12 | self.min_size = 3 13 | 14 | def represent(self, batch, _pred, is_output_polygon=False): 15 | ''' 16 | batch: a dict produced by dataloaders. 17 | image: tensor of shape (N, C, H, W). 18 | 19 | shape: the original shape of images. 20 | filename: the original filenames of images. 21 | pred: 22 | binary: text region segmentation map, with shape (N, 1, H, W) 23 | thresh: [if exists] thresh hold prediction with shape (N, 1, H, W) 24 | thresh_binary: [if exists] binarized with threshhold, (N, 1, H, W) 25 | ''' 26 | images = batch['image'] 27 | if isinstance(_pred, dict): 28 | pred = _pred[self.dest] 29 | else: 30 | pred = _pred 31 | segmentation = self.binarize(pred) 32 | boxes_batch = [] 33 | scores_batch = [] 34 | for batch_index in range(images.size(0)): 35 | height, width = batch['shape'][batch_index] 36 | if is_output_polygon: 37 | boxes, scores = self.polygons_from_bitmap( 38 | pred[batch_index], 39 | segmentation[batch_index], width, height) 40 | else: 41 | boxes, scores = self.boxes_from_bitmap( 42 | pred[batch_index], 43 | segmentation[batch_index], width, height) 44 | boxes_batch.append(boxes) 45 | scores_batch.append(scores) 46 | return boxes_batch, scores_batch 47 | 48 | def binarize(self, pred): 49 | return pred > self.thresh 50 | 51 | def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height): 52 | ''' 53 | _bitmap: single map with shape (1, H, W), 54 | whose values are binarized as {0, 1} 55 | ''' 56 | assert _bitmap.size(0) == 1 57 | bitmap = _bitmap.cpu().numpy()[0] # The first channel 58 | pred = pred.cpu().detach().numpy()[0] 59 | height, width = bitmap.shape 60 | boxes = [] 61 | scores = [] 62 | 63 | contours, _ = cv2.findContours( 64 | (bitmap*255).astype(np.uint8), 65 | cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) 66 | 67 | for contour in contours[:self.max_candidates]: 68 | epsilon = 0.01 * cv2.arcLength(contour, True) 69 | approx = cv2.approxPolyDP(contour, epsilon, True) 70 | points = approx.reshape((-1, 2)) 71 | if points.shape[0] < 4: 72 | continue 73 | # _, sside = self.get_mini_boxes(contour) 74 | # if sside < self.min_size: 75 | # continue 76 | score = self.box_score_fast(pred, points.reshape(-1, 2)) 77 | if self.box_thresh > score: 78 | continue 79 | 80 | if points.shape[0] > 2: 81 | box = self.unclip(points, unclip_ratio=2.0) 82 | if len(box) > 1: 83 | continue 84 | else: 85 | continue 86 | box = box.reshape(-1, 2) 87 | _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2))) 88 | if sside < self.min_size + 2: 89 | continue 90 | 91 | if not isinstance(dest_width, int): 92 | dest_width = dest_width.item() 93 | dest_height = dest_height.item() 94 | 95 | box[:, 0] = np.clip( 96 | np.round(box[:, 0] / width * dest_width), 0, dest_width) 97 | box[:, 1] = np.clip( 98 | np.round(box[:, 1] / height * dest_height), 0, dest_height) 99 | boxes.append(box.tolist()) 100 | scores.append(score) 101 | return boxes, scores 102 | 103 | def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): 104 | ''' 105 | _bitmap: single map with shape (1, H, W), 106 | whose values are binarized as {0, 1} 107 | ''' 108 | 109 | assert _bitmap.size(0) == 1 110 | bitmap = _bitmap.cpu().numpy()[0] 111 | pred = pred.cpu().detach().numpy()[0] 112 | height, width = bitmap.shape 113 | contours, _ = cv2.findContours( 114 | (bitmap*255).astype(np.uint8), 115 | cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) 116 | 117 | num_contours = min(len(contours), self.max_candidates) 118 | # print('number of contours', num_contours) 119 | boxes = np.zeros((num_contours, 4, 2), dtype=np.int16) 120 | scores = np.zeros((num_contours,), dtype=np.float32) 121 | 122 | for index in range(num_contours): 123 | contour = contours[index] 124 | points, sside = self.get_mini_boxes(contour) 125 | if sside < self.min_size: 126 | continue 127 | points = np.array(points) 128 | score = self.box_score_fast(pred, points.reshape(-1, 2)) 129 | if score < self.box_thresh: 130 | continue 131 | 132 | box = self.unclip(points).reshape(-1, 1, 2) 133 | box, sside = self.get_mini_boxes(box) 134 | if sside < self.min_size + 2: 135 | continue 136 | box = np.array(box) 137 | if not isinstance(dest_width, int): 138 | dest_width = dest_width.item() 139 | dest_height = dest_height.item() 140 | 141 | box[:, 0] = np.clip( 142 | np.round(box[:, 0] / width * dest_width), 0, dest_width) 143 | box[:, 1] = np.clip( 144 | np.round(box[:, 1] / height * dest_height), 0, dest_height) 145 | boxes[index, :, :] = box.astype(np.int16) 146 | scores[index] = score 147 | return boxes, scores 148 | 149 | def unclip(self, box, unclip_ratio=1.5): 150 | poly = Polygon(box) 151 | distance = poly.area * unclip_ratio / poly.length 152 | offset = pyclipper.PyclipperOffset() 153 | offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 154 | expanded = np.array(offset.Execute(distance)) 155 | return expanded 156 | 157 | def get_mini_boxes(self, contour): 158 | bounding_box = cv2.minAreaRect(contour) 159 | points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) 160 | 161 | index_1, index_2, index_3, index_4 = 0, 1, 2, 3 162 | if points[1][1] > points[0][1]: 163 | index_1 = 0 164 | index_4 = 1 165 | else: 166 | index_1 = 1 167 | index_4 = 0 168 | if points[3][1] > points[2][1]: 169 | index_2 = 2 170 | index_3 = 3 171 | else: 172 | index_2 = 3 173 | index_3 = 2 174 | 175 | box = [points[index_1], points[index_2], 176 | points[index_3], points[index_4]] 177 | return box, min(bounding_box[1]) 178 | 179 | def box_score_fast(self, bitmap, _box): 180 | h, w = bitmap.shape[:2] 181 | box = _box.copy() 182 | xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) 183 | xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1) 184 | ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1) 185 | ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1) 186 | 187 | mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) 188 | box[:, 0] = box[:, 0] - xmin 189 | box[:, 1] = box[:, 1] - ymin 190 | cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) 191 | return cv2.mean(bitmap[ymin:ymax+1, xmin:xmax+1], mask)[0] 192 | -------------------------------------------------------------------------------- /db_model/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 5 | 'resnet152'] 6 | 7 | model_urls = { 8 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 9 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 10 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 11 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 12 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 13 | } 14 | 15 | BatchNorm2d = nn.BatchNorm2d 16 | 17 | def constant_init(module, constant, bias=0): 18 | nn.init.constant_(module.weight, constant) 19 | if hasattr(module, 'bias'): 20 | nn.init.constant_(module.bias, bias) 21 | 22 | def conv3x3(in_planes, out_planes, stride=1): 23 | """3x3 convolution with padding""" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 25 | padding=1, bias=False) 26 | 27 | class BasicBlock(nn.Module): 28 | expansion = 1 29 | 30 | def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None): 31 | super(BasicBlock, self).__init__() 32 | self.with_dcn = dcn is not None 33 | self.conv1 = conv3x3(inplanes, planes, stride) 34 | self.bn1 = BatchNorm2d(planes) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.with_modulated_dcn = False 37 | if self.with_dcn: 38 | fallback_on_stride = dcn.get('fallback_on_stride', False) 39 | self.with_modulated_dcn = dcn.get('modulated', False) 40 | if not self.with_dcn or fallback_on_stride: 41 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 42 | padding=1, bias=False) 43 | else: 44 | deformable_groups = dcn.get('deformable_groups', 1) 45 | if not self.with_modulated_dcn: 46 | from assets.ops.dcn import DeformConv 47 | conv_op = DeformConv 48 | offset_channels = 18 49 | else: 50 | from assets.ops.dcn import ModulatedDeformConv 51 | conv_op = ModulatedDeformConv 52 | offset_channels = 27 53 | self.conv2_offset = nn.Conv2d( 54 | planes, 55 | deformable_groups * offset_channels, 56 | kernel_size=3, 57 | padding=1) 58 | self.conv2 = conv_op( 59 | planes, 60 | planes, 61 | kernel_size=3, 62 | padding=1, 63 | deformable_groups=deformable_groups, 64 | bias=False) 65 | self.bn2 = BatchNorm2d(planes) 66 | self.downsample = downsample 67 | 68 | def forward(self, x): 69 | residual = x 70 | 71 | out = self.conv1(x) 72 | out = self.bn1(out) 73 | out = self.relu(out) 74 | 75 | if not self.with_dcn: 76 | out = self.conv2(out) 77 | elif self.with_modulated_dcn: 78 | offset_mask = self.conv2_offset(out) 79 | offset = offset_mask[:, :18, :, :] 80 | mask = offset_mask[:, -9:, :, :].sigmoid() 81 | out = self.conv2(out, offset, mask) 82 | else: 83 | offset = self.conv2_offset(out) 84 | out = self.conv2(out, offset) 85 | out = self.bn2(out) 86 | 87 | if self.downsample is not None: 88 | residual = self.downsample(x) 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class Bottleneck(nn.Module): 96 | expansion = 4 97 | 98 | def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None): 99 | super(Bottleneck, self).__init__() 100 | self.with_dcn = dcn is not None 101 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 102 | self.bn1 = BatchNorm2d(planes) 103 | fallback_on_stride = False 104 | self.with_modulated_dcn = False 105 | if self.with_dcn: 106 | fallback_on_stride = dcn.get('fallback_on_stride', False) 107 | self.with_modulated_dcn = dcn.get('modulated', False) 108 | if not self.with_dcn or fallback_on_stride: 109 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 110 | stride=stride, padding=1, bias=False) 111 | else: 112 | deformable_groups = dcn.get('deformable_groups', 1) 113 | if not self.with_modulated_dcn: 114 | from assets.ops.dcn import DeformConv 115 | conv_op = DeformConv 116 | offset_channels = 18 117 | else: 118 | from assets.ops.dcn import ModulatedDeformConv 119 | conv_op = ModulatedDeformConv 120 | offset_channels = 27 121 | self.conv2_offset = nn.Conv2d( 122 | planes, deformable_groups * offset_channels, 123 | kernel_size=3, 124 | padding=1) 125 | self.conv2 = conv_op( 126 | planes, planes, kernel_size=3, padding=1, stride=stride, 127 | deformable_groups=deformable_groups, bias=False) 128 | self.bn2 = BatchNorm2d(planes) 129 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 130 | self.bn3 = BatchNorm2d(planes * 4) 131 | self.relu = nn.ReLU(inplace=True) 132 | self.downsample = downsample 133 | 134 | 135 | def forward(self, x): 136 | residual = x 137 | 138 | out = self.conv1(x) 139 | out = self.bn1(out) 140 | out = self.relu(out) 141 | 142 | if not self.with_dcn: 143 | out = self.conv2(out) 144 | elif self.with_modulated_dcn: 145 | offset_mask = self.conv2_offset(out) 146 | offset = offset_mask[:, :18, :, :] 147 | mask = offset_mask[:, -9:, :, :].sigmoid() 148 | out = self.conv2(out, offset, mask) 149 | else: 150 | offset = self.conv2_offset(out) 151 | out = self.conv2(out, offset) 152 | 153 | out = self.bn2(out) 154 | out = self.relu(out) 155 | 156 | out = self.conv3(out) 157 | out = self.bn3(out) 158 | 159 | if self.downsample is not None: 160 | residual = self.downsample(x) 161 | 162 | out += residual 163 | out = self.relu(out) 164 | 165 | return out 166 | 167 | 168 | class ResNet(nn.Module): 169 | def __init__(self, block, layers, num_classes=1000, 170 | dcn=None, stage_with_dcn=(False, False, False, False)): 171 | super(ResNet, self).__init__() 172 | self.dcn = dcn 173 | self.stage_with_dcn = stage_with_dcn 174 | self.inplanes = 64 175 | 176 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 177 | bias=False) 178 | self.bn1 = BatchNorm2d(64) 179 | self.relu = nn.ReLU(inplace=True) 180 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 181 | self.layer1 = self._make_layer(block, 64, layers[0]) 182 | self.layer2 = self._make_layer( 183 | block, 128, layers[1], stride=2, dcn=dcn) 184 | self.layer3 = self._make_layer( 185 | block, 256, layers[2], stride=2, dcn=dcn) 186 | self.layer4 = self._make_layer( 187 | block, 512, layers[3], stride=2, dcn=dcn) 188 | self.avgpool = nn.AvgPool2d(7, stride=1) 189 | self.fc = nn.Linear(512 * block.expansion, num_classes) 190 | 191 | self.smooth = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=1) 192 | 193 | for m in self.modules(): 194 | if isinstance(m, nn.Conv2d): 195 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 196 | m.weight.data.normal_(0, math.sqrt(2. / n)) 197 | elif isinstance(m, BatchNorm2d): 198 | m.weight.data.fill_(1) 199 | m.bias.data.zero_() 200 | if self.dcn is not None: 201 | for m in self.modules(): 202 | if isinstance(m, Bottleneck) or isinstance(m, BasicBlock): 203 | if hasattr(m, 'conv2_offset'): 204 | constant_init(m.conv2_offset, 0) 205 | 206 | def _make_layer(self, block, planes, blocks, stride=1, dcn=None): 207 | downsample = None 208 | if stride != 1 or self.inplanes != planes * block.expansion: 209 | downsample = nn.Sequential( 210 | nn.Conv2d(self.inplanes, planes * block.expansion, 211 | kernel_size=1, stride=stride, bias=False), 212 | BatchNorm2d(planes * block.expansion), 213 | ) 214 | 215 | layers = [] 216 | layers.append(block(self.inplanes, planes, 217 | stride, downsample, dcn=dcn)) 218 | self.inplanes = planes * block.expansion 219 | for i in range(1, blocks): 220 | layers.append(block(self.inplanes, planes, dcn=dcn)) 221 | 222 | return nn.Sequential(*layers) 223 | 224 | def forward(self, x): 225 | x = self.conv1(x) 226 | x = self.bn1(x) 227 | x = self.relu(x) 228 | x = self.maxpool(x) 229 | 230 | x2 = self.layer1(x) 231 | x3 = self.layer2(x2) 232 | x4 = self.layer3(x3) 233 | x5 = self.layer4(x4) 234 | 235 | return x2, x3, x4, x5 236 | 237 | 238 | def resnet18(pretrained=True, **kwargs): 239 | """Constructs a ResNet-18 model. 240 | Args: 241 | pretrained (bool): If True, returns a model pre-trained on ImageNet 242 | """ 243 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 244 | if pretrained: 245 | model.load_state_dict(model_zoo.load_url( 246 | model_urls['resnet18']), strict=False) 247 | return model 248 | 249 | def deformable_resnet18(pretrained=True, **kwargs): 250 | """Constructs a ResNet-18 model. 251 | Args: 252 | pretrained (bool): If True, returns a model pre-trained on ImageNet 253 | """ 254 | model = ResNet(BasicBlock, [2, 2, 2, 2], 255 | dcn=dict(modulated=True, 256 | deformable_groups=1, 257 | fallback_on_stride=False), 258 | stage_with_dcn=[False, True, True, True], **kwargs) 259 | if pretrained: 260 | model.load_state_dict(model_zoo.load_url( 261 | model_urls['resnet18']), strict=False) 262 | return model 263 | 264 | 265 | def resnet34(pretrained=True, **kwargs): 266 | """Constructs a ResNet-34 model. 267 | Args: 268 | pretrained (bool): If True, returns a model pre-trained on ImageNet 269 | """ 270 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 271 | if pretrained: 272 | model.load_state_dict(model_zoo.load_url( 273 | model_urls['resnet34']), strict=False) 274 | return model 275 | 276 | 277 | def resnet50(pretrained=True, **kwargs): 278 | """Constructs a ResNet-50 model. 279 | Args: 280 | pretrained (bool): If True, returns a model pre-trained on ImageNet 281 | """ 282 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 283 | if pretrained: 284 | model.load_state_dict(model_zoo.load_url( 285 | model_urls['resnet50']), strict=False) 286 | return model 287 | 288 | 289 | def deformable_resnet50(pretrained=True, **kwargs): 290 | """Constructs a ResNet-50 model with deformable conv. 291 | Args: 292 | pretrained (bool): If True, returns a model pre-trained on ImageNet 293 | """ 294 | model = ResNet(Bottleneck, [3, 4, 6, 3], 295 | dcn=dict(modulated=True, 296 | deformable_groups=1, 297 | fallback_on_stride=False), 298 | stage_with_dcn=[False, True, True, True], 299 | **kwargs) 300 | if pretrained: 301 | model.load_state_dict(model_zoo.load_url( 302 | model_urls['resnet50']), strict=False) 303 | return model 304 | 305 | 306 | def resnet101(pretrained=True, **kwargs): 307 | """Constructs a ResNet-101 model. 308 | Args: 309 | pretrained (bool): If True, returns a model pre-trained on ImageNet 310 | """ 311 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 312 | if pretrained: 313 | model.load_state_dict(model_zoo.load_url( 314 | model_urls['resnet101']), strict=False) 315 | return model 316 | 317 | 318 | def resnet152(pretrained=True, **kwargs): 319 | """Constructs a ResNet-152 model. 320 | Args: 321 | pretrained (bool): If True, returns a model pre-trained on ImageNet 322 | """ 323 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 324 | if pretrained: 325 | model.load_state_dict(model_zoo.load_url( 326 | model_urls['resnet152']), strict=False) 327 | return model 328 | 329 | 330 | 331 | 332 | -------------------------------------------------------------------------------- /demo_textboxPP.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import argparse 4 | from utils import log, utils 5 | from track_textboxPP import eval_seq 6 | import warnings 7 | 8 | def track(opt): 9 | 10 | videos = os.listdir(opt.input_root) 11 | videos = [v for v in videos if v.endswith(opt.suffix)] 12 | if opt.input_format == 'video': 13 | from dataset import LoadVideo as DataSet 14 | else: 15 | pass 16 | 17 | for video in videos: 18 | print(video) 19 | input_video = os.path.join(opt.input_root, video) 20 | video_name = video.split('.')[0] #Video_1_1_2 21 | result_root = os.path.join(opt.output_root, video_name) 22 | utils.mkdir_if_missing(result_root) 23 | opt.result_root = result_root 24 | dataloader = DataSet(input_video) 25 | frame_dir = None if opt.output_format=='txt' else osp.join(result_root, 'frames') 26 | logger.info('start tracking...') 27 | eval_seq(opt, dataloader, video_name, frame_dir=frame_dir) 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--suffix', type=str, default='mp4') 32 | parser.add_argument('--input-root', type=str, help='path to the input video') 33 | parser.add_argument('--input-format', type=str, default='video', help='expected input format, can be video, or image') 34 | parser.add_argument('--output-root', type=str, default='results', help='expected output root path') 35 | parser.add_argument('--output-format', type=str, default='video', help='expected output format, can be video, or text') 36 | parser.add_argument('--add-vot-track', action='store_false', help='whether use SCM Module') 37 | parser.add_argument('--show-gt', action='store_true') 38 | parser.add_argument('--gt-dir', type=str, default='') 39 | parser.add_argument('--min-box-area', type=float, default=200, help='filter out tiny boxes') 40 | parser.add_argument('--dataset', type=str, default='icdar', help='icdar or minetoo') 41 | parser.add_argument('--sub-res', action='store_true') 42 | parser.add_argument('--sub-res-root', type=str, default='ourmodel', help='sub dir to save submit files') 43 | parser.add_argument('--conf-thresh', type=float, default=0.65, help='object confidence threshold') 44 | parser.add_argument('--weight-path', type=str, default='./db_model/db_embedding/weights/experiment9_cat_5*16_STL/db_embedding_weight_epoch100.pth', help='path to the model of DB_Embedding') 45 | parser.add_argument('--img-min-size', type=int, default=1280, help='the shorter side of input img') 46 | parser.add_argument('--scm-config', type=str, default='./scm/experiments/siammask_sharp/config.json', help='path to the config of scm') 47 | parser.add_argument('--scm-weight-path', type=str, default='./scm/experiments/siammask_sharp/snapshot/checkpoint_e19.pth', help='path to the model of scm') 48 | parser.add_argument('--eval_det', action='store_true') 49 | parser.add_argument('--save-det-out', action='store_true', help='whether to save det model output') 50 | opt = parser.parse_args() 51 | return opt 52 | 53 | if __name__ == '__main__': 54 | warnings.filterwarnings("ignore") 55 | opt = parse_args() 56 | global logger 57 | logger = log.init_log('root') 58 | log.add_file_handler('root', osp.join(opt.output_root, 'log.txt')) 59 | logger.info(opt) 60 | track(opt) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | attrs==21.4.0 2 | certifi==2020.6.20 3 | flake8==4.0.1 4 | flake8-import-order==0.18.1 5 | importlib-metadata==4.2.0 6 | iniconfig==1.1.1 7 | llvmlite==0.36.0 8 | mccabe==0.6.1 9 | mkl-fft==1.3.0 10 | mkl-random==1.1.1 11 | mkl-service==2.3.0 12 | motmetrics==1.2.0 13 | numba==0.53.1 14 | olefile==0.46 15 | opencv-python==4.5.5.62 16 | packaging==21.3 17 | pandas==1.1.5 18 | pluggy==1.0.0 19 | protobuf==3.19.3 20 | py==1.11.0 21 | py-cpuinfo==8.0.0 22 | pyclipper==1.3.0.post2 23 | pycodestyle==2.8.0 24 | pyflakes==2.4.0 25 | pyparsing==3.0.7 26 | pytest==7.0.1 27 | pytest-benchmark==3.4.1 28 | python-dateutil==2.8.2 29 | pytz==2021.3 30 | scipy==1.5.4 31 | Shapely==1.8.0 32 | tensorboardX==2.4.1 33 | tomli==1.2.3 34 | torch==1.5.0 35 | torchvision==0.6.0a0+82fd1c8 36 | xmltodict==0.12.0 37 | zipp==3.6.0 38 | -------------------------------------------------------------------------------- /scm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lsabrinax/VideoTextSCM/d87ad1bbb6ada7573a02a82045ee1b9ead5861ad/scm/__init__.py -------------------------------------------------------------------------------- /scm/datasets/gen_json.py: -------------------------------------------------------------------------------- 1 | from os.path import join, isfile 2 | import json 3 | import glob 4 | from sys import argv 5 | from collections import defaultdict 6 | 7 | def gen_json(data_dir): 8 | imgs = glob.glob(join(data_dir, 'Video_*/*.jpg')) 9 | n_imgs = len(imgs) 10 | dataset = defaultdict(dict) 11 | for i, img in enumerate(imgs): 12 | ann_file = img+'.txt' 13 | if not isfile(ann_file): 14 | continue 15 | imgid = int(img.split('/')[-1].split('.')[0]) 16 | crop_base_path = join('crop511', img.split('/')[-2]) 17 | 18 | with open(ann_file, 'r') as f: 19 | lines = f.readlines() 20 | assert len(lines) > 0 21 | for line in lines: 22 | line = line.strip() 23 | items = line.split(',') 24 | pts = list(map(float, items[:8])) 25 | x0, x1 = min(pts[::2]), max(pts[::2]) 26 | y0, y1 = min(pts[1::2]), max(pts[1::2]) 27 | bbox = list(map(int, [x0, y0, x1, y1])) 28 | track_id = int(items[8]) 29 | if '{:07d}'.format(track_id) not in dataset[crop_base_path]: 30 | dataset[crop_base_path]['{:07d}'.format(track_id)]={'{:04d}'.format(imgid): bbox} 31 | else: 32 | dataset[crop_base_path]['{:07d}'.format(track_id)].update({'{:04d}'.format(imgid): bbox}) 33 | print('image id: {:04d} / {:04d}'.format(i, n_imgs)) 34 | json.dump(dataset, open(join(data_dir, 'icdar2015.json'), 'w'), indent=4, sort_keys=True) 35 | print('done!') 36 | 37 | if __name__ == '__main__': 38 | gen_json(argv[1]) -------------------------------------------------------------------------------- /scm/datasets/par_crop.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from os.path import join, isdir, isfile 4 | from os import makedirs 5 | from concurrent import futures 6 | import glob 7 | import time 8 | import argparse 9 | from pathlib import Path 10 | import sys 11 | 12 | # Print iterations progress 13 | def printProgress(iteration, total, prefix='', suffix='', decimals=1, barLength=100): 14 | """ 15 | Call in a loop to create terminal progress bar 16 | @params: 17 | iteration - Required : current iteration (Int) 18 | total - Required : total iterations (Int) 19 | prefix - Optional : prefix string (Str) 20 | suffix - Optional : suffix string (Str) 21 | decimals - Optional : positive number of decimals in percent complete (Int) 22 | barLength - Optional : character length of bar (Int) 23 | """ 24 | formatStr = "{0:." + str(decimals) + "f}" 25 | percents = formatStr.format(100 * (iteration / float(total))) 26 | filledLength = int(round(barLength * iteration / float(total))) 27 | bar = '' * filledLength + '-' * (barLength - filledLength) 28 | sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)), 29 | if iteration == total: 30 | sys.stdout.write('\x1b[2K\r') 31 | sys.stdout.flush() 32 | 33 | def crop_hwc(image, bbox, out_sz, padding=(0, 0, 0)): 34 | a = (out_sz) / (bbox[2]-bbox[0]) 35 | b = (out_sz) / (bbox[3]-bbox[1]) 36 | c = -a * bbox[0] 37 | d = -b * bbox[1] 38 | mapping = np.array([[a, 0, c], 39 | [0, b, d]]).astype(np.float) 40 | crop = cv2.warpAffine(image, mapping, (out_sz, out_sz), 41 | borderMode=cv2.BORDER_CONSTANT, borderValue=padding) 42 | return crop 43 | 44 | def pos_s_2_bbox(pos, s): 45 | return [pos[0]-s/2, pos[1]-s/2, pos[0]+s/2, pos[1]+s/2] 46 | 47 | def crop_like_SiamFCx(image, bbox, exemplar_size=127, context_amount=0.5, search_size=255, padding=(0, 0, 0)): 48 | target_pos = [(bbox[2]+bbox[0])/2., (bbox[3]+bbox[1])/2.] 49 | target_size = [bbox[2]-bbox[0]+1, bbox[3]-bbox[1]+1] 50 | wc_z = target_size[1] + context_amount * sum(target_size) 51 | hc_z = target_size[0] + context_amount * sum(target_size) 52 | s_z = np.sqrt(wc_z * hc_z) 53 | scale_z = exemplar_size / s_z 54 | d_search = (search_size - exemplar_size) / 2 55 | pad = d_search / scale_z 56 | s_x = s_z + 2 * pad 57 | 58 | x = crop_hwc(image, pos_s_2_bbox(target_pos, s_x), search_size, padding) 59 | return x 60 | 61 | def crop_img(img, set_crop_base_path,exemplar_size=127, context_amount=0.5, search_size=511, enable_mask=True): 62 | 63 | frame_crop_base_path = join(set_crop_base_path, img.split('/')[-2]) 64 | if not isdir(frame_crop_base_path): makedirs(frame_crop_base_path) 65 | # print(frame_crop_base_path) 66 | im = cv2.imread(img) 67 | avg_chans = np.mean(im, axis=(0, 1)) 68 | ann_file = img+'.txt' 69 | if not isfile(ann_file) or 'Video_46_6_4' not in ann_file: 70 | return 71 | imgid = int(img.split('/')[-1].split('.')[0]) 72 | with open(ann_file, 'r') as f: 73 | lines = f.readlines() 74 | for line in lines: 75 | line = line.strip() 76 | items = line.split(',') 77 | pts = list(map(float, items[:8])) 78 | x0, x1 = min(pts[::2]), max(pts[::2]) 79 | w = int(x1 - x0) 80 | y0, y1 = min(pts[1::2]), max(pts[1::2]) 81 | h = int(y1 - y0) 82 | if w * h <= 0: 83 | continue 84 | bbox = [x0, y0, x1, y1] 85 | track_id = int(items[8]) 86 | x = crop_like_SiamFCx(im, bbox, exemplar_size=exemplar_size, context_amount=context_amount, 87 | search_size=search_size, padding=avg_chans) 88 | cv2.imwrite(join(frame_crop_base_path, '{:04d}.{:07d}.x.jpg'.format(imgid, track_id)), x) 89 | if enable_mask: 90 | im_mask = np.zeros(im.shape) 91 | cv2.fillPoly(im_mask, [np.array(pts).reshape(-1,2).astype(np.int64)], (1,1,1)) 92 | 93 | x = (crop_like_SiamFCx(im_mask, bbox, exemplar_size=exemplar_size, context_amount=context_amount, 94 | search_size=search_size) > 0.5).astype(np.uint8) * 255 95 | cv2.imwrite(join(frame_crop_base_path, '{:04d}.{:07d}.m.png'.format(imgid, track_id)), x) 96 | 97 | def main(data_dir, exemplar_size=127, context_amount=0.5, search_size=511, enable_mask=True, num_threads=24): 98 | data_dir = Path(data_dir) 99 | crop_path = data_dir / 'crop{:d}'.format(search_size) 100 | if not isdir(crop_path): makedirs(crop_path) 101 | 102 | imgs = glob.glob(join(data_dir, 'Video_*/*.jpg')) 103 | n_imgs = len(imgs) 104 | with futures.ProcessPoolExecutor(max_workers=num_threads) as executor: 105 | fs = [executor.submit(crop_img, img,crop_path, exemplar_size, 106 | context_amount, search_size, enable_mask) for img in imgs] 107 | for i, f in enumerate(futures.as_completed(fs)): 108 | printProgress(i, n_imgs, prefix='icdar', suffix='Done ', barLength=40) 109 | print('done') 110 | 111 | if __name__ == '__main__': 112 | parser = argparse.ArgumentParser(description='ICDAR Parallel Preprocessing for SCM') 113 | parser.add_argument('--exemplar_size', type=int, default=127, help='size of exemplar') 114 | parser.add_argument('--context_amount', type=float, default=0.5, help='context amount') 115 | parser.add_argument('--search_size', type=int, default=511, help='size of cropped search region') 116 | parser.add_argument('--enable_mask', action='store_true', help='whether crop mask') 117 | parser.add_argument('--num_threads', type=int, default=16, help='number of threads') 118 | parser.add_argument('--data_dir', type=str, default='../../datasets/video_train', help='dir for data to preprocess') 119 | args = parser.parse_args() 120 | since = time.time() 121 | main(args.data_dir, args.exemplar_size, args.context_amount, \ 122 | args.search_size, args.enable_mask, args.num_threads) 123 | time_elapsed = time.time() - since 124 | print('Total complete in {:.0f}m {:.0f}s'.format( 125 | time_elapsed // 60, time_elapsed % 60)) 126 | -------------------------------------------------------------------------------- /scm/experiments/siammask_sharp/config_icdar.json: -------------------------------------------------------------------------------- 1 | { 2 | "network": { 3 | "arch": "Custom" 4 | }, 5 | "hp": { 6 | "instance_size": 255, 7 | "base_size": 8, 8 | "out_size": 127, 9 | "seg_thr": 0.35, 10 | "penalty_k": 0.04, 11 | "window_influence": 0.4, 12 | "lr": 1.0 13 | }, 14 | "lr": { 15 | "type": "log", 16 | "start_lr": 0.005, 17 | "end_lr": 0.00025 18 | }, 19 | "loss": { 20 | "weight": [0, 0, 36] 21 | }, 22 | "train_datasets": { 23 | "datasets": { 24 | "icdar": { 25 | "root": "./datasets/video_train", 26 | "anno": "./datasets/video_train/icdar2015.json", 27 | "frame_range": 20 28 | } 29 | }, 30 | "template_size": 127, 31 | "search_size": 143, 32 | "base_size": 0, 33 | "size": 3, 34 | 35 | "num" : 200000, 36 | 37 | "augmentation": { 38 | "template": { 39 | "shift": 4, "scale": 0.05 40 | }, 41 | "search": { 42 | "shift": 8, "scale": 0.18, "blur": 0.18 43 | }, 44 | "neg": 0, 45 | "gray": 0.25 46 | } 47 | }, 48 | "anchors": { 49 | "stride": 8, 50 | "ratios": [0.33, 0.5, 1, 2, 3], 51 | "scales": [8], 52 | "round_dight": 0 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /scm/experiments/siammask_sharp/custom.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from scm.models.siammask_sharp import SiamMask 5 | from scm.models.features import MultiStageFeature 6 | from scm.models.rpn import RPN, DepthCorr 7 | from scm.models.mask import Mask 8 | from scm.utils.load_helper import load_pretrain 9 | from .resnet import resnet50 10 | 11 | class ResDownS(nn.Module): 12 | def __init__(self, inplane, outplane): 13 | super(ResDownS, self).__init__() 14 | self.downsample = nn.Sequential( 15 | nn.Conv2d(inplane, outplane, kernel_size=1, bias=False), 16 | nn.BatchNorm2d(outplane)) 17 | 18 | def forward(self, x): 19 | x = self.downsample(x) 20 | if x.size(3) < 20: 21 | l = 4 22 | r = -4 23 | x = x[:, :, l:r, l:r] 24 | return x 25 | 26 | class ResDown(MultiStageFeature): 27 | def __init__(self, pretrain=False): 28 | super(ResDown, self).__init__() 29 | self.features = resnet50(layer3=True, layer4=False) 30 | if pretrain: 31 | load_pretrain(self.features, 'resnet.model') 32 | self.downsample = ResDownS(1024, 256) 33 | self.layers = [self.downsample, self.features.layer2, self.features.layer3] 34 | self.train_nums = [1, 3] 35 | self.change_point = [0, 0.5] 36 | self.unfix(0.0) 37 | 38 | def param_groups(self, start_lr, feature_mult=1): 39 | lr = start_lr * feature_mult 40 | 41 | def _params(module, mult=1): 42 | params = list(filter(lambda x:x.requires_grad, module.parameters())) 43 | if len(params): 44 | return [{'params': params, 'lr': lr * mult}] 45 | else: 46 | return [] 47 | 48 | groups = [] 49 | groups += _params(self.downsample) 50 | groups += _params(self.features, 0.1) 51 | return groups 52 | 53 | def forward(self, x): 54 | output = self.features(x) 55 | p3 = self.downsample(output[-1]) 56 | return p3 57 | 58 | def forward_all(self, x): 59 | output = self.features(x) 60 | p3 = self.downsample(output[-1]) 61 | return output, p3 62 | 63 | class UP(RPN): 64 | def __init__(self, anchor_num=5, feature_in=256, feature_out=256): 65 | super(UP, self).__init__() 66 | self.anchor_num = anchor_num 67 | self.feature_in = feature_in 68 | self.feature_out = feature_out 69 | self.cls_output = 2 * self.anchor_num 70 | self.loc_output = 4 * self.anchor_num 71 | self.cls = DepthCorr(feature_in, feature_out, self.cls_output) 72 | self.loc = DepthCorr(feature_in, feature_out, self.loc_output) 73 | 74 | def forward(self, z_f, x_f): 75 | cls = self.cls(z_f, x_f) 76 | loc = self.loc(z_f, x_f) 77 | return cls, loc 78 | 79 | 80 | class MaskCorr(Mask): 81 | def __init__(self, oSz=63): 82 | super(MaskCorr, self).__init__() 83 | self.oSz = oSz 84 | self.mask = DepthCorr(256, 256, self.oSz**2) 85 | 86 | def forward(self, z, x): 87 | return self.mask(z, x) 88 | 89 | 90 | class Refine(nn.Module): 91 | def __init__(self): 92 | super(Refine, self).__init__() 93 | self.v0 = nn.Sequential(nn.Conv2d(64, 16, 3, padding=1), nn.ReLU(), 94 | nn.Conv2d(16, 4, 3, padding=1),nn.ReLU()) 95 | 96 | self.v1 = nn.Sequential(nn.Conv2d(256, 64, 3, padding=1), nn.ReLU(), 97 | nn.Conv2d(64, 16, 3, padding=1), nn.ReLU()) 98 | 99 | self.v2 = nn.Sequential(nn.Conv2d(512, 128, 3, padding=1), nn.ReLU(), 100 | nn.Conv2d(128, 32, 3, padding=1), nn.ReLU()) 101 | 102 | self.h2 = nn.Sequential(nn.Conv2d(32, 32, 3, padding=1), nn.ReLU(), 103 | nn.Conv2d(32, 32, 3, padding=1), nn.ReLU()) 104 | 105 | self.h1 = nn.Sequential(nn.Conv2d(16, 16, 3, padding=1), nn.ReLU(), 106 | nn.Conv2d(16, 16, 3, padding=1), nn.ReLU()) 107 | 108 | self.h0 = nn.Sequential(nn.Conv2d(4, 4, 3, padding=1), nn.ReLU(), 109 | nn.Conv2d(4, 4, 3, padding=1), nn.ReLU()) 110 | 111 | self.deconv = nn.ConvTranspose2d(256, 32, 15, 15) 112 | self.post0 = nn.Conv2d(32, 16, 3, padding=1) 113 | self.post1 = nn.Conv2d(16, 4, 3, padding=1) 114 | self.post2 = nn.Conv2d(4, 1, 3, padding=1) 115 | 116 | for modules in [self.v0, self.v1, self.v2, self.h2, self.h1, self.h0, self.deconv, self.post0, self.post1, self.post2,]: 117 | for l in modules.modules(): 118 | if isinstance(l, nn.Conv2d): 119 | nn.init.kaiming_uniform_(l.weight, a=1) 120 | 121 | def forward(self, f, corr_feature, pos=None, test=False): 122 | if test: 123 | p0 = torch.nn.functional.pad(f[0], [16, 16, 16, 16])[:, :, 4*pos[0]:4*pos[0]+61, 4*pos[1]:4*pos[1]+61] 124 | p1 = torch.nn.functional.pad(f[1], [8, 8, 8, 8])[:, :, 2 * pos[0]:2 * pos[0] + 31, 2 * pos[1]:2 * pos[1] + 31] 125 | p2 = torch.nn.functional.pad(f[2], [4, 4, 4, 4])[:, :, pos[0]:pos[0] + 15, pos[1]:pos[1] + 15] 126 | else: 127 | p0 = F.unfold(f[0], (61, 61), padding=0, stride=4).permute(0, 2, 1).contiguous().view(-1, 64, 61, 61) 128 | if not (pos is None): p0 = torch.index_select(p0, 0, pos) 129 | p1 = F.unfold(f[1], (31, 31), padding=0, stride=2).permute(0, 2, 1).contiguous().view(-1, 256, 31, 31) 130 | if not (pos is None): p1 = torch.index_select(p1, 0, pos) 131 | p2 = F.unfold(f[2], (15, 15), padding=0, stride=1).permute(0, 2, 1).contiguous().view(-1, 512, 15, 15) 132 | if not (pos is None): p2 = torch.index_select(p2, 0, pos) 133 | 134 | if not(pos is None): 135 | p3 = corr_feature[:, :, pos[0], pos[1]].view(-1, 256, 1, 1) 136 | else: 137 | p3 = corr_feature.permute(0, 2, 3, 1).contiguous().view(-1, 256, 1, 1) 138 | 139 | out = self.deconv(p3) # B * 32 * 15 * 15 140 | out = self.post0(F.upsample(self.h2(out) + self.v2(p2), size=(31, 31))) 141 | out = self.post1(F.upsample(self.h1(out) + self.v1(p1), size=(61, 61))) 142 | out = self.post2(F.upsample(self.h0(out) + self.v0(p0), size=(127, 127))) 143 | out = out.view(-1, 127*127) 144 | return out 145 | 146 | def param_groups(self, start_lr, feature_mult=1): 147 | params = filter(lambda x:x.requires_grad, self.parameters()) 148 | params = [{'params': params, 'lr': start_lr * feature_mult}] 149 | return params 150 | 151 | class Custom(SiamMask): 152 | def __init__(self, pretrain=False, **kwargs): 153 | super(Custom, self).__init__(**kwargs) 154 | self.features = ResDown(pretrain=pretrain) 155 | self.rpn_model = UP(anchor_num=self.anchor_num, feature_in=256, feature_out=256) 156 | self.mask_model = MaskCorr() 157 | self.refine_model = Refine() 158 | 159 | def refine(self, f, pos=None): 160 | return self.refine_model(f, pos) 161 | 162 | def template(self, template): 163 | self.zf = self.features(template) 164 | 165 | def track(self, search): 166 | search = self.features(search) 167 | rpn_pred_cls, rpn_pred_loc = self.rpn(self.zf, search) 168 | return rpn_pred_cls, rpn_pred_loc 169 | 170 | def track_mask(self, search): 171 | self.feature, self.search = self.features.forward_all(search) # feature:125\63\31\31 172 | rpn_pred_cls, rpn_pred_loc = self.rpn(self.zf, self.search) 173 | self.corr_feature = self.mask_model.mask.forward_corr(self.zf, self.search) 174 | pred_mask = self.mask_model.mask.head(self.corr_feature) 175 | return rpn_pred_cls, rpn_pred_loc, pred_mask 176 | 177 | def track_refine(self, pos): 178 | pred_mask = self.refine_model(self.feature, self.corr_feature, pos=pos, test=True) 179 | return pred_mask 180 | -------------------------------------------------------------------------------- /scm/experiments/siammask_sharp/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.autograd import Variable 4 | import math 5 | import torch.utils.model_zoo as model_zoo 6 | from scm.models.features import Features 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | 'resnet152'] 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | } 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | "3x3 convolution with padding" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, bias=False) 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | out += residual 50 | out = self.relu(out) 51 | return out 52 | 53 | 54 | class Bottleneck(Features): 55 | expansion = 4 56 | 57 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 58 | super(Bottleneck, self).__init__() 59 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 60 | self.bn1 = nn.BatchNorm2d(planes) 61 | padding = 2 - stride 62 | assert stride==1 or dilation==1, "stride and dilation must have one equals to zero at least" 63 | if dilation > 1: 64 | padding = dilation 65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 66 | padding=padding, bias=False, dilation=dilation) 67 | self.bn2 = nn.BatchNorm2d(planes) 68 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 69 | self.bn3 = nn.BatchNorm2d(planes * 4) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | residual = x 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv2(out) 82 | out = self.bn2(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv3(out) 86 | out = self.bn3(out) 87 | 88 | if self.downsample is not None: 89 | residual = self.downsample(x) 90 | 91 | if out.size() != residual.size(): 92 | print(out.size(), residual.size()) 93 | out += residual 94 | 95 | out = self.relu(out) 96 | 97 | return out 98 | 99 | 100 | 101 | class Bottleneck_nop(nn.Module): 102 | expansion = 4 103 | 104 | def __init__(self, inplanes, planes, stride=1, downsample=None): 105 | super(Bottleneck_nop, self).__init__() 106 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 107 | self.bn1 = nn.BatchNorm2d(planes) 108 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 109 | padding=0, bias=False) 110 | self.bn2 = nn.BatchNorm2d(planes) 111 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 112 | self.bn3 = nn.BatchNorm2d(planes * 4) 113 | self.relu = nn.ReLU(inplace=True) 114 | self.downsample = downsample 115 | self.stride = stride 116 | 117 | def forward(self, x): 118 | residual = x 119 | 120 | out = self.conv1(x) 121 | out = self.bn1(out) 122 | out = self.relu(out) 123 | 124 | out = self.conv2(out) 125 | out = self.bn2(out) 126 | out = self.relu(out) 127 | 128 | out = self.conv3(out) 129 | out = self.bn3(out) 130 | 131 | if self.downsample is not None: 132 | residual = self.downsample(x) 133 | 134 | s = residual.size(3) 135 | residual = residual[:, :, 1:s-1, 1:s-1] 136 | 137 | out += residual 138 | out = self.relu(out) 139 | 140 | return out 141 | 142 | 143 | class ResNet(nn.Module): 144 | 145 | def __init__(self, block, layers, layer4=False, layer3=False): 146 | self.inplanes = 64 147 | super(ResNet, self).__init__() 148 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, # 3 149 | bias=False) 150 | self.bn1 = nn.BatchNorm2d(64) 151 | self.relu = nn.ReLU(inplace=True) 152 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 153 | self.layer1 = self._make_layer(block, 64, layers[0]) 154 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) # 31x31, 15x15 155 | 156 | self.feature_size = 128 * block.expansion 157 | 158 | if layer3: 159 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) # 31x31, 15x15 160 | self.feature_size = (256 + 128) * block.expansion 161 | else: 162 | self.layer3 = lambda x:x # identity 163 | 164 | if layer4: 165 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) 166 | self.feature_size = 512 * block.expansion 167 | else: 168 | self.layer4 = lambda x:x # identity 169 | 170 | for m in self.modules(): 171 | if isinstance(m, nn.Conv2d): 172 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 173 | m.weight.data.normal_(0, math.sqrt(2. / n)) 174 | elif isinstance(m, nn.BatchNorm2d): 175 | m.weight.data.fill_(1) 176 | m.bias.data.zero_() 177 | 178 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 179 | downsample = None 180 | dd = dilation 181 | if stride != 1 or self.inplanes != planes * block.expansion: 182 | if stride == 1 and dilation == 1: 183 | downsample = nn.Sequential( 184 | nn.Conv2d(self.inplanes, planes * block.expansion, 185 | kernel_size=1, stride=stride, bias=False), 186 | nn.BatchNorm2d(planes * block.expansion), 187 | ) 188 | else: 189 | if dilation > 1: 190 | dd = dilation // 2 191 | padding = dd 192 | else: 193 | dd = 1 194 | padding = 0 195 | downsample = nn.Sequential( 196 | nn.Conv2d(self.inplanes, planes * block.expansion, 197 | kernel_size=3, stride=stride, bias=False, 198 | padding=padding, dilation=dd), 199 | nn.BatchNorm2d(planes * block.expansion), 200 | ) 201 | 202 | layers = [] 203 | layers.append(block(self.inplanes, planes, stride, downsample, dilation=dd)) 204 | self.inplanes = planes * block.expansion 205 | for i in range(1, blocks): 206 | layers.append(block(self.inplanes, planes, dilation=dilation)) 207 | 208 | return nn.Sequential(*layers) 209 | 210 | def forward(self, x): 211 | x = self.conv1(x) 212 | x = self.bn1(x) 213 | p0 = self.relu(x) 214 | x = self.maxpool(p0) 215 | 216 | p1 = self.layer1(x) 217 | p2 = self.layer2(p1) 218 | p3 = self.layer3(p2) 219 | 220 | return p0, p1, p2, p3 221 | 222 | class ResAdjust(nn.Module): 223 | def __init__(self, 224 | block=Bottleneck, 225 | out_channels=256, 226 | adjust_number=1, 227 | fuse_layers=[2,3,4]): 228 | super(ResAdjust, self).__init__() 229 | self.fuse_layers = set(fuse_layers) 230 | 231 | if 2 in self.fuse_layers: 232 | self.layer2 = self._make_layer(block, 128, 1, out_channels, adjust_number) 233 | if 3 in self.fuse_layers: 234 | self.layer3 = self._make_layer(block, 256, 2, out_channels, adjust_number) 235 | if 4 in self.fuse_layers: 236 | self.layer4 = self._make_layer(block, 512, 4, out_channels, adjust_number) 237 | 238 | self.feature_size = out_channels * len(self.fuse_layers) 239 | 240 | def _make_layer(self, block, plances, dilation, out, number=1): 241 | 242 | layers = [] 243 | 244 | for _ in range(number): 245 | layer = block(plances * block.expansion, plances, dilation=dilation) 246 | layers.append(layer) 247 | 248 | downsample = nn.Sequential( 249 | nn.Conv2d(plances * block.expansion, out, kernel_size=3, padding=1, bias=False), 250 | nn.BatchNorm2d(out) 251 | ) 252 | layers.append(downsample) 253 | 254 | return nn.Sequential(*layers) 255 | 256 | def forward(self, p2, p3, p4): 257 | outputs = [] 258 | 259 | if 2 in self.fuse_layers: 260 | outputs.append(self.layer2(p2)) 261 | if 3 in self.fuse_layers: 262 | outputs.append(self.layer3(p3)) 263 | if 4 in self.fuse_layers: 264 | outputs.append(self.layer4(p4)) 265 | return outputs 266 | 267 | def resnet18(pretrained=False, **kwargs): 268 | """Constructs a ResNet-18 model. 269 | 270 | Args: 271 | pretrained (bool): If True, returns a model pre-trained on ImageNet 272 | """ 273 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 274 | if pretrained: 275 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 276 | return model 277 | 278 | def resnet34(pretrained=False, **kwargs): 279 | """Constructs a ResNet-34 model. 280 | 281 | Args: 282 | pretrained (bool): If True, returns a model pre-trained on ImageNet 283 | """ 284 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 285 | if pretrained: 286 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 287 | return model 288 | 289 | def resnet50(pretrained=False, **kwargs): 290 | """Constructs a ResNet-50 model. 291 | 292 | Args: 293 | pretrained (bool): If True, returns a model pre-trained on ImageNet 294 | """ 295 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 296 | if pretrained: 297 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 298 | return model 299 | 300 | def resnet101(pretrained=False, **kwargs): 301 | """Constructs a ResNet-101 model. 302 | 303 | Args: 304 | pretrained (bool): If True, returns a model pre-trained on ImageNet 305 | """ 306 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 307 | if pretrained: 308 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 309 | return model 310 | 311 | def resnet152(pretrained=False, **kwargs): 312 | """Constructs a ResNet-152 model. 313 | 314 | Args: 315 | pretrained (bool): If True, returns a model pre-trained on ImageNet 316 | """ 317 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 318 | if pretrained: 319 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 320 | return model 321 | 322 | 323 | -------------------------------------------------------------------------------- /scm/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lsabrinax/VideoTextSCM/d87ad1bbb6ada7573a02a82045ee1b9ead5861ad/scm/models/__init__.py -------------------------------------------------------------------------------- /scm/models/features.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class Features(nn.Module): 5 | def __init__(self): 6 | super(Features, self).__init__() 7 | self.feature_size = -1 8 | 9 | def forward(self, x): 10 | raise NotImplementedError 11 | 12 | def param_groups(self, start_lr, feature_mult=1): 13 | params = filter(lambda x:x.requires_grad, self.parameters()) 14 | params = [{'params': params, 'lr': start_lr * feature_mult}] 15 | return params 16 | 17 | def load_model(self, f='pretrain.model'): 18 | with open(f) as f: 19 | pretrained_dict = torch.load(f) 20 | model_dict = self.state_dict() 21 | print(pretrained_dict.keys()) 22 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 23 | print(pretrained_dict.keys()) 24 | model_dict.update(pretrained_dict) 25 | self.load_state_dict(model_dict) 26 | 27 | class MultiStageFeature(Features): 28 | def __init__(self): 29 | super(MultiStageFeature, self).__init__() 30 | 31 | self.layers = [] 32 | self.train_num = -1 33 | self.change_point = [] 34 | self.train_nums = [] 35 | 36 | def unfix(self, ratio=0.0): 37 | if self.train_num == -1: 38 | self.train_num = 0 39 | self.unlock() 40 | self.eval() 41 | for p, t in reversed(list(zip(self.change_point, self.train_nums))): 42 | if ratio >= p: 43 | if self.train_num != t: 44 | self.train_num = t 45 | self.unlock() 46 | return True 47 | break 48 | return False 49 | 50 | def train_layers(self): 51 | return self.layers[:self.train_num] 52 | 53 | def unlock(self): 54 | for p in self.parameters(): 55 | p.requires_grad = False 56 | 57 | for m in self.train_layers(): 58 | for p in m.parameters(): 59 | p.requires_grad = True 60 | 61 | def train(self, mode): 62 | self.training = mode 63 | if mode == False: 64 | super(MultiStageFeature,self).train(False) 65 | else: 66 | for m in self.train_layers(): 67 | m.train(True) 68 | 69 | return self 70 | -------------------------------------------------------------------------------- /scm/models/mask.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class Mask(nn.Module): 4 | def __init__(self): 5 | super(Mask, self).__init__() 6 | 7 | def forward(self, z_f, x_f): 8 | raise NotImplementedError 9 | 10 | def template(self, template): 11 | raise NotImplementedError 12 | 13 | def track(self, search): 14 | raise NotImplementedError 15 | 16 | def param_groups(self, start_lr, feature_mult=1): 17 | params = filter(lambda x:x.requires_grad, self.parameters()) 18 | params = [{'params': params, 'lr': start_lr * feature_mult}] 19 | return params 20 | -------------------------------------------------------------------------------- /scm/models/rpn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | class RPN(nn.Module): 5 | def __init__(self): 6 | super(RPN, self).__init__() 7 | 8 | def forward(self, z_f, x_f): 9 | raise NotImplementedError 10 | 11 | def template(self, template): 12 | raise NotImplementedError 13 | 14 | def track(self, search): 15 | raise NotImplementedError 16 | 17 | def param_groups(self, start_lr, feature_mult=1, key=None): 18 | if key is None: 19 | params = filter(lambda x:x.requires_grad, self.parameters()) 20 | else: 21 | params = [v for k, v in self.named_parameters() if (key in k) and v.requires_grad] 22 | params = [{'params': params, 'lr': start_lr * feature_mult}] 23 | return params 24 | 25 | def conv2d_dw_group(x, kernel): 26 | batch, channel = kernel.shape[:2] 27 | x = x.view(1, batch*channel, x.size(2), x.size(3)) # 1 * (b*c) * k * k 28 | kernel = kernel.view(batch*channel, 1, kernel.size(2), kernel.size(3)) # (b*c) * 1 * H * W 29 | out = F.conv2d(x, kernel, groups=batch*channel) 30 | out = out.view(batch, channel, out.size(2), out.size(3)) 31 | return out 32 | 33 | class DepthCorr(nn.Module): 34 | def __init__(self, in_channels, hidden, out_channels, kernel_size=3): 35 | super(DepthCorr, self).__init__() 36 | # adjust layer for asymmetrical features 37 | self.conv_kernel = nn.Sequential( 38 | nn.Conv2d(in_channels, hidden, kernel_size=kernel_size, bias=False), 39 | nn.BatchNorm2d(hidden), 40 | nn.ReLU(inplace=True), 41 | ) 42 | self.conv_search = nn.Sequential( 43 | nn.Conv2d(in_channels, hidden, kernel_size=kernel_size, bias=False), 44 | nn.BatchNorm2d(hidden), 45 | nn.ReLU(inplace=True), 46 | ) 47 | self.head = nn.Sequential( 48 | nn.Conv2d(hidden, hidden, kernel_size=1, bias=False), 49 | nn.BatchNorm2d(hidden), 50 | nn.ReLU(inplace=True), 51 | nn.Conv2d(hidden, out_channels, kernel_size=1) 52 | ) 53 | 54 | def forward_corr(self, kernel, input): 55 | kernel = self.conv_kernel(kernel) #5*5 56 | input = self.conv_search(input) #29*29 57 | feature = conv2d_dw_group(input, kernel) # B * C * 25 * 25 58 | return feature 59 | 60 | def forward(self, kernel, search): 61 | feature = self.forward_corr(kernel, search) 62 | out = self.head(feature) 63 | return out # B * out_c *25 * 25 64 | -------------------------------------------------------------------------------- /scm/models/siammask_sharp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from scm.utils.anchors import Anchors 6 | 7 | class SiamMask(nn.Module): 8 | def __init__(self, anchors=None, o_sz=127, g_sz=127): 9 | super(SiamMask, self).__init__() 10 | self.anchors = anchors # anchor_cfg 11 | self.anchor_num = len(self.anchors["ratios"]) * len(self.anchors["scales"]) 12 | self.anchor = Anchors(anchors) 13 | self.features = None 14 | self.rpn_model = None 15 | self.mask_model = None 16 | self.o_sz = o_sz 17 | self.g_sz = g_sz 18 | self.upSample = nn.UpsamplingBilinear2d(size=[g_sz, g_sz]) 19 | self.all_anchors = None 20 | 21 | def set_all_anchors(self, image_center, size): 22 | # cx,cy,w,h 23 | if not self.anchor.generate_all_anchors(image_center, size): 24 | return 25 | all_anchors = self.anchor.all_anchors[1] # cx, cy, w, h 26 | self.all_anchors = torch.from_numpy(all_anchors).float().cuda() 27 | self.all_anchors = [self.all_anchors[i] for i in range(4)] 28 | 29 | def feature_extractor(self, x): 30 | return self.features(x) 31 | 32 | def rpn(self, template, search): 33 | pred_cls, pred_loc = self.rpn_model(template, search) 34 | return pred_cls, pred_loc 35 | 36 | def mask(self, template, search): 37 | pred_mask = self.mask_model(template, search) 38 | return pred_mask 39 | 40 | def _add_rpn_loss(self, label_cls, label_loc, lable_loc_weight, label_mask, label_mask_weight, 41 | rpn_pred_cls, rpn_pred_loc, rpn_pred_mask): 42 | rpn_loss_cls = select_cross_entropy_loss(rpn_pred_cls, label_cls) 43 | 44 | rpn_loss_loc = weight_l1_loss(rpn_pred_loc, label_loc, lable_loc_weight) 45 | 46 | rpn_loss_mask, iou_m, iou_5, iou_7 = select_mask_logistic_loss(rpn_pred_mask, label_mask, label_mask_weight) 47 | 48 | return rpn_loss_cls, rpn_loss_loc, rpn_loss_mask, iou_m, iou_5, iou_7 49 | 50 | def run(self, template, search, softmax=False): 51 | template_feature = self.feature_extractor(template) #B * 256 * 7 * 7 52 | feature, search_feature = self.features.forward_all(search)#69/35/17/17 53 | rpn_pred_cls, rpn_pred_loc = self.rpn(template_feature, search_feature) 54 | corr_feature = self.mask_model.mask.forward_corr(template_feature, search_feature) # (b, 256, w, h) 55 | rpn_pred_mask = self.refine_model(feature, corr_feature) 56 | 57 | if softmax: 58 | rpn_pred_cls = self.softmax(rpn_pred_cls) 59 | return rpn_pred_cls, rpn_pred_loc, rpn_pred_mask, template_feature, search_feature 60 | 61 | def softmax(self, cls): 62 | b, a2, h, w = cls.size() 63 | cls = cls.view(b, 2, a2//2, h, w) 64 | cls = cls.permute(0, 2, 3, 4, 1).contiguous() 65 | cls = F.log_softmax(cls, dim=4) 66 | return cls 67 | 68 | def forward(self, input): 69 | """ 70 | :param input: dict of input with keys of: 71 | 'template': [b, 3, h1, w1], input template image. 72 | 'search': [b, 3, h2, w2], input search image. 73 | 'label_cls':[b, max_num_gts, 5] or None(self.training==False), 74 | each gt contains x1,y1,x2,y2,class. 75 | :return: dict of loss, predict, accuracy 76 | """ 77 | template = input['template'] 78 | search = input['search'] 79 | if self.training: 80 | label_cls = input['label_cls'] 81 | label_loc = input['label_loc'] 82 | lable_loc_weight = input['label_loc_weight'] 83 | label_mask = input['label_mask'] 84 | label_mask_weight = input['label_mask_weight'] 85 | 86 | rpn_pred_cls, rpn_pred_loc, rpn_pred_mask, template_feature, search_feature = \ 87 | self.run(template, search, softmax=self.training) 88 | 89 | outputs = dict() 90 | outputs['predict'] = [rpn_pred_loc, rpn_pred_cls, rpn_pred_mask, template_feature, search_feature] 91 | if self.training: 92 | rpn_loss_cls, rpn_loss_loc, rpn_loss_mask, iou_acc_mean, iou_acc_5, iou_acc_7 = \ 93 | self._add_rpn_loss(label_cls, label_loc, lable_loc_weight, label_mask, label_mask_weight, 94 | rpn_pred_cls, rpn_pred_loc, rpn_pred_mask) 95 | outputs['losses'] = [rpn_loss_cls, rpn_loss_loc, rpn_loss_mask] 96 | outputs['accuracy'] = [iou_acc_mean, iou_acc_5, iou_acc_7] 97 | 98 | return outputs 99 | 100 | def template(self, z): 101 | self.zf = self.feature_extractor(z) 102 | cls_kernel, loc_kernel = self.rpn_model.template(self.zf) 103 | return cls_kernel, loc_kernel 104 | 105 | def track(self, x, cls_kernel=None, loc_kernel=None, softmax=False): 106 | xf = self.feature_extractor(x) 107 | rpn_pred_cls, rpn_pred_loc = self.rpn_model.track(xf, cls_kernel, loc_kernel) 108 | if softmax: 109 | rpn_pred_cls = self.softmax(rpn_pred_cls) 110 | return rpn_pred_cls, rpn_pred_loc 111 | 112 | def get_cls_loss(pred, label, select): 113 | if select.nelement() == 0: return pred.sum()*0. 114 | pred = torch.index_select(pred, 0, select) 115 | label = torch.index_select(label, 0, select) 116 | return F.nll_loss(pred, label) 117 | 118 | def select_cross_entropy_loss(pred, label): 119 | pred = pred.view(-1, 2) 120 | label = label.view(-1) 121 | pos = Variable(label.data.eq(1).nonzero().squeeze()).cuda() 122 | neg = Variable(label.data.eq(0).nonzero().squeeze()).cuda() 123 | 124 | loss_pos = get_cls_loss(pred, label, pos) 125 | loss_neg = get_cls_loss(pred, label, neg) 126 | return loss_pos * 0.5 + loss_neg * 0.5 127 | 128 | def weight_l1_loss(pred_loc, label_loc, loss_weight): 129 | """ 130 | :param pred_loc: [b, 4k, h, w] 131 | :param label_loc: [b, 4k, h, w] 132 | :param loss_weight: [b, k, h, w] 133 | :return: loc loss value 134 | """ 135 | b, _, sh, sw = pred_loc.size() 136 | pred_loc = pred_loc.view(b, 4, -1, sh, sw) 137 | diff = (pred_loc - label_loc).abs() 138 | diff = diff.sum(dim=1).view(b, -1, sh, sw) 139 | loss = diff * loss_weight 140 | return loss.sum().div(b) 141 | 142 | def select_mask_logistic_loss(p_m, mask, weight, o_sz=63, g_sz=127): 143 | weight = weight.view(-1) 144 | pos = Variable(weight.data.eq(1).nonzero().squeeze()) 145 | if pos.nelement() == 0: return p_m.sum() * 0, p_m.sum() * 0, p_m.sum() * 0, p_m.sum() * 0 146 | 147 | if len(p_m.shape) == 4: 148 | p_m = p_m.permute(0, 2, 3, 1).contiguous().view(-1, 1, o_sz, o_sz) 149 | p_m = torch.index_select(p_m, 0, pos) 150 | p_m = nn.UpsamplingBilinear2d(size=[g_sz, g_sz])(p_m) 151 | p_m = p_m.view(-1, g_sz * g_sz) 152 | else: 153 | p_m = torch.index_select(p_m, 0, pos) 154 | 155 | mask_uf = F.unfold(mask, (g_sz, g_sz), padding=0, stride=8) 156 | mask_uf = torch.transpose(mask_uf, 1, 2).contiguous().view(-1, g_sz * g_sz) 157 | mask_uf = torch.index_select(mask_uf, 0, pos) 158 | loss = F.soft_margin_loss(p_m, mask_uf, reduction='none') 159 | iou_m, iou_5, iou_7 = iou_measure(p_m, mask_uf) 160 | return loss, iou_m, iou_5, iou_7 161 | 162 | def iou_measure(pred, label): 163 | pred = pred.ge(0).int() 164 | mask_sum = pred.eq(1).int().add(label.eq(1).int()) 165 | intxn = torch.sum(mask_sum == 2, dim=1).float() 166 | union = torch.sum(mask_sum > 0, dim=1).float() 167 | iou = intxn/(union+1e-6) 168 | return torch.mean(iou), (torch.sum(iou > 0.5).float()/iou.shape[0]), (torch.sum(iou > 0.7).float()/iou.shape[0]) 169 | 170 | -------------------------------------------------------------------------------- /scm/tools/track2mask.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | from scm.utils.anchors import Anchors 7 | from scm.utils.bbox_helper import cxy_wh_2_rect 8 | from scm.utils.tracker_config import TrackerConfig 9 | 10 | def to_torch(ndarray): 11 | if type(ndarray).__module__ == 'numpy': 12 | return torch.from_numpy(ndarray) 13 | elif not torch.is_tensor(ndarray): 14 | raise ValueError("Cannot convert {} to torch tensor" 15 | .format(type(ndarray))) 16 | return ndarray 17 | 18 | def im_to_torch(img): 19 | img = np.transpose(img, (2, 0, 1)) # C*H*W 20 | img = to_torch(img).float() 21 | return img 22 | 23 | def get_subwindow_tracking(im, pos, model_sz, original_sz, avg_chans, out_mode='torch'): 24 | if isinstance(pos, float): 25 | pos = [pos, pos] 26 | sz = original_sz 27 | im_sz = im.shape 28 | c = (original_sz + 1) / 2 29 | context_xmin = round(pos[0] - c) 30 | context_xmax = context_xmin + sz - 1 31 | context_ymin = round(pos[1] - c) 32 | context_ymax = context_ymin + sz - 1 33 | left_pad = int(max(0., -context_xmin)) 34 | top_pad = int(max(0., -context_ymin)) 35 | right_pad = int(max(0., context_xmax - im_sz[1] + 1)) 36 | bottom_pad = int(max(0., context_ymax - im_sz[0] + 1)) 37 | 38 | context_xmin = context_xmin + left_pad 39 | context_xmax = context_xmax + left_pad 40 | context_ymin = context_ymin + top_pad 41 | context_ymax = context_ymax + top_pad 42 | 43 | r, c, k = im.shape 44 | if any([top_pad, bottom_pad, left_pad, right_pad]): 45 | te_im = np.zeros((r + top_pad + bottom_pad, c + left_pad + right_pad, k), np.uint8) 46 | te_im[top_pad:top_pad + r, left_pad:left_pad + c, :] = im 47 | if top_pad: 48 | te_im[0:top_pad, left_pad:left_pad + c, :] = avg_chans 49 | if bottom_pad: 50 | te_im[r + top_pad:, left_pad:left_pad + c, :] = avg_chans 51 | if left_pad: 52 | te_im[:, 0:left_pad, :] = avg_chans 53 | if right_pad: 54 | te_im[:, c + left_pad:, :] = avg_chans 55 | im_patch_original = te_im[int(context_ymin):int(context_ymax + 1), int(context_xmin):int(context_xmax + 1), :] 56 | else: 57 | im_patch_original = im[int(context_ymin):int(context_ymax + 1), int(context_xmin):int(context_xmax + 1), :] 58 | 59 | if not np.array_equal(model_sz, original_sz): 60 | im_patch = cv2.resize(im_patch_original, (model_sz, model_sz)) 61 | else: 62 | im_patch = im_patch_original 63 | return im_to_torch(im_patch) if out_mode in 'torch' else im_patch 64 | 65 | def generate_anchor(cfg, score_size): 66 | anchors = Anchors(cfg) 67 | anchor = anchors.anchors 68 | x1, y1, x2, y2 = anchor[:, 0], anchor[:, 1], anchor[:, 2], anchor[:, 3] 69 | anchor = np.stack([(x1+x2)*0.5, (y1+y2)*0.5, x2-x1, y2-y1], 1) 70 | total_stride = anchors.stride 71 | anchor_num = anchor.shape[0] 72 | anchor = np.tile(anchor, score_size * score_size).reshape((-1, 4)) 73 | ori = - (score_size // 2) * total_stride 74 | xx, yy = np.meshgrid([ori + total_stride * dx for dx in range(score_size)], 75 | [ori + total_stride * dy for dy in range(score_size)]) 76 | xx, yy = np.tile(xx.flatten(), (anchor_num, 1)).flatten(), \ 77 | np.tile(yy.flatten(), (anchor_num, 1)).flatten() 78 | anchor[:, 0], anchor[:, 1] = xx.astype(np.float32), yy.astype(np.float32) 79 | return anchor 80 | 81 | def siamese_init(im, target_pos, target_sz, model, hp=None, device='cpu'): 82 | state = dict() 83 | state['im_h'] = im.shape[0] 84 | state['im_w'] = im.shape[1] 85 | p = TrackerConfig() 86 | p.update(hp, model.anchors) 87 | net = model 88 | p.anchor = generate_anchor(model.anchors, p.score_size) 89 | avg_chans = np.mean(im, axis=(0, 1)) 90 | 91 | wc_z = target_sz[0] + p.context_amount * sum(target_sz) 92 | hc_z = target_sz[1] + p.context_amount * sum(target_sz) 93 | s_z = round(np.sqrt(wc_z * hc_z)) 94 | 95 | # initialize the exemplar 96 | z_crop = get_subwindow_tracking(im, target_pos, p.exemplar_size, s_z, avg_chans) 97 | z = Variable(z_crop.unsqueeze(0)) 98 | net.template(z.to(device)) 99 | 100 | if p.windowing == 'cosine': 101 | window = np.outer(np.hanning(p.score_size), np.hanning(p.score_size)) 102 | elif p.windowing == 'uniform': 103 | window = np.ones((p.score_size, p.score_size)) 104 | window = np.tile(window.flatten(), p.anchor_num) 105 | 106 | state['p'] = p 107 | state['net'] = net 108 | state['avg_chans'] = avg_chans 109 | state['window'] = window 110 | state['target_pos'] = target_pos 111 | state['target_sz'] = target_sz 112 | return state 113 | 114 | def siamese_track(state, im, mask_enable=False, refine_enable=False, device='cpu', debug=False): 115 | p = state['p'] 116 | net = state['net'] 117 | avg_chans = state['avg_chans'] 118 | window = state['window'] 119 | target_pos = state['target_pos'] 120 | target_sz = state['target_sz'] 121 | 122 | wc_x = target_sz[1] + p.context_amount * sum(target_sz) 123 | hc_x = target_sz[0] + p.context_amount * sum(target_sz) 124 | s_x = np.sqrt(wc_x * hc_x) 125 | scale_x = p.exemplar_size / s_x 126 | d_search = (p.instance_size - p.exemplar_size) / 2 127 | pad = d_search / scale_x 128 | s_x = s_x + 2 * pad 129 | crop_box = [target_pos[0] - round(s_x) / 2, target_pos[1] - round(s_x) / 2, round(s_x), round(s_x)] 130 | 131 | if debug: 132 | im_debug = im.copy() 133 | crop_box_int = np.int0(crop_box) 134 | cv2.rectangle(im_debug, (crop_box_int[0], crop_box_int[1]), 135 | (crop_box_int[0] + crop_box_int[2], crop_box_int[1] + crop_box_int[3]), (255, 0, 0), 2) 136 | cv2.imshow('search area', im_debug) 137 | cv2.waitKey(0) 138 | 139 | # extract scaled crops for search region x at previous target position 140 | x_crop = Variable(get_subwindow_tracking(im, target_pos, p.instance_size, round(s_x), avg_chans).unsqueeze(0)) 141 | 142 | if mask_enable: 143 | score, delta, mask = net.track_mask(x_crop.to(device)) 144 | else: 145 | score, delta = net.track(x_crop.to(device)) 146 | 147 | delta = delta.permute(1, 2, 3, 0).contiguous().view(4, -1).data.cpu().numpy() 148 | score = F.softmax(score.permute(1, 2, 3, 0).contiguous().view(2, -1).permute(1, 0), dim=1).data[:, 149 | 1].cpu().numpy() 150 | delta[0, :] = delta[0, :] * p.anchor[:, 2] + p.anchor[:, 0] 151 | delta[1, :] = delta[1, :] * p.anchor[:, 3] + p.anchor[:, 1] 152 | delta[2, :] = np.exp(delta[2, :]) * p.anchor[:, 2] 153 | delta[3, :] = np.exp(delta[3, :]) * p.anchor[:, 3] 154 | 155 | def change(r): 156 | return np.maximum(r, 1. / r) 157 | 158 | def sz(w, h): 159 | pad = (w + h) * 0.5 160 | sz2 = (w + pad) * (h + pad) 161 | return np.sqrt(sz2) 162 | 163 | def sz_wh(wh): 164 | pad = (wh[0] + wh[1]) * 0.5 165 | sz2 = (wh[0] + pad) * (wh[1] + pad) 166 | return np.sqrt(sz2) 167 | 168 | # size penalty 169 | target_sz_in_crop = target_sz*scale_x 170 | s_c = change(sz(delta[2, :], delta[3, :]) / (sz_wh(target_sz_in_crop))) # scale penalty 171 | r_c = change((target_sz_in_crop[0] / target_sz_in_crop[1]) / (delta[2, :] / delta[3, :])) # ratio penalty 172 | 173 | penalty = np.exp(-(r_c * s_c - 1) * p.penalty_k) 174 | pscore = penalty * score 175 | 176 | # cos window (motion model) 177 | pscore = pscore * (1 - p.window_influence) + window * p.window_influence 178 | best_pscore_id = np.argmax(pscore) 179 | 180 | pred_in_crop = delta[:, best_pscore_id] / scale_x 181 | lr = penalty[best_pscore_id] * score[best_pscore_id] * p.lr 182 | 183 | res_x = pred_in_crop[0] + target_pos[0] 184 | res_y = pred_in_crop[1] + target_pos[1] 185 | 186 | res_w = target_sz[0] * (1 - lr) + pred_in_crop[2] * lr 187 | res_h = target_sz[1] * (1 - lr) + pred_in_crop[3] * lr 188 | 189 | target_pos = np.array([res_x, res_y]) 190 | target_sz = np.array([res_w, res_h]) 191 | 192 | # for Mask Branch 193 | if mask_enable: 194 | best_pscore_id_mask = np.unravel_index(best_pscore_id, (5, p.score_size, p.score_size)) 195 | delta_x, delta_y = best_pscore_id_mask[2], best_pscore_id_mask[1] 196 | 197 | if refine_enable: 198 | mask = net.track_refine((delta_y, delta_x)).to(device).sigmoid().squeeze().view( 199 | p.out_size, p.out_size).cpu().data.numpy() 200 | else: 201 | mask = mask[0, :, delta_y, delta_x].sigmoid(). \ 202 | squeeze().view(p.out_size, p.out_size).cpu().data.numpy() 203 | 204 | def crop_back(image, bbox, out_sz, padding=-1): 205 | a = (out_sz[0] - 1) / bbox[2] 206 | b = (out_sz[1] - 1) / bbox[3] 207 | c = -a * bbox[0] 208 | d = -b * bbox[1] 209 | mapping = np.array([[a, 0, c], 210 | [0, b, d]]).astype(np.float) 211 | crop = cv2.warpAffine(image, mapping, (out_sz[0], out_sz[1]), 212 | flags=cv2.INTER_LINEAR, 213 | borderMode=cv2.BORDER_CONSTANT, 214 | borderValue=padding) 215 | return crop 216 | 217 | s = crop_box[2] / p.instance_size 218 | sub_box = [crop_box[0] + (delta_x - p.base_size / 2) * p.total_stride * s, 219 | crop_box[1] + (delta_y - p.base_size / 2) * p.total_stride * s, 220 | s * p.exemplar_size, s * p.exemplar_size] 221 | s = p.out_size / sub_box[2] 222 | back_box = [-sub_box[0] * s, -sub_box[1] * s, state['im_w'] * s, state['im_h'] * s] 223 | mask_in_img = crop_back(mask, back_box, (state['im_w'], state['im_h'])) 224 | 225 | target_mask = (mask_in_img > p.seg_thr).astype(np.uint8) 226 | if cv2.__version__[-5] == '4': 227 | contours, _ = cv2.findContours(target_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 228 | else: 229 | _, contours, _ = cv2.findContours(target_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 230 | cnt_area = [cv2.contourArea(cnt) for cnt in contours] 231 | if len(contours) != 0 and np.max(cnt_area) > 100: 232 | contour = contours[np.argmax(cnt_area)] # use max area polygon 233 | polygon = contour.reshape(-1, 2) 234 | prbox = cv2.boxPoints(cv2.minAreaRect(polygon)) # Rotated Rectangle 235 | rbox_in_img = prbox 236 | else: # empty mask 237 | location = cxy_wh_2_rect(target_pos, target_sz) 238 | rbox_in_img = np.array([[location[0], location[1]], 239 | [location[0] + location[2], location[1]], 240 | [location[0] + location[2], location[1] + location[3]], 241 | [location[0], location[1] + location[3]]]) 242 | 243 | target_pos[0] = max(0, min(state['im_w'], target_pos[0])) 244 | target_pos[1] = max(0, min(state['im_h'], target_pos[1])) 245 | target_sz[0] = max(10, min(state['im_w'], target_sz[0])) 246 | target_sz[1] = max(10, min(state['im_h'], target_sz[1])) 247 | 248 | state['target_pos'] = target_pos 249 | state['target_sz'] = target_sz 250 | state['score'] = score[best_pscore_id] 251 | state['mask'] = mask_in_img if mask_enable else [] 252 | state['ploygon'] = rbox_in_img if mask_enable else [] 253 | return state 254 | 255 | def track_all_objs2mask(img0, img1, boxes, device, scm, cfg): 256 | ''' 257 | img0: previous frame 258 | img1: current frame 259 | boxes: Tensor, shape(n, 9), boxes in previous frame 260 | ''' 261 | boxes = boxes.clone().numpy() 262 | boxes_ = boxes[:,:8] 263 | mask = np.zeros((img1.shape[0], img1.shape[1])) 264 | polygons = [] 265 | 266 | for i in range(boxes_.shape[0]): 267 | init_rbox = boxes_[i] 268 | region_point = init_rbox.reshape(-1, 2) 269 | init_rect = cv2.boundingRect(region_point) 270 | x,y,w,h = init_rect 271 | # tracker init 272 | target_pos = np.array([x + w / 2, y + h / 2]) 273 | target_sz = np.array([w, h]) 274 | state = siamese_init(img0, target_pos, target_sz, scm, cfg['hp'], device=device) 275 | # track 276 | state = siamese_track(state, img1, mask_enable=True, refine_enable=True, device=device) 277 | track_mask = state['mask'] 278 | track_mask[track_mask 0: 35 | ws = round(math.sqrt(size*1. / r), self.round_dight) 36 | hs = round(ws * r, self.round_dight) 37 | else: 38 | ws = int(math.sqrt(size*1. / r)) 39 | hs = int(ws * r) 40 | 41 | for s in self.scales: 42 | w = ws * s 43 | h = hs * s 44 | self.anchors[count][:] = [-w*0.5+x_offset, -h*0.5+y_offset, w*0.5+x_offset, h*0.5+y_offset][:] 45 | count += 1 46 | 47 | def generate_all_anchors(self, im_c, size): 48 | if self.image_center == im_c and self.size == size: 49 | return False 50 | self.image_center = im_c 51 | self.size = size 52 | 53 | a0x = im_c - size // 2 * self.stride 54 | ori = np.array([a0x] * 4, dtype=np.float32) 55 | zero_anchors = self.anchors + ori 56 | 57 | x1 = zero_anchors[:, 0] 58 | y1 = zero_anchors[:, 1] 59 | x2 = zero_anchors[:, 2] 60 | y2 = zero_anchors[:, 3] 61 | 62 | x1, y1, x2, y2 = map(lambda x: x.reshape(self.anchor_num, 1, 1), [x1, y1, x2, y2]) 63 | cx, cy, w, h = corner2center([x1, y1, x2, y2]) 64 | 65 | disp_x = np.arange(0, size).reshape(1, 1, -1) * self.stride 66 | disp_y = np.arange(0, size).reshape(1, -1, 1) * self.stride 67 | 68 | cx = cx + disp_x 69 | cy = cy + disp_y 70 | 71 | # broadcast 72 | zero = np.zeros((self.anchor_num, size, size), dtype=np.float32) 73 | cx, cy, w, h = map(lambda x: x + zero, [cx, cy, w, h]) 74 | x1, y1, x2, y2 = center2corner([cx, cy, w, h]) 75 | 76 | self.all_anchors = np.stack([x1, y1, x2, y2]), np.stack([cx, cy, w, h]) 77 | return True 78 | 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /scm/utils/bbox_helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import namedtuple 3 | 4 | Corner = namedtuple('Corner', 'x1 y1 x2 y2') 5 | BBox = Corner 6 | Center = namedtuple('Center', 'x y w h') 7 | 8 | def corner2center(corner): 9 | """ 10 | :param corner: Corner or np.array 4*N 11 | :return: Center or 4 np.array N 12 | """ 13 | if isinstance(corner, Corner): 14 | x1, y1, x2, y2 = corner 15 | return Center((x1 + x2) * 0.5, (y1 + y2) * 0.5, (x2 - x1), (y2 - y1)) 16 | else: 17 | x1, y1, x2, y2 = corner[0], corner[1], corner[2], corner[3] 18 | x = (x1 + x2) * 0.5 19 | y = (y1 + y2) * 0.5 20 | w = x2 - x1 21 | h = y2 - y1 22 | return x, y, w, h 23 | 24 | def center2corner(center): 25 | """ 26 | :param center: Center or np.array 4*N 27 | :return: Corner or np.array 4*N 28 | """ 29 | if isinstance(center, Center): 30 | x, y, w, h = center 31 | return Corner(x - w * 0.5, y - h * 0.5, x + w * 0.5, y + h * 0.5) 32 | else: 33 | x, y, w, h = center[0], center[1], center[2], center[3] 34 | x1 = x - w * 0.5 35 | y1 = y - h * 0.5 36 | x2 = x + w * 0.5 37 | y2 = y + h * 0.5 38 | return x1, y1, x2, y2 39 | 40 | def cxy_wh_2_rect(pos, sz): 41 | return np.array([pos[0]-sz[0]/2, pos[1]-sz[1]/2, sz[0], sz[1]]) 42 | 43 | def aug_apply(bbox, param, shape, inv=False, rd=False): 44 | """ 45 | apply augmentation 46 | :param bbox: original bbox in image 47 | :param param: augmentation param, shift/scale 48 | :param shape: image shape, h, w, (c) 49 | :param inv: inverse 50 | :param rd: round bbox 51 | :return: bbox(, param) 52 | bbox: augmented bbox 53 | param: real augmentation param 54 | """ 55 | if not inv: 56 | center = corner2center(bbox) 57 | original_center = center 58 | 59 | real_param = {} 60 | if 'scale' in param: 61 | scale_x, scale_y = param['scale'] 62 | imh, imw = shape[:2] 63 | h, w = center.h, center.w 64 | scale_x = min(scale_x, float(imw) / w) 65 | scale_y = min(scale_y, float(imh) / h) 66 | center = Center(center.x, center.y, center.w * scale_x, center.h * scale_y) 67 | 68 | bbox = center2corner(center) 69 | if 'shift' in param: 70 | tx, ty = param['shift'] 71 | x1, y1, x2, y2 = bbox 72 | imh, imw = shape[:2] 73 | tx = max(-x1, min(imw - 1 - x2, tx)) 74 | ty = max(-y1, min(imh - 1 - y2, ty)) 75 | bbox = Corner(x1 + tx, y1 + ty, x2 + tx, y2 + ty) 76 | 77 | if rd: 78 | bbox = Corner(*map(round, bbox)) 79 | 80 | current_center = corner2center(bbox) 81 | real_param['scale'] = current_center.w / original_center.w, current_center.h / original_center.h 82 | real_param['shift'] = current_center.x - original_center.x, current_center.y - original_center.y 83 | return bbox, real_param 84 | else: 85 | if 'scale' in param: 86 | scale_x, scale_y = param['scale'] 87 | else: 88 | scale_x, scale_y = 1., 1. 89 | if 'shift' in param: 90 | tx, ty = param['shift'] 91 | else: 92 | tx, ty = 0, 0 93 | center = corner2center(bbox) 94 | center = Center(center.x - tx, center.y - ty, center.w / scale_x, center.h / scale_y) 95 | return center2corner(center) 96 | 97 | def IoU(rect1, rect2): 98 | # overlap 99 | x1, y1, x2, y2 = rect1[0], rect1[1], rect1[2], rect1[3] 100 | tx1, ty1, tx2, ty2 = rect2[0], rect2[1], rect2[2], rect2[3] 101 | 102 | xx1 = np.maximum(tx1, x1) 103 | yy1 = np.maximum(ty1, y1) 104 | xx2 = np.minimum(tx2, x2) 105 | yy2 = np.minimum(ty2, y2) 106 | 107 | ww = np.maximum(0, xx2 - xx1) 108 | hh = np.maximum(0, yy2 - yy1) 109 | area = (x2-x1) * (y2-y1) 110 | target_a = (tx2-tx1) * (ty2 - ty1) 111 | inter = ww * hh 112 | overlap = inter / (area + target_a - inter + 1e-6) 113 | 114 | return overlap 115 | 116 | def get_axis_aligned_bbox(region): 117 | nv = region.size 118 | if nv == 8: 119 | cx = np.mean(region[0::2]) 120 | cy = np.mean(region[1::2]) 121 | x1 = min(region[0::2]) 122 | x2 = max(region[0::2]) 123 | y1 = min(region[1::2]) 124 | y2 = max(region[1::2]) 125 | A1 = np.linalg.norm(region[0:2] - region[2:4]) * np.linalg.norm(region[2:4] - region[4:6]) 126 | A2 = (x2 - x1) * (y2 - y1) 127 | s = np.sqrt(A1 / A2) 128 | w = s * (x2 - x1) + 1 129 | h = s * (y2 - y1) + 1 130 | else: 131 | x = region[0] 132 | y = region[1] 133 | w = region[2] 134 | h = region[3] 135 | cx = x+w/2 136 | cy = y+h/2 137 | return cx, cy, w, h -------------------------------------------------------------------------------- /scm/utils/load_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | logger = logging.getLogger('global') 4 | 5 | def check_keys(model, pretrained_state_dict): 6 | ckpt_keys = set(pretrained_state_dict.keys()) 7 | model_keys = set(model.state_dict().keys()) 8 | used_pretrained_keys = model_keys & ckpt_keys 9 | unused_pretrained_keys = ckpt_keys - model_keys 10 | missing_keys = model_keys - ckpt_keys 11 | if len(missing_keys) > 0: 12 | logger.info('[Warning] missing keys: {}'.format(missing_keys)) 13 | logger.info('missing keys:{}'.format(len(missing_keys))) 14 | if len(unused_pretrained_keys) > 0: 15 | logger.info('[Warning] unused_pretrained_keys: {}'.format(unused_pretrained_keys)) 16 | logger.info('unused checkpoint keys:{}'.format(len(unused_pretrained_keys))) 17 | logger.info('used keys:{}'.format(len(used_pretrained_keys))) 18 | assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' 19 | return True 20 | 21 | def remove_prefix(state_dict, prefix): 22 | ''' Old style model is stored with all names of parameters share common prefix 'module.' ''' 23 | logger.info('remove prefix \'{}\''.format(prefix)) 24 | f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x 25 | return {f(key): value for key, value in state_dict.items()} 26 | 27 | def load_pretrain(model, pretrained_path): 28 | logger.info('load pretrained model from {}'.format(pretrained_path)) 29 | if not torch.cuda.is_available(): 30 | pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage) 31 | else: 32 | device = torch.cuda.current_device() 33 | pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device)) 34 | 35 | if "state_dict" in pretrained_dict.keys(): 36 | pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.') 37 | else: 38 | pretrained_dict = remove_prefix(pretrained_dict, 'module.') 39 | 40 | try: 41 | check_keys(model, pretrained_dict) 42 | except: 43 | logger.info('[Warning]: using pretrain as features. Adding "features." as prefix') 44 | new_dict = {} 45 | for k, v in pretrained_dict.items(): 46 | k = 'features.' + k 47 | new_dict[k] = v 48 | pretrained_dict = new_dict 49 | check_keys(model, pretrained_dict) 50 | model.load_state_dict(pretrained_dict, strict=True) 51 | return model 52 | 53 | def restore_from(model, optimizer, ckpt_path): 54 | logger.info('restore from {}'.format(ckpt_path)) 55 | device = torch.cuda.current_device() 56 | ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage.cuda(device)) 57 | epoch = ckpt['epoch'] 58 | arch = ckpt['arch'] 59 | ckpt_model_dict = remove_prefix(ckpt['state_dict'], 'module.') 60 | check_keys(model, ckpt_model_dict) 61 | model.load_state_dict(ckpt_model_dict, strict=False) 62 | check_keys(optimizer, ckpt['optimizer']) 63 | optimizer.load_state_dict(ckpt['optimizer']) 64 | return model, optimizer, epoch, arch 65 | -------------------------------------------------------------------------------- /scm/utils/lr_helper.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import math 4 | from torch.optim.lr_scheduler import _LRScheduler 5 | 6 | class LRScheduler(_LRScheduler): 7 | def __init__(self, optimizer, last_epoch=-1): 8 | if 'lr_spaces' not in self.__dict__: 9 | raise Exception('lr_spaces must be set in "LRSchduler"') 10 | super(LRScheduler, self).__init__(optimizer, last_epoch) 11 | 12 | def get_cur_lr(self): 13 | return self.lr_spaces[self.last_epoch] 14 | 15 | def get_lr(self): 16 | epoch = self.last_epoch 17 | return [self.lr_spaces[epoch] * pg['initial_lr'] / self.start_lr for pg in self.optimizer.param_groups] 18 | 19 | def __repr__(self): 20 | return "({}) lr spaces: \n{}".format(self.__class__.__name__, self.lr_spaces) 21 | 22 | class LogScheduler(LRScheduler): 23 | def __init__(self, optimizer, start_lr=0.03, end_lr=5e-4, epochs=50, last_epoch=-1, **kwargs): 24 | self.start_lr = start_lr 25 | self.end_lr = end_lr 26 | self.epochs = epochs 27 | self.lr_spaces = np.logspace(math.log10(start_lr), math.log10(end_lr), epochs) 28 | 29 | super(LogScheduler, self).__init__(optimizer, last_epoch) 30 | 31 | class StepScheduler(LRScheduler): 32 | def __init__(self, optimizer, start_lr=0.01, end_lr=None, step=10, mult=0.1, epochs=50, last_epoch=-1, **kwargs): 33 | if end_lr is not None: 34 | if start_lr is None: 35 | start_lr = end_lr / (mult ** (epochs // step)) 36 | else: # for warm up policy 37 | mult = math.pow(end_lr/start_lr, 1. / (epochs // step)) 38 | self.start_lr = start_lr 39 | self.lr_spaces = self.start_lr * (mult**(np.arange(epochs) // step)) 40 | self.mult = mult 41 | self._step = step 42 | 43 | super(StepScheduler, self).__init__(optimizer, last_epoch) 44 | 45 | class MultiStepScheduler(LRScheduler): 46 | def __init__(self, optimizer, start_lr=0.01, end_lr=None, steps=[10,20,30,40], mult=0.5, epochs=50, last_epoch=-1, **kwargs): 47 | if end_lr is not None: 48 | if start_lr is None: 49 | start_lr = end_lr / (mult ** (len(steps))) 50 | else: 51 | mult = math.pow(end_lr/start_lr, 1. / len(steps)) 52 | self.start_lr = start_lr 53 | self.lr_spaces = self._build_lr(start_lr, steps, mult, epochs) 54 | self.mult = mult 55 | self.steps = steps 56 | super(MultiStepScheduler, self).__init__(optimizer, last_epoch) 57 | 58 | def _build_lr(self, start_lr, steps, mult, epochs): 59 | lr = [0] * epochs 60 | lr[0] = start_lr 61 | for i in range(1, epochs): 62 | lr[i] = lr[i-1] 63 | if i in steps: 64 | lr[i] *= mult 65 | return np.array(lr, dtype=np.float32) 66 | 67 | class LinearStepScheduler(LRScheduler): 68 | def __init__(self, optimizer, start_lr=0.01, end_lr=0.005, epochs=50, last_epoch=-1, **kwargs): 69 | self.start_lr = start_lr 70 | self.end_lr = end_lr 71 | self.lr_spaces = np.linspace(start_lr, end_lr, epochs) 72 | 73 | super(LinearStepScheduler, self).__init__(optimizer, last_epoch) 74 | 75 | class CosStepScheduler(LRScheduler): 76 | def __init__(self, optimizer, start_lr=0.01, end_lr=0.005, epochs=50, last_epoch=-1, **kwargs): 77 | self.start_lr = start_lr 78 | self.end_lr = end_lr 79 | self.lr_spaces = self._build_lr(start_lr, end_lr, epochs) 80 | super(CosStepScheduler, self).__init__(optimizer, last_epoch) 81 | 82 | def _build_lr(self, start_lr, end_lr, epochs): 83 | index = np.arange(epochs).astype(np.float32) 84 | lr = end_lr + (start_lr - end_lr) * (1. + np.cos(index * np.pi/ epochs)) * 0.5 85 | return lr.astype(np.float32) 86 | 87 | 88 | class WarmUPScheduler(LRScheduler): 89 | def __init__(self, optimizer, warmup, normal, epochs=50, last_epoch=-1): 90 | warmup = warmup.lr_spaces # [::-1] 91 | normal = normal.lr_spaces 92 | self.lr_spaces = np.concatenate([warmup, normal]) 93 | self.start_lr = normal[0] 94 | 95 | super(WarmUPScheduler, self).__init__(optimizer, last_epoch) 96 | 97 | 98 | LRs = { 99 | 'log': LogScheduler, 100 | 'step': StepScheduler, 101 | 'multi-step': MultiStepScheduler, 102 | 'linear': LinearStepScheduler, 103 | 'cos': CosStepScheduler} 104 | 105 | 106 | def _build_lr_scheduler(optimizer, cfg, epochs=50, last_epoch=-1): 107 | if 'type' not in cfg: 108 | # return LogScheduler(optimizer, last_epoch=last_epoch, epochs=epochs) 109 | cfg['type'] = 'log' 110 | 111 | if cfg['type'] not in LRs: 112 | raise Exception('Unknown type of LR Scheduler "%s"'%cfg['type']) 113 | 114 | return LRs[cfg['type']](optimizer, last_epoch=last_epoch, epochs=epochs, **cfg) 115 | 116 | 117 | def _build_warm_up_scheduler(optimizer, cfg, epochs=50, last_epoch=-1): 118 | warmup_epoch = cfg['warmup']['epoch'] 119 | sc1 = _build_lr_scheduler(optimizer, cfg['warmup'], warmup_epoch, last_epoch) 120 | sc2 = _build_lr_scheduler(optimizer, cfg, epochs - warmup_epoch, last_epoch) 121 | return WarmUPScheduler(optimizer, sc1, sc2, epochs, last_epoch) 122 | 123 | 124 | def build_lr_scheduler(optimizer, cfg, epochs=50, last_epoch=-1): 125 | if 'warmup' in cfg: 126 | return _build_warm_up_scheduler(optimizer, cfg, epochs, last_epoch) 127 | else: 128 | return _build_lr_scheduler(optimizer, cfg, epochs, last_epoch) 129 | 130 | -------------------------------------------------------------------------------- /scm/utils/tracker_config.py: -------------------------------------------------------------------------------- 1 | from scm.utils.anchors import Anchors 2 | 3 | class TrackerConfig(object): 4 | # These are the default hyper-params for SiamMask 5 | penalty_k = 0.09 6 | window_influence = 0.39 7 | lr = 0.38 8 | seg_thr = 0.3 # for mask 9 | windowing = 'cosine' # to penalize large displacements [cosine/uniform] 10 | # Params from the network architecture, have to be consistent with the training 11 | exemplar_size = 127 # input z size 12 | instance_size = 255 # input x size (search region) 13 | total_stride = 8 14 | out_size = 63 # for mask 15 | base_size = 8 16 | score_size = (instance_size-exemplar_size)//total_stride+1+base_size 17 | context_amount = 0.5 # context amount for the exemplar 18 | ratios = [0.33, 0.5, 1, 2, 3] 19 | scales = [8, ] 20 | anchor_num = len(ratios) * len(scales) 21 | round_dight = 0 22 | anchor = [] 23 | 24 | def update(self, newparam=None, anchors=None): 25 | if newparam: 26 | for key, value in newparam.items(): 27 | setattr(self, key, value) 28 | if anchors is not None: 29 | if isinstance(anchors, dict): 30 | anchors = Anchors(anchors) 31 | if isinstance(anchors, Anchors): 32 | self.total_stride = anchors.stride 33 | self.ratios = anchors.ratios 34 | self.scales = anchors.scales 35 | self.round_dight = anchors.round_dight 36 | self.renew() 37 | 38 | def renew(self): 39 | self.score_size = (self.instance_size - self.exemplar_size) // self.total_stride + 1 + self.base_size 40 | self.anchor_num = len(self.ratios) * len(self.scales) -------------------------------------------------------------------------------- /track_textboxPP.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import cv2 4 | import numpy as np 5 | import xml.etree.ElementTree as ET 6 | from utils import utils 7 | from tracker.db_text_multitracker import JDETracker 8 | 9 | def get(root, name): 10 | vars = root.findall(name) 11 | return vars 12 | 13 | def draw_gt(xml_dir, online_im, frame_id, opt): 14 | tree = ET.parse(xml_dir) 15 | root = tree.getroot() 16 | frames = get(root, 'frame') 17 | try: 18 | frame = frames[frame_id] 19 | except: 20 | return online_im 21 | if opt.dataset=='icdar': 22 | assert int(frame_id)+1==int(frame.attrib['ID']) 23 | elif opt.dataset=='minetto': 24 | assert int(frame_id)==int(frame.attrib['ID']) 25 | elif opt.dataset=='roadtext_test': 26 | assert int(frame_id)+1==int(frame.attrib['ID']) 27 | 28 | objects = get(frame, 'object') 29 | for obj in objects: 30 | try: 31 | quality = obj.attrib['Quality'] # ['MODERATE', 'LOW', 'HIGH', 'MODERTE'] 32 | if quality=='LOW': 33 | continue 34 | except: quality='HIGH' 35 | Points = get(obj, 'Point') 36 | xs = [] 37 | ys = [] 38 | for point in Points: 39 | xs.append(float(point.attrib['x'])) 40 | ys.append(float(point.attrib['y'])) 41 | cv2.polylines(online_im, [np.array([[int(xs[0]),int(ys[0])],[int(xs[1]),int(ys[1])], \ 42 | [int(xs[2]),int(ys[2])],[int(xs[3]),int(ys[3])]])], True, (255,255, 255),1) 43 | return online_im 44 | 45 | def eval_seq(opt, dataloader, video_name, frame_dir=None, show_image=False, video_writer=None, timer=None): 46 | 47 | tracker = JDETracker(opt, frame_rate=dataloader.frame_rate) 48 | results = [] 49 | frame_id = 0 50 | pre_img0 = None 51 | pre_boxes = None 52 | 53 | for img_path, img0 in dataloader: 54 | boxes=[] 55 | 56 | # run tracking 57 | online_targets, pre_img0, pre_boxes = tracker.update(img_path, img0, add_vot_track=opt.add_vot_track, \ 58 | pre_img0=pre_img0, pre_boxes=pre_boxes) 59 | online_ids = [] 60 | for t in online_targets: 61 | tlwh = t._tlwh 62 | tid = t.track_id 63 | if tlwh[2] * tlwh[3] >= opt.min_box_area: 64 | boxes.append(list(t.pt)+[int(tid)]) 65 | online_ids.append(tid) 66 | results.append(boxes) 67 | 68 | if opt.eval_det: 69 | eval_det_dir = osp.join(opt.output_root, 'det_res') 70 | if not osp.exists(eval_det_dir): 71 | os.makedirs(eval_det_dir) 72 | utils.save_det_res(img_path.split('/')[-1].replace('jpg', 'txt'), video_name, pre_boxes, eval_det_dir, opt.dataset) 73 | 74 | if show_image: 75 | utils.mkdir_if_missing(frame_dir) 76 | pred_im = img0.copy() 77 | for i in range(len(boxes)): 78 | cv2.polylines(pred_im, [np.array(boxes[i][:8]).reshape(-1, 2).astype(np.int32)], True, (127,255,0),2) 79 | cv2.putText(pred_im,'{}'.format(int(boxes[i][8])),(int(pre_boxes[i][0]), \ 80 | int(pre_boxes[i][1])),cv2.FONT_HERSHEY_SIMPLEX,0.6,(255,255,255),1,cv2.LINE_AA) 81 | cv2.imwrite(os.path.join(frame_dir, '{:05d}_pred.jpg'.format(frame_id)), pred_im) 82 | if opt.show_gt: 83 | assert opt.gt_dir != '', 'should give gt dir when show_gt is True' 84 | gt_im = img0.copy() 85 | gt_name = video_name + '_GT.xml' 86 | xml_f = os.path.join(opt.input_root, gt_name) 87 | gt_im = draw_gt(xml_f, gt_im, frame_id,opt) 88 | cv2.imwrite(os.path.join(frame_dir, '{:05d}_gt.jpg'.format(frame_id)), gt_im) 89 | 90 | frame_id += 1 91 | 92 | # save results 93 | if opt.sub_res: 94 | save_dir = os.path.join(opt.output_root, opt.sub_res_root) 95 | if not os.path.exists(save_dir): 96 | os.makedirs(save_dir) 97 | 98 | if opt.dataset == 'icdar': 99 | videonum = video_name.split('_')[1] 100 | xml_name = os.path.join(save_dir, 'res_video_'+videonum+'.xml') 101 | utils.write2xml(xml_name, results, change_id=True) 102 | elif opt.dataset == 'roadtext' or opt.dataset=='minetto': 103 | txt_name = os.path.join(save_dir, video_name+'.txt') 104 | utils.write2txt(txt_name, results) -------------------------------------------------------------------------------- /tracker/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lsabrinax/VideoTextSCM/d87ad1bbb6ada7573a02a82045ee1b9ead5861ad/tracker/__init__.py -------------------------------------------------------------------------------- /tracker/basetrack.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import OrderedDict 3 | from numba import jit 4 | from . import matching 5 | 6 | class TrackState(object): 7 | New = 0 8 | Tracked = 1 9 | Lost = 2 10 | Removed = 3 11 | 12 | class BaseTrack(object): 13 | _count = 0 14 | track_id = 0 15 | is_activated = False 16 | state = TrackState.New 17 | history = OrderedDict() 18 | features = [] 19 | curr_feature = None 20 | score = 0 21 | start_frame = 0 22 | frame_id = 0 23 | time_since_update = 0 24 | # multi-camera 25 | location = (np.inf, np.inf) 26 | 27 | @property 28 | def end_frame(self): 29 | return self.frame_id 30 | 31 | @staticmethod 32 | def next_id(): 33 | BaseTrack._count += 1 34 | return BaseTrack._count 35 | 36 | def new_video(): 37 | BaseTrack._count = 0 38 | return BaseTrack._count 39 | 40 | def activate(self, *args): 41 | raise NotImplementedError 42 | 43 | def predict(self): 44 | raise NotImplementedError 45 | 46 | def update(self, *args, **kwargs): 47 | raise NotImplementedError 48 | 49 | def mark_lost(self): 50 | self.state = TrackState.Lost 51 | 52 | def mark_removed(self): 53 | self.state = TrackState.Removed 54 | 55 | class STrack(BaseTrack): 56 | 57 | def __init__(self, tlwh, score, temp_feat, pt, buffer_size=30, cur_frame=0, img_path=None): 58 | # wait activate 59 | self._tlwh = np.asarray(tlwh, dtype=np.float) 60 | self.kalman_filter = None 61 | self.mean, self.covariance = None, None 62 | self.is_activated = False 63 | self.pt = pt 64 | self.score = score 65 | self.smooth_feat = None 66 | self.update_features(temp_feat) 67 | self.alpha = 0.9 68 | self.cur_frame = cur_frame 69 | self.img_path = img_path 70 | 71 | def update_features(self, feat): 72 | self.curr_feat = feat 73 | if self.smooth_feat is None: 74 | self.smooth_feat = feat 75 | else: 76 | self.smooth_feat = self.alpha *self.smooth_feat + (1-self.alpha) * feat 77 | self.smooth_feat /= np.linalg.norm(self.smooth_feat) 78 | 79 | def predict(self): 80 | mean_state = self.mean.copy() 81 | if self.state != TrackState.Tracked: 82 | mean_state[7] = 0 83 | self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance) 84 | 85 | def activate(self, kalman_filter, frame_id): 86 | """Start a new tracklet""" 87 | self.kalman_filter = kalman_filter 88 | self.track_id = self.next_id() 89 | self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh)) 90 | self.state = TrackState.Tracked 91 | self.is_activated = True 92 | self.frame_id = frame_id 93 | self.start_frame = frame_id 94 | 95 | def re_activate(self, new_track, frame_id, new_id=False): 96 | self.mean, self.covariance = self.kalman_filter.update( 97 | self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh) 98 | ) 99 | self.pt = new_track.pt 100 | self.update_features(new_track.curr_feat) 101 | self.state = TrackState.Tracked 102 | self.is_activated = True 103 | self.frame_id = frame_id 104 | if new_id: 105 | self.track_id = self.next_id() 106 | self.img_path = new_track.img_path 107 | self._tlwh = self.tlwh 108 | self.cur_frame = new_track.cur_frame 109 | self.img_path = new_track.img_path 110 | 111 | def update(self, new_track, frame_id, update_feature=True): 112 | """ 113 | Update a matched track 114 | :type new_track: STrack 115 | :type frame_id: int 116 | :type update_feature: bool 117 | :return: 118 | """ 119 | self.frame_id = frame_id 120 | new_tlwh = new_track.tlwh 121 | self.pt = new_track.pt 122 | self.mean, self.covariance = self.kalman_filter.update( 123 | self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh)) 124 | self.state = TrackState.Tracked 125 | self.is_activated = True 126 | self.score = new_track.score 127 | if update_feature: 128 | self.update_features(new_track.curr_feat) 129 | self.img_path = new_track.img_path 130 | self._tlwh = self.tlwh 131 | self.cur_frame = new_track.cur_frame 132 | 133 | @property 134 | @jit 135 | def tlwh(self): 136 | """Get current position in bounding box format `(top left x, top left y, 137 | width, height)`. 138 | """ 139 | if self.mean is None: 140 | return self._tlwh.copy() 141 | ret = self.mean[:4].copy() 142 | ret[2] *= ret[3] 143 | ret[:2] -= ret[2:] / 2 144 | return ret 145 | 146 | @property 147 | @jit 148 | def tlbr(self): 149 | """Convert bounding box to format `(min x, min y, max x, max y)`, i.e., 150 | `(top left, bottom right)`. 151 | """ 152 | ret = self.tlwh.copy() 153 | ret[2:] += ret[:2] 154 | return ret 155 | 156 | @staticmethod 157 | @jit 158 | def tlwh_to_xyah(tlwh): 159 | """Convert bounding box to format `(center x, center y, aspect ratio, 160 | height)`, where the aspect ratio is `width / height`. 161 | """ 162 | ret = np.asarray(tlwh).copy() 163 | ret[:2] += ret[2:] / 2 164 | ret[2] /= ret[3] 165 | return ret 166 | 167 | def to_xyah(self): 168 | return self.tlwh_to_xyah(self.tlwh) 169 | 170 | @staticmethod 171 | @jit 172 | def tlbr_to_tlwh(tlbr): 173 | ret = np.asarray(tlbr).copy() 174 | ret[2:] -= ret[:2] 175 | return ret 176 | 177 | @staticmethod 178 | @jit 179 | def tlwh_to_tlbr(tlwh): 180 | ret = np.asarray(tlwh).copy() 181 | ret[2:] += ret[:2] 182 | return ret 183 | 184 | def __repr__(self): 185 | return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame) 186 | 187 | def joint_stracks(tlista, tlistb): 188 | exists = {} 189 | res = [] 190 | for t in tlista: 191 | exists[t.track_id] = 1 192 | res.append(t) 193 | for t in tlistb: 194 | tid = t.track_id 195 | if not exists.get(tid, 0): 196 | exists[tid] = 1 197 | res.append(t) 198 | return res 199 | 200 | def sub_stracks(tlista, tlistb): 201 | stracks = {} 202 | for t in tlista: 203 | stracks[t.track_id] = t 204 | for t in tlistb: 205 | tid = t.track_id 206 | if stracks.get(tid, 0): 207 | del stracks[tid] 208 | return list(stracks.values()) 209 | 210 | def remove_duplicate_stracks(stracksa, stracksb): 211 | pdist = matching.poly_distance(stracksa, stracksb) 212 | pairs = np.where(pdist<0.15) 213 | dupa, dupb = list(), list() 214 | for p,q in zip(*pairs): 215 | timep = stracksa[p].frame_id - stracksa[p].start_frame 216 | timeq = stracksb[q].frame_id - stracksb[q].start_frame 217 | if timep > timeq: 218 | dupb.append(q) 219 | else: 220 | dupa.append(p) 221 | resa = [t for i,t in enumerate(stracksa) if not i in dupa] 222 | resb = [t for i,t in enumerate(stracksb) if not i in dupb] 223 | return resa, resb -------------------------------------------------------------------------------- /tracker/db_text_multitracker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os.path as osp 3 | import os 4 | from db_model.db_embedding import Demo 5 | from .basetrack import * 6 | from .kalman_filter import KalmanFilter 7 | 8 | class JDETracker(object): 9 | def __init__(self, opt, frame_rate=30): 10 | self.save_root = osp.join(opt.result_root, 'det_out') 11 | self.save_det_out = opt.save_det_out 12 | self.model = Demo(opt.weight_path, opt.scm_weight_path, opt.scm_config,opt.img_min_size, opt.conf_thresh) 13 | self.tracked_stracks = [] # type: list[STrack] 14 | self.lost_stracks = [] # type: list[STrack] 15 | self.removed_stracks = [] # type: list[STrack] 16 | self.frame_id = -1 17 | self.buffer_size = frame_rate 18 | self.max_time_lost = self.buffer_size 19 | self.kalman_filter = KalmanFilter() 20 | STrack.new_video() 21 | 22 | def update(self, img_path, img0, add_vot_track=0, pre_img0=None, pre_boxes=None): 23 | 24 | self.frame_id += 1 25 | print('================frame_id: {}================='.format(self.frame_id)) 26 | activated_starcks = [] 27 | refind_stracks = [] 28 | lost_stracks = [] 29 | removed_stracks = [] 30 | 31 | strack_pool = joint_stracks(self.tracked_stracks, self.lost_stracks) 32 | 33 | ''' Step 1: Network forward, get detections & embeddings 34 | pred_f: shape=(n,512) 35 | pred_boxes: shape=(n, 4) [x1,y1,x3,y3] 36 | pre_boxes: shape=(n,9) [x1,y1,x2,y2,x3,y3,x4,y4,score] 37 | ''' 38 | with torch.no_grad(): 39 | pred_pt = img_path.split('/')[-1].replace('.jpg', '_pred.pt') 40 | pred_pt = osp.join(self.save_root, pred_pt) 41 | if not os.path.isfile(pred_pt): 42 | pred_f, pred_boxes, pre_boxes, no_objs = self.model.inference(img_path,img0, pre_img0=pre_img0, pre_boxes=pre_boxes, add_vot_track=add_vot_track) 43 | if self.save_det_out: 44 | if not os.path.exists(self.save_root): 45 | os.makedirs(self.save_root) 46 | save_dict = {'pred_f': pred_f, 'pred_boxes': pred_boxes, 'pre_boxes':pre_boxes, 'no_objs': no_objs} 47 | torch.save(save_dict, pred_pt) 48 | else: 49 | save_dict = torch.load(pred_pt) 50 | pred_f, pred_boxes, pre_boxes, no_objs = save_dict['pred_f'], save_dict['pred_boxes'], save_dict['pre_boxes'], save_dict['no_objs'] 51 | 52 | if no_objs == False and len(pred_f) > 0: 53 | detections = [STrack(STrack.tlbr_to_tlwh(tlbr), s, f.numpy(), pt.clone().numpy(), self.buffer_size, self.frame_id, img_path) for 54 | (tlbr, s, f, pt) in zip(pred_boxes, pre_boxes[:, 8], pred_f, pre_boxes[:, :8])] 55 | else: 56 | detections = [] 57 | 58 | ''' Step 2: association, with embedding, iou and shape''' 59 | for strack in strack_pool: 60 | strack.predict() 61 | dists1 = matching.embedding_distance(strack_pool, detections) 62 | dists2 = matching.poly_distance(strack_pool, detections) 63 | dists3 = matching.shape_distance(strack_pool, detections) 64 | # dists = 0.6*dists1 + 0.4*dists3 65 | dists = 0.6*dists1 + 0.2*dists2 + 0.2*dists3 66 | # dists = 0.4*dists1 + 0.3*dists2 + 0.3*dists3 67 | # dists = 0.8*dists1 + 0.1*dists2 + 0.1*dists3 68 | # dists = 0.6*dists1 + 0.3*dists2 + 0.1*dists3 69 | # dists = 0.6*dists1 + 0.1*dists2 + 0.3*dists3 70 | # dists = 0.5*dists1 + 0.2*dists2 + 0.3*dists3 71 | # dists = 0.5*dists1 + 0.3*dists2 + 0.2*dists3 72 | # dists = 0.5*dists1 + 0.25*dists2 + 0.25*dists3 73 | # dists = 0.7*dists1 + 0.2*dists2 + 0.1*dists3 74 | # dists = 0.7*dists1 + 0.1*dists2 + 0.2*dists3 75 | # dists = 0.7*dists1 + 0.15*dists2 + 0.15*dists3 76 | matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7) 77 | 78 | ''' step3: Add newly detected tracklets to tracked_stracks''' 79 | for itracked, idet in matches: 80 | track = strack_pool[itracked] 81 | det = detections[idet] 82 | if track.state == TrackState.Tracked: 83 | track.update(det, self.frame_id) 84 | activated_starcks.append(track) 85 | else: 86 | track.re_activate(det, self.frame_id, new_id=False) 87 | refind_stracks.append(track) 88 | 89 | '''step4: mark the losted stracks''' 90 | for it in u_track: 91 | track = strack_pool[it] 92 | if not track.state == TrackState.Lost: 93 | track.mark_lost() 94 | lost_stracks.append(track) 95 | 96 | """ Step 5: Init new stracks""" 97 | for inew in u_detection: 98 | track = detections[inew] 99 | track.activate(self.kalman_filter, self.frame_id) 100 | activated_starcks.append(track) 101 | 102 | """ Step 6: Update state""" 103 | for track in self.lost_stracks: 104 | if self.frame_id - track.end_frame > self.max_time_lost: 105 | track.mark_removed() 106 | removed_stracks.append(track) 107 | 108 | self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked] 109 | self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks) 110 | self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks) 111 | self.lost_stracks = [t for t in self.lost_stracks if t.state == TrackState.Lost] 112 | self.lost_stracks.extend(lost_stracks) 113 | self.removed_stracks.extend(removed_stracks) 114 | self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks) 115 | output_stracks = [track for track in self.tracked_stracks if track.is_activated] 116 | 117 | return output_stracks,img0, pre_boxes -------------------------------------------------------------------------------- /tracker/kalman_filter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.linalg 3 | 4 | """ 5 | Table for the 0.95 quantile of the chi-square distribution with N degrees of 6 | freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv 7 | function and used as Mahalanobis gating threshold. 8 | """ 9 | chi2inv95 = { 10 | 1: 3.8415, 11 | 2: 5.9915, 12 | 3: 7.8147, 13 | 4: 9.4877, 14 | 5: 11.070, 15 | 6: 12.592, 16 | 7: 14.067, 17 | 8: 15.507, 18 | 9: 16.919} 19 | 20 | class KalmanFilter(object): 21 | """ 22 | A simple Kalman filter for tracking bounding boxes in image space. 23 | The 8-dimensional state space: x, y, a, h, vx, vy, va, vh 24 | contains the bounding box center position (x, y), aspect ratio a, height h, 25 | and their respective velocities. 26 | Object motion follows a constant velocity model. The bounding box location 27 | (x, y, a, h) is taken as direct observation of the state space (linear 28 | observation model). 29 | """ 30 | 31 | def __init__(self): 32 | ndim, dt = 4, 1. 33 | 34 | # Create Kalman filter model matrices. 35 | self._motion_mat = np.eye(2 * ndim, 2 * ndim) 36 | for i in range(ndim): 37 | self._motion_mat[i, ndim + i] = dt 38 | self._update_mat = np.eye(ndim, 2 * ndim) 39 | 40 | # Motion and observation uncertainty are chosen relative to the current 41 | # state estimate. These weights control the amount of uncertainty in 42 | # the model. This is a bit hacky. 43 | self._std_weight_position = 1. / 20 44 | self._std_weight_velocity = 1. / 160 45 | 46 | def initiate(self, measurement): 47 | """Create track from unassociated measurement. 48 | Parameters 49 | ---------- 50 | measurement : ndarray 51 | Bounding box coordinates (x, y, a, h) with center position (x, y), 52 | aspect ratio a, and height h. 53 | Returns 54 | ------- 55 | (ndarray, ndarray) 56 | Returns the mean vector (8 dimensional) and covariance matrix (8x8 57 | dimensional) of the new track. Unobserved velocities are initialized 58 | to 0 mean. 59 | """ 60 | mean_pos = measurement 61 | mean_vel = np.zeros_like(mean_pos) 62 | mean = np.r_[mean_pos, mean_vel] 63 | 64 | std = [ 65 | 2 * self._std_weight_position * measurement[3], 66 | 2 * self._std_weight_position * measurement[3], 67 | 1e-2, 68 | 2 * self._std_weight_position * measurement[3], 69 | 10 * self._std_weight_velocity * measurement[3], 70 | 10 * self._std_weight_velocity * measurement[3], 71 | 1e-5, 72 | 10 * self._std_weight_velocity * measurement[3]] 73 | covariance = np.diag(np.square(std)) 74 | return mean, covariance 75 | 76 | def predict(self, mean, covariance): 77 | """Run Kalman filter prediction step. 78 | Parameters 79 | ---------- 80 | mean : ndarray 81 | The 8 dimensional mean vector of the object state at the previous 82 | time step. 83 | covariance : ndarray 84 | The 8x8 dimensional covariance matrix of the object state at the 85 | previous time step. 86 | Returns 87 | ------- 88 | (ndarray, ndarray) 89 | Returns the mean vector and covariance matrix of the predicted 90 | state. Unobserved velocities are initialized to 0 mean. 91 | 92 | """ 93 | std_pos = [ 94 | self._std_weight_position * mean[3], 95 | self._std_weight_position * mean[3], 96 | 1e-2, 97 | self._std_weight_position * mean[3]] 98 | std_vel = [ 99 | self._std_weight_velocity * mean[3], 100 | self._std_weight_velocity * mean[3], 101 | 1e-5, 102 | self._std_weight_velocity * mean[3]] 103 | motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) 104 | mean = np.dot(self._motion_mat, mean) 105 | covariance = np.linalg.multi_dot(( 106 | self._motion_mat, covariance, self._motion_mat.T)) + motion_cov 107 | 108 | return mean, covariance 109 | 110 | def project(self, mean, covariance): 111 | """Project state distribution to measurement space. 112 | Parameters 113 | ---------- 114 | mean : ndarray 115 | The state's mean vector (8 dimensional array). 116 | covariance : ndarray 117 | The state's covariance matrix (8x8 dimensional). 118 | Returns 119 | ------- 120 | (ndarray, ndarray) 121 | Returns the projected mean and covariance matrix of the given state 122 | estimate. 123 | """ 124 | std = [ 125 | self._std_weight_position * mean[3], 126 | self._std_weight_position * mean[3], 127 | 1e-1, 128 | self._std_weight_position * mean[3]] 129 | innovation_cov = np.diag(np.square(std)) 130 | 131 | mean = np.dot(self._update_mat, mean) 132 | covariance = np.linalg.multi_dot(( 133 | self._update_mat, covariance, self._update_mat.T)) 134 | return mean, covariance + innovation_cov 135 | 136 | def update(self, mean, covariance, measurement): 137 | """Run Kalman filter correction step. 138 | Parameters 139 | ---------- 140 | mean : ndarray 141 | The predicted state's mean vector (8 dimensional). 142 | covariance : ndarray 143 | The state's covariance matrix (8x8 dimensional). 144 | measurement : ndarray 145 | The 4 dimensional measurement vector (x, y, a, h), where (x, y) 146 | is the center position, a the aspect ratio, and h the height of the 147 | bounding box. 148 | Returns 149 | ------- 150 | (ndarray, ndarray) 151 | Returns the measurement-corrected state distribution. 152 | 153 | """ 154 | projected_mean, projected_cov = self.project(mean, covariance) 155 | 156 | chol_factor, lower = scipy.linalg.cho_factor( 157 | projected_cov, lower=True, check_finite=False) 158 | kalman_gain = scipy.linalg.cho_solve( 159 | (chol_factor, lower), np.dot(covariance, self._update_mat.T).T, 160 | check_finite=False).T 161 | innovation = measurement - projected_mean 162 | 163 | new_mean = mean + np.dot(innovation, kalman_gain.T) 164 | new_covariance = covariance - np.linalg.multi_dot(( 165 | kalman_gain, projected_cov, kalman_gain.T)) 166 | return new_mean, new_covariance 167 | 168 | def gating_distance(self, mean, covariance, measurements, 169 | only_position=False): 170 | """Compute gating distance between state distribution and measurements. 171 | 172 | A suitable distance threshold can be obtained from `chi2inv95`. If 173 | `only_position` is False, the chi-square distribution has 4 degrees of 174 | freedom, otherwise 2. 175 | Parameters 176 | ---------- 177 | mean : ndarray 178 | Mean vector over the state distribution (8 dimensional). 179 | covariance : ndarray 180 | Covariance of the state distribution (8x8 dimensional). 181 | measurements : ndarray 182 | An Nx4 dimensional matrix of N measurements, each in 183 | format (x, y, a, h) where (x, y) is the bounding box center 184 | position, a the aspect ratio, and h the height. 185 | only_position : Optional[bool] 186 | If True, distance computation is done with respect to the bounding 187 | box center position only. 188 | Returns 189 | ------- 190 | ndarray 191 | Returns an array of length N, where the i-th element contains the 192 | squared Mahalanobis distance between (mean, covariance) and 193 | `measurements[i]`. 194 | """ 195 | mean, covariance = self.project(mean, covariance) 196 | if only_position: 197 | mean, covariance = mean[:2], covariance[:2, :2] 198 | measurements = measurements[:, :2] 199 | 200 | cholesky_factor = np.linalg.cholesky(covariance) 201 | d = measurements - mean 202 | z = scipy.linalg.solve_triangular( 203 | cholesky_factor, d.T, lower=True, check_finite=False, 204 | overwrite_b=True) 205 | squared_maha = np.sum(z * z, axis=0) 206 | return squared_maha -------------------------------------------------------------------------------- /tracker/matching.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils.utils import polygon_iou 3 | import cv2 4 | from scipy.optimize import linear_sum_assignment 5 | 6 | def _indices_to_matches(cost_matrix, indices, thresh): 7 | indices_=[] 8 | for i in zip(indices[0], indices[1]): 9 | indices_.append(i) 10 | indices=np.array(indices_) 11 | matched_cost = cost_matrix[tuple(zip(*indices))] 12 | matched_mask = (matched_cost <= thresh) 13 | matches = indices[matched_mask] 14 | unmatched_a = tuple(set(range(cost_matrix.shape[0])) - set(matches[:, 0])) 15 | unmatched_b = tuple(set(range(cost_matrix.shape[1])) - set(matches[:, 1])) 16 | return matches, unmatched_a, unmatched_b 17 | 18 | def linear_assignment(cost_matrix, thresh): 19 | """ 20 | Simple linear assignment 21 | :type cost_matrix: np.ndarray 22 | :type thresh: float 23 | :return: matches, unmatched_a, unmatched_b 24 | """ 25 | if cost_matrix.size == 0: 26 | return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1])) 27 | 28 | cost_matrix[cost_matrix > thresh] = thresh + 1e-4 29 | indices = linear_sum_assignment(cost_matrix) 30 | 31 | return _indices_to_matches(cost_matrix, indices, thresh) 32 | 33 | def poly_distance(atracks, btracks): 34 | 35 | apts = [track.pt for track in atracks] 36 | bpts = [track.pt for track in btracks] 37 | apts = np.ascontiguousarray(apts, dtype=np.float) 38 | bpts = np.ascontiguousarray(bpts, dtype=np.float) 39 | 40 | if apts.size==0 or bpts.size==0: 41 | _ious = np.zeros((len(apts), len(bpts)), dtype=np.float) 42 | else: 43 | _ious = polygon_iou(apts, bpts) 44 | cost_matrix = 1 - _ious 45 | return cost_matrix 46 | 47 | 48 | def rec_dis(apts, bpts): 49 | dis = np.zeros((len(apts), len(bpts)), dtype=np.float) 50 | if dis.size == 0: 51 | return dis 52 | apts = np.ascontiguousarray(apts, dtype=np.float) 53 | bpts = np.ascontiguousarray(bpts, dtype=np.float) 54 | apts = np.expand_dims(apts, 1) 55 | apts = np.tile(apts, (1, bpts.shape[0], 1)) 56 | bpts = np.expand_dims(bpts, 0) 57 | bpts = np.tile(bpts, (apts.shape[0], 1, 1)) 58 | 59 | dis = abs(apts - bpts) 60 | dis[:, :, :4] /= 10 61 | dis[:, :, 5][dis[:, :, 5]>45] = 90 - dis[:, :, 5][dis[:, :, 5]>45] 62 | dis[:, :, 5] /= 10 63 | dis = 0.3*(dis[:, :, :4].sum(axis=-1))/(4*dis[:,:,6]) + 0.7*dis[:,:,4:6].sum(axis=-1)/2 64 | return dis 65 | 66 | def shape_distance(atracks, btracks): 67 | 68 | apts = [] 69 | for track in atracks: 70 | ct,wh,th = cv2.minAreaRect(track.pt.reshape(4,2)) 71 | if wh[0] 2000: 35 | rand_triple = torch.randint(0,num_of_triple, size=(2000,)).long() 36 | triple_list = triple_list[rand_triple,:] 37 | 38 | triple_feat = roi_feat[triple_list] 39 | all_boxes = torch.cat([cur_boxes[i], next_boxes[i]]) 40 | triple_box = all_boxes[triple_list] 41 | 42 | pos_dist = torch.sum(torch.pow(triple_feat[:,0,:] - triple_feat[:,1,:],2), 1) 43 | neg_dist = torch.sum(torch.pow(triple_feat[:,0,:] - triple_feat[:,2,:],2), 1) 44 | anchor_center = (triple_box[:, 0, 2:] + triple_box[:, 0, :2])/2 45 | positive_center = (triple_box[:, 1, 2:] + triple_box[:, 1, :2])/2 46 | negative_center = (triple_box[:, 2, 2:] + triple_box[:, 2, :2])/2 47 | W_scale = torch.exp(1-torch.sum(torch.abs(triple_box[:, 0, 2:] - triple_box[:, 0, :2]), 1)/2) 48 | W_pos_dis = 1-torch.exp(-(torch.sqrt(torch.sum(torch.pow(anchor_center - positive_center, 2),1)))) 49 | W_neg_dis = 1-torch.exp(-(torch.sqrt(torch.sum(torch.pow(anchor_center - negative_center, 2),1)))) 50 | loss_pos = pos_dist*(W_scale+W_pos_dis) - alpha1 51 | loss_neg = alpha2 - neg_dist*W_neg_dis 52 | loss = torch.sum(torch.max(loss_pos, torch.zeros_like(loss_pos)) + torch.max(loss_neg, torch.zeros_like(loss_neg)), 0)/triple_box.shape[0] 53 | losses += loss 54 | 55 | return losses / batch_size 56 | 57 | def collect_fn(batch): 58 | cur_imgs, next_imgs, cur_boxes, next_boxes, triple_lists = zip(*batch) 59 | cur_imgs = torch.stack(cur_imgs) 60 | next_imgs = torch.stack(next_imgs) 61 | return cur_imgs, next_imgs, list(cur_boxes), list(next_boxes), list(triple_lists) 62 | 63 | def train(opt): 64 | 65 | exp_dir = Path(opt.exp_name) 66 | log_dir = 'db_model/log' / exp_dir 67 | save_dir = 'db_model/weights' / exp_dir 68 | 69 | #init and load weight for DB 70 | if torch.cuda.is_available(): 71 | device = torch.device('cuda') 72 | else: 73 | device = torch.device('cpu') 74 | net = DB_Embedding_Model() 75 | state = torch.load(opt.weight_path) 76 | state = {k.replace('module.', ''): v for k, v in state.items()} 77 | net.db.load_state_dict(state, strict=True) 78 | net = net.to(device) 79 | net.db.eval() 80 | for name,param in net.named_parameters(): 81 | if 'db' in name: 82 | param.requires_grad = False 83 | 84 | #optimizer 85 | optimizer = optim.RMSprop(net.parameters(), lr=opt.lr, alpha=0.9, eps=1e-4, weight_decay=0.0001) 86 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,[15,30,45,50,75], gamma=0.3) 87 | # log 88 | if not os.path.exists(save_dir): 89 | os.makedirs(save_dir) 90 | if not os.path.exists(log_dir): 91 | os.makedirs(log_dir) 92 | writer = SummaryWriter(log_dir=log_dir) 93 | logger = init_log('embed') 94 | add_file_handler('embed',os.path.join(save_dir, "log.txt")) 95 | 96 | #dataset 97 | train_data = VideoDataset(opt.img_dir, opt.gt_dir, opt.train_list_file) 98 | train_loader = DataLoader(train_data, batch_size=opt.batch_size, \ 99 | num_workers=opt.num_workers, shuffle=True, collate_fn=collect_fn) 100 | 101 | batch_time = Timer() 102 | data_time = Timer() 103 | losses = AverageMeter() 104 | 105 | global_step = 0 106 | for epoch in range(opt.epoch_num): 107 | lr = scheduler.get_last_lr()[0] 108 | logger.info(f'now learning rate is {lr}') 109 | for i, input_ in enumerate(train_loader): 110 | global_step += 1 111 | batch_time.tic() 112 | data_time.tic() 113 | cur_imgs, next_imgs, cur_boxes, next_boxes, triple_lists = input_ 114 | imgs = torch.cat([cur_imgs, next_imgs]).to(device) 115 | cur_boxes = [boxes.to(device) for boxes in cur_boxes] 116 | next_boxes = [boxes.to(device) for boxes in next_boxes] 117 | triple_lists = [triple_list.to(device) for triple_list in triple_lists] 118 | all_boxes = cur_boxes + next_boxes 119 | data_time.toc() 120 | 121 | pred, roi_feats = net(imgs, all_boxes) 122 | loss_embd = spatical_triplet_loss(roi_feats, cur_boxes, next_boxes, triple_lists) 123 | loss = loss_embd 124 | if loss <= 0: 125 | continue 126 | 127 | writer.add_scalar('Loss/train_cur', loss.item(), global_step) 128 | optimizer.zero_grad() 129 | loss.backward() 130 | optimizer.step() 131 | losses.update(loss.item()) 132 | batch_time.toc() 133 | writer.add_scalar('Loss/train_avg', losses.avg, global_step) 134 | 135 | if global_step % 50 == 0: 136 | logger.info("epoch:[{}/{}] iter:[{}/{}] data_time:{:.3f} batch_time:{:.3f} || loss:{:.4f}/{:.4f}".format( 137 | epoch, opt.epoch_num, 138 | i+1, len(train_loader), 139 | data_time.average_time, 140 | batch_time.average_time, 141 | losses.val, losses.avg 142 | )) 143 | scheduler.step() 144 | if (epoch+1) % 20 == 0: 145 | save_path = os.path.join(save_dir, "db_embedding_weight_epoch{}.pth".format(epoch+1)) 146 | torch.save(net.state_dict(), save_path) 147 | logger.info("save weight at epoch={} in {}".format(epoch, save_path)) 148 | 149 | if __name__ == "__main__": 150 | parser = argparse.ArgumentParser() 151 | parser.add_argument('--img_dir', type=str, default='./datasets/video_train') 152 | parser.add_argument('--gt_dir', type=str, default='./datasets/video_train') 153 | parser.add_argument('--train_list_file', type=str, default='./datasets/video_train/db_train_valid_pair_list.txt') 154 | parser.add_argument('--weight_path', type=str, default="./db_model/weights/totaltext_resnet50") 155 | parser.add_argument('--exp_name', type=str, default='exp') 156 | parser.add_argument('--epoch_num', type=int, default=100) 157 | parser.add_argument('--lr', type=float, default=0.0001) 158 | parser.add_argument('--batch_size', type=int, default=1) 159 | parser.add_argument('--num_workers', type=int, default=8) 160 | opt = parser.parse_args() 161 | train(opt) 162 | 163 | 164 | -------------------------------------------------------------------------------- /train_scm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import math 4 | import torch 5 | import logging 6 | import os 7 | import cv2 8 | import shutil 9 | from torch.utils.collect_env import get_pretty_env_info 10 | from tensorboardX import SummaryWriter 11 | from torch.utils.data import DataLoader 12 | from utils.log import add_file_handler, init_log, print_speed 13 | from utils.parse_config import load_config 14 | from utils.timer import Timer 15 | from utils.meters import AverageMeter 16 | from scm.datasets.scm_dataset import DataSets 17 | from scm.experiments.siammask_sharp.custom import Custom 18 | from scm.utils.load_helper import load_pretrain, restore_from 19 | from scm.utils.lr_helper import build_lr_scheduler 20 | 21 | torch.backends.cudnn.benchmark = True 22 | 23 | def collect_env_info(): 24 | env_str = get_pretty_env_info() 25 | env_str += "\n OpenCV ({})".format(cv2.__version__) 26 | return env_str 27 | 28 | def build_data_loader(cfg): 29 | 30 | logger.info("build train dataset") # train_dataset 31 | train_set = DataSets(cfg['train_datasets'], cfg['anchors'], args.save_dir, args.epochs) 32 | train_set.shuffle() 33 | 34 | logger.info("build val dataset") # val_dataset 35 | if not 'val_datasets' in cfg.keys(): 36 | cfg['val_datasets'] = cfg['train_datasets'] 37 | val_set = DataSets(cfg['val_datasets'], cfg['anchors']) 38 | val_set.shuffle() 39 | 40 | train_loader = DataLoader(train_set, batch_size=args.batch, num_workers=args.workers, 41 | pin_memory=True, sampler=None) 42 | val_loader = DataLoader(val_set, batch_size=args.batch, num_workers=args.workers, 43 | pin_memory=True, sampler=None) 44 | 45 | logger.info('build dataset done') 46 | return train_loader, val_loader 47 | 48 | def build_opt_lr(model, cfg, args, epoch): 49 | trainable_params = model.mask_model.param_groups(cfg['lr']['start_lr'], cfg['lr']['mask_lr_mult']) + \ 50 | model.refine_model.param_groups(cfg['lr']['start_lr'], cfg['lr']['mask_lr_mult']) 51 | 52 | optimizer = torch.optim.SGD(trainable_params, args.lr, 53 | momentum=args.momentum, 54 | weight_decay=args.weight_decay) 55 | 56 | lr_scheduler = build_lr_scheduler(optimizer, cfg['lr'], epochs=args.epochs) 57 | 58 | return optimizer, lr_scheduler 59 | 60 | 61 | def main(): 62 | global logger, tb_writer 63 | args = parser.parse_args() 64 | init_log('global', logging.INFO) 65 | 66 | if args.log != "": 67 | add_file_handler('global', os.path.join(args.save_dir, args.log_dir, args.log), logging.INFO) 68 | 69 | logger = logging.getLogger('global') 70 | logger.info("\n" + collect_env_info()) 71 | logger.info(args) 72 | cfg = load_config(args) 73 | logger.info("config \n{}".format(json.dumps(cfg, indent=4))) 74 | tb_writer = SummaryWriter(os.path.join(args.save_dir, args.log_dir)) 75 | 76 | 77 | # build dataset 78 | train_loader, val_loader = build_data_loader(cfg) 79 | model = Custom(anchors=cfg['anchors']) 80 | logger.info(model) 81 | if args.pretrained: 82 | model = load_pretrain(model, args.pretrained) 83 | model = model.cuda() 84 | dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda() 85 | if args.resume and args.start_epoch != 0: 86 | model.features.unfix((args.start_epoch - 1) / args.epochs) 87 | 88 | optimizer, lr_scheduler = build_opt_lr(model, cfg, args, args.start_epoch) 89 | # optionally resume from a checkpoint 90 | if args.resume: 91 | assert os.path.isfile(args.resume), '{} is not a valid file'.format(args.resume) 92 | model, optimizer, args.start_epoch, arch = restore_from(model, optimizer, args.resume) 93 | dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda() 94 | 95 | logger.info(lr_scheduler) 96 | logger.info('model prepare done') 97 | logger.info('start training') 98 | train(train_loader, dist_model, optimizer, lr_scheduler, args.start_epoch, cfg) 99 | logger.info('end training') 100 | 101 | def BNtoFixed(m): 102 | class_name = m.__class__.__name__ 103 | if class_name.find('BatchNorm') != -1: 104 | m.eval() 105 | 106 | def train(train_loader, model, optimizer, lr_scheduler, epoch, cfg): 107 | 108 | cur_lr = lr_scheduler.get_cur_lr() 109 | batch_time = Timer() 110 | mask_loss = AverageMeter() 111 | iou_mean = AverageMeter() 112 | iou_at_5 = AverageMeter() 113 | iou_at_7 = AverageMeter() 114 | model.train() 115 | model.module.features.eval() 116 | model.module.rpn_model.eval() 117 | model.module.features.apply(BNtoFixed) 118 | model.module.rpn_model.apply(BNtoFixed) 119 | model.module.mask_model.train() 120 | model.module.refine_model.train() 121 | model = model.cuda() 122 | 123 | num_per_epoch = len(train_loader.dataset) // args.epochs // args.batch 124 | start_epoch = epoch 125 | epoch = epoch 126 | for iter, input in enumerate(train_loader): 127 | batch_time.tic() 128 | if epoch != iter // num_per_epoch + start_epoch: 129 | epoch = iter // num_per_epoch + start_epoch 130 | if not os.path.exists(args.save_dir): 131 | os.makedirs(args.save_dir) 132 | save_checkpoint({ 133 | 'epoch': epoch, 134 | 'arch': args.arch, 135 | 'state_dict': model.module.state_dict(), 136 | 'optimizer': optimizer.state_dict(), 137 | 'anchor_cfg': cfg['anchors'] 138 | }, False, 139 | os.path.join(args.save_dir, 'checkpoint_e%d.pth' % (epoch)), 140 | os.path.join(args.save_dir, 'best.pth')) 141 | if epoch == args.epochs: 142 | return 143 | lr_scheduler.step(epoch) 144 | cur_lr = lr_scheduler.get_cur_lr() 145 | logger.info('epoch:{}'.format(epoch)) 146 | 147 | if iter % num_per_epoch == 0 and iter != 0: 148 | for idx, pg in enumerate(optimizer.param_groups): 149 | logger.info("epoch {} lr {}".format(epoch, pg['lr'])) 150 | tb_writer.add_scalar('lr/group%d' % (idx+1), pg['lr'], iter) 151 | x = { 152 | 'cfg': cfg, 153 | 'template': torch.autograd.Variable(input[0]).cuda(), 154 | 'search': torch.autograd.Variable(input[1]).cuda(), 155 | 'label_cls': torch.autograd.Variable(input[2]).cuda(), 156 | 'label_loc': torch.autograd.Variable(input[3]).cuda(), 157 | 'label_loc_weight': torch.autograd.Variable(input[4]).cuda(), 158 | 'label_mask': torch.autograd.Variable(input[6]).cuda(), 159 | 'label_mask_weight': torch.autograd.Variable(input[7]).cuda(), 160 | } 161 | outputs = model(x) 162 | rpn_cls_loss, rpn_loc_loss, rpn_mask_loss = torch.mean(outputs['losses'][0]), \ 163 | torch.mean(outputs['losses'][1]), torch.mean(outputs['losses'][2]) 164 | mask_iou_mean, mask_iou_at_5, mask_iou_at_7 = torch.mean(outputs['accuracy'][0]), \ 165 | torch.mean(outputs['accuracy'][1]), torch.mean(outputs['accuracy'][2]) 166 | cls_weight, reg_weight, mask_weight = cfg['loss']['weight'] 167 | loss = rpn_cls_loss * cls_weight + rpn_loc_loss * reg_weight + rpn_mask_loss * mask_weight 168 | optimizer.zero_grad() 169 | loss.backward() 170 | if cfg['clip']['split']: 171 | torch.nn.utils.clip_grad_norm_(model.module.features.parameters(), cfg['clip']['feature']) 172 | torch.nn.utils.clip_grad_norm_(model.module.rpn_model.parameters(), cfg['clip']['rpn']) 173 | torch.nn.utils.clip_grad_norm_(model.module.mask_model.parameters(), cfg['clip']['mask']) 174 | torch.nn.utils.clip_grad_norm_(model.module.refine_model.parameters(), cfg['clip']['mask']) 175 | else: 176 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) # gradient clip 177 | optimizer.step() 178 | 179 | batch_time.toc() 180 | mask_loss.update(rpn_mask_loss.item()) 181 | iou_mean.update(mask_iou_mean.item()) 182 | iou_at_5.update(mask_iou_at_5.item()) 183 | iou_at_7.update(mask_iou_at_7.item()) 184 | 185 | tb_writer.add_scalar('loss/mask', rpn_mask_loss.item(), iter) 186 | tb_writer.add_scalar('mask/mIoU', mask_iou_mean.item(), iter) 187 | tb_writer.add_scalar('mask/AP@.5', mask_iou_at_5.item(), iter) 188 | tb_writer.add_scalar('mask/AP@.7', mask_iou_at_7.item(), iter) 189 | 190 | if (iter + 1) % args.print_freq == 0: 191 | logger.info('Epoch: [{0}][{1}/{2}] lr: {3:.6f}\tbatch_time:{4:.3f}' 192 | '\trpn_mask_loss:{5:.3f}\tmask_iou_mean:{6:.3f}' 193 | '\tmask_iou_at_5:{7:.3f}\tmask_iou_at_7:{8:.3f}'.format( 194 | epoch+1, (iter + 1) % num_per_epoch, num_per_epoch, cur_lr, batch_time.average_time, 195 | mask_loss.avg, iou_mean.avg, iou_at_5.avg, iou_at_7.avg)) 196 | print_speed(iter + 1, batch_time.average_time, args.epochs * num_per_epoch) 197 | 198 | def save_checkpoint(state, is_best, filename='checkpoint.pth', best_file='model_best.pth'): 199 | torch.save(state, filename) 200 | if is_best: 201 | shutil.copyfile(filename, best_file) 202 | 203 | if __name__ == '__main__': 204 | global args 205 | parser = argparse.ArgumentParser(description='PyTorch Tracking Training') 206 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 207 | help='number of data loading workers (default: 16)') 208 | parser.add_argument('--epochs', default=20, type=int, metavar='N', 209 | help='number of total epochs to run') 210 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 211 | help='manual epoch number (useful on restarts)') 212 | parser.add_argument('-b', '--batch', default=64, type=int, 213 | metavar='N', help='mini-batch size (default: 64)') 214 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, 215 | metavar='LR', help='initial learning rate') 216 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 217 | help='momentum') 218 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 219 | metavar='W', help='weight decay (default: 1e-4)') 220 | parser.add_argument('--clip', default=10.0, type=float, 221 | help='gradient clip value') 222 | parser.add_argument('--print-freq', '-p', default=10, type=int, 223 | metavar='N', help='print frequency (default: 10)') 224 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 225 | help='path to latest checkpoint (default: none)') 226 | parser.add_argument('--pretrained', dest='pretrained', default='', 227 | help='use pre-trained model') 228 | parser.add_argument('--config', dest='config', required=True, 229 | help='hyperparameter of SiamMask in json format') 230 | parser.add_argument('--arch', dest='arch', default='', choices=['Custom',''], 231 | help='architecture of pretrained model') 232 | parser.add_argument('-l', '--log', default="log.txt", type=str, 233 | help='log file') 234 | parser.add_argument('-s', '--save-dir', default='', type=str, 235 | help='save dir') 236 | parser.add_argument('--log-dir', default='board', help='TensorBoard log dir') 237 | args = parser.parse_args() 238 | main() 239 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lsabrinax/VideoTextSCM/d87ad1bbb6ada7573a02a82045ee1b9ead5861ad/utils/__init__.py -------------------------------------------------------------------------------- /utils/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import math 4 | 5 | def print_speed(i, i_time, n): 6 | """print_speed(index, index_time, total_iteration)""" 7 | logger = logging.getLogger('global') 8 | average_time = i_time 9 | remaining_time = (n - i) * average_time 10 | remaining_day = math.floor(remaining_time / 86400) 11 | remaining_hour = math.floor(remaining_time / 3600 - remaining_day * 24) 12 | remaining_min = math.floor(remaining_time / 60 - remaining_day * 1440 - remaining_hour * 60) 13 | logger.info('Progress: %d / %d [%d%%], Speed: %.3f s/iter, ETA %d:%02d:%02d (D:H:M)\n' % (i, n, i/n*100, average_time, remaining_day, remaining_hour, remaining_min)) 14 | 15 | class Filter: 16 | def __init__(self, flag): 17 | self.flag = flag 18 | 19 | def filter(self, x): return self.flag 20 | 21 | def get_format(logger, level): 22 | if 'SLURM_PROCID' in os.environ: 23 | rank = int(os.environ['SLURM_PROCID']) 24 | if level == logging.INFO: 25 | logger.addFilter(Filter(rank == 0)) 26 | else: 27 | rank = 0 28 | format_str = '[%(asctime)s rk{} %(filename)s#%(lineno)3d] %(message)s'.format(rank) 29 | formatter = logging.Formatter(fmt=format_str, datefmt='%Y-%m-%d %H:%M:%S') 30 | return formatter 31 | 32 | def init_log(name, level = logging.INFO, format_func=get_format): 33 | logger = logging.getLogger(name) 34 | logger.setLevel(level) 35 | ch = logging.StreamHandler() 36 | ch.setLevel(level) 37 | formatter = format_func(logger, level) 38 | ch.setFormatter(formatter) 39 | logger.addHandler(ch) 40 | return logger 41 | 42 | def add_file_handler(name, log_file, level = logging.INFO): 43 | logger = logging.getLogger(name) 44 | if not os.path.isdir(os.path.dirname(log_file)): 45 | os.makedirs(os.path.dirname(log_file)) 46 | fh = logging.FileHandler(log_file) 47 | fh.setFormatter(get_format(logger, level)) 48 | logger.addHandler(fh) -------------------------------------------------------------------------------- /utils/meters.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """computes and stores the average and current value""" 3 | def __init__(self, infos=None): 4 | self.val = 0 5 | self.avg = 0 6 | self.sum = 0 7 | self.count = 0 8 | self.infos = infos 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val*n 19 | self.count += n 20 | self.avg = self.sum / self.count 21 | -------------------------------------------------------------------------------- /utils/parse_config.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import json 3 | 4 | def proccess_loss(cfg): 5 | if 'reg' not in cfg: 6 | cfg['reg'] = {'loss': 'L1Loss'} 7 | else: 8 | if 'loss' not in cfg['reg']: 9 | cfg['reg']['loss'] = 'L1Loss' 10 | 11 | if 'cls' not in cfg: 12 | cfg['cls'] = {'split': True} 13 | 14 | cfg['weight'] = cfg.get('weight', [1, 1, 36]) # cls, reg, mask 15 | 16 | def load_config(args): 17 | assert osp.exists(args.config), '"{}" not exists'.format(args.config) 18 | config = json.load(open(args.config)) 19 | 20 | # deal with network 21 | if 'network' not in config: 22 | print('Warning: network lost in config. This will be error in next version') 23 | config['network'] = {} 24 | if not args.arch: 25 | raise Exception('no arch provided') 26 | args.arch = config['network']['arch'] 27 | 28 | # deal with loss 29 | if 'loss' not in config: 30 | config['loss'] = {} 31 | proccess_loss(config['loss']) 32 | 33 | # deal with lr 34 | if 'lr' not in config: 35 | config['lr'] = {} 36 | default = { 37 | 'feature_lr_mult': 1.0, 38 | 'rpn_lr_mult': 1.0, 39 | 'mask_lr_mult': 1.0, 40 | 'type': 'log', 41 | 'start_lr': 0.03 42 | } 43 | default.update(config['lr']) 44 | config['lr'] = default 45 | 46 | # clip 47 | if 'clip' in config or 'clip' in args.__dict__: 48 | if 'clip' not in config: 49 | config['clip'] = {} 50 | config['clip'].update({'feature': args.clip, 'rpn': args.clip, 'split': False}) 51 | if config['clip']['feature'] != config['clip']['rpn']: 52 | config['clip']['split'] = True 53 | if not config['clip']['split']: 54 | args.clip = config['clip']['feature'] 55 | 56 | return config 57 | -------------------------------------------------------------------------------- /utils/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | class Timer(object): 4 | """A simple timer.""" 5 | def __init__(self): 6 | self.total_time = 0. 7 | self.calls = 0 8 | self.start_time = 0. 9 | self.diff = 0. 10 | self.average_time = 0. 11 | self.duration = 0. 12 | 13 | def tic(self): 14 | # using time.time instead of time.clock because time time.clock 15 | # does not normalize for multithreading 16 | self.start_time = time.time() 17 | 18 | def toc(self, average=True): 19 | self.diff = time.time() - self.start_time 20 | self.total_time += self.diff 21 | self.calls += 1 22 | self.average_time = self.total_time / self.calls 23 | if average: 24 | self.duration = self.average_time 25 | else: 26 | self.duration = self.diff 27 | return self.duration 28 | 29 | def clear(self): 30 | self.total_time = 0. 31 | self.calls = 0 32 | self.start_time = 0. 33 | self.diff = 0. 34 | self.average_time = 0. 35 | self.duration = 0. 36 | 37 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | from shapely.geometry import Polygon 5 | import cv2 6 | import pyclipper 7 | import torch 8 | 9 | def mkdir_if_missing(d): 10 | if not osp.exists(d): 11 | os.makedirs(d) 12 | 13 | def write2xml(file_root, results, change_id=False): 14 | xml_file = open(file_root, 'w') 15 | xml_file.write('\n') 16 | for i, result in enumerate(results): # result: pre frame; box:[pts id] 17 | if change_id: 18 | xml_file.write(' \n'.format(i+1)) 19 | else: 20 | xml_file.write(' \n'.format(i)) 21 | for box in result: 22 | x1,y1,x2,y2,x3,y3,x4,y4,oid=box 23 | xml_file.write(' \n'.format(oid)) 24 | xml_file.write(' \n'.format(int(x1),int(y1))) 25 | xml_file.write(' \n'.format(int(x2),int(y2))) 26 | xml_file.write(' \n'.format(int(x3),int(y3))) 27 | xml_file.write(' \n'.format(int(x4),int(y4))) 28 | xml_file.write(' \n') 29 | xml_file.write(' \n') 30 | xml_file.write('\n') 31 | 32 | def write2txt(filename, results, change_id=False): 33 | save_format = '{imgid},{insid},{x0},{y0},{w},{h},1,-1,-1,-1\n' 34 | with open(filename, 'w') as f: 35 | for i, result in enumerate(results): 36 | if change_id: 37 | imgid = i + 1 38 | else: 39 | imgid = i 40 | for box in result: 41 | x1,y1,x2,y2,x3,y3,x4,y4,oid=box 42 | t, l = min(y1,y2,y3,y4), min(x1,x2,x3,x4) 43 | b, r = max(y1,y2,y3,y4), max(x1,x2,x3,x4) 44 | x0, y0, w, h = l, t, r-l+1, b-t+1 45 | line = save_format.format(imgid=imgid, insid=oid, x0=x0, y0=y0, w=w, h=h) 46 | f.write(line) 47 | 48 | def save_det_res(txt_name, video_name, boxes, save_dir, dataset): 49 | if dataset == 'roadtext': 50 | save_format = '{x0},{y0},{x2},{y2}\n' 51 | elif dataset == 'minetto' or dataset == 'icdar': 52 | save_format = '{x0},{y0},{x1},{y1},{x2},{y2},{x3},{y3}\n' 53 | gt_file = os.path.join(save_dir, video_name+'_'+txt_name) 54 | if torch.is_tensor(boxes) and boxes.shape[0]>0: 55 | boxes = boxes[:,:8].int().numpy() 56 | with open(gt_file,'w') as f: 57 | for box in boxes: 58 | if dataset == 'roadtext': 59 | x0, y0 = min(box[::2]), min(box[1::2]) 60 | x2, y2 = max(box[::2]), max(box[1::2]) 61 | elif dataset == 'minetto' or dataset == 'icdar': 62 | x0, y0, x1, y1, x2, y2, x3, y3 = box.tolist() 63 | f.write(save_format.format(x0=x0,y0=y0,x1=x1,y1=y1,x2=x2,y2=y2,x3=x3,y3=y3)) 64 | f.close() 65 | else: 66 | f = open(gt_file,'w') 67 | f.close() 68 | 69 | def polygon_iou(apts, bpts): 70 | ious = np.empty((apts.shape[0], bpts.shape[0])) 71 | for i, apt in enumerate(apts): 72 | apt = apt.reshape(4, 2) 73 | polya = Polygon(apt).convex_hull 74 | for j, bpt in enumerate(bpts): 75 | bpt = bpt.reshape(4, 2) 76 | polyb = Polygon(bpt).convex_hull 77 | inter = polya.intersection(polyb).area 78 | union = polya.area + polyb.area - inter 79 | ious[i, j] = inter / (union + 1e-6) 80 | return ious 81 | 82 | def iou(reference, proposals): 83 | """Compute the IoU between a reference box with multiple proposal boxes. 84 | args: 85 | reference - Tensor of shape (1, 4). 86 | proposals - Tensor of shape (num_proposals, 4) 87 | returns: 88 | torch.Tensor - Tensor of shape (num_proposals,) containing IoU of reference box with each proposal box. 89 | """ 90 | # Intersection box 91 | tl = torch.max(reference[:,:2], proposals[:,:2]) 92 | br = torch.min(reference[:,:2] + reference[:,2:], proposals[:,:2] + proposals[:,2:]) 93 | sz = (br - tl).clamp(0) 94 | # Area 95 | intersection = sz.prod(dim=1) 96 | union = reference[:,2:].prod(dim=1) + proposals[:,2:].prod(dim=1) - intersection 97 | 98 | return intersection / (union + 1e-6) 99 | 100 | def validate_polygons(polygons, ignore_tags, h, w): 101 | ''' 102 | polygons (numpy.array, required): of shape (num_instances, num_points, 2) 103 | ''' 104 | if polygons.shape[0] == 0: 105 | return polygons, ignore_tags 106 | assert polygons.shape[0] == len(ignore_tags) 107 | polygons[:, :, 0] = np.clip(polygons[:, :, 0], 0, w - 1) 108 | polygons[:, :, 1] = np.clip(polygons[:, :, 1], 0, h - 1) 109 | 110 | for i in range(polygons.shape[0]): 111 | area = Polygon(polygons[i]).area 112 | if abs(area) < 1: 113 | ignore_tags[i] = True 114 | return polygons, ignore_tags 115 | 116 | def make_seg_shrink(polygons, ignore_tags,h,w): 117 | min_text_size = 8 118 | shrink_ratio = 0.4 119 | polygons, ignore_tags = validate_polygons( 120 | polygons, ignore_tags, h, w) 121 | gt = np.zeros((1, h, w), dtype=np.float32) 122 | mask = np.ones((h, w), dtype=np.float32) 123 | for i in range(polygons.shape[0]): 124 | polygon = polygons[i] 125 | height = min(np.linalg.norm(polygon[0] - polygon[3]), 126 | np.linalg.norm(polygon[1] - polygon[2])) 127 | width = min(np.linalg.norm(polygon[0] - polygon[1]), 128 | np.linalg.norm(polygon[2] - polygon[3])) 129 | if ignore_tags[i] or min(height, width) < min_text_size: 130 | cv2.fillPoly(mask, polygon.astype( 131 | np.int32)[np.newaxis, :, :], 0) 132 | ignore_tags[i] = True 133 | else: 134 | polygon_shape = Polygon(polygon) 135 | distance = polygon_shape.area * \ 136 | (1 - np.power(shrink_ratio, 2)) / polygon_shape.length 137 | subject = [tuple(l) for l in polygons[i]] 138 | padding = pyclipper.PyclipperOffset() 139 | padding.AddPath(subject, pyclipper.JT_ROUND, 140 | pyclipper.ET_CLOSEDPOLYGON) 141 | shrinked = padding.Execute(-distance) 142 | if shrinked == []: 143 | cv2.fillPoly(mask, polygon.astype( 144 | np.int32)[np.newaxis, :, :], 0) 145 | ignore_tags[i] = True 146 | continue 147 | shrinked = np.array(shrinked[0]).reshape(-1, 2) 148 | cv2.fillPoly(gt[0], [shrinked.astype(np.int32)], 1) 149 | return np.squeeze(gt) 150 | 151 | --------------------------------------------------------------------------------