├── README.md
├── docs
├── drift-chiance.gif
├── lab-coat.gif
├── parkour.gif
└── teaser.png
├── libs
├── autoencoder.py
├── concentration_loss.py
├── data
│ ├── db_info.yaml
│ ├── palette.txt
│ └── train.txt
├── loader.py
├── loss.py
├── model.py
├── resnet.py
├── test_utils.py
├── track_utils.py
├── train_utils.py
├── transforms_multi.py
├── transforms_pair.py
├── utils.py
└── vis_utils.py
├── model.py
├── test.py
├── test_with_track.py
├── track_match_v1.py
└── weights
├── checkpoint_latest.pth.tar
├── decoder_single_gpu.pth
└── encoder_single_gpu.pth
/README.md:
--------------------------------------------------------------------------------
1 | # Joint-task Self-supervised Learning for Temporal Correspondence
2 |
3 | [**Project**](https://sites.google.com/view/uvc2019) | [**Paper**]()
4 |
5 | # Overview
6 |
7 | 
8 |
9 | [Joint-task Self-supervised Learning for Temporal Correspondence]()
10 |
11 | [Xueting Li*](https://sunshineatnoon.github.io/), [Sifei Liu*](https://www.sifeiliu.net/), [Shalini De Mello](https://research.nvidia.com/person/shalini-gupta), [Xiaolong Wang](https://www.cs.cmu.edu/~xiaolonw/), [Jan Kautz](http://jankautz.com/), [Ming-Hsuan Yang](http://faculty.ucmerced.edu/mhyang/).
12 |
13 | (* equal contributions)
14 |
15 | In Neural Information Processing Systems (NeurIPS), 2019.
16 |
17 | # Citation
18 | If you use our code in your research, please use the following BibTex:
19 |
20 | ```
21 | @inproceedings{uvc_2019,
22 | Author = {Xueting Li and Sifei Liu and Shalini De Mello and Xiaolong Wang and Jan Kautz and Ming-Hsuan Yang},
23 | Title = {Joint-task Self-supervised Learning for Temporal Correspondence},
24 | Booktitle = {NeurIPS},
25 | Year = {2019},
26 | }
27 | ```
28 |
29 | # Instance segmentation propagation on DAVIS2017
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 | | Method | J_mean | J_recall | J_decay | F_mean | F_recall | F_decay |
38 | | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- |
39 | | Ours | 0.563 | 0.650 | 0.289 | 0.592 | 0.641 | 0.354 |
40 | | Ours - track | 0.577 | 0.683 | 0.263 | 0.613 | 0.698 | 0.324 |
41 |
42 | # Prerequisites
43 | The code is tested in the following environment:
44 | - Ubuntu 16.04
45 | - Pytorch 1.1.0, [tqdm](https://github.com/tqdm/tqdm), scipy 1.2.1
46 |
47 | # Testing on DAVIS2017
48 | ## Testing without tracking
49 | To test on DAVIS2017 for instance segmentation mask propagation, please run:
50 | ```
51 | python test.py -d /workspace/DAVIS/ -s 480
52 | ```
53 | Important parameters:
54 | - `-c`: checkpoint path.
55 | - `-o`: results path.
56 | - `-d`: DAVIS 2017 dataset path.
57 | - `-s`: test resolution, all results in the paper are tested on 480p images, i.e. `-s 480`.
58 |
59 | Please check the `test.py` file for other parameters.
60 |
61 | ## Testing with tracking
62 | To test on DAVIS2017 by tracking & propagation, please run:
63 | ```
64 | python test_with_track.py -d /workspace/DAVIS/ -s 480
65 | ```
66 | Similar parameters as `test.py`, please see the `test_with_track.py` for details.
67 |
68 | # Training on Kinetics
69 |
70 | ## Dataset
71 |
72 | We use the [kinetics dataset](https://deepmind.com/research/open-source/open-source-datasets/kinetics/) for training.
73 |
74 | ## Training command
75 |
76 | ```
77 | python track_match_v1.py --wepoch 10 --nepoch 30 -c match_track_switch --batchsize 40 --coord_switch 0 --lc 0.3
78 | ```
79 |
80 | # Acknowledgements
81 | - This code is based on [TPN](https://arxiv.org/pdf/1804.08758.pdf) and [TimeCycle](https://github.com/xiaolonw/TimeCycle).
82 | - For any issues, please contact xli75@ucmerced.edu or sifeil@nvidia.com.
83 |
--------------------------------------------------------------------------------
/docs/drift-chiance.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaolonw/UVC-1/fafb36f1577080e8ecfa09d97dc2c024b04ccdb2/docs/drift-chiance.gif
--------------------------------------------------------------------------------
/docs/lab-coat.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaolonw/UVC-1/fafb36f1577080e8ecfa09d97dc2c024b04ccdb2/docs/lab-coat.gif
--------------------------------------------------------------------------------
/docs/parkour.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaolonw/UVC-1/fafb36f1577080e8ecfa09d97dc2c024b04ccdb2/docs/parkour.gif
--------------------------------------------------------------------------------
/docs/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaolonw/UVC-1/fafb36f1577080e8ecfa09d97dc2c024b04ccdb2/docs/teaser.png
--------------------------------------------------------------------------------
/libs/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision.models import vgg16
4 | from torch.autograd import Variable
5 | from collections import OrderedDict
6 | import torch.nn.functional as F
7 | from .resnet import resnet18, resnet50
8 |
9 | class encoder3(nn.Module):
10 | def __init__(self, reduce = False):
11 | super(encoder3,self).__init__()
12 | # vgg
13 | # 224 x 224
14 | self.conv1 = nn.Conv2d(3,3,1,1,0)
15 | self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
16 | # 226 x 226
17 |
18 | self.conv2 = nn.Conv2d(3,64,3,1,0)
19 | self.relu2 = nn.ReLU(inplace=True)
20 | # 224 x 224
21 |
22 | self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
23 | self.conv3 = nn.Conv2d(64,64,3,1,0)
24 | self.relu3 = nn.ReLU(inplace=True)
25 | # 224 x 224
26 |
27 | self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
28 | # 112 x 112
29 |
30 | self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
31 | self.conv4 = nn.Conv2d(64,128,3,1,0)
32 | self.relu4 = nn.ReLU(inplace=True)
33 | # 112 x 112
34 |
35 | self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
36 | self.conv5 = nn.Conv2d(128,128,3,1,0)
37 | self.relu5 = nn.ReLU(inplace=True)
38 | # 112 x 112
39 |
40 | self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
41 | # 56 x 56
42 |
43 | self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
44 | self.conv6 = nn.Conv2d(128,256,3,1,0)
45 | self.relu6 = nn.ReLU(inplace=True)
46 | # 56 x 56
47 | self.reduce = reduce
48 | if reduce:
49 | self.downsample = nn.Sequential(nn.MaxPool2d(kernel_size=2,stride=2),
50 | nn.Conv2d(256,256,1,1,0),
51 | nn.ReLU(inplace=True))
52 |
53 | def forward(self,x):
54 | out = self.conv1(x)
55 | out = self.reflecPad1(out)
56 | out = self.conv2(out)
57 | out = self.relu2(out)
58 | out = self.reflecPad3(out)
59 | out = self.conv3(out)
60 | pool1 = self.relu3(out)
61 | out,pool_idx = self.maxPool(pool1)
62 | out = self.reflecPad4(out)
63 | out = self.conv4(out)
64 | out = self.relu4(out)
65 | out = self.reflecPad5(out)
66 | out = self.conv5(out)
67 | pool2 = self.relu5(out)
68 | out,pool_idx2 = self.maxPool2(pool2)
69 | out = self.reflecPad6(out)
70 | out = self.conv6(out)
71 | out = self.relu6(out)
72 | if self.reduce:
73 | out = self.downsample(out)
74 | return out
75 |
76 | class decoder3(nn.Module):
77 | def __init__(self, cls=False, cls_num=32, reduce = False):
78 | """
79 | INPUTS:
80 | - cls: if using classification.
81 | - cls_num: cluster number
82 | """
83 | super(decoder3,self).__init__()
84 | if reduce:
85 | self.upsample = nn.Sequential(nn.UpsamplingNearest2d(scale_factor=2),
86 | nn.Conv2d(256,256,1,1,0),
87 | nn.ReLU(inplace=True))
88 | self.reduce = reduce
89 |
90 | self.cls = cls
91 | # decoder
92 | self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
93 | self.conv7 = nn.Conv2d(256,128,3,1,0)
94 | self.relu7 = nn.ReLU(inplace=True)
95 | # 56 x 56
96 |
97 | self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
98 | # 112 x 112
99 |
100 | self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
101 | self.conv8 = nn.Conv2d(128,128,3,1,0)
102 | self.relu8 = nn.ReLU(inplace=True)
103 | # 112 x 112
104 |
105 | self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
106 | self.conv9 = nn.Conv2d(128,64,3,1,0)
107 | self.relu9 = nn.ReLU(inplace=True)
108 |
109 | self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
110 | # 224 x 224
111 |
112 | self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
113 | self.conv10 = nn.Conv2d(64,64,3,1,0)
114 | self.relu10 = nn.ReLU(inplace=True)
115 |
116 | self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
117 | if not cls:
118 | self.conv11 = nn.Conv2d(64,3,3,1,0)
119 | else:
120 | self.conv11 = nn.Sequential(nn.Conv2d(64,cls_num,3,1,0),
121 | nn.LogSoftmax())
122 |
123 | def forward(self,x):
124 | output = {}
125 | if self.reduce:
126 | x = self.upsample(x)
127 | out = self.reflecPad7(x)
128 | out = self.conv7(out)
129 | out = self.relu7(out)
130 | out = self.unpool(out)
131 | out = self.reflecPad8(out)
132 | out = self.conv8(out)
133 | out = self.relu8(out)
134 | out = self.reflecPad9(out)
135 | out = self.conv9(out)
136 | out_relu9 = self.relu9(out)
137 | out = self.unpool2(out_relu9)
138 | out = self.reflecPad10(out)
139 | out = self.conv10(out)
140 | out = self.relu10(out)
141 | out = self.reflecPad11(out)
142 | out = self.conv11(out)
143 | if not self.cls:
144 | out = torch.tanh(out)
145 | return out
146 |
147 |
148 | def encoder_res18(pretrained = True, uselayer=3):
149 | """Constructs a ResNet-18 model.
150 | Args:
151 | pretrained (bool): If True, returns a model pre-trained on ImageNet
152 | """
153 | model = resnet18(uselayer=uselayer)
154 | if pretrained:
155 | print("Using pretrianed ResNet18 as guide.")
156 | model.load_state_dict(torch.load("ae_models/resnet18-5c106cde.pth"), strict = False)
157 | return model
158 |
--------------------------------------------------------------------------------
/libs/concentration_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | def transform(aff, frame1):
5 | """
6 | Given aff, copy from frame1 to construct frame2.
7 | INPUTS:
8 | - aff: (h*w)*(h*w) affinity matrix
9 | - frame1: n*c*h*w feature map
10 | """
11 | b,c,h,w = frame1.size()
12 | frame1 = frame1.view(b,c,-1)
13 | frame2 = torch.bmm(frame1, aff)
14 | return frame2.view(b,c,h,w)
15 |
16 | def aff2coord(F_size, grid, A = None, temp = None, softmax=None):
17 | """
18 | INPUT:
19 | - A: a (H*W)*(H*W) affinity matrix
20 | - F_size: image feature size
21 | - mode: if mode is coord, return coordinates, else return flow
22 | - grid: a standard grid, see create_grid function.
23 | OUTPUT:
24 | - U: a (2*H*W) coordinate tensor, U_ij indicates the coordinates of pixel ij in target image.
25 | """
26 | grid = grid.permute(0,3,1,2)
27 | if A is not None:
28 | if softmax is not None:
29 | if temp is None:
30 | raise Exception("Need temp for softmax!")
31 | A = softmax(A*temp)
32 | # b x c x h x w
33 | U = transform(A, grid)
34 | else:
35 | U = grid
36 | return U
37 |
38 | def create_grid(F_size, GPU=True):
39 | b, c, h, w = F_size
40 | theta = torch.tensor([[1,0,0],[0,1,0]])
41 | theta = theta.unsqueeze(0).repeat(b,1,1)
42 | theta = theta.float()
43 |
44 | # grid is a uniform grid with left top (-1,1) and right bottom (1,1)
45 | # b * (h*w) * 2
46 | grid = nn.functional.affine_grid(theta, F_size)
47 | if(GPU):
48 | grid = grid.cuda()
49 | return grid
50 |
51 | def im2col(img, win_len, stride=1):
52 | """
53 | INPUTS:
54 | - img: a b*c*h*w feature tensor.
55 | - win_len: each pixel compares with its neighbors within a
56 | (win_len*2+1) * (win_len*2+1) window.
57 | OUTPUT:
58 | - result: a b*c*(h*w)*(win_len*2+1)^2 tensor, unfolded neighbors for each pixel
59 | """
60 | b,c,_,_ = img.size()
61 | # b * (c*w*w) * win_num
62 | unfold_img = torch.nn.functional.unfold(img, win_len, padding=0, stride=stride)
63 | unfold_img = unfold_img.view(b,c,win_len*win_len,-1)
64 | unfold_img = unfold_img.permute(0,1,3,2)
65 | return unfold_img
66 |
67 | class ConcentrationLoss(nn.Module):
68 | def __init__(self, win_len, stride, F_size):
69 | super(ConcentrationLoss, self).__init__()
70 | self.win_len = win_len
71 | self.grid = nn.Parameter(create_grid(F_size), requires_grad = False)
72 | self.F_size = F_size
73 | self.stride = stride
74 |
75 | def forward(self, aff):
76 | b, c, h, w = self.F_size
77 | #if aff.dim() == 4:
78 | # aff = torch.squeeze(aff)
79 | # b * 2 * h * w
80 | coord1 = aff2coord(self.F_size, self.grid, aff)
81 | # b * 2 * (h * w) * (win ^ 2)
82 | coord1_unfold = im2col(coord1, self.win_len, stride = self.stride)
83 | # b * 2 * (h * w) * 1
84 | # center = coord1_unfold[:,:,:,int((self.win_len ** 2)/2)]
85 | center = torch.mean(coord1_unfold, dim = 3)
86 | center = center.view(b, 2, -1, 1)
87 | # b * 2 * (h * w) * (win ^ 2)
88 | dis2center = (coord1_unfold - center) ** 2
89 | return torch.sum(dis2center) / dis2center.numel()
90 |
91 | class ConcentrationDetachLoss(nn.Module):
92 | def __init__(self, win_len, stride, F_size):
93 | super(ConcentrationDetachLoss, self).__init__()
94 | self.win_len = win_len
95 | self.grid = nn.Parameter(create_grid(F_size), requires_grad = False)
96 | self.F_size = F_size
97 | self.stride = stride
98 |
99 | def forward(self, aff):
100 | b, c, h, w = self.F_size
101 | if aff.dim() == 4:
102 | aff = torch.squeeze(aff)
103 | # b * 2 * h * w
104 | coord1 = aff2coord(self.F_size, self.grid, aff)
105 | # b * 2 * (h * w) * (win ^ 2)
106 | coord1_unfold = im2col(coord1, self.win_len, stride = self.stride)
107 | # b * 2 * (h * w) * 1
108 | # center = coord1_unfold[:,:,:,int((self.win_len ** 2)/2)]
109 | center = torch.mean(coord1_unfold, dim = 3).detach()
110 | center = center.view(b, 2, -1, 1)
111 | # b * 2 * (h * w) * (win ^ 2)
112 | dis2center = (coord1_unfold - center) ** 2
113 | return torch.sum(dis2center) / dis2center.numel()
114 |
115 | class ConcentrationSwitchLoss(nn.Module):
116 | def __init__(self, win_len, stride, F_size, temp):
117 | super(ConcentrationSwitchLoss, self).__init__()
118 | self.win_len = win_len
119 | self.grid = nn.Parameter(create_grid(F_size), requires_grad = False)
120 | self.F_size = F_size
121 | self.stride = stride
122 | self.temp = temp
123 | self.softmax = nn.Softmax(dim=1)
124 |
125 | def forward(self, aff):
126 | # aff here is not processed by softmax
127 | b, c, h, w = self.F_size
128 | if aff.dim() == 4:
129 | aff = torch.squeeze(aff)
130 | # b * 2 * h * w
131 | coord1 = aff2coord(self.F_size, self.grid, aff, self.temp, softmax=self.softmax)
132 | coord2 = aff2coord(self.F_size, self.grid, aff.permute(0,2,1), self.temp, softmax=self.softmax)
133 | # b * 2 * (h * w) * (win ^ 2)
134 | coord1_unfold = im2col(coord1, self.win_len, stride = self.stride)
135 | coord2_unfold = im2col(coord2, self.win_len, stride = self.stride)
136 | # b * 2 * (h * w) * 1
137 | center1 = torch.mean(coord1_unfold, dim = 3)
138 | center1 = center1.view(b, 2, -1, 1)
139 | # b * 2 * (h * w) * (win ^ 2)
140 | dis2center1 = (coord1_unfold - center1) ** 2
141 | # b * 2 * (h * w) * 1
142 | center2 = torch.mean(coord2_unfold, dim = 3)
143 | center2 = center2.view(b, 2, -1, 1)
144 | # b * 2 * (h * w) * (win ^ 2)
145 | dis2center2 = (coord2_unfold - center2) ** 2
146 | return (torch.sum(dis2center1) + torch.sum(dis2center2))/ dis2center1.numel()
147 |
148 |
149 | if __name__ == '__main__':
150 | cl = ConcentrationLoss(win_len=8, F_size=torch.Size((1,3,32,32)), stride=8)
151 | aff = torch.Tensor(1,1024,1024)
152 | aff.uniform_()
153 | aff = aff.cuda()
154 | print(cl(aff))
155 |
--------------------------------------------------------------------------------
/libs/data/db_info.yaml:
--------------------------------------------------------------------------------
1 | attributes: [AC, BC, CS, DB, DEF, EA, FM, HO, IO, LR, MB, OCC, OV, ROT, SC, SV]
2 | sets: [train, val, val-dev]
3 | years: [2016, 2017]
4 |
5 | sequences:
6 |
7 | - name: aerobatics
8 | attributes: []
9 | num_frames: 71
10 | set: test-dev
11 | eval_t: False
12 | year: 2017
13 |
14 | - name: bear
15 | attributes: [DEF]
16 | num_frames: 82
17 | set: train
18 | eval_t: True
19 | year: 2016
20 |
21 | - name: bike-packing
22 | attributes: []
23 | num_frames: 69
24 | set: val
25 | eval_t: False
26 | year: 2017
27 |
28 | - name: blackswan
29 | attributes: []
30 | num_frames: 50
31 | set: val
32 | eval_t: True
33 | year: 2016
34 |
35 | - name: bmx-bumps
36 | attributes: [LR, SV, SC, FM, CS, IO, MB, OCC, HO, EA, OV]
37 | num_frames: 90
38 | set: train
39 | eval_t: False
40 | year: 2016
41 |
42 | - name: bmx-trees
43 | attributes: [LR, SV, SC, FM, CS, IO, MB, DEF, OCC, HO, EA, BC]
44 | num_frames: 80
45 | set: val
46 | eval_t: False
47 | year: 2016
48 |
49 | - name: boat
50 | attributes: [SC, DB, EA, BC]
51 | num_frames: 75
52 | set: train
53 | eval_t: True
54 | year: 2016
55 |
56 | - name: boxing-fisheye
57 | attributes: []
58 | num_frames: 87
59 | set: train
60 | eval_t: False
61 | year: 2017
62 |
63 | - name: breakdance
64 | attributes: [FM, DB, MB, DEF, HO, ROT, OV, AC]
65 | num_frames: 84
66 | set: val
67 | eval_t: False
68 | year: 2016
69 |
70 | - name: breakdance-flare
71 | attributes: [FM, CS, MB, DEF, HO, ROT]
72 | num_frames: 71
73 | set: train
74 | eval_t: False
75 | year: 2016
76 |
77 | - name: bus
78 | attributes: [SC, OCC, HO, EA]
79 | num_frames: 80
80 | set: train
81 | eval_t: True
82 | year: 2016
83 |
84 | - name: camel
85 | attributes: [CS, IO, DEF, ROT]
86 | num_frames: 90
87 | set: val
88 | eval_t: True
89 | year: 2016
90 |
91 | - name: car-race
92 | attributes: []
93 | num_frames: 31
94 | set: test-dev
95 | eval_t: False
96 | year: 2017
97 |
98 | - name: car-roundabout
99 | attributes: [ROT, BC]
100 | num_frames: 75
101 | set: val
102 | eval_t: True
103 | year: 2016
104 |
105 | - name: car-shadow
106 | attributes: [LR, EA, AC, BC]
107 | num_frames: 40
108 | set: val
109 | eval_t: True
110 | year: 2016
111 |
112 | - name: car-turn
113 | attributes: [SV, ROT, BC]
114 | num_frames: 80
115 | set: train
116 | eval_t: True
117 | year: 2016
118 |
119 | - name: carousel
120 | attributes: []
121 | num_frames: 69
122 | set: test-dev
123 | eval_t: False
124 | year: 2017
125 |
126 | - name: cat-girl
127 | attributes: []
128 | num_frames: 89
129 | set: train
130 | eval_t: False
131 | year: 2017
132 |
133 | - name: cats-car
134 | attributes: []
135 | num_frames: 67
136 | set: test-dev
137 | eval_t: False
138 | year: 2017
139 |
140 | - name: chamaleon
141 | attributes: []
142 | num_frames: 85
143 | set: test-dev
144 | eval_t: False
145 | year: 2017
146 |
147 | - name: classic-car
148 | attributes: []
149 | num_frames: 63
150 | set: train
151 | eval_t: False
152 | year: 2017
153 |
154 | - name: color-run
155 | attributes: []
156 | num_frames: 84
157 | set: train
158 | eval_t: False
159 | year: 2017
160 |
161 | - name: cows
162 | attributes: [CS, IO, DEF, OCC, HO]
163 | num_frames: 104
164 | set: val
165 | eval_t: True
166 | year: 2016
167 |
168 | - name: crossing
169 | attributes: []
170 | num_frames: 52
171 | set: train
172 | eval_t: False
173 | year: 2017
174 |
175 | - name: dance-jump
176 | attributes: [SC, DB, MB, DEF, OCC, HO, ROT, EA]
177 | num_frames: 60
178 | set: train
179 | eval_t: True
180 | year: 2016
181 |
182 | - name: dance-twirl
183 | attributes: [SC, CS, IO, MB, DEF, HO, ROT, OV]
184 | num_frames: 90
185 | set: val
186 | eval_t: False
187 | year: 2016
188 |
189 | - name: dancing
190 | attributes: []
191 | num_frames: 62
192 | set: train
193 | eval_t: False
194 | year: 2017
195 |
196 | - name: deer
197 | attributes: []
198 | num_frames: 79
199 | set: test-dev
200 | eval_t: False
201 | year: 2017
202 |
203 | - name: disc-jockey
204 | attributes: []
205 | num_frames: 76
206 | set: train
207 | eval_t: False
208 | year: 2017
209 |
210 | - name: dog
211 | attributes: [FM, CS, MB, DEF, ROT, EA]
212 | num_frames: 60
213 | set: val
214 | eval_t: False
215 | year: 2016
216 |
217 | - name: dog-agility
218 | attributes: [FM, MB, DEF, OCC, HO, EA, OV, AC]
219 | num_frames: 25
220 | set: train
221 | eval_t: False
222 | year: 2016
223 |
224 | - name: dog-gooses
225 | attributes: []
226 | num_frames: 86
227 | set: train
228 | eval_t: False
229 | year: 2017
230 |
231 | - name: dogs-jump
232 | attributes: []
233 | num_frames: 66
234 | set: val
235 | eval_t: False
236 | year: 2017
237 |
238 | - name: dogs-scale
239 | attributes: []
240 | num_frames: 83
241 | set: train
242 | eval_t: False
243 | year: 2017
244 |
245 | - name: drift-chicane
246 | attributes: [LR, SV, FM, DB, HO, ROT, EA, AC]
247 | num_frames: 52
248 | set: val
249 | eval_t: False
250 | year: 2016
251 |
252 | - name: drift-straight
253 | attributes: [LR, SV, FM, CS, MB, HO, ROT, EA, OV, AC]
254 | num_frames: 50
255 | set: val
256 | eval_t: True
257 | year: 2016
258 |
259 | - name: drift-turn
260 | attributes: [SV, FM, IO, DB, HO, ROT, OV, AC]
261 | num_frames: 64
262 | set: train
263 | eval_t: True
264 | year: 2016
265 |
266 | - name: drone
267 | attributes: []
268 | num_frames: 91
269 | set: train
270 | eval_t: False
271 | year: 2017
272 |
273 | - name: elephant
274 | attributes: [CS, DB, DEF, EA]
275 | num_frames: 80
276 | set: train
277 | eval_t: True
278 | year: 2016
279 |
280 | - name: flamingo
281 | attributes: [SC, IO, DB, DEF, HO]
282 | num_frames: 80
283 | set: train
284 | eval_t: True
285 | year: 2016
286 |
287 | - name: giant-slalom
288 | attributes: []
289 | num_frames: 127
290 | set: test-dev
291 | eval_t: False
292 | year: 2017
293 |
294 | - name: girl-dog
295 | attributes: []
296 | num_frames: 86
297 | set: test-dev
298 | eval_t: False
299 | year: 2017
300 |
301 | - name: goat
302 | attributes: [CS, DEF, EA, BC]
303 | num_frames: 90
304 | set: val
305 | eval_t: False
306 | year: 2016
307 |
308 | - name: gold-fish
309 | attributes: []
310 | num_frames: 78
311 | set: val
312 | eval_t: False
313 | year: 2017
314 |
315 | - name: golf
316 | attributes: []
317 | num_frames: 79
318 | set: test-dev
319 | eval_t: False
320 | year: 2017
321 |
322 | - name: guitar-violin
323 | attributes: []
324 | num_frames: 55
325 | set: test-dev
326 | eval_t: False
327 | year: 2017
328 |
329 | - name: gym
330 | attributes: []
331 | num_frames: 60
332 | set: test-dev
333 | eval_t: False
334 | year: 2017
335 |
336 | - name: helicopter
337 | attributes: []
338 | num_frames: 49
339 | set: test-dev
340 | eval_t: False
341 | year: 2017
342 |
343 | - name: hike
344 | attributes: [LR, DEF, HO]
345 | num_frames: 80
346 | set: train
347 | eval_t: True
348 | year: 2016
349 |
350 | - name: hockey
351 | attributes: [SC, IO, DEF, HO, ROT]
352 | num_frames: 75
353 | set: train
354 | eval_t: True
355 | year: 2016
356 |
357 | - name: horsejump-high
358 | attributes: [SC, IO, DEF, OCC, HO]
359 | num_frames: 50
360 | set: val
361 | eval_t: False
362 | year: 2016
363 |
364 | - name: horsejump-low
365 | attributes: [SC, IO, DEF, OCC, HO, EA]
366 | num_frames: 60
367 | set: train
368 | eval_t: False
369 | year: 2016
370 |
371 | - name: horsejump-stick
372 | attributes: []
373 | num_frames: 58
374 | set: test-dev
375 | eval_t: False
376 | year: 2017
377 |
378 | - name: hoverboard
379 | attributes: []
380 | num_frames: 81
381 | set: test-dev
382 | eval_t: False
383 | year: 2017
384 |
385 | - name: india
386 | attributes: []
387 | num_frames: 81
388 | set: val
389 | eval_t: False
390 | year: 2017
391 |
392 | - name: judo
393 | attributes: []
394 | num_frames: 34
395 | set: val
396 | eval_t: False
397 | year: 2017
398 |
399 | - name: kid-football
400 | attributes: []
401 | num_frames: 68
402 | set: train
403 | eval_t: False
404 | year: 2017
405 |
406 | - name: kite-surf
407 | attributes: [SV, SC, IO, DB, MB, OCC, HO, EA]
408 | num_frames: 50
409 | set: val
410 | eval_t: True
411 | year: 2016
412 |
413 | - name: kite-walk
414 | attributes: [SC, IO, DB, DEF, OCC, HO]
415 | num_frames: 80
416 | set: train
417 | eval_t: True
418 | year: 2016
419 |
420 | - name: koala
421 | attributes: []
422 | num_frames: 100
423 | set: train
424 | eval_t: False
425 | year: 2017
426 |
427 | - name: lab-coat
428 | attributes: []
429 | num_frames: 47
430 | set: val
431 | eval_t: False
432 | year: 2017
433 |
434 | - name: lady-running
435 | attributes: []
436 | num_frames: 65
437 | set: train
438 | eval_t: False
439 | year: 2017
440 |
441 | - name: libby
442 | attributes: [SC, MB, DEF, OCC, HO, EA]
443 | num_frames: 49
444 | set: val
445 | eval_t: False
446 | year: 2016
447 |
448 | - name: lindy-hop
449 | attributes: []
450 | num_frames: 73
451 | set: train
452 | eval_t: False
453 | year: 2017
454 |
455 | - name: loading
456 | attributes: []
457 | num_frames: 50
458 | set: val
459 | eval_t: False
460 | year: 2017
461 |
462 | - name: lock
463 | attributes: []
464 | num_frames: 43
465 | set: test-dev
466 | eval_t: False
467 | year: 2017
468 |
469 | - name: longboard
470 | attributes: []
471 | num_frames: 52
472 | set: train
473 | eval_t: False
474 | year: 2017
475 |
476 |
477 | - name: lucia
478 | attributes: [DEF, OCC, HO]
479 | num_frames: 70
480 | set: train
481 | eval_t: False
482 | year: 2016
483 |
484 | - name: mallard-fly
485 | attributes: [LR, SV, FM, DB, MB, DEF, ROT, EA, OV, AC]
486 | num_frames: 70
487 | set: train
488 | eval_t: False
489 | year: 2016
490 |
491 | - name: mallard-water
492 | attributes: [LR, IO, DB, EA]
493 | num_frames: 80
494 | set: train
495 | eval_t: True
496 | year: 2016
497 |
498 | - name: man-bike
499 | attributes: []
500 | num_frames: 75
501 | set: test-dev
502 | eval_t: False
503 | year: 2017
504 |
505 |
506 | - name: mbike-trick
507 | attributes: []
508 | num_frames: 79
509 | set: val
510 | eval_t: False
511 | year: 2017
512 |
513 | - name: miami-surf
514 | attributes: []
515 | num_frames: 70
516 | set: train
517 | eval_t: False
518 | year: 2017
519 |
520 | - name: monkeys-trees
521 | attributes: []
522 | num_frames: 83
523 | set: test-dev
524 | eval_t: False
525 | year: 2017
526 |
527 |
528 | - name: motocross-bumps
529 | attributes: [SV, FM, IO, HO, ROT, OV, AC, BC]
530 | num_frames: 60
531 | set: train
532 | eval_t: True
533 | year: 2016
534 |
535 | - name: motocross-jump
536 | attributes: [SV, SC, FM, IO, MB, DEF, HO, ROT, EA, OV, AC]
537 | num_frames: 40
538 | set: val
539 | eval_t: False
540 | year: 2016
541 |
542 | - name: motorbike
543 | attributes: [LR, SV, SC, FM, IO, OCC, HO, ROT, EA]
544 | num_frames: 43
545 | set: train
546 | eval_t: False
547 | year: 2016
548 |
549 | - name: mtb-race
550 | attributes: []
551 | num_frames: 69
552 | set: test-dev
553 | eval_t: False
554 | year: 2017
555 |
556 | - name: night-race
557 | attributes: []
558 | num_frames: 46
559 | set: train
560 | eval_t: False
561 | year: 2017
562 |
563 | - name: orchid
564 | attributes: []
565 | num_frames: 57
566 | set: test-dev
567 | eval_t: False
568 | year: 2017
569 |
570 | - name: paragliding
571 | attributes: [LR, SC, IO, HO]
572 | num_frames: 70
573 | set: train
574 | eval_t: False
575 | year: 2016
576 |
577 | - name: paragliding-launch
578 | attributes: [SC, IO, DEF, HO, EA]
579 | num_frames: 80
580 | set: val
581 | eval_t: True
582 | year: 2016
583 |
584 | - name: parkour
585 | attributes: [LR, SV, FM, DEF, OCC, HO, ROT, AC]
586 | num_frames: 100
587 | set: val
588 | eval_t: False
589 | year: 2016
590 |
591 | - name: people-sunset
592 | attributes: []
593 | num_frames: 67
594 | set: test-dev
595 | eval_t: False
596 | year: 2017
597 |
598 | - name: pigs
599 | attributes: []
600 | num_frames: 79
601 | set: val
602 | eval_t: False
603 | year: 2017
604 |
605 | - name: planes-crossing
606 | attributes: []
607 | num_frames: 31
608 | set: test-dev
609 | eval_t: False
610 | year: 2017
611 |
612 | - name: planes-water
613 | attributes: []
614 | num_frames: 38
615 | set: train
616 | eval_t: False
617 | year: 2017
618 |
619 | - name: rallye
620 | attributes: []
621 | num_frames: 50
622 | set: train
623 | eval_t: False
624 | year: 2017
625 |
626 | - name: rhino
627 | attributes: [DEF, OCC, BC]
628 | num_frames: 90
629 | set: train
630 | eval_t: True
631 | year: 2016
632 |
633 | - name: rollerblade
634 | attributes: [LR, FM, CS, MB, DEF, HO]
635 | num_frames: 35
636 | set: train
637 | eval_t: False
638 | year: 2016
639 |
640 | - name: rollercoaster
641 | attributes: []
642 | num_frames: 70
643 | set: test-dev
644 | eval_t: False
645 | year: 2017
646 |
647 | - name: salsa
648 | attributes: []
649 | num_frames: 86
650 | set: test-dev
651 | eval_t: False
652 | year: 2017
653 |
654 | - name: schoolgirls
655 | attributes: []
656 | num_frames: 80
657 | set: train
658 | eval_t: False
659 | year: 2017
660 |
661 | - name: scooter-black
662 | attributes: [SV, IO, HO, EA]
663 | num_frames: 43
664 | set: val
665 | eval_t: True
666 | year: 2016
667 |
668 | - name: scooter-board
669 | attributes: []
670 | num_frames: 91
671 | set: train
672 | eval_t: False
673 | year: 2017
674 |
675 | - name: scooter-gray
676 | attributes: [SC, FM, IO, OCC, HO, ROT, EA, BC]
677 | num_frames: 75
678 | set: train
679 | eval_t: False
680 | year: 2016
681 |
682 | - name: seasnake
683 | attributes: []
684 | num_frames: 80
685 | set: test-dev
686 | eval_t: False
687 | year: 2017
688 |
689 | - name: sheep
690 | attributes: []
691 | num_frames: 68
692 | set: train
693 | eval_t: False
694 | year: 2017
695 |
696 | - name: shooting
697 | attributes: []
698 | num_frames: 40
699 | set: val
700 | eval_t: False
701 | year: 2017
702 |
703 | - name: skate-jump
704 | attributes: []
705 | num_frames: 68
706 | set: test-dev
707 | eval_t: False
708 | year: 2017
709 |
710 | - name: skate-park
711 | attributes: []
712 | num_frames: 80
713 | set: train
714 | eval_t: False
715 | year: 2017
716 |
717 | - name: slackline
718 | attributes: []
719 | num_frames: 60
720 | set: test-dev
721 | eval_t: False
722 | year: 2017
723 |
724 | - name: snowboard
725 | attributes: []
726 | num_frames: 66
727 | set: train
728 | eval_t: False
729 | year: 2017
730 |
731 | - name: soapbox
732 | attributes: [SV, IO, MB, DEF, HO, ROT, AC]
733 | num_frames: 99
734 | set: val
735 | eval_t: True
736 | year: 2016
737 |
738 | - name: soccerball
739 | attributes: [LR, FM, MB, OCC, HO]
740 | num_frames: 48
741 | set: train
742 | eval_t: False
743 | year: 2016
744 |
745 | - name: stroller
746 | attributes: [SC, FM, CS, IO, DEF, HO]
747 | num_frames: 91
748 | set: train
749 | eval_t: True
750 | year: 2016
751 |
752 | - name: stunt
753 | attributes: []
754 | num_frames: 71
755 | set: train
756 | eval_t: False
757 | year: 2017
758 |
759 | - name: subway
760 | attributes: []
761 | num_frames: 88
762 | set: test-dev
763 | eval_t: False
764 | year: 2017
765 |
766 | - name: surf
767 | attributes: [SV, FM, CS, IO, DB, HO, OV]
768 | num_frames: 55
769 | set: train
770 | eval_t: True
771 | year: 2016
772 |
773 | - name: swing
774 | attributes: [SC, FM, IO, DEF, OCC, HO]
775 | num_frames: 60
776 | set: train
777 | eval_t: False
778 | year: 2016
779 |
780 | - name: tandem
781 | attributes: []
782 | num_frames: 72
783 | set: test-dev
784 | eval_t: False
785 | year: 2017
786 |
787 |
788 | - name: tennis
789 | attributes: [SV, FM, IO, MB, DEF, HO]
790 | num_frames: 70
791 | set: train
792 | eval_t: False
793 | year: 2016
794 |
795 | - name: tennis-vest
796 | attributes: []
797 | num_frames: 75
798 | set: test-dev
799 | eval_t: False
800 | year: 2017
801 |
802 | - name: tractor
803 | attributes: []
804 | num_frames: 65
805 | set: test-dev
806 | eval_t: False
807 | year: 2017
808 |
809 | - name: tractor-sand
810 | attributes: []
811 | num_frames: 76
812 | set: train
813 | eval_t: False
814 | year: 2017
815 |
816 | - name: train
817 | attributes: [SC, HO, EA]
818 | num_frames: 80
819 | set: train
820 | eval_t: True
821 | year: 2016
822 |
823 | - name: tuk-tuk
824 | attributes: []
825 | num_frames: 59
826 | set: train
827 | eval_t: False
828 | year: 2017
829 |
830 | - name: upside-down
831 | attributes: []
832 | num_frames: 65
833 | set: train
834 | eval_t: False
835 | year: 2017
836 |
837 | - name: varanus-cage
838 | attributes: []
839 | num_frames: 67
840 | set: train
841 | eval_t: False
842 | year: 2017
843 |
844 | - name: walking
845 | attributes: []
846 | num_frames: 72
847 | set: train
848 | eval_t: False
849 | year: 2017
850 |
--------------------------------------------------------------------------------
/libs/data/palette.txt:
--------------------------------------------------------------------------------
1 | 0 0 0
2 | 128 0 0
3 | 0 128 0
4 | 128 128 0
5 | 0 0 128
6 | 128 0 128
7 | 0 128 128
8 | 128 128 128
9 | 64 0 0
10 | 191 0 0
11 | 64 128 0
12 | 191 128 0
13 | 64 0 128
14 | 191 0 128
15 | 64 128 128
16 | 191 128 128
17 | 0 64 0
18 | 128 64 0
19 | 0 191 0
20 | 128 191 0
21 | 0 64 128
22 | 128 64 128
23 | 22 22 22
24 | 23 23 23
25 | 24 24 24
26 | 25 25 25
27 | 26 26 26
28 | 27 27 27
29 | 28 28 28
30 | 29 29 29
31 | 30 30 30
32 | 31 31 31
33 | 32 32 32
34 | 33 33 33
35 | 34 34 34
36 | 35 35 35
37 | 36 36 36
38 | 37 37 37
39 | 38 38 38
40 | 39 39 39
41 | 40 40 40
42 | 41 41 41
43 | 42 42 42
44 | 43 43 43
45 | 44 44 44
46 | 45 45 45
47 | 46 46 46
48 | 47 47 47
49 | 48 48 48
50 | 49 49 49
51 | 50 50 50
52 | 51 51 51
53 | 52 52 52
54 | 53 53 53
55 | 54 54 54
56 | 55 55 55
57 | 56 56 56
58 | 57 57 57
59 | 58 58 58
60 | 59 59 59
61 | 60 60 60
62 | 61 61 61
63 | 62 62 62
64 | 63 63 63
65 | 64 64 64
66 | 65 65 65
67 | 66 66 66
68 | 67 67 67
69 | 68 68 68
70 | 69 69 69
71 | 70 70 70
72 | 71 71 71
73 | 72 72 72
74 | 73 73 73
75 | 74 74 74
76 | 75 75 75
77 | 76 76 76
78 | 77 77 77
79 | 78 78 78
80 | 79 79 79
81 | 80 80 80
82 | 81 81 81
83 | 82 82 82
84 | 83 83 83
85 | 84 84 84
86 | 85 85 85
87 | 86 86 86
88 | 87 87 87
89 | 88 88 88
90 | 89 89 89
91 | 90 90 90
92 | 91 91 91
93 | 92 92 92
94 | 93 93 93
95 | 94 94 94
96 | 95 95 95
97 | 96 96 96
98 | 97 97 97
99 | 98 98 98
100 | 99 99 99
101 | 100 100 100
102 | 101 101 101
103 | 102 102 102
104 | 103 103 103
105 | 104 104 104
106 | 105 105 105
107 | 106 106 106
108 | 107 107 107
109 | 108 108 108
110 | 109 109 109
111 | 110 110 110
112 | 111 111 111
113 | 112 112 112
114 | 113 113 113
115 | 114 114 114
116 | 115 115 115
117 | 116 116 116
118 | 117 117 117
119 | 118 118 118
120 | 119 119 119
121 | 120 120 120
122 | 121 121 121
123 | 122 122 122
124 | 123 123 123
125 | 124 124 124
126 | 125 125 125
127 | 126 126 126
128 | 127 127 127
129 | 128 128 128
130 | 129 129 129
131 | 130 130 130
132 | 131 131 131
133 | 132 132 132
134 | 133 133 133
135 | 134 134 134
136 | 135 135 135
137 | 136 136 136
138 | 137 137 137
139 | 138 138 138
140 | 139 139 139
141 | 140 140 140
142 | 141 141 141
143 | 142 142 142
144 | 143 143 143
145 | 144 144 144
146 | 145 145 145
147 | 146 146 146
148 | 147 147 147
149 | 148 148 148
150 | 149 149 149
151 | 150 150 150
152 | 151 151 151
153 | 152 152 152
154 | 153 153 153
155 | 154 154 154
156 | 155 155 155
157 | 156 156 156
158 | 157 157 157
159 | 158 158 158
160 | 159 159 159
161 | 160 160 160
162 | 161 161 161
163 | 162 162 162
164 | 163 163 163
165 | 164 164 164
166 | 165 165 165
167 | 166 166 166
168 | 167 167 167
169 | 168 168 168
170 | 169 169 169
171 | 170 170 170
172 | 171 171 171
173 | 172 172 172
174 | 173 173 173
175 | 174 174 174
176 | 175 175 175
177 | 176 176 176
178 | 177 177 177
179 | 178 178 178
180 | 179 179 179
181 | 180 180 180
182 | 181 181 181
183 | 182 182 182
184 | 183 183 183
185 | 184 184 184
186 | 185 185 185
187 | 186 186 186
188 | 187 187 187
189 | 188 188 188
190 | 189 189 189
191 | 190 190 190
192 | 191 191 191
193 | 192 192 192
194 | 193 193 193
195 | 194 194 194
196 | 195 195 195
197 | 196 196 196
198 | 197 197 197
199 | 198 198 198
200 | 199 199 199
201 | 200 200 200
202 | 201 201 201
203 | 202 202 202
204 | 203 203 203
205 | 204 204 204
206 | 205 205 205
207 | 206 206 206
208 | 207 207 207
209 | 208 208 208
210 | 209 209 209
211 | 210 210 210
212 | 211 211 211
213 | 212 212 212
214 | 213 213 213
215 | 214 214 214
216 | 215 215 215
217 | 216 216 216
218 | 217 217 217
219 | 218 218 218
220 | 219 219 219
221 | 220 220 220
222 | 221 221 221
223 | 222 222 222
224 | 223 223 223
225 | 224 224 224
226 | 225 225 225
227 | 226 226 226
228 | 227 227 227
229 | 228 228 228
230 | 229 229 229
231 | 230 230 230
232 | 231 231 231
233 | 232 232 232
234 | 233 233 233
235 | 234 234 234
236 | 235 235 235
237 | 236 236 236
238 | 237 237 237
239 | 238 238 238
240 | 239 239 239
241 | 240 240 240
242 | 241 241 241
243 | 242 242 242
244 | 243 243 243
245 | 244 244 244
246 | 245 245 245
247 | 246 246 246
248 | 247 247 247
249 | 248 248 248
250 | 249 249 249
251 | 250 250 250
252 | 251 251 251
253 | 252 252 252
254 | 253 253 253
255 | 254 254 254
256 | 255 255 255
257 |
--------------------------------------------------------------------------------
/libs/loader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import time
3 | import random
4 | import scipy.io
5 | from PIL import Image
6 | import cv2
7 | import torch
8 | from os.path import exists, join, split
9 | import libs.transforms_multi as transforms
10 | from torchvision import datasets
11 |
12 | def video_loader(video_path, frame_end, step, frame_start=0):
13 | cap = cv2.VideoCapture(video_path)
14 | cap.set(1, frame_start - 1)
15 | video = []
16 | for i in range(frame_start - 1, frame_end, step):
17 | cap.set(1, i)
18 | success, image = cap.read()
19 | if not success:
20 | raise Exception('Error while reading video {}'.format(video_path))
21 | pil_im = image
22 | video.append(pil_im)
23 | return video
24 |
25 |
26 | def framepair_loader(video_path, frame_start, frame_end):
27 |
28 | cap = cv2.VideoCapture(video_path)
29 |
30 | pair = []
31 | id_ = np.zeros(2)
32 | frame_num = frame_end - frame_start
33 | if frame_end > 50:
34 | id_[0] = random.randint(frame_start, frame_end-50)
35 | id_[1] = id_[0] + random.randint(1, 50)
36 | else:
37 | id_[0] = random.randint(frame_start, frame_end)
38 | id_[1] = random.randint(frame_start, frame_end)
39 |
40 |
41 | for ii in range(2):
42 |
43 | cap.set(1, id_[ii])
44 |
45 | success, image = cap.read()
46 |
47 | if not success:
48 | print("id, frame_end:", id_, frame_end)
49 | raise Exception('Error while reading video {}'.format(video_path))
50 |
51 | h,w,_ = image.shape
52 | h = (h // 64) * 64
53 | w = (w // 64) * 64
54 | image = cv2.resize(image, (w,h))
55 | image = image.astype(np.uint8)
56 | pil_im = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
57 | pair.append(pil_im)
58 |
59 | return pair
60 |
61 | def video_frame_counter(video_path):
62 | cap = cv2.VideoCapture(video_path)
63 | return cap.get(7)
64 |
65 |
66 | class VidListv1(torch.utils.data.Dataset):
67 | # for warm up, random crop both
68 | def __init__(self, video_path, list_path, patch_size, rotate = 10, scale=1.2, is_train=True, moreaug= True):
69 | super(VidListv1, self).__init__()
70 | self.data_dir = video_path
71 | self.list_path = list_path
72 | normalize = transforms.Normalize(mean = (128, 128, 128), std = (128, 128, 128))
73 |
74 | t = []
75 | if rotate > 0:
76 | t.append(transforms.RandomRotate(rotate))
77 | if scale > 0:
78 | t.append(transforms.RandomScale(scale))
79 | t.extend([transforms.RandomCrop(patch_size, seperate =moreaug), transforms.RandomHorizontalFlip(), transforms.ToTensor(),
80 | normalize])
81 |
82 | self.transforms = transforms.Compose(t)
83 |
84 | self.is_train = is_train
85 | self.read_list()
86 |
87 | def __getitem__(self, idx):
88 | while True:
89 | video_ = self.list[idx]
90 | frame_end = video_frame_counter(video_)-1
91 | if frame_end <=0:
92 | print("Empty video {}, skip to the next".format(self.list[idx]))
93 | idx += 1
94 | else:
95 | break
96 |
97 | pair_ = framepair_loader(video_, 0, frame_end)
98 | data = list(self.transforms(*pair_))
99 | return tuple(data)
100 |
101 | def __len__(self):
102 | return len(self.list)
103 |
104 | def read_list(self):
105 | path = join(self.list_path)
106 | root = path.partition("Kinetices/")[0]
107 | if not exists(path):
108 | raise Exception("{} does not exist in kinet_dataset.py.".format(path))
109 | self.list = [line.replace("/Data/", root).strip() for line in open(path, 'r')]
110 |
111 |
112 | class VidListv2(torch.utils.data.Dataset):
113 | # for localization, random crop frame1
114 | def __init__(self, video_path, list_path, patch_size, window_len, rotate = 10, scale = 1.2, full_size = 640, is_train=True):
115 | super(VidListv2, self).__init__()
116 | self.data_dir = video_path
117 | self.list_path = list_path
118 | self.window_len = window_len
119 | normalize = transforms.Normalize(mean = (128, 128, 128), std = (128, 128, 128))
120 | self.transforms1 = transforms.Compose([
121 | transforms.RandomRotate(rotate),
122 | # transforms.RandomScale(scale),
123 | transforms.ResizeandPad(full_size),
124 | transforms.RandomCrop(patch_size),
125 | transforms.ToTensor(),
126 | normalize])
127 | self.transforms2 = transforms.Compose([
128 | transforms.ResizeandPad(full_size),
129 | transforms.ToTensor(),
130 | normalize])
131 | self.is_train = is_train
132 | self.read_list()
133 |
134 | def __getitem__(self, idx):
135 | while True:
136 | video_ = self.list[idx]
137 | frame_end = video_frame_counter(video_)-1
138 | if frame_end <=0:
139 | print("Empty video {}, skip to the next".format(self.list[idx]))
140 | idx += 1
141 | else:
142 | break
143 |
144 | pair_ = framepair_loader(video_, 0, frame_end)
145 | data1 = list(self.transforms1(*pair_))
146 | data2 = list(self.transforms2(*pair_))
147 | if self.window_len == 2:
148 | data = [data1[0],data2[1]]
149 | else:
150 | data = [data1[0],data2[1], data2[2]]
151 | return tuple(data)
152 |
153 | def __len__(self):
154 | return len(self.list)
155 |
156 | def read_list(self):
157 | path = join(self.list_path)
158 | root = path.partition("Kinetices/")[0]
159 | if not exists(path):
160 | raise Exception("{} does not exist in kinet_dataset.py.".format(path))
161 | self.list = [line.replace("/Data/", root).strip() for line in open(path, 'r')]
162 |
163 |
164 |
165 | if __name__ == '__main__':
166 | normalize = transforms.Normalize(mean = (128, 128, 128),
167 | std = (128, 128, 128))
168 | t = []
169 | t.extend([transforms.RandomCrop(256),
170 | transforms.RandomHorizontalFlip(),
171 | transforms.ToTensor(),
172 | normalize])
173 | dataset_train = VidList('/home/xtli/DATA/compress/train_256/',
174 | '/home/xtli/DATA/compress/train.txt',
175 | transforms.Compose(t), window_len=2)
176 |
177 | train_loader = torch.utils.data.DataLoader(dataset_train,
178 | batch_size = 16,
179 | shuffle = True,
180 | num_workers=8,
181 | drop_last=True)
182 |
183 | start_time = time.time()
184 | for i, (frames) in enumerate(train_loader):
185 | print(i)
186 | if(i >= 1000):
187 | break
188 | end_time = time.time()
189 | print((end_time - start_time) / 1000)
190 |
--------------------------------------------------------------------------------
/libs/loss.py:
--------------------------------------------------------------------------------
1 | from numpy.random import randint
2 | import torch
3 | import torch.nn as nn
4 |
5 | def weight_L1(pred, gt, w_a, w_b, mask=None):
6 | l1 = nn.L1Loss()
7 | if mask is not None:
8 | pred = pred * mask.repeat(1,pred.size(1),1,1)
9 | gt = gt * mask.repeat(1,gt.size(1),1,1)
10 | l_l = l1(pred[:,0,:,:], gt[:,0,:,:])
11 | l_a = l1(pred[:,1,:,:], gt[:,1,:,:])
12 | l_b = l1(pred[:,2,:,:], gt[:,2,:,:])
13 | loss = l_l + l_a * w_a + l_b * w_b
14 | return loss
15 |
16 | def weightsingle_L1(pred, gt, w_a, w_b):
17 | n,c,h,w = pred.size()
18 | l_l = torch.abs(pred[:,0,:,:] - gt[:,0,:,:])
19 | l_a = torch.abs(pred[:,1,:,:] - gt[:,1,:,:])
20 | l_b = torch.abs(pred[:,2,:,:] - gt[:,2,:,:])
21 |
22 | l_l = torch.mean(l_l.view(n,-1),dim=1)
23 | l_a = torch.mean(l_a.view(n,-1),dim=1)
24 | l_b = torch.mean(l_b.view(n,-1),dim=1)
25 | loss = l_l + l_a * w_a + l_b * w_b
26 | return loss
27 |
28 | def switch_L1_thr(pred2, frame2_var, pred1, frame1_var, w_a, w_b, thr):
29 | loss1 = weightsingle_L1(pred2, frame2_var, w_a, w_b)
30 | loss2 = weightsingle_L1(pred1, frame1_var, w_a, w_b)
31 | loss = loss1 + 0.1 * loss2
32 | loss[loss1 > thr] = 0
33 | return torch.mean(loss)
34 |
35 | def L1_thr(pred2, frame2_var, w_a, w_b, thr):
36 | loss = weightsingle_L1(pred2, frame2_var, w_a, w_b)
37 | loss[loss > thr] = 0
38 | return torch.mean(loss)
39 |
40 | # merge the above codes
41 | def L1_loss(pred2, frame2_var, w_a, w_b, thr=None, pred1=None, frame1_var=None):
42 | if pred1 is None:
43 | loss = weightsingle_L1(pred2, frame2_var, w_a, w_b)
44 | else:
45 | loss1 = weightsingle_L1(pred2, frame2_var, w_a, w_b)
46 | loss2 = weightsingle_L1(pred1, frame1_var, w_a, w_b)
47 | loss = loss1 + 0.1 * loss2
48 | if thr is not None:
49 | loss[loss > thr] = 0
50 | return torch.mean(loss)
--------------------------------------------------------------------------------
/libs/model.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | import torch.nn as nn
4 |
5 | from libs.utils import *
6 | from libs.autoencoder import encoder3, decoder3, encoder_res18
7 |
8 | class NLM_woSoft(nn.Module):
9 | """
10 | Non-local mean layer w/o softmax on affinity
11 | """
12 | def __init__(self):
13 | super(NLM_woSoft, self).__init__()
14 |
15 | def forward(self, in1, in2):
16 | n,c,h,w = in1.size()
17 | in1 = in1.view(n,c,-1)
18 | in2 = in2.view(n,c,-1)
19 |
20 | affinity = torch.bmm(in1.permute(0,2,1), in2)
21 | return affinity
22 |
23 | def transform(aff, frame1):
24 | """
25 | Given aff, copy from frame1 to construct frame2.
26 | INPUTS:
27 | - aff: (h*w)*(h*w) affinity matrix
28 | - frame1: n*c*h*w feature map
29 | """
30 | b,c,h,w = frame1.size()
31 | frame1 = frame1.view(b,c,-1)
32 | frame2 = torch.bmm(frame1, aff)
33 | return frame2.view(b,c,h,w)
34 |
35 | class normalize(nn.Module):
36 | """Given mean: (R, G, B) and std: (R, G, B),
37 | will normalize each channel of the torch.*Tensor, i.e.
38 | channel = (channel - mean) / std
39 | """
40 |
41 | def __init__(self, mean, std = (1.0,1.0,1.0)):
42 | super(normalize, self).__init__()
43 | self.mean = nn.Parameter(torch.FloatTensor(mean).cuda(), requires_grad=False)
44 | self.std = nn.Parameter(torch.FloatTensor(std).cuda(), requires_grad=False)
45 |
46 | def forward(self, frames):
47 | b,c,h,w = frames.size()
48 | frames = (frames - self.mean.view(1,3,1,1).repeat(b,1,h,w))/self.std.view(1,3,1,1).repeat(b,1,h,w)
49 | return frames
50 |
51 | def create_flat_grid(F_size, GPU=True):
52 | """
53 | INPUTS:
54 | - F_size: feature size
55 | OUTPUT:
56 | - return a standard grid coordinate
57 | """
58 | b, c, h, w = F_size
59 | theta = torch.tensor([[1,0,0],[0,1,0]])
60 | theta = theta.unsqueeze(0).repeat(b,1,1)
61 | theta = theta.float()
62 |
63 | # grid is a uniform grid with left top (-1,1) and right bottom (1,1)
64 | # b * (h*w) * 2
65 | grid = torch.nn.functional.affine_grid(theta, F_size)
66 | grid[:,:,:,0] = (grid[:,:,:,0]+1)/2 * w
67 | grid[:,:,:,1] = (grid[:,:,:,1]+1)/2 * h
68 | grid_flat = grid.view(b,-1,2)
69 | if(GPU):
70 | grid_flat = grid_flat.cuda()
71 | return grid_flat
72 |
73 |
74 | class track_match_comb(nn.Module):
75 | def __init__(self, pretrained, encoder_dir = None, decoder_dir = None, temp=1, color_switch=True, coord_switch=True):
76 | super(track_match_comb, self).__init__()
77 |
78 | self.gray_encoder = encoder_res18(pretrained=pretrained, uselayer=4)
79 | self.rgb_encoder = encoder3(reduce=True)
80 | self.decoder = decoder3(reduce=True)
81 |
82 | self.rgb_encoder.load_state_dict(torch.load(encoder_dir))
83 | self.decoder.load_state_dict(torch.load(decoder_dir))
84 | for param in self.decoder.parameters():
85 | param.requires_grad = False
86 | for param in self.rgb_encoder.parameters():
87 | param.requires_grad = False
88 |
89 | self.nlm = NLM_woSoft()
90 | self.normalize = normalize(mean=[0.485, 0.456, 0.406],
91 | std=[0.229, 0.224, 0.225])
92 | self.softmax = nn.Softmax(dim=1)
93 | self.temp = temp
94 | self.grid_flat = None
95 | self.grid_flat_crop = None
96 | self.color_switch = color_switch
97 | self.coord_switch = coord_switch
98 |
99 |
100 | def forward(self, img_ref, img_tar, warm_up=True, patch_size=None):
101 | n, c, h_ref, w_ref = img_ref.size()
102 | n, c, h_tar, w_tar = img_tar.size()
103 | gray_ref = copy.deepcopy(img_ref[:,0].view(n,1,h_ref,w_ref).repeat(1,3,1,1))
104 | gray_tar = copy.deepcopy(img_tar[:,0].view(n,1,h_tar,w_tar).repeat(1,3,1,1))
105 |
106 | gray_ref = (gray_ref + 1) / 2
107 | gray_tar = (gray_tar + 1) / 2
108 |
109 | gray_ref = self.normalize(gray_ref)
110 | gray_tar = self.normalize(gray_tar)
111 |
112 | Fgray1 = self.gray_encoder(gray_ref)
113 | Fgray2 = self.gray_encoder(gray_tar)
114 | Fcolor1 = self.rgb_encoder(img_ref)
115 |
116 | output = []
117 |
118 | if warm_up:
119 | aff = self.nlm(Fgray1, Fgray2)
120 | aff_norm = self.softmax(aff)
121 | Fcolor2_est = transform(aff_norm, Fcolor1)
122 | color2_est = self.decoder(Fcolor2_est)
123 |
124 | output.append(color2_est)
125 | output.append(aff)
126 |
127 | if self.color_switch:
128 | Fcolor2 = self.rgb_encoder(img_tar)
129 | Fcolor1_est = transform(aff_norm.transpose(1,2), Fcolor2)
130 | color1_est = self.decoder(Fcolor1_est)
131 | output.append(color1_est)
132 | else:
133 | if(self.grid_flat is None):
134 | self.grid_flat = create_flat_grid(Fgray2.size())
135 | aff_ref_tar = self.nlm(Fgray1, Fgray2)
136 | aff_ref_tar = torch.nn.functional.softmax(aff_ref_tar * self.temp, dim = 2)
137 | coords = torch.bmm(aff_ref_tar, self.grid_flat)
138 | new_c = coords2bbox(coords, patch_size, h_tar, w_tar)
139 | Fgray2_crop = diff_crop(Fgray2, new_c[:,0], new_c[:,2], new_c[:,1], new_c[:,3], patch_size[1], patch_size[0])
140 |
141 | aff_p = self.nlm(Fgray1, Fgray2_crop)
142 | aff_norm = self.softmax(aff_p * self.temp)
143 | Fcolor2_est = transform(aff_norm, Fcolor1)
144 | color2_est = self.decoder(Fcolor2_est)
145 |
146 | Fcolor2_full = self.rgb_encoder(img_tar)
147 | Fcolor2_crop = diff_crop(Fcolor2_full, new_c[:,0], new_c[:,2], new_c[:,1], new_c[:,3], patch_size[1], patch_size[0])
148 |
149 | output.append(color2_est)
150 | output.append(Fcolor2_crop)
151 | output.append(aff_p)
152 | output.append(new_c*8)
153 | output.append(coords)
154 |
155 | # color orthorganal
156 | if self.color_switch:
157 | Fcolor1_est = transform(aff_norm.transpose(1,2), Fcolor2_crop)
158 | color1_est = self.decoder(Fcolor1_est)
159 | output.append(color1_est)
160 |
161 | # coord orthorganal
162 | if self.coord_switch:
163 | aff_norm_tran = self.softmax(aff_p.permute(0,2,1)*self.temp)
164 | if self.grid_flat_crop is None:
165 | self.grid_flat_crop = create_flat_grid(Fp_tar.size()).permute(0,2,1).detach()
166 | C12 = torch.bmm(self.grid_flat_crop, aff_norm)
167 | C11 = torch.bmm(C12, aff_norm_tran)
168 | output.append(self.grid_flat_crop)
169 | output.append(C11)
170 |
171 | # return pred1, pred2, aff_p, new_c * 8, self.grid_flat_crop, C11, coords
172 | return output
173 |
174 |
175 | class Model_switchGTfixdot_swCC_Res(nn.Module):
176 | def __init__(self, encoder_dir = None, decoder_dir = None,
177 | temp = None, pretrainRes = False, uselayer=4):
178 | '''
179 | For switchable concenration loss
180 | Using Resnet18
181 | '''
182 | super(Model_switchGTfixdot_swCC_Res, self).__init__()
183 | self.gray_encoder = encoder_res18(pretrained = pretrainRes, uselayer=uselayer)
184 | self.rgb_encoder = encoder3(reduce = True)
185 | self.nlm = NLM_woSoft()
186 | self.decoder = decoder3(reduce = True)
187 | self.temp = temp
188 | self.softmax = nn.Softmax(dim=1)
189 | self.cos_window = torch.Tensor(np.outer(np.hanning(40), np.hanning(40))).cuda()
190 | self.normalize = normalize(mean=[0.485, 0.456, 0.406],
191 | std=[0.229, 0.224, 0.225])
192 |
193 | if(not encoder_dir is None):
194 | print("Using pretrained encoders: %s."%encoder_dir)
195 | self.rgb_encoder.load_state_dict(torch.load(encoder_dir))
196 | if(not decoder_dir is None):
197 | print("Using pretrained decoders: %s."%decoder_dir)
198 | self.decoder.load_state_dict(torch.load(decoder_dir))
199 |
200 | for param in self.decoder.parameters():
201 | param.requires_grad = False
202 | for param in self.rgb_encoder.parameters():
203 | param.requires_grad = False
204 |
205 | def forward(self, gray1, gray2, color1=None, color2=None):
206 | gray1 = (gray1 + 1) / 2
207 | gray2 = (gray2 + 1) / 2
208 |
209 | gray1 = self.normalize(gray1)
210 | gray2 = self.normalize(gray2)
211 |
212 |
213 | Fgray1 = self.gray_encoder(gray1)
214 | Fgray2 = self.gray_encoder(gray2)
215 |
216 | aff = self.nlm(Fgray1, Fgray2)
217 | aff_norm = self.softmax(aff*self.temp)
218 |
219 | if(color1 is None):
220 | # for testing
221 | return aff_norm, Fgray1, Fgray2
222 |
223 | Fcolor1 = self.rgb_encoder(color1)
224 | Fcolor2 = self.rgb_encoder(color2)
225 | Fcolor2_est = transform(aff_norm, Fcolor1)
226 | pred2 = self.decoder(Fcolor2_est)
227 |
228 | Fcolor1_est = transform(aff_norm.transpose(1,2), Fcolor2)
229 | pred1 = self.decoder(Fcolor1_est)
230 |
231 | return pred1, pred2, aff_norm, aff, Fgray1, Fgray2
232 |
--------------------------------------------------------------------------------
/libs/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.model_zoo as model_zoo
3 |
4 |
5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
6 | 'resnet152']
7 |
8 |
9 | model_urls = {
10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
15 | }
16 |
17 |
18 | def conv3x3(in_planes, out_planes, stride=1):
19 | """3x3 convolution with padding"""
20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
21 | padding=1, bias=False)
22 |
23 |
24 | def conv1x1(in_planes, out_planes, stride=1):
25 | """1x1 convolution"""
26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
27 |
28 |
29 | class BasicBlock(nn.Module):
30 | expansion = 1
31 |
32 | def __init__(self, inplanes, planes, stride=1, downsample=None):
33 | super(BasicBlock, self).__init__()
34 | self.conv1 = conv3x3(inplanes, planes, stride)
35 | self.bn1 = nn.BatchNorm2d(planes)
36 | self.relu = nn.ReLU(inplace=True)
37 | self.conv2 = conv3x3(planes, planes)
38 | self.bn2 = nn.BatchNorm2d(planes)
39 | self.downsample = downsample
40 | self.stride = stride
41 |
42 | def forward(self, x):
43 | identity = x
44 |
45 | out = self.conv1(x)
46 | out = self.bn1(out)
47 | out = self.relu(out)
48 |
49 | out = self.conv2(out)
50 | out = self.bn2(out)
51 |
52 | if self.downsample is not None:
53 | identity = self.downsample(x)
54 |
55 | out += identity
56 | out = self.relu(out)
57 |
58 | return out
59 |
60 |
61 | class Bottleneck(nn.Module):
62 | expansion = 4
63 |
64 | def __init__(self, inplanes, planes, stride=1, downsample=None):
65 | super(Bottleneck, self).__init__()
66 | self.conv1 = conv1x1(inplanes, planes)
67 | self.bn1 = nn.BatchNorm2d(planes)
68 | self.conv2 = conv3x3(planes, planes, stride)
69 | self.bn2 = nn.BatchNorm2d(planes)
70 | self.conv3 = conv1x1(planes, planes * self.expansion)
71 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
72 | self.relu = nn.ReLU(inplace=True)
73 | self.downsample = downsample
74 | self.stride = stride
75 |
76 | def forward(self, x):
77 | identity = x
78 |
79 | out = self.conv1(x)
80 | out = self.bn1(out)
81 | out = self.relu(out)
82 |
83 | out = self.conv2(out)
84 | out = self.bn2(out)
85 | out = self.relu(out)
86 |
87 | out = self.conv3(out)
88 | out = self.bn3(out)
89 |
90 | if self.downsample is not None:
91 | identity = self.downsample(x)
92 |
93 | out += identity
94 | out = self.relu(out)
95 |
96 | return out
97 |
98 |
99 | class ResNet(nn.Module):
100 |
101 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, uselayer=3):
102 | super(ResNet, self).__init__()
103 | self.inplanes = 64
104 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
105 | bias=False)
106 | self.bn1 = nn.BatchNorm2d(64)
107 | self.relu = nn.ReLU(inplace=True)
108 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
109 | self.layer1 = self._make_layer(block, 64, layers[0])
110 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
111 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1)
112 | if uselayer==4:
113 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1)
114 | self.uselayer = uselayer
115 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
116 | # self.fc = nn.Linear(512 * block.expansion, num_classes)
117 |
118 | for m in self.modules():
119 | if isinstance(m, nn.Conv2d):
120 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
121 | elif isinstance(m, nn.BatchNorm2d):
122 | nn.init.constant_(m.weight, 1)
123 | nn.init.constant_(m.bias, 0)
124 |
125 | # Zero-initialize the last BN in each residual branch,
126 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
127 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
128 | if zero_init_residual:
129 | for m in self.modules():
130 | if isinstance(m, Bottleneck):
131 | nn.init.constant_(m.bn3.weight, 0)
132 | elif isinstance(m, BasicBlock):
133 | nn.init.constant_(m.bn2.weight, 0)
134 |
135 | def _make_layer(self, block, planes, blocks, stride=1):
136 | downsample = None
137 | if stride != 1 or self.inplanes != planes * block.expansion:
138 | downsample = nn.Sequential(
139 | conv1x1(self.inplanes, planes * block.expansion, stride),
140 | nn.BatchNorm2d(planes * block.expansion),
141 | )
142 |
143 | layers = []
144 | layers.append(block(self.inplanes, planes, stride, downsample))
145 | self.inplanes = planes * block.expansion
146 | for _ in range(1, blocks):
147 | layers.append(block(self.inplanes, planes))
148 |
149 | return nn.Sequential(*layers)
150 |
151 | def forward(self, x):
152 | x = self.conv1(x)
153 | x = self.bn1(x)
154 | x = self.relu(x)
155 | x = self.maxpool(x)
156 |
157 | x = self.layer1(x)
158 | x = self.layer2(x)
159 | x = self.layer3(x)
160 | if self.uselayer == 4:
161 | x = self.layer4(x)
162 |
163 | # x = self.avgpool(x)
164 | # x = x.view(x.size(0), -1)
165 | # x = self.fc(x)
166 |
167 | return x
168 |
169 |
170 | def resnet18(pretrained=False, **kwargs):
171 | """Constructs a ResNet-18 model.
172 |
173 | Args:
174 | pretrained (bool): If True, returns a model pre-trained on ImageNet
175 | """
176 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
177 | if pretrained:
178 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), model_dir = "ae_models")
179 | return model
180 |
181 |
182 |
183 | def resnet50(pretrained=False, **kwargs):
184 | """Constructs a ResNet-50 model.
185 |
186 | Args:
187 | pretrained (bool): If True, returns a model pre-trained on ImageNet
188 | """
189 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
190 | if pretrained:
191 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
192 | return model
193 |
--------------------------------------------------------------------------------
/libs/test_utils.py:
--------------------------------------------------------------------------------
1 | # OS libraries
2 | import os
3 | import cv2
4 | import glob
5 | import scipy.misc
6 | import numpy as np
7 | from PIL import Image
8 |
9 | # Pytorch
10 | import torch
11 |
12 | # Customized libraries
13 | import libs.transforms_pair as transforms
14 |
15 | color_palette = np.loadtxt('libs/data/palette.txt',dtype=np.uint8).reshape(-1,3)
16 |
17 | def transform_topk(aff, frame1, k, h2=None, w2=None):
18 | """
19 | INPUTS:
20 | - aff: affinity matrix, b * N * N
21 | - frame1: reference frame
22 | - k: only aggregate top-k pixels with highest aff(j,i)
23 | - h2, w2, frame2's height & width
24 | OUTPUT:
25 | - frame2: propagated mask from frame1 to the next frame
26 | """
27 | b,c,h,w = frame1.size()
28 | b, N1, N2 = aff.size()
29 | # b * 20 * N
30 | tk_val, tk_idx = torch.topk(aff, dim = 1, k=k)
31 | # b * N
32 | tk_val_min,_ = torch.min(tk_val,dim=1)
33 | tk_val_min = tk_val_min.view(b,1,N2)
34 | aff[tk_val_min > aff] = 0
35 | frame1 = frame1.contiguous().view(b,c,-1)
36 | frame2 = torch.bmm(frame1, aff)
37 | if(h2 is None):
38 | return frame2.view(b,c,h,w)
39 | else:
40 | return frame2.view(b,c,h2,w2)
41 |
42 | def read_frame_list(video_dir):
43 | frame_list = [img for img in glob.glob(os.path.join(video_dir,"*.jpg"))]
44 | frame_list = sorted(frame_list)
45 | return frame_list
46 |
47 | def create_transforms():
48 | normalize = transforms.Normalize(mean = (128, 128, 128), std = (128, 128, 128))
49 | t = []
50 | t.extend([transforms.ToTensor(),
51 | normalize])
52 | return transforms.Compose(t)
53 |
54 | def read_frame(frame_dir, transforms, scale_size):
55 | """
56 | read a single frame & preprocess
57 | """
58 | frame = cv2.imread(frame_dir)
59 | ori_h,ori_w,_ = frame.shape
60 | # scale, makes height & width multiples of 64
61 | if(len(scale_size) == 1):
62 | if(ori_h > ori_w):
63 | tw = scale_size[0]
64 | th = (tw * ori_h) / ori_w
65 | th = int((th // 64) * 64)
66 | else:
67 | th = scale_size[0]
68 | tw = (th * ori_w) / ori_h
69 | tw = int((tw // 64) * 64)
70 | else:
71 | tw = scale_size[1]
72 | th = scale_size[0]
73 | frame = cv2.resize(frame, (tw,th))
74 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2LAB)
75 |
76 | pair = [frame, frame]
77 | transformed = list(transforms(*pair))
78 | return transformed[0].cuda().unsqueeze(0), ori_h, ori_w
79 |
80 | def to_one_hot(y_tensor, n_dims=None):
81 | """
82 | Take integer y (tensor or variable) with n dims &
83 | convert it to 1-hot representation with n+1 dims.
84 | """
85 | if(n_dims is None):
86 | n_dims = int(y_tensor.max()+ 1)
87 | _,h,w = y_tensor.size()
88 | y_tensor = y_tensor.type(torch.LongTensor).view(-1, 1)
89 | n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1
90 | y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1)
91 | y_one_hot = y_one_hot.view(h,w,n_dims)
92 | return y_one_hot.permute(2,0,1).unsqueeze(0)
93 |
94 | def read_seg(seg_dir, scale_size):
95 | seg = Image.open(seg_dir)
96 | h,w = seg.size
97 | if(len(scale_size) == 1):
98 | if(h > w):
99 | tw = scale_size[0]
100 | th = (tw * h) / w
101 | th = int((th // 64) * 64)
102 | else:
103 | th = scale_size[0]
104 | tw = (th * w) / h
105 | tw = int((tw // 64) * 64)
106 | else:
107 | tw = scale_size[1]
108 | th = scale_size[0]
109 | seg = np.asarray(seg).reshape((w,h,1))
110 | seg_ori = np.squeeze(seg)
111 | small_seg = scipy.misc.imresize(seg_ori, (tw//8,th//8),"nearest",mode="F")
112 | large_seg = scipy.misc.imresize(seg_ori, (tw,th),"nearest",mode="F")
113 |
114 | t = []
115 | t.extend([transforms.ToTensor()])
116 | trans = transforms.Compose(t)
117 | pair = [large_seg, small_seg]
118 | transformed = list(trans(*pair))
119 | large_seg = transformed[0]
120 | small_seg = transformed[1]
121 | return to_one_hot(large_seg), to_one_hot(small_seg), seg_ori
122 |
123 | def imwrite_indexed(filename,array):
124 | """ Save indexed png for DAVIS."""
125 | if np.atleast_3d(array).shape[2] != 1:
126 | raise Exception("Saving indexed PNGs requires 2D array.")
127 |
128 | im = Image.fromarray(array)
129 | im.putpalette(color_palette.ravel())
130 | im.save(filename, format='PNG')
131 |
--------------------------------------------------------------------------------
/libs/track_utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import math
3 | import copy
4 | import torch
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 |
8 | import torchvision.utils as vutils
9 |
10 | from libs.vis_utils import norm_mask
11 |
12 | ############################# GLOBAL VARIABLES ########################
13 | color_platte = [[0, 0, 0],[128, 0, 0],[0, 128, 0],[128, 128, 0],
14 | [0, 0, 128],[128, 0, 128],[0, 128, 128],[128, 128, 128],
15 | [64, 0, 0],[192, 0, 0],[64, 128, 0],[192, 128, 0],[64, 0, 128],
16 | [192, 0, 128],[64, 128, 128],[192, 128, 128],[0, 64, 0],
17 | [128, 64, 0],[0, 192, 0],[128, 192, 0],[0, 64, 128]]
18 | color_platte = np.array(color_platte)
19 |
20 | ############################# HELPER FUNCTIONS ########################
21 | class BBox():
22 | """
23 | bounding box class
24 | """
25 | def __init__(self, left, right, top, bottom, margin, h, w):
26 | if(margin > 0):
27 | bb_w = float(right - left)
28 | bb_h = float(bottom - top)
29 | margin_h = (bb_h * margin) / 2
30 | margin_w = (bb_w * margin) / 2
31 | left = left - margin_w
32 | right = right + margin_w
33 | top = top - margin_h
34 | bottom = bottom + margin_h
35 | self.left = max(math.floor(left), 0)
36 | self.right = min(math.ceil(right), w)
37 | self.top = max(math.floor(top), 0)
38 | self.bottom = min(math.ceil(bottom), h)
39 |
40 | def print(self):
41 | print("Left: {:n}, Right:{:n}, Top:{:n}, Bottom:{:n}".format(self.left, self.right, self.top, self.bottom))
42 |
43 | def upscale(self, scale):
44 | self.left = math.floor(self.left * scale)
45 | self.right = math.floor(self.right * scale)
46 | self.top = math.floor(self.top * scale)
47 | self.bottom = math.floor(self.bottom * scale)
48 |
49 | def add(self, bbox):
50 | self.left += bbox.left
51 | self.right += bbox.right
52 | self.top += bbox.top
53 | self.bottom += bbox.bottom
54 |
55 | def div(self, num):
56 | self.left /= num
57 | self.right /= num
58 | self.top /= num
59 | self.bottom /= num
60 |
61 | def to_one_hot(y_tensor, n_dims=9):
62 | _,h,w = y_tensor.size()
63 | """
64 | Take integer y (tensor or variable) with n dims and convert it to 1-hot representation with n+1 dims.
65 | """
66 | y_tensor = y_tensor.type(torch.LongTensor).view(-1, 1)
67 | n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1
68 | y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1)
69 | y_one_hot = y_one_hot.view(h,w,n_dims)
70 | return y_one_hot.permute(2,0,1).unsqueeze(0)
71 |
72 | def create_grid(F_size, GPU=True):
73 | """
74 | INPUTS:
75 | - F_size: feature size
76 | OUTPUT:
77 | - return a standard grid coordinate
78 | """
79 | b, c, h, w = F_size
80 | theta = torch.tensor([[1,0,0],[0,1,0]])
81 | theta = theta.unsqueeze(0).repeat(b,1,1)
82 | theta = theta.float()
83 |
84 | # grid is a uniform grid with left top (-1,1) and right bottom (1,1)
85 | # b * (h*w) * 2
86 | grid = torch.nn.functional.affine_grid(theta, F_size)
87 | if(GPU):
88 | grid = grid.cuda()
89 | return grid
90 |
91 | def seg2bbox_v2(seg, bbox_pre):
92 | """
93 | INPUTS:
94 | - seg: segmentation mask, a c*h*w one-hot tensor
95 | OUTPUS:
96 | - bbox: a c*4 tensor, indicating bbox for each object
97 | """
98 | seg = seg.squeeze()
99 | c,h,w = seg.size()
100 | bbox = {}
101 | bbox[0] = BBox(0, w, 0, h, 0, h, w)
102 | bbox[0].upscale(0.125)
103 | _, seg_int = torch.max(seg, dim=0)
104 | for cnt in range(1,c): # rule out background label
105 | seg_cnt = (seg_int == cnt) * 1
106 | # x * 2
107 | fg_idx = seg_cnt.nonzero().float()
108 |
109 | if(fg_idx.numel() > 0 and (bbox_pre[cnt] is not None)):
110 | fg_idx = torch.flip(fg_idx, (0,1))
111 |
112 | bbox_tmp = copy.deepcopy(bbox_pre[cnt])
113 | bbox_tmp.upscale(8)
114 | bbox[cnt] = coords2bbox_scale(fg_idx, h, w, bbox_tmp, margin=0.6, bandwidth=20)
115 |
116 | bbox[cnt].upscale(0.125)
117 | else:
118 | bbox[cnt] = None
119 | return bbox
120 |
121 | def seg2bbox(seg, margin=0,print_info=False):
122 | """
123 | INPUTS:
124 | - seg: segmentation mask, a c*h*w one-hot tensor
125 | OUTPUS:
126 | - bbox: a c*4 tensor, indicating bbox for each object
127 | """
128 | seg = seg.squeeze()
129 | c,h,w = seg.size()
130 | bbox = {}
131 | bbox[0] = BBox(0, w, 0, h, 0, h, w)
132 | for cnt in range(1,c): # rule out background label
133 | seg_cnt = seg[cnt]
134 | # x * 2
135 | fg_idx = seg_cnt.nonzero()
136 | if(fg_idx.numel() > 0):
137 | left = fg_idx[:,1].min()
138 | right = fg_idx[:,1].max()
139 | top = fg_idx[:,0].min()
140 | bottom = fg_idx[:,0].max()
141 | bbox[cnt] = BBox(left, right, top, bottom, margin, h, w)
142 | else:
143 | bbox[cnt] = None
144 | return bbox
145 |
146 | def gaussin(x, sigma):
147 | return torch.exp(- x ** 2 / (2 * (sigma ** 2)))
148 |
149 | def calc_center(arr, mode='mean', sigma=10):
150 | """
151 | INPUTS:
152 | - arr: an array with coordinates, shape: n
153 | - mode: 'mean' to calculate Euclean center, 'mass' to calculate mass center
154 | - sigma: Gaussian parameter if calculating mass center
155 | """
156 | eu_center = torch.mean(arr)
157 | if(mode == 'mean'):
158 | return eu_center
159 | # calculate weight center
160 | eu_center = eu_center.view(1,1).repeat(1,arr.size(0)).squeeze()
161 | diff = eu_center - arr
162 | weight = gaussin(diff, sigma)
163 | mass_center = torch.sum(weight * arr / (torch.sum(weight)))
164 | return mass_center
165 |
166 | def masked_softmax(vec, mask, dim=1, epsilon=1e-5):
167 | exps = torch.exp(vec)
168 | masked_exps = exps * mask.float()
169 | masked_sums = masked_exps.sum(dim, keepdim=True) + epsilon
170 | return (masked_exps/masked_sums)
171 |
172 | def transform_topk(aff, frame1, k):
173 | """
174 | INPUTS:
175 | - aff: affinity matrix, b * N * N
176 | - frame1: reference frame
177 | - k: only aggregate top-k pixels with highest aff(j,i)
178 | """
179 | b,c,h,w = frame1.size()
180 | b, _, N = aff.size()
181 | # b * 20 * N
182 | tk_val, tk_idx = torch.topk(aff, dim = 1, k=k)
183 | # b * N
184 | tk_val_min,_ = torch.min(tk_val, dim=1)
185 | tk_val_min = tk_val_min.view(b,1,N)
186 | aff[tk_val_min > aff] = 0
187 | aff = masked_softmax(aff, aff > 0, dim=1)
188 | frame1 = frame1.view(b,c,-1)
189 | frame2 = torch.bmm(frame1, aff)
190 | return frame2
191 |
192 | def squeeze_all(*args):
193 | res = []
194 | for arg in args:
195 | res.append(arg.squeeze())
196 | return tuple(res)
197 |
198 | def decode_seg(seg):
199 | seg = seg.squeeze()
200 | h,w = seg.size()
201 | color_seg = np.zeros((h,w,3))
202 | unis = np.unique(seg)
203 | seg = seg.numpy()
204 | for cnt,uni in enumerate(unis):
205 | xys = (seg == uni)
206 | xs,ys = np.nonzero(xys)
207 | color_seg[xs,ys,:] = color_platte[cnt].reshape(1,3)
208 | return color_seg
209 |
210 | def draw_bbox(seg,bbox,color):
211 | """
212 | INPUTS:
213 | - segmentation, h * w * 3 numpy array
214 | - coord: left, right, top, bottom
215 | OUTPUT:
216 | - seg with a drawn bbox
217 | """
218 | seg = seg.copy()
219 | pt1 = (bbox.left,bbox.top)
220 | pt2 = (bbox.right,bbox.bottom)
221 | color = np.array(color, dtype=np.uint8)
222 | c = tuple(map(int, color))
223 | seg = cv2.rectangle(seg, pt1, pt2, c, 2)
224 | return seg
225 |
226 | def vis_bbox(img, bbox, upscale=1):
227 | """
228 | INPUTS:
229 | - img: a h*w*c opencv image
230 | - bbox: a list of bounding box
231 | OUTPUT:
232 | - img: image with bbox drawn on
233 | """
234 | #for cnt in range(len(bbox)):
235 | for cnt, bbox_cnt in bbox.items():
236 | #bbox_cnt = bbox[cnt]
237 | if(bbox_cnt is None):
238 | continue
239 | bbox_cnt.upscale(upscale)
240 | color = color_platte[cnt+1]
241 | img = draw_bbox(img, bbox_cnt, color)
242 | return img
243 |
244 | def vis_bbox_pair(bbox, bbox_next, frame1_var, frame2_var, out_path):
245 | frame1 = frame1_var * 128 + 128
246 | frame1 = frame1.squeeze().permute(1,2,0)
247 | im1 = cv2.cvtColor(np.array(frame1, dtype = np.uint8), cv2.COLOR_LAB2BGR)
248 |
249 | frame2 = frame2_var * 128 + 128
250 | frame2 = frame2.squeeze().permute(1,2,0)
251 | im2 = cv2.cvtColor(np.array(frame2, dtype = np.uint8), cv2.COLOR_LAB2BGR)
252 |
253 | #for i, (bbox_cnt, bbox_next_cnt) in enumerate(zip(bbox, bbox_next)):
254 | for cnt, bbox_cnt in bbox.items():
255 | bbox_next_cnt = bbox_next[cnt]
256 | bbox_cnt.upscale(8)
257 | bbox_next_cnt.upscale(8)
258 | im1 = draw_bbox(im1, bbox_cnt, color_platte[i+1])
259 | im2 = draw_bbox(im2, bbox_next_cnt, color_platte[i+1])
260 |
261 | im = np.concatenate((im1,im2), axis=1)
262 | cv2.imwrite(out_path, im)
263 |
264 | def clean_seg(seg, bbox, threshold):
265 | c,h,w = seg.size()
266 | fgs = {}
267 | for cnt, bbox_cnt in bbox.items():
268 | if(bbox_cnt is not None):
269 | seg_cnt = seg[cnt]
270 | fg = seg_cnt.nonzero()
271 | fgs[cnt] = fg[:,[1,0]].float()
272 | else:
273 | fgs[cnt] = None
274 | fgs = clean_coords(fgs, bbox, threshold)
275 | seg_new = torch.zeros(seg.size())
276 | for cnt, fg in fgs.items():
277 | if(fg is not None):
278 | fg = fg.long()
279 | seg_new[cnt][fg[:,1],fg[:,0]] = seg[cnt][fg[:,1],fg[:,0]]
280 | return seg_new.view(1,c,h,w)
281 |
282 | def clean_coords(coord, bbox_pre, threshold):
283 | """
284 | INPUT:
285 | - coord: coordinates of foreground points, a N * 2 tensor.
286 | - center: cluster center, a 2 tensor
287 | - threshold: we cut all points larger than this threshold
288 | METHOD: Rule out outliers in coord
289 | """
290 | new_coord = {}
291 | for cnt, coord_cnt in coord.items():
292 | bbox_pre_cnt = bbox_pre[cnt]
293 | if((bbox_pre_cnt is not None) and (coord_cnt is not None)):
294 | center_h = (bbox_pre_cnt.top + bbox_pre_cnt.bottom)/2
295 | center_w = (bbox_pre_cnt.left + bbox_pre_cnt.right)/2
296 | dis = (coord_cnt[:,1] - center_h )**2 + (coord_cnt[:,0] - center_w)**2
297 | mean_ = torch.mean(dis)
298 | dis = dis / mean_
299 |
300 | idx_ = (dis <= threshold).nonzero()
301 | new_coord[cnt] = coord_cnt[idx_].squeeze()
302 | else:
303 | new_coord[cnt] = None
304 | return new_coord
305 |
306 | def coord2bbox(bbox_pre, coord, h, w, adaptive=False):
307 | avg_h = calc_center(coord[:,1], mode='mass')
308 | avg_w = calc_center(coord[:,0], mode='mass')
309 |
310 | if adaptive:
311 | center = torch.mean(coord, dim = 0)
312 | dis_h = coord[:,1] - center[1]
313 | dis_w = coord[:,0] - center[0]
314 |
315 | dis_h = torch.mean(dis_h * dis_h, dim = 0)
316 | bb_height = (dis_h ** 0.5) * 8
317 |
318 | dis_w = torch.mean(dis_w * dis_w, dim = 0)
319 | bb_width = (dis_w ** 0.5) * 8
320 |
321 | # the adaptive method is sentitive to outliers, let's assume there's no dramatical change within
322 | # short range, so the height should not grow larger than 1.2 times height in previous frame
323 | bb_height = torch.min(bb_height, torch.Tensor([bbox_pre.bottom - bbox_pre.top]) * 1.2)
324 | bb_width = torch.min(bb_width, torch.Tensor([bbox_pre.right - bbox_pre.left]) * 1.2)
325 |
326 | bb_height = torch.max(bb_height, torch.Tensor([bbox_pre.bottom - bbox_pre.top]) * 0.8)
327 | bb_width = torch.max(bb_width, torch.Tensor([bbox_pre.right - bbox_pre.left]) * 0.8)
328 |
329 | else:
330 | bb_width = float(bbox_pre.right - bbox_pre.left)
331 | bb_height = float(bbox_pre.bottom - bbox_pre.top)
332 | left = avg_w - bb_width/2.0
333 | right = avg_w + bb_width/2.0
334 | top = avg_h - bb_height/2.0
335 | bottom = avg_h + bb_height/2.0
336 |
337 | coord_left = coord[:,0].min()
338 | coord_right = coord[:,0].max()
339 | coord_top = coord[:,1].min()
340 | coord_bottom = coord[:,1].max()
341 |
342 | bbox_tar_ = BBox(left = int(max(left, 0)),
343 | right = int(min(right, w)),
344 | top = int(max(top, 0)),
345 | bottom = int(min(bottom, h)),
346 | margin = 0, h = h, w = w)
347 | return bbox_tar_
348 |
349 | def post_process_seg(seg_pred):
350 | frame2_seg_bbox = torch.nn.functional.interpolate(seg_pred,scale_factor=8,mode='bilinear')
351 | frame2_seg_bbox = norm_mask(frame2_seg_bbox.squeeze())
352 | _, frame2_seg_bbox = torch.max(frame2_seg_bbox, dim=0)
353 | return frame2_seg_bbox
354 |
355 | def post_process_bbox(seg, bbox_pre):
356 | fg_idx = seg.nonzero()
357 |
358 | bbox_pre_cnt = bbox_pre[cnt]
359 | if(bbox_pre_cnt is not None):
360 | bbox_pre_cnt.upscale(8)
361 | center_h = (bbox_pre_cnt.top + bbox_pre_cnt.bottom) / 2
362 | center_w = (bbox_pre_cnt.left + bbox_pre_cnt.right) / 2
363 | fg_idx = clean_coord(fg_idx.float().cuda(), torch.Tensor([center_h, center_w]),keep_ratio=0.9)
364 |
365 | def scatter_point(coord, name, w, h, center=None):
366 | fig, ax = plt.subplots()
367 | ax.axis((0,w,0,h))
368 | ax.scatter(coord[:,0], h - coord[:,1])
369 | ax.scatter(center[0], h - center[1], marker='^')
370 |
371 | plt.savefig(name)
372 | plt.clf()
373 |
374 | def shift_bbox(seg, bbox):
375 | c,h,w = seg.size()
376 | bbox_new = {}
377 | for cnt in range(c):
378 | seg_cnt = seg[cnt]
379 | bbox_cnt = bbox[cnt]
380 | if(bbox_cnt is not None):
381 | fg_idx = seg_cnt.nonzero()
382 | fg_idx = fg_idx[:,[1,0]].float()
383 | center_h = calc_center(fg_idx[:,1], mode='mass')
384 | center_w = calc_center(fg_idx[:,0], mode='mass')
385 |
386 | # shift bbox w.r.t new center
387 | old_h = (bbox_cnt.top + bbox_cnt.bottom) / 2
388 | old_w = (bbox_cnt.left + bbox_cnt.right) / 2
389 | bb_width = bbox_cnt.right - bbox_cnt.left
390 | bb_height = bbox_cnt.bottom - bbox_cnt.top
391 |
392 | left = center_w - bb_width/2
393 | right = center_w + bb_width/2
394 | top = center_h - bb_height/2
395 | bottom = center_h + bb_height/2
396 | bbox_new[cnt] = BBox(left = left, right = right,
397 | top = top, bottom = bottom,
398 | margin = 0, h = h, w = w)
399 | else:
400 | bbox_new[cnt] = None
401 | return bbox_new
402 |
403 | ############################# MATCHING FUNCTIONS ########################
404 | def match_ref_tar(F_ref, F_tar, seg_ref, temp):
405 | """
406 | INPUTS:
407 | - F_ref: feature of reference frame
408 | - F_tar: feature of target frame
409 | - seg_ref: segmentation of reference frame
410 | - temp: temperature of softmax
411 | METHOD:
412 | - take foreground pixels from the reference frame and match them to the
413 | target frame.
414 | RETURNS:
415 | - coord: a list of coordinates of foreground pixels in the target frame.
416 | """
417 | coords = {}
418 | c, h, w = F_ref.size()
419 | F_ref_flat = F_ref.view(c, -1)
420 | F_tar_flat = F_tar.view(c, -1)
421 |
422 | grid = create_grid(F_ref.unsqueeze(0).size()).squeeze()
423 | grid[:,:,0] = (grid[:,:,0]+1)/2 * w
424 | grid[:,:,1] = (grid[:,:,1]+1)/2 * h
425 | grid_flat = grid.view(-1,2)
426 |
427 | for cnt in range(seg_ref.size(0)):
428 | seg_cnt = seg_ref[cnt, :, :].view(-1)
429 | # there's no mask for this channel
430 | if(torch.sum(seg_cnt) < 2):
431 | coords[cnt] = None
432 | continue
433 | if(cnt > 0):
434 | fg_idx = seg_cnt.nonzero()
435 | F_ref_cnt_flat = F_ref_flat[:,fg_idx].squeeze()
436 |
437 | else:
438 | # for the background class, we just take the whole frame
439 | F_ref_cnt_flat = F_ref_flat
440 | aff = torch.mm(F_ref_cnt_flat.permute(1,0), F_tar_flat)
441 | aff = torch.nn.functional.softmax(aff*temp, dim = 1)
442 | coord = torch.mm(aff, grid_flat)
443 | coords[cnt] = coord
444 | return coords
445 |
446 | def weighted_center(coords, center):
447 | """
448 | in_range = []
449 | for cnt in range(coords.shape[1]):
450 | coord_i = coords[0,cnt,:]
451 | if(np.linalg.norm(coord_i - prev_center) < bandwidth):
452 | in_range.append(coord_i)
453 | in_range = np.array(in_range)
454 | new_center = np.mean(in_range, axis=0)
455 | """
456 | center = center.reshape(1, 1, 2)
457 |
458 | dis_x = np.sqrt(np.power(coords[:,:,0] - center[:,:,0], 2))
459 | weight_x = 1 / dis_x
460 | weight_x = weight_x / np.sum(weight_x)
461 | dis_y = np.sqrt(np.power(coords[:,:,1] - center[:,:,1], 2))
462 | weight_y = 1 / dis_y
463 | weight_y = weight_y / np.sum(weight_y)
464 |
465 | new_x = np.sum(weight_x * coords[:,:,0])
466 | new_y = np.sum(weight_y * coords[:,:,1])
467 |
468 | return np.array([new_x, new_y]).reshape(1,1,2)
469 |
470 | def euclid_distance(x, xi):
471 | return np.sqrt(np.sum((x - xi)**2))
472 |
473 | def neighbourhood_points(X, x_centroid, distance = 20):
474 | eligible_X = []
475 | #for x in X:
476 | for cnt in range(X.shape[0]):
477 | x = X[cnt,:]
478 | distance_between = euclid_distance(x, x_centroid)
479 | if distance_between <= distance:
480 | eligible_X.append(x)
481 | return eligible_X
482 |
483 | def gaussian_kernel(distance, bandwidth):
484 | val = (1/(bandwidth*math.sqrt(2*math.pi))) * np.exp(-0.5*((distance / bandwidth))**2)
485 | return val
486 |
487 | def mean_shift_center(coords, bandwidth=20):
488 | """
489 | INPUTS:
490 | - coords: coordinates of pixels in the next frame, 1xNx2
491 | """
492 | # 1 x 2
493 | avg_center = np.mean(coords, axis=0)
494 | prev_center = copy.deepcopy(avg_center)
495 |
496 | # pick the most close point as the center
497 | minimum = 100000
498 | for cnt in range(coords.shape[0]):
499 | coord_i = coords[cnt,:].reshape(1,2)
500 | dis = np.linalg.norm(coord_i - avg_center)
501 | if(dis < minimum):
502 | minimum = dis
503 | prev_center = copy.deepcopy(coord_i)
504 |
505 | counter = 0
506 | while True:
507 | counter += 1
508 | neighbors = neighbourhood_points(coords, prev_center.reshape(1,2), bandwidth)
509 | numerator = 0
510 | denominator = 0
511 | for neighbor in neighbors:
512 | distance = euclid_distance(neighbor, prev_center)
513 | weight = gaussian_kernel(distance, 4)
514 | numerator += (weight * neighbor)
515 | denominator += weight
516 | new_center = numerator / denominator
517 |
518 | if(np.sum(new_center - prev_center) < 0.1):
519 | final_center = torch.from_numpy(new_center)
520 | return final_center.view(1,2)
521 | else:
522 | prev_center = copy.deepcopy(new_center)
523 |
524 | def coords2bbox_scale(coords, h_tar, w_tar, bbox_pre, margin, bandwidth, log=False):
525 | """
526 | INPUTS:
527 | - coords: coordinates of pixels in the next frame
528 | - h_tar: target image height
529 | - w_tar: target image widthg
530 | """
531 | b = 1
532 | center = mean_shift_center(coords.numpy(), bandwidth)
533 |
534 | center_repeat = center.repeat(coords.size(0),1)
535 |
536 | dis_x = torch.sqrt(torch.pow(coords[:,0] - center_repeat[:,0], 2))
537 | dis_x = torch.mean(dis_x, dim=0).detach()
538 | dis_y = torch.sqrt(torch.pow(coords[:,1] - center_repeat[:,1], 2))
539 | dis_y = torch.mean(dis_y, dim=0).detach()
540 |
541 | left = (center[:,0] - dis_x*2).view(b,1)
542 | right = (center[:,0] + dis_x*2).view(b,1)
543 | top = (center[:,1] - dis_y*2).view(b,1)
544 | bottom = (center[:,1] + dis_y*2).view(b,1)
545 |
546 |
547 | bbox_tar_ = BBox(left = max(left, 0),
548 | right = min(right, w_tar),
549 | top = max(top, 0),
550 | bottom = min(bottom, h_tar),
551 | margin = margin, h = h_tar, w = w_tar)
552 |
553 |
554 | return bbox_tar_
555 |
556 | def coord_wrt_bbox(coord, bbox):
557 | new_coord = []
558 | left = bbox.left
559 | right = bbox.right
560 | top = bbox.top
561 | bottom = bbox.bottom
562 | for cnt in range(coord.size(0)):
563 | coord_i = coord[cnt]
564 | if(coord_i[0] >= left and coord_i[0] <= right and coord_i[1] >= top and coord_i[1] <= bottom):
565 | new_coord.append(coord_i)
566 | new_coord = torch.stack(new_coord)
567 | return new_coord
568 |
569 | def bbox_in_tar_scale(coords_tar, bbox_ref, h, w):
570 | """
571 | INPUTS:
572 | - coords_tar: foreground coordinates in the target frame
573 | - bbox_ref: bboxs in the reference frame
574 | METHOD:
575 | - calculate bbox in the next frame w.r.t pixels coordinates
576 | RETURNS:
577 | - each bbox in the tar frame
578 | """
579 | bbox_tar = {}
580 | #for cnt in range(len(bbox_ref)):
581 | for cnt, bbox_cnt in bbox_ref.items():
582 | if(cnt == 0):
583 | bbox_tar_ = BBox(left = 0,
584 | right = w,
585 | top = 0,
586 | bottom = h,
587 | margin = 0.6, h = h, w = w)
588 | elif(bbox_cnt is not None):
589 | if not (cnt in coords_tar):
590 | continue
591 | coord = coords_tar[cnt]
592 | if(coord is None):
593 | continue
594 | coord = coord.cpu()
595 |
596 | bbox_tar_ = coords2bbox_scale(coord, h, w, bbox_cnt, margin=1, bandwidth=5, log=(cnt==3))
597 | else:
598 | bbox_tar_ = None
599 |
600 | bbox_tar[cnt] = bbox_tar_
601 | return bbox_tar
602 |
603 | ############################# depreciated code, keep for safety ########################
604 | """
605 | def bbox_in_tar(coords_tar, bbox_ref, h, w):
606 | INPUTS:
607 | - coords_tar: foreground coordinates in the target frame
608 | - bbox_ref: bboxs in the reference frame
609 | METHOD:
610 | - calculate bbox in the next frame w.r.t pixels coordinates
611 | RETURNS:
612 | - each bbox in the tar frame
613 | bbox_tar = {}
614 | for cnt, bbox_cnt in bbox_ref.items():
615 | if(cnt == 0):
616 | bbox_tar_ = BBox(left = 0,
617 | right = w,
618 | top = 0,
619 | bottom = h,
620 | margin = 0.5, h = h, w = w)
621 | elif(bbox_cnt is not None):
622 | if not (cnt in coords_tar):
623 | continue
624 | coord = coords_tar[cnt]
625 | center_h = (bbox_cnt.top + bbox_cnt.bottom) / 2
626 | center_w = (bbox_cnt.left + bbox_cnt.right) / 2
627 | if(coord is None):
628 | continue
629 | coord = coord.cpu()
630 |
631 | bbox_tar_ = coord2bbox(bbox_cnt, coord, h, w, adaptive=False)
632 | else:
633 | bbox_tar_ = None
634 |
635 | bbox_tar[cnt] = bbox_tar_
636 | return bbox_tar
637 |
638 | def bbox_in_tar_v2(coords_tar, bbox_ref, h, w, seg_pre):
639 | INPUTS:
640 | - coords_tar: foreground coordinates in the target frame
641 | - bbox_ref: bboxs in the reference frame
642 | METHOD:
643 | - calculate bbox in the next frame w.r.t pixels coordinates
644 | RETURNS:
645 | - each bbox in the tar frame
646 | VERSION NOTE:
647 | - include scaling modeling
648 | bbox_tar = {}
649 | for cnt, bbox_cnt in bbox_ref.items():
650 | # for each channel
651 | if(cnt == 0):
652 | # for background channel
653 | bbox_tar_ = BBox(left = 0,
654 | right = w,
655 | top = 0,
656 | bottom = h,
657 | margin = 0.5, h = h, w = w)
658 | elif(bbox_cnt is not None):
659 | coord = coords_tar[cnt]
660 | if(coord is None):
661 | continue
662 | coord = coord.cpu()
663 |
664 | bbox_tar_ = coord2bbox(bbox_cnt, coord, h, w, seg_pre, adaptive=True)
665 | else:
666 | bbox_tar_ = None
667 |
668 | bbox_tar[cnt] = bbox_tar_
669 | return bbox_tar
670 |
671 | def recoginition(F_ref, F_tar, bbox_ref, bbox_tar, seg_ref, temp):
672 | - F_ref: feature of reference frame
673 | - F_tar: feature of target frame
674 | - bbox_ref: bboxes of reference frame
675 | - bbox_tar: bboxes of target frame
676 | - seg_ref: segmentation of reference frame
677 | c, h, w = F_tar.size()
678 | seg_pred = torch.zeros(seg_ref.size())
679 | #for cnt,(br, bt) in enumerate(zip(bbox_ref, bbox_tar)):
680 | for cnt, br in bbox_ref.items():
681 | bt = bbox_tar[cnt]
682 | if(br is None or bt is None):
683 | continue
684 | seg_cnt = seg_ref[cnt]
685 |
686 | # feature of patch in the next frame
687 | F_tar_box = F_tar[:, bt.top:bt.bottom, bt.left:bt.right]
688 | F_ref_box = F_ref[:, br.top:br.bottom, br.left:br.right]
689 | F_tar_box_flat = F_tar_box.contiguous().view(c,-1)
690 | F_ref_box_flat = F_ref_box.contiguous().view(c,-1)
691 |
692 | # affinity between two patches
693 | aff = torch.mm(F_ref_box_flat.permute(1,0), F_tar_box_flat)
694 | aff = torch.nn.functional.softmax(aff * temp, dim=0)
695 | # transfer segmentation from patch1 to patch2
696 | seg_ref_box = seg_cnt[br.top:br.bottom, br.left:br.right]
697 | aff = aff.cpu()
698 | if(cnt == 0):
699 | seg_ref_box_flat = seg_ref_box.contiguous().view(-1)
700 | seg_tar_box = torch.mm(seg_ref_box_flat.unsqueeze(0), aff).squeeze()
701 | else:
702 | seg_ref_box_flat = seg_ref_box.contiguous().view(-1)
703 | seg_tar_box = torch.mm(seg_ref_box_flat.unsqueeze(0), aff).squeeze()
704 | #seg_tar_box = transform_topk(aff.unsqueeze(0), seg_ref_box.contiguous().unsqueeze(0).unsqueeze(0), 20)
705 | seg_tar_box = seg_tar_box.view(F_tar_box.size(1), F_tar_box.size(2))
706 |
707 | seg_pred[cnt,bt.top:bt.bottom, bt.left:bt.right] = seg_tar_box
708 | return seg_pred
709 |
710 | def bbox_next_frame_v2(F_first, F_pre, seg_pre, seg_first, F_tar, bbox_first, bbox_pre, temp, direct=False):
711 | INPUTS:
712 | - direct: rec|direct,
713 | - if False, use previous frame to locate bbox
714 | - if True, use first frame to locate bbox
715 | F_first, F_pre, seg_pre, seg_first, F_tar = squeeze_all(F_first, F_pre, seg_pre, seg_first, F_tar)
716 | c, h, w = F_first.size()
717 | if not direct:
718 | coords_tar = match_ref_tar(F_pre, F_tar, seg_pre, temp)
719 | else:
720 | coords_tar = match_ref_tar(F_first, F_tar, seg_first, temp)
721 |
722 | bbox_tar = bbox_in_tar(coords_tar, bbox_first, h, w)
723 |
724 | seg_pred = recoginition(F_first, F_tar, bbox_first, bbox_tar, seg_first, temp)
725 | return seg_pred.unsqueeze(0)
726 |
727 | def bbox_next_frame_v3(F_first, F_pre, seg_pre, seg_first, F_tar, bbox_first, bbox_pre, temp, name):
728 | METHOD: combining tracking & direct recognition, calculate bbox in target frame
729 | using both first frame and previous frame.
730 | F_first, F_pre, seg_pre, seg_first, F_tar = squeeze_all(F_first, F_pre, seg_pre, seg_first, F_tar)
731 | c, h, w = F_first.size()
732 |
733 | coords_pre_tar = match_ref_tar(F_pre, F_tar, seg_pre, temp)
734 | coords_first_tar = match_ref_tar(F_first, F_tar, seg_first, temp)
735 | coords_tar = {}
736 | for cnt, coord_first in coords_first_tar.items():
737 | coord_pre = coords_pre_tar[cnt]
738 | # fall-back schema
739 | if(coord_pre is None):
740 | coord_tar_ = coord_first
741 | else:
742 | coord_tar_ = coord_pre
743 | coords_tar[cnt] = coord_tar_
744 | _, seg_pre_idx = torch.max(seg_pre, dim = 0)
745 |
746 | coords_tar = clean_coords(coords_tar, bbox_pre, threshold=4)
747 | bbox_tar = bbox_in_tar(coords_tar, bbox_first, h, w)
748 |
749 | # recoginition
750 | seg_pred = recoginition(F_first, F_tar, bbox_first, bbox_tar, seg_first, temp)
751 | seg_cleaned = clean_seg(seg_pred, bbox_tar, threshold=1)
752 |
753 | # move bbox w.r.t cleaned seg
754 | bbox_tar = shift_bbox(seg_cleaned, bbox_tar)
755 |
756 | seg_post = post_process_seg(seg_pred.unsqueeze(0))
757 | return seg_pred, seg_post, bbox_tar
758 |
759 | def bbox_next_frame_v4(F_first, F_pre, seg_pre, seg_first, F_tar, bbox_first,
760 | bbox_pre, temp):
761 | METHOD: combining tracking & direct recognition, calculate bbox in target frame
762 | using both first frame and previous frame.
763 | Version Note: include bounding box scaling
764 | F_first, F_pre, seg_pre, seg_first, F_tar = squeeze_all(F_first, F_pre, seg_pre, seg_first, F_tar)
765 | c, h, w = F_first.size()
766 |
767 | coords_pre_tar = match_ref_tar(F_pre, F_tar, seg_pre, temp)
768 | coords_first_tar = match_ref_tar(F_first, F_tar, seg_first, temp)
769 | coords_tar = {}
770 | for cnt, coord_first in coords_first_tar.items():
771 | coord_pre = coords_pre_tar[cnt]
772 | # fall-back schema
773 | if(coord_pre is None):
774 | coord_tar_ = coord_first
775 | else:
776 | coord_tar_ = coord_pre
777 | coords_tar[cnt] = coord_tar_
778 | _, seg_pre_idx = torch.max(seg_pre, dim = 0)
779 |
780 | coords_tar = clean_coords(coords_tar, bbox_pre, threshold=4)
781 |
782 | bbox_tar = bbox_in_tar_v2(coords_tar, bbox_first, h, w, seg_pre)
783 |
784 | # recoginition
785 | seg_pred = recoginition(F_first, F_tar, bbox_first, bbox_tar, seg_first, temp)
786 | seg_cleaned = clean_seg(seg_pred, bbox_tar, threshold=1)
787 |
788 | # move bbox w.r.t cleaned seg
789 | bbox_tar = shift_bbox(seg_cleaned, bbox_tar)
790 |
791 | seg_post = post_process_seg(seg_pred.unsqueeze(0))
792 | return seg_pred, seg_post, bbox_tar
793 |
794 | """
795 | """
796 | def bbox_next_frame(F_ref, seg_ref, F_tar, bbox, temp):
797 | # b * h * w * 2
798 | b, c, h, w = F_ref.size()
799 | grid = create_grid(F_ref.size()).squeeze()
800 | grid[:,:,0] = (grid[:,:,0]+1)/2 * w
801 | grid[:,:,1] = (grid[:,:,1]+1)/2 * h
802 | # grid_flat: (h * w) * 2
803 | grid_flat = grid.view(-1,2)
804 | seg_ref = seg_ref.squeeze()
805 | F_ref = F_ref.squeeze()
806 | F_ref_flat = F_ref.view(c,-1)
807 | F_tar = F_tar.squeeze()
808 | F_tar_flat = F_tar.view(c,-1)
809 | bbox_next = []
810 | seg_pred = torch.zeros(seg_ref.size())
811 | for i in range(0,seg_ref.size(0)):
812 | seg_cnt = seg_ref[i,:,:].contiguous().view(-1)
813 | if(seg_cnt.max() == 0):
814 | continue
815 | bbox_cnt = bbox[i]
816 | if(i > 0):
817 | fg_idx = seg_cnt.nonzero()
818 | # take pixels of this instance out
819 | F_ref_cnt_flat = F_ref_flat[:,fg_idx].squeeze()
820 |
821 | # affinity between patch and target frame
822 | # aff: (hh * ww, h * w)
823 | aff = torch.mm(F_ref_cnt_flat.permute(1,0), F_tar_flat)
824 | aff = torch.nn.functional.softmax(aff*temp, dim = 1)
825 | # coord of this patch in next frame: (hh*ww) * 2
826 | coord = torch.mm(aff, grid_flat)
827 | avg_h = calc_center(coord[:,1], mode='mass').cpu().long()
828 | avg_w = calc_center(coord[:,0], mode='mass').cpu().long()
829 | bb_width = bbox_cnt.right - bbox_cnt.left
830 | bb_height = bbox_cnt.bottom - bbox_cnt.top
831 | bbox_next_ = BBox(left = max(avg_w - bb_width/2,0),
832 | right = min(avg_w + bb_width/2, w),
833 | top = max(avg_h - bb_height/2,0),
834 | bottom = min(avg_h + bb_height/2,h),
835 | margin = 0, h = h, w = w)
836 | else:
837 | bbox_next_ = BBox(left = 0,
838 | right = w,
839 | top = 0,
840 | bottom = h,
841 | margin = 0, h = h, w = w)
842 | bbox_next.append(bbox_next_)
843 |
844 | # feature of patch in the next frame
845 | F_tar_box = F_tar[:, bbox_next_.top:bbox_next_.bottom, bbox_next_.left:bbox_next_.right]
846 | F_ref_box = F_ref[:, bbox_cnt.top:bbox_cnt.bottom, bbox_cnt.left:bbox_cnt.right]
847 | F_tar_box_flat = F_tar_box.contiguous().view(c,-1)
848 | F_ref_box_flat = F_ref_box.contiguous().view(c,-1)
849 |
850 | # affinity between two patches
851 | aff = torch.mm(F_ref_box_flat.permute(1,0), F_tar_box_flat)
852 | aff = torch.nn.functional.softmax(aff * temp, dim=0)
853 | # transfer segmentation from patch1 to patch2
854 | seg_ref_box = seg_ref[i, bbox_cnt.top:bbox_cnt.bottom, bbox_cnt.left:bbox_cnt.right]
855 | aff = aff.cpu()
856 | seg_ref_box_flat = seg_ref_box.contiguous().view(-1)
857 | seg_tar_box = torch.mm(seg_ref_box_flat.unsqueeze(0), aff).squeeze()
858 | seg_tar_box = seg_tar_box.view(F_tar_box.size(1),F_tar_box.size(2))
859 |
860 | seg_pred[i,bbox_next_.top:bbox_next_.bottom, bbox_next_.left:bbox_next_.right] = seg_tar_box
861 |
862 | #seg_pred[0,:,:] = 1 - torch.sum(seg_pred[1:,:,:],dim=0)
863 |
864 | return seg_pred.unsqueeze(0), bbox_next
865 | def bbox_next_frame_rec(F_first, F_ref, seg_ref, seg_first,
866 | F_tar, bbox_first, bbox, temp):
867 | # b * h * w * 2
868 | b, c, h, w = F_ref.size()
869 | seg_ref = seg_ref.squeeze()
870 | F_ref = F_ref.squeeze()
871 | F_ref_flat = F_ref.view(c,-1)
872 | F_tar = F_tar.squeeze()
873 | F_tar_flat = F_tar.view(c,-1)
874 | F_first = F_first.squeeze()
875 | bbox_next = []
876 | seg_pred = torch.zeros(seg_ref.size())
877 | seg_first = seg_first.squeeze()
878 | for i in range(0,seg_ref.size(0)):
879 | seg_cnt = seg_ref[i,:,:].contiguous().view(-1)
880 | if(seg_cnt.max() == 0):
881 | continue
882 | if(i > len(bbox)-1):
883 | continue
884 | bbox_cnt = bbox[i]
885 | bbox_first_cnt = bbox_first[i]
886 | if(i > 0):
887 | fg_idx = seg_cnt.nonzero()
888 | F_ref_cnt_flat = F_ref_flat[:,fg_idx].squeeze()
889 |
890 | # affinity between patch and target frame
891 | if(F_ref_cnt_flat.dim() < 2):
892 | # some small objects may miss
893 | continue
894 | aff = torch.mm(F_ref_cnt_flat.permute(1,0), F_tar_flat)
895 | aff = torch.nn.functional.softmax(aff*temp, dim = 1)
896 | coord = torch.mm(aff, grid_flat)
897 | #coord = transform_topk(aff.unsqueeze(0), grid.unsqueeze(0), dim=2, k=20)
898 | avg_h = calc_center(coord[:,1], mode='mass').cpu()
899 | avg_w = calc_center(coord[:,0], mode='mass').cpu()
900 | bb_width = float(bbox_first_cnt.right - bbox_first_cnt.left)
901 | bb_height = float(bbox_first_cnt.bottom - bbox_first_cnt.top)
902 | coord = coord.cpu()
903 |
904 | left = avg_w - bb_width/2.0
905 | right = avg_w + bb_width/2.0
906 | top = avg_h - bb_height/2.0
907 | bottom = avg_h + bb_height/2.0
908 |
909 | bbox_next_ = BBox(left = int(max(left, 0)),
910 | right = int(min(right, w)),
911 | top = int(max(top, 0)),
912 | bottom = int(min(bottom, h)),
913 | margin = 0, h = h, w = w)
914 | else:
915 | bbox_next_ = BBox(left = 0,
916 | right = w,
917 | top = 0,
918 | bottom = h,
919 | margin = 0, h = h, w = w)
920 |
921 | bbox_next.append(bbox_next_)
922 |
923 | # feature of patch in the next frame
924 | F_tar_box = F_tar[:, bbox_next_.top:bbox_next_.bottom, bbox_next_.left:bbox_next_.right]
925 | F_ref_box = F_first[:, bbox_first_cnt.top:bbox_first_cnt.bottom, bbox_first_cnt.left:bbox_first_cnt.right]
926 | F_tar_box_flat = F_tar_box.contiguous().view(c,-1)
927 | F_ref_box_flat = F_ref_box.contiguous().view(c,-1)
928 | print('================')
929 |
930 | # affinity between two patches
931 | aff = torch.mm(F_ref_box_flat.permute(1,0), F_tar_box_flat)
932 | aff = torch.nn.functional.softmax(aff * temp, dim=0)
933 | # transfer segmentation from patch1 to patch2
934 | seg_ref_box = seg_first[i, bbox_first_cnt.top:bbox_first_cnt.bottom, bbox_first_cnt.left:bbox_first_cnt.right]
935 | aff = aff.cpu()
936 | seg_ref_box_flat = seg_ref_box.contiguous().view(-1)
937 | seg_tar_box = torch.mm(seg_ref_box_flat.unsqueeze(0), aff).squeeze()
938 | #seg_tar_box = transform_topk(aff.unsqueeze(0), seg_ref_box.contiguous().unsqueeze(0).unsqueeze(0))
939 | seg_tar_box = seg_tar_box.view(F_tar_box.size(1),F_tar_box.size(2))
940 |
941 | seg_pred[i,bbox_next_.top:bbox_next_.bottom, bbox_next_.left:bbox_next_.right] = seg_tar_box
942 |
943 | return seg_pred.unsqueeze(0), bbox_next
944 |
945 | def transform_topk(aff, frame1, k=20, dim=1):
946 | b,c,h,w = frame1.size()
947 | b, N1, N2 = aff.size()
948 | # b * 20 * N
949 | tk_val, tk_idx = torch.topk(aff, dim = dim, k=k)
950 | # b * N
951 | tk_val_min,_ = torch.min(tk_val,dim=dim)
952 | if(dim == 1):
953 | tk_val_min = tk_val_min.view(b,1,N2)
954 | else:
955 | tk_val_min = tk_val_min.view(b,N1,1)
956 | aff[tk_val_min > aff] = 0
957 | frame1 = frame1.view(b,c,-1)
958 | if(dim == 1):
959 | frame2 = torch.bmm(frame1, aff)
960 | return frame2
961 | else:
962 | frame2 = torch.bmm(aff, frame1.permute(0,2,1))
963 | return frame2.squeeze()
964 | """
965 |
--------------------------------------------------------------------------------
/libs/train_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import shutil
5 | # import visdom
6 | import numpy as np
7 | from os.path import join
8 |
9 | def save_vis(pred, gt2, gt1, out_dir, gt_grey=False, prefix=0):
10 | """
11 | INPUTS:
12 | - pred: predicted Lab image, a 3xhxw tensor
13 | - gt2: second GT frame, a 3xhxw tensor
14 | - gt1: first GT frame, a 3xhxw tensor
15 | - out_dir: output image save path
16 | - gt_grey: whether to use ground trught L channel in predicted image
17 | """
18 | b = pred.size(0)
19 | pred = pred * 128 + 128
20 | gt1 = gt1 * 128 + 128
21 | gt2 = gt2 * 128 + 128
22 |
23 | if(gt_grey):
24 | pred[:,0,:,:] = gt2[:,0,:,:]
25 | for cnt in range(b):
26 | im = pred[cnt].cpu().detach().numpy().transpose( 1, 2, 0)
27 | im_bgr = cv2.cvtColor(np.array(im, dtype = np.uint8), cv2.COLOR_LAB2BGR)
28 | im_pred = np.clip(im_bgr, 0, 255)
29 |
30 | im = gt2[cnt].cpu().detach().numpy().transpose( 1, 2, 0)
31 | im_gt2 = cv2.cvtColor(np.array(im, dtype = np.uint8), cv2.COLOR_LAB2BGR)
32 |
33 | im = gt1[cnt].cpu().detach().numpy().transpose( 1, 2, 0)
34 | im_gt1 = cv2.cvtColor(np.array(im, dtype = np.uint8), cv2.COLOR_LAB2BGR)
35 |
36 | im = np.concatenate((im_gt1, im_gt2, im_pred), axis = 1)
37 | print(out_dir, "{:02d}{:02d}.png".format(prefix, cnt))
38 | cv2.imwrite(join(out_dir, "{:02d}{:02d}.png".format(prefix, cnt)), im)
39 |
40 | def save_vis_ae(pred, gt, savepath):
41 | b = pred.size(0)
42 | for cnt in range(b):
43 | im = pred[cnt].cpu().detach() * 128 + 128
44 | im = im.numpy().transpose(1,2,0)
45 | im_pred = cv2.cvtColor(np.array(im, dtype = np.uint8), cv2.COLOR_LAB2BGR)
46 |
47 | im = gt[cnt].cpu().detach() * 128 + 128
48 | im = im.numpy().transpose(1,2,0)
49 | im_gt = cv2.cvtColor(np.array(im, dtype = np.uint8), cv2.COLOR_LAB2BGR)
50 |
51 | im = np.concatenate((im_gt, im_pred), axis = 1)
52 | cv2.imwrite(os.path.join(savepath, "{:02d}.png".format(cnt)), im)
53 |
54 | class AverageMeter(object):
55 | """Computes and stores the average and current value"""
56 | def __init__(self):
57 | self.reset()
58 |
59 | def reset(self):
60 | self.val = 0
61 | self.avg = 0
62 | self.sum = 0
63 | self.count = 0
64 |
65 | def update(self, val, n=1):
66 | self.val = val
67 | self.sum += val * n
68 | self.count += n
69 | self.avg = self.sum / self.count
70 |
71 | def save_checkpoint(state, is_best, filename="checkpoint.pth.tar", savedir="models"):
72 | torch.save(state, filename)
73 | if is_best:
74 | shutil.copyfile(filename, os.path.join(savedir, 'model_best.pth.tar'))
75 |
76 | def sample_patch(b,h,w,patch_size):
77 | left = randint(0, max(w - patch_size,1))
78 | top = randint(0, max(h - patch_size,1))
79 | right = left + patch_size
80 | bottom = top + patch_size
81 | return torch.Tensor([left, right, top, bottom]).view(1,4).repeat(b,1).cuda()
82 |
83 |
84 | def diff_crop(F, x1, y1, x2, y2, ph, pw):
85 | """
86 | Differatiable cropping
87 | INPUTS:
88 | - F: frame feature
89 | - x1,y1,x2,y2: top left and bottom right points of the patch
90 | - theta is defined as :
91 | a b c
92 | d e f
93 | """
94 | bs, ch, h, w = F.size()
95 | a = ((x2-x1)/w).view(bs,1,1)
96 | b = torch.zeros(a.size()).cuda()
97 | c = (-1+(x1+x2)/w).view(bs,1,1)
98 | d = torch.zeros(a.size()).cuda()
99 | e = ((y2-y1)/h).view(bs,1,1)
100 | f = (-1+(y2+y1)/h).view(bs,1,1)
101 | theta_row1 = torch.cat((a,b,c),dim=2)
102 | theta_row2 = torch.cat((d,e,f),dim=2)
103 | theta = torch.cat((theta_row1, theta_row2),dim=1).cuda()
104 | size = torch.Size((bs,ch,pw,ph))
105 | grid = FUNC.affine_grid(theta, size)
106 | patch = FUNC.grid_sample(F,grid)
107 | return patch
108 |
109 | def log_current(epoch, loss_ave, best_loss, filename = "log_current.txt", savedir="models"):
110 | file = join(savedir, filename)
111 | with open(file, "a") as text_file:
112 | print("epoch: {}".format(epoch), file=text_file)
113 | print("best_loss: {}".format(best_loss), file=text_file)
114 | print("current_loss: {}".format(loss_ave), file=text_file)
115 |
116 | def print_options(opt):
117 | message = ''
118 | message += '----------------- Options ---------------\n'
119 | for k, v in sorted(vars(opt).items()):
120 | comment = ''
121 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
122 | message += '----------------- End -------------------'
123 | print(message)
124 |
125 | # save to the disk
126 | expr_dir = os.path.join(opt.savedir)
127 | file_name = os.path.join(expr_dir, 'opt.txt')
128 | with open(file_name, 'wt') as opt_file:
129 | opt_file.write(message)
130 | opt_file.write('\n')
131 |
--------------------------------------------------------------------------------
/libs/transforms_multi.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | import random
3 | import scipy.io
4 | import cv2
5 | import numpy as np
6 | from PIL import Image, ImageOps, ImageEnhance, PILLOW_VERSION
7 | try:
8 | import accimage
9 | except ImportError:
10 | accimage = None
11 | import torch
12 |
13 | class Scale(object):
14 | """docstring for Scale"""
15 | def __init__(self, short_side):
16 | self.short_side = short_side
17 |
18 | def __call__(self, pair_0, pair_1):
19 | if(type(self.short_side) == int):
20 | h,w,c = pair_0.shape
21 | if(h > w):
22 | tw = self.short_side
23 | th = (tw * h) / w
24 | th = int((th // 64) * 64)
25 | else:
26 | th = self.short_side
27 | tw = (th * w) / h
28 | tw = int((tw // 64) * 64)
29 | elif(type(self.short_side) == list):
30 | th = self.short_side[0]
31 | tw = self.short_side[1]
32 |
33 | interpolation = cv2.INTER_NEAREST
34 | pair_0 = cv2.resize(pair_0, dsize = (tw, th), interpolation=interpolation)
35 | pair_1 = cv2.resize(pair_1, dsize = (tw, th), interpolation=interpolation)
36 |
37 | return pair_0, pair_1
38 |
39 | class RandomCrop(object):
40 | """for pair of frames"""
41 | def __init__(self, size, seperate = False):
42 | if isinstance(size, numbers.Number):
43 | self.size = (int(size), int(size))
44 | else:
45 | self.size = size
46 | self.seperate = seperate
47 |
48 | def __call__(self, *frames):
49 | frames = list(frames)
50 | h, w, c = frames[0].shape
51 | th, tw = self.size
52 | top = bottom = left = right = 0
53 |
54 | if w == tw and h == th:
55 | return frames
56 |
57 | if w < tw:
58 | left = (tw - w) // 2
59 | right = tw - w - left
60 | if h < th:
61 | top = (th - h) // 2
62 | bottom = th - h - top
63 | if left > 0 or right > 0 or top > 0 or bottom > 0:
64 | for i in range(len(frames)):
65 | frames[i] = pad_image(
66 | 'reflection', frames[i], top, bottom, left, right)
67 | if w > tw:
68 | x1 = np.array([random.randint(0, w - tw)])
69 | x1 = np.concatenate((x1,x1))
70 | if self.seperate:
71 | #print("True")
72 | x1[1] = np.array([random.randint(0, w - tw)])
73 | for i in range(len(frames)):
74 | frames[i] = frames[i][:, x1[i]:x1[i]+tw]
75 | if h > th:
76 | y1 = np.array([random.randint(0, h - th)])
77 | y1 = np.concatenate((y1,y1))
78 |
79 | if self.seperate:
80 | y1[1] = np.array([random.randint(0, h - th)])
81 |
82 | for i in range(len(frames)):
83 | frames[i] = frames[i][y1[i]:y1[i]+th]
84 |
85 | return frames
86 |
87 | class CenterCrop(object):
88 | """for pair of frames"""
89 | def __init__(self, size):
90 | if isinstance(size, numbers.Number):
91 | self.size = (int(size), int(size))
92 | else:
93 | self.size = size
94 |
95 | def __call__(self, *frames):
96 | frames = list(frames)
97 | h, w, c = frames[0].shape
98 | th, tw = self.size
99 | top = bottom = left = right = 0
100 |
101 | if w == tw and h == th:
102 | return frames
103 |
104 | if w < tw:
105 | left = (tw - w) // 2
106 | right = tw - w - left
107 | if h < th:
108 | top = (th - h) // 2
109 | bottom = th - h - top
110 | if left > 0 or right > 0 or top > 0 or bottom > 0:
111 | for i in range(len(frames)):
112 | frames[i] = pad_image(
113 | 'reflection', frames[i], top, bottom, left, right)
114 |
115 | if w > tw:
116 | #x1 = random.randint(0, w - tw)
117 | x1 = (w - tw) // 2
118 | for i in range(len(frames)):
119 | frames[i] = frames[i][:, x1:x1+tw]
120 | if h > th:
121 | #y1 = random.randint(0, h - th)
122 | y1 = (h - th) // 2
123 | for i in range(len(frames)):
124 | frames[i] = frames[i][y1:y1+th]
125 |
126 | return frames
127 |
128 | class RandomScale(object):
129 | """docstring for RandomScale"""
130 | def __init__(self, scale, seperate = False):
131 | if isinstance(scale, numbers.Number):
132 | scale = [1 / scale, scale]
133 | self.scale = scale
134 | self.seperate = seperate
135 |
136 | def __call__(self, *frames):
137 | h,w,c = frames[0].shape
138 | results = []
139 | if self.seperate:
140 | ratio1 = random.uniform(self.scale[0], self.scale[1])
141 | ratio2 = random.uniform(self.scale[0], self.scale[1])
142 | tw1 = int(ratio1*w)
143 | th1 = int(ratio1*h)
144 | tw2 = int(ratio2*w)
145 | th2 = int(ratio2*h)
146 |
147 | if ratio1 == 1:
148 | results.append(frames[0])
149 | elif ratio1 < 1:
150 | interpolation = cv2.INTER_LANCZOS4
151 | elif ratio1 > 1:
152 | interpolation = cv2.INTER_CUBIC
153 | frame = cv2.resize(frames[0], dsize = (tw1, th1), interpolation=interpolation)
154 | results.append(frame)
155 |
156 | if ratio2 == 1:
157 | results.append(frames[1])
158 | elif ratio2 < 1:
159 | interpolation = cv2.INTER_LANCZOS4
160 | elif ratio2 > 1:
161 | interpolation = cv2.INTER_CUBIC
162 | frame = cv2.resize(frames[1], dsize = (tw2, th2), interpolation=interpolation)
163 | results.append(frame)
164 | else:
165 | ratio = random.uniform(self.scale[0], self.scale[1])
166 | tw = int(ratio*w)
167 | th = int(ratio*h)
168 | if ratio == 1:
169 | return frames
170 | elif ratio < 1:
171 | interpolation = cv2.INTER_LANCZOS4
172 | elif ratio > 1:
173 | interpolation = cv2.INTER_CUBIC
174 | for frame in frames:
175 | frame = cv2.resize(frame, dsize = (tw, th), interpolation=interpolation)
176 | results.append(frame)
177 | # print(results[0].shape,type(results[1]))
178 | return results
179 |
180 | class RandomRotate(object):
181 | """docstring for RandomRotate"""
182 | def __init__(self, angle, seperate = False):
183 | self.angle = angle
184 | self.seperate = seperate
185 |
186 | #def __call__(self, pair_0, pair_1):
187 | def __call__(self, *frames):
188 | results = []
189 | if self.seperate:
190 | angle = random.randint(0, self.angle * 2) - self.angle
191 | h,w,c = frames[0].shape
192 | p = max((h, w))
193 | frame = pad_image('reflection', frames[0], h,h,w,w)
194 | frame = rotatenumpy(frame, angle)
195 | frame = frame[h : h + h, w : w + w]
196 | results.append(frame)
197 |
198 | angle = random.randint(0, self.angle * 2) - self.angle
199 | h,w,c = frames[1].shape
200 | p = max((h, w))
201 | frame = pad_image('reflection', frames[1], h,h,w,w)
202 | frame = rotatenumpy(frame, angle)
203 | frame = frame[h : h + h, w : w + w]
204 | results.append(frame)
205 | else:
206 | angle = random.randint(0, self.angle * 2) - self.angle
207 | for frame in frames:
208 | h,w,c = frame.shape
209 | p = max((h, w))
210 | frame = pad_image('reflection', frame, h,h,w,w)
211 | frame = rotatenumpy(frame, angle)
212 | frame = frame[h : h + h, w : w + w]
213 | results.append(frame)
214 |
215 | return results
216 |
217 | class RandomHorizontalFlip(object):
218 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5
219 | """
220 |
221 | def __call__(self, *frames):
222 | results = []
223 | if random.random() < 0.5:
224 | for frame in frames:
225 | results.append(cv2.flip(frame, 1))
226 | else:
227 | results = frames
228 | return results
229 |
230 | class Resize(object):
231 | """Resize the input PIL Image to the given size.
232 | Args:
233 | size (sequence or int): Desired output size. If size is a sequence like
234 | (h, w), output size will be matched to this. If size is an int,
235 | smaller edge of the image will be matched to this number.
236 | i.e, if height > width, then image will be rescaled to
237 | (size * height / width, size)
238 | interpolation (int, optional): Desired interpolation. Default is
239 | ``PIL.Image.BILINEAR``
240 | """
241 |
242 | def __init__(self, size, interpolation=cv2.INTER_NEAREST):
243 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
244 | self.size = size
245 | self.interpolation = interpolation
246 |
247 | def __call__(self, pair_0, pair_1):
248 |
249 | return resize(pair_0, self.size, self.interpolation), \
250 | resize(pair_1, self.size, self.interpolation)
251 |
252 |
253 | class Pad(object):
254 |
255 | def __init__(self, padding, fill=0):
256 | assert isinstance(padding, numbers.Number)
257 | assert isinstance(fill, numbers.Number) or isinstance(fill, str) or \
258 | isinstance(fill, tuple)
259 | self.padding = padding
260 | self.fill = fill
261 |
262 | def __call__(self, pair_0, pair_1):
263 | if self.fill == -1:
264 | pair_0 = pad_image('reflection', pair_0,
265 | self.padding, self.padding, self.padding, self.padding)
266 | pair_1 = pad_image('reflection', pair_1,
267 | self.padding, self.padding, self.padding, self.padding)
268 | else:
269 | pair_0 = pad_image('constant', pair_0,
270 | self.padding, self.padding, self.padding, self.padding,
271 | value=self.fill)
272 | pair_1 = pad_image('constant', pair_1,
273 | self.padding, self.padding, self.padding, self.padding,
274 | value=self.fill)
275 |
276 | return pair_0, pair_1
277 |
278 |
279 | class ResizeandPad(object):
280 | """
281 | resize the larger boundary to the desized eva_size;
282 | pad the smaller one to square
283 | """
284 | def __init__(self, size, interpolation=cv2.INTER_NEAREST):
285 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
286 | self.size = size
287 | self.interpolation = interpolation
288 |
289 | def __call__(self, pair_0, pair_1):
290 | """
291 | Resize and Pad
292 | """
293 | pair_0 = resize_large(pair_0, self.size, self.interpolation)
294 | pair_1 = resize_large(pair_1, self.size, self.interpolation)
295 | h,w,_ = pair_0.shape
296 | if w > h:
297 | bd = int((w - h) / 2)
298 | pair_0 = pad_image('reflection', pair_0, bd, (w-h)-bd, 0, 0)
299 | # pair_1 = pad_image('reflection', pair_1, bd, (w-h)-bd, 0, 0)
300 | elif h > w:
301 | bd = int((h-w) / 2)
302 | pair_0 = pad_image('reflection', pair_0, 0, 0, bd, (h-w)-bd)
303 | # pair_1 = pad_image('reflection', pair_1, 0, 0, bd, (h-w)-bd)
304 |
305 | h,w,_ = pair_1.shape
306 | if w > h:
307 | bd = int((w - h) / 2)
308 | # pair_0 = pad_image('reflection', pair_0, bd, (w-h)-bd, 0, 0)
309 | pair_1 = pad_image('reflection', pair_1, bd, (w-h)-bd, 0, 0)
310 | elif h > w:
311 | bd = int((h-w) / 2)
312 | # pair_0 = pad_image('reflection', pair_0, 0, 0, bd, (h-w)-bd)
313 | pair_1 = pad_image('reflection', pair_1, 0, 0, bd, (h-w)-bd)
314 |
315 | return pair_0, pair_1
316 |
317 |
318 | class Normalize(object):
319 | """Given mean: (R, G, B) and std: (R, G, B),
320 | will normalize each channel of the torch.*Tensor, i.e.
321 | channel = (channel - mean) / std
322 | """
323 |
324 | def __init__(self, mean, std = (1.0,1.0,1.0)):
325 | self.mean = torch.FloatTensor(mean)
326 | self.std = torch.FloatTensor(std)
327 |
328 | def __call__(self, *frames):
329 | results = []
330 | for frame in frames:
331 | for t, m, s in zip(frame, self.mean, self.std):
332 | t.sub_(m).div_(s)
333 | results.append(frame)
334 | return results
335 |
336 |
337 | class ToTensor(object):
338 | """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
339 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
340 | """
341 | def __call__(self, *frames):
342 | results = []
343 | for frame in frames:
344 | frame = torch.from_numpy(frame.transpose(2,0,1).copy()).contiguous().float()
345 | results.append(frame)
346 | return results
347 |
348 |
349 | class Compose(object):
350 | """
351 | Composes several transforms together.
352 | """
353 | def __init__(self, transforms):
354 | self.transforms = transforms
355 |
356 | def __call__(self, *args):
357 | for t in self.transforms:
358 | args = t(*args)
359 | return args
360 |
361 |
362 | #=============================functions===============================
363 |
364 | def resize(img, size, interpolation=cv2.INTER_NEAREST):
365 |
366 | if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)):
367 | raise TypeError('Got inappropriate size arg: {}'.format(size))
368 |
369 | h, w = img.shape
370 |
371 | if isinstance(size, int):
372 | if (w <= h and w == size) or (h <= w and h == size):
373 | return img
374 | if w < h:
375 | ow = size
376 | oh = int(size * h / w)
377 | return cv2.resize(img, (ow, oh), interpolation)
378 | else:
379 | oh = size
380 | ow = int(size * w / h)
381 | return cv2.resize(img, (ow, oh), interpolation)
382 | else:
383 | return cv2.resize(img, size[::-1], interpolation)
384 |
385 |
386 | def resize_large(img, size, interpolation=cv2.INTER_NEAREST):
387 |
388 | if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)):
389 | raise TypeError('Got inappropriate size arg: {}'.format(size))
390 |
391 | h, w,_ = img.shape
392 |
393 | if isinstance(size, int):
394 | if (w >= h and w == size) or (h >= w and h == size):
395 | return img
396 | if w > h:
397 | ow = size
398 | oh = int(size * h / w)
399 | return cv2.resize(img, (ow, oh), interpolation)
400 | else:
401 | oh = size
402 | ow = int(size * w / h)
403 | return cv2.resize(img, (ow, oh), interpolation)
404 | else:
405 | return cv2.resize(img, size[::-1], interpolation)
406 |
407 |
408 | def rotatenumpy(image, angle, interpolation=cv2.INTER_NEAREST):
409 | rot_mat = cv2.getRotationMatrix2D((image.shape[1]/2, image.shape[0]/2), angle, 1.0)
410 | result = cv2.warpAffine(image, rot_mat, (image.shape[1],image.shape[0]), flags=interpolation)
411 | return result
412 |
413 | # good, written with numpy
414 | def pad_reflection(image, top, bottom, left, right):
415 | if top == 0 and bottom == 0 and left == 0 and right == 0:
416 | return image
417 | h, w = image.shape[:2]
418 | next_top = next_bottom = next_left = next_right = 0
419 | if top > h - 1:
420 | next_top = top - h + 1
421 | top = h - 1
422 | if bottom > h - 1:
423 | next_bottom = bottom - h + 1
424 | bottom = h - 1
425 | if left > w - 1:
426 | next_left = left - w + 1
427 | left = w - 1
428 | if right > w - 1:
429 | next_right = right - w + 1
430 | right = w - 1
431 | new_shape = list(image.shape)
432 | new_shape[0] += top + bottom
433 | new_shape[1] += left + right
434 | new_image = np.empty(new_shape, dtype=image.dtype)
435 | new_image[top:top+h, left:left+w] = image
436 | new_image[:top, left:left+w] = image[top:0:-1, :]
437 | new_image[top+h:, left:left+w] = image[-1:-bottom-1:-1, :]
438 | new_image[:, :left] = new_image[:, left*2:left:-1]
439 | new_image[:, left+w:] = new_image[:, -right-1:-right*2-1:-1]
440 | return pad_reflection(new_image, next_top, next_bottom,
441 | next_left, next_right)
442 |
443 | # good, writen with numpy
444 | def pad_constant(image, top, bottom, left, right, value):
445 | if top == 0 and bottom == 0 and left == 0 and right == 0:
446 | return image
447 | h, w = image.shape[:2]
448 | new_shape = list(image.shape)
449 | new_shape[0] += top + bottom
450 | new_shape[1] += left + right
451 | new_image = np.empty(new_shape, dtype=image.dtype)
452 | new_image.fill(value)
453 | new_image[top:top+h, left:left+w] = image
454 | return new_image
455 |
456 | # change to np/non-np options
457 | def pad_image(mode, image, top, bottom, left, right, value=0):
458 | if mode == 'reflection':
459 | return pad_reflection(image, top, bottom, left, right)
460 | elif mode == 'constant':
461 | return pad_constant(image, top, bottom, left, right, value)
462 | else:
463 | raise ValueError('Unknown mode {}'.format(mode))
464 |
--------------------------------------------------------------------------------
/libs/transforms_pair.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | import random
3 | import scipy.io
4 | import cv2
5 | import numpy as np
6 | from PIL import Image, ImageOps, ImageEnhance, PILLOW_VERSION
7 | try:
8 | import accimage
9 | except ImportError:
10 | accimage = None
11 | import torch
12 |
13 | class CenterCrop(object):
14 | """for pair of frames"""
15 | def __init__(self, size):
16 | if isinstance(size, numbers.Number):
17 | self.size = (int(size), int(size))
18 | else:
19 | self.size = size
20 |
21 | def __call__(self, *frames):
22 | frames = list(frames)
23 | h, w, c = frames[0].shape
24 | th, tw = self.size
25 | top = bottom = left = right = 0
26 |
27 | if w == tw and h == th:
28 | return frames
29 |
30 | if w < tw:
31 | left = (tw - w) // 2
32 | right = tw - w - left
33 | if h < th:
34 | top = (th - h) // 2
35 | bottom = th - h - top
36 | if left > 0 or right > 0 or top > 0 or bottom > 0:
37 | for i in range(len(frames)):
38 | frames[i] = pad_image(
39 | 'reflection', frames[i], top, bottom, left, right)
40 |
41 | if w > tw:
42 | #x1 = random.randint(0, w - tw)
43 | x1 = (w - tw) // 2
44 | for i in range(len(frames)):
45 | frames[i] = frames[i][:, x1:x1+tw]
46 | if h > th:
47 | #y1 = random.randint(0, h - th)
48 | y1 = (h - th) // 2
49 | for i in range(len(frames)):
50 | frames[i] = frames[i][y1:y1+th]
51 |
52 | return frames
53 |
54 | class RandomCrop(object):
55 | """for pair of frames"""
56 | def __init__(self, size):
57 | if isinstance(size, numbers.Number):
58 | self.size = (int(size), int(size))
59 | else:
60 | self.size = size
61 |
62 | def __call__(self, pair_0, pair_1):
63 | h, w, c = pair_0.shape
64 | th, tw = self.size
65 | top = bottom = left = right = 0
66 |
67 | if w == tw and h == th:
68 | return pair_0, pair_1
69 |
70 | if w < tw:
71 | left = (tw - w) // 2
72 | right = tw - w - left
73 | if h < th:
74 | top = (th - h) // 2
75 | bottom = th - h - top
76 | if left > 0 or right > 0 or top > 0 or bottom > 0:
77 | pair_0 = pad_image(
78 | 'reflection', pair_0, top, bottom, left, right)
79 | pair_1 = pad_image(
80 | 'reflection', pair_1, top, bottom, left, right)
81 |
82 | if w > tw:
83 | x1 = random.randint(0, w - tw)
84 | pair_0 = pair_0[:, x1:x1+tw]
85 | pair_1 = pair_1[:, x1:x1+tw]
86 | if h > th:
87 | y1 = random.randint(0, h - th)
88 | pair_0 = pair_0[y1:y1+th]
89 | pair_1 = pair_1[y1:y1+th]
90 |
91 | return pair_0, pair_1
92 |
93 | class RandomScale(object):
94 | """docstring for RandomScale"""
95 | def __init__(self, scale):
96 | if isinstance(scale, numbers.Number):
97 | scale = [1 / scale, scale]
98 | self.scale = scale
99 |
100 | def __call__(self, pair_0, pair_1):
101 | ratio = random.uniform(self.scale[0], self.scale[1])
102 | h,w,c = pair_0.shape
103 | tw = int(ratio*w)
104 | th = int(ratio*h)
105 | if ratio == 1:
106 | return pair_0, pair_1
107 | elif ratio < 1:
108 | interpolation = cv2.INTER_LANCZOS4
109 | elif ratio > 1:
110 | interpolation = cv2.INTER_CUBIC
111 | pair_0 = cv2.resize(pair_0, dsize = (tw, th), interpolation=interpolation)
112 | pair_1 = cv2.resize(pair_1, dsize = (tw, th), interpolation=interpolation)
113 |
114 | return pair_0, pair_1
115 |
116 | class Scale(object):
117 | """docstring for Scale"""
118 | def __init__(self, short_side):
119 | self.short_side = short_side
120 |
121 | def __call__(self, pair_0, pair_1):
122 | if(type(self.short_side) == int):
123 | h,w,c = pair_0.shape
124 | if(h > w):
125 | tw = self.short_side
126 | th = (tw * h) / w
127 | th = int((th // 64) * 64)
128 | else:
129 | th = self.short_side
130 | tw = (th * w) / h
131 | tw = int((tw // 64) * 64)
132 | elif(type(self.short_side) == list):
133 | th = self.short_side[0]
134 | tw = self.short_side[1]
135 |
136 | interpolation = cv2.INTER_NEAREST
137 | pair_0 = cv2.resize(pair_0, dsize = (tw, th), interpolation=interpolation)
138 | pair_1 = cv2.resize(pair_1, dsize = (tw, th), interpolation=interpolation)
139 |
140 | return pair_0, pair_1
141 |
142 | class RandomRotate(object):
143 | """docstring for RandomRotate"""
144 | def __init__(self, angle):
145 | self.angle = angle
146 |
147 | def __call__(self, pair_0, pair_1):
148 | h,w,c = pair_0.shape
149 | p = max((h, w))
150 | angle = random.randint(0, self.angle * 2) - self.angle
151 | pair_0 = pad_image('reflection', pair_0, h,h,w,w)
152 | pair_0 = rotatenumpy(pair_0, angle)
153 | pair_0 = pair_0[h : h + h, w : w + w]
154 | pair_1 = pad_image('reflection', pair_1, h,h,w,w)
155 | pair_1 = rotatenumpy(pair_1, angle)
156 | pair_1 = pair_1[h : h + h, w : w + w]
157 |
158 | return pair_0, pair_1
159 |
160 | class RandomHorizontalFlip(object):
161 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5
162 | """
163 |
164 | def __call__(self, pair_0, pair_1):
165 | if random.random() < 0.5:
166 | results = [cv2.flip(pair_0, 1), cv2.flip(pair_1, 1)]
167 | else:
168 | results = [pair_0, pair_1]
169 | return results
170 |
171 | class Resize(object):
172 | """Resize the input PIL Image to the given size.
173 | Args:
174 | size (sequence or int): Desired output size. If size is a sequence like
175 | (h, w), output size will be matched to this. If size is an int,
176 | smaller edge of the image will be matched to this number.
177 | i.e, if height > width, then image will be rescaled to
178 | (size * height / width, size)
179 | interpolation (int, optional): Desired interpolation. Default is
180 | ``PIL.Image.BILINEAR``
181 | """
182 |
183 | def __init__(self, size, interpolation=cv2.INTER_NEAREST):
184 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
185 | self.size = size
186 | self.interpolation = interpolation
187 |
188 | def __call__(self, pair_0, pair_1):
189 |
190 | return resize(pair_0, self.size, self.interpolation), \
191 | resize(pair_1, self.size, self.interpolation)
192 |
193 |
194 | class Pad(object):
195 |
196 | def __init__(self, padding, fill=0):
197 | assert isinstance(padding, numbers.Number)
198 | assert isinstance(fill, numbers.Number) or isinstance(fill, str) or \
199 | isinstance(fill, tuple)
200 | self.padding = padding
201 | self.fill = fill
202 |
203 | def __call__(self, pair_0, pair_1):
204 | if self.fill == -1:
205 | pair_0 = pad_image('reflection', pair_0,
206 | self.padding, self.padding, self.padding, self.padding)
207 | pair_1 = pad_image('reflection', pair_1,
208 | self.padding, self.padding, self.padding, self.padding)
209 | else:
210 | pair_0 = pad_image('constant', pair_0,
211 | self.padding, self.padding, self.padding, self.padding,
212 | value=self.fill)
213 | pair_1 = pad_image('constant', pair_1,
214 | self.padding, self.padding, self.padding, self.padding,
215 | value=self.fill)
216 |
217 | return pair_0, pair_1
218 |
219 |
220 | class ResizeandPad(object):
221 | """
222 | resize the larger boundary to the desized eva_size;
223 | pad the smaller one to square
224 | """
225 | def __init__(self, size, interpolation=cv2.INTER_NEAREST):
226 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
227 | self.size = size
228 | self.interpolation = interpolation
229 |
230 | def __call__(self, pair_0, pair_1):
231 | """
232 | Resize and Pad
233 | """
234 | pair_0 = resize(pair_0, self.size, self.interpolation)
235 | pair_1 = resize(pair_1, self.size, self.interpolation)
236 | h,w,_ = pair_0.shape
237 | if w > h:
238 | bd = int((w - h) / 2)
239 | pair_0 = pad_image('reflection', pair_0, bd, (w-h)-bd, 0, 0)
240 | pair_1 = pad_image('reflection', pair_1, bd, (w-h)-bd, 0, 0)
241 | elif h > w:
242 | bd = int((h-w) / 2)
243 | pair_0 = pad_image('reflection', pair_0, 0, 0, bd, (h-w)-bd)
244 | pair_1 = pad_image('reflection', pair_1, 0, 0, bd, (h-w)-bd)
245 | return pair_0, pair_1
246 |
247 |
248 | class Normalize(object):
249 | """Given mean: (R, G, B) and std: (R, G, B),
250 | will normalize each channel of the torch.*Tensor, i.e.
251 | channel = (channel - mean) / std
252 | """
253 |
254 | def __init__(self, mean, std = (1.0,1.0,1.0)):
255 | self.mean = torch.FloatTensor(mean)
256 | self.std = torch.FloatTensor(std)
257 |
258 | def __call__(self, pair_0, pair_1):
259 | for t, m, s in zip(pair_0, self.mean, self.std):
260 | t.sub_(m).div_(s)
261 | for t, m, s in zip(pair_1, self.mean, self.std):
262 | t.sub_(m).div_(s)
263 | # print("pair_0: ", pair_0.size())
264 | # print("pair_1: ", pair_1.size())
265 | return pair_0, pair_1
266 |
267 |
268 | class ToTensor(object):
269 | """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
270 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
271 | """
272 | def __call__(self, pair_0, pair_1):
273 | if(pair_0.ndim == 2):
274 | pair_0 = torch.from_numpy(pair_0.copy()).contiguous().float().unsqueeze(0)
275 | pair_1 = torch.from_numpy(pair_1.copy()).contiguous().float().unsqueeze(0)
276 | else:
277 | pair_0 = torch.from_numpy(pair_0.transpose(2,0,1).copy()).contiguous().float()
278 | pair_1 = torch.from_numpy(pair_1.transpose(2,0,1).copy()).contiguous().float()
279 | return pair_0, pair_1
280 |
281 |
282 | class Compose(object):
283 | """
284 | Composes several transforms together.
285 | """
286 | def __init__(self, transforms):
287 | self.transforms = transforms
288 |
289 | def __call__(self, *args):
290 | for t in self.transforms:
291 | args = t(*args)
292 | return args
293 |
294 |
295 | #=============================functions===============================
296 |
297 | def resize(img, size, interpolation=cv2.INTER_NEAREST):
298 |
299 | if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)):
300 | raise TypeError('Got inappropriate size arg: {}'.format(size))
301 |
302 | h, w, _ = img.shape
303 |
304 | if isinstance(size, int):
305 | if (w <= h and w == size) or (h <= w and h == size):
306 | return img
307 | if w < h:
308 | ow = size
309 | oh = int(size * h / w)
310 | return cv2.resize(img, (ow, oh), interpolation)
311 | else:
312 | oh = size
313 | ow = int(size * w / h)
314 | return cv2.resize(img, (ow, oh), interpolation)
315 | else:
316 | return cv2.resize(img, size[::-1], interpolation)
317 |
318 |
319 | def rotatenumpy(image, angle, interpolation=cv2.INTER_NEAREST):
320 | rot_mat = cv2.getRotationMatrix2D((image.shape[1]/2, image.shape[0]/2), angle, 1.0)
321 | result = cv2.warpAffine(image, rot_mat, (image.shape[1],image.shape[0]), flags=interpolation)
322 | return result
323 |
324 | # good, written with numpy
325 | def pad_reflection(image, top, bottom, left, right):
326 | if top == 0 and bottom == 0 and left == 0 and right == 0:
327 | return image
328 | h, w = image.shape[:2]
329 | next_top = next_bottom = next_left = next_right = 0
330 | if top > h - 1:
331 | next_top = top - h + 1
332 | top = h - 1
333 | if bottom > h - 1:
334 | next_bottom = bottom - h + 1
335 | bottom = h - 1
336 | if left > w - 1:
337 | next_left = left - w + 1
338 | left = w - 1
339 | if right > w - 1:
340 | next_right = right - w + 1
341 | right = w - 1
342 | new_shape = list(image.shape)
343 | new_shape[0] += top + bottom
344 | new_shape[1] += left + right
345 | new_image = np.empty(new_shape, dtype=image.dtype)
346 | new_image[top:top+h, left:left+w] = image
347 | new_image[:top, left:left+w] = image[top:0:-1, :]
348 | new_image[top+h:, left:left+w] = image[-1:-bottom-1:-1, :]
349 | new_image[:, :left] = new_image[:, left*2:left:-1]
350 | new_image[:, left+w:] = new_image[:, -right-1:-right*2-1:-1]
351 | return pad_reflection(new_image, next_top, next_bottom,
352 | next_left, next_right)
353 |
354 | # good, writen with numpy
355 | def pad_constant(image, top, bottom, left, right, value):
356 | if top == 0 and bottom == 0 and left == 0 and right == 0:
357 | return image
358 | h, w = image.shape[:2]
359 | new_shape = list(image.shape)
360 | new_shape[0] += top + bottom
361 | new_shape[1] += left + right
362 | new_image = np.empty(new_shape, dtype=image.dtype)
363 | new_image.fill(value)
364 | new_image[top:top+h, left:left+w] = image
365 | return new_image
366 |
367 | # change to np/non-np options
368 | def pad_image(mode, image, top, bottom, left, right, value=0):
369 | if mode == 'reflection':
370 | return pad_reflection(image, top, bottom, left, right)
371 | elif mode == 'constant':
372 | return pad_constant(image, top, bottom, left, right, value)
373 | else:
374 | raise ValueError('Unknown mode {}'.format(mode))
375 |
--------------------------------------------------------------------------------
/libs/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import glob
4 | import torch
5 | import scipy.misc
6 | import numpy as np
7 | from PIL import Image
8 | import libs.transforms_multi as transforms
9 |
10 | def print_options(opt, test=False):
11 | message = ''
12 | message += '----------------- Options ---------------\n'
13 | for k, v in sorted(vars(opt).items()):
14 | comment = ''
15 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
16 | message += '----------------- End -------------------'
17 | print(message)
18 |
19 | # save to the disk
20 | if(test):
21 | expr_dir = os.path.join(opt.out_dir)
22 | os.makedirs(expr_dir,exist_ok=True)
23 | file_name = os.path.join(expr_dir, 'test_opt.txt')
24 | else:
25 | expr_dir = os.path.join(opt.checkpoint_dir)
26 | os.makedirs(expr_dir,exist_ok=True)
27 | file_name = os.path.join(expr_dir, 'train_opt.txt')
28 | with open(file_name, 'wt') as opt_file:
29 | opt_file.write(message)
30 | opt_file.write('\n')
31 |
32 | def read_frame_list(video_dir):
33 | frame_list = [img for img in glob.glob(os.path.join(video_dir,"*.jpg"))]
34 | frame_list = sorted(frame_list)
35 | return frame_list
36 |
37 | def read_frame(frame_dir, transforms):
38 | frame = cv2.imread(frame_dir)
39 | ori_h,ori_w,_ = frame.shape
40 | if(ori_h > ori_w):
41 | tw = ori_w
42 | th = (tw * ori_h) / ori_w
43 | th = int((th // 64) * 64)
44 | else:
45 | th = ori_h
46 | tw = (th * ori_w) / ori_h
47 | tw = int((tw // 64) * 64)
48 | #h = (ori_h // 64) * 64
49 | #w = (ori_w // 64) * 64
50 | frame = cv2.resize(frame, (tw,th))
51 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2LAB)
52 |
53 | pair = [frame, frame]
54 | transformed = list(transforms(*pair))
55 | return transformed[0].cuda().unsqueeze(0), ori_h, ori_w
56 |
57 | def to_one_hot(y_tensor, n_dims=9):
58 | _,h,w = y_tensor.size()
59 | """ Take integer y (tensor or variable) with n dims and convert it to 1-hot representation with n+1 dims. """
60 | y_tensor = y_tensor.type(torch.LongTensor).view(-1, 1)
61 | n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1
62 | y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1)
63 | y_one_hot = y_one_hot.view(h,w,n_dims)
64 | return y_one_hot.permute(2,0,1).unsqueeze(0)
65 |
66 | def read_seg(seg_dir, crop_size):
67 | seg = Image.open(seg_dir)
68 | h,w = seg.size
69 | if(h > w):
70 | tw = crop_size
71 | th = (tw * h) / w
72 | th = int((th // 64) * 64)
73 | else:
74 | th = crop_size
75 | tw = (th * w) / h
76 | tw = int((tw // 64) * 64)
77 | seg = np.asarray(seg).reshape((w,h,1))
78 | seg = np.squeeze(seg)
79 | seg = scipy.misc.imresize(seg, (tw//8,th//8),"nearest",mode="F")
80 |
81 | seg = torch.from_numpy(seg).view(1,tw//8,th//8)
82 | return to_one_hot(seg)
83 |
84 | def create_transforms(crop_size):
85 | normalize = transforms.Normalize(mean = (128, 128, 128), std = (128, 128, 128))
86 | t = []
87 | t.extend([transforms.ToTensor(),
88 | normalize])
89 | return transforms.Compose(t)
90 |
91 | def transform_topk(aff, frame1, k):
92 | """
93 | INPUTS:
94 | - aff: affinity matrix, b * N * N
95 | - frame1: reference frame
96 | - k: only aggregate top-k pixels with highest aff(j,i)
97 | """
98 | b,c,h,w = frame1.size()
99 | b, N, _ = aff.size()
100 | # b * 20 * N
101 | tk_val, tk_idx = torch.topk(aff, dim = 1, k=k)
102 | # b * N
103 | tk_val_min,_ = torch.min(tk_val,dim=1)
104 | tk_val_min = tk_val_min.view(b,1,N)
105 | aff[tk_val_min > aff] = 0
106 | frame1 = frame1.view(b,c,-1)
107 | frame2 = torch.bmm(frame1, aff)
108 | return frame2.view(b,c,h,w)
109 |
110 | def norm_mask(mask):
111 | """
112 | INPUTS:
113 | - mask: segmentation mask
114 | """
115 | c,h,w = mask.size()
116 | for cnt in range(c):
117 | mask_cnt = mask[cnt,:,:]
118 | if(mask_cnt.max() > 0):
119 | mask_cnt = (mask_cnt - mask_cnt.min())
120 | mask_cnt = mask_cnt/mask_cnt.max()
121 | mask[cnt,:,:] = mask_cnt
122 | return mask
123 |
124 | def diff_crop(F, x1, y1, x2, y2, ph, pw):
125 | """
126 | Differatiable cropping
127 | INPUTS:
128 | - F: frame feature
129 | - x1,y1,x2,y2: top left and bottom right points of the patch
130 | - theta is defined as :
131 | a b c
132 | d e f
133 | """
134 | bs, ch, h, w = F.size()
135 | a = ((x2-x1)/w).view(bs,1,1)
136 | b = torch.zeros(a.size()).cuda()
137 | c = (-1+(x1+x2)/w).view(bs,1,1)
138 | d = torch.zeros(a.size()).cuda()
139 | e = ((y2-y1)/h).view(bs,1,1)
140 | f = (-1+(y2+y1)/h).view(bs,1,1)
141 | theta_row1 = torch.cat((a,b,c),dim=2)
142 | theta_row2 = torch.cat((d,e,f),dim=2)
143 | theta = torch.cat((theta_row1, theta_row2),dim=1).cuda()
144 | size = torch.Size((bs,ch,pw,ph))
145 | grid = torch.nn.functional.affine_grid(theta, size)
146 | patch = torch.nn.functional.grid_sample(F,grid)
147 | return patch
148 |
149 | def center2bbox(center, patch_size, h, w):
150 | b = center.size(0)
151 | if(isinstance(patch_size,int)):
152 | new_l = center[:,0] - patch_size/2
153 | else:
154 | new_l = center[:,0] - patch_size[1]/2
155 | new_l[new_l < 0] = 0
156 | new_l = new_l.view(b,1)
157 |
158 | if(isinstance(patch_size,int)):
159 | new_r = new_l + patch_size
160 | else:
161 | new_r = new_l + patch_size[1]
162 | new_r[new_r > w] = w
163 |
164 | if(isinstance(patch_size,int)):
165 | new_t = center[:,1] - patch_size/2
166 | else:
167 | new_t = center[:,1] - patch_size[0]/2
168 | new_t[new_t < 0] = 0
169 | new_t = new_t.view(b,1)
170 |
171 | if(isinstance(patch_size,int)):
172 | new_b = new_t + patch_size
173 | else:
174 | new_b = new_t + patch_size[0]
175 | new_b[new_b > h] = h
176 |
177 | new_center = torch.cat((new_l,new_r,new_t,new_b),dim=1)
178 | return new_center
179 |
--------------------------------------------------------------------------------
/libs/vis_utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import numpy as np
4 | #import seaborn as sns
5 | from .model import transform
6 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
7 | import torchvision.utils as vutils
8 |
9 | UNKNOWN_FLOW_THRESH = 1e7
10 | SMALLFLOW = 0.0
11 | LARGEFLOW = 1e8
12 |
13 | def prepare_img(img):
14 |
15 | if img.ndim == 3:
16 | img = img[:, :, ::-1] ### RGB to BGR
17 |
18 | ## clip to [0, 1]
19 | img = np.clip(img, 0, 1)
20 |
21 | ## quantize to [0, 255]
22 | img = np.uint8(img * 255.0)
23 |
24 | #cv2.imwrite(filename, img, [cv2.IMWRITE_PNG_COMPRESSION, 0])
25 | return img
26 |
27 | def flow_to_rgb(flow, mr = None):
28 | """
29 | Convert flow into middlebury color code image
30 | :param flow: optical flow map
31 | :return: optical flow image in middlebury color
32 | """
33 | u = flow[:, :, 0]
34 | v = flow[:, :, 1]
35 |
36 | maxu = -999.
37 | maxv = -999.
38 | minu = 999.
39 | minv = 999.
40 |
41 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
42 | u[idxUnknow] = 0
43 | v[idxUnknow] = 0
44 |
45 | maxu = max(maxu, np.max(u))
46 | minu = min(minu, np.min(u))
47 |
48 | maxv = max(maxv, np.max(v))
49 | minv = min(minv, np.min(v))
50 |
51 | rad = np.sqrt(u ** 2 + v ** 2)
52 | maxrad = max(-1, np.max(rad))
53 | if(mr is not None):
54 | maxrad = mr
55 |
56 | #print "max flow: %.4f\nflow range:\nu = %.3f .. %.3f\nv = %.3f .. %.3f" % (maxrad, minu,maxu, minv, maxv)
57 |
58 | u = u/(maxrad + np.finfo(float).eps)
59 | v = v/(maxrad + np.finfo(float).eps)
60 |
61 | img = compute_color(u, v)
62 |
63 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
64 | img[idx] = 0
65 |
66 | return np.float32(img) / 255.0, maxrad
67 |
68 |
69 | def compute_color(u, v):
70 | """
71 | compute optical flow color map
72 | :param u: optical flow horizontal map
73 | :param v: optical flow vertical map
74 | :return: optical flow in color code
75 | """
76 | [h, w] = u.shape
77 | img = np.zeros([h, w, 3])
78 | nanIdx = np.isnan(u) | np.isnan(v)
79 | u[nanIdx] = 0
80 | v[nanIdx] = 0
81 |
82 | colorwheel = make_color_wheel()
83 | ncols = np.size(colorwheel, 0)
84 |
85 | rad = np.sqrt(u**2+v**2)
86 |
87 | a = np.arctan2(-v, -u) / np.pi
88 |
89 | fk = (a+1) / 2 * (ncols - 1) + 1
90 |
91 | k0 = np.floor(fk).astype(int)
92 |
93 | k1 = k0 + 1
94 | k1[k1 == ncols+1] = 1
95 | f = fk - k0
96 |
97 | for i in range(0, np.size(colorwheel,1)):
98 | tmp = colorwheel[:, i]
99 | col0 = tmp[k0-1] / 255
100 | col1 = tmp[k1-1] / 255
101 | col = (1-f) * col0 + f * col1
102 |
103 | idx = rad <= 1
104 | col[idx] = 1-rad[idx]*(1-col[idx])
105 | notidx = np.logical_not(idx)
106 |
107 | col[notidx] *= 0.75
108 | img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
109 |
110 | return img
111 |
112 |
113 | def make_color_wheel():
114 | """
115 | Generate color wheel according Middlebury color code
116 | :return: Color wheel
117 | """
118 | RY = 15
119 | YG = 6
120 | GC = 4
121 | CB = 11
122 | BM = 13
123 | MR = 6
124 |
125 | ncols = RY + YG + GC + CB + BM + MR
126 |
127 | colorwheel = np.zeros([ncols, 3])
128 |
129 | col = 0
130 |
131 | # RY
132 | colorwheel[0:RY, 0] = 255
133 | colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
134 | col += RY
135 |
136 | # YG
137 | colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
138 | colorwheel[col:col+YG, 1] = 255
139 | col += YG
140 |
141 | # GC
142 | colorwheel[col:col+GC, 1] = 255
143 | colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
144 | col += GC
145 |
146 | # CB
147 | colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
148 | colorwheel[col:col+CB, 2] = 255
149 | col += CB
150 |
151 | # BM
152 | colorwheel[col:col+BM, 2] = 255
153 | colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
154 | col += + BM
155 |
156 | # MR
157 | colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
158 | colorwheel[col:col+MR, 0] = 255
159 |
160 | return colorwheel
161 |
162 | def aff2flow(A, F_size, GPU=True):
163 | """
164 | INPUT:
165 | - A: a (H*W)*(H*W) affinity matrix
166 | - F_size: image feature size
167 | OUTPUT:
168 | - U: a (2*H*W) flow tensor, U_ij indicates the coordinates of pixel ij in target image.
169 | """
170 | b,c,h,w = F_size
171 | theta = torch.tensor([[1,0,0],[0,1,0]])
172 | theta = theta.unsqueeze(0).repeat(b,1,1)
173 | theta = theta.float()
174 |
175 | # grid is a uniform grid with left top (-1,1) and right bottom (1,1)
176 | # b * (h*w) * 2
177 | grid = torch.nn.functional.affine_grid(theta, F_size)
178 | #grid = grid.view(b,h*w,2)
179 | if(GPU):
180 | grid = grid.cuda()
181 | # b * (h*w) * 2
182 | # A: 1x1024x1024
183 | grid = grid.permute(0,3,1,2)
184 | U = transform(A, grid)
185 | return (U - grid).permute(0,2,3,1)
186 |
187 | def draw_certainty_map(map, normalize=False):
188 | """
189 | INPUTS:
190 | - map: certainty map of flow
191 | """
192 | map = map.squeeze()
193 | # normalization
194 | if(normalize):
195 | map = (map - map.min())/(map.max() - map.min())
196 | # draw heat map
197 | ax = sns.heatmap(map, yticklabels=False, xticklabels=False, cbar=True)
198 | else:
199 | ax = sns.heatmap(map, yticklabels=False, xticklabels=False, vmin=0.0, vmax=1.5, cbar=True)
200 | figure = ax.get_figure()
201 | width, height = figure.get_size_inches() * figure.get_dpi()
202 |
203 | canvas = FigureCanvas(figure)
204 | canvas.draw()
205 | image = np.fromstring(canvas.tostring_rgb(), dtype='uint8')
206 | image = image.reshape(int(height), int(width), 3)
207 | # crop border
208 | gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
209 | gray[gray == 255] = 0
210 | gray[gray > 0] = 255
211 | coords = cv2.findNonZero(gray)
212 | x,y,w,h = cv2.boundingRect(coords)
213 | #rect = image[y:y+h, x:x+w]
214 | rect = image
215 | figure.clf()
216 | return rect
217 |
218 | def draw_foreground(aff, mask, temp = 1, reverse=False):
219 | """
220 | INPUTS:
221 | - aff: a b*N*N affinity matrix, without applying any softmax
222 | - mask: a b*h*w segmentation mask of the first frame
223 | - temp: temperature
224 | - reverse: if False, calculates where does each pixel go
225 | if True, calculates where does each pixel come from
226 | """
227 | mask = torch.argmax(mask, dim = 0)
228 | h,w = mask.size()
229 | res = torch.zeros(h,w)
230 | if(reverse):
231 | # apply softmax to each column
232 | aff = torch.nn.functional.softmax(aff*temp, dim = 0)
233 | else:
234 | # apply softmax to each row
235 | aff = torch.nn.functional.softmax(aff*temp, dim = 1)
236 | # extract forground affinity
237 | # N
238 | mask_flat = mask.view(-1)
239 | # x
240 | fgcmask = mask_flat.nonzero().squeeze()
241 | if(reverse):
242 | # N * x
243 | Faff = aff[:,fgcmask]
244 | else:
245 | # x * N
246 | Faff = aff[fgcmask,:]
247 |
248 | theta = torch.tensor([[1,0,0],[0,1,0]])
249 | theta = theta.unsqueeze(0).repeat(1,1,1)
250 | theta = theta.float()
251 |
252 | # grid is a uniform grid with left top (-1,1) and right bottom (1,1)
253 | # 1 * (h*w) * 2
254 | grid = torch.nn.functional.affine_grid(theta, torch.Size((1,1,h,w)))
255 | # N*2
256 | grid = grid.squeeze()
257 | grid = grid.view(-1,2)
258 | grid = (grid + 1)/2
259 | grid[:,0] *= w
260 | grid[:,1] *= h
261 | if(reverse):
262 | grid = grid.permute(1,0)
263 | # 2 * x
264 | Fcoord = torch.mm(grid, Faff)
265 | Fcoord = Fcoord.long()
266 | res[Fcoord[1,:], Fcoord[0,:]] = 1
267 | else:
268 | # x * 2
269 | grid -= 1
270 | Fcoord = torch.mm(Faff, grid)
271 | Fcoord = Fcoord.long()
272 | res[Fcoord[:,1], Fcoord[:,0]] = 1
273 | res = torch.nn.functional.interpolate(res.unsqueeze(0).unsqueeze(0), scale_factor=8, mode='bilinear')
274 | return res
275 |
276 | def norm_mask(mask):
277 | """
278 | INPUTS:
279 | - mask: segmentation mask
280 | """
281 | c,h,w = mask.size()
282 | for cnt in range(c):
283 | mask_cnt = mask[cnt,:,:]
284 | if(mask_cnt.max() > 0):
285 | mask_cnt = (mask_cnt - mask_cnt.min())
286 | mask_cnt = mask_cnt/mask_cnt.max()
287 | mask[cnt,:,:] = mask_cnt
288 | return mask
289 |
290 | def read_flo(filename):
291 | FLO_TAG = 202021.25
292 |
293 | with open(filename, 'rb') as f:
294 | tag = np.fromfile(f, np.float32, count=1)
295 |
296 | if tag != FLO_TAG:
297 | sys.exit('Wrong tag. Invalid .flo file %s' %filename)
298 | else:
299 | w = int(np.fromfile(f, np.int32, count=1))
300 | h = int(np.fromfile(f, np.int32, count=1))
301 | #print 'Reading %d x %d flo file' % (w, h)
302 |
303 | data = np.fromfile(f, np.float32, count=2*w*h)
304 |
305 | # Reshape data into 3D array (columns, rows, bands)
306 | flow = np.resize(data, (h, w, 2))
307 |
308 | return flow
309 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | import torch.nn as nn
4 | from libs.net_utils import NLM, NLM_dot, NLM_woSoft
5 | from torchvision.models import resnet18
6 | from autoencoder import encoder3, decoder3, encoder_res18, encoder_res50
7 | from torch.utils.serialization import load_lua
8 | from libs.utils import *
9 |
10 | def transform(aff, frame1):
11 | """
12 | Given aff, copy from frame1 to construct frame2.
13 | INPUTS:
14 | - aff: (h*w)*(h*w) affinity matrix
15 | - frame1: n*c*h*w feature map
16 | """
17 | b,c,h,w = frame1.size()
18 | frame1 = frame1.view(b,c,-1)
19 | frame2 = torch.bmm(frame1, aff)
20 | return frame2.view(b,c,h,w)
21 |
22 | class normalize(nn.Module):
23 | """Given mean: (R, G, B) and std: (R, G, B),
24 | will normalize each channel of the torch.*Tensor, i.e.
25 | channel = (channel - mean) / std
26 | """
27 |
28 | def __init__(self, mean, std = (1.0,1.0,1.0)):
29 | super(normalize, self).__init__()
30 | self.mean = nn.Parameter(torch.FloatTensor(mean).cuda(), requires_grad=False)
31 | self.std = nn.Parameter(torch.FloatTensor(std).cuda(), requires_grad=False)
32 |
33 | def forward(self, frames):
34 | b,c,h,w = frames.size()
35 | frames = (frames - self.mean.view(1,3,1,1).repeat(b,1,h,w))/self.std.view(1,3,1,1).repeat(b,1,h,w)
36 | return frames
37 |
38 | def create_flat_grid(F_size, GPU=True):
39 | """
40 | INPUTS:
41 | - F_size: feature size
42 | OUTPUT:
43 | - return a standard grid coordinate
44 | """
45 | b, c, h, w = F_size
46 | theta = torch.tensor([[1,0,0],[0,1,0]])
47 | theta = theta.unsqueeze(0).repeat(b,1,1)
48 | theta = theta.float()
49 |
50 | # grid is a uniform grid with left top (-1,1) and right bottom (1,1)
51 | # b * (h*w) * 2
52 | grid = torch.nn.functional.affine_grid(theta, F_size)
53 | grid[:,:,:,0] = (grid[:,:,:,0]+1)/2 * w
54 | grid[:,:,:,1] = (grid[:,:,:,1]+1)/2 * h
55 | grid_flat = grid.view(b,-1,2)
56 | if(GPU):
57 | grid_flat = grid_flat.cuda()
58 | return grid_flat
59 |
60 |
61 | def coords2bbox(coords, patch_size, h_tar, w_tar):
62 | """
63 | INPUTS:
64 | - coords: coordinates of pixels in the next frame
65 | - patch_size: patch size
66 | - h_tar: target image height
67 | - w_tar: target image widthg
68 | """
69 | b = coords.size(0)
70 | center = torch.mean(coords, dim=1) # b * 2
71 | center_repeat = center.unsqueeze(1).repeat(1,coords.size(1),1)
72 | dis_x = torch.sqrt(torch.pow(coords[:,:,0] - center_repeat[:,:,0], 2))
73 | dis_x = torch.mean(dis_x, dim=1).detach()
74 | dis_y = torch.sqrt(torch.pow(coords[:,:,1] - center_repeat[:,:,1], 2))
75 | dis_y = torch.mean(dis_y, dim=1).detach()
76 | left = (center[:,0] - dis_x*2).view(b,1)
77 | left[left < 0] = 0
78 | right = (center[:,0] + dis_x*2).view(b,1)
79 | right[right > w_tar] = w_tar
80 | top = (center[:,1] - dis_y*2).view(b,1)
81 | top[top < 0] = 0
82 | bottom = (center[:,1] + dis_y*2).view(b,1)
83 | bottom[bottom > h_tar] = h_tar
84 | new_center = torch.cat((left,right,top,bottom),dim=1)
85 | return new_center
86 |
87 |
88 |
89 | class track_match_comb(nn.Module):
90 | def __init__(self, pretrained, encoder_dir = None, decoder_dir = None, temp=1, Resnet = "r18", color_switch=True, coord_switch=True):
91 | super(track_match_comb, self).__init__()
92 |
93 | if Resnet in "r18":
94 | self.gray_encoder = encoder_res18(pretrained=pretrained, uselayer=4)
95 | elif Resnet in "r50":
96 | self.gray_encoder = encoder_res50(pretrained=pretrained, uselayer=4)
97 | self.rgb_encoder = encoder3(reduce=True)
98 | self.decoder = decoder3(reduce=True)
99 |
100 | self.rgb_encoder.load_state_dict(torch.load(encoder_dir))
101 | self.decoder.load_state_dict(torch.load(decoder_dir))
102 | for param in self.decoder.parameters():
103 | param.requires_grad = False
104 | for param in self.rgb_encoder.parameters():
105 | param.requires_grad = False
106 |
107 | self.nlm = NLM_woSoft()
108 | self.normalize = normalize(mean=[0.485, 0.456, 0.406],
109 | std=[0.229, 0.224, 0.225])
110 | self.softmax = nn.Softmax(dim=1)
111 | self.temp = temp
112 | self.grid_flat = None
113 | self.grid_flat_crop = None
114 | self.color_switch = color_switch
115 | self.coord_switch = coord_switch
116 |
117 |
118 | def forward(self, img_ref, img_tar, warm_up=True, patch_size=None):
119 | n, c, h_ref, w_ref = img_ref.size()
120 | n, c, h_tar, w_tar = img_tar.size()
121 | gray_ref = copy.deepcopy(img_ref[:,0].view(n,1,h_ref,w_ref).repeat(1,3,1,1))
122 | gray_tar = copy.deepcopy(img_tar[:,0].view(n,1,h_tar,w_tar).repeat(1,3,1,1))
123 |
124 | gray_ref = (gray_ref + 1) / 2
125 | gray_tar = (gray_tar + 1) / 2
126 |
127 | gray_ref = self.normalize(gray_ref)
128 | gray_tar = self.normalize(gray_tar)
129 |
130 | Fgray1 = self.gray_encoder(gray_ref)
131 | Fgray2 = self.gray_encoder(gray_tar)
132 | Fcolor1 = self.rgb_encoder(img_ref)
133 |
134 | output = []
135 |
136 | if warm_up:
137 | aff = self.nlm(Fgray1, Fgray2)
138 | aff_norm = self.softmax(aff)
139 | Fcolor2_est = transform(aff_norm, Fcolor1)
140 | color2_est = self.decoder(Fcolor2_est)
141 |
142 | output.append(color2_est)
143 | output.append(aff)
144 |
145 | if self.color_switch:
146 | Fcolor2 = self.rgb_encoder(img_tar)
147 | Fcolor1_est = transform(aff_norm.transpose(1,2), Fcolor2)
148 | color1_est = self.decoder(Fcolor1_est)
149 | output.append(color1_est)
150 | else:
151 | if(self.grid_flat is None):
152 | self.grid_flat = create_flat_grid(Fgray2.size())
153 | aff_ref_tar = self.nlm(Fgray1, Fgray2)
154 | aff_ref_tar = torch.nn.functional.softmax(aff_ref_tar * self.temp, dim = 2)
155 | coords = torch.bmm(aff_ref_tar, self.grid_flat)
156 | center = torch.mean(coords, dim=1) # b * 2
157 | # new_c = center2bbox(center, patch_size, h_tar, w_tar)
158 | new_c = center2bbox(center, patch_size, Fgray2.size(2), Fgray2.size(3))
159 | # print("center2bbox:", new_c, h_tar, w_tar)
160 |
161 | Fgray2_crop = diff_crop(Fgray2, new_c[:,0], new_c[:,2], new_c[:,1], new_c[:,3], patch_size[1], patch_size[0])
162 | # print("HERE: ", Fgray2.size(), Fgray1.size(), Fgray2_crop.size())
163 |
164 | aff_p = self.nlm(Fgray1, Fgray2_crop)
165 | aff_norm = self.softmax(aff_p * self.temp)
166 | Fcolor2_est = transform(aff_norm, Fcolor1)
167 | color2_est = self.decoder(Fcolor2_est)
168 |
169 | Fcolor2_full = self.rgb_encoder(img_tar)
170 | Fcolor2_crop = diff_crop(Fcolor2_full, new_c[:,0], new_c[:,2], new_c[:,1], new_c[:,3], patch_size[1], patch_size[0])
171 |
172 | output.append(color2_est)
173 | output.append(aff_p)
174 | output.append(new_c*8)
175 | output.append(coords)
176 |
177 | # color orthorganal
178 | if self.color_switch:
179 | Fcolor1_est = transform(aff_norm.transpose(1,2), Fcolor2_crop)
180 | color1_est = self.decoder(Fcolor1_est)
181 | output.append(color1_est)
182 |
183 | # coord orthorganal
184 | if self.coord_switch:
185 | aff_norm_tran = self.softmax(aff_p.permute(0,2,1)*self.temp)
186 | if self.grid_flat_crop is None:
187 | self.grid_flat_crop = create_flat_grid(Fp_tar.size()).permute(0,2,1).detach()
188 | C12 = torch.bmm(self.grid_flat_crop, aff_norm)
189 | C11 = torch.bmm(C12, aff_norm_tran)
190 | output.append(self.grid_flat_crop)
191 | output.append(C11)
192 |
193 | return output
194 |
195 |
196 | class Model_switchGTfixdot_swCC_Res(nn.Module):
197 | def __init__(self, encoder_dir = None, decoder_dir = None, fix_dec = True,
198 | temp = None, pretrainRes = False, uselayer=3, model='resnet18'):
199 | '''
200 | For switchable concenration loss
201 | Using Resnet18
202 | '''
203 | super(Model_switchGTfixdot_swCC_Res, self).__init__()
204 | if(model == 'resnet18'):
205 | print('Use ResNet18.')
206 | self.gray_encoder = encoder_res18(pretrained = pretrainRes, uselayer=uselayer)
207 | else:
208 | print('Use ResNet50.')
209 | self.gray_encoder = encoder_res50(pretrained = pretrainRes, uselayer=uselayer)
210 | self.rgb_encoder = encoder3(reduce = True)
211 | self.nlm = NLM_woSoft()
212 | self.decoder = decoder3(reduce = True)
213 | self.temp = temp
214 | self.softmax = nn.Softmax(dim=1)
215 | self.cos_window = torch.Tensor(np.outer(np.hanning(40), np.hanning(40))).cuda()
216 | self.normalize = normalize(mean=[0.485, 0.456, 0.406],
217 | std=[0.229, 0.224, 0.225])
218 |
219 | self.rgb_encoder.load_state_dict(torch.load(encoder_dir))
220 | self.decoder.load_state_dict(torch.load(decoder_dir))
221 |
222 | for param in self.decoder.parameters():
223 | param.requires_grad = False
224 | for param in self.rgb_encoder.parameters():
225 | param.requires_grad = False
226 |
227 | def forward(self, gray1, gray2, color1=None, color2=None):
228 | # move gray scale image to 0-1 so that they match ImageNet pre-training
229 | gray1 = (gray1 + 1) / 2
230 | gray2 = (gray2 + 1) / 2
231 |
232 | # normalize to fit resnet
233 | b = gray1.size(0)
234 |
235 | gray1 = self.normalize(gray1)
236 | gray2 = self.normalize(gray2)
237 |
238 | Fgray1 = self.gray_encoder(gray1)
239 | Fgray2 = self.gray_encoder(gray2)
240 |
241 | aff = self.nlm(Fgray1, Fgray2) # bx4096x4096
242 | aff_norm = self.softmax(aff*self.temp)
243 |
244 | if(color1 is None):
245 | return aff_norm, Fgray1, Fgray2
246 |
247 | Fcolor1 = self.rgb_encoder(color1)
248 | Fcolor2 = self.rgb_encoder(color2)
249 | Fcolor2_est = transform(aff_norm, Fcolor1)
250 | pred2 = self.decoder(Fcolor2_est)
251 |
252 | Fcolor1_est = transform(aff_norm.transpose(1,2), Fcolor2)
253 | pred1 = self.decoder(Fcolor1_est)
254 |
255 | return pred1, pred2, aff_norm, aff, Fgray1, Fgray2
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # OS libraries
2 | import os
3 | import copy
4 | import queue
5 | import argparse
6 | import scipy.misc
7 | import numpy as np
8 | from tqdm import tqdm
9 |
10 | # Pytorch
11 | import torch
12 | import torch.nn as nn
13 |
14 | # Customized libraries
15 | from libs.test_utils import *
16 | from libs.model import transform
17 | from libs.vis_utils import norm_mask
18 | from libs.model import Model_switchGTfixdot_swCC_Res as Model
19 |
20 | ############################## helper functions ##############################
21 | def parse_args():
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument("--batch_size", type = int, default = 1,
24 | help = "batch size")
25 | parser.add_argument("-o","--out_dir",type = str,default = "results/",
26 | help = "output saving path")
27 | parser.add_argument("--device", type = int, default = 5,
28 | help = "0~4 for single GPU, 5 for dataparallel.")
29 | parser.add_argument("-c","--checkpoint_dir",type = str,
30 | default = "weights/checkpoint_latest.pth.tar",
31 | help = "checkpoints path")
32 | parser.add_argument("-s", "--scale_size", type = int, nargs = "+",
33 | help = "scale size, a single number for shorter edge, or a pair for height and width")
34 | parser.add_argument("--pre_num", type = int, default = 7,
35 | help = "preceding frame numbers")
36 | parser.add_argument("--temp", type = float,default = 1,
37 | help = "softmax temperature")
38 | parser.add_argument("--topk", type = int, default = 5,
39 | help = "accumulate label from top k neighbors")
40 | parser.add_argument("-d", "--davis_dir", type = str,
41 | default = "/workspace/DAVIS/",
42 | help = "davis dataset path")
43 |
44 | args = parser.parse_args()
45 | args.is_train = False
46 |
47 | args.multiGPU = args.device == 5
48 | if not args.multiGPU:
49 | torch.cuda.set_device(args.device)
50 |
51 | args.val_txt = os.path.join(args.davis_dir, "ImageSets/2017/val.txt")
52 | args.davis_dir = os.path.join(args.davis_dir, "JPEGImages/480p/")
53 |
54 | return args
55 |
56 | ############################## testing functions ##############################
57 |
58 | def forward(frame1, frame2, model, seg):
59 | """
60 | propagate seg of frame1 to frame2
61 | """
62 | n, c, h, w = frame1.size()
63 | frame1_gray = frame1[:,0].view(n,1,h,w)
64 | frame2_gray = frame2[:,0].view(n,1,h,w)
65 | frame1_gray = frame1_gray.repeat(1,3,1,1)
66 | frame2_gray = frame2_gray.repeat(1,3,1,1)
67 |
68 | output = model(frame1_gray, frame2_gray, frame1, frame2)
69 | aff = output[2]
70 |
71 | frame2_seg = transform_topk(aff,seg.cuda(),k=args.topk)
72 |
73 | return frame2_seg
74 |
75 | def test(model, frame_list, video_dir, first_seg, seg_ori):
76 | """
77 | test on a video given first frame & segmentation
78 | """
79 | video_dir = os.path.join(video_dir)
80 | video_nm = video_dir.split('/')[-1]
81 | video_folder = os.path.join(args.out_dir, video_nm)
82 | os.makedirs(video_folder, exist_ok = True)
83 |
84 | transforms = create_transforms()
85 |
86 | # The queue stores args.pre_num preceding frames
87 | que = queue.Queue(args.pre_num)
88 |
89 | # first frame
90 | frame1, ori_h, ori_w = read_frame(frame_list[0], transforms, args.scale_size)
91 | n, c, h, w = frame1.size()
92 |
93 | # saving first segmentation
94 | out_path = os.path.join(video_folder,"00000.png")
95 | imwrite_indexed(out_path, seg_ori)
96 |
97 | for cnt in tqdm(range(1,len(frame_list))):
98 | frame_tar, ori_h, ori_w = read_frame(frame_list[cnt], transforms, args.scale_size)
99 |
100 | with torch.no_grad():
101 | # frame 1 -> frame cnt
102 | frame_tar_acc = forward(frame1, frame_tar, model, first_seg)
103 |
104 | # frame cnt - i -> frame cnt, (i = 1, ..., pre_num)
105 | tmp_queue = list(que.queue)
106 | for pair in tmp_queue:
107 | framei = pair[0]
108 | segi = pair[1]
109 | frame_tar_est_i = forward(framei, frame_tar, model, segi)
110 | frame_tar_acc += frame_tar_est_i
111 | frame_tar_avg = frame_tar_acc / (1 + len(tmp_queue))
112 |
113 | frame_nm = frame_list[cnt].split('/')[-1].replace(".jpg",".png")
114 | out_path = os.path.join(video_folder,frame_nm)
115 |
116 | # pop out oldest frame if neccessary
117 | if(que.qsize() == args.pre_num):
118 | que.get()
119 | # push current results into queue
120 | seg = copy.deepcopy(frame_tar_avg)
121 | frame, ori_h, ori_w = read_frame(frame_list[cnt], transforms, args.scale_size)
122 | que.put([frame,seg])
123 |
124 | # upsampling & argmax
125 | frame_tar_avg = torch.nn.functional.interpolate(frame_tar_avg,scale_factor=8,mode='bilinear')
126 | frame_tar_avg = frame_tar_avg.squeeze()
127 | frame_tar_avg = norm_mask(frame_tar_avg.squeeze())
128 | _, frame_tar_seg = torch.max(frame_tar_avg, dim=0)
129 |
130 | # saving to disk
131 | frame_tar_seg = frame_tar_seg.squeeze().cpu().numpy()
132 | frame_tar_seg = np.array(frame_tar_seg, dtype=np.uint8)
133 | frame_tar_seg = scipy.misc.imresize(frame_tar_seg, (ori_h, ori_w), "nearest")
134 |
135 | output_path = os.path.join(video_folder, frame_nm.split('.')[0]+'_seg.png')
136 | imwrite_indexed(out_path,frame_tar_seg)
137 |
138 | ############################## main function ##############################
139 |
140 | if(__name__ == '__main__'):
141 | args = parse_args()
142 | with open(args.val_txt) as f:
143 | lines = f.readlines()
144 | f.close()
145 |
146 | # loading pretrained model
147 | model = Model(pretrainRes=False, temp = args.temp, uselayer=4)
148 | if(args.multiGPU):
149 | model = nn.DataParallel(model)
150 | checkpoint = torch.load(args.checkpoint_dir)
151 | best_loss = checkpoint['best_loss']
152 | model.load_state_dict(checkpoint['state_dict'])
153 | print("=> loaded checkpoint '{} ({})' (epoch {})"
154 | .format(args.checkpoint_dir, best_loss, checkpoint['epoch']))
155 | model.cuda()
156 | model.eval()
157 |
158 |
159 | # start testing
160 | for cnt,line in enumerate(lines):
161 | video_nm = line.strip()
162 | print('[{:n}/{:n}] Begin to segmentate video {}.'.format(cnt,len(lines),video_nm))
163 |
164 | video_dir = os.path.join(args.davis_dir, video_nm)
165 | frame_list = read_frame_list(video_dir)
166 | seg_dir = frame_list[0].replace("JPEGImages","Annotations")
167 | seg_dir = seg_dir.replace("jpg","png")
168 | _, first_seg, seg_ori = read_seg(seg_dir, args.scale_size)
169 | test(model, frame_list, video_dir, first_seg, seg_ori)
170 |
--------------------------------------------------------------------------------
/test_with_track.py:
--------------------------------------------------------------------------------
1 | # OS libraries
2 | import os
3 | import cv2
4 | import glob
5 | import copy
6 | import math
7 | import queue
8 | import argparse
9 | import scipy.misc
10 | import numpy as np
11 | from tqdm import tqdm
12 | from PIL import Image
13 |
14 | # Pytorch libraries
15 | import torch
16 | import torch.nn as nn
17 |
18 | # Customized libraries
19 | from libs.test_utils import *
20 | from libs.model import transform
21 | from libs.vis_utils import norm_mask
22 | import libs.transforms_pair as transforms
23 | from libs.model import Model_switchGTfixdot_swCC_Res as Model
24 | from libs.track_utils import seg2bbox, draw_bbox, match_ref_tar
25 | from libs.track_utils import squeeze_all, seg2bbox_v2, bbox_in_tar_scale
26 |
27 | ############################## helper functions ##############################
28 |
29 | def parse_args():
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument("--batch_size", type = int, default = 1,
32 | help = "batch size")
33 | parser.add_argument("-o","--out_dir", type = str,default="results_with_track/",
34 | help = "output path")
35 | parser.add_argument("--device", type = int, default = 5,
36 | help="0~4 for single GPU, 5 for dataparallel.")
37 | parser.add_argument("-c","--checkpoint_dir",type = str,
38 | default = "weights/checkpoint_latest.pth.tar",
39 | help = "checkpoints path")
40 | parser.add_argument("-s", "--scale_size", type = int, nargs = '+',
41 | help = "scale size, either a single number for short edge, or a pair for height and width")
42 | parser.add_argument("--pre_num", type = int, default = 7,
43 | help = "preceding frame numbers")
44 | parser.add_argument("--temp", type = float, default = 1,
45 | help = "softmax temperature")
46 | parser.add_argument("-t", "--topk", type = int, default = 5,
47 | help = "accumulate label from top k neighbors")
48 | parser.add_argument("-d", "--davis_dir", type = str,
49 | default = "/workspace/DAVIS/",
50 | help = "davis dataset path")
51 |
52 | print("Begin parser arguments.")
53 | args = parser.parse_args()
54 | args.is_train = False
55 |
56 | args.multiGPU = args.device == 5
57 | if not args.multiGPU:
58 | torch.cuda.set_device(args.device)
59 |
60 | args.val_txt = os.path.join(args.davis_dir, "ImageSets/2017/val.txt")
61 | args.davis_dir = os.path.join(args.davis_dir, "JPEGImages/480p/")
62 | return args
63 |
64 | def vis_bbox(im, bbox, name, coords, seg):
65 | im = im * 128 + 128
66 | im = im.squeeze().permute(1,2,0).cpu().numpy().astype(np.uint8)
67 | im = cv2.cvtColor(im, cv2.COLOR_LAB2BGR)
68 | fg_idx = seg.nonzero()
69 | im = draw_bbox(im, bbox, (0,0,255))
70 |
71 | for cnt in range(coords.size(0)):
72 | coord_i = coords[cnt]
73 |
74 | cv2.circle(im, (int(coord_i[0]*8), int(coord_i[1]*8)), 2, (0,255,0), thickness=-1)
75 | cv2.imwrite(name, im)
76 |
77 | ############################## tracking functions ##############################
78 |
79 | def adjust_bbox(bbox_now, bbox_pre, a, h, w):
80 | """
81 | Adjust a bounding box w.r.t previous frame,
82 | assuming objects don't go under abrupt changes.
83 | """
84 | for cnt in bbox_pre.keys():
85 | if(cnt == 0):
86 | continue
87 | if(cnt in bbox_now and bbox_pre[cnt] is not None and bbox_now[cnt] is not None):
88 | bbox_now_h = (bbox_now[cnt].top + bbox_now[cnt].bottom) / 2.0
89 | bbox_now_w = (bbox_now[cnt].left + bbox_now[cnt].right) / 2.0
90 |
91 | bbox_now_height_ = bbox_now[cnt].bottom - bbox_now[cnt].top
92 | bbox_now_width_ = bbox_now[cnt].right - bbox_now[cnt].left
93 |
94 | bbox_pre_height = bbox_pre[cnt].bottom - bbox_pre[cnt].top
95 | bbox_pre_width = bbox_pre[cnt].right - bbox_pre[cnt].left
96 |
97 | bbox_now_height = a * bbox_now_height_ + (1 - a) * bbox_pre_height
98 | bbox_now_width = a * bbox_now_width_ + (1 - a) * bbox_pre_width
99 |
100 | bbox_now[cnt].left = math.floor(bbox_now_w - bbox_now_width / 2.0)
101 | bbox_now[cnt].right = math.ceil(bbox_now_w + bbox_now_width / 2.0)
102 | bbox_now[cnt].top = math.floor(bbox_now_h - bbox_now_height / 2.0)
103 | bbox_now[cnt].bottom = math.ceil(bbox_now_h + bbox_now_height / 2.0)
104 |
105 | bbox_now[cnt].left = max(0, bbox_now[cnt].left)
106 | bbox_now[cnt].right = min(w, bbox_now[cnt].right)
107 | bbox_now[cnt].top = max(0, bbox_now[cnt].top)
108 | bbox_now[cnt].bottom = min(h, bbox_now[cnt].bottom)
109 |
110 | return bbox_now
111 |
112 | def bbox_next_frame(img_ref, seg_ref, img_tar, bbox_ref):
113 | """
114 | Match bbox from the reference frame to the target frame
115 | """
116 | F_ref, F_tar = forward(img_ref, img_tar, model, seg_ref, return_feature=True)
117 | seg_ref = seg_ref.squeeze(0)
118 | F_ref, F_tar = squeeze_all(F_ref, F_tar)
119 | c, h, w = F_ref.size()
120 |
121 | # get coordinates of each point in the target frame
122 | coords_ref_tar = match_ref_tar(F_ref, F_tar, seg_ref, args.temp)
123 | # coordinates -> bbox
124 | bbox_tar = bbox_in_tar_scale(coords_ref_tar, bbox_ref, h, w)
125 | # adjust bbox
126 | bbox_tar = adjust_bbox(bbox_tar, bbox_ref, 0.1, h, w)
127 | return bbox_tar, coords_ref_tar
128 |
129 | def recoginition(img_ref, img_tar, bbox_ref, bbox_tar, seg_ref, model):
130 | """
131 | propagate from bbox in the reference frame to bbox in the target frame
132 | """
133 | F_ref, F_tar = forward(img_ref, img_tar, model, seg_ref, return_feature=True)
134 | seg_ref = seg_ref.squeeze()
135 | _, c, h, w = F_tar.size()
136 | seg_pred = torch.zeros(seg_ref.size())
137 |
138 | # calculate affinity only once to save time
139 | aff_whole = torch.mm(F_ref.view(c,-1).permute(1,0), F_tar.view(c,-1))
140 | aff_whole = torch.nn.functional.softmax(aff_whole * args.temp, dim=0)
141 |
142 | for cnt, br in bbox_ref.items():
143 | if not (cnt in bbox_tar):
144 | continue
145 | bt = bbox_tar[cnt]
146 | if(br is None or bt is None):
147 | continue
148 | seg_cnt = seg_ref[cnt]
149 |
150 | # affinity between two patches
151 | seg_ref_box = seg_cnt[br.top:br.bottom, br.left:br.right]
152 | seg_ref_box = seg_ref_box.unsqueeze(0).unsqueeze(0)
153 |
154 | h, w = F_ref.size(2), F_ref.size(3)
155 | mask = torch.zeros(h,w)
156 | mask[br.top:br.bottom, br.left:br.right] = 1
157 | mask = mask.view(-1)
158 | aff_row = aff_whole[mask.nonzero().squeeze(), :]
159 |
160 | h, w = F_tar.size(2), F_tar.size(3)
161 | mask = torch.zeros(h,w)
162 | mask[bt.top:bt.bottom, bt.left:bt.right] = 1
163 | mask = mask.view(-1)
164 | aff = aff_row[:, mask.nonzero().squeeze()]
165 |
166 | aff = aff.unsqueeze(0)
167 |
168 | seg_tar_box = transform_topk(aff,seg_ref_box.cuda(),k=args.topk,
169 | h2 = bt.bottom - bt.top,w2 = bt.right - bt.left)
170 | seg_pred[cnt, bt.top:bt.bottom, bt.left:bt.right] = seg_tar_box
171 |
172 | return seg_pred
173 |
174 | def disappear(seg,bbox_ref,bbox_tar=None):
175 | """
176 | Check if bbox disappear in the target frame.
177 | """
178 | b,c,h,w = seg.size()
179 | for cnt in range(c):
180 | if(torch.sum(seg[:,cnt,:,:]) < 3 or (not (cnt in bbox_ref))):
181 | return True
182 | if(bbox_ref[cnt] is None):
183 | return True
184 | if(bbox_ref[cnt].right - bbox_ref[cnt].left < 3 or bbox_ref[cnt].bottom - bbox_ref[cnt].top < 3):
185 | return True
186 |
187 | if(bbox_tar is not None):
188 | if(cnt not in bbox_tar.keys()):
189 | return True
190 | if(bbox_tar[cnt] is None):
191 | return True
192 | if(bbox_tar[cnt].right - bbox_tar[cnt].left < 3 or bbox_tar[cnt].bottom - bbox_tar[cnt].top < 3):
193 | return True
194 | return False
195 |
196 | ############################## testing functions ##############################
197 |
198 | def forward(frame1, frame2, model, seg, return_feature=False):
199 | n, c, h, w = frame1.size()
200 | frame1_gray = frame1[:,0].view(n,1,h,w)
201 | frame2_gray = frame2[:,0].view(n,1,h,w)
202 | frame1_gray = frame1_gray.repeat(1,3,1,1)
203 | frame2_gray = frame2_gray.repeat(1,3,1,1)
204 |
205 | output = model(frame1_gray, frame2_gray, frame1, frame2)
206 | if(return_feature):
207 | return output[-2], output[-1]
208 |
209 | aff = output[2]
210 | frame2_seg = transform_topk(aff,seg.cuda(),k=args.topk)
211 |
212 | return frame2_seg
213 |
214 | def test(model, frame_list, video_dir, first_seg, large_seg, first_bbox, seg_ori):
215 | video_dir = os.path.join(video_dir)
216 | video_nm = video_dir.split('/')[-1]
217 | video_folder = os.path.join(args.out_dir, video_nm)
218 | os.makedirs(video_folder, exist_ok = True)
219 | os.makedirs(os.path.join(video_folder, 'track'), exist_ok = True)
220 |
221 | transforms = create_transforms()
222 |
223 | # The queue stores `pre_num` preceding frames
224 | que = queue.Queue(args.pre_num)
225 |
226 | # frame 1
227 | frame1, ori_h, ori_w = read_frame(frame_list[0], transforms, args.scale_size)
228 | n, c, h, w = frame1.size()
229 |
230 | # saving first segmentation
231 | out_path = os.path.join(video_folder,"00000.png")
232 | imwrite_indexed(out_path, seg_ori)
233 |
234 | coords = first_seg[0,1].nonzero()
235 | coords = coords.flip(1)
236 |
237 | for cnt in tqdm(range(1,len(frame_list))):
238 | frame_tar, ori_h, ori_w = read_frame(frame_list[cnt], transforms, args.scale_size)
239 |
240 | with torch.no_grad():
241 |
242 | tmp_list = list(que.queue)
243 | if(len(tmp_list) > 0):
244 | pair = tmp_list[-1]
245 | framei = pair[0]
246 | segi = pair[1]
247 | bbox_pre = pair[2]
248 | else:
249 | bbox_pre = first_bbox
250 | framei = frame1
251 | segi = first_seg
252 | _, segi_int = torch.max(segi, dim=1)
253 | segi = to_one_hot(segi_int)
254 | bbox_tar, coords_ref_tar = bbox_next_frame(framei, segi, frame_tar, bbox_pre)
255 |
256 | if(bbox_tar is not None):
257 |
258 | if(1 in bbox_tar):
259 | tmp = copy.deepcopy(bbox_tar[1])
260 | if(tmp is not None):
261 | tmp.upscale(8)
262 | vis_bbox(frame_tar, tmp, os.path.join(video_folder, 'track', 'frame'+str(cnt+1)+'.png'), coords_ref_tar[1], segi[0,1,:,:])
263 | frame_tar_acc = recoginition(frame1, frame_tar, first_bbox, bbox_tar, first_seg, model)
264 | else:
265 | frame_tar_acc = forward(frame1, frame_tar, model, first_seg)
266 | frame_tar_acc = frame_tar_acc.cpu()
267 |
268 |
269 | # previous 7 frames
270 | tmp_queue = list(que.queue)
271 | for pair in tmp_queue:
272 | framei = pair[0]
273 | segi = pair[1]
274 | bboxi = pair[2]
275 |
276 | if(bbox_tar is None or disappear(segi, bboxi, bbox_tar)):
277 | frame_tar_est_i = forward(framei, frame_tar, model, segi)
278 | frame_tar_est_i = frame_tar_est_i.cpu()
279 | else:
280 |
281 | frame_tar_est_i = recoginition(framei, frame_tar, bboxi, bbox_tar, segi, model)
282 |
283 | frame_tar_acc += frame_tar_est_i.cpu().view(frame_tar_acc.size())
284 | frame_tar_avg = frame_tar_acc / (1 + len(tmp_queue))
285 |
286 | frame_nm = frame_list[cnt].split('/')[-1].replace(".jpg",".png")
287 | out_path = os.path.join(video_folder,frame_nm)
288 |
289 | # upsampling & argmax
290 | if(frame_tar_avg.dim() == 3):
291 | frame_tar_avg = frame_tar_avg.unsqueeze(0)
292 | elif(frame_tar_avg.dim() == 2):
293 | frame_tar_avg = frame_tar_avg.unsqueeze(0).unsqueeze(0)
294 | frame_tar_up = torch.nn.functional.interpolate(frame_tar_avg,scale_factor=8,mode='bilinear')
295 |
296 | frame_tar_up = frame_tar_up.squeeze()
297 | frame_tar_up = norm_mask(frame_tar_up.squeeze())
298 | _, frame_tar_seg = torch.max(frame_tar_up.squeeze(), dim=0)
299 |
300 | frame_tar_seg = frame_tar_seg.squeeze().cpu().numpy()
301 | frame_tar_seg = np.array(frame_tar_seg, dtype=np.uint8)
302 | frame_tar_seg = scipy.misc.imresize(frame_tar_seg, (ori_h, ori_w), "nearest")
303 | imwrite_indexed(out_path,frame_tar_seg)
304 |
305 | if(que.qsize() == args.pre_num):
306 | que.get()
307 | seg = copy.deepcopy(frame_tar_avg.squeeze())
308 | frame, ori_h, ori_w = read_frame(frame_list[cnt], transforms, args.scale_size)
309 | bbox_tar = seg2bbox_v2(frame_tar_up.cpu(), bbox_pre)
310 | bbox_tar = adjust_bbox(bbox_tar, bbox_pre, 0.1, h, w)
311 | que.put([frame,seg.unsqueeze(0),bbox_tar])
312 |
313 | if(__name__ == '__main__'):
314 | args = parse_args()
315 | with open(args.val_txt) as f:
316 | lines = f.readlines()
317 | f.close()
318 |
319 | model = Model(pretrainRes=False, temp = args.temp, uselayer=4)
320 | if(args.multiGPU):
321 | model = nn.DataParallel(model)
322 | checkpoint = torch.load(args.checkpoint_dir)
323 | best_loss = checkpoint['best_loss']
324 | model.load_state_dict(checkpoint['state_dict'])
325 | print("=> loaded checkpoint '{} ({})' (epoch {})"
326 | .format(args.checkpoint_dir, best_loss, checkpoint['epoch']))
327 | model.cuda()
328 | model.eval()
329 |
330 | for cnt,line in enumerate(lines):
331 | video_nm = line.strip()
332 | print('[{:n}/{:n}] Begin to segmentate video {}.'.format(cnt,len(lines),video_nm))
333 |
334 | video_dir = os.path.join(args.davis_dir, video_nm)
335 | frame_list = read_frame_list(video_dir)
336 | seg_dir = frame_list[0].replace("JPEGImages","Annotations")
337 | seg_dir = seg_dir.replace("jpg","png")
338 | large_seg, first_seg, seg_ori = read_seg(seg_dir, args.scale_size)
339 |
340 | first_bbox = seg2bbox(large_seg, margin=0.6)
341 | for k,v in first_bbox.items():
342 | v.upscale(0.125)
343 |
344 | test(model, frame_list, video_dir, first_seg, large_seg, first_bbox, seg_ori)
345 |
--------------------------------------------------------------------------------
/track_match_v1.py:
--------------------------------------------------------------------------------
1 | # a combination of track and match
2 | # 1. load fullres images, resize to 640**2
3 | # 2. warmup: set random location for crop
4 | # 3. loc-match: add attention
5 | import os
6 | import cv2
7 | import sys
8 | import time
9 | import torch
10 | import logging
11 | import argparse
12 | import numpy as np
13 | import torch.nn as nn
14 | from libs.loader import VidListv1, VidListv2
15 | import torch.backends.cudnn as cudnn
16 | import libs.transforms_multi as transforms
17 | from model import track_match_comb as Model
18 | from libs.loss import L1_loss
19 | from libs.concentration_loss import ConcentrationSwitchLoss as ConcentrationLoss
20 | from libs.train_utils import save_vis, AverageMeter, save_checkpoint, log_current
21 | from libs.utils import diff_crop
22 |
23 |
24 | FORMAT = "[%(asctime)-15s %(filename)s:%(lineno)d %(funcName)s] %(message)s"
25 | logging.basicConfig(format=FORMAT)
26 | logger = logging.getLogger(__name__)
27 | logger.setLevel(logging.DEBUG)
28 |
29 | def parse_args():
30 | parser = argparse.ArgumentParser(description='')
31 |
32 | # file/folder pathes
33 | parser.add_argument("--videoRoot", type=str, default="/Data2/Kinetices/compress/train_256/", help='train video path')
34 | parser.add_argument("--videoList", type=str, default="/Data2/Kinetices/compress/train.txt", help='train video list (after "train_256")')
35 | parser.add_argument("--encoder_dir",type=str, default='ae_small/encoder_single_gpu.pth', help="pretrained encoder")
36 | parser.add_argument("--decoder_dir",type=str, default='ae_small/decoder_single_gpu.pth', help="pretrained decoder")
37 | parser.add_argument('--resume', type=str, default='', metavar='PATH', help='path to latest checkpoint (default: none)')
38 | parser.add_argument("-c","--savedir",type=str,default="match_track_comb/",help='checkpoints path')
39 | parser.add_argument("--Resnet", type=str, default="r18", help="choose from r18 or r50")
40 |
41 | # main parameters
42 | parser.add_argument("--pretrainRes",action="store_true")
43 | parser.add_argument("--batchsize",type=int, default=1, help="batchsize")
44 | parser.add_argument('--workers', type=int, default=16)
45 | parser.add_argument("--patch_size", type=int, default=256, help="crop size for localization.")
46 | parser.add_argument("--full_size", type=int, default=640, help="full size for one frame.")
47 | parser.add_argument("--rotate",type=int,default=10,help='degree to rotate training images')
48 | parser.add_argument("--scale",type=float,default=1.2,help='random scale')
49 | parser.add_argument("--lr",type=float,default=0.0001,help='learning rate')
50 | parser.add_argument('--lr-mode', type=str, default='poly')
51 | parser.add_argument("--window_len",type=int,default=2,help='number of images (2 for pair and 3 for triple)')
52 | parser.add_argument("--log_interval",type=int,default=10,help='')
53 | parser.add_argument("--save_interval",type=int,default=1000,help='save every x epoch')
54 | parser.add_argument("--momentum",type=float,default=0.9,help='momentum')
55 | parser.add_argument("--weight_decay",type=float,default=0.005,help='weight decay')
56 | parser.add_argument("--device", type=int, default=4, help="0~device_count-1 for single GPU, device_count for dataparallel.")
57 | parser.add_argument("--temp", type=int, default=1, help="temprature for softmax.")
58 |
59 | # set epoches
60 | parser.add_argument("--wepoch",type=int,default=10,help='warmup epoch')
61 | parser.add_argument("--nepoch",type=int,default=20,help='max epoch')
62 |
63 | # concenration regularization
64 | parser.add_argument("--lc",type=float,default=1e4, help='weight of concentration loss')
65 | parser.add_argument("--lc_win",type=int,default=8, help='win_len for concentration loss')
66 |
67 | # orthorganal regularization
68 | parser.add_argument("--color_switch",type=float,default=0.1, help='weight of color switch loss')
69 | parser.add_argument("--coord_switch",type=float,default=0.1, help='weight of color switch loss')
70 |
71 |
72 | print("Begin parser arguments.")
73 | args = parser.parse_args()
74 | assert args.videoRoot is not None
75 | assert args.videoList is not None
76 | if not os.path.exists(args.savedir):
77 | os.mkdir(args.savedir)
78 | args.savepatch = os.path.join(args.savedir,'savepatch')
79 | args.logfile = open(os.path.join(args.savedir,"logargs.txt"),"w")
80 | args.multiGPU = args.device == torch.cuda.device_count()
81 |
82 | if not args.multiGPU:
83 | torch.cuda.set_device(args.device)
84 | if not os.path.exists(args.savepatch):
85 | os.mkdir(args.savepatch)
86 |
87 | args.vis = True
88 | if args.color_switch > 0:
89 | args.color_switch_flag = True
90 | else:
91 | args.color_switch_flag = False
92 | if args.coord_switch > 0:
93 | args.coord_switch_flag = True
94 | else:
95 | args.coord_switch_flag = False
96 |
97 | try:
98 | from tensorboardX import SummaryWriter
99 | global writer
100 | writer = SummaryWriter()
101 | except ImportError:
102 | args.vis = False
103 | print(' '.join(sys.argv))
104 | print('\n')
105 | args.logfile.write(' '.join(sys.argv))
106 | args.logfile.write('\n')
107 |
108 | for k, v in args.__dict__.items():
109 | print(k, ':', v)
110 | args.logfile.write('{}:{}\n'.format(k,v))
111 | args.logfile.close()
112 | return args
113 |
114 |
115 | def adjust_learning_rate(args, optimizer, epoch):
116 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
117 | if args.lr_mode == 'step':
118 | lr = args.lr * (0.1 ** (epoch // args.step))
119 | elif args.lr_mode == 'poly':
120 | lr = args.lr * (1 - epoch / args.nepoch) ** 0.9
121 | else:
122 | raise ValueError('Unknown lr mode {}'.format(args.lr_mode))
123 |
124 | for param_group in optimizer.param_groups:
125 | param_group['lr'] = lr
126 | return lr
127 |
128 |
129 | def create_loader(args):
130 | dataset_train_warm = VidListv1(args.videoRoot, args.videoList, args.patch_size, args.rotate, args.scale)
131 | dataset_train = VidListv2(args.videoRoot, args.videoList, args.patch_size, args.window_len, args.rotate, args.scale, args.full_size)
132 |
133 | if args.multiGPU:
134 | train_loader_warm = torch.utils.data.DataLoader(
135 | dataset_train_warm, batch_size=args.batchsize, shuffle = True, num_workers=args.workers, pin_memory=True, drop_last=True)
136 | train_loader = torch.utils.data.DataLoader(
137 | dataset_train, batch_size=args.batchsize, shuffle = True, num_workers=args.workers, pin_memory=True, drop_last=True)
138 | else:
139 | train_loader_warm = torch.utils.data.DataLoader(
140 | dataset_train_warm, batch_size=args.batchsize, shuffle = True, num_workers=0, drop_last=True)
141 | train_loader = torch.utils.data.DataLoader(
142 | dataset_train, batch_size=args.batchsize, shuffle = True, num_workers=0, drop_last=True)
143 | return train_loader_warm, train_loader
144 |
145 |
146 | def train(args):
147 | loader_warm, loader = create_loader(args)
148 | cudnn.benchmark = True
149 | best_loss = 1e10
150 | start_epoch = 0
151 |
152 | model = Model(args.pretrainRes, args.encoder_dir, args.decoder_dir, temp = args.temp, Resnet = args.Resnet, color_switch = args.color_switch_flag, coord_switch = args.coord_switch_flag)
153 |
154 | if args.multiGPU:
155 | model = torch.nn.DataParallel(model).cuda()
156 | closs = ConcentrationLoss(win_len=args.lc_win, stride=args.lc_win,
157 | F_size=torch.Size((args.batchsize//torch.cuda.device_count(),2, args.patch_size//8, args.patch_size//8)), temp = args.temp)
158 | closs = nn.DataParallel(closs).cuda()
159 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model._modules['module'].parameters()),args.lr)
160 | else:
161 | closs = ConcentrationLoss(win_len=args.lc_win, stride=args.lc_win,
162 | F_size=torch.Size((args.batchsize,2,
163 | args.patch_size//8,
164 | args.patch_size//8)), temp = args.temp)
165 | model.cuda()
166 | closs.cuda()
167 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),args.lr)
168 |
169 | if args.resume:
170 | if os.path.isfile(args.resume):
171 | print("=> loading checkpoint '{}'".format(args.resume))
172 | checkpoint = torch.load(args.resume)
173 | start_epoch = checkpoint['epoch']
174 | best_loss = checkpoint['best_loss']
175 | model.load_state_dict(checkpoint['state_dict'])
176 | print("=> loaded checkpoint '{} ({})' (epoch {})"
177 | .format(args.resume, best_loss, checkpoint['epoch']))
178 | else:
179 | print("=> no checkpoint found at '{}'".format(args.resume))
180 |
181 | for epoch in range(start_epoch, args.nepoch):
182 | if epoch < args.wepoch:
183 | lr = adjust_learning_rate(args, optimizer, epoch)
184 | print("Base lr for epoch {}: {}.".format(epoch, optimizer.param_groups[0]['lr']))
185 | best_loss = train_iter(args, loader_warm, model, closs, optimizer, epoch, best_loss)
186 | else:
187 | lr = adjust_learning_rate(args, optimizer, epoch-args.wepoch)
188 | print("Base lr for epoch {}: {}.".format(epoch, optimizer.param_groups[0]['lr']))
189 | best_loss = train_iter(args, loader, model, closs, optimizer, epoch, best_loss)
190 |
191 |
192 | def forward(frame1, frame2, model, warm_up, patch_size=None):
193 | n, c, h, w = frame1.size()
194 | if warm_up:
195 | output = model(frame1, frame2)
196 | else:
197 | output = model(frame1, frame2, warm_up=False, patch_size=[patch_size//8, patch_size//8])
198 | new_c = output[2]
199 | # gt patch
200 | # print("HERE2: ", frame2.size(), new_c, patch_size)
201 | color2_gt = diff_crop(frame2, new_c[:,0], new_c[:,2], new_c[:,1], new_c[:,3],
202 | patch_size, patch_size)
203 | output.append(color2_gt)
204 | return output
205 |
206 |
207 | def train_iter(args, loader, model, closs, optimizer, epoch, best_loss):
208 | losses = AverageMeter()
209 | batch_time = AverageMeter()
210 | losses = AverageMeter()
211 | c_losses = AverageMeter()
212 | model.train()
213 | end = time.time()
214 | if args.coord_switch_flag:
215 | coord_switch_loss = nn.L1Loss()
216 | sc_losses = AverageMeter()
217 |
218 | if epoch < 1 or (epoch>=args.wepoch and epoch< args.wepoch+2):
219 | thr = None
220 | else:
221 | thr = 2.5
222 |
223 | for i,frames in enumerate(loader):
224 | frame1_var = frames[0].cuda()
225 | frame2_var = frames[1].cuda()
226 |
227 | if epoch < args.wepoch:
228 | output = forward(frame1_var, frame2_var, model, warm_up=True)
229 | color2_est = output[0]
230 | aff = output[1]
231 | b,x,_ = aff.size()
232 | color1_est = None
233 | if args.color_switch_flag:
234 | color1_est = output[2]
235 | loss_ = L1_loss(color2_est, frame2_var, 10, 10, thr=thr, pred1=color1_est, frame1_var = frame1_var)
236 |
237 | if epoch >=1 and args.lc > 0:
238 | constraint_loss = torch.sum(closs(aff.view(b,1,x,x))) * args.lc
239 | c_losses.update(constraint_loss.item(), frame1_var.size(0))
240 | loss = loss_ + constraint_loss
241 | else:
242 | loss = loss_
243 | if(i % args.log_interval == 0):
244 | save_vis(color2_est, frame2_var, frame1_var, frame2_var, args.savepatch)
245 | else:
246 | # print("input: ", frame1_var.size(), frame2_var.size())
247 | output = forward(frame1_var, frame2_var, model, warm_up=False, patch_size = args.patch_size)
248 | color2_est = output[0]
249 | aff = output[1]
250 | new_c = output[2]
251 | coords = output[3]
252 | Fcolor2_crop = output[-1]
253 |
254 | b,x,x = aff.size()
255 | color1_est = None
256 | count = 3
257 |
258 | constraint_loss = torch.sum(closs(aff.view(b,1,x,x))) * args.lc
259 | c_losses.update(constraint_loss.item(), frame1_var.size(0))
260 |
261 | if args.color_switch_flag:
262 | count += 1
263 | color1_est = output[count]
264 |
265 | loss_color = L1_loss(color2_est, Fcolor2_crop, 10, 10, thr=thr, pred1=color1_est, frame1_var = frame1_var)
266 | loss_ = loss_color + constraint_loss
267 |
268 | if args.coord_switch_flag:
269 | count += 1
270 | grids = output[count]
271 | C11 = output[count+1]
272 | loss_coord = args.coord_switch * coord_switch_loss(C11, grids)
273 | loss = loss_ + loss_coord
274 | sc_losses.update(loss_coord.item(), frame1_var.size(0))
275 | else:
276 | loss = loss_
277 |
278 | if(i % args.log_interval == 0):
279 | save_vis(color2_est, Fcolor2_crop, frame1_var, frame2_var, args.savepatch, new_c)
280 |
281 | losses.update(loss.item(), frame1_var.size(0))
282 | optimizer.zero_grad()
283 | loss.backward()
284 | optimizer.step()
285 | batch_time.update(time.time() - end)
286 | end = time.time()
287 |
288 | if epoch >= args.wepoch and args.coord_switch_flag:
289 | logger.info('Epoch: [{0}][{1}/{2}]\t'
290 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
291 | 'Color Loss {loss.val:.4f} ({loss.avg:.4f})\t '
292 | 'Coord switch Loss {scloss.val:.4f} ({scloss.avg:.4f})\t '
293 | 'Constraint Loss {c_loss.val:.4f} ({c_loss.avg:.4f})\t '.format(
294 | epoch, i+1, len(loader), batch_time=batch_time, loss=losses, scloss=sc_losses, c_loss= c_losses))
295 | else:
296 | logger.info('Epoch: [{0}][{1}/{2}]\t'
297 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
298 | 'Color Loss {loss.val:.4f} ({loss.avg:.4f})\t '
299 | 'Constraint Loss {c_loss.val:.4f} ({c_loss.avg:.4f})\t '.format(
300 | epoch, i+1, len(loader), batch_time=batch_time, loss=losses, c_loss= c_losses))
301 |
302 | if((i + 1) % args.save_interval == 0):
303 | is_best = losses.avg < best_loss
304 | best_loss = min(losses.avg, best_loss)
305 | checkpoint_path = os.path.join(args.savedir, 'checkpoint_latest.pth.tar')
306 | save_checkpoint({
307 | 'epoch': epoch + 1,
308 | 'state_dict': model.state_dict(),
309 | 'best_loss': best_loss,
310 | }, is_best, filename=checkpoint_path, savedir = args.savedir)
311 | log_current(epoch, losses.avg, best_loss, filename = "log_current.txt", savedir=args.savedir)
312 |
313 | return best_loss
314 |
315 |
316 | if __name__ == '__main__':
317 | args = parse_args()
318 | train(args)
319 | writer.close()
--------------------------------------------------------------------------------
/weights/checkpoint_latest.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaolonw/UVC-1/fafb36f1577080e8ecfa09d97dc2c024b04ccdb2/weights/checkpoint_latest.pth.tar
--------------------------------------------------------------------------------
/weights/decoder_single_gpu.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaolonw/UVC-1/fafb36f1577080e8ecfa09d97dc2c024b04ccdb2/weights/decoder_single_gpu.pth
--------------------------------------------------------------------------------
/weights/encoder_single_gpu.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaolonw/UVC-1/fafb36f1577080e8ecfa09d97dc2c024b04ccdb2/weights/encoder_single_gpu.pth
--------------------------------------------------------------------------------