├── README.md ├── docs ├── drift-chiance.gif ├── lab-coat.gif ├── parkour.gif └── teaser.png ├── libs ├── autoencoder.py ├── concentration_loss.py ├── data │ ├── db_info.yaml │ ├── palette.txt │ └── train.txt ├── loader.py ├── loss.py ├── model.py ├── resnet.py ├── test_utils.py ├── track_utils.py ├── train_utils.py ├── transforms_multi.py ├── transforms_pair.py ├── utils.py └── vis_utils.py ├── model.py ├── test.py ├── test_with_track.py ├── track_match_v1.py └── weights ├── checkpoint_latest.pth.tar ├── decoder_single_gpu.pth └── encoder_single_gpu.pth /README.md: -------------------------------------------------------------------------------- 1 | # Joint-task Self-supervised Learning for Temporal Correspondence 2 | 3 | [**Project**](https://sites.google.com/view/uvc2019) | [**Paper**]() 4 | 5 | # Overview 6 | 7 | ![](docs/teaser.png) 8 | 9 | [Joint-task Self-supervised Learning for Temporal Correspondence]() 10 | 11 | [Xueting Li*](https://sunshineatnoon.github.io/), [Sifei Liu*](https://www.sifeiliu.net/), [Shalini De Mello](https://research.nvidia.com/person/shalini-gupta), [Xiaolong Wang](https://www.cs.cmu.edu/~xiaolonw/), [Jan Kautz](http://jankautz.com/), [Ming-Hsuan Yang](http://faculty.ucmerced.edu/mhyang/). 12 | 13 | (* equal contributions) 14 | 15 | In Neural Information Processing Systems (NeurIPS), 2019. 16 | 17 | # Citation 18 | If you use our code in your research, please use the following BibTex: 19 | 20 | ``` 21 | @inproceedings{uvc_2019, 22 | Author = {Xueting Li and Sifei Liu and Shalini De Mello and Xiaolong Wang and Jan Kautz and Ming-Hsuan Yang}, 23 | Title = {Joint-task Self-supervised Learning for Temporal Correspondence}, 24 | Booktitle = {NeurIPS}, 25 | Year = {2019}, 26 | } 27 | ``` 28 | 29 | # Instance segmentation propagation on DAVIS2017 30 |

31 | 32 | 33 | 34 |

35 | 36 | 37 | | Method | J_mean | J_recall | J_decay | F_mean | F_recall | F_decay | 38 | | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | 39 | | Ours | 0.563 | 0.650 | 0.289 | 0.592 | 0.641 | 0.354 | 40 | | Ours - track | 0.577 | 0.683 | 0.263 | 0.613 | 0.698 | 0.324 | 41 | 42 | # Prerequisites 43 | The code is tested in the following environment: 44 | - Ubuntu 16.04 45 | - Pytorch 1.1.0, [tqdm](https://github.com/tqdm/tqdm), scipy 1.2.1 46 | 47 | # Testing on DAVIS2017 48 | ## Testing without tracking 49 | To test on DAVIS2017 for instance segmentation mask propagation, please run: 50 | ``` 51 | python test.py -d /workspace/DAVIS/ -s 480 52 | ``` 53 | Important parameters: 54 | - `-c`: checkpoint path. 55 | - `-o`: results path. 56 | - `-d`: DAVIS 2017 dataset path. 57 | - `-s`: test resolution, all results in the paper are tested on 480p images, i.e. `-s 480`. 58 | 59 | Please check the `test.py` file for other parameters. 60 | 61 | ## Testing with tracking 62 | To test on DAVIS2017 by tracking & propagation, please run: 63 | ``` 64 | python test_with_track.py -d /workspace/DAVIS/ -s 480 65 | ``` 66 | Similar parameters as `test.py`, please see the `test_with_track.py` for details. 67 | 68 | # Training on Kinetics 69 | 70 | ## Dataset 71 | 72 | We use the [kinetics dataset](https://deepmind.com/research/open-source/open-source-datasets/kinetics/) for training. 73 | 74 | ## Training command 75 | 76 | ``` 77 | python track_match_v1.py --wepoch 10 --nepoch 30 -c match_track_switch --batchsize 40 --coord_switch 0 --lc 0.3 78 | ``` 79 | 80 | # Acknowledgements 81 | - This code is based on [TPN](https://arxiv.org/pdf/1804.08758.pdf) and [TimeCycle](https://github.com/xiaolonw/TimeCycle). 82 | - For any issues, please contact xli75@ucmerced.edu or sifeil@nvidia.com. 83 | -------------------------------------------------------------------------------- /docs/drift-chiance.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaolonw/UVC-1/fafb36f1577080e8ecfa09d97dc2c024b04ccdb2/docs/drift-chiance.gif -------------------------------------------------------------------------------- /docs/lab-coat.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaolonw/UVC-1/fafb36f1577080e8ecfa09d97dc2c024b04ccdb2/docs/lab-coat.gif -------------------------------------------------------------------------------- /docs/parkour.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaolonw/UVC-1/fafb36f1577080e8ecfa09d97dc2c024b04ccdb2/docs/parkour.gif -------------------------------------------------------------------------------- /docs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaolonw/UVC-1/fafb36f1577080e8ecfa09d97dc2c024b04ccdb2/docs/teaser.png -------------------------------------------------------------------------------- /libs/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models import vgg16 4 | from torch.autograd import Variable 5 | from collections import OrderedDict 6 | import torch.nn.functional as F 7 | from .resnet import resnet18, resnet50 8 | 9 | class encoder3(nn.Module): 10 | def __init__(self, reduce = False): 11 | super(encoder3,self).__init__() 12 | # vgg 13 | # 224 x 224 14 | self.conv1 = nn.Conv2d(3,3,1,1,0) 15 | self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1)) 16 | # 226 x 226 17 | 18 | self.conv2 = nn.Conv2d(3,64,3,1,0) 19 | self.relu2 = nn.ReLU(inplace=True) 20 | # 224 x 224 21 | 22 | self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1)) 23 | self.conv3 = nn.Conv2d(64,64,3,1,0) 24 | self.relu3 = nn.ReLU(inplace=True) 25 | # 224 x 224 26 | 27 | self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True) 28 | # 112 x 112 29 | 30 | self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1)) 31 | self.conv4 = nn.Conv2d(64,128,3,1,0) 32 | self.relu4 = nn.ReLU(inplace=True) 33 | # 112 x 112 34 | 35 | self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) 36 | self.conv5 = nn.Conv2d(128,128,3,1,0) 37 | self.relu5 = nn.ReLU(inplace=True) 38 | # 112 x 112 39 | 40 | self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True) 41 | # 56 x 56 42 | 43 | self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) 44 | self.conv6 = nn.Conv2d(128,256,3,1,0) 45 | self.relu6 = nn.ReLU(inplace=True) 46 | # 56 x 56 47 | self.reduce = reduce 48 | if reduce: 49 | self.downsample = nn.Sequential(nn.MaxPool2d(kernel_size=2,stride=2), 50 | nn.Conv2d(256,256,1,1,0), 51 | nn.ReLU(inplace=True)) 52 | 53 | def forward(self,x): 54 | out = self.conv1(x) 55 | out = self.reflecPad1(out) 56 | out = self.conv2(out) 57 | out = self.relu2(out) 58 | out = self.reflecPad3(out) 59 | out = self.conv3(out) 60 | pool1 = self.relu3(out) 61 | out,pool_idx = self.maxPool(pool1) 62 | out = self.reflecPad4(out) 63 | out = self.conv4(out) 64 | out = self.relu4(out) 65 | out = self.reflecPad5(out) 66 | out = self.conv5(out) 67 | pool2 = self.relu5(out) 68 | out,pool_idx2 = self.maxPool2(pool2) 69 | out = self.reflecPad6(out) 70 | out = self.conv6(out) 71 | out = self.relu6(out) 72 | if self.reduce: 73 | out = self.downsample(out) 74 | return out 75 | 76 | class decoder3(nn.Module): 77 | def __init__(self, cls=False, cls_num=32, reduce = False): 78 | """ 79 | INPUTS: 80 | - cls: if using classification. 81 | - cls_num: cluster number 82 | """ 83 | super(decoder3,self).__init__() 84 | if reduce: 85 | self.upsample = nn.Sequential(nn.UpsamplingNearest2d(scale_factor=2), 86 | nn.Conv2d(256,256,1,1,0), 87 | nn.ReLU(inplace=True)) 88 | self.reduce = reduce 89 | 90 | self.cls = cls 91 | # decoder 92 | self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) 93 | self.conv7 = nn.Conv2d(256,128,3,1,0) 94 | self.relu7 = nn.ReLU(inplace=True) 95 | # 56 x 56 96 | 97 | self.unpool = nn.UpsamplingNearest2d(scale_factor=2) 98 | # 112 x 112 99 | 100 | self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1)) 101 | self.conv8 = nn.Conv2d(128,128,3,1,0) 102 | self.relu8 = nn.ReLU(inplace=True) 103 | # 112 x 112 104 | 105 | self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1)) 106 | self.conv9 = nn.Conv2d(128,64,3,1,0) 107 | self.relu9 = nn.ReLU(inplace=True) 108 | 109 | self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2) 110 | # 224 x 224 111 | 112 | self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1)) 113 | self.conv10 = nn.Conv2d(64,64,3,1,0) 114 | self.relu10 = nn.ReLU(inplace=True) 115 | 116 | self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1)) 117 | if not cls: 118 | self.conv11 = nn.Conv2d(64,3,3,1,0) 119 | else: 120 | self.conv11 = nn.Sequential(nn.Conv2d(64,cls_num,3,1,0), 121 | nn.LogSoftmax()) 122 | 123 | def forward(self,x): 124 | output = {} 125 | if self.reduce: 126 | x = self.upsample(x) 127 | out = self.reflecPad7(x) 128 | out = self.conv7(out) 129 | out = self.relu7(out) 130 | out = self.unpool(out) 131 | out = self.reflecPad8(out) 132 | out = self.conv8(out) 133 | out = self.relu8(out) 134 | out = self.reflecPad9(out) 135 | out = self.conv9(out) 136 | out_relu9 = self.relu9(out) 137 | out = self.unpool2(out_relu9) 138 | out = self.reflecPad10(out) 139 | out = self.conv10(out) 140 | out = self.relu10(out) 141 | out = self.reflecPad11(out) 142 | out = self.conv11(out) 143 | if not self.cls: 144 | out = torch.tanh(out) 145 | return out 146 | 147 | 148 | def encoder_res18(pretrained = True, uselayer=3): 149 | """Constructs a ResNet-18 model. 150 | Args: 151 | pretrained (bool): If True, returns a model pre-trained on ImageNet 152 | """ 153 | model = resnet18(uselayer=uselayer) 154 | if pretrained: 155 | print("Using pretrianed ResNet18 as guide.") 156 | model.load_state_dict(torch.load("ae_models/resnet18-5c106cde.pth"), strict = False) 157 | return model 158 | -------------------------------------------------------------------------------- /libs/concentration_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def transform(aff, frame1): 5 | """ 6 | Given aff, copy from frame1 to construct frame2. 7 | INPUTS: 8 | - aff: (h*w)*(h*w) affinity matrix 9 | - frame1: n*c*h*w feature map 10 | """ 11 | b,c,h,w = frame1.size() 12 | frame1 = frame1.view(b,c,-1) 13 | frame2 = torch.bmm(frame1, aff) 14 | return frame2.view(b,c,h,w) 15 | 16 | def aff2coord(F_size, grid, A = None, temp = None, softmax=None): 17 | """ 18 | INPUT: 19 | - A: a (H*W)*(H*W) affinity matrix 20 | - F_size: image feature size 21 | - mode: if mode is coord, return coordinates, else return flow 22 | - grid: a standard grid, see create_grid function. 23 | OUTPUT: 24 | - U: a (2*H*W) coordinate tensor, U_ij indicates the coordinates of pixel ij in target image. 25 | """ 26 | grid = grid.permute(0,3,1,2) 27 | if A is not None: 28 | if softmax is not None: 29 | if temp is None: 30 | raise Exception("Need temp for softmax!") 31 | A = softmax(A*temp) 32 | # b x c x h x w 33 | U = transform(A, grid) 34 | else: 35 | U = grid 36 | return U 37 | 38 | def create_grid(F_size, GPU=True): 39 | b, c, h, w = F_size 40 | theta = torch.tensor([[1,0,0],[0,1,0]]) 41 | theta = theta.unsqueeze(0).repeat(b,1,1) 42 | theta = theta.float() 43 | 44 | # grid is a uniform grid with left top (-1,1) and right bottom (1,1) 45 | # b * (h*w) * 2 46 | grid = nn.functional.affine_grid(theta, F_size) 47 | if(GPU): 48 | grid = grid.cuda() 49 | return grid 50 | 51 | def im2col(img, win_len, stride=1): 52 | """ 53 | INPUTS: 54 | - img: a b*c*h*w feature tensor. 55 | - win_len: each pixel compares with its neighbors within a 56 | (win_len*2+1) * (win_len*2+1) window. 57 | OUTPUT: 58 | - result: a b*c*(h*w)*(win_len*2+1)^2 tensor, unfolded neighbors for each pixel 59 | """ 60 | b,c,_,_ = img.size() 61 | # b * (c*w*w) * win_num 62 | unfold_img = torch.nn.functional.unfold(img, win_len, padding=0, stride=stride) 63 | unfold_img = unfold_img.view(b,c,win_len*win_len,-1) 64 | unfold_img = unfold_img.permute(0,1,3,2) 65 | return unfold_img 66 | 67 | class ConcentrationLoss(nn.Module): 68 | def __init__(self, win_len, stride, F_size): 69 | super(ConcentrationLoss, self).__init__() 70 | self.win_len = win_len 71 | self.grid = nn.Parameter(create_grid(F_size), requires_grad = False) 72 | self.F_size = F_size 73 | self.stride = stride 74 | 75 | def forward(self, aff): 76 | b, c, h, w = self.F_size 77 | #if aff.dim() == 4: 78 | # aff = torch.squeeze(aff) 79 | # b * 2 * h * w 80 | coord1 = aff2coord(self.F_size, self.grid, aff) 81 | # b * 2 * (h * w) * (win ^ 2) 82 | coord1_unfold = im2col(coord1, self.win_len, stride = self.stride) 83 | # b * 2 * (h * w) * 1 84 | # center = coord1_unfold[:,:,:,int((self.win_len ** 2)/2)] 85 | center = torch.mean(coord1_unfold, dim = 3) 86 | center = center.view(b, 2, -1, 1) 87 | # b * 2 * (h * w) * (win ^ 2) 88 | dis2center = (coord1_unfold - center) ** 2 89 | return torch.sum(dis2center) / dis2center.numel() 90 | 91 | class ConcentrationDetachLoss(nn.Module): 92 | def __init__(self, win_len, stride, F_size): 93 | super(ConcentrationDetachLoss, self).__init__() 94 | self.win_len = win_len 95 | self.grid = nn.Parameter(create_grid(F_size), requires_grad = False) 96 | self.F_size = F_size 97 | self.stride = stride 98 | 99 | def forward(self, aff): 100 | b, c, h, w = self.F_size 101 | if aff.dim() == 4: 102 | aff = torch.squeeze(aff) 103 | # b * 2 * h * w 104 | coord1 = aff2coord(self.F_size, self.grid, aff) 105 | # b * 2 * (h * w) * (win ^ 2) 106 | coord1_unfold = im2col(coord1, self.win_len, stride = self.stride) 107 | # b * 2 * (h * w) * 1 108 | # center = coord1_unfold[:,:,:,int((self.win_len ** 2)/2)] 109 | center = torch.mean(coord1_unfold, dim = 3).detach() 110 | center = center.view(b, 2, -1, 1) 111 | # b * 2 * (h * w) * (win ^ 2) 112 | dis2center = (coord1_unfold - center) ** 2 113 | return torch.sum(dis2center) / dis2center.numel() 114 | 115 | class ConcentrationSwitchLoss(nn.Module): 116 | def __init__(self, win_len, stride, F_size, temp): 117 | super(ConcentrationSwitchLoss, self).__init__() 118 | self.win_len = win_len 119 | self.grid = nn.Parameter(create_grid(F_size), requires_grad = False) 120 | self.F_size = F_size 121 | self.stride = stride 122 | self.temp = temp 123 | self.softmax = nn.Softmax(dim=1) 124 | 125 | def forward(self, aff): 126 | # aff here is not processed by softmax 127 | b, c, h, w = self.F_size 128 | if aff.dim() == 4: 129 | aff = torch.squeeze(aff) 130 | # b * 2 * h * w 131 | coord1 = aff2coord(self.F_size, self.grid, aff, self.temp, softmax=self.softmax) 132 | coord2 = aff2coord(self.F_size, self.grid, aff.permute(0,2,1), self.temp, softmax=self.softmax) 133 | # b * 2 * (h * w) * (win ^ 2) 134 | coord1_unfold = im2col(coord1, self.win_len, stride = self.stride) 135 | coord2_unfold = im2col(coord2, self.win_len, stride = self.stride) 136 | # b * 2 * (h * w) * 1 137 | center1 = torch.mean(coord1_unfold, dim = 3) 138 | center1 = center1.view(b, 2, -1, 1) 139 | # b * 2 * (h * w) * (win ^ 2) 140 | dis2center1 = (coord1_unfold - center1) ** 2 141 | # b * 2 * (h * w) * 1 142 | center2 = torch.mean(coord2_unfold, dim = 3) 143 | center2 = center2.view(b, 2, -1, 1) 144 | # b * 2 * (h * w) * (win ^ 2) 145 | dis2center2 = (coord2_unfold - center2) ** 2 146 | return (torch.sum(dis2center1) + torch.sum(dis2center2))/ dis2center1.numel() 147 | 148 | 149 | if __name__ == '__main__': 150 | cl = ConcentrationLoss(win_len=8, F_size=torch.Size((1,3,32,32)), stride=8) 151 | aff = torch.Tensor(1,1024,1024) 152 | aff.uniform_() 153 | aff = aff.cuda() 154 | print(cl(aff)) 155 | -------------------------------------------------------------------------------- /libs/data/db_info.yaml: -------------------------------------------------------------------------------- 1 | attributes: [AC, BC, CS, DB, DEF, EA, FM, HO, IO, LR, MB, OCC, OV, ROT, SC, SV] 2 | sets: [train, val, val-dev] 3 | years: [2016, 2017] 4 | 5 | sequences: 6 | 7 | - name: aerobatics 8 | attributes: [] 9 | num_frames: 71 10 | set: test-dev 11 | eval_t: False 12 | year: 2017 13 | 14 | - name: bear 15 | attributes: [DEF] 16 | num_frames: 82 17 | set: train 18 | eval_t: True 19 | year: 2016 20 | 21 | - name: bike-packing 22 | attributes: [] 23 | num_frames: 69 24 | set: val 25 | eval_t: False 26 | year: 2017 27 | 28 | - name: blackswan 29 | attributes: [] 30 | num_frames: 50 31 | set: val 32 | eval_t: True 33 | year: 2016 34 | 35 | - name: bmx-bumps 36 | attributes: [LR, SV, SC, FM, CS, IO, MB, OCC, HO, EA, OV] 37 | num_frames: 90 38 | set: train 39 | eval_t: False 40 | year: 2016 41 | 42 | - name: bmx-trees 43 | attributes: [LR, SV, SC, FM, CS, IO, MB, DEF, OCC, HO, EA, BC] 44 | num_frames: 80 45 | set: val 46 | eval_t: False 47 | year: 2016 48 | 49 | - name: boat 50 | attributes: [SC, DB, EA, BC] 51 | num_frames: 75 52 | set: train 53 | eval_t: True 54 | year: 2016 55 | 56 | - name: boxing-fisheye 57 | attributes: [] 58 | num_frames: 87 59 | set: train 60 | eval_t: False 61 | year: 2017 62 | 63 | - name: breakdance 64 | attributes: [FM, DB, MB, DEF, HO, ROT, OV, AC] 65 | num_frames: 84 66 | set: val 67 | eval_t: False 68 | year: 2016 69 | 70 | - name: breakdance-flare 71 | attributes: [FM, CS, MB, DEF, HO, ROT] 72 | num_frames: 71 73 | set: train 74 | eval_t: False 75 | year: 2016 76 | 77 | - name: bus 78 | attributes: [SC, OCC, HO, EA] 79 | num_frames: 80 80 | set: train 81 | eval_t: True 82 | year: 2016 83 | 84 | - name: camel 85 | attributes: [CS, IO, DEF, ROT] 86 | num_frames: 90 87 | set: val 88 | eval_t: True 89 | year: 2016 90 | 91 | - name: car-race 92 | attributes: [] 93 | num_frames: 31 94 | set: test-dev 95 | eval_t: False 96 | year: 2017 97 | 98 | - name: car-roundabout 99 | attributes: [ROT, BC] 100 | num_frames: 75 101 | set: val 102 | eval_t: True 103 | year: 2016 104 | 105 | - name: car-shadow 106 | attributes: [LR, EA, AC, BC] 107 | num_frames: 40 108 | set: val 109 | eval_t: True 110 | year: 2016 111 | 112 | - name: car-turn 113 | attributes: [SV, ROT, BC] 114 | num_frames: 80 115 | set: train 116 | eval_t: True 117 | year: 2016 118 | 119 | - name: carousel 120 | attributes: [] 121 | num_frames: 69 122 | set: test-dev 123 | eval_t: False 124 | year: 2017 125 | 126 | - name: cat-girl 127 | attributes: [] 128 | num_frames: 89 129 | set: train 130 | eval_t: False 131 | year: 2017 132 | 133 | - name: cats-car 134 | attributes: [] 135 | num_frames: 67 136 | set: test-dev 137 | eval_t: False 138 | year: 2017 139 | 140 | - name: chamaleon 141 | attributes: [] 142 | num_frames: 85 143 | set: test-dev 144 | eval_t: False 145 | year: 2017 146 | 147 | - name: classic-car 148 | attributes: [] 149 | num_frames: 63 150 | set: train 151 | eval_t: False 152 | year: 2017 153 | 154 | - name: color-run 155 | attributes: [] 156 | num_frames: 84 157 | set: train 158 | eval_t: False 159 | year: 2017 160 | 161 | - name: cows 162 | attributes: [CS, IO, DEF, OCC, HO] 163 | num_frames: 104 164 | set: val 165 | eval_t: True 166 | year: 2016 167 | 168 | - name: crossing 169 | attributes: [] 170 | num_frames: 52 171 | set: train 172 | eval_t: False 173 | year: 2017 174 | 175 | - name: dance-jump 176 | attributes: [SC, DB, MB, DEF, OCC, HO, ROT, EA] 177 | num_frames: 60 178 | set: train 179 | eval_t: True 180 | year: 2016 181 | 182 | - name: dance-twirl 183 | attributes: [SC, CS, IO, MB, DEF, HO, ROT, OV] 184 | num_frames: 90 185 | set: val 186 | eval_t: False 187 | year: 2016 188 | 189 | - name: dancing 190 | attributes: [] 191 | num_frames: 62 192 | set: train 193 | eval_t: False 194 | year: 2017 195 | 196 | - name: deer 197 | attributes: [] 198 | num_frames: 79 199 | set: test-dev 200 | eval_t: False 201 | year: 2017 202 | 203 | - name: disc-jockey 204 | attributes: [] 205 | num_frames: 76 206 | set: train 207 | eval_t: False 208 | year: 2017 209 | 210 | - name: dog 211 | attributes: [FM, CS, MB, DEF, ROT, EA] 212 | num_frames: 60 213 | set: val 214 | eval_t: False 215 | year: 2016 216 | 217 | - name: dog-agility 218 | attributes: [FM, MB, DEF, OCC, HO, EA, OV, AC] 219 | num_frames: 25 220 | set: train 221 | eval_t: False 222 | year: 2016 223 | 224 | - name: dog-gooses 225 | attributes: [] 226 | num_frames: 86 227 | set: train 228 | eval_t: False 229 | year: 2017 230 | 231 | - name: dogs-jump 232 | attributes: [] 233 | num_frames: 66 234 | set: val 235 | eval_t: False 236 | year: 2017 237 | 238 | - name: dogs-scale 239 | attributes: [] 240 | num_frames: 83 241 | set: train 242 | eval_t: False 243 | year: 2017 244 | 245 | - name: drift-chicane 246 | attributes: [LR, SV, FM, DB, HO, ROT, EA, AC] 247 | num_frames: 52 248 | set: val 249 | eval_t: False 250 | year: 2016 251 | 252 | - name: drift-straight 253 | attributes: [LR, SV, FM, CS, MB, HO, ROT, EA, OV, AC] 254 | num_frames: 50 255 | set: val 256 | eval_t: True 257 | year: 2016 258 | 259 | - name: drift-turn 260 | attributes: [SV, FM, IO, DB, HO, ROT, OV, AC] 261 | num_frames: 64 262 | set: train 263 | eval_t: True 264 | year: 2016 265 | 266 | - name: drone 267 | attributes: [] 268 | num_frames: 91 269 | set: train 270 | eval_t: False 271 | year: 2017 272 | 273 | - name: elephant 274 | attributes: [CS, DB, DEF, EA] 275 | num_frames: 80 276 | set: train 277 | eval_t: True 278 | year: 2016 279 | 280 | - name: flamingo 281 | attributes: [SC, IO, DB, DEF, HO] 282 | num_frames: 80 283 | set: train 284 | eval_t: True 285 | year: 2016 286 | 287 | - name: giant-slalom 288 | attributes: [] 289 | num_frames: 127 290 | set: test-dev 291 | eval_t: False 292 | year: 2017 293 | 294 | - name: girl-dog 295 | attributes: [] 296 | num_frames: 86 297 | set: test-dev 298 | eval_t: False 299 | year: 2017 300 | 301 | - name: goat 302 | attributes: [CS, DEF, EA, BC] 303 | num_frames: 90 304 | set: val 305 | eval_t: False 306 | year: 2016 307 | 308 | - name: gold-fish 309 | attributes: [] 310 | num_frames: 78 311 | set: val 312 | eval_t: False 313 | year: 2017 314 | 315 | - name: golf 316 | attributes: [] 317 | num_frames: 79 318 | set: test-dev 319 | eval_t: False 320 | year: 2017 321 | 322 | - name: guitar-violin 323 | attributes: [] 324 | num_frames: 55 325 | set: test-dev 326 | eval_t: False 327 | year: 2017 328 | 329 | - name: gym 330 | attributes: [] 331 | num_frames: 60 332 | set: test-dev 333 | eval_t: False 334 | year: 2017 335 | 336 | - name: helicopter 337 | attributes: [] 338 | num_frames: 49 339 | set: test-dev 340 | eval_t: False 341 | year: 2017 342 | 343 | - name: hike 344 | attributes: [LR, DEF, HO] 345 | num_frames: 80 346 | set: train 347 | eval_t: True 348 | year: 2016 349 | 350 | - name: hockey 351 | attributes: [SC, IO, DEF, HO, ROT] 352 | num_frames: 75 353 | set: train 354 | eval_t: True 355 | year: 2016 356 | 357 | - name: horsejump-high 358 | attributes: [SC, IO, DEF, OCC, HO] 359 | num_frames: 50 360 | set: val 361 | eval_t: False 362 | year: 2016 363 | 364 | - name: horsejump-low 365 | attributes: [SC, IO, DEF, OCC, HO, EA] 366 | num_frames: 60 367 | set: train 368 | eval_t: False 369 | year: 2016 370 | 371 | - name: horsejump-stick 372 | attributes: [] 373 | num_frames: 58 374 | set: test-dev 375 | eval_t: False 376 | year: 2017 377 | 378 | - name: hoverboard 379 | attributes: [] 380 | num_frames: 81 381 | set: test-dev 382 | eval_t: False 383 | year: 2017 384 | 385 | - name: india 386 | attributes: [] 387 | num_frames: 81 388 | set: val 389 | eval_t: False 390 | year: 2017 391 | 392 | - name: judo 393 | attributes: [] 394 | num_frames: 34 395 | set: val 396 | eval_t: False 397 | year: 2017 398 | 399 | - name: kid-football 400 | attributes: [] 401 | num_frames: 68 402 | set: train 403 | eval_t: False 404 | year: 2017 405 | 406 | - name: kite-surf 407 | attributes: [SV, SC, IO, DB, MB, OCC, HO, EA] 408 | num_frames: 50 409 | set: val 410 | eval_t: True 411 | year: 2016 412 | 413 | - name: kite-walk 414 | attributes: [SC, IO, DB, DEF, OCC, HO] 415 | num_frames: 80 416 | set: train 417 | eval_t: True 418 | year: 2016 419 | 420 | - name: koala 421 | attributes: [] 422 | num_frames: 100 423 | set: train 424 | eval_t: False 425 | year: 2017 426 | 427 | - name: lab-coat 428 | attributes: [] 429 | num_frames: 47 430 | set: val 431 | eval_t: False 432 | year: 2017 433 | 434 | - name: lady-running 435 | attributes: [] 436 | num_frames: 65 437 | set: train 438 | eval_t: False 439 | year: 2017 440 | 441 | - name: libby 442 | attributes: [SC, MB, DEF, OCC, HO, EA] 443 | num_frames: 49 444 | set: val 445 | eval_t: False 446 | year: 2016 447 | 448 | - name: lindy-hop 449 | attributes: [] 450 | num_frames: 73 451 | set: train 452 | eval_t: False 453 | year: 2017 454 | 455 | - name: loading 456 | attributes: [] 457 | num_frames: 50 458 | set: val 459 | eval_t: False 460 | year: 2017 461 | 462 | - name: lock 463 | attributes: [] 464 | num_frames: 43 465 | set: test-dev 466 | eval_t: False 467 | year: 2017 468 | 469 | - name: longboard 470 | attributes: [] 471 | num_frames: 52 472 | set: train 473 | eval_t: False 474 | year: 2017 475 | 476 | 477 | - name: lucia 478 | attributes: [DEF, OCC, HO] 479 | num_frames: 70 480 | set: train 481 | eval_t: False 482 | year: 2016 483 | 484 | - name: mallard-fly 485 | attributes: [LR, SV, FM, DB, MB, DEF, ROT, EA, OV, AC] 486 | num_frames: 70 487 | set: train 488 | eval_t: False 489 | year: 2016 490 | 491 | - name: mallard-water 492 | attributes: [LR, IO, DB, EA] 493 | num_frames: 80 494 | set: train 495 | eval_t: True 496 | year: 2016 497 | 498 | - name: man-bike 499 | attributes: [] 500 | num_frames: 75 501 | set: test-dev 502 | eval_t: False 503 | year: 2017 504 | 505 | 506 | - name: mbike-trick 507 | attributes: [] 508 | num_frames: 79 509 | set: val 510 | eval_t: False 511 | year: 2017 512 | 513 | - name: miami-surf 514 | attributes: [] 515 | num_frames: 70 516 | set: train 517 | eval_t: False 518 | year: 2017 519 | 520 | - name: monkeys-trees 521 | attributes: [] 522 | num_frames: 83 523 | set: test-dev 524 | eval_t: False 525 | year: 2017 526 | 527 | 528 | - name: motocross-bumps 529 | attributes: [SV, FM, IO, HO, ROT, OV, AC, BC] 530 | num_frames: 60 531 | set: train 532 | eval_t: True 533 | year: 2016 534 | 535 | - name: motocross-jump 536 | attributes: [SV, SC, FM, IO, MB, DEF, HO, ROT, EA, OV, AC] 537 | num_frames: 40 538 | set: val 539 | eval_t: False 540 | year: 2016 541 | 542 | - name: motorbike 543 | attributes: [LR, SV, SC, FM, IO, OCC, HO, ROT, EA] 544 | num_frames: 43 545 | set: train 546 | eval_t: False 547 | year: 2016 548 | 549 | - name: mtb-race 550 | attributes: [] 551 | num_frames: 69 552 | set: test-dev 553 | eval_t: False 554 | year: 2017 555 | 556 | - name: night-race 557 | attributes: [] 558 | num_frames: 46 559 | set: train 560 | eval_t: False 561 | year: 2017 562 | 563 | - name: orchid 564 | attributes: [] 565 | num_frames: 57 566 | set: test-dev 567 | eval_t: False 568 | year: 2017 569 | 570 | - name: paragliding 571 | attributes: [LR, SC, IO, HO] 572 | num_frames: 70 573 | set: train 574 | eval_t: False 575 | year: 2016 576 | 577 | - name: paragliding-launch 578 | attributes: [SC, IO, DEF, HO, EA] 579 | num_frames: 80 580 | set: val 581 | eval_t: True 582 | year: 2016 583 | 584 | - name: parkour 585 | attributes: [LR, SV, FM, DEF, OCC, HO, ROT, AC] 586 | num_frames: 100 587 | set: val 588 | eval_t: False 589 | year: 2016 590 | 591 | - name: people-sunset 592 | attributes: [] 593 | num_frames: 67 594 | set: test-dev 595 | eval_t: False 596 | year: 2017 597 | 598 | - name: pigs 599 | attributes: [] 600 | num_frames: 79 601 | set: val 602 | eval_t: False 603 | year: 2017 604 | 605 | - name: planes-crossing 606 | attributes: [] 607 | num_frames: 31 608 | set: test-dev 609 | eval_t: False 610 | year: 2017 611 | 612 | - name: planes-water 613 | attributes: [] 614 | num_frames: 38 615 | set: train 616 | eval_t: False 617 | year: 2017 618 | 619 | - name: rallye 620 | attributes: [] 621 | num_frames: 50 622 | set: train 623 | eval_t: False 624 | year: 2017 625 | 626 | - name: rhino 627 | attributes: [DEF, OCC, BC] 628 | num_frames: 90 629 | set: train 630 | eval_t: True 631 | year: 2016 632 | 633 | - name: rollerblade 634 | attributes: [LR, FM, CS, MB, DEF, HO] 635 | num_frames: 35 636 | set: train 637 | eval_t: False 638 | year: 2016 639 | 640 | - name: rollercoaster 641 | attributes: [] 642 | num_frames: 70 643 | set: test-dev 644 | eval_t: False 645 | year: 2017 646 | 647 | - name: salsa 648 | attributes: [] 649 | num_frames: 86 650 | set: test-dev 651 | eval_t: False 652 | year: 2017 653 | 654 | - name: schoolgirls 655 | attributes: [] 656 | num_frames: 80 657 | set: train 658 | eval_t: False 659 | year: 2017 660 | 661 | - name: scooter-black 662 | attributes: [SV, IO, HO, EA] 663 | num_frames: 43 664 | set: val 665 | eval_t: True 666 | year: 2016 667 | 668 | - name: scooter-board 669 | attributes: [] 670 | num_frames: 91 671 | set: train 672 | eval_t: False 673 | year: 2017 674 | 675 | - name: scooter-gray 676 | attributes: [SC, FM, IO, OCC, HO, ROT, EA, BC] 677 | num_frames: 75 678 | set: train 679 | eval_t: False 680 | year: 2016 681 | 682 | - name: seasnake 683 | attributes: [] 684 | num_frames: 80 685 | set: test-dev 686 | eval_t: False 687 | year: 2017 688 | 689 | - name: sheep 690 | attributes: [] 691 | num_frames: 68 692 | set: train 693 | eval_t: False 694 | year: 2017 695 | 696 | - name: shooting 697 | attributes: [] 698 | num_frames: 40 699 | set: val 700 | eval_t: False 701 | year: 2017 702 | 703 | - name: skate-jump 704 | attributes: [] 705 | num_frames: 68 706 | set: test-dev 707 | eval_t: False 708 | year: 2017 709 | 710 | - name: skate-park 711 | attributes: [] 712 | num_frames: 80 713 | set: train 714 | eval_t: False 715 | year: 2017 716 | 717 | - name: slackline 718 | attributes: [] 719 | num_frames: 60 720 | set: test-dev 721 | eval_t: False 722 | year: 2017 723 | 724 | - name: snowboard 725 | attributes: [] 726 | num_frames: 66 727 | set: train 728 | eval_t: False 729 | year: 2017 730 | 731 | - name: soapbox 732 | attributes: [SV, IO, MB, DEF, HO, ROT, AC] 733 | num_frames: 99 734 | set: val 735 | eval_t: True 736 | year: 2016 737 | 738 | - name: soccerball 739 | attributes: [LR, FM, MB, OCC, HO] 740 | num_frames: 48 741 | set: train 742 | eval_t: False 743 | year: 2016 744 | 745 | - name: stroller 746 | attributes: [SC, FM, CS, IO, DEF, HO] 747 | num_frames: 91 748 | set: train 749 | eval_t: True 750 | year: 2016 751 | 752 | - name: stunt 753 | attributes: [] 754 | num_frames: 71 755 | set: train 756 | eval_t: False 757 | year: 2017 758 | 759 | - name: subway 760 | attributes: [] 761 | num_frames: 88 762 | set: test-dev 763 | eval_t: False 764 | year: 2017 765 | 766 | - name: surf 767 | attributes: [SV, FM, CS, IO, DB, HO, OV] 768 | num_frames: 55 769 | set: train 770 | eval_t: True 771 | year: 2016 772 | 773 | - name: swing 774 | attributes: [SC, FM, IO, DEF, OCC, HO] 775 | num_frames: 60 776 | set: train 777 | eval_t: False 778 | year: 2016 779 | 780 | - name: tandem 781 | attributes: [] 782 | num_frames: 72 783 | set: test-dev 784 | eval_t: False 785 | year: 2017 786 | 787 | 788 | - name: tennis 789 | attributes: [SV, FM, IO, MB, DEF, HO] 790 | num_frames: 70 791 | set: train 792 | eval_t: False 793 | year: 2016 794 | 795 | - name: tennis-vest 796 | attributes: [] 797 | num_frames: 75 798 | set: test-dev 799 | eval_t: False 800 | year: 2017 801 | 802 | - name: tractor 803 | attributes: [] 804 | num_frames: 65 805 | set: test-dev 806 | eval_t: False 807 | year: 2017 808 | 809 | - name: tractor-sand 810 | attributes: [] 811 | num_frames: 76 812 | set: train 813 | eval_t: False 814 | year: 2017 815 | 816 | - name: train 817 | attributes: [SC, HO, EA] 818 | num_frames: 80 819 | set: train 820 | eval_t: True 821 | year: 2016 822 | 823 | - name: tuk-tuk 824 | attributes: [] 825 | num_frames: 59 826 | set: train 827 | eval_t: False 828 | year: 2017 829 | 830 | - name: upside-down 831 | attributes: [] 832 | num_frames: 65 833 | set: train 834 | eval_t: False 835 | year: 2017 836 | 837 | - name: varanus-cage 838 | attributes: [] 839 | num_frames: 67 840 | set: train 841 | eval_t: False 842 | year: 2017 843 | 844 | - name: walking 845 | attributes: [] 846 | num_frames: 72 847 | set: train 848 | eval_t: False 849 | year: 2017 850 | -------------------------------------------------------------------------------- /libs/data/palette.txt: -------------------------------------------------------------------------------- 1 | 0 0 0 2 | 128 0 0 3 | 0 128 0 4 | 128 128 0 5 | 0 0 128 6 | 128 0 128 7 | 0 128 128 8 | 128 128 128 9 | 64 0 0 10 | 191 0 0 11 | 64 128 0 12 | 191 128 0 13 | 64 0 128 14 | 191 0 128 15 | 64 128 128 16 | 191 128 128 17 | 0 64 0 18 | 128 64 0 19 | 0 191 0 20 | 128 191 0 21 | 0 64 128 22 | 128 64 128 23 | 22 22 22 24 | 23 23 23 25 | 24 24 24 26 | 25 25 25 27 | 26 26 26 28 | 27 27 27 29 | 28 28 28 30 | 29 29 29 31 | 30 30 30 32 | 31 31 31 33 | 32 32 32 34 | 33 33 33 35 | 34 34 34 36 | 35 35 35 37 | 36 36 36 38 | 37 37 37 39 | 38 38 38 40 | 39 39 39 41 | 40 40 40 42 | 41 41 41 43 | 42 42 42 44 | 43 43 43 45 | 44 44 44 46 | 45 45 45 47 | 46 46 46 48 | 47 47 47 49 | 48 48 48 50 | 49 49 49 51 | 50 50 50 52 | 51 51 51 53 | 52 52 52 54 | 53 53 53 55 | 54 54 54 56 | 55 55 55 57 | 56 56 56 58 | 57 57 57 59 | 58 58 58 60 | 59 59 59 61 | 60 60 60 62 | 61 61 61 63 | 62 62 62 64 | 63 63 63 65 | 64 64 64 66 | 65 65 65 67 | 66 66 66 68 | 67 67 67 69 | 68 68 68 70 | 69 69 69 71 | 70 70 70 72 | 71 71 71 73 | 72 72 72 74 | 73 73 73 75 | 74 74 74 76 | 75 75 75 77 | 76 76 76 78 | 77 77 77 79 | 78 78 78 80 | 79 79 79 81 | 80 80 80 82 | 81 81 81 83 | 82 82 82 84 | 83 83 83 85 | 84 84 84 86 | 85 85 85 87 | 86 86 86 88 | 87 87 87 89 | 88 88 88 90 | 89 89 89 91 | 90 90 90 92 | 91 91 91 93 | 92 92 92 94 | 93 93 93 95 | 94 94 94 96 | 95 95 95 97 | 96 96 96 98 | 97 97 97 99 | 98 98 98 100 | 99 99 99 101 | 100 100 100 102 | 101 101 101 103 | 102 102 102 104 | 103 103 103 105 | 104 104 104 106 | 105 105 105 107 | 106 106 106 108 | 107 107 107 109 | 108 108 108 110 | 109 109 109 111 | 110 110 110 112 | 111 111 111 113 | 112 112 112 114 | 113 113 113 115 | 114 114 114 116 | 115 115 115 117 | 116 116 116 118 | 117 117 117 119 | 118 118 118 120 | 119 119 119 121 | 120 120 120 122 | 121 121 121 123 | 122 122 122 124 | 123 123 123 125 | 124 124 124 126 | 125 125 125 127 | 126 126 126 128 | 127 127 127 129 | 128 128 128 130 | 129 129 129 131 | 130 130 130 132 | 131 131 131 133 | 132 132 132 134 | 133 133 133 135 | 134 134 134 136 | 135 135 135 137 | 136 136 136 138 | 137 137 137 139 | 138 138 138 140 | 139 139 139 141 | 140 140 140 142 | 141 141 141 143 | 142 142 142 144 | 143 143 143 145 | 144 144 144 146 | 145 145 145 147 | 146 146 146 148 | 147 147 147 149 | 148 148 148 150 | 149 149 149 151 | 150 150 150 152 | 151 151 151 153 | 152 152 152 154 | 153 153 153 155 | 154 154 154 156 | 155 155 155 157 | 156 156 156 158 | 157 157 157 159 | 158 158 158 160 | 159 159 159 161 | 160 160 160 162 | 161 161 161 163 | 162 162 162 164 | 163 163 163 165 | 164 164 164 166 | 165 165 165 167 | 166 166 166 168 | 167 167 167 169 | 168 168 168 170 | 169 169 169 171 | 170 170 170 172 | 171 171 171 173 | 172 172 172 174 | 173 173 173 175 | 174 174 174 176 | 175 175 175 177 | 176 176 176 178 | 177 177 177 179 | 178 178 178 180 | 179 179 179 181 | 180 180 180 182 | 181 181 181 183 | 182 182 182 184 | 183 183 183 185 | 184 184 184 186 | 185 185 185 187 | 186 186 186 188 | 187 187 187 189 | 188 188 188 190 | 189 189 189 191 | 190 190 190 192 | 191 191 191 193 | 192 192 192 194 | 193 193 193 195 | 194 194 194 196 | 195 195 195 197 | 196 196 196 198 | 197 197 197 199 | 198 198 198 200 | 199 199 199 201 | 200 200 200 202 | 201 201 201 203 | 202 202 202 204 | 203 203 203 205 | 204 204 204 206 | 205 205 205 207 | 206 206 206 208 | 207 207 207 209 | 208 208 208 210 | 209 209 209 211 | 210 210 210 212 | 211 211 211 213 | 212 212 212 214 | 213 213 213 215 | 214 214 214 216 | 215 215 215 217 | 216 216 216 218 | 217 217 217 219 | 218 218 218 220 | 219 219 219 221 | 220 220 220 222 | 221 221 221 223 | 222 222 222 224 | 223 223 223 225 | 224 224 224 226 | 225 225 225 227 | 226 226 226 228 | 227 227 227 229 | 228 228 228 230 | 229 229 229 231 | 230 230 230 232 | 231 231 231 233 | 232 232 232 234 | 233 233 233 235 | 234 234 234 236 | 235 235 235 237 | 236 236 236 238 | 237 237 237 239 | 238 238 238 240 | 239 239 239 241 | 240 240 240 242 | 241 241 241 243 | 242 242 242 244 | 243 243 243 245 | 244 244 244 246 | 245 245 245 247 | 246 246 246 248 | 247 247 247 249 | 248 248 248 250 | 249 249 249 251 | 250 250 250 252 | 251 251 251 253 | 252 252 252 254 | 253 253 253 255 | 254 254 254 256 | 255 255 255 257 | -------------------------------------------------------------------------------- /libs/loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import random 4 | import scipy.io 5 | from PIL import Image 6 | import cv2 7 | import torch 8 | from os.path import exists, join, split 9 | import libs.transforms_multi as transforms 10 | from torchvision import datasets 11 | 12 | def video_loader(video_path, frame_end, step, frame_start=0): 13 | cap = cv2.VideoCapture(video_path) 14 | cap.set(1, frame_start - 1) 15 | video = [] 16 | for i in range(frame_start - 1, frame_end, step): 17 | cap.set(1, i) 18 | success, image = cap.read() 19 | if not success: 20 | raise Exception('Error while reading video {}'.format(video_path)) 21 | pil_im = image 22 | video.append(pil_im) 23 | return video 24 | 25 | 26 | def framepair_loader(video_path, frame_start, frame_end): 27 | 28 | cap = cv2.VideoCapture(video_path) 29 | 30 | pair = [] 31 | id_ = np.zeros(2) 32 | frame_num = frame_end - frame_start 33 | if frame_end > 50: 34 | id_[0] = random.randint(frame_start, frame_end-50) 35 | id_[1] = id_[0] + random.randint(1, 50) 36 | else: 37 | id_[0] = random.randint(frame_start, frame_end) 38 | id_[1] = random.randint(frame_start, frame_end) 39 | 40 | 41 | for ii in range(2): 42 | 43 | cap.set(1, id_[ii]) 44 | 45 | success, image = cap.read() 46 | 47 | if not success: 48 | print("id, frame_end:", id_, frame_end) 49 | raise Exception('Error while reading video {}'.format(video_path)) 50 | 51 | h,w,_ = image.shape 52 | h = (h // 64) * 64 53 | w = (w // 64) * 64 54 | image = cv2.resize(image, (w,h)) 55 | image = image.astype(np.uint8) 56 | pil_im = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) 57 | pair.append(pil_im) 58 | 59 | return pair 60 | 61 | def video_frame_counter(video_path): 62 | cap = cv2.VideoCapture(video_path) 63 | return cap.get(7) 64 | 65 | 66 | class VidListv1(torch.utils.data.Dataset): 67 | # for warm up, random crop both 68 | def __init__(self, video_path, list_path, patch_size, rotate = 10, scale=1.2, is_train=True, moreaug= True): 69 | super(VidListv1, self).__init__() 70 | self.data_dir = video_path 71 | self.list_path = list_path 72 | normalize = transforms.Normalize(mean = (128, 128, 128), std = (128, 128, 128)) 73 | 74 | t = [] 75 | if rotate > 0: 76 | t.append(transforms.RandomRotate(rotate)) 77 | if scale > 0: 78 | t.append(transforms.RandomScale(scale)) 79 | t.extend([transforms.RandomCrop(patch_size, seperate =moreaug), transforms.RandomHorizontalFlip(), transforms.ToTensor(), 80 | normalize]) 81 | 82 | self.transforms = transforms.Compose(t) 83 | 84 | self.is_train = is_train 85 | self.read_list() 86 | 87 | def __getitem__(self, idx): 88 | while True: 89 | video_ = self.list[idx] 90 | frame_end = video_frame_counter(video_)-1 91 | if frame_end <=0: 92 | print("Empty video {}, skip to the next".format(self.list[idx])) 93 | idx += 1 94 | else: 95 | break 96 | 97 | pair_ = framepair_loader(video_, 0, frame_end) 98 | data = list(self.transforms(*pair_)) 99 | return tuple(data) 100 | 101 | def __len__(self): 102 | return len(self.list) 103 | 104 | def read_list(self): 105 | path = join(self.list_path) 106 | root = path.partition("Kinetices/")[0] 107 | if not exists(path): 108 | raise Exception("{} does not exist in kinet_dataset.py.".format(path)) 109 | self.list = [line.replace("/Data/", root).strip() for line in open(path, 'r')] 110 | 111 | 112 | class VidListv2(torch.utils.data.Dataset): 113 | # for localization, random crop frame1 114 | def __init__(self, video_path, list_path, patch_size, window_len, rotate = 10, scale = 1.2, full_size = 640, is_train=True): 115 | super(VidListv2, self).__init__() 116 | self.data_dir = video_path 117 | self.list_path = list_path 118 | self.window_len = window_len 119 | normalize = transforms.Normalize(mean = (128, 128, 128), std = (128, 128, 128)) 120 | self.transforms1 = transforms.Compose([ 121 | transforms.RandomRotate(rotate), 122 | # transforms.RandomScale(scale), 123 | transforms.ResizeandPad(full_size), 124 | transforms.RandomCrop(patch_size), 125 | transforms.ToTensor(), 126 | normalize]) 127 | self.transforms2 = transforms.Compose([ 128 | transforms.ResizeandPad(full_size), 129 | transforms.ToTensor(), 130 | normalize]) 131 | self.is_train = is_train 132 | self.read_list() 133 | 134 | def __getitem__(self, idx): 135 | while True: 136 | video_ = self.list[idx] 137 | frame_end = video_frame_counter(video_)-1 138 | if frame_end <=0: 139 | print("Empty video {}, skip to the next".format(self.list[idx])) 140 | idx += 1 141 | else: 142 | break 143 | 144 | pair_ = framepair_loader(video_, 0, frame_end) 145 | data1 = list(self.transforms1(*pair_)) 146 | data2 = list(self.transforms2(*pair_)) 147 | if self.window_len == 2: 148 | data = [data1[0],data2[1]] 149 | else: 150 | data = [data1[0],data2[1], data2[2]] 151 | return tuple(data) 152 | 153 | def __len__(self): 154 | return len(self.list) 155 | 156 | def read_list(self): 157 | path = join(self.list_path) 158 | root = path.partition("Kinetices/")[0] 159 | if not exists(path): 160 | raise Exception("{} does not exist in kinet_dataset.py.".format(path)) 161 | self.list = [line.replace("/Data/", root).strip() for line in open(path, 'r')] 162 | 163 | 164 | 165 | if __name__ == '__main__': 166 | normalize = transforms.Normalize(mean = (128, 128, 128), 167 | std = (128, 128, 128)) 168 | t = [] 169 | t.extend([transforms.RandomCrop(256), 170 | transforms.RandomHorizontalFlip(), 171 | transforms.ToTensor(), 172 | normalize]) 173 | dataset_train = VidList('/home/xtli/DATA/compress/train_256/', 174 | '/home/xtli/DATA/compress/train.txt', 175 | transforms.Compose(t), window_len=2) 176 | 177 | train_loader = torch.utils.data.DataLoader(dataset_train, 178 | batch_size = 16, 179 | shuffle = True, 180 | num_workers=8, 181 | drop_last=True) 182 | 183 | start_time = time.time() 184 | for i, (frames) in enumerate(train_loader): 185 | print(i) 186 | if(i >= 1000): 187 | break 188 | end_time = time.time() 189 | print((end_time - start_time) / 1000) 190 | -------------------------------------------------------------------------------- /libs/loss.py: -------------------------------------------------------------------------------- 1 | from numpy.random import randint 2 | import torch 3 | import torch.nn as nn 4 | 5 | def weight_L1(pred, gt, w_a, w_b, mask=None): 6 | l1 = nn.L1Loss() 7 | if mask is not None: 8 | pred = pred * mask.repeat(1,pred.size(1),1,1) 9 | gt = gt * mask.repeat(1,gt.size(1),1,1) 10 | l_l = l1(pred[:,0,:,:], gt[:,0,:,:]) 11 | l_a = l1(pred[:,1,:,:], gt[:,1,:,:]) 12 | l_b = l1(pred[:,2,:,:], gt[:,2,:,:]) 13 | loss = l_l + l_a * w_a + l_b * w_b 14 | return loss 15 | 16 | def weightsingle_L1(pred, gt, w_a, w_b): 17 | n,c,h,w = pred.size() 18 | l_l = torch.abs(pred[:,0,:,:] - gt[:,0,:,:]) 19 | l_a = torch.abs(pred[:,1,:,:] - gt[:,1,:,:]) 20 | l_b = torch.abs(pred[:,2,:,:] - gt[:,2,:,:]) 21 | 22 | l_l = torch.mean(l_l.view(n,-1),dim=1) 23 | l_a = torch.mean(l_a.view(n,-1),dim=1) 24 | l_b = torch.mean(l_b.view(n,-1),dim=1) 25 | loss = l_l + l_a * w_a + l_b * w_b 26 | return loss 27 | 28 | def switch_L1_thr(pred2, frame2_var, pred1, frame1_var, w_a, w_b, thr): 29 | loss1 = weightsingle_L1(pred2, frame2_var, w_a, w_b) 30 | loss2 = weightsingle_L1(pred1, frame1_var, w_a, w_b) 31 | loss = loss1 + 0.1 * loss2 32 | loss[loss1 > thr] = 0 33 | return torch.mean(loss) 34 | 35 | def L1_thr(pred2, frame2_var, w_a, w_b, thr): 36 | loss = weightsingle_L1(pred2, frame2_var, w_a, w_b) 37 | loss[loss > thr] = 0 38 | return torch.mean(loss) 39 | 40 | # merge the above codes 41 | def L1_loss(pred2, frame2_var, w_a, w_b, thr=None, pred1=None, frame1_var=None): 42 | if pred1 is None: 43 | loss = weightsingle_L1(pred2, frame2_var, w_a, w_b) 44 | else: 45 | loss1 = weightsingle_L1(pred2, frame2_var, w_a, w_b) 46 | loss2 = weightsingle_L1(pred1, frame1_var, w_a, w_b) 47 | loss = loss1 + 0.1 * loss2 48 | if thr is not None: 49 | loss[loss > thr] = 0 50 | return torch.mean(loss) -------------------------------------------------------------------------------- /libs/model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | 5 | from libs.utils import * 6 | from libs.autoencoder import encoder3, decoder3, encoder_res18 7 | 8 | class NLM_woSoft(nn.Module): 9 | """ 10 | Non-local mean layer w/o softmax on affinity 11 | """ 12 | def __init__(self): 13 | super(NLM_woSoft, self).__init__() 14 | 15 | def forward(self, in1, in2): 16 | n,c,h,w = in1.size() 17 | in1 = in1.view(n,c,-1) 18 | in2 = in2.view(n,c,-1) 19 | 20 | affinity = torch.bmm(in1.permute(0,2,1), in2) 21 | return affinity 22 | 23 | def transform(aff, frame1): 24 | """ 25 | Given aff, copy from frame1 to construct frame2. 26 | INPUTS: 27 | - aff: (h*w)*(h*w) affinity matrix 28 | - frame1: n*c*h*w feature map 29 | """ 30 | b,c,h,w = frame1.size() 31 | frame1 = frame1.view(b,c,-1) 32 | frame2 = torch.bmm(frame1, aff) 33 | return frame2.view(b,c,h,w) 34 | 35 | class normalize(nn.Module): 36 | """Given mean: (R, G, B) and std: (R, G, B), 37 | will normalize each channel of the torch.*Tensor, i.e. 38 | channel = (channel - mean) / std 39 | """ 40 | 41 | def __init__(self, mean, std = (1.0,1.0,1.0)): 42 | super(normalize, self).__init__() 43 | self.mean = nn.Parameter(torch.FloatTensor(mean).cuda(), requires_grad=False) 44 | self.std = nn.Parameter(torch.FloatTensor(std).cuda(), requires_grad=False) 45 | 46 | def forward(self, frames): 47 | b,c,h,w = frames.size() 48 | frames = (frames - self.mean.view(1,3,1,1).repeat(b,1,h,w))/self.std.view(1,3,1,1).repeat(b,1,h,w) 49 | return frames 50 | 51 | def create_flat_grid(F_size, GPU=True): 52 | """ 53 | INPUTS: 54 | - F_size: feature size 55 | OUTPUT: 56 | - return a standard grid coordinate 57 | """ 58 | b, c, h, w = F_size 59 | theta = torch.tensor([[1,0,0],[0,1,0]]) 60 | theta = theta.unsqueeze(0).repeat(b,1,1) 61 | theta = theta.float() 62 | 63 | # grid is a uniform grid with left top (-1,1) and right bottom (1,1) 64 | # b * (h*w) * 2 65 | grid = torch.nn.functional.affine_grid(theta, F_size) 66 | grid[:,:,:,0] = (grid[:,:,:,0]+1)/2 * w 67 | grid[:,:,:,1] = (grid[:,:,:,1]+1)/2 * h 68 | grid_flat = grid.view(b,-1,2) 69 | if(GPU): 70 | grid_flat = grid_flat.cuda() 71 | return grid_flat 72 | 73 | 74 | class track_match_comb(nn.Module): 75 | def __init__(self, pretrained, encoder_dir = None, decoder_dir = None, temp=1, color_switch=True, coord_switch=True): 76 | super(track_match_comb, self).__init__() 77 | 78 | self.gray_encoder = encoder_res18(pretrained=pretrained, uselayer=4) 79 | self.rgb_encoder = encoder3(reduce=True) 80 | self.decoder = decoder3(reduce=True) 81 | 82 | self.rgb_encoder.load_state_dict(torch.load(encoder_dir)) 83 | self.decoder.load_state_dict(torch.load(decoder_dir)) 84 | for param in self.decoder.parameters(): 85 | param.requires_grad = False 86 | for param in self.rgb_encoder.parameters(): 87 | param.requires_grad = False 88 | 89 | self.nlm = NLM_woSoft() 90 | self.normalize = normalize(mean=[0.485, 0.456, 0.406], 91 | std=[0.229, 0.224, 0.225]) 92 | self.softmax = nn.Softmax(dim=1) 93 | self.temp = temp 94 | self.grid_flat = None 95 | self.grid_flat_crop = None 96 | self.color_switch = color_switch 97 | self.coord_switch = coord_switch 98 | 99 | 100 | def forward(self, img_ref, img_tar, warm_up=True, patch_size=None): 101 | n, c, h_ref, w_ref = img_ref.size() 102 | n, c, h_tar, w_tar = img_tar.size() 103 | gray_ref = copy.deepcopy(img_ref[:,0].view(n,1,h_ref,w_ref).repeat(1,3,1,1)) 104 | gray_tar = copy.deepcopy(img_tar[:,0].view(n,1,h_tar,w_tar).repeat(1,3,1,1)) 105 | 106 | gray_ref = (gray_ref + 1) / 2 107 | gray_tar = (gray_tar + 1) / 2 108 | 109 | gray_ref = self.normalize(gray_ref) 110 | gray_tar = self.normalize(gray_tar) 111 | 112 | Fgray1 = self.gray_encoder(gray_ref) 113 | Fgray2 = self.gray_encoder(gray_tar) 114 | Fcolor1 = self.rgb_encoder(img_ref) 115 | 116 | output = [] 117 | 118 | if warm_up: 119 | aff = self.nlm(Fgray1, Fgray2) 120 | aff_norm = self.softmax(aff) 121 | Fcolor2_est = transform(aff_norm, Fcolor1) 122 | color2_est = self.decoder(Fcolor2_est) 123 | 124 | output.append(color2_est) 125 | output.append(aff) 126 | 127 | if self.color_switch: 128 | Fcolor2 = self.rgb_encoder(img_tar) 129 | Fcolor1_est = transform(aff_norm.transpose(1,2), Fcolor2) 130 | color1_est = self.decoder(Fcolor1_est) 131 | output.append(color1_est) 132 | else: 133 | if(self.grid_flat is None): 134 | self.grid_flat = create_flat_grid(Fgray2.size()) 135 | aff_ref_tar = self.nlm(Fgray1, Fgray2) 136 | aff_ref_tar = torch.nn.functional.softmax(aff_ref_tar * self.temp, dim = 2) 137 | coords = torch.bmm(aff_ref_tar, self.grid_flat) 138 | new_c = coords2bbox(coords, patch_size, h_tar, w_tar) 139 | Fgray2_crop = diff_crop(Fgray2, new_c[:,0], new_c[:,2], new_c[:,1], new_c[:,3], patch_size[1], patch_size[0]) 140 | 141 | aff_p = self.nlm(Fgray1, Fgray2_crop) 142 | aff_norm = self.softmax(aff_p * self.temp) 143 | Fcolor2_est = transform(aff_norm, Fcolor1) 144 | color2_est = self.decoder(Fcolor2_est) 145 | 146 | Fcolor2_full = self.rgb_encoder(img_tar) 147 | Fcolor2_crop = diff_crop(Fcolor2_full, new_c[:,0], new_c[:,2], new_c[:,1], new_c[:,3], patch_size[1], patch_size[0]) 148 | 149 | output.append(color2_est) 150 | output.append(Fcolor2_crop) 151 | output.append(aff_p) 152 | output.append(new_c*8) 153 | output.append(coords) 154 | 155 | # color orthorganal 156 | if self.color_switch: 157 | Fcolor1_est = transform(aff_norm.transpose(1,2), Fcolor2_crop) 158 | color1_est = self.decoder(Fcolor1_est) 159 | output.append(color1_est) 160 | 161 | # coord orthorganal 162 | if self.coord_switch: 163 | aff_norm_tran = self.softmax(aff_p.permute(0,2,1)*self.temp) 164 | if self.grid_flat_crop is None: 165 | self.grid_flat_crop = create_flat_grid(Fp_tar.size()).permute(0,2,1).detach() 166 | C12 = torch.bmm(self.grid_flat_crop, aff_norm) 167 | C11 = torch.bmm(C12, aff_norm_tran) 168 | output.append(self.grid_flat_crop) 169 | output.append(C11) 170 | 171 | # return pred1, pred2, aff_p, new_c * 8, self.grid_flat_crop, C11, coords 172 | return output 173 | 174 | 175 | class Model_switchGTfixdot_swCC_Res(nn.Module): 176 | def __init__(self, encoder_dir = None, decoder_dir = None, 177 | temp = None, pretrainRes = False, uselayer=4): 178 | ''' 179 | For switchable concenration loss 180 | Using Resnet18 181 | ''' 182 | super(Model_switchGTfixdot_swCC_Res, self).__init__() 183 | self.gray_encoder = encoder_res18(pretrained = pretrainRes, uselayer=uselayer) 184 | self.rgb_encoder = encoder3(reduce = True) 185 | self.nlm = NLM_woSoft() 186 | self.decoder = decoder3(reduce = True) 187 | self.temp = temp 188 | self.softmax = nn.Softmax(dim=1) 189 | self.cos_window = torch.Tensor(np.outer(np.hanning(40), np.hanning(40))).cuda() 190 | self.normalize = normalize(mean=[0.485, 0.456, 0.406], 191 | std=[0.229, 0.224, 0.225]) 192 | 193 | if(not encoder_dir is None): 194 | print("Using pretrained encoders: %s."%encoder_dir) 195 | self.rgb_encoder.load_state_dict(torch.load(encoder_dir)) 196 | if(not decoder_dir is None): 197 | print("Using pretrained decoders: %s."%decoder_dir) 198 | self.decoder.load_state_dict(torch.load(decoder_dir)) 199 | 200 | for param in self.decoder.parameters(): 201 | param.requires_grad = False 202 | for param in self.rgb_encoder.parameters(): 203 | param.requires_grad = False 204 | 205 | def forward(self, gray1, gray2, color1=None, color2=None): 206 | gray1 = (gray1 + 1) / 2 207 | gray2 = (gray2 + 1) / 2 208 | 209 | gray1 = self.normalize(gray1) 210 | gray2 = self.normalize(gray2) 211 | 212 | 213 | Fgray1 = self.gray_encoder(gray1) 214 | Fgray2 = self.gray_encoder(gray2) 215 | 216 | aff = self.nlm(Fgray1, Fgray2) 217 | aff_norm = self.softmax(aff*self.temp) 218 | 219 | if(color1 is None): 220 | # for testing 221 | return aff_norm, Fgray1, Fgray2 222 | 223 | Fcolor1 = self.rgb_encoder(color1) 224 | Fcolor2 = self.rgb_encoder(color2) 225 | Fcolor2_est = transform(aff_norm, Fcolor1) 226 | pred2 = self.decoder(Fcolor2_est) 227 | 228 | Fcolor1_est = transform(aff_norm.transpose(1,2), Fcolor2) 229 | pred1 = self.decoder(Fcolor1_est) 230 | 231 | return pred1, pred2, aff_norm, aff, Fgray1, Fgray2 232 | -------------------------------------------------------------------------------- /libs/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | def conv1x1(in_planes, out_planes, stride=1): 25 | """1x1 convolution""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None): 33 | super(BasicBlock, self).__init__() 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | identity = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | identity = self.downsample(x) 54 | 55 | out += identity 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None): 65 | super(Bottleneck, self).__init__() 66 | self.conv1 = conv1x1(inplanes, planes) 67 | self.bn1 = nn.BatchNorm2d(planes) 68 | self.conv2 = conv3x3(planes, planes, stride) 69 | self.bn2 = nn.BatchNorm2d(planes) 70 | self.conv3 = conv1x1(planes, planes * self.expansion) 71 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.downsample = downsample 74 | self.stride = stride 75 | 76 | def forward(self, x): 77 | identity = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv3(out) 88 | out = self.bn3(out) 89 | 90 | if self.downsample is not None: 91 | identity = self.downsample(x) 92 | 93 | out += identity 94 | out = self.relu(out) 95 | 96 | return out 97 | 98 | 99 | class ResNet(nn.Module): 100 | 101 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, uselayer=3): 102 | super(ResNet, self).__init__() 103 | self.inplanes = 64 104 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 105 | bias=False) 106 | self.bn1 = nn.BatchNorm2d(64) 107 | self.relu = nn.ReLU(inplace=True) 108 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 109 | self.layer1 = self._make_layer(block, 64, layers[0]) 110 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 111 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1) 112 | if uselayer==4: 113 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1) 114 | self.uselayer = uselayer 115 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 116 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 117 | 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 121 | elif isinstance(m, nn.BatchNorm2d): 122 | nn.init.constant_(m.weight, 1) 123 | nn.init.constant_(m.bias, 0) 124 | 125 | # Zero-initialize the last BN in each residual branch, 126 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 127 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 128 | if zero_init_residual: 129 | for m in self.modules(): 130 | if isinstance(m, Bottleneck): 131 | nn.init.constant_(m.bn3.weight, 0) 132 | elif isinstance(m, BasicBlock): 133 | nn.init.constant_(m.bn2.weight, 0) 134 | 135 | def _make_layer(self, block, planes, blocks, stride=1): 136 | downsample = None 137 | if stride != 1 or self.inplanes != planes * block.expansion: 138 | downsample = nn.Sequential( 139 | conv1x1(self.inplanes, planes * block.expansion, stride), 140 | nn.BatchNorm2d(planes * block.expansion), 141 | ) 142 | 143 | layers = [] 144 | layers.append(block(self.inplanes, planes, stride, downsample)) 145 | self.inplanes = planes * block.expansion 146 | for _ in range(1, blocks): 147 | layers.append(block(self.inplanes, planes)) 148 | 149 | return nn.Sequential(*layers) 150 | 151 | def forward(self, x): 152 | x = self.conv1(x) 153 | x = self.bn1(x) 154 | x = self.relu(x) 155 | x = self.maxpool(x) 156 | 157 | x = self.layer1(x) 158 | x = self.layer2(x) 159 | x = self.layer3(x) 160 | if self.uselayer == 4: 161 | x = self.layer4(x) 162 | 163 | # x = self.avgpool(x) 164 | # x = x.view(x.size(0), -1) 165 | # x = self.fc(x) 166 | 167 | return x 168 | 169 | 170 | def resnet18(pretrained=False, **kwargs): 171 | """Constructs a ResNet-18 model. 172 | 173 | Args: 174 | pretrained (bool): If True, returns a model pre-trained on ImageNet 175 | """ 176 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 177 | if pretrained: 178 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), model_dir = "ae_models") 179 | return model 180 | 181 | 182 | 183 | def resnet50(pretrained=False, **kwargs): 184 | """Constructs a ResNet-50 model. 185 | 186 | Args: 187 | pretrained (bool): If True, returns a model pre-trained on ImageNet 188 | """ 189 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 190 | if pretrained: 191 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 192 | return model 193 | -------------------------------------------------------------------------------- /libs/test_utils.py: -------------------------------------------------------------------------------- 1 | # OS libraries 2 | import os 3 | import cv2 4 | import glob 5 | import scipy.misc 6 | import numpy as np 7 | from PIL import Image 8 | 9 | # Pytorch 10 | import torch 11 | 12 | # Customized libraries 13 | import libs.transforms_pair as transforms 14 | 15 | color_palette = np.loadtxt('libs/data/palette.txt',dtype=np.uint8).reshape(-1,3) 16 | 17 | def transform_topk(aff, frame1, k, h2=None, w2=None): 18 | """ 19 | INPUTS: 20 | - aff: affinity matrix, b * N * N 21 | - frame1: reference frame 22 | - k: only aggregate top-k pixels with highest aff(j,i) 23 | - h2, w2, frame2's height & width 24 | OUTPUT: 25 | - frame2: propagated mask from frame1 to the next frame 26 | """ 27 | b,c,h,w = frame1.size() 28 | b, N1, N2 = aff.size() 29 | # b * 20 * N 30 | tk_val, tk_idx = torch.topk(aff, dim = 1, k=k) 31 | # b * N 32 | tk_val_min,_ = torch.min(tk_val,dim=1) 33 | tk_val_min = tk_val_min.view(b,1,N2) 34 | aff[tk_val_min > aff] = 0 35 | frame1 = frame1.contiguous().view(b,c,-1) 36 | frame2 = torch.bmm(frame1, aff) 37 | if(h2 is None): 38 | return frame2.view(b,c,h,w) 39 | else: 40 | return frame2.view(b,c,h2,w2) 41 | 42 | def read_frame_list(video_dir): 43 | frame_list = [img for img in glob.glob(os.path.join(video_dir,"*.jpg"))] 44 | frame_list = sorted(frame_list) 45 | return frame_list 46 | 47 | def create_transforms(): 48 | normalize = transforms.Normalize(mean = (128, 128, 128), std = (128, 128, 128)) 49 | t = [] 50 | t.extend([transforms.ToTensor(), 51 | normalize]) 52 | return transforms.Compose(t) 53 | 54 | def read_frame(frame_dir, transforms, scale_size): 55 | """ 56 | read a single frame & preprocess 57 | """ 58 | frame = cv2.imread(frame_dir) 59 | ori_h,ori_w,_ = frame.shape 60 | # scale, makes height & width multiples of 64 61 | if(len(scale_size) == 1): 62 | if(ori_h > ori_w): 63 | tw = scale_size[0] 64 | th = (tw * ori_h) / ori_w 65 | th = int((th // 64) * 64) 66 | else: 67 | th = scale_size[0] 68 | tw = (th * ori_w) / ori_h 69 | tw = int((tw // 64) * 64) 70 | else: 71 | tw = scale_size[1] 72 | th = scale_size[0] 73 | frame = cv2.resize(frame, (tw,th)) 74 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2LAB) 75 | 76 | pair = [frame, frame] 77 | transformed = list(transforms(*pair)) 78 | return transformed[0].cuda().unsqueeze(0), ori_h, ori_w 79 | 80 | def to_one_hot(y_tensor, n_dims=None): 81 | """ 82 | Take integer y (tensor or variable) with n dims & 83 | convert it to 1-hot representation with n+1 dims. 84 | """ 85 | if(n_dims is None): 86 | n_dims = int(y_tensor.max()+ 1) 87 | _,h,w = y_tensor.size() 88 | y_tensor = y_tensor.type(torch.LongTensor).view(-1, 1) 89 | n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1 90 | y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1) 91 | y_one_hot = y_one_hot.view(h,w,n_dims) 92 | return y_one_hot.permute(2,0,1).unsqueeze(0) 93 | 94 | def read_seg(seg_dir, scale_size): 95 | seg = Image.open(seg_dir) 96 | h,w = seg.size 97 | if(len(scale_size) == 1): 98 | if(h > w): 99 | tw = scale_size[0] 100 | th = (tw * h) / w 101 | th = int((th // 64) * 64) 102 | else: 103 | th = scale_size[0] 104 | tw = (th * w) / h 105 | tw = int((tw // 64) * 64) 106 | else: 107 | tw = scale_size[1] 108 | th = scale_size[0] 109 | seg = np.asarray(seg).reshape((w,h,1)) 110 | seg_ori = np.squeeze(seg) 111 | small_seg = scipy.misc.imresize(seg_ori, (tw//8,th//8),"nearest",mode="F") 112 | large_seg = scipy.misc.imresize(seg_ori, (tw,th),"nearest",mode="F") 113 | 114 | t = [] 115 | t.extend([transforms.ToTensor()]) 116 | trans = transforms.Compose(t) 117 | pair = [large_seg, small_seg] 118 | transformed = list(trans(*pair)) 119 | large_seg = transformed[0] 120 | small_seg = transformed[1] 121 | return to_one_hot(large_seg), to_one_hot(small_seg), seg_ori 122 | 123 | def imwrite_indexed(filename,array): 124 | """ Save indexed png for DAVIS.""" 125 | if np.atleast_3d(array).shape[2] != 1: 126 | raise Exception("Saving indexed PNGs requires 2D array.") 127 | 128 | im = Image.fromarray(array) 129 | im.putpalette(color_palette.ravel()) 130 | im.save(filename, format='PNG') 131 | -------------------------------------------------------------------------------- /libs/track_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import copy 4 | import torch 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | import torchvision.utils as vutils 9 | 10 | from libs.vis_utils import norm_mask 11 | 12 | ############################# GLOBAL VARIABLES ######################## 13 | color_platte = [[0, 0, 0],[128, 0, 0],[0, 128, 0],[128, 128, 0], 14 | [0, 0, 128],[128, 0, 128],[0, 128, 128],[128, 128, 128], 15 | [64, 0, 0],[192, 0, 0],[64, 128, 0],[192, 128, 0],[64, 0, 128], 16 | [192, 0, 128],[64, 128, 128],[192, 128, 128],[0, 64, 0], 17 | [128, 64, 0],[0, 192, 0],[128, 192, 0],[0, 64, 128]] 18 | color_platte = np.array(color_platte) 19 | 20 | ############################# HELPER FUNCTIONS ######################## 21 | class BBox(): 22 | """ 23 | bounding box class 24 | """ 25 | def __init__(self, left, right, top, bottom, margin, h, w): 26 | if(margin > 0): 27 | bb_w = float(right - left) 28 | bb_h = float(bottom - top) 29 | margin_h = (bb_h * margin) / 2 30 | margin_w = (bb_w * margin) / 2 31 | left = left - margin_w 32 | right = right + margin_w 33 | top = top - margin_h 34 | bottom = bottom + margin_h 35 | self.left = max(math.floor(left), 0) 36 | self.right = min(math.ceil(right), w) 37 | self.top = max(math.floor(top), 0) 38 | self.bottom = min(math.ceil(bottom), h) 39 | 40 | def print(self): 41 | print("Left: {:n}, Right:{:n}, Top:{:n}, Bottom:{:n}".format(self.left, self.right, self.top, self.bottom)) 42 | 43 | def upscale(self, scale): 44 | self.left = math.floor(self.left * scale) 45 | self.right = math.floor(self.right * scale) 46 | self.top = math.floor(self.top * scale) 47 | self.bottom = math.floor(self.bottom * scale) 48 | 49 | def add(self, bbox): 50 | self.left += bbox.left 51 | self.right += bbox.right 52 | self.top += bbox.top 53 | self.bottom += bbox.bottom 54 | 55 | def div(self, num): 56 | self.left /= num 57 | self.right /= num 58 | self.top /= num 59 | self.bottom /= num 60 | 61 | def to_one_hot(y_tensor, n_dims=9): 62 | _,h,w = y_tensor.size() 63 | """ 64 | Take integer y (tensor or variable) with n dims and convert it to 1-hot representation with n+1 dims. 65 | """ 66 | y_tensor = y_tensor.type(torch.LongTensor).view(-1, 1) 67 | n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1 68 | y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1) 69 | y_one_hot = y_one_hot.view(h,w,n_dims) 70 | return y_one_hot.permute(2,0,1).unsqueeze(0) 71 | 72 | def create_grid(F_size, GPU=True): 73 | """ 74 | INPUTS: 75 | - F_size: feature size 76 | OUTPUT: 77 | - return a standard grid coordinate 78 | """ 79 | b, c, h, w = F_size 80 | theta = torch.tensor([[1,0,0],[0,1,0]]) 81 | theta = theta.unsqueeze(0).repeat(b,1,1) 82 | theta = theta.float() 83 | 84 | # grid is a uniform grid with left top (-1,1) and right bottom (1,1) 85 | # b * (h*w) * 2 86 | grid = torch.nn.functional.affine_grid(theta, F_size) 87 | if(GPU): 88 | grid = grid.cuda() 89 | return grid 90 | 91 | def seg2bbox_v2(seg, bbox_pre): 92 | """ 93 | INPUTS: 94 | - seg: segmentation mask, a c*h*w one-hot tensor 95 | OUTPUS: 96 | - bbox: a c*4 tensor, indicating bbox for each object 97 | """ 98 | seg = seg.squeeze() 99 | c,h,w = seg.size() 100 | bbox = {} 101 | bbox[0] = BBox(0, w, 0, h, 0, h, w) 102 | bbox[0].upscale(0.125) 103 | _, seg_int = torch.max(seg, dim=0) 104 | for cnt in range(1,c): # rule out background label 105 | seg_cnt = (seg_int == cnt) * 1 106 | # x * 2 107 | fg_idx = seg_cnt.nonzero().float() 108 | 109 | if(fg_idx.numel() > 0 and (bbox_pre[cnt] is not None)): 110 | fg_idx = torch.flip(fg_idx, (0,1)) 111 | 112 | bbox_tmp = copy.deepcopy(bbox_pre[cnt]) 113 | bbox_tmp.upscale(8) 114 | bbox[cnt] = coords2bbox_scale(fg_idx, h, w, bbox_tmp, margin=0.6, bandwidth=20) 115 | 116 | bbox[cnt].upscale(0.125) 117 | else: 118 | bbox[cnt] = None 119 | return bbox 120 | 121 | def seg2bbox(seg, margin=0,print_info=False): 122 | """ 123 | INPUTS: 124 | - seg: segmentation mask, a c*h*w one-hot tensor 125 | OUTPUS: 126 | - bbox: a c*4 tensor, indicating bbox for each object 127 | """ 128 | seg = seg.squeeze() 129 | c,h,w = seg.size() 130 | bbox = {} 131 | bbox[0] = BBox(0, w, 0, h, 0, h, w) 132 | for cnt in range(1,c): # rule out background label 133 | seg_cnt = seg[cnt] 134 | # x * 2 135 | fg_idx = seg_cnt.nonzero() 136 | if(fg_idx.numel() > 0): 137 | left = fg_idx[:,1].min() 138 | right = fg_idx[:,1].max() 139 | top = fg_idx[:,0].min() 140 | bottom = fg_idx[:,0].max() 141 | bbox[cnt] = BBox(left, right, top, bottom, margin, h, w) 142 | else: 143 | bbox[cnt] = None 144 | return bbox 145 | 146 | def gaussin(x, sigma): 147 | return torch.exp(- x ** 2 / (2 * (sigma ** 2))) 148 | 149 | def calc_center(arr, mode='mean', sigma=10): 150 | """ 151 | INPUTS: 152 | - arr: an array with coordinates, shape: n 153 | - mode: 'mean' to calculate Euclean center, 'mass' to calculate mass center 154 | - sigma: Gaussian parameter if calculating mass center 155 | """ 156 | eu_center = torch.mean(arr) 157 | if(mode == 'mean'): 158 | return eu_center 159 | # calculate weight center 160 | eu_center = eu_center.view(1,1).repeat(1,arr.size(0)).squeeze() 161 | diff = eu_center - arr 162 | weight = gaussin(diff, sigma) 163 | mass_center = torch.sum(weight * arr / (torch.sum(weight))) 164 | return mass_center 165 | 166 | def masked_softmax(vec, mask, dim=1, epsilon=1e-5): 167 | exps = torch.exp(vec) 168 | masked_exps = exps * mask.float() 169 | masked_sums = masked_exps.sum(dim, keepdim=True) + epsilon 170 | return (masked_exps/masked_sums) 171 | 172 | def transform_topk(aff, frame1, k): 173 | """ 174 | INPUTS: 175 | - aff: affinity matrix, b * N * N 176 | - frame1: reference frame 177 | - k: only aggregate top-k pixels with highest aff(j,i) 178 | """ 179 | b,c,h,w = frame1.size() 180 | b, _, N = aff.size() 181 | # b * 20 * N 182 | tk_val, tk_idx = torch.topk(aff, dim = 1, k=k) 183 | # b * N 184 | tk_val_min,_ = torch.min(tk_val, dim=1) 185 | tk_val_min = tk_val_min.view(b,1,N) 186 | aff[tk_val_min > aff] = 0 187 | aff = masked_softmax(aff, aff > 0, dim=1) 188 | frame1 = frame1.view(b,c,-1) 189 | frame2 = torch.bmm(frame1, aff) 190 | return frame2 191 | 192 | def squeeze_all(*args): 193 | res = [] 194 | for arg in args: 195 | res.append(arg.squeeze()) 196 | return tuple(res) 197 | 198 | def decode_seg(seg): 199 | seg = seg.squeeze() 200 | h,w = seg.size() 201 | color_seg = np.zeros((h,w,3)) 202 | unis = np.unique(seg) 203 | seg = seg.numpy() 204 | for cnt,uni in enumerate(unis): 205 | xys = (seg == uni) 206 | xs,ys = np.nonzero(xys) 207 | color_seg[xs,ys,:] = color_platte[cnt].reshape(1,3) 208 | return color_seg 209 | 210 | def draw_bbox(seg,bbox,color): 211 | """ 212 | INPUTS: 213 | - segmentation, h * w * 3 numpy array 214 | - coord: left, right, top, bottom 215 | OUTPUT: 216 | - seg with a drawn bbox 217 | """ 218 | seg = seg.copy() 219 | pt1 = (bbox.left,bbox.top) 220 | pt2 = (bbox.right,bbox.bottom) 221 | color = np.array(color, dtype=np.uint8) 222 | c = tuple(map(int, color)) 223 | seg = cv2.rectangle(seg, pt1, pt2, c, 2) 224 | return seg 225 | 226 | def vis_bbox(img, bbox, upscale=1): 227 | """ 228 | INPUTS: 229 | - img: a h*w*c opencv image 230 | - bbox: a list of bounding box 231 | OUTPUT: 232 | - img: image with bbox drawn on 233 | """ 234 | #for cnt in range(len(bbox)): 235 | for cnt, bbox_cnt in bbox.items(): 236 | #bbox_cnt = bbox[cnt] 237 | if(bbox_cnt is None): 238 | continue 239 | bbox_cnt.upscale(upscale) 240 | color = color_platte[cnt+1] 241 | img = draw_bbox(img, bbox_cnt, color) 242 | return img 243 | 244 | def vis_bbox_pair(bbox, bbox_next, frame1_var, frame2_var, out_path): 245 | frame1 = frame1_var * 128 + 128 246 | frame1 = frame1.squeeze().permute(1,2,0) 247 | im1 = cv2.cvtColor(np.array(frame1, dtype = np.uint8), cv2.COLOR_LAB2BGR) 248 | 249 | frame2 = frame2_var * 128 + 128 250 | frame2 = frame2.squeeze().permute(1,2,0) 251 | im2 = cv2.cvtColor(np.array(frame2, dtype = np.uint8), cv2.COLOR_LAB2BGR) 252 | 253 | #for i, (bbox_cnt, bbox_next_cnt) in enumerate(zip(bbox, bbox_next)): 254 | for cnt, bbox_cnt in bbox.items(): 255 | bbox_next_cnt = bbox_next[cnt] 256 | bbox_cnt.upscale(8) 257 | bbox_next_cnt.upscale(8) 258 | im1 = draw_bbox(im1, bbox_cnt, color_platte[i+1]) 259 | im2 = draw_bbox(im2, bbox_next_cnt, color_platte[i+1]) 260 | 261 | im = np.concatenate((im1,im2), axis=1) 262 | cv2.imwrite(out_path, im) 263 | 264 | def clean_seg(seg, bbox, threshold): 265 | c,h,w = seg.size() 266 | fgs = {} 267 | for cnt, bbox_cnt in bbox.items(): 268 | if(bbox_cnt is not None): 269 | seg_cnt = seg[cnt] 270 | fg = seg_cnt.nonzero() 271 | fgs[cnt] = fg[:,[1,0]].float() 272 | else: 273 | fgs[cnt] = None 274 | fgs = clean_coords(fgs, bbox, threshold) 275 | seg_new = torch.zeros(seg.size()) 276 | for cnt, fg in fgs.items(): 277 | if(fg is not None): 278 | fg = fg.long() 279 | seg_new[cnt][fg[:,1],fg[:,0]] = seg[cnt][fg[:,1],fg[:,0]] 280 | return seg_new.view(1,c,h,w) 281 | 282 | def clean_coords(coord, bbox_pre, threshold): 283 | """ 284 | INPUT: 285 | - coord: coordinates of foreground points, a N * 2 tensor. 286 | - center: cluster center, a 2 tensor 287 | - threshold: we cut all points larger than this threshold 288 | METHOD: Rule out outliers in coord 289 | """ 290 | new_coord = {} 291 | for cnt, coord_cnt in coord.items(): 292 | bbox_pre_cnt = bbox_pre[cnt] 293 | if((bbox_pre_cnt is not None) and (coord_cnt is not None)): 294 | center_h = (bbox_pre_cnt.top + bbox_pre_cnt.bottom)/2 295 | center_w = (bbox_pre_cnt.left + bbox_pre_cnt.right)/2 296 | dis = (coord_cnt[:,1] - center_h )**2 + (coord_cnt[:,0] - center_w)**2 297 | mean_ = torch.mean(dis) 298 | dis = dis / mean_ 299 | 300 | idx_ = (dis <= threshold).nonzero() 301 | new_coord[cnt] = coord_cnt[idx_].squeeze() 302 | else: 303 | new_coord[cnt] = None 304 | return new_coord 305 | 306 | def coord2bbox(bbox_pre, coord, h, w, adaptive=False): 307 | avg_h = calc_center(coord[:,1], mode='mass') 308 | avg_w = calc_center(coord[:,0], mode='mass') 309 | 310 | if adaptive: 311 | center = torch.mean(coord, dim = 0) 312 | dis_h = coord[:,1] - center[1] 313 | dis_w = coord[:,0] - center[0] 314 | 315 | dis_h = torch.mean(dis_h * dis_h, dim = 0) 316 | bb_height = (dis_h ** 0.5) * 8 317 | 318 | dis_w = torch.mean(dis_w * dis_w, dim = 0) 319 | bb_width = (dis_w ** 0.5) * 8 320 | 321 | # the adaptive method is sentitive to outliers, let's assume there's no dramatical change within 322 | # short range, so the height should not grow larger than 1.2 times height in previous frame 323 | bb_height = torch.min(bb_height, torch.Tensor([bbox_pre.bottom - bbox_pre.top]) * 1.2) 324 | bb_width = torch.min(bb_width, torch.Tensor([bbox_pre.right - bbox_pre.left]) * 1.2) 325 | 326 | bb_height = torch.max(bb_height, torch.Tensor([bbox_pre.bottom - bbox_pre.top]) * 0.8) 327 | bb_width = torch.max(bb_width, torch.Tensor([bbox_pre.right - bbox_pre.left]) * 0.8) 328 | 329 | else: 330 | bb_width = float(bbox_pre.right - bbox_pre.left) 331 | bb_height = float(bbox_pre.bottom - bbox_pre.top) 332 | left = avg_w - bb_width/2.0 333 | right = avg_w + bb_width/2.0 334 | top = avg_h - bb_height/2.0 335 | bottom = avg_h + bb_height/2.0 336 | 337 | coord_left = coord[:,0].min() 338 | coord_right = coord[:,0].max() 339 | coord_top = coord[:,1].min() 340 | coord_bottom = coord[:,1].max() 341 | 342 | bbox_tar_ = BBox(left = int(max(left, 0)), 343 | right = int(min(right, w)), 344 | top = int(max(top, 0)), 345 | bottom = int(min(bottom, h)), 346 | margin = 0, h = h, w = w) 347 | return bbox_tar_ 348 | 349 | def post_process_seg(seg_pred): 350 | frame2_seg_bbox = torch.nn.functional.interpolate(seg_pred,scale_factor=8,mode='bilinear') 351 | frame2_seg_bbox = norm_mask(frame2_seg_bbox.squeeze()) 352 | _, frame2_seg_bbox = torch.max(frame2_seg_bbox, dim=0) 353 | return frame2_seg_bbox 354 | 355 | def post_process_bbox(seg, bbox_pre): 356 | fg_idx = seg.nonzero() 357 | 358 | bbox_pre_cnt = bbox_pre[cnt] 359 | if(bbox_pre_cnt is not None): 360 | bbox_pre_cnt.upscale(8) 361 | center_h = (bbox_pre_cnt.top + bbox_pre_cnt.bottom) / 2 362 | center_w = (bbox_pre_cnt.left + bbox_pre_cnt.right) / 2 363 | fg_idx = clean_coord(fg_idx.float().cuda(), torch.Tensor([center_h, center_w]),keep_ratio=0.9) 364 | 365 | def scatter_point(coord, name, w, h, center=None): 366 | fig, ax = plt.subplots() 367 | ax.axis((0,w,0,h)) 368 | ax.scatter(coord[:,0], h - coord[:,1]) 369 | ax.scatter(center[0], h - center[1], marker='^') 370 | 371 | plt.savefig(name) 372 | plt.clf() 373 | 374 | def shift_bbox(seg, bbox): 375 | c,h,w = seg.size() 376 | bbox_new = {} 377 | for cnt in range(c): 378 | seg_cnt = seg[cnt] 379 | bbox_cnt = bbox[cnt] 380 | if(bbox_cnt is not None): 381 | fg_idx = seg_cnt.nonzero() 382 | fg_idx = fg_idx[:,[1,0]].float() 383 | center_h = calc_center(fg_idx[:,1], mode='mass') 384 | center_w = calc_center(fg_idx[:,0], mode='mass') 385 | 386 | # shift bbox w.r.t new center 387 | old_h = (bbox_cnt.top + bbox_cnt.bottom) / 2 388 | old_w = (bbox_cnt.left + bbox_cnt.right) / 2 389 | bb_width = bbox_cnt.right - bbox_cnt.left 390 | bb_height = bbox_cnt.bottom - bbox_cnt.top 391 | 392 | left = center_w - bb_width/2 393 | right = center_w + bb_width/2 394 | top = center_h - bb_height/2 395 | bottom = center_h + bb_height/2 396 | bbox_new[cnt] = BBox(left = left, right = right, 397 | top = top, bottom = bottom, 398 | margin = 0, h = h, w = w) 399 | else: 400 | bbox_new[cnt] = None 401 | return bbox_new 402 | 403 | ############################# MATCHING FUNCTIONS ######################## 404 | def match_ref_tar(F_ref, F_tar, seg_ref, temp): 405 | """ 406 | INPUTS: 407 | - F_ref: feature of reference frame 408 | - F_tar: feature of target frame 409 | - seg_ref: segmentation of reference frame 410 | - temp: temperature of softmax 411 | METHOD: 412 | - take foreground pixels from the reference frame and match them to the 413 | target frame. 414 | RETURNS: 415 | - coord: a list of coordinates of foreground pixels in the target frame. 416 | """ 417 | coords = {} 418 | c, h, w = F_ref.size() 419 | F_ref_flat = F_ref.view(c, -1) 420 | F_tar_flat = F_tar.view(c, -1) 421 | 422 | grid = create_grid(F_ref.unsqueeze(0).size()).squeeze() 423 | grid[:,:,0] = (grid[:,:,0]+1)/2 * w 424 | grid[:,:,1] = (grid[:,:,1]+1)/2 * h 425 | grid_flat = grid.view(-1,2) 426 | 427 | for cnt in range(seg_ref.size(0)): 428 | seg_cnt = seg_ref[cnt, :, :].view(-1) 429 | # there's no mask for this channel 430 | if(torch.sum(seg_cnt) < 2): 431 | coords[cnt] = None 432 | continue 433 | if(cnt > 0): 434 | fg_idx = seg_cnt.nonzero() 435 | F_ref_cnt_flat = F_ref_flat[:,fg_idx].squeeze() 436 | 437 | else: 438 | # for the background class, we just take the whole frame 439 | F_ref_cnt_flat = F_ref_flat 440 | aff = torch.mm(F_ref_cnt_flat.permute(1,0), F_tar_flat) 441 | aff = torch.nn.functional.softmax(aff*temp, dim = 1) 442 | coord = torch.mm(aff, grid_flat) 443 | coords[cnt] = coord 444 | return coords 445 | 446 | def weighted_center(coords, center): 447 | """ 448 | in_range = [] 449 | for cnt in range(coords.shape[1]): 450 | coord_i = coords[0,cnt,:] 451 | if(np.linalg.norm(coord_i - prev_center) < bandwidth): 452 | in_range.append(coord_i) 453 | in_range = np.array(in_range) 454 | new_center = np.mean(in_range, axis=0) 455 | """ 456 | center = center.reshape(1, 1, 2) 457 | 458 | dis_x = np.sqrt(np.power(coords[:,:,0] - center[:,:,0], 2)) 459 | weight_x = 1 / dis_x 460 | weight_x = weight_x / np.sum(weight_x) 461 | dis_y = np.sqrt(np.power(coords[:,:,1] - center[:,:,1], 2)) 462 | weight_y = 1 / dis_y 463 | weight_y = weight_y / np.sum(weight_y) 464 | 465 | new_x = np.sum(weight_x * coords[:,:,0]) 466 | new_y = np.sum(weight_y * coords[:,:,1]) 467 | 468 | return np.array([new_x, new_y]).reshape(1,1,2) 469 | 470 | def euclid_distance(x, xi): 471 | return np.sqrt(np.sum((x - xi)**2)) 472 | 473 | def neighbourhood_points(X, x_centroid, distance = 20): 474 | eligible_X = [] 475 | #for x in X: 476 | for cnt in range(X.shape[0]): 477 | x = X[cnt,:] 478 | distance_between = euclid_distance(x, x_centroid) 479 | if distance_between <= distance: 480 | eligible_X.append(x) 481 | return eligible_X 482 | 483 | def gaussian_kernel(distance, bandwidth): 484 | val = (1/(bandwidth*math.sqrt(2*math.pi))) * np.exp(-0.5*((distance / bandwidth))**2) 485 | return val 486 | 487 | def mean_shift_center(coords, bandwidth=20): 488 | """ 489 | INPUTS: 490 | - coords: coordinates of pixels in the next frame, 1xNx2 491 | """ 492 | # 1 x 2 493 | avg_center = np.mean(coords, axis=0) 494 | prev_center = copy.deepcopy(avg_center) 495 | 496 | # pick the most close point as the center 497 | minimum = 100000 498 | for cnt in range(coords.shape[0]): 499 | coord_i = coords[cnt,:].reshape(1,2) 500 | dis = np.linalg.norm(coord_i - avg_center) 501 | if(dis < minimum): 502 | minimum = dis 503 | prev_center = copy.deepcopy(coord_i) 504 | 505 | counter = 0 506 | while True: 507 | counter += 1 508 | neighbors = neighbourhood_points(coords, prev_center.reshape(1,2), bandwidth) 509 | numerator = 0 510 | denominator = 0 511 | for neighbor in neighbors: 512 | distance = euclid_distance(neighbor, prev_center) 513 | weight = gaussian_kernel(distance, 4) 514 | numerator += (weight * neighbor) 515 | denominator += weight 516 | new_center = numerator / denominator 517 | 518 | if(np.sum(new_center - prev_center) < 0.1): 519 | final_center = torch.from_numpy(new_center) 520 | return final_center.view(1,2) 521 | else: 522 | prev_center = copy.deepcopy(new_center) 523 | 524 | def coords2bbox_scale(coords, h_tar, w_tar, bbox_pre, margin, bandwidth, log=False): 525 | """ 526 | INPUTS: 527 | - coords: coordinates of pixels in the next frame 528 | - h_tar: target image height 529 | - w_tar: target image widthg 530 | """ 531 | b = 1 532 | center = mean_shift_center(coords.numpy(), bandwidth) 533 | 534 | center_repeat = center.repeat(coords.size(0),1) 535 | 536 | dis_x = torch.sqrt(torch.pow(coords[:,0] - center_repeat[:,0], 2)) 537 | dis_x = torch.mean(dis_x, dim=0).detach() 538 | dis_y = torch.sqrt(torch.pow(coords[:,1] - center_repeat[:,1], 2)) 539 | dis_y = torch.mean(dis_y, dim=0).detach() 540 | 541 | left = (center[:,0] - dis_x*2).view(b,1) 542 | right = (center[:,0] + dis_x*2).view(b,1) 543 | top = (center[:,1] - dis_y*2).view(b,1) 544 | bottom = (center[:,1] + dis_y*2).view(b,1) 545 | 546 | 547 | bbox_tar_ = BBox(left = max(left, 0), 548 | right = min(right, w_tar), 549 | top = max(top, 0), 550 | bottom = min(bottom, h_tar), 551 | margin = margin, h = h_tar, w = w_tar) 552 | 553 | 554 | return bbox_tar_ 555 | 556 | def coord_wrt_bbox(coord, bbox): 557 | new_coord = [] 558 | left = bbox.left 559 | right = bbox.right 560 | top = bbox.top 561 | bottom = bbox.bottom 562 | for cnt in range(coord.size(0)): 563 | coord_i = coord[cnt] 564 | if(coord_i[0] >= left and coord_i[0] <= right and coord_i[1] >= top and coord_i[1] <= bottom): 565 | new_coord.append(coord_i) 566 | new_coord = torch.stack(new_coord) 567 | return new_coord 568 | 569 | def bbox_in_tar_scale(coords_tar, bbox_ref, h, w): 570 | """ 571 | INPUTS: 572 | - coords_tar: foreground coordinates in the target frame 573 | - bbox_ref: bboxs in the reference frame 574 | METHOD: 575 | - calculate bbox in the next frame w.r.t pixels coordinates 576 | RETURNS: 577 | - each bbox in the tar frame 578 | """ 579 | bbox_tar = {} 580 | #for cnt in range(len(bbox_ref)): 581 | for cnt, bbox_cnt in bbox_ref.items(): 582 | if(cnt == 0): 583 | bbox_tar_ = BBox(left = 0, 584 | right = w, 585 | top = 0, 586 | bottom = h, 587 | margin = 0.6, h = h, w = w) 588 | elif(bbox_cnt is not None): 589 | if not (cnt in coords_tar): 590 | continue 591 | coord = coords_tar[cnt] 592 | if(coord is None): 593 | continue 594 | coord = coord.cpu() 595 | 596 | bbox_tar_ = coords2bbox_scale(coord, h, w, bbox_cnt, margin=1, bandwidth=5, log=(cnt==3)) 597 | else: 598 | bbox_tar_ = None 599 | 600 | bbox_tar[cnt] = bbox_tar_ 601 | return bbox_tar 602 | 603 | ############################# depreciated code, keep for safety ######################## 604 | """ 605 | def bbox_in_tar(coords_tar, bbox_ref, h, w): 606 | INPUTS: 607 | - coords_tar: foreground coordinates in the target frame 608 | - bbox_ref: bboxs in the reference frame 609 | METHOD: 610 | - calculate bbox in the next frame w.r.t pixels coordinates 611 | RETURNS: 612 | - each bbox in the tar frame 613 | bbox_tar = {} 614 | for cnt, bbox_cnt in bbox_ref.items(): 615 | if(cnt == 0): 616 | bbox_tar_ = BBox(left = 0, 617 | right = w, 618 | top = 0, 619 | bottom = h, 620 | margin = 0.5, h = h, w = w) 621 | elif(bbox_cnt is not None): 622 | if not (cnt in coords_tar): 623 | continue 624 | coord = coords_tar[cnt] 625 | center_h = (bbox_cnt.top + bbox_cnt.bottom) / 2 626 | center_w = (bbox_cnt.left + bbox_cnt.right) / 2 627 | if(coord is None): 628 | continue 629 | coord = coord.cpu() 630 | 631 | bbox_tar_ = coord2bbox(bbox_cnt, coord, h, w, adaptive=False) 632 | else: 633 | bbox_tar_ = None 634 | 635 | bbox_tar[cnt] = bbox_tar_ 636 | return bbox_tar 637 | 638 | def bbox_in_tar_v2(coords_tar, bbox_ref, h, w, seg_pre): 639 | INPUTS: 640 | - coords_tar: foreground coordinates in the target frame 641 | - bbox_ref: bboxs in the reference frame 642 | METHOD: 643 | - calculate bbox in the next frame w.r.t pixels coordinates 644 | RETURNS: 645 | - each bbox in the tar frame 646 | VERSION NOTE: 647 | - include scaling modeling 648 | bbox_tar = {} 649 | for cnt, bbox_cnt in bbox_ref.items(): 650 | # for each channel 651 | if(cnt == 0): 652 | # for background channel 653 | bbox_tar_ = BBox(left = 0, 654 | right = w, 655 | top = 0, 656 | bottom = h, 657 | margin = 0.5, h = h, w = w) 658 | elif(bbox_cnt is not None): 659 | coord = coords_tar[cnt] 660 | if(coord is None): 661 | continue 662 | coord = coord.cpu() 663 | 664 | bbox_tar_ = coord2bbox(bbox_cnt, coord, h, w, seg_pre, adaptive=True) 665 | else: 666 | bbox_tar_ = None 667 | 668 | bbox_tar[cnt] = bbox_tar_ 669 | return bbox_tar 670 | 671 | def recoginition(F_ref, F_tar, bbox_ref, bbox_tar, seg_ref, temp): 672 | - F_ref: feature of reference frame 673 | - F_tar: feature of target frame 674 | - bbox_ref: bboxes of reference frame 675 | - bbox_tar: bboxes of target frame 676 | - seg_ref: segmentation of reference frame 677 | c, h, w = F_tar.size() 678 | seg_pred = torch.zeros(seg_ref.size()) 679 | #for cnt,(br, bt) in enumerate(zip(bbox_ref, bbox_tar)): 680 | for cnt, br in bbox_ref.items(): 681 | bt = bbox_tar[cnt] 682 | if(br is None or bt is None): 683 | continue 684 | seg_cnt = seg_ref[cnt] 685 | 686 | # feature of patch in the next frame 687 | F_tar_box = F_tar[:, bt.top:bt.bottom, bt.left:bt.right] 688 | F_ref_box = F_ref[:, br.top:br.bottom, br.left:br.right] 689 | F_tar_box_flat = F_tar_box.contiguous().view(c,-1) 690 | F_ref_box_flat = F_ref_box.contiguous().view(c,-1) 691 | 692 | # affinity between two patches 693 | aff = torch.mm(F_ref_box_flat.permute(1,0), F_tar_box_flat) 694 | aff = torch.nn.functional.softmax(aff * temp, dim=0) 695 | # transfer segmentation from patch1 to patch2 696 | seg_ref_box = seg_cnt[br.top:br.bottom, br.left:br.right] 697 | aff = aff.cpu() 698 | if(cnt == 0): 699 | seg_ref_box_flat = seg_ref_box.contiguous().view(-1) 700 | seg_tar_box = torch.mm(seg_ref_box_flat.unsqueeze(0), aff).squeeze() 701 | else: 702 | seg_ref_box_flat = seg_ref_box.contiguous().view(-1) 703 | seg_tar_box = torch.mm(seg_ref_box_flat.unsqueeze(0), aff).squeeze() 704 | #seg_tar_box = transform_topk(aff.unsqueeze(0), seg_ref_box.contiguous().unsqueeze(0).unsqueeze(0), 20) 705 | seg_tar_box = seg_tar_box.view(F_tar_box.size(1), F_tar_box.size(2)) 706 | 707 | seg_pred[cnt,bt.top:bt.bottom, bt.left:bt.right] = seg_tar_box 708 | return seg_pred 709 | 710 | def bbox_next_frame_v2(F_first, F_pre, seg_pre, seg_first, F_tar, bbox_first, bbox_pre, temp, direct=False): 711 | INPUTS: 712 | - direct: rec|direct, 713 | - if False, use previous frame to locate bbox 714 | - if True, use first frame to locate bbox 715 | F_first, F_pre, seg_pre, seg_first, F_tar = squeeze_all(F_first, F_pre, seg_pre, seg_first, F_tar) 716 | c, h, w = F_first.size() 717 | if not direct: 718 | coords_tar = match_ref_tar(F_pre, F_tar, seg_pre, temp) 719 | else: 720 | coords_tar = match_ref_tar(F_first, F_tar, seg_first, temp) 721 | 722 | bbox_tar = bbox_in_tar(coords_tar, bbox_first, h, w) 723 | 724 | seg_pred = recoginition(F_first, F_tar, bbox_first, bbox_tar, seg_first, temp) 725 | return seg_pred.unsqueeze(0) 726 | 727 | def bbox_next_frame_v3(F_first, F_pre, seg_pre, seg_first, F_tar, bbox_first, bbox_pre, temp, name): 728 | METHOD: combining tracking & direct recognition, calculate bbox in target frame 729 | using both first frame and previous frame. 730 | F_first, F_pre, seg_pre, seg_first, F_tar = squeeze_all(F_first, F_pre, seg_pre, seg_first, F_tar) 731 | c, h, w = F_first.size() 732 | 733 | coords_pre_tar = match_ref_tar(F_pre, F_tar, seg_pre, temp) 734 | coords_first_tar = match_ref_tar(F_first, F_tar, seg_first, temp) 735 | coords_tar = {} 736 | for cnt, coord_first in coords_first_tar.items(): 737 | coord_pre = coords_pre_tar[cnt] 738 | # fall-back schema 739 | if(coord_pre is None): 740 | coord_tar_ = coord_first 741 | else: 742 | coord_tar_ = coord_pre 743 | coords_tar[cnt] = coord_tar_ 744 | _, seg_pre_idx = torch.max(seg_pre, dim = 0) 745 | 746 | coords_tar = clean_coords(coords_tar, bbox_pre, threshold=4) 747 | bbox_tar = bbox_in_tar(coords_tar, bbox_first, h, w) 748 | 749 | # recoginition 750 | seg_pred = recoginition(F_first, F_tar, bbox_first, bbox_tar, seg_first, temp) 751 | seg_cleaned = clean_seg(seg_pred, bbox_tar, threshold=1) 752 | 753 | # move bbox w.r.t cleaned seg 754 | bbox_tar = shift_bbox(seg_cleaned, bbox_tar) 755 | 756 | seg_post = post_process_seg(seg_pred.unsqueeze(0)) 757 | return seg_pred, seg_post, bbox_tar 758 | 759 | def bbox_next_frame_v4(F_first, F_pre, seg_pre, seg_first, F_tar, bbox_first, 760 | bbox_pre, temp): 761 | METHOD: combining tracking & direct recognition, calculate bbox in target frame 762 | using both first frame and previous frame. 763 | Version Note: include bounding box scaling 764 | F_first, F_pre, seg_pre, seg_first, F_tar = squeeze_all(F_first, F_pre, seg_pre, seg_first, F_tar) 765 | c, h, w = F_first.size() 766 | 767 | coords_pre_tar = match_ref_tar(F_pre, F_tar, seg_pre, temp) 768 | coords_first_tar = match_ref_tar(F_first, F_tar, seg_first, temp) 769 | coords_tar = {} 770 | for cnt, coord_first in coords_first_tar.items(): 771 | coord_pre = coords_pre_tar[cnt] 772 | # fall-back schema 773 | if(coord_pre is None): 774 | coord_tar_ = coord_first 775 | else: 776 | coord_tar_ = coord_pre 777 | coords_tar[cnt] = coord_tar_ 778 | _, seg_pre_idx = torch.max(seg_pre, dim = 0) 779 | 780 | coords_tar = clean_coords(coords_tar, bbox_pre, threshold=4) 781 | 782 | bbox_tar = bbox_in_tar_v2(coords_tar, bbox_first, h, w, seg_pre) 783 | 784 | # recoginition 785 | seg_pred = recoginition(F_first, F_tar, bbox_first, bbox_tar, seg_first, temp) 786 | seg_cleaned = clean_seg(seg_pred, bbox_tar, threshold=1) 787 | 788 | # move bbox w.r.t cleaned seg 789 | bbox_tar = shift_bbox(seg_cleaned, bbox_tar) 790 | 791 | seg_post = post_process_seg(seg_pred.unsqueeze(0)) 792 | return seg_pred, seg_post, bbox_tar 793 | 794 | """ 795 | """ 796 | def bbox_next_frame(F_ref, seg_ref, F_tar, bbox, temp): 797 | # b * h * w * 2 798 | b, c, h, w = F_ref.size() 799 | grid = create_grid(F_ref.size()).squeeze() 800 | grid[:,:,0] = (grid[:,:,0]+1)/2 * w 801 | grid[:,:,1] = (grid[:,:,1]+1)/2 * h 802 | # grid_flat: (h * w) * 2 803 | grid_flat = grid.view(-1,2) 804 | seg_ref = seg_ref.squeeze() 805 | F_ref = F_ref.squeeze() 806 | F_ref_flat = F_ref.view(c,-1) 807 | F_tar = F_tar.squeeze() 808 | F_tar_flat = F_tar.view(c,-1) 809 | bbox_next = [] 810 | seg_pred = torch.zeros(seg_ref.size()) 811 | for i in range(0,seg_ref.size(0)): 812 | seg_cnt = seg_ref[i,:,:].contiguous().view(-1) 813 | if(seg_cnt.max() == 0): 814 | continue 815 | bbox_cnt = bbox[i] 816 | if(i > 0): 817 | fg_idx = seg_cnt.nonzero() 818 | # take pixels of this instance out 819 | F_ref_cnt_flat = F_ref_flat[:,fg_idx].squeeze() 820 | 821 | # affinity between patch and target frame 822 | # aff: (hh * ww, h * w) 823 | aff = torch.mm(F_ref_cnt_flat.permute(1,0), F_tar_flat) 824 | aff = torch.nn.functional.softmax(aff*temp, dim = 1) 825 | # coord of this patch in next frame: (hh*ww) * 2 826 | coord = torch.mm(aff, grid_flat) 827 | avg_h = calc_center(coord[:,1], mode='mass').cpu().long() 828 | avg_w = calc_center(coord[:,0], mode='mass').cpu().long() 829 | bb_width = bbox_cnt.right - bbox_cnt.left 830 | bb_height = bbox_cnt.bottom - bbox_cnt.top 831 | bbox_next_ = BBox(left = max(avg_w - bb_width/2,0), 832 | right = min(avg_w + bb_width/2, w), 833 | top = max(avg_h - bb_height/2,0), 834 | bottom = min(avg_h + bb_height/2,h), 835 | margin = 0, h = h, w = w) 836 | else: 837 | bbox_next_ = BBox(left = 0, 838 | right = w, 839 | top = 0, 840 | bottom = h, 841 | margin = 0, h = h, w = w) 842 | bbox_next.append(bbox_next_) 843 | 844 | # feature of patch in the next frame 845 | F_tar_box = F_tar[:, bbox_next_.top:bbox_next_.bottom, bbox_next_.left:bbox_next_.right] 846 | F_ref_box = F_ref[:, bbox_cnt.top:bbox_cnt.bottom, bbox_cnt.left:bbox_cnt.right] 847 | F_tar_box_flat = F_tar_box.contiguous().view(c,-1) 848 | F_ref_box_flat = F_ref_box.contiguous().view(c,-1) 849 | 850 | # affinity between two patches 851 | aff = torch.mm(F_ref_box_flat.permute(1,0), F_tar_box_flat) 852 | aff = torch.nn.functional.softmax(aff * temp, dim=0) 853 | # transfer segmentation from patch1 to patch2 854 | seg_ref_box = seg_ref[i, bbox_cnt.top:bbox_cnt.bottom, bbox_cnt.left:bbox_cnt.right] 855 | aff = aff.cpu() 856 | seg_ref_box_flat = seg_ref_box.contiguous().view(-1) 857 | seg_tar_box = torch.mm(seg_ref_box_flat.unsqueeze(0), aff).squeeze() 858 | seg_tar_box = seg_tar_box.view(F_tar_box.size(1),F_tar_box.size(2)) 859 | 860 | seg_pred[i,bbox_next_.top:bbox_next_.bottom, bbox_next_.left:bbox_next_.right] = seg_tar_box 861 | 862 | #seg_pred[0,:,:] = 1 - torch.sum(seg_pred[1:,:,:],dim=0) 863 | 864 | return seg_pred.unsqueeze(0), bbox_next 865 | def bbox_next_frame_rec(F_first, F_ref, seg_ref, seg_first, 866 | F_tar, bbox_first, bbox, temp): 867 | # b * h * w * 2 868 | b, c, h, w = F_ref.size() 869 | seg_ref = seg_ref.squeeze() 870 | F_ref = F_ref.squeeze() 871 | F_ref_flat = F_ref.view(c,-1) 872 | F_tar = F_tar.squeeze() 873 | F_tar_flat = F_tar.view(c,-1) 874 | F_first = F_first.squeeze() 875 | bbox_next = [] 876 | seg_pred = torch.zeros(seg_ref.size()) 877 | seg_first = seg_first.squeeze() 878 | for i in range(0,seg_ref.size(0)): 879 | seg_cnt = seg_ref[i,:,:].contiguous().view(-1) 880 | if(seg_cnt.max() == 0): 881 | continue 882 | if(i > len(bbox)-1): 883 | continue 884 | bbox_cnt = bbox[i] 885 | bbox_first_cnt = bbox_first[i] 886 | if(i > 0): 887 | fg_idx = seg_cnt.nonzero() 888 | F_ref_cnt_flat = F_ref_flat[:,fg_idx].squeeze() 889 | 890 | # affinity between patch and target frame 891 | if(F_ref_cnt_flat.dim() < 2): 892 | # some small objects may miss 893 | continue 894 | aff = torch.mm(F_ref_cnt_flat.permute(1,0), F_tar_flat) 895 | aff = torch.nn.functional.softmax(aff*temp, dim = 1) 896 | coord = torch.mm(aff, grid_flat) 897 | #coord = transform_topk(aff.unsqueeze(0), grid.unsqueeze(0), dim=2, k=20) 898 | avg_h = calc_center(coord[:,1], mode='mass').cpu() 899 | avg_w = calc_center(coord[:,0], mode='mass').cpu() 900 | bb_width = float(bbox_first_cnt.right - bbox_first_cnt.left) 901 | bb_height = float(bbox_first_cnt.bottom - bbox_first_cnt.top) 902 | coord = coord.cpu() 903 | 904 | left = avg_w - bb_width/2.0 905 | right = avg_w + bb_width/2.0 906 | top = avg_h - bb_height/2.0 907 | bottom = avg_h + bb_height/2.0 908 | 909 | bbox_next_ = BBox(left = int(max(left, 0)), 910 | right = int(min(right, w)), 911 | top = int(max(top, 0)), 912 | bottom = int(min(bottom, h)), 913 | margin = 0, h = h, w = w) 914 | else: 915 | bbox_next_ = BBox(left = 0, 916 | right = w, 917 | top = 0, 918 | bottom = h, 919 | margin = 0, h = h, w = w) 920 | 921 | bbox_next.append(bbox_next_) 922 | 923 | # feature of patch in the next frame 924 | F_tar_box = F_tar[:, bbox_next_.top:bbox_next_.bottom, bbox_next_.left:bbox_next_.right] 925 | F_ref_box = F_first[:, bbox_first_cnt.top:bbox_first_cnt.bottom, bbox_first_cnt.left:bbox_first_cnt.right] 926 | F_tar_box_flat = F_tar_box.contiguous().view(c,-1) 927 | F_ref_box_flat = F_ref_box.contiguous().view(c,-1) 928 | print('================') 929 | 930 | # affinity between two patches 931 | aff = torch.mm(F_ref_box_flat.permute(1,0), F_tar_box_flat) 932 | aff = torch.nn.functional.softmax(aff * temp, dim=0) 933 | # transfer segmentation from patch1 to patch2 934 | seg_ref_box = seg_first[i, bbox_first_cnt.top:bbox_first_cnt.bottom, bbox_first_cnt.left:bbox_first_cnt.right] 935 | aff = aff.cpu() 936 | seg_ref_box_flat = seg_ref_box.contiguous().view(-1) 937 | seg_tar_box = torch.mm(seg_ref_box_flat.unsqueeze(0), aff).squeeze() 938 | #seg_tar_box = transform_topk(aff.unsqueeze(0), seg_ref_box.contiguous().unsqueeze(0).unsqueeze(0)) 939 | seg_tar_box = seg_tar_box.view(F_tar_box.size(1),F_tar_box.size(2)) 940 | 941 | seg_pred[i,bbox_next_.top:bbox_next_.bottom, bbox_next_.left:bbox_next_.right] = seg_tar_box 942 | 943 | return seg_pred.unsqueeze(0), bbox_next 944 | 945 | def transform_topk(aff, frame1, k=20, dim=1): 946 | b,c,h,w = frame1.size() 947 | b, N1, N2 = aff.size() 948 | # b * 20 * N 949 | tk_val, tk_idx = torch.topk(aff, dim = dim, k=k) 950 | # b * N 951 | tk_val_min,_ = torch.min(tk_val,dim=dim) 952 | if(dim == 1): 953 | tk_val_min = tk_val_min.view(b,1,N2) 954 | else: 955 | tk_val_min = tk_val_min.view(b,N1,1) 956 | aff[tk_val_min > aff] = 0 957 | frame1 = frame1.view(b,c,-1) 958 | if(dim == 1): 959 | frame2 = torch.bmm(frame1, aff) 960 | return frame2 961 | else: 962 | frame2 = torch.bmm(aff, frame1.permute(0,2,1)) 963 | return frame2.squeeze() 964 | """ 965 | -------------------------------------------------------------------------------- /libs/train_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import shutil 5 | # import visdom 6 | import numpy as np 7 | from os.path import join 8 | 9 | def save_vis(pred, gt2, gt1, out_dir, gt_grey=False, prefix=0): 10 | """ 11 | INPUTS: 12 | - pred: predicted Lab image, a 3xhxw tensor 13 | - gt2: second GT frame, a 3xhxw tensor 14 | - gt1: first GT frame, a 3xhxw tensor 15 | - out_dir: output image save path 16 | - gt_grey: whether to use ground trught L channel in predicted image 17 | """ 18 | b = pred.size(0) 19 | pred = pred * 128 + 128 20 | gt1 = gt1 * 128 + 128 21 | gt2 = gt2 * 128 + 128 22 | 23 | if(gt_grey): 24 | pred[:,0,:,:] = gt2[:,0,:,:] 25 | for cnt in range(b): 26 | im = pred[cnt].cpu().detach().numpy().transpose( 1, 2, 0) 27 | im_bgr = cv2.cvtColor(np.array(im, dtype = np.uint8), cv2.COLOR_LAB2BGR) 28 | im_pred = np.clip(im_bgr, 0, 255) 29 | 30 | im = gt2[cnt].cpu().detach().numpy().transpose( 1, 2, 0) 31 | im_gt2 = cv2.cvtColor(np.array(im, dtype = np.uint8), cv2.COLOR_LAB2BGR) 32 | 33 | im = gt1[cnt].cpu().detach().numpy().transpose( 1, 2, 0) 34 | im_gt1 = cv2.cvtColor(np.array(im, dtype = np.uint8), cv2.COLOR_LAB2BGR) 35 | 36 | im = np.concatenate((im_gt1, im_gt2, im_pred), axis = 1) 37 | print(out_dir, "{:02d}{:02d}.png".format(prefix, cnt)) 38 | cv2.imwrite(join(out_dir, "{:02d}{:02d}.png".format(prefix, cnt)), im) 39 | 40 | def save_vis_ae(pred, gt, savepath): 41 | b = pred.size(0) 42 | for cnt in range(b): 43 | im = pred[cnt].cpu().detach() * 128 + 128 44 | im = im.numpy().transpose(1,2,0) 45 | im_pred = cv2.cvtColor(np.array(im, dtype = np.uint8), cv2.COLOR_LAB2BGR) 46 | 47 | im = gt[cnt].cpu().detach() * 128 + 128 48 | im = im.numpy().transpose(1,2,0) 49 | im_gt = cv2.cvtColor(np.array(im, dtype = np.uint8), cv2.COLOR_LAB2BGR) 50 | 51 | im = np.concatenate((im_gt, im_pred), axis = 1) 52 | cv2.imwrite(os.path.join(savepath, "{:02d}.png".format(cnt)), im) 53 | 54 | class AverageMeter(object): 55 | """Computes and stores the average and current value""" 56 | def __init__(self): 57 | self.reset() 58 | 59 | def reset(self): 60 | self.val = 0 61 | self.avg = 0 62 | self.sum = 0 63 | self.count = 0 64 | 65 | def update(self, val, n=1): 66 | self.val = val 67 | self.sum += val * n 68 | self.count += n 69 | self.avg = self.sum / self.count 70 | 71 | def save_checkpoint(state, is_best, filename="checkpoint.pth.tar", savedir="models"): 72 | torch.save(state, filename) 73 | if is_best: 74 | shutil.copyfile(filename, os.path.join(savedir, 'model_best.pth.tar')) 75 | 76 | def sample_patch(b,h,w,patch_size): 77 | left = randint(0, max(w - patch_size,1)) 78 | top = randint(0, max(h - patch_size,1)) 79 | right = left + patch_size 80 | bottom = top + patch_size 81 | return torch.Tensor([left, right, top, bottom]).view(1,4).repeat(b,1).cuda() 82 | 83 | 84 | def diff_crop(F, x1, y1, x2, y2, ph, pw): 85 | """ 86 | Differatiable cropping 87 | INPUTS: 88 | - F: frame feature 89 | - x1,y1,x2,y2: top left and bottom right points of the patch 90 | - theta is defined as : 91 | a b c 92 | d e f 93 | """ 94 | bs, ch, h, w = F.size() 95 | a = ((x2-x1)/w).view(bs,1,1) 96 | b = torch.zeros(a.size()).cuda() 97 | c = (-1+(x1+x2)/w).view(bs,1,1) 98 | d = torch.zeros(a.size()).cuda() 99 | e = ((y2-y1)/h).view(bs,1,1) 100 | f = (-1+(y2+y1)/h).view(bs,1,1) 101 | theta_row1 = torch.cat((a,b,c),dim=2) 102 | theta_row2 = torch.cat((d,e,f),dim=2) 103 | theta = torch.cat((theta_row1, theta_row2),dim=1).cuda() 104 | size = torch.Size((bs,ch,pw,ph)) 105 | grid = FUNC.affine_grid(theta, size) 106 | patch = FUNC.grid_sample(F,grid) 107 | return patch 108 | 109 | def log_current(epoch, loss_ave, best_loss, filename = "log_current.txt", savedir="models"): 110 | file = join(savedir, filename) 111 | with open(file, "a") as text_file: 112 | print("epoch: {}".format(epoch), file=text_file) 113 | print("best_loss: {}".format(best_loss), file=text_file) 114 | print("current_loss: {}".format(loss_ave), file=text_file) 115 | 116 | def print_options(opt): 117 | message = '' 118 | message += '----------------- Options ---------------\n' 119 | for k, v in sorted(vars(opt).items()): 120 | comment = '' 121 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 122 | message += '----------------- End -------------------' 123 | print(message) 124 | 125 | # save to the disk 126 | expr_dir = os.path.join(opt.savedir) 127 | file_name = os.path.join(expr_dir, 'opt.txt') 128 | with open(file_name, 'wt') as opt_file: 129 | opt_file.write(message) 130 | opt_file.write('\n') 131 | -------------------------------------------------------------------------------- /libs/transforms_multi.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import random 3 | import scipy.io 4 | import cv2 5 | import numpy as np 6 | from PIL import Image, ImageOps, ImageEnhance, PILLOW_VERSION 7 | try: 8 | import accimage 9 | except ImportError: 10 | accimage = None 11 | import torch 12 | 13 | class Scale(object): 14 | """docstring for Scale""" 15 | def __init__(self, short_side): 16 | self.short_side = short_side 17 | 18 | def __call__(self, pair_0, pair_1): 19 | if(type(self.short_side) == int): 20 | h,w,c = pair_0.shape 21 | if(h > w): 22 | tw = self.short_side 23 | th = (tw * h) / w 24 | th = int((th // 64) * 64) 25 | else: 26 | th = self.short_side 27 | tw = (th * w) / h 28 | tw = int((tw // 64) * 64) 29 | elif(type(self.short_side) == list): 30 | th = self.short_side[0] 31 | tw = self.short_side[1] 32 | 33 | interpolation = cv2.INTER_NEAREST 34 | pair_0 = cv2.resize(pair_0, dsize = (tw, th), interpolation=interpolation) 35 | pair_1 = cv2.resize(pair_1, dsize = (tw, th), interpolation=interpolation) 36 | 37 | return pair_0, pair_1 38 | 39 | class RandomCrop(object): 40 | """for pair of frames""" 41 | def __init__(self, size, seperate = False): 42 | if isinstance(size, numbers.Number): 43 | self.size = (int(size), int(size)) 44 | else: 45 | self.size = size 46 | self.seperate = seperate 47 | 48 | def __call__(self, *frames): 49 | frames = list(frames) 50 | h, w, c = frames[0].shape 51 | th, tw = self.size 52 | top = bottom = left = right = 0 53 | 54 | if w == tw and h == th: 55 | return frames 56 | 57 | if w < tw: 58 | left = (tw - w) // 2 59 | right = tw - w - left 60 | if h < th: 61 | top = (th - h) // 2 62 | bottom = th - h - top 63 | if left > 0 or right > 0 or top > 0 or bottom > 0: 64 | for i in range(len(frames)): 65 | frames[i] = pad_image( 66 | 'reflection', frames[i], top, bottom, left, right) 67 | if w > tw: 68 | x1 = np.array([random.randint(0, w - tw)]) 69 | x1 = np.concatenate((x1,x1)) 70 | if self.seperate: 71 | #print("True") 72 | x1[1] = np.array([random.randint(0, w - tw)]) 73 | for i in range(len(frames)): 74 | frames[i] = frames[i][:, x1[i]:x1[i]+tw] 75 | if h > th: 76 | y1 = np.array([random.randint(0, h - th)]) 77 | y1 = np.concatenate((y1,y1)) 78 | 79 | if self.seperate: 80 | y1[1] = np.array([random.randint(0, h - th)]) 81 | 82 | for i in range(len(frames)): 83 | frames[i] = frames[i][y1[i]:y1[i]+th] 84 | 85 | return frames 86 | 87 | class CenterCrop(object): 88 | """for pair of frames""" 89 | def __init__(self, size): 90 | if isinstance(size, numbers.Number): 91 | self.size = (int(size), int(size)) 92 | else: 93 | self.size = size 94 | 95 | def __call__(self, *frames): 96 | frames = list(frames) 97 | h, w, c = frames[0].shape 98 | th, tw = self.size 99 | top = bottom = left = right = 0 100 | 101 | if w == tw and h == th: 102 | return frames 103 | 104 | if w < tw: 105 | left = (tw - w) // 2 106 | right = tw - w - left 107 | if h < th: 108 | top = (th - h) // 2 109 | bottom = th - h - top 110 | if left > 0 or right > 0 or top > 0 or bottom > 0: 111 | for i in range(len(frames)): 112 | frames[i] = pad_image( 113 | 'reflection', frames[i], top, bottom, left, right) 114 | 115 | if w > tw: 116 | #x1 = random.randint(0, w - tw) 117 | x1 = (w - tw) // 2 118 | for i in range(len(frames)): 119 | frames[i] = frames[i][:, x1:x1+tw] 120 | if h > th: 121 | #y1 = random.randint(0, h - th) 122 | y1 = (h - th) // 2 123 | for i in range(len(frames)): 124 | frames[i] = frames[i][y1:y1+th] 125 | 126 | return frames 127 | 128 | class RandomScale(object): 129 | """docstring for RandomScale""" 130 | def __init__(self, scale, seperate = False): 131 | if isinstance(scale, numbers.Number): 132 | scale = [1 / scale, scale] 133 | self.scale = scale 134 | self.seperate = seperate 135 | 136 | def __call__(self, *frames): 137 | h,w,c = frames[0].shape 138 | results = [] 139 | if self.seperate: 140 | ratio1 = random.uniform(self.scale[0], self.scale[1]) 141 | ratio2 = random.uniform(self.scale[0], self.scale[1]) 142 | tw1 = int(ratio1*w) 143 | th1 = int(ratio1*h) 144 | tw2 = int(ratio2*w) 145 | th2 = int(ratio2*h) 146 | 147 | if ratio1 == 1: 148 | results.append(frames[0]) 149 | elif ratio1 < 1: 150 | interpolation = cv2.INTER_LANCZOS4 151 | elif ratio1 > 1: 152 | interpolation = cv2.INTER_CUBIC 153 | frame = cv2.resize(frames[0], dsize = (tw1, th1), interpolation=interpolation) 154 | results.append(frame) 155 | 156 | if ratio2 == 1: 157 | results.append(frames[1]) 158 | elif ratio2 < 1: 159 | interpolation = cv2.INTER_LANCZOS4 160 | elif ratio2 > 1: 161 | interpolation = cv2.INTER_CUBIC 162 | frame = cv2.resize(frames[1], dsize = (tw2, th2), interpolation=interpolation) 163 | results.append(frame) 164 | else: 165 | ratio = random.uniform(self.scale[0], self.scale[1]) 166 | tw = int(ratio*w) 167 | th = int(ratio*h) 168 | if ratio == 1: 169 | return frames 170 | elif ratio < 1: 171 | interpolation = cv2.INTER_LANCZOS4 172 | elif ratio > 1: 173 | interpolation = cv2.INTER_CUBIC 174 | for frame in frames: 175 | frame = cv2.resize(frame, dsize = (tw, th), interpolation=interpolation) 176 | results.append(frame) 177 | # print(results[0].shape,type(results[1])) 178 | return results 179 | 180 | class RandomRotate(object): 181 | """docstring for RandomRotate""" 182 | def __init__(self, angle, seperate = False): 183 | self.angle = angle 184 | self.seperate = seperate 185 | 186 | #def __call__(self, pair_0, pair_1): 187 | def __call__(self, *frames): 188 | results = [] 189 | if self.seperate: 190 | angle = random.randint(0, self.angle * 2) - self.angle 191 | h,w,c = frames[0].shape 192 | p = max((h, w)) 193 | frame = pad_image('reflection', frames[0], h,h,w,w) 194 | frame = rotatenumpy(frame, angle) 195 | frame = frame[h : h + h, w : w + w] 196 | results.append(frame) 197 | 198 | angle = random.randint(0, self.angle * 2) - self.angle 199 | h,w,c = frames[1].shape 200 | p = max((h, w)) 201 | frame = pad_image('reflection', frames[1], h,h,w,w) 202 | frame = rotatenumpy(frame, angle) 203 | frame = frame[h : h + h, w : w + w] 204 | results.append(frame) 205 | else: 206 | angle = random.randint(0, self.angle * 2) - self.angle 207 | for frame in frames: 208 | h,w,c = frame.shape 209 | p = max((h, w)) 210 | frame = pad_image('reflection', frame, h,h,w,w) 211 | frame = rotatenumpy(frame, angle) 212 | frame = frame[h : h + h, w : w + w] 213 | results.append(frame) 214 | 215 | return results 216 | 217 | class RandomHorizontalFlip(object): 218 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 219 | """ 220 | 221 | def __call__(self, *frames): 222 | results = [] 223 | if random.random() < 0.5: 224 | for frame in frames: 225 | results.append(cv2.flip(frame, 1)) 226 | else: 227 | results = frames 228 | return results 229 | 230 | class Resize(object): 231 | """Resize the input PIL Image to the given size. 232 | Args: 233 | size (sequence or int): Desired output size. If size is a sequence like 234 | (h, w), output size will be matched to this. If size is an int, 235 | smaller edge of the image will be matched to this number. 236 | i.e, if height > width, then image will be rescaled to 237 | (size * height / width, size) 238 | interpolation (int, optional): Desired interpolation. Default is 239 | ``PIL.Image.BILINEAR`` 240 | """ 241 | 242 | def __init__(self, size, interpolation=cv2.INTER_NEAREST): 243 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 244 | self.size = size 245 | self.interpolation = interpolation 246 | 247 | def __call__(self, pair_0, pair_1): 248 | 249 | return resize(pair_0, self.size, self.interpolation), \ 250 | resize(pair_1, self.size, self.interpolation) 251 | 252 | 253 | class Pad(object): 254 | 255 | def __init__(self, padding, fill=0): 256 | assert isinstance(padding, numbers.Number) 257 | assert isinstance(fill, numbers.Number) or isinstance(fill, str) or \ 258 | isinstance(fill, tuple) 259 | self.padding = padding 260 | self.fill = fill 261 | 262 | def __call__(self, pair_0, pair_1): 263 | if self.fill == -1: 264 | pair_0 = pad_image('reflection', pair_0, 265 | self.padding, self.padding, self.padding, self.padding) 266 | pair_1 = pad_image('reflection', pair_1, 267 | self.padding, self.padding, self.padding, self.padding) 268 | else: 269 | pair_0 = pad_image('constant', pair_0, 270 | self.padding, self.padding, self.padding, self.padding, 271 | value=self.fill) 272 | pair_1 = pad_image('constant', pair_1, 273 | self.padding, self.padding, self.padding, self.padding, 274 | value=self.fill) 275 | 276 | return pair_0, pair_1 277 | 278 | 279 | class ResizeandPad(object): 280 | """ 281 | resize the larger boundary to the desized eva_size; 282 | pad the smaller one to square 283 | """ 284 | def __init__(self, size, interpolation=cv2.INTER_NEAREST): 285 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 286 | self.size = size 287 | self.interpolation = interpolation 288 | 289 | def __call__(self, pair_0, pair_1): 290 | """ 291 | Resize and Pad 292 | """ 293 | pair_0 = resize_large(pair_0, self.size, self.interpolation) 294 | pair_1 = resize_large(pair_1, self.size, self.interpolation) 295 | h,w,_ = pair_0.shape 296 | if w > h: 297 | bd = int((w - h) / 2) 298 | pair_0 = pad_image('reflection', pair_0, bd, (w-h)-bd, 0, 0) 299 | # pair_1 = pad_image('reflection', pair_1, bd, (w-h)-bd, 0, 0) 300 | elif h > w: 301 | bd = int((h-w) / 2) 302 | pair_0 = pad_image('reflection', pair_0, 0, 0, bd, (h-w)-bd) 303 | # pair_1 = pad_image('reflection', pair_1, 0, 0, bd, (h-w)-bd) 304 | 305 | h,w,_ = pair_1.shape 306 | if w > h: 307 | bd = int((w - h) / 2) 308 | # pair_0 = pad_image('reflection', pair_0, bd, (w-h)-bd, 0, 0) 309 | pair_1 = pad_image('reflection', pair_1, bd, (w-h)-bd, 0, 0) 310 | elif h > w: 311 | bd = int((h-w) / 2) 312 | # pair_0 = pad_image('reflection', pair_0, 0, 0, bd, (h-w)-bd) 313 | pair_1 = pad_image('reflection', pair_1, 0, 0, bd, (h-w)-bd) 314 | 315 | return pair_0, pair_1 316 | 317 | 318 | class Normalize(object): 319 | """Given mean: (R, G, B) and std: (R, G, B), 320 | will normalize each channel of the torch.*Tensor, i.e. 321 | channel = (channel - mean) / std 322 | """ 323 | 324 | def __init__(self, mean, std = (1.0,1.0,1.0)): 325 | self.mean = torch.FloatTensor(mean) 326 | self.std = torch.FloatTensor(std) 327 | 328 | def __call__(self, *frames): 329 | results = [] 330 | for frame in frames: 331 | for t, m, s in zip(frame, self.mean, self.std): 332 | t.sub_(m).div_(s) 333 | results.append(frame) 334 | return results 335 | 336 | 337 | class ToTensor(object): 338 | """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range 339 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 340 | """ 341 | def __call__(self, *frames): 342 | results = [] 343 | for frame in frames: 344 | frame = torch.from_numpy(frame.transpose(2,0,1).copy()).contiguous().float() 345 | results.append(frame) 346 | return results 347 | 348 | 349 | class Compose(object): 350 | """ 351 | Composes several transforms together. 352 | """ 353 | def __init__(self, transforms): 354 | self.transforms = transforms 355 | 356 | def __call__(self, *args): 357 | for t in self.transforms: 358 | args = t(*args) 359 | return args 360 | 361 | 362 | #=============================functions=============================== 363 | 364 | def resize(img, size, interpolation=cv2.INTER_NEAREST): 365 | 366 | if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)): 367 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 368 | 369 | h, w = img.shape 370 | 371 | if isinstance(size, int): 372 | if (w <= h and w == size) or (h <= w and h == size): 373 | return img 374 | if w < h: 375 | ow = size 376 | oh = int(size * h / w) 377 | return cv2.resize(img, (ow, oh), interpolation) 378 | else: 379 | oh = size 380 | ow = int(size * w / h) 381 | return cv2.resize(img, (ow, oh), interpolation) 382 | else: 383 | return cv2.resize(img, size[::-1], interpolation) 384 | 385 | 386 | def resize_large(img, size, interpolation=cv2.INTER_NEAREST): 387 | 388 | if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)): 389 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 390 | 391 | h, w,_ = img.shape 392 | 393 | if isinstance(size, int): 394 | if (w >= h and w == size) or (h >= w and h == size): 395 | return img 396 | if w > h: 397 | ow = size 398 | oh = int(size * h / w) 399 | return cv2.resize(img, (ow, oh), interpolation) 400 | else: 401 | oh = size 402 | ow = int(size * w / h) 403 | return cv2.resize(img, (ow, oh), interpolation) 404 | else: 405 | return cv2.resize(img, size[::-1], interpolation) 406 | 407 | 408 | def rotatenumpy(image, angle, interpolation=cv2.INTER_NEAREST): 409 | rot_mat = cv2.getRotationMatrix2D((image.shape[1]/2, image.shape[0]/2), angle, 1.0) 410 | result = cv2.warpAffine(image, rot_mat, (image.shape[1],image.shape[0]), flags=interpolation) 411 | return result 412 | 413 | # good, written with numpy 414 | def pad_reflection(image, top, bottom, left, right): 415 | if top == 0 and bottom == 0 and left == 0 and right == 0: 416 | return image 417 | h, w = image.shape[:2] 418 | next_top = next_bottom = next_left = next_right = 0 419 | if top > h - 1: 420 | next_top = top - h + 1 421 | top = h - 1 422 | if bottom > h - 1: 423 | next_bottom = bottom - h + 1 424 | bottom = h - 1 425 | if left > w - 1: 426 | next_left = left - w + 1 427 | left = w - 1 428 | if right > w - 1: 429 | next_right = right - w + 1 430 | right = w - 1 431 | new_shape = list(image.shape) 432 | new_shape[0] += top + bottom 433 | new_shape[1] += left + right 434 | new_image = np.empty(new_shape, dtype=image.dtype) 435 | new_image[top:top+h, left:left+w] = image 436 | new_image[:top, left:left+w] = image[top:0:-1, :] 437 | new_image[top+h:, left:left+w] = image[-1:-bottom-1:-1, :] 438 | new_image[:, :left] = new_image[:, left*2:left:-1] 439 | new_image[:, left+w:] = new_image[:, -right-1:-right*2-1:-1] 440 | return pad_reflection(new_image, next_top, next_bottom, 441 | next_left, next_right) 442 | 443 | # good, writen with numpy 444 | def pad_constant(image, top, bottom, left, right, value): 445 | if top == 0 and bottom == 0 and left == 0 and right == 0: 446 | return image 447 | h, w = image.shape[:2] 448 | new_shape = list(image.shape) 449 | new_shape[0] += top + bottom 450 | new_shape[1] += left + right 451 | new_image = np.empty(new_shape, dtype=image.dtype) 452 | new_image.fill(value) 453 | new_image[top:top+h, left:left+w] = image 454 | return new_image 455 | 456 | # change to np/non-np options 457 | def pad_image(mode, image, top, bottom, left, right, value=0): 458 | if mode == 'reflection': 459 | return pad_reflection(image, top, bottom, left, right) 460 | elif mode == 'constant': 461 | return pad_constant(image, top, bottom, left, right, value) 462 | else: 463 | raise ValueError('Unknown mode {}'.format(mode)) 464 | -------------------------------------------------------------------------------- /libs/transforms_pair.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import random 3 | import scipy.io 4 | import cv2 5 | import numpy as np 6 | from PIL import Image, ImageOps, ImageEnhance, PILLOW_VERSION 7 | try: 8 | import accimage 9 | except ImportError: 10 | accimage = None 11 | import torch 12 | 13 | class CenterCrop(object): 14 | """for pair of frames""" 15 | def __init__(self, size): 16 | if isinstance(size, numbers.Number): 17 | self.size = (int(size), int(size)) 18 | else: 19 | self.size = size 20 | 21 | def __call__(self, *frames): 22 | frames = list(frames) 23 | h, w, c = frames[0].shape 24 | th, tw = self.size 25 | top = bottom = left = right = 0 26 | 27 | if w == tw and h == th: 28 | return frames 29 | 30 | if w < tw: 31 | left = (tw - w) // 2 32 | right = tw - w - left 33 | if h < th: 34 | top = (th - h) // 2 35 | bottom = th - h - top 36 | if left > 0 or right > 0 or top > 0 or bottom > 0: 37 | for i in range(len(frames)): 38 | frames[i] = pad_image( 39 | 'reflection', frames[i], top, bottom, left, right) 40 | 41 | if w > tw: 42 | #x1 = random.randint(0, w - tw) 43 | x1 = (w - tw) // 2 44 | for i in range(len(frames)): 45 | frames[i] = frames[i][:, x1:x1+tw] 46 | if h > th: 47 | #y1 = random.randint(0, h - th) 48 | y1 = (h - th) // 2 49 | for i in range(len(frames)): 50 | frames[i] = frames[i][y1:y1+th] 51 | 52 | return frames 53 | 54 | class RandomCrop(object): 55 | """for pair of frames""" 56 | def __init__(self, size): 57 | if isinstance(size, numbers.Number): 58 | self.size = (int(size), int(size)) 59 | else: 60 | self.size = size 61 | 62 | def __call__(self, pair_0, pair_1): 63 | h, w, c = pair_0.shape 64 | th, tw = self.size 65 | top = bottom = left = right = 0 66 | 67 | if w == tw and h == th: 68 | return pair_0, pair_1 69 | 70 | if w < tw: 71 | left = (tw - w) // 2 72 | right = tw - w - left 73 | if h < th: 74 | top = (th - h) // 2 75 | bottom = th - h - top 76 | if left > 0 or right > 0 or top > 0 or bottom > 0: 77 | pair_0 = pad_image( 78 | 'reflection', pair_0, top, bottom, left, right) 79 | pair_1 = pad_image( 80 | 'reflection', pair_1, top, bottom, left, right) 81 | 82 | if w > tw: 83 | x1 = random.randint(0, w - tw) 84 | pair_0 = pair_0[:, x1:x1+tw] 85 | pair_1 = pair_1[:, x1:x1+tw] 86 | if h > th: 87 | y1 = random.randint(0, h - th) 88 | pair_0 = pair_0[y1:y1+th] 89 | pair_1 = pair_1[y1:y1+th] 90 | 91 | return pair_0, pair_1 92 | 93 | class RandomScale(object): 94 | """docstring for RandomScale""" 95 | def __init__(self, scale): 96 | if isinstance(scale, numbers.Number): 97 | scale = [1 / scale, scale] 98 | self.scale = scale 99 | 100 | def __call__(self, pair_0, pair_1): 101 | ratio = random.uniform(self.scale[0], self.scale[1]) 102 | h,w,c = pair_0.shape 103 | tw = int(ratio*w) 104 | th = int(ratio*h) 105 | if ratio == 1: 106 | return pair_0, pair_1 107 | elif ratio < 1: 108 | interpolation = cv2.INTER_LANCZOS4 109 | elif ratio > 1: 110 | interpolation = cv2.INTER_CUBIC 111 | pair_0 = cv2.resize(pair_0, dsize = (tw, th), interpolation=interpolation) 112 | pair_1 = cv2.resize(pair_1, dsize = (tw, th), interpolation=interpolation) 113 | 114 | return pair_0, pair_1 115 | 116 | class Scale(object): 117 | """docstring for Scale""" 118 | def __init__(self, short_side): 119 | self.short_side = short_side 120 | 121 | def __call__(self, pair_0, pair_1): 122 | if(type(self.short_side) == int): 123 | h,w,c = pair_0.shape 124 | if(h > w): 125 | tw = self.short_side 126 | th = (tw * h) / w 127 | th = int((th // 64) * 64) 128 | else: 129 | th = self.short_side 130 | tw = (th * w) / h 131 | tw = int((tw // 64) * 64) 132 | elif(type(self.short_side) == list): 133 | th = self.short_side[0] 134 | tw = self.short_side[1] 135 | 136 | interpolation = cv2.INTER_NEAREST 137 | pair_0 = cv2.resize(pair_0, dsize = (tw, th), interpolation=interpolation) 138 | pair_1 = cv2.resize(pair_1, dsize = (tw, th), interpolation=interpolation) 139 | 140 | return pair_0, pair_1 141 | 142 | class RandomRotate(object): 143 | """docstring for RandomRotate""" 144 | def __init__(self, angle): 145 | self.angle = angle 146 | 147 | def __call__(self, pair_0, pair_1): 148 | h,w,c = pair_0.shape 149 | p = max((h, w)) 150 | angle = random.randint(0, self.angle * 2) - self.angle 151 | pair_0 = pad_image('reflection', pair_0, h,h,w,w) 152 | pair_0 = rotatenumpy(pair_0, angle) 153 | pair_0 = pair_0[h : h + h, w : w + w] 154 | pair_1 = pad_image('reflection', pair_1, h,h,w,w) 155 | pair_1 = rotatenumpy(pair_1, angle) 156 | pair_1 = pair_1[h : h + h, w : w + w] 157 | 158 | return pair_0, pair_1 159 | 160 | class RandomHorizontalFlip(object): 161 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 162 | """ 163 | 164 | def __call__(self, pair_0, pair_1): 165 | if random.random() < 0.5: 166 | results = [cv2.flip(pair_0, 1), cv2.flip(pair_1, 1)] 167 | else: 168 | results = [pair_0, pair_1] 169 | return results 170 | 171 | class Resize(object): 172 | """Resize the input PIL Image to the given size. 173 | Args: 174 | size (sequence or int): Desired output size. If size is a sequence like 175 | (h, w), output size will be matched to this. If size is an int, 176 | smaller edge of the image will be matched to this number. 177 | i.e, if height > width, then image will be rescaled to 178 | (size * height / width, size) 179 | interpolation (int, optional): Desired interpolation. Default is 180 | ``PIL.Image.BILINEAR`` 181 | """ 182 | 183 | def __init__(self, size, interpolation=cv2.INTER_NEAREST): 184 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 185 | self.size = size 186 | self.interpolation = interpolation 187 | 188 | def __call__(self, pair_0, pair_1): 189 | 190 | return resize(pair_0, self.size, self.interpolation), \ 191 | resize(pair_1, self.size, self.interpolation) 192 | 193 | 194 | class Pad(object): 195 | 196 | def __init__(self, padding, fill=0): 197 | assert isinstance(padding, numbers.Number) 198 | assert isinstance(fill, numbers.Number) or isinstance(fill, str) or \ 199 | isinstance(fill, tuple) 200 | self.padding = padding 201 | self.fill = fill 202 | 203 | def __call__(self, pair_0, pair_1): 204 | if self.fill == -1: 205 | pair_0 = pad_image('reflection', pair_0, 206 | self.padding, self.padding, self.padding, self.padding) 207 | pair_1 = pad_image('reflection', pair_1, 208 | self.padding, self.padding, self.padding, self.padding) 209 | else: 210 | pair_0 = pad_image('constant', pair_0, 211 | self.padding, self.padding, self.padding, self.padding, 212 | value=self.fill) 213 | pair_1 = pad_image('constant', pair_1, 214 | self.padding, self.padding, self.padding, self.padding, 215 | value=self.fill) 216 | 217 | return pair_0, pair_1 218 | 219 | 220 | class ResizeandPad(object): 221 | """ 222 | resize the larger boundary to the desized eva_size; 223 | pad the smaller one to square 224 | """ 225 | def __init__(self, size, interpolation=cv2.INTER_NEAREST): 226 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 227 | self.size = size 228 | self.interpolation = interpolation 229 | 230 | def __call__(self, pair_0, pair_1): 231 | """ 232 | Resize and Pad 233 | """ 234 | pair_0 = resize(pair_0, self.size, self.interpolation) 235 | pair_1 = resize(pair_1, self.size, self.interpolation) 236 | h,w,_ = pair_0.shape 237 | if w > h: 238 | bd = int((w - h) / 2) 239 | pair_0 = pad_image('reflection', pair_0, bd, (w-h)-bd, 0, 0) 240 | pair_1 = pad_image('reflection', pair_1, bd, (w-h)-bd, 0, 0) 241 | elif h > w: 242 | bd = int((h-w) / 2) 243 | pair_0 = pad_image('reflection', pair_0, 0, 0, bd, (h-w)-bd) 244 | pair_1 = pad_image('reflection', pair_1, 0, 0, bd, (h-w)-bd) 245 | return pair_0, pair_1 246 | 247 | 248 | class Normalize(object): 249 | """Given mean: (R, G, B) and std: (R, G, B), 250 | will normalize each channel of the torch.*Tensor, i.e. 251 | channel = (channel - mean) / std 252 | """ 253 | 254 | def __init__(self, mean, std = (1.0,1.0,1.0)): 255 | self.mean = torch.FloatTensor(mean) 256 | self.std = torch.FloatTensor(std) 257 | 258 | def __call__(self, pair_0, pair_1): 259 | for t, m, s in zip(pair_0, self.mean, self.std): 260 | t.sub_(m).div_(s) 261 | for t, m, s in zip(pair_1, self.mean, self.std): 262 | t.sub_(m).div_(s) 263 | # print("pair_0: ", pair_0.size()) 264 | # print("pair_1: ", pair_1.size()) 265 | return pair_0, pair_1 266 | 267 | 268 | class ToTensor(object): 269 | """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range 270 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 271 | """ 272 | def __call__(self, pair_0, pair_1): 273 | if(pair_0.ndim == 2): 274 | pair_0 = torch.from_numpy(pair_0.copy()).contiguous().float().unsqueeze(0) 275 | pair_1 = torch.from_numpy(pair_1.copy()).contiguous().float().unsqueeze(0) 276 | else: 277 | pair_0 = torch.from_numpy(pair_0.transpose(2,0,1).copy()).contiguous().float() 278 | pair_1 = torch.from_numpy(pair_1.transpose(2,0,1).copy()).contiguous().float() 279 | return pair_0, pair_1 280 | 281 | 282 | class Compose(object): 283 | """ 284 | Composes several transforms together. 285 | """ 286 | def __init__(self, transforms): 287 | self.transforms = transforms 288 | 289 | def __call__(self, *args): 290 | for t in self.transforms: 291 | args = t(*args) 292 | return args 293 | 294 | 295 | #=============================functions=============================== 296 | 297 | def resize(img, size, interpolation=cv2.INTER_NEAREST): 298 | 299 | if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)): 300 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 301 | 302 | h, w, _ = img.shape 303 | 304 | if isinstance(size, int): 305 | if (w <= h and w == size) or (h <= w and h == size): 306 | return img 307 | if w < h: 308 | ow = size 309 | oh = int(size * h / w) 310 | return cv2.resize(img, (ow, oh), interpolation) 311 | else: 312 | oh = size 313 | ow = int(size * w / h) 314 | return cv2.resize(img, (ow, oh), interpolation) 315 | else: 316 | return cv2.resize(img, size[::-1], interpolation) 317 | 318 | 319 | def rotatenumpy(image, angle, interpolation=cv2.INTER_NEAREST): 320 | rot_mat = cv2.getRotationMatrix2D((image.shape[1]/2, image.shape[0]/2), angle, 1.0) 321 | result = cv2.warpAffine(image, rot_mat, (image.shape[1],image.shape[0]), flags=interpolation) 322 | return result 323 | 324 | # good, written with numpy 325 | def pad_reflection(image, top, bottom, left, right): 326 | if top == 0 and bottom == 0 and left == 0 and right == 0: 327 | return image 328 | h, w = image.shape[:2] 329 | next_top = next_bottom = next_left = next_right = 0 330 | if top > h - 1: 331 | next_top = top - h + 1 332 | top = h - 1 333 | if bottom > h - 1: 334 | next_bottom = bottom - h + 1 335 | bottom = h - 1 336 | if left > w - 1: 337 | next_left = left - w + 1 338 | left = w - 1 339 | if right > w - 1: 340 | next_right = right - w + 1 341 | right = w - 1 342 | new_shape = list(image.shape) 343 | new_shape[0] += top + bottom 344 | new_shape[1] += left + right 345 | new_image = np.empty(new_shape, dtype=image.dtype) 346 | new_image[top:top+h, left:left+w] = image 347 | new_image[:top, left:left+w] = image[top:0:-1, :] 348 | new_image[top+h:, left:left+w] = image[-1:-bottom-1:-1, :] 349 | new_image[:, :left] = new_image[:, left*2:left:-1] 350 | new_image[:, left+w:] = new_image[:, -right-1:-right*2-1:-1] 351 | return pad_reflection(new_image, next_top, next_bottom, 352 | next_left, next_right) 353 | 354 | # good, writen with numpy 355 | def pad_constant(image, top, bottom, left, right, value): 356 | if top == 0 and bottom == 0 and left == 0 and right == 0: 357 | return image 358 | h, w = image.shape[:2] 359 | new_shape = list(image.shape) 360 | new_shape[0] += top + bottom 361 | new_shape[1] += left + right 362 | new_image = np.empty(new_shape, dtype=image.dtype) 363 | new_image.fill(value) 364 | new_image[top:top+h, left:left+w] = image 365 | return new_image 366 | 367 | # change to np/non-np options 368 | def pad_image(mode, image, top, bottom, left, right, value=0): 369 | if mode == 'reflection': 370 | return pad_reflection(image, top, bottom, left, right) 371 | elif mode == 'constant': 372 | return pad_constant(image, top, bottom, left, right, value) 373 | else: 374 | raise ValueError('Unknown mode {}'.format(mode)) 375 | -------------------------------------------------------------------------------- /libs/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | import torch 5 | import scipy.misc 6 | import numpy as np 7 | from PIL import Image 8 | import libs.transforms_multi as transforms 9 | 10 | def print_options(opt, test=False): 11 | message = '' 12 | message += '----------------- Options ---------------\n' 13 | for k, v in sorted(vars(opt).items()): 14 | comment = '' 15 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 16 | message += '----------------- End -------------------' 17 | print(message) 18 | 19 | # save to the disk 20 | if(test): 21 | expr_dir = os.path.join(opt.out_dir) 22 | os.makedirs(expr_dir,exist_ok=True) 23 | file_name = os.path.join(expr_dir, 'test_opt.txt') 24 | else: 25 | expr_dir = os.path.join(opt.checkpoint_dir) 26 | os.makedirs(expr_dir,exist_ok=True) 27 | file_name = os.path.join(expr_dir, 'train_opt.txt') 28 | with open(file_name, 'wt') as opt_file: 29 | opt_file.write(message) 30 | opt_file.write('\n') 31 | 32 | def read_frame_list(video_dir): 33 | frame_list = [img for img in glob.glob(os.path.join(video_dir,"*.jpg"))] 34 | frame_list = sorted(frame_list) 35 | return frame_list 36 | 37 | def read_frame(frame_dir, transforms): 38 | frame = cv2.imread(frame_dir) 39 | ori_h,ori_w,_ = frame.shape 40 | if(ori_h > ori_w): 41 | tw = ori_w 42 | th = (tw * ori_h) / ori_w 43 | th = int((th // 64) * 64) 44 | else: 45 | th = ori_h 46 | tw = (th * ori_w) / ori_h 47 | tw = int((tw // 64) * 64) 48 | #h = (ori_h // 64) * 64 49 | #w = (ori_w // 64) * 64 50 | frame = cv2.resize(frame, (tw,th)) 51 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2LAB) 52 | 53 | pair = [frame, frame] 54 | transformed = list(transforms(*pair)) 55 | return transformed[0].cuda().unsqueeze(0), ori_h, ori_w 56 | 57 | def to_one_hot(y_tensor, n_dims=9): 58 | _,h,w = y_tensor.size() 59 | """ Take integer y (tensor or variable) with n dims and convert it to 1-hot representation with n+1 dims. """ 60 | y_tensor = y_tensor.type(torch.LongTensor).view(-1, 1) 61 | n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1 62 | y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1) 63 | y_one_hot = y_one_hot.view(h,w,n_dims) 64 | return y_one_hot.permute(2,0,1).unsqueeze(0) 65 | 66 | def read_seg(seg_dir, crop_size): 67 | seg = Image.open(seg_dir) 68 | h,w = seg.size 69 | if(h > w): 70 | tw = crop_size 71 | th = (tw * h) / w 72 | th = int((th // 64) * 64) 73 | else: 74 | th = crop_size 75 | tw = (th * w) / h 76 | tw = int((tw // 64) * 64) 77 | seg = np.asarray(seg).reshape((w,h,1)) 78 | seg = np.squeeze(seg) 79 | seg = scipy.misc.imresize(seg, (tw//8,th//8),"nearest",mode="F") 80 | 81 | seg = torch.from_numpy(seg).view(1,tw//8,th//8) 82 | return to_one_hot(seg) 83 | 84 | def create_transforms(crop_size): 85 | normalize = transforms.Normalize(mean = (128, 128, 128), std = (128, 128, 128)) 86 | t = [] 87 | t.extend([transforms.ToTensor(), 88 | normalize]) 89 | return transforms.Compose(t) 90 | 91 | def transform_topk(aff, frame1, k): 92 | """ 93 | INPUTS: 94 | - aff: affinity matrix, b * N * N 95 | - frame1: reference frame 96 | - k: only aggregate top-k pixels with highest aff(j,i) 97 | """ 98 | b,c,h,w = frame1.size() 99 | b, N, _ = aff.size() 100 | # b * 20 * N 101 | tk_val, tk_idx = torch.topk(aff, dim = 1, k=k) 102 | # b * N 103 | tk_val_min,_ = torch.min(tk_val,dim=1) 104 | tk_val_min = tk_val_min.view(b,1,N) 105 | aff[tk_val_min > aff] = 0 106 | frame1 = frame1.view(b,c,-1) 107 | frame2 = torch.bmm(frame1, aff) 108 | return frame2.view(b,c,h,w) 109 | 110 | def norm_mask(mask): 111 | """ 112 | INPUTS: 113 | - mask: segmentation mask 114 | """ 115 | c,h,w = mask.size() 116 | for cnt in range(c): 117 | mask_cnt = mask[cnt,:,:] 118 | if(mask_cnt.max() > 0): 119 | mask_cnt = (mask_cnt - mask_cnt.min()) 120 | mask_cnt = mask_cnt/mask_cnt.max() 121 | mask[cnt,:,:] = mask_cnt 122 | return mask 123 | 124 | def diff_crop(F, x1, y1, x2, y2, ph, pw): 125 | """ 126 | Differatiable cropping 127 | INPUTS: 128 | - F: frame feature 129 | - x1,y1,x2,y2: top left and bottom right points of the patch 130 | - theta is defined as : 131 | a b c 132 | d e f 133 | """ 134 | bs, ch, h, w = F.size() 135 | a = ((x2-x1)/w).view(bs,1,1) 136 | b = torch.zeros(a.size()).cuda() 137 | c = (-1+(x1+x2)/w).view(bs,1,1) 138 | d = torch.zeros(a.size()).cuda() 139 | e = ((y2-y1)/h).view(bs,1,1) 140 | f = (-1+(y2+y1)/h).view(bs,1,1) 141 | theta_row1 = torch.cat((a,b,c),dim=2) 142 | theta_row2 = torch.cat((d,e,f),dim=2) 143 | theta = torch.cat((theta_row1, theta_row2),dim=1).cuda() 144 | size = torch.Size((bs,ch,pw,ph)) 145 | grid = torch.nn.functional.affine_grid(theta, size) 146 | patch = torch.nn.functional.grid_sample(F,grid) 147 | return patch 148 | 149 | def center2bbox(center, patch_size, h, w): 150 | b = center.size(0) 151 | if(isinstance(patch_size,int)): 152 | new_l = center[:,0] - patch_size/2 153 | else: 154 | new_l = center[:,0] - patch_size[1]/2 155 | new_l[new_l < 0] = 0 156 | new_l = new_l.view(b,1) 157 | 158 | if(isinstance(patch_size,int)): 159 | new_r = new_l + patch_size 160 | else: 161 | new_r = new_l + patch_size[1] 162 | new_r[new_r > w] = w 163 | 164 | if(isinstance(patch_size,int)): 165 | new_t = center[:,1] - patch_size/2 166 | else: 167 | new_t = center[:,1] - patch_size[0]/2 168 | new_t[new_t < 0] = 0 169 | new_t = new_t.view(b,1) 170 | 171 | if(isinstance(patch_size,int)): 172 | new_b = new_t + patch_size 173 | else: 174 | new_b = new_t + patch_size[0] 175 | new_b[new_b > h] = h 176 | 177 | new_center = torch.cat((new_l,new_r,new_t,new_b),dim=1) 178 | return new_center 179 | -------------------------------------------------------------------------------- /libs/vis_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | #import seaborn as sns 5 | from .model import transform 6 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas 7 | import torchvision.utils as vutils 8 | 9 | UNKNOWN_FLOW_THRESH = 1e7 10 | SMALLFLOW = 0.0 11 | LARGEFLOW = 1e8 12 | 13 | def prepare_img(img): 14 | 15 | if img.ndim == 3: 16 | img = img[:, :, ::-1] ### RGB to BGR 17 | 18 | ## clip to [0, 1] 19 | img = np.clip(img, 0, 1) 20 | 21 | ## quantize to [0, 255] 22 | img = np.uint8(img * 255.0) 23 | 24 | #cv2.imwrite(filename, img, [cv2.IMWRITE_PNG_COMPRESSION, 0]) 25 | return img 26 | 27 | def flow_to_rgb(flow, mr = None): 28 | """ 29 | Convert flow into middlebury color code image 30 | :param flow: optical flow map 31 | :return: optical flow image in middlebury color 32 | """ 33 | u = flow[:, :, 0] 34 | v = flow[:, :, 1] 35 | 36 | maxu = -999. 37 | maxv = -999. 38 | minu = 999. 39 | minv = 999. 40 | 41 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) 42 | u[idxUnknow] = 0 43 | v[idxUnknow] = 0 44 | 45 | maxu = max(maxu, np.max(u)) 46 | minu = min(minu, np.min(u)) 47 | 48 | maxv = max(maxv, np.max(v)) 49 | minv = min(minv, np.min(v)) 50 | 51 | rad = np.sqrt(u ** 2 + v ** 2) 52 | maxrad = max(-1, np.max(rad)) 53 | if(mr is not None): 54 | maxrad = mr 55 | 56 | #print "max flow: %.4f\nflow range:\nu = %.3f .. %.3f\nv = %.3f .. %.3f" % (maxrad, minu,maxu, minv, maxv) 57 | 58 | u = u/(maxrad + np.finfo(float).eps) 59 | v = v/(maxrad + np.finfo(float).eps) 60 | 61 | img = compute_color(u, v) 62 | 63 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) 64 | img[idx] = 0 65 | 66 | return np.float32(img) / 255.0, maxrad 67 | 68 | 69 | def compute_color(u, v): 70 | """ 71 | compute optical flow color map 72 | :param u: optical flow horizontal map 73 | :param v: optical flow vertical map 74 | :return: optical flow in color code 75 | """ 76 | [h, w] = u.shape 77 | img = np.zeros([h, w, 3]) 78 | nanIdx = np.isnan(u) | np.isnan(v) 79 | u[nanIdx] = 0 80 | v[nanIdx] = 0 81 | 82 | colorwheel = make_color_wheel() 83 | ncols = np.size(colorwheel, 0) 84 | 85 | rad = np.sqrt(u**2+v**2) 86 | 87 | a = np.arctan2(-v, -u) / np.pi 88 | 89 | fk = (a+1) / 2 * (ncols - 1) + 1 90 | 91 | k0 = np.floor(fk).astype(int) 92 | 93 | k1 = k0 + 1 94 | k1[k1 == ncols+1] = 1 95 | f = fk - k0 96 | 97 | for i in range(0, np.size(colorwheel,1)): 98 | tmp = colorwheel[:, i] 99 | col0 = tmp[k0-1] / 255 100 | col1 = tmp[k1-1] / 255 101 | col = (1-f) * col0 + f * col1 102 | 103 | idx = rad <= 1 104 | col[idx] = 1-rad[idx]*(1-col[idx]) 105 | notidx = np.logical_not(idx) 106 | 107 | col[notidx] *= 0.75 108 | img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx))) 109 | 110 | return img 111 | 112 | 113 | def make_color_wheel(): 114 | """ 115 | Generate color wheel according Middlebury color code 116 | :return: Color wheel 117 | """ 118 | RY = 15 119 | YG = 6 120 | GC = 4 121 | CB = 11 122 | BM = 13 123 | MR = 6 124 | 125 | ncols = RY + YG + GC + CB + BM + MR 126 | 127 | colorwheel = np.zeros([ncols, 3]) 128 | 129 | col = 0 130 | 131 | # RY 132 | colorwheel[0:RY, 0] = 255 133 | colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY)) 134 | col += RY 135 | 136 | # YG 137 | colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG)) 138 | colorwheel[col:col+YG, 1] = 255 139 | col += YG 140 | 141 | # GC 142 | colorwheel[col:col+GC, 1] = 255 143 | colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC)) 144 | col += GC 145 | 146 | # CB 147 | colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB)) 148 | colorwheel[col:col+CB, 2] = 255 149 | col += CB 150 | 151 | # BM 152 | colorwheel[col:col+BM, 2] = 255 153 | colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM)) 154 | col += + BM 155 | 156 | # MR 157 | colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) 158 | colorwheel[col:col+MR, 0] = 255 159 | 160 | return colorwheel 161 | 162 | def aff2flow(A, F_size, GPU=True): 163 | """ 164 | INPUT: 165 | - A: a (H*W)*(H*W) affinity matrix 166 | - F_size: image feature size 167 | OUTPUT: 168 | - U: a (2*H*W) flow tensor, U_ij indicates the coordinates of pixel ij in target image. 169 | """ 170 | b,c,h,w = F_size 171 | theta = torch.tensor([[1,0,0],[0,1,0]]) 172 | theta = theta.unsqueeze(0).repeat(b,1,1) 173 | theta = theta.float() 174 | 175 | # grid is a uniform grid with left top (-1,1) and right bottom (1,1) 176 | # b * (h*w) * 2 177 | grid = torch.nn.functional.affine_grid(theta, F_size) 178 | #grid = grid.view(b,h*w,2) 179 | if(GPU): 180 | grid = grid.cuda() 181 | # b * (h*w) * 2 182 | # A: 1x1024x1024 183 | grid = grid.permute(0,3,1,2) 184 | U = transform(A, grid) 185 | return (U - grid).permute(0,2,3,1) 186 | 187 | def draw_certainty_map(map, normalize=False): 188 | """ 189 | INPUTS: 190 | - map: certainty map of flow 191 | """ 192 | map = map.squeeze() 193 | # normalization 194 | if(normalize): 195 | map = (map - map.min())/(map.max() - map.min()) 196 | # draw heat map 197 | ax = sns.heatmap(map, yticklabels=False, xticklabels=False, cbar=True) 198 | else: 199 | ax = sns.heatmap(map, yticklabels=False, xticklabels=False, vmin=0.0, vmax=1.5, cbar=True) 200 | figure = ax.get_figure() 201 | width, height = figure.get_size_inches() * figure.get_dpi() 202 | 203 | canvas = FigureCanvas(figure) 204 | canvas.draw() 205 | image = np.fromstring(canvas.tostring_rgb(), dtype='uint8') 206 | image = image.reshape(int(height), int(width), 3) 207 | # crop border 208 | gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 209 | gray[gray == 255] = 0 210 | gray[gray > 0] = 255 211 | coords = cv2.findNonZero(gray) 212 | x,y,w,h = cv2.boundingRect(coords) 213 | #rect = image[y:y+h, x:x+w] 214 | rect = image 215 | figure.clf() 216 | return rect 217 | 218 | def draw_foreground(aff, mask, temp = 1, reverse=False): 219 | """ 220 | INPUTS: 221 | - aff: a b*N*N affinity matrix, without applying any softmax 222 | - mask: a b*h*w segmentation mask of the first frame 223 | - temp: temperature 224 | - reverse: if False, calculates where does each pixel go 225 | if True, calculates where does each pixel come from 226 | """ 227 | mask = torch.argmax(mask, dim = 0) 228 | h,w = mask.size() 229 | res = torch.zeros(h,w) 230 | if(reverse): 231 | # apply softmax to each column 232 | aff = torch.nn.functional.softmax(aff*temp, dim = 0) 233 | else: 234 | # apply softmax to each row 235 | aff = torch.nn.functional.softmax(aff*temp, dim = 1) 236 | # extract forground affinity 237 | # N 238 | mask_flat = mask.view(-1) 239 | # x 240 | fgcmask = mask_flat.nonzero().squeeze() 241 | if(reverse): 242 | # N * x 243 | Faff = aff[:,fgcmask] 244 | else: 245 | # x * N 246 | Faff = aff[fgcmask,:] 247 | 248 | theta = torch.tensor([[1,0,0],[0,1,0]]) 249 | theta = theta.unsqueeze(0).repeat(1,1,1) 250 | theta = theta.float() 251 | 252 | # grid is a uniform grid with left top (-1,1) and right bottom (1,1) 253 | # 1 * (h*w) * 2 254 | grid = torch.nn.functional.affine_grid(theta, torch.Size((1,1,h,w))) 255 | # N*2 256 | grid = grid.squeeze() 257 | grid = grid.view(-1,2) 258 | grid = (grid + 1)/2 259 | grid[:,0] *= w 260 | grid[:,1] *= h 261 | if(reverse): 262 | grid = grid.permute(1,0) 263 | # 2 * x 264 | Fcoord = torch.mm(grid, Faff) 265 | Fcoord = Fcoord.long() 266 | res[Fcoord[1,:], Fcoord[0,:]] = 1 267 | else: 268 | # x * 2 269 | grid -= 1 270 | Fcoord = torch.mm(Faff, grid) 271 | Fcoord = Fcoord.long() 272 | res[Fcoord[:,1], Fcoord[:,0]] = 1 273 | res = torch.nn.functional.interpolate(res.unsqueeze(0).unsqueeze(0), scale_factor=8, mode='bilinear') 274 | return res 275 | 276 | def norm_mask(mask): 277 | """ 278 | INPUTS: 279 | - mask: segmentation mask 280 | """ 281 | c,h,w = mask.size() 282 | for cnt in range(c): 283 | mask_cnt = mask[cnt,:,:] 284 | if(mask_cnt.max() > 0): 285 | mask_cnt = (mask_cnt - mask_cnt.min()) 286 | mask_cnt = mask_cnt/mask_cnt.max() 287 | mask[cnt,:,:] = mask_cnt 288 | return mask 289 | 290 | def read_flo(filename): 291 | FLO_TAG = 202021.25 292 | 293 | with open(filename, 'rb') as f: 294 | tag = np.fromfile(f, np.float32, count=1) 295 | 296 | if tag != FLO_TAG: 297 | sys.exit('Wrong tag. Invalid .flo file %s' %filename) 298 | else: 299 | w = int(np.fromfile(f, np.int32, count=1)) 300 | h = int(np.fromfile(f, np.int32, count=1)) 301 | #print 'Reading %d x %d flo file' % (w, h) 302 | 303 | data = np.fromfile(f, np.float32, count=2*w*h) 304 | 305 | # Reshape data into 3D array (columns, rows, bands) 306 | flow = np.resize(data, (h, w, 2)) 307 | 308 | return flow 309 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | from libs.net_utils import NLM, NLM_dot, NLM_woSoft 5 | from torchvision.models import resnet18 6 | from autoencoder import encoder3, decoder3, encoder_res18, encoder_res50 7 | from torch.utils.serialization import load_lua 8 | from libs.utils import * 9 | 10 | def transform(aff, frame1): 11 | """ 12 | Given aff, copy from frame1 to construct frame2. 13 | INPUTS: 14 | - aff: (h*w)*(h*w) affinity matrix 15 | - frame1: n*c*h*w feature map 16 | """ 17 | b,c,h,w = frame1.size() 18 | frame1 = frame1.view(b,c,-1) 19 | frame2 = torch.bmm(frame1, aff) 20 | return frame2.view(b,c,h,w) 21 | 22 | class normalize(nn.Module): 23 | """Given mean: (R, G, B) and std: (R, G, B), 24 | will normalize each channel of the torch.*Tensor, i.e. 25 | channel = (channel - mean) / std 26 | """ 27 | 28 | def __init__(self, mean, std = (1.0,1.0,1.0)): 29 | super(normalize, self).__init__() 30 | self.mean = nn.Parameter(torch.FloatTensor(mean).cuda(), requires_grad=False) 31 | self.std = nn.Parameter(torch.FloatTensor(std).cuda(), requires_grad=False) 32 | 33 | def forward(self, frames): 34 | b,c,h,w = frames.size() 35 | frames = (frames - self.mean.view(1,3,1,1).repeat(b,1,h,w))/self.std.view(1,3,1,1).repeat(b,1,h,w) 36 | return frames 37 | 38 | def create_flat_grid(F_size, GPU=True): 39 | """ 40 | INPUTS: 41 | - F_size: feature size 42 | OUTPUT: 43 | - return a standard grid coordinate 44 | """ 45 | b, c, h, w = F_size 46 | theta = torch.tensor([[1,0,0],[0,1,0]]) 47 | theta = theta.unsqueeze(0).repeat(b,1,1) 48 | theta = theta.float() 49 | 50 | # grid is a uniform grid with left top (-1,1) and right bottom (1,1) 51 | # b * (h*w) * 2 52 | grid = torch.nn.functional.affine_grid(theta, F_size) 53 | grid[:,:,:,0] = (grid[:,:,:,0]+1)/2 * w 54 | grid[:,:,:,1] = (grid[:,:,:,1]+1)/2 * h 55 | grid_flat = grid.view(b,-1,2) 56 | if(GPU): 57 | grid_flat = grid_flat.cuda() 58 | return grid_flat 59 | 60 | 61 | def coords2bbox(coords, patch_size, h_tar, w_tar): 62 | """ 63 | INPUTS: 64 | - coords: coordinates of pixels in the next frame 65 | - patch_size: patch size 66 | - h_tar: target image height 67 | - w_tar: target image widthg 68 | """ 69 | b = coords.size(0) 70 | center = torch.mean(coords, dim=1) # b * 2 71 | center_repeat = center.unsqueeze(1).repeat(1,coords.size(1),1) 72 | dis_x = torch.sqrt(torch.pow(coords[:,:,0] - center_repeat[:,:,0], 2)) 73 | dis_x = torch.mean(dis_x, dim=1).detach() 74 | dis_y = torch.sqrt(torch.pow(coords[:,:,1] - center_repeat[:,:,1], 2)) 75 | dis_y = torch.mean(dis_y, dim=1).detach() 76 | left = (center[:,0] - dis_x*2).view(b,1) 77 | left[left < 0] = 0 78 | right = (center[:,0] + dis_x*2).view(b,1) 79 | right[right > w_tar] = w_tar 80 | top = (center[:,1] - dis_y*2).view(b,1) 81 | top[top < 0] = 0 82 | bottom = (center[:,1] + dis_y*2).view(b,1) 83 | bottom[bottom > h_tar] = h_tar 84 | new_center = torch.cat((left,right,top,bottom),dim=1) 85 | return new_center 86 | 87 | 88 | 89 | class track_match_comb(nn.Module): 90 | def __init__(self, pretrained, encoder_dir = None, decoder_dir = None, temp=1, Resnet = "r18", color_switch=True, coord_switch=True): 91 | super(track_match_comb, self).__init__() 92 | 93 | if Resnet in "r18": 94 | self.gray_encoder = encoder_res18(pretrained=pretrained, uselayer=4) 95 | elif Resnet in "r50": 96 | self.gray_encoder = encoder_res50(pretrained=pretrained, uselayer=4) 97 | self.rgb_encoder = encoder3(reduce=True) 98 | self.decoder = decoder3(reduce=True) 99 | 100 | self.rgb_encoder.load_state_dict(torch.load(encoder_dir)) 101 | self.decoder.load_state_dict(torch.load(decoder_dir)) 102 | for param in self.decoder.parameters(): 103 | param.requires_grad = False 104 | for param in self.rgb_encoder.parameters(): 105 | param.requires_grad = False 106 | 107 | self.nlm = NLM_woSoft() 108 | self.normalize = normalize(mean=[0.485, 0.456, 0.406], 109 | std=[0.229, 0.224, 0.225]) 110 | self.softmax = nn.Softmax(dim=1) 111 | self.temp = temp 112 | self.grid_flat = None 113 | self.grid_flat_crop = None 114 | self.color_switch = color_switch 115 | self.coord_switch = coord_switch 116 | 117 | 118 | def forward(self, img_ref, img_tar, warm_up=True, patch_size=None): 119 | n, c, h_ref, w_ref = img_ref.size() 120 | n, c, h_tar, w_tar = img_tar.size() 121 | gray_ref = copy.deepcopy(img_ref[:,0].view(n,1,h_ref,w_ref).repeat(1,3,1,1)) 122 | gray_tar = copy.deepcopy(img_tar[:,0].view(n,1,h_tar,w_tar).repeat(1,3,1,1)) 123 | 124 | gray_ref = (gray_ref + 1) / 2 125 | gray_tar = (gray_tar + 1) / 2 126 | 127 | gray_ref = self.normalize(gray_ref) 128 | gray_tar = self.normalize(gray_tar) 129 | 130 | Fgray1 = self.gray_encoder(gray_ref) 131 | Fgray2 = self.gray_encoder(gray_tar) 132 | Fcolor1 = self.rgb_encoder(img_ref) 133 | 134 | output = [] 135 | 136 | if warm_up: 137 | aff = self.nlm(Fgray1, Fgray2) 138 | aff_norm = self.softmax(aff) 139 | Fcolor2_est = transform(aff_norm, Fcolor1) 140 | color2_est = self.decoder(Fcolor2_est) 141 | 142 | output.append(color2_est) 143 | output.append(aff) 144 | 145 | if self.color_switch: 146 | Fcolor2 = self.rgb_encoder(img_tar) 147 | Fcolor1_est = transform(aff_norm.transpose(1,2), Fcolor2) 148 | color1_est = self.decoder(Fcolor1_est) 149 | output.append(color1_est) 150 | else: 151 | if(self.grid_flat is None): 152 | self.grid_flat = create_flat_grid(Fgray2.size()) 153 | aff_ref_tar = self.nlm(Fgray1, Fgray2) 154 | aff_ref_tar = torch.nn.functional.softmax(aff_ref_tar * self.temp, dim = 2) 155 | coords = torch.bmm(aff_ref_tar, self.grid_flat) 156 | center = torch.mean(coords, dim=1) # b * 2 157 | # new_c = center2bbox(center, patch_size, h_tar, w_tar) 158 | new_c = center2bbox(center, patch_size, Fgray2.size(2), Fgray2.size(3)) 159 | # print("center2bbox:", new_c, h_tar, w_tar) 160 | 161 | Fgray2_crop = diff_crop(Fgray2, new_c[:,0], new_c[:,2], new_c[:,1], new_c[:,3], patch_size[1], patch_size[0]) 162 | # print("HERE: ", Fgray2.size(), Fgray1.size(), Fgray2_crop.size()) 163 | 164 | aff_p = self.nlm(Fgray1, Fgray2_crop) 165 | aff_norm = self.softmax(aff_p * self.temp) 166 | Fcolor2_est = transform(aff_norm, Fcolor1) 167 | color2_est = self.decoder(Fcolor2_est) 168 | 169 | Fcolor2_full = self.rgb_encoder(img_tar) 170 | Fcolor2_crop = diff_crop(Fcolor2_full, new_c[:,0], new_c[:,2], new_c[:,1], new_c[:,3], patch_size[1], patch_size[0]) 171 | 172 | output.append(color2_est) 173 | output.append(aff_p) 174 | output.append(new_c*8) 175 | output.append(coords) 176 | 177 | # color orthorganal 178 | if self.color_switch: 179 | Fcolor1_est = transform(aff_norm.transpose(1,2), Fcolor2_crop) 180 | color1_est = self.decoder(Fcolor1_est) 181 | output.append(color1_est) 182 | 183 | # coord orthorganal 184 | if self.coord_switch: 185 | aff_norm_tran = self.softmax(aff_p.permute(0,2,1)*self.temp) 186 | if self.grid_flat_crop is None: 187 | self.grid_flat_crop = create_flat_grid(Fp_tar.size()).permute(0,2,1).detach() 188 | C12 = torch.bmm(self.grid_flat_crop, aff_norm) 189 | C11 = torch.bmm(C12, aff_norm_tran) 190 | output.append(self.grid_flat_crop) 191 | output.append(C11) 192 | 193 | return output 194 | 195 | 196 | class Model_switchGTfixdot_swCC_Res(nn.Module): 197 | def __init__(self, encoder_dir = None, decoder_dir = None, fix_dec = True, 198 | temp = None, pretrainRes = False, uselayer=3, model='resnet18'): 199 | ''' 200 | For switchable concenration loss 201 | Using Resnet18 202 | ''' 203 | super(Model_switchGTfixdot_swCC_Res, self).__init__() 204 | if(model == 'resnet18'): 205 | print('Use ResNet18.') 206 | self.gray_encoder = encoder_res18(pretrained = pretrainRes, uselayer=uselayer) 207 | else: 208 | print('Use ResNet50.') 209 | self.gray_encoder = encoder_res50(pretrained = pretrainRes, uselayer=uselayer) 210 | self.rgb_encoder = encoder3(reduce = True) 211 | self.nlm = NLM_woSoft() 212 | self.decoder = decoder3(reduce = True) 213 | self.temp = temp 214 | self.softmax = nn.Softmax(dim=1) 215 | self.cos_window = torch.Tensor(np.outer(np.hanning(40), np.hanning(40))).cuda() 216 | self.normalize = normalize(mean=[0.485, 0.456, 0.406], 217 | std=[0.229, 0.224, 0.225]) 218 | 219 | self.rgb_encoder.load_state_dict(torch.load(encoder_dir)) 220 | self.decoder.load_state_dict(torch.load(decoder_dir)) 221 | 222 | for param in self.decoder.parameters(): 223 | param.requires_grad = False 224 | for param in self.rgb_encoder.parameters(): 225 | param.requires_grad = False 226 | 227 | def forward(self, gray1, gray2, color1=None, color2=None): 228 | # move gray scale image to 0-1 so that they match ImageNet pre-training 229 | gray1 = (gray1 + 1) / 2 230 | gray2 = (gray2 + 1) / 2 231 | 232 | # normalize to fit resnet 233 | b = gray1.size(0) 234 | 235 | gray1 = self.normalize(gray1) 236 | gray2 = self.normalize(gray2) 237 | 238 | Fgray1 = self.gray_encoder(gray1) 239 | Fgray2 = self.gray_encoder(gray2) 240 | 241 | aff = self.nlm(Fgray1, Fgray2) # bx4096x4096 242 | aff_norm = self.softmax(aff*self.temp) 243 | 244 | if(color1 is None): 245 | return aff_norm, Fgray1, Fgray2 246 | 247 | Fcolor1 = self.rgb_encoder(color1) 248 | Fcolor2 = self.rgb_encoder(color2) 249 | Fcolor2_est = transform(aff_norm, Fcolor1) 250 | pred2 = self.decoder(Fcolor2_est) 251 | 252 | Fcolor1_est = transform(aff_norm.transpose(1,2), Fcolor2) 253 | pred1 = self.decoder(Fcolor1_est) 254 | 255 | return pred1, pred2, aff_norm, aff, Fgray1, Fgray2 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # OS libraries 2 | import os 3 | import copy 4 | import queue 5 | import argparse 6 | import scipy.misc 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | # Pytorch 11 | import torch 12 | import torch.nn as nn 13 | 14 | # Customized libraries 15 | from libs.test_utils import * 16 | from libs.model import transform 17 | from libs.vis_utils import norm_mask 18 | from libs.model import Model_switchGTfixdot_swCC_Res as Model 19 | 20 | ############################## helper functions ############################## 21 | def parse_args(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--batch_size", type = int, default = 1, 24 | help = "batch size") 25 | parser.add_argument("-o","--out_dir",type = str,default = "results/", 26 | help = "output saving path") 27 | parser.add_argument("--device", type = int, default = 5, 28 | help = "0~4 for single GPU, 5 for dataparallel.") 29 | parser.add_argument("-c","--checkpoint_dir",type = str, 30 | default = "weights/checkpoint_latest.pth.tar", 31 | help = "checkpoints path") 32 | parser.add_argument("-s", "--scale_size", type = int, nargs = "+", 33 | help = "scale size, a single number for shorter edge, or a pair for height and width") 34 | parser.add_argument("--pre_num", type = int, default = 7, 35 | help = "preceding frame numbers") 36 | parser.add_argument("--temp", type = float,default = 1, 37 | help = "softmax temperature") 38 | parser.add_argument("--topk", type = int, default = 5, 39 | help = "accumulate label from top k neighbors") 40 | parser.add_argument("-d", "--davis_dir", type = str, 41 | default = "/workspace/DAVIS/", 42 | help = "davis dataset path") 43 | 44 | args = parser.parse_args() 45 | args.is_train = False 46 | 47 | args.multiGPU = args.device == 5 48 | if not args.multiGPU: 49 | torch.cuda.set_device(args.device) 50 | 51 | args.val_txt = os.path.join(args.davis_dir, "ImageSets/2017/val.txt") 52 | args.davis_dir = os.path.join(args.davis_dir, "JPEGImages/480p/") 53 | 54 | return args 55 | 56 | ############################## testing functions ############################## 57 | 58 | def forward(frame1, frame2, model, seg): 59 | """ 60 | propagate seg of frame1 to frame2 61 | """ 62 | n, c, h, w = frame1.size() 63 | frame1_gray = frame1[:,0].view(n,1,h,w) 64 | frame2_gray = frame2[:,0].view(n,1,h,w) 65 | frame1_gray = frame1_gray.repeat(1,3,1,1) 66 | frame2_gray = frame2_gray.repeat(1,3,1,1) 67 | 68 | output = model(frame1_gray, frame2_gray, frame1, frame2) 69 | aff = output[2] 70 | 71 | frame2_seg = transform_topk(aff,seg.cuda(),k=args.topk) 72 | 73 | return frame2_seg 74 | 75 | def test(model, frame_list, video_dir, first_seg, seg_ori): 76 | """ 77 | test on a video given first frame & segmentation 78 | """ 79 | video_dir = os.path.join(video_dir) 80 | video_nm = video_dir.split('/')[-1] 81 | video_folder = os.path.join(args.out_dir, video_nm) 82 | os.makedirs(video_folder, exist_ok = True) 83 | 84 | transforms = create_transforms() 85 | 86 | # The queue stores args.pre_num preceding frames 87 | que = queue.Queue(args.pre_num) 88 | 89 | # first frame 90 | frame1, ori_h, ori_w = read_frame(frame_list[0], transforms, args.scale_size) 91 | n, c, h, w = frame1.size() 92 | 93 | # saving first segmentation 94 | out_path = os.path.join(video_folder,"00000.png") 95 | imwrite_indexed(out_path, seg_ori) 96 | 97 | for cnt in tqdm(range(1,len(frame_list))): 98 | frame_tar, ori_h, ori_w = read_frame(frame_list[cnt], transforms, args.scale_size) 99 | 100 | with torch.no_grad(): 101 | # frame 1 -> frame cnt 102 | frame_tar_acc = forward(frame1, frame_tar, model, first_seg) 103 | 104 | # frame cnt - i -> frame cnt, (i = 1, ..., pre_num) 105 | tmp_queue = list(que.queue) 106 | for pair in tmp_queue: 107 | framei = pair[0] 108 | segi = pair[1] 109 | frame_tar_est_i = forward(framei, frame_tar, model, segi) 110 | frame_tar_acc += frame_tar_est_i 111 | frame_tar_avg = frame_tar_acc / (1 + len(tmp_queue)) 112 | 113 | frame_nm = frame_list[cnt].split('/')[-1].replace(".jpg",".png") 114 | out_path = os.path.join(video_folder,frame_nm) 115 | 116 | # pop out oldest frame if neccessary 117 | if(que.qsize() == args.pre_num): 118 | que.get() 119 | # push current results into queue 120 | seg = copy.deepcopy(frame_tar_avg) 121 | frame, ori_h, ori_w = read_frame(frame_list[cnt], transforms, args.scale_size) 122 | que.put([frame,seg]) 123 | 124 | # upsampling & argmax 125 | frame_tar_avg = torch.nn.functional.interpolate(frame_tar_avg,scale_factor=8,mode='bilinear') 126 | frame_tar_avg = frame_tar_avg.squeeze() 127 | frame_tar_avg = norm_mask(frame_tar_avg.squeeze()) 128 | _, frame_tar_seg = torch.max(frame_tar_avg, dim=0) 129 | 130 | # saving to disk 131 | frame_tar_seg = frame_tar_seg.squeeze().cpu().numpy() 132 | frame_tar_seg = np.array(frame_tar_seg, dtype=np.uint8) 133 | frame_tar_seg = scipy.misc.imresize(frame_tar_seg, (ori_h, ori_w), "nearest") 134 | 135 | output_path = os.path.join(video_folder, frame_nm.split('.')[0]+'_seg.png') 136 | imwrite_indexed(out_path,frame_tar_seg) 137 | 138 | ############################## main function ############################## 139 | 140 | if(__name__ == '__main__'): 141 | args = parse_args() 142 | with open(args.val_txt) as f: 143 | lines = f.readlines() 144 | f.close() 145 | 146 | # loading pretrained model 147 | model = Model(pretrainRes=False, temp = args.temp, uselayer=4) 148 | if(args.multiGPU): 149 | model = nn.DataParallel(model) 150 | checkpoint = torch.load(args.checkpoint_dir) 151 | best_loss = checkpoint['best_loss'] 152 | model.load_state_dict(checkpoint['state_dict']) 153 | print("=> loaded checkpoint '{} ({})' (epoch {})" 154 | .format(args.checkpoint_dir, best_loss, checkpoint['epoch'])) 155 | model.cuda() 156 | model.eval() 157 | 158 | 159 | # start testing 160 | for cnt,line in enumerate(lines): 161 | video_nm = line.strip() 162 | print('[{:n}/{:n}] Begin to segmentate video {}.'.format(cnt,len(lines),video_nm)) 163 | 164 | video_dir = os.path.join(args.davis_dir, video_nm) 165 | frame_list = read_frame_list(video_dir) 166 | seg_dir = frame_list[0].replace("JPEGImages","Annotations") 167 | seg_dir = seg_dir.replace("jpg","png") 168 | _, first_seg, seg_ori = read_seg(seg_dir, args.scale_size) 169 | test(model, frame_list, video_dir, first_seg, seg_ori) 170 | -------------------------------------------------------------------------------- /test_with_track.py: -------------------------------------------------------------------------------- 1 | # OS libraries 2 | import os 3 | import cv2 4 | import glob 5 | import copy 6 | import math 7 | import queue 8 | import argparse 9 | import scipy.misc 10 | import numpy as np 11 | from tqdm import tqdm 12 | from PIL import Image 13 | 14 | # Pytorch libraries 15 | import torch 16 | import torch.nn as nn 17 | 18 | # Customized libraries 19 | from libs.test_utils import * 20 | from libs.model import transform 21 | from libs.vis_utils import norm_mask 22 | import libs.transforms_pair as transforms 23 | from libs.model import Model_switchGTfixdot_swCC_Res as Model 24 | from libs.track_utils import seg2bbox, draw_bbox, match_ref_tar 25 | from libs.track_utils import squeeze_all, seg2bbox_v2, bbox_in_tar_scale 26 | 27 | ############################## helper functions ############################## 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--batch_size", type = int, default = 1, 32 | help = "batch size") 33 | parser.add_argument("-o","--out_dir", type = str,default="results_with_track/", 34 | help = "output path") 35 | parser.add_argument("--device", type = int, default = 5, 36 | help="0~4 for single GPU, 5 for dataparallel.") 37 | parser.add_argument("-c","--checkpoint_dir",type = str, 38 | default = "weights/checkpoint_latest.pth.tar", 39 | help = "checkpoints path") 40 | parser.add_argument("-s", "--scale_size", type = int, nargs = '+', 41 | help = "scale size, either a single number for short edge, or a pair for height and width") 42 | parser.add_argument("--pre_num", type = int, default = 7, 43 | help = "preceding frame numbers") 44 | parser.add_argument("--temp", type = float, default = 1, 45 | help = "softmax temperature") 46 | parser.add_argument("-t", "--topk", type = int, default = 5, 47 | help = "accumulate label from top k neighbors") 48 | parser.add_argument("-d", "--davis_dir", type = str, 49 | default = "/workspace/DAVIS/", 50 | help = "davis dataset path") 51 | 52 | print("Begin parser arguments.") 53 | args = parser.parse_args() 54 | args.is_train = False 55 | 56 | args.multiGPU = args.device == 5 57 | if not args.multiGPU: 58 | torch.cuda.set_device(args.device) 59 | 60 | args.val_txt = os.path.join(args.davis_dir, "ImageSets/2017/val.txt") 61 | args.davis_dir = os.path.join(args.davis_dir, "JPEGImages/480p/") 62 | return args 63 | 64 | def vis_bbox(im, bbox, name, coords, seg): 65 | im = im * 128 + 128 66 | im = im.squeeze().permute(1,2,0).cpu().numpy().astype(np.uint8) 67 | im = cv2.cvtColor(im, cv2.COLOR_LAB2BGR) 68 | fg_idx = seg.nonzero() 69 | im = draw_bbox(im, bbox, (0,0,255)) 70 | 71 | for cnt in range(coords.size(0)): 72 | coord_i = coords[cnt] 73 | 74 | cv2.circle(im, (int(coord_i[0]*8), int(coord_i[1]*8)), 2, (0,255,0), thickness=-1) 75 | cv2.imwrite(name, im) 76 | 77 | ############################## tracking functions ############################## 78 | 79 | def adjust_bbox(bbox_now, bbox_pre, a, h, w): 80 | """ 81 | Adjust a bounding box w.r.t previous frame, 82 | assuming objects don't go under abrupt changes. 83 | """ 84 | for cnt in bbox_pre.keys(): 85 | if(cnt == 0): 86 | continue 87 | if(cnt in bbox_now and bbox_pre[cnt] is not None and bbox_now[cnt] is not None): 88 | bbox_now_h = (bbox_now[cnt].top + bbox_now[cnt].bottom) / 2.0 89 | bbox_now_w = (bbox_now[cnt].left + bbox_now[cnt].right) / 2.0 90 | 91 | bbox_now_height_ = bbox_now[cnt].bottom - bbox_now[cnt].top 92 | bbox_now_width_ = bbox_now[cnt].right - bbox_now[cnt].left 93 | 94 | bbox_pre_height = bbox_pre[cnt].bottom - bbox_pre[cnt].top 95 | bbox_pre_width = bbox_pre[cnt].right - bbox_pre[cnt].left 96 | 97 | bbox_now_height = a * bbox_now_height_ + (1 - a) * bbox_pre_height 98 | bbox_now_width = a * bbox_now_width_ + (1 - a) * bbox_pre_width 99 | 100 | bbox_now[cnt].left = math.floor(bbox_now_w - bbox_now_width / 2.0) 101 | bbox_now[cnt].right = math.ceil(bbox_now_w + bbox_now_width / 2.0) 102 | bbox_now[cnt].top = math.floor(bbox_now_h - bbox_now_height / 2.0) 103 | bbox_now[cnt].bottom = math.ceil(bbox_now_h + bbox_now_height / 2.0) 104 | 105 | bbox_now[cnt].left = max(0, bbox_now[cnt].left) 106 | bbox_now[cnt].right = min(w, bbox_now[cnt].right) 107 | bbox_now[cnt].top = max(0, bbox_now[cnt].top) 108 | bbox_now[cnt].bottom = min(h, bbox_now[cnt].bottom) 109 | 110 | return bbox_now 111 | 112 | def bbox_next_frame(img_ref, seg_ref, img_tar, bbox_ref): 113 | """ 114 | Match bbox from the reference frame to the target frame 115 | """ 116 | F_ref, F_tar = forward(img_ref, img_tar, model, seg_ref, return_feature=True) 117 | seg_ref = seg_ref.squeeze(0) 118 | F_ref, F_tar = squeeze_all(F_ref, F_tar) 119 | c, h, w = F_ref.size() 120 | 121 | # get coordinates of each point in the target frame 122 | coords_ref_tar = match_ref_tar(F_ref, F_tar, seg_ref, args.temp) 123 | # coordinates -> bbox 124 | bbox_tar = bbox_in_tar_scale(coords_ref_tar, bbox_ref, h, w) 125 | # adjust bbox 126 | bbox_tar = adjust_bbox(bbox_tar, bbox_ref, 0.1, h, w) 127 | return bbox_tar, coords_ref_tar 128 | 129 | def recoginition(img_ref, img_tar, bbox_ref, bbox_tar, seg_ref, model): 130 | """ 131 | propagate from bbox in the reference frame to bbox in the target frame 132 | """ 133 | F_ref, F_tar = forward(img_ref, img_tar, model, seg_ref, return_feature=True) 134 | seg_ref = seg_ref.squeeze() 135 | _, c, h, w = F_tar.size() 136 | seg_pred = torch.zeros(seg_ref.size()) 137 | 138 | # calculate affinity only once to save time 139 | aff_whole = torch.mm(F_ref.view(c,-1).permute(1,0), F_tar.view(c,-1)) 140 | aff_whole = torch.nn.functional.softmax(aff_whole * args.temp, dim=0) 141 | 142 | for cnt, br in bbox_ref.items(): 143 | if not (cnt in bbox_tar): 144 | continue 145 | bt = bbox_tar[cnt] 146 | if(br is None or bt is None): 147 | continue 148 | seg_cnt = seg_ref[cnt] 149 | 150 | # affinity between two patches 151 | seg_ref_box = seg_cnt[br.top:br.bottom, br.left:br.right] 152 | seg_ref_box = seg_ref_box.unsqueeze(0).unsqueeze(0) 153 | 154 | h, w = F_ref.size(2), F_ref.size(3) 155 | mask = torch.zeros(h,w) 156 | mask[br.top:br.bottom, br.left:br.right] = 1 157 | mask = mask.view(-1) 158 | aff_row = aff_whole[mask.nonzero().squeeze(), :] 159 | 160 | h, w = F_tar.size(2), F_tar.size(3) 161 | mask = torch.zeros(h,w) 162 | mask[bt.top:bt.bottom, bt.left:bt.right] = 1 163 | mask = mask.view(-1) 164 | aff = aff_row[:, mask.nonzero().squeeze()] 165 | 166 | aff = aff.unsqueeze(0) 167 | 168 | seg_tar_box = transform_topk(aff,seg_ref_box.cuda(),k=args.topk, 169 | h2 = bt.bottom - bt.top,w2 = bt.right - bt.left) 170 | seg_pred[cnt, bt.top:bt.bottom, bt.left:bt.right] = seg_tar_box 171 | 172 | return seg_pred 173 | 174 | def disappear(seg,bbox_ref,bbox_tar=None): 175 | """ 176 | Check if bbox disappear in the target frame. 177 | """ 178 | b,c,h,w = seg.size() 179 | for cnt in range(c): 180 | if(torch.sum(seg[:,cnt,:,:]) < 3 or (not (cnt in bbox_ref))): 181 | return True 182 | if(bbox_ref[cnt] is None): 183 | return True 184 | if(bbox_ref[cnt].right - bbox_ref[cnt].left < 3 or bbox_ref[cnt].bottom - bbox_ref[cnt].top < 3): 185 | return True 186 | 187 | if(bbox_tar is not None): 188 | if(cnt not in bbox_tar.keys()): 189 | return True 190 | if(bbox_tar[cnt] is None): 191 | return True 192 | if(bbox_tar[cnt].right - bbox_tar[cnt].left < 3 or bbox_tar[cnt].bottom - bbox_tar[cnt].top < 3): 193 | return True 194 | return False 195 | 196 | ############################## testing functions ############################## 197 | 198 | def forward(frame1, frame2, model, seg, return_feature=False): 199 | n, c, h, w = frame1.size() 200 | frame1_gray = frame1[:,0].view(n,1,h,w) 201 | frame2_gray = frame2[:,0].view(n,1,h,w) 202 | frame1_gray = frame1_gray.repeat(1,3,1,1) 203 | frame2_gray = frame2_gray.repeat(1,3,1,1) 204 | 205 | output = model(frame1_gray, frame2_gray, frame1, frame2) 206 | if(return_feature): 207 | return output[-2], output[-1] 208 | 209 | aff = output[2] 210 | frame2_seg = transform_topk(aff,seg.cuda(),k=args.topk) 211 | 212 | return frame2_seg 213 | 214 | def test(model, frame_list, video_dir, first_seg, large_seg, first_bbox, seg_ori): 215 | video_dir = os.path.join(video_dir) 216 | video_nm = video_dir.split('/')[-1] 217 | video_folder = os.path.join(args.out_dir, video_nm) 218 | os.makedirs(video_folder, exist_ok = True) 219 | os.makedirs(os.path.join(video_folder, 'track'), exist_ok = True) 220 | 221 | transforms = create_transforms() 222 | 223 | # The queue stores `pre_num` preceding frames 224 | que = queue.Queue(args.pre_num) 225 | 226 | # frame 1 227 | frame1, ori_h, ori_w = read_frame(frame_list[0], transforms, args.scale_size) 228 | n, c, h, w = frame1.size() 229 | 230 | # saving first segmentation 231 | out_path = os.path.join(video_folder,"00000.png") 232 | imwrite_indexed(out_path, seg_ori) 233 | 234 | coords = first_seg[0,1].nonzero() 235 | coords = coords.flip(1) 236 | 237 | for cnt in tqdm(range(1,len(frame_list))): 238 | frame_tar, ori_h, ori_w = read_frame(frame_list[cnt], transforms, args.scale_size) 239 | 240 | with torch.no_grad(): 241 | 242 | tmp_list = list(que.queue) 243 | if(len(tmp_list) > 0): 244 | pair = tmp_list[-1] 245 | framei = pair[0] 246 | segi = pair[1] 247 | bbox_pre = pair[2] 248 | else: 249 | bbox_pre = first_bbox 250 | framei = frame1 251 | segi = first_seg 252 | _, segi_int = torch.max(segi, dim=1) 253 | segi = to_one_hot(segi_int) 254 | bbox_tar, coords_ref_tar = bbox_next_frame(framei, segi, frame_tar, bbox_pre) 255 | 256 | if(bbox_tar is not None): 257 | 258 | if(1 in bbox_tar): 259 | tmp = copy.deepcopy(bbox_tar[1]) 260 | if(tmp is not None): 261 | tmp.upscale(8) 262 | vis_bbox(frame_tar, tmp, os.path.join(video_folder, 'track', 'frame'+str(cnt+1)+'.png'), coords_ref_tar[1], segi[0,1,:,:]) 263 | frame_tar_acc = recoginition(frame1, frame_tar, first_bbox, bbox_tar, first_seg, model) 264 | else: 265 | frame_tar_acc = forward(frame1, frame_tar, model, first_seg) 266 | frame_tar_acc = frame_tar_acc.cpu() 267 | 268 | 269 | # previous 7 frames 270 | tmp_queue = list(que.queue) 271 | for pair in tmp_queue: 272 | framei = pair[0] 273 | segi = pair[1] 274 | bboxi = pair[2] 275 | 276 | if(bbox_tar is None or disappear(segi, bboxi, bbox_tar)): 277 | frame_tar_est_i = forward(framei, frame_tar, model, segi) 278 | frame_tar_est_i = frame_tar_est_i.cpu() 279 | else: 280 | 281 | frame_tar_est_i = recoginition(framei, frame_tar, bboxi, bbox_tar, segi, model) 282 | 283 | frame_tar_acc += frame_tar_est_i.cpu().view(frame_tar_acc.size()) 284 | frame_tar_avg = frame_tar_acc / (1 + len(tmp_queue)) 285 | 286 | frame_nm = frame_list[cnt].split('/')[-1].replace(".jpg",".png") 287 | out_path = os.path.join(video_folder,frame_nm) 288 | 289 | # upsampling & argmax 290 | if(frame_tar_avg.dim() == 3): 291 | frame_tar_avg = frame_tar_avg.unsqueeze(0) 292 | elif(frame_tar_avg.dim() == 2): 293 | frame_tar_avg = frame_tar_avg.unsqueeze(0).unsqueeze(0) 294 | frame_tar_up = torch.nn.functional.interpolate(frame_tar_avg,scale_factor=8,mode='bilinear') 295 | 296 | frame_tar_up = frame_tar_up.squeeze() 297 | frame_tar_up = norm_mask(frame_tar_up.squeeze()) 298 | _, frame_tar_seg = torch.max(frame_tar_up.squeeze(), dim=0) 299 | 300 | frame_tar_seg = frame_tar_seg.squeeze().cpu().numpy() 301 | frame_tar_seg = np.array(frame_tar_seg, dtype=np.uint8) 302 | frame_tar_seg = scipy.misc.imresize(frame_tar_seg, (ori_h, ori_w), "nearest") 303 | imwrite_indexed(out_path,frame_tar_seg) 304 | 305 | if(que.qsize() == args.pre_num): 306 | que.get() 307 | seg = copy.deepcopy(frame_tar_avg.squeeze()) 308 | frame, ori_h, ori_w = read_frame(frame_list[cnt], transforms, args.scale_size) 309 | bbox_tar = seg2bbox_v2(frame_tar_up.cpu(), bbox_pre) 310 | bbox_tar = adjust_bbox(bbox_tar, bbox_pre, 0.1, h, w) 311 | que.put([frame,seg.unsqueeze(0),bbox_tar]) 312 | 313 | if(__name__ == '__main__'): 314 | args = parse_args() 315 | with open(args.val_txt) as f: 316 | lines = f.readlines() 317 | f.close() 318 | 319 | model = Model(pretrainRes=False, temp = args.temp, uselayer=4) 320 | if(args.multiGPU): 321 | model = nn.DataParallel(model) 322 | checkpoint = torch.load(args.checkpoint_dir) 323 | best_loss = checkpoint['best_loss'] 324 | model.load_state_dict(checkpoint['state_dict']) 325 | print("=> loaded checkpoint '{} ({})' (epoch {})" 326 | .format(args.checkpoint_dir, best_loss, checkpoint['epoch'])) 327 | model.cuda() 328 | model.eval() 329 | 330 | for cnt,line in enumerate(lines): 331 | video_nm = line.strip() 332 | print('[{:n}/{:n}] Begin to segmentate video {}.'.format(cnt,len(lines),video_nm)) 333 | 334 | video_dir = os.path.join(args.davis_dir, video_nm) 335 | frame_list = read_frame_list(video_dir) 336 | seg_dir = frame_list[0].replace("JPEGImages","Annotations") 337 | seg_dir = seg_dir.replace("jpg","png") 338 | large_seg, first_seg, seg_ori = read_seg(seg_dir, args.scale_size) 339 | 340 | first_bbox = seg2bbox(large_seg, margin=0.6) 341 | for k,v in first_bbox.items(): 342 | v.upscale(0.125) 343 | 344 | test(model, frame_list, video_dir, first_seg, large_seg, first_bbox, seg_ori) 345 | -------------------------------------------------------------------------------- /track_match_v1.py: -------------------------------------------------------------------------------- 1 | # a combination of track and match 2 | # 1. load fullres images, resize to 640**2 3 | # 2. warmup: set random location for crop 4 | # 3. loc-match: add attention 5 | import os 6 | import cv2 7 | import sys 8 | import time 9 | import torch 10 | import logging 11 | import argparse 12 | import numpy as np 13 | import torch.nn as nn 14 | from libs.loader import VidListv1, VidListv2 15 | import torch.backends.cudnn as cudnn 16 | import libs.transforms_multi as transforms 17 | from model import track_match_comb as Model 18 | from libs.loss import L1_loss 19 | from libs.concentration_loss import ConcentrationSwitchLoss as ConcentrationLoss 20 | from libs.train_utils import save_vis, AverageMeter, save_checkpoint, log_current 21 | from libs.utils import diff_crop 22 | 23 | 24 | FORMAT = "[%(asctime)-15s %(filename)s:%(lineno)d %(funcName)s] %(message)s" 25 | logging.basicConfig(format=FORMAT) 26 | logger = logging.getLogger(__name__) 27 | logger.setLevel(logging.DEBUG) 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser(description='') 31 | 32 | # file/folder pathes 33 | parser.add_argument("--videoRoot", type=str, default="/Data2/Kinetices/compress/train_256/", help='train video path') 34 | parser.add_argument("--videoList", type=str, default="/Data2/Kinetices/compress/train.txt", help='train video list (after "train_256")') 35 | parser.add_argument("--encoder_dir",type=str, default='ae_small/encoder_single_gpu.pth', help="pretrained encoder") 36 | parser.add_argument("--decoder_dir",type=str, default='ae_small/decoder_single_gpu.pth', help="pretrained decoder") 37 | parser.add_argument('--resume', type=str, default='', metavar='PATH', help='path to latest checkpoint (default: none)') 38 | parser.add_argument("-c","--savedir",type=str,default="match_track_comb/",help='checkpoints path') 39 | parser.add_argument("--Resnet", type=str, default="r18", help="choose from r18 or r50") 40 | 41 | # main parameters 42 | parser.add_argument("--pretrainRes",action="store_true") 43 | parser.add_argument("--batchsize",type=int, default=1, help="batchsize") 44 | parser.add_argument('--workers', type=int, default=16) 45 | parser.add_argument("--patch_size", type=int, default=256, help="crop size for localization.") 46 | parser.add_argument("--full_size", type=int, default=640, help="full size for one frame.") 47 | parser.add_argument("--rotate",type=int,default=10,help='degree to rotate training images') 48 | parser.add_argument("--scale",type=float,default=1.2,help='random scale') 49 | parser.add_argument("--lr",type=float,default=0.0001,help='learning rate') 50 | parser.add_argument('--lr-mode', type=str, default='poly') 51 | parser.add_argument("--window_len",type=int,default=2,help='number of images (2 for pair and 3 for triple)') 52 | parser.add_argument("--log_interval",type=int,default=10,help='') 53 | parser.add_argument("--save_interval",type=int,default=1000,help='save every x epoch') 54 | parser.add_argument("--momentum",type=float,default=0.9,help='momentum') 55 | parser.add_argument("--weight_decay",type=float,default=0.005,help='weight decay') 56 | parser.add_argument("--device", type=int, default=4, help="0~device_count-1 for single GPU, device_count for dataparallel.") 57 | parser.add_argument("--temp", type=int, default=1, help="temprature for softmax.") 58 | 59 | # set epoches 60 | parser.add_argument("--wepoch",type=int,default=10,help='warmup epoch') 61 | parser.add_argument("--nepoch",type=int,default=20,help='max epoch') 62 | 63 | # concenration regularization 64 | parser.add_argument("--lc",type=float,default=1e4, help='weight of concentration loss') 65 | parser.add_argument("--lc_win",type=int,default=8, help='win_len for concentration loss') 66 | 67 | # orthorganal regularization 68 | parser.add_argument("--color_switch",type=float,default=0.1, help='weight of color switch loss') 69 | parser.add_argument("--coord_switch",type=float,default=0.1, help='weight of color switch loss') 70 | 71 | 72 | print("Begin parser arguments.") 73 | args = parser.parse_args() 74 | assert args.videoRoot is not None 75 | assert args.videoList is not None 76 | if not os.path.exists(args.savedir): 77 | os.mkdir(args.savedir) 78 | args.savepatch = os.path.join(args.savedir,'savepatch') 79 | args.logfile = open(os.path.join(args.savedir,"logargs.txt"),"w") 80 | args.multiGPU = args.device == torch.cuda.device_count() 81 | 82 | if not args.multiGPU: 83 | torch.cuda.set_device(args.device) 84 | if not os.path.exists(args.savepatch): 85 | os.mkdir(args.savepatch) 86 | 87 | args.vis = True 88 | if args.color_switch > 0: 89 | args.color_switch_flag = True 90 | else: 91 | args.color_switch_flag = False 92 | if args.coord_switch > 0: 93 | args.coord_switch_flag = True 94 | else: 95 | args.coord_switch_flag = False 96 | 97 | try: 98 | from tensorboardX import SummaryWriter 99 | global writer 100 | writer = SummaryWriter() 101 | except ImportError: 102 | args.vis = False 103 | print(' '.join(sys.argv)) 104 | print('\n') 105 | args.logfile.write(' '.join(sys.argv)) 106 | args.logfile.write('\n') 107 | 108 | for k, v in args.__dict__.items(): 109 | print(k, ':', v) 110 | args.logfile.write('{}:{}\n'.format(k,v)) 111 | args.logfile.close() 112 | return args 113 | 114 | 115 | def adjust_learning_rate(args, optimizer, epoch): 116 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 117 | if args.lr_mode == 'step': 118 | lr = args.lr * (0.1 ** (epoch // args.step)) 119 | elif args.lr_mode == 'poly': 120 | lr = args.lr * (1 - epoch / args.nepoch) ** 0.9 121 | else: 122 | raise ValueError('Unknown lr mode {}'.format(args.lr_mode)) 123 | 124 | for param_group in optimizer.param_groups: 125 | param_group['lr'] = lr 126 | return lr 127 | 128 | 129 | def create_loader(args): 130 | dataset_train_warm = VidListv1(args.videoRoot, args.videoList, args.patch_size, args.rotate, args.scale) 131 | dataset_train = VidListv2(args.videoRoot, args.videoList, args.patch_size, args.window_len, args.rotate, args.scale, args.full_size) 132 | 133 | if args.multiGPU: 134 | train_loader_warm = torch.utils.data.DataLoader( 135 | dataset_train_warm, batch_size=args.batchsize, shuffle = True, num_workers=args.workers, pin_memory=True, drop_last=True) 136 | train_loader = torch.utils.data.DataLoader( 137 | dataset_train, batch_size=args.batchsize, shuffle = True, num_workers=args.workers, pin_memory=True, drop_last=True) 138 | else: 139 | train_loader_warm = torch.utils.data.DataLoader( 140 | dataset_train_warm, batch_size=args.batchsize, shuffle = True, num_workers=0, drop_last=True) 141 | train_loader = torch.utils.data.DataLoader( 142 | dataset_train, batch_size=args.batchsize, shuffle = True, num_workers=0, drop_last=True) 143 | return train_loader_warm, train_loader 144 | 145 | 146 | def train(args): 147 | loader_warm, loader = create_loader(args) 148 | cudnn.benchmark = True 149 | best_loss = 1e10 150 | start_epoch = 0 151 | 152 | model = Model(args.pretrainRes, args.encoder_dir, args.decoder_dir, temp = args.temp, Resnet = args.Resnet, color_switch = args.color_switch_flag, coord_switch = args.coord_switch_flag) 153 | 154 | if args.multiGPU: 155 | model = torch.nn.DataParallel(model).cuda() 156 | closs = ConcentrationLoss(win_len=args.lc_win, stride=args.lc_win, 157 | F_size=torch.Size((args.batchsize//torch.cuda.device_count(),2, args.patch_size//8, args.patch_size//8)), temp = args.temp) 158 | closs = nn.DataParallel(closs).cuda() 159 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model._modules['module'].parameters()),args.lr) 160 | else: 161 | closs = ConcentrationLoss(win_len=args.lc_win, stride=args.lc_win, 162 | F_size=torch.Size((args.batchsize,2, 163 | args.patch_size//8, 164 | args.patch_size//8)), temp = args.temp) 165 | model.cuda() 166 | closs.cuda() 167 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),args.lr) 168 | 169 | if args.resume: 170 | if os.path.isfile(args.resume): 171 | print("=> loading checkpoint '{}'".format(args.resume)) 172 | checkpoint = torch.load(args.resume) 173 | start_epoch = checkpoint['epoch'] 174 | best_loss = checkpoint['best_loss'] 175 | model.load_state_dict(checkpoint['state_dict']) 176 | print("=> loaded checkpoint '{} ({})' (epoch {})" 177 | .format(args.resume, best_loss, checkpoint['epoch'])) 178 | else: 179 | print("=> no checkpoint found at '{}'".format(args.resume)) 180 | 181 | for epoch in range(start_epoch, args.nepoch): 182 | if epoch < args.wepoch: 183 | lr = adjust_learning_rate(args, optimizer, epoch) 184 | print("Base lr for epoch {}: {}.".format(epoch, optimizer.param_groups[0]['lr'])) 185 | best_loss = train_iter(args, loader_warm, model, closs, optimizer, epoch, best_loss) 186 | else: 187 | lr = adjust_learning_rate(args, optimizer, epoch-args.wepoch) 188 | print("Base lr for epoch {}: {}.".format(epoch, optimizer.param_groups[0]['lr'])) 189 | best_loss = train_iter(args, loader, model, closs, optimizer, epoch, best_loss) 190 | 191 | 192 | def forward(frame1, frame2, model, warm_up, patch_size=None): 193 | n, c, h, w = frame1.size() 194 | if warm_up: 195 | output = model(frame1, frame2) 196 | else: 197 | output = model(frame1, frame2, warm_up=False, patch_size=[patch_size//8, patch_size//8]) 198 | new_c = output[2] 199 | # gt patch 200 | # print("HERE2: ", frame2.size(), new_c, patch_size) 201 | color2_gt = diff_crop(frame2, new_c[:,0], new_c[:,2], new_c[:,1], new_c[:,3], 202 | patch_size, patch_size) 203 | output.append(color2_gt) 204 | return output 205 | 206 | 207 | def train_iter(args, loader, model, closs, optimizer, epoch, best_loss): 208 | losses = AverageMeter() 209 | batch_time = AverageMeter() 210 | losses = AverageMeter() 211 | c_losses = AverageMeter() 212 | model.train() 213 | end = time.time() 214 | if args.coord_switch_flag: 215 | coord_switch_loss = nn.L1Loss() 216 | sc_losses = AverageMeter() 217 | 218 | if epoch < 1 or (epoch>=args.wepoch and epoch< args.wepoch+2): 219 | thr = None 220 | else: 221 | thr = 2.5 222 | 223 | for i,frames in enumerate(loader): 224 | frame1_var = frames[0].cuda() 225 | frame2_var = frames[1].cuda() 226 | 227 | if epoch < args.wepoch: 228 | output = forward(frame1_var, frame2_var, model, warm_up=True) 229 | color2_est = output[0] 230 | aff = output[1] 231 | b,x,_ = aff.size() 232 | color1_est = None 233 | if args.color_switch_flag: 234 | color1_est = output[2] 235 | loss_ = L1_loss(color2_est, frame2_var, 10, 10, thr=thr, pred1=color1_est, frame1_var = frame1_var) 236 | 237 | if epoch >=1 and args.lc > 0: 238 | constraint_loss = torch.sum(closs(aff.view(b,1,x,x))) * args.lc 239 | c_losses.update(constraint_loss.item(), frame1_var.size(0)) 240 | loss = loss_ + constraint_loss 241 | else: 242 | loss = loss_ 243 | if(i % args.log_interval == 0): 244 | save_vis(color2_est, frame2_var, frame1_var, frame2_var, args.savepatch) 245 | else: 246 | # print("input: ", frame1_var.size(), frame2_var.size()) 247 | output = forward(frame1_var, frame2_var, model, warm_up=False, patch_size = args.patch_size) 248 | color2_est = output[0] 249 | aff = output[1] 250 | new_c = output[2] 251 | coords = output[3] 252 | Fcolor2_crop = output[-1] 253 | 254 | b,x,x = aff.size() 255 | color1_est = None 256 | count = 3 257 | 258 | constraint_loss = torch.sum(closs(aff.view(b,1,x,x))) * args.lc 259 | c_losses.update(constraint_loss.item(), frame1_var.size(0)) 260 | 261 | if args.color_switch_flag: 262 | count += 1 263 | color1_est = output[count] 264 | 265 | loss_color = L1_loss(color2_est, Fcolor2_crop, 10, 10, thr=thr, pred1=color1_est, frame1_var = frame1_var) 266 | loss_ = loss_color + constraint_loss 267 | 268 | if args.coord_switch_flag: 269 | count += 1 270 | grids = output[count] 271 | C11 = output[count+1] 272 | loss_coord = args.coord_switch * coord_switch_loss(C11, grids) 273 | loss = loss_ + loss_coord 274 | sc_losses.update(loss_coord.item(), frame1_var.size(0)) 275 | else: 276 | loss = loss_ 277 | 278 | if(i % args.log_interval == 0): 279 | save_vis(color2_est, Fcolor2_crop, frame1_var, frame2_var, args.savepatch, new_c) 280 | 281 | losses.update(loss.item(), frame1_var.size(0)) 282 | optimizer.zero_grad() 283 | loss.backward() 284 | optimizer.step() 285 | batch_time.update(time.time() - end) 286 | end = time.time() 287 | 288 | if epoch >= args.wepoch and args.coord_switch_flag: 289 | logger.info('Epoch: [{0}][{1}/{2}]\t' 290 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 291 | 'Color Loss {loss.val:.4f} ({loss.avg:.4f})\t ' 292 | 'Coord switch Loss {scloss.val:.4f} ({scloss.avg:.4f})\t ' 293 | 'Constraint Loss {c_loss.val:.4f} ({c_loss.avg:.4f})\t '.format( 294 | epoch, i+1, len(loader), batch_time=batch_time, loss=losses, scloss=sc_losses, c_loss= c_losses)) 295 | else: 296 | logger.info('Epoch: [{0}][{1}/{2}]\t' 297 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 298 | 'Color Loss {loss.val:.4f} ({loss.avg:.4f})\t ' 299 | 'Constraint Loss {c_loss.val:.4f} ({c_loss.avg:.4f})\t '.format( 300 | epoch, i+1, len(loader), batch_time=batch_time, loss=losses, c_loss= c_losses)) 301 | 302 | if((i + 1) % args.save_interval == 0): 303 | is_best = losses.avg < best_loss 304 | best_loss = min(losses.avg, best_loss) 305 | checkpoint_path = os.path.join(args.savedir, 'checkpoint_latest.pth.tar') 306 | save_checkpoint({ 307 | 'epoch': epoch + 1, 308 | 'state_dict': model.state_dict(), 309 | 'best_loss': best_loss, 310 | }, is_best, filename=checkpoint_path, savedir = args.savedir) 311 | log_current(epoch, losses.avg, best_loss, filename = "log_current.txt", savedir=args.savedir) 312 | 313 | return best_loss 314 | 315 | 316 | if __name__ == '__main__': 317 | args = parse_args() 318 | train(args) 319 | writer.close() -------------------------------------------------------------------------------- /weights/checkpoint_latest.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaolonw/UVC-1/fafb36f1577080e8ecfa09d97dc2c024b04ccdb2/weights/checkpoint_latest.pth.tar -------------------------------------------------------------------------------- /weights/decoder_single_gpu.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaolonw/UVC-1/fafb36f1577080e8ecfa09d97dc2c024b04ccdb2/weights/decoder_single_gpu.pth -------------------------------------------------------------------------------- /weights/encoder_single_gpu.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaolonw/UVC-1/fafb36f1577080e8ecfa09d97dc2c024b04ccdb2/weights/encoder_single_gpu.pth --------------------------------------------------------------------------------