├── README.md ├── dsbn_init.py ├── model.py ├── sesr.py ├── sesr_common.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # CDCL 2 | Cross-Dataset Collaborative Learning for Semantic Segmentation in Autonomous Driving 3 | 4 | We are preparing the project and will release the source code later. 5 | -------------------------------------------------------------------------------- /dsbn_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # load resnet imagenet weight 5 | def _update_initial_weights_dsbn(state_dict, num_classes=1000, num_domains=2): 6 | new_state_dict = state_dict.copy() 7 | 8 | for key, val in state_dict.items(): 9 | update_dict = True 10 | 11 | if (update_dict): 12 | if 'weight' in key: 13 | for d in range(num_domains): 14 | new_state_dict[key[0:-6] + 'bns.{}.weight'.format(d)] = val.data.clone() 15 | 16 | elif 'bias' in key: 17 | for d in range(num_domains): 18 | new_state_dict[key[0:-4] + 'bns.{}.bias'.format(d)] = val.data.clone() 19 | 20 | if 'running_mean' in key: 21 | for d in range(num_domains): 22 | new_state_dict[key[0:-12] + 'bns.{}.running_mean'.format(d)] = val.data.clone() 23 | 24 | if 'running_var' in key: 25 | for d in range(num_domains): 26 | new_state_dict[key[0:-11] + 'bns.{}.running_var'.format(d)] = val.data.clone() 27 | 28 | if 'num_batches_tracked' in key: 29 | for d in range(num_domains): 30 | new_state_dict[ 31 | key[0:-len('num_batches_tracked')] + 'bns.{}.num_batches_tracked'.format(d)] = val.data.clone() 32 | 33 | if num_classes != 1000 or len([key for key in new_state_dict.keys() if 'fc' in key]) > 1: 34 | key_list = list(new_state_dict.keys()) 35 | for key in key_list: 36 | if 'fc' in key: 37 | print('pretrained {} are not used as initial params.'.format(key)) 38 | del new_state_dict[key] 39 | 40 | return new_state_dict 41 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .utils import icnr_init 6 | from .sesr import SESR 7 | 8 | class ConvBlock(nn.Sequential): 9 | def __init__(self, input_channels, output_channels, kernel_size=3, with_bn=False): 10 | layers = [nn.Conv2d(input_channels, output_channels, kernel_size, padding=kernel_size//2, bias=(not with_bn))] 11 | if with_bn: layers.append(nn.BatchNorm2d(output_channels)) 12 | layers.append(nn.ReLU()) 13 | 14 | nn.init.kaiming_uniform_(layers[0].weight) 15 | if not with_bn: 16 | nn.init.normal_(layers[0].bias, mean=0.0, std=0.01) 17 | 18 | super().__init__(*layers) 19 | 20 | class ConvNextBlock(nn.Module): 21 | def __init__(self, input_channels, output_channels, kernel_size=3): 22 | super().__init__() 23 | assert input_channels==output_channels, "For ConvNextBlock input_channels has to be equal to output_channels." 24 | self.dwpw = nn.Sequential( 25 | nn.Conv2d(input_channels, output_channels, kernel_size, padding=kernel_size//2, groups=output_channels, bias=False), 26 | nn.BatchNorm2d(output_channels), 27 | nn.Conv2d(output_channels, output_channels, kernel_size=1, bias=True), 28 | nn.GELU() 29 | ) 30 | self.dwpw_ib = nn.Sequential( 31 | nn.Conv2d(output_channels, output_channels, kernel_size, padding=kernel_size//2, groups=output_channels, bias=False), 32 | nn.BatchNorm2d(output_channels), 33 | nn.Conv2d(output_channels, 2*output_channels, kernel_size=1, groups=2, bias=True), 34 | nn.GELU(), 35 | nn.Conv2d(2*output_channels, output_channels, kernel_size=1, groups=2, bias=True), 36 | ) 37 | 38 | def forward(self, x): 39 | mid = x + self.dwpw(x) 40 | out = mid + self.dwpw_ib(mid) 41 | return out 42 | 43 | class UNetTiny(nn.Module): 44 | def __init__(self, input_channels): 45 | super().__init__() 46 | self.encoder1 = nn.Sequential( 47 | ConvBlock(input_channels, 16), 48 | ConvBlock(16, 16), 49 | ) 50 | self.encoder2 = nn.Sequential( 51 | ConvBlock(16, 32), 52 | ConvBlock(32, 32), 53 | ) 54 | self.encoder3 = nn.Sequential( 55 | ConvBlock(32, 48), 56 | ConvBlock(48, 48), 57 | ) 58 | self.bottleneck = nn.Sequential( 59 | ConvBlock(48, 64), 60 | ConvBlock(64, 64), 61 | ) 62 | self.decoder3 = nn.Sequential( 63 | ConvBlock(64+48, 48, kernel_size=1), 64 | ConvBlock(48, 48) 65 | ) 66 | self.decoder2 = nn.Sequential( 67 | ConvBlock(48+32, 32, kernel_size=1), 68 | ConvBlock(32, 32) 69 | ) 70 | self.decoder1 = nn.Sequential( 71 | ConvBlock(32+16, 32, kernel_size=1), 72 | ConvBlock(32, 32) 73 | ) 74 | 75 | def forward(self, x): 76 | skip1 = self.encoder1(x) 77 | skip2 = self.encoder2(F.max_pool2d(skip1, kernel_size=2)) 78 | skip3 = self.encoder3(F.max_pool2d(skip2, kernel_size=2)) 79 | up3 = self.bottleneck(F.max_pool2d(skip3, kernel_size=2)) 80 | up2 = self.decoder3(torch.cat([F.interpolate(up3, scale_factor=2, mode='nearest'), skip3], dim=1)) 81 | up1 = self.decoder2(torch.cat([F.interpolate(up2, scale_factor=2, mode='nearest'), skip2], dim=1)) 82 | out = self.decoder1(torch.cat([F.interpolate(up1, scale_factor=2, mode='nearest'), skip1], dim=1)) 83 | return out 84 | 85 | class UNetBaseline(nn.Module): # ~25k MACs/pixel 86 | def __init__(self, input_channels): 87 | super().__init__() 88 | self.input_conv = nn.Conv2d(input_channels, 16, kernel_size=1) 89 | self.encoder1 = nn.Sequential( 90 | ConvBlock(16, 16), 91 | ConvBlock(16, 16), 92 | ) 93 | self.encoder2 = nn.Sequential( 94 | ConvBlock(16, 32), 95 | ConvBlock(32, 32), 96 | ConvBlock(32, 32), 97 | ) 98 | self.bottleneck = nn.Sequential( 99 | ConvBlock(32, 48), 100 | ConvBlock(48, 48), 101 | ConvBlock(48, 48), 102 | ) 103 | self.decoder2 = nn.Sequential( 104 | ConvBlock(48+32, 32, kernel_size=1), 105 | ConvBlock(32, 32), 106 | ConvBlock(32, 32) 107 | ) 108 | self.decoder1 = nn.Sequential( 109 | ConvBlock(32+16, 16, kernel_size=1), 110 | ConvBlock(16, 16), 111 | ConvBlock(16, 16) 112 | ) 113 | 114 | def forward(self, x): 115 | x = self.input_conv(x) 116 | skip1 = self.encoder1(x) 117 | skip2 = self.encoder2(F.max_pool2d(skip1, kernel_size=2)) 118 | up2 = self.bottleneck(F.max_pool2d(skip2, kernel_size=2)) 119 | up1 = self.decoder2(torch.cat([F.interpolate(up2, scale_factor=2, mode='nearest'), skip2], dim=1)) 120 | out = self.decoder1(torch.cat([F.interpolate(up1, scale_factor=2, mode='nearest'), skip1], dim=1)) 121 | return out 122 | 123 | class FilterNetwork(nn.Module): 124 | def __init__(self, input_channels, recurrent_channels): 125 | super().__init__() 126 | self.recurrent_channels = recurrent_channels 127 | self.conv = nn.Conv2d(input_channels, 10 + recurrent_channels, kernel_size=1) 128 | nn.init.zeros_(self.conv.bias) 129 | 130 | def forward(self, features, sr_color, history_color): 131 | filters, recurrent = self.conv(features).split([10, self.recurrent_channels], dim=1) 132 | 133 | filters = F.softmax(filters, dim=1) 134 | recurrent = torch.sigmoid(recurrent) # We use sigmoid for recurrent 135 | 136 | def apply_filter(img, filter): 137 | N,C,H,W = img.shape 138 | KS = int(filter.size(1)**0.5) 139 | assert KS**2 == filter.size(1) and KS%2 == 1, f"Wrong filtering kernel shape: {filter.shape}" 140 | 141 | # unfold (N,C,H,W) tensor into (N,C*KS*KS,H*W) patches 142 | img_unfolded = F.unfold(img, (KS, KS), padding=KS//2) 143 | img_patches = img_unfolded.reshape(N,C,KS*KS,H,W) 144 | 145 | img_filtered = (img_patches * filter.unsqueeze(1)).sum(dim=2) 146 | return img_filtered 147 | 148 | current_filter, history_filter = filters.split([9,1], dim=1) 149 | 150 | sr_color_filtered = apply_filter(sr_color, current_filter) 151 | history_color_filtered = history_color * history_filter 152 | 153 | return sr_color_filtered + history_color_filtered, recurrent 154 | 155 | class ColorPredictionNetwork(nn.Module): 156 | def __init__(self, input_channels, recurrent_channels): 157 | super().__init__() 158 | self.recurrent_channels = recurrent_channels 159 | self.conv = nn.Conv2d(input_channels, 3 + recurrent_channels, kernel_size=1) 160 | nn.init.zeros_(self.conv.bias) 161 | 162 | def forward(self, features, sr_color, history_color): 163 | # sr_color and history_color are not used but they're kept for FilterNetwork compatibility 164 | color, recurrent = self.conv(features).split([3, self.recurrent_channels], dim=1) 165 | 166 | color = torch.relu(color) 167 | recurrent = torch.relu(recurrent) # We use sigmoid for recurrent 168 | 169 | return color, recurrent 170 | 171 | class AMDNetTiny(nn.Module): 172 | def __init__(self): 173 | super().__init__() 174 | self.feature_network = UNetTiny(16) 175 | self.prediction_network = FilterNetwork(8, recurrent_channels=1) 176 | 177 | # INCR init beacuase of pixel shuffle 178 | self.feature_network.decoder1[1][0].weight.data.copy_(icnr_init(self.feature_network.decoder1[1][0].weight.data)) 179 | 180 | def forward(self, x, sr_color, prev_color): 181 | input_unshuffled = F.pixel_unshuffle(x, downscale_factor=2) 182 | features = self.feature_network(input_unshuffled) 183 | features_shuffled = F.pixel_shuffle(features, upscale_factor=2) 184 | 185 | final_color, recurrent_state = self.prediction_network(features_shuffled, sr_color, prev_color) 186 | 187 | return final_color, recurrent_state 188 | 189 | class KPNBaseline(nn.Module): 190 | def __init__(self): 191 | super().__init__() 192 | self.feature_network = SESR(in_channels=11, return_outs=16) 193 | self.prediction_network = FilterNetwork(16, recurrent_channels=8) 194 | 195 | def forward(self, x, sr_color, prev_color): 196 | features = self.feature_network(x) 197 | 198 | final_color, recurrent_state = self.prediction_network(features, sr_color, prev_color) 199 | 200 | return final_color, recurrent_state 201 | 202 | class CPNBaseline(nn.Module): 203 | def __init__(self): 204 | super().__init__() 205 | self.feature_network = SESR(in_channels=11, return_outs=16) 206 | self.prediction_network = ColorPredictionNetwork(16, recurrent_channels=12) 207 | 208 | def forward(self, x, sr_color, prev_color): 209 | features = self.feature_network(x) 210 | 211 | final_color, recurrent_state = self.prediction_network(features, sr_color, prev_color) 212 | 213 | return final_color, recurrent_state 214 | -------------------------------------------------------------------------------- /sesr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import mlsr_training.sesr_common as common 4 | import copy 5 | import numpy as np 6 | import torch.nn as nn 7 | from collections import OrderedDict 8 | # from .convolution import reset_state, detach_state 9 | # def make_model(args, parent=False): 10 | # return SESR(args) 11 | 12 | def Conv(in_channels, out_channels, kernel_size, stride, padding, groups=1): 13 | conv_seq = nn.Sequential() 14 | conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 15 | kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False) 16 | conv_seq.add_module('conv', conv) 17 | # conv_seq.add_module('bn', nn.BatchNorm2d(out_channels)) 18 | return conv_seq 19 | 20 | ## Residual Block (RCAB) 21 | class RCAB(nn.Module): 22 | def __init__( 23 | self, conv, n_feat, kernel_size, reduction, 24 | res_scale=1, deploy=False): 25 | 26 | super(RCAB, self).__init__() 27 | self.in_channels = n_feat 28 | self.groups = 1 29 | self.res_scale = res_scale 30 | self.deploy = deploy 31 | self.se = nn.Identity() 32 | r = 16 33 | 34 | if deploy: 35 | self.body_reparam = nn.Conv2d(in_channels=n_feat, out_channels=n_feat, kernel_size=kernel_size, stride=1, 36 | padding=1, dilation=1, groups=1, bias=False) 37 | 38 | else: 39 | self.body_identity = None 40 | 41 | self.body_dense = Conv(n_feat, r*n_feat, kernel_size, 1, 1, 1) 42 | self.body_dense_1x1 = Conv(r*n_feat, n_feat, 1, 1, 0, 1) 43 | self.body_1x1 = Conv(n_feat, n_feat, 1, 1, 0, 1) 44 | #print('Rep Block, identity = ', self.body_identity) 45 | 46 | def forward(self, x): 47 | if hasattr(self, 'body_reparam'): 48 | return self.body_reparam(x) 49 | else: 50 | if self.body_identity is None: 51 | id_out = 0 52 | else: 53 | id_out = self.body_identity(x) 54 | y = self.body_dense(x) 55 | y = self.body_dense_1x1(y) 56 | return y + self.body_1x1(x) + id_out 57 | 58 | def get_equivalent_kernel_bias(self): 59 | kernel3x3, bias3x3 = self._fuse_tensor(self.body_dense) 60 | kernel1x1, bias1x1 = self._fuse_tensor(self.body_1x1) 61 | kernelid, biasid = self._fuse_tensor(self.body_identity) 62 | return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, None#bias3x3 + bias1x1 + biasid 63 | 64 | def _pad_1x1_to_3x3_tensor(self, kernel1x1): 65 | if kernel1x1 is None: 66 | return 0 67 | else: 68 | return torch.nn.functional.pad(kernel1x1, [1,1,1,1]) 69 | 70 | 71 | def _fuse_tensor(self, branch): 72 | 73 | if branch is None: 74 | return 0, 0 75 | if isinstance(branch, nn.Sequential): 76 | kernel = branch.conv.weight 77 | else: 78 | assert isinstance(branch, nn.BatchNorm2d) 79 | if not hasattr(self, 'id_tensor'): 80 | input_dim = self.in_channels // self.groups 81 | kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32) 82 | 83 | for i in range(self.in_channels): 84 | kernel_value[i, i % input_dim, 1, 1] = 1 85 | self.id_tensor = torch.from_numpy(kernel_value) 86 | kernel = self.id_tensor 87 | return kernel, None 88 | 89 | def switch_to_deploy(self): 90 | if hasattr(self, 'body_reparam'): 91 | return 92 | kernel, bias = self.get_equivalent_kernel_bias() 93 | self.body_reparam = nn.Conv2d(in_channels=self.body_dense.conv.in_channels, out_channels=self.body_dense.conv.out_channels, 94 | kernel_size=self.body_dense.conv.kernel_size, stride=self.body_dense.conv.stride, 95 | padding=self.body_dense.conv.padding, dilation=self.body_dense.conv.dilation, groups=self.body_dense.conv.groups, bias=False) 96 | 97 | self.body_reparam.weight.data = kernel 98 | for para in self.parameters(): 99 | para.detach_() 100 | self.__delattr__('body_dense') 101 | self.__delattr__('body_1x1') 102 | if hasattr(self, 'body_identity'): 103 | self.__delattr__('body_identity') 104 | 105 | ## RonvGroup 106 | class ConvGroup(nn.Module): 107 | def __init__( 108 | self, conv, in_feat, mid_feat, out_feat, kernel_size, deploy=False): 109 | 110 | super(ConvGroup, self).__init__() 111 | self.deploy = deploy 112 | 113 | if deploy: 114 | self.body_reparam = nn.Conv2d(in_channels=in_feat, out_channels=out_feat, kernel_size=kernel_size, stride=1, 115 | padding=(kernel_size - 1) // 2, dilation=1, groups=1, bias=False) 116 | else: 117 | self.body_dense = Conv(in_feat, mid_feat, kernel_size, 1, (kernel_size - 1) // 2, 1) 118 | self.body_1x1 = Conv(mid_feat, out_feat, 1, 1, 0, 1) 119 | 120 | def forward(self, x): 121 | if hasattr(self, 'body_reparam'): 122 | return self.body_reparam(x) 123 | else: 124 | x = self.body_dense(x) 125 | x = self.body_1x1(x) 126 | return x 127 | 128 | def merge_tensor(self): 129 | kernel5x5 = self.body_dense.conv.weight 130 | kernel1x1 = self.body_1x1.conv.weight 131 | return torch.conv2d(kernel5x5.permute(1, 0, 2, 3), kernel1x1.flip(-1, -2), padding=0).permute(1, 0, 2, 3) 132 | 133 | def switch_to_deploy(self): 134 | if hasattr(self, 'body_reparam'): 135 | return 136 | kernel = self.merge_tensor() 137 | self.body_reparam = nn.Conv2d(in_channels=self.body_dense.conv.in_channels, out_channels=self.body_1x1.conv.out_channels, 138 | kernel_size=self.body_dense.conv.kernel_size, stride=self.body_dense.conv.stride, 139 | padding=self.body_dense.conv.padding, dilation=self.body_dense.conv.dilation, groups=self.body_dense.conv.groups, bias=False) 140 | 141 | self.body_reparam.weight.data = kernel 142 | for para in self.parameters(): 143 | para.detach_() 144 | self.__delattr__('body_dense') 145 | self.__delattr__('body_1x1') 146 | 147 | class ResidualGroup(nn.Module): 148 | def __init__(self, conv, n_feat, kernel_size, reduction, res_scale, n_resblocks, deploy): 149 | super(ResidualGroup, self).__init__() 150 | self.deploy = deploy 151 | modules_body = [] 152 | modules_body = [ 153 | RCAB( 154 | conv, n_feat, kernel_size, reduction, res_scale=1, deploy=self.deploy) \ 155 | for _ in range(n_resblocks)] 156 | self.body = nn.Sequential(*modules_body) 157 | self.act = nn.ReLU() 158 | 159 | def forward(self, x): 160 | res = self.body(x) 161 | res += x 162 | return self.act(res) 163 | 164 | class SESR(nn.Module): 165 | def __init__(self, in_channels:int, return_outs:bool, gen_cfg=None, scale:int=2, conv=common.default_conv): 166 | super(SESR, self).__init__() 167 | 168 | # n_resgroups = 11 169 | # n_resblocks = 1 170 | # n_feats = 16 171 | # kernel_size = 3 172 | # reduction = 16 173 | # scale = 1 174 | # act = nn.PReLU() 175 | # deploy = False 176 | # res_scale = 1 177 | # self.deploy = deploy 178 | 179 | n_resgroups = 7 180 | n_resblocks = 1 181 | n_feats = 16 182 | kernel_size = 3 183 | reduction = 16 184 | scale = 1 185 | act = nn.PReLU() 186 | deploy = False 187 | res_scale = 1 188 | self.deploy = deploy 189 | 190 | # define head module 191 | self.head = ConvGroup(conv, in_feat=15, mid_feat=256, out_feat=n_feats, kernel_size=5, deploy=deploy) 192 | 193 | # define body module 194 | modules_body = [ 195 | ResidualGroup( 196 | conv, n_feats, kernel_size, reduction, res_scale=res_scale, n_resblocks=n_resblocks, deploy=deploy) \ 197 | for _ in range(n_resgroups)] 198 | 199 | self.body = nn.Sequential(*modules_body) 200 | 201 | # define tail module 202 | self.tail = ConvGroup(conv, in_feat=n_feats, mid_feat=256, out_feat=scale*scale*16, kernel_size=5, deploy=deploy) 203 | 204 | def forward(self, x): 205 | x = self.head(x) 206 | res = self.body(x) 207 | res = res + x 208 | res = self.tail(res) 209 | return res 210 | 211 | def reset(self): 212 | self.head.reset() 213 | for m in self.body: 214 | m.reset() 215 | self.tail.reset() 216 | 217 | def detach(self): 218 | self.head.detach() 219 | for m in self.body: 220 | m.detach() 221 | self.tail.detach() 222 | 223 | def load_state_dict(self, state_dict, strict=False): 224 | own_state = self.state_dict() 225 | for name, param in state_dict.items(): 226 | if name in own_state: 227 | if isinstance(param, nn.Parameter): 228 | param = param.data 229 | try: 230 | own_state[name].copy_(param) 231 | except Exception: 232 | if name.find('tail') >= 0: 233 | print('Replace pre-trained upsampler to new one...') 234 | else: 235 | raise RuntimeError('While copying the parameter named {}, ' 236 | 'whose dimensions in the model are {} and ' 237 | 'whose dimensions in the checkpoint are {}.' 238 | .format(name, own_state[name].size(), param.size())) 239 | elif strict: 240 | if name.find('tail') == -1: 241 | raise KeyError('unexpected key "{}" in state_dict' 242 | .format(name)) 243 | 244 | if strict: 245 | missing = set(own_state.keys()) - set(state_dict.keys()) 246 | if len(missing) > 0: 247 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) 248 | 249 | 250 | def model_convert(model:torch.nn.Module, save_path=None, do_copy=True): 251 | if do_copy: 252 | model = copy.deepcopy(model) 253 | for module in model.modules(): 254 | if hasattr(module, 'switch_to_deploy'): 255 | module.switch_to_deploy() 256 | if save_path is not None: 257 | torch.save(model.state_dict(), save_path) 258 | print('Save converted model in: ', save_path) 259 | return model 260 | -------------------------------------------------------------------------------- /sesr_common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 9 | return nn.Conv2d( 10 | in_channels, out_channels, kernel_size, 11 | padding=(kernel_size//2), bias=bias) 12 | 13 | 14 | 15 | class MeanShift(nn.Conv2d): 16 | def __init__( 17 | self, rgb_range, 18 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 19 | #print(rgb_mean, rgb_std) 20 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 21 | std = torch.Tensor(rgb_std) 22 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 23 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 24 | for p in self.parameters(): 25 | p.requires_grad = False 26 | 27 | 28 | class MeanShiftConv(nn.Module): 29 | def __init__( 30 | self, rgb_range, rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 31 | 32 | super(MeanShiftConv, self).__init__() 33 | self.mean_conv = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1, stride=1, 34 | padding=0, dilation=1, groups=1, bias=True) 35 | 36 | std = torch.Tensor(rgb_std) 37 | self.mean_conv.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 38 | self.mean_conv.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 39 | 40 | for p in self.parameters(): 41 | p.requires_grad = False 42 | 43 | def forward(self, x): 44 | return self.mean_conv(x) 45 | 46 | 47 | class BasicBlock(nn.Sequential): 48 | def __init__( 49 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, 50 | bn=True, act=nn.ReLU(True)): 51 | 52 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 53 | if bn: 54 | m.append(nn.BatchNorm2d(out_channels)) 55 | if act is not None: 56 | m.append(act) 57 | 58 | super(BasicBlock, self).__init__(*m) 59 | 60 | 61 | 62 | class ResBlock(nn.Module): 63 | def __init__( 64 | self, conv, n_feats, kernel_size, 65 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 66 | 67 | super(ResBlock, self).__init__() 68 | m = [] 69 | for i in range(2): 70 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 71 | if bn: 72 | m.append(nn.BatchNorm2d(n_feats)) 73 | if i == 0: 74 | m.append(act) 75 | 76 | self.body = nn.Sequential(*m) 77 | self.res_scale = res_scale 78 | 79 | def forward(self, x): 80 | res = self.body(x).mul(self.res_scale) 81 | res += x 82 | 83 | return res 84 | 85 | class Upsampler(nn.Sequential): 86 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 87 | 88 | m = [] 89 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 90 | for _ in range(int(math.log(scale, 2))): 91 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 92 | m.append(nn.PixelShuffle(2)) 93 | if bn: 94 | m.append(nn.BatchNorm2d(n_feats)) 95 | if act == 'relu': 96 | m.append(nn.ReLU(True)) 97 | elif act == 'prelu': 98 | m.append(nn.PReLU(n_feats)) 99 | 100 | elif scale == 3: 101 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 102 | m.append(nn.PixelShuffle(3)) 103 | if bn: 104 | m.append(nn.BatchNorm2d(n_feats)) 105 | if act == 'relu': 106 | m.append(nn.ReLU(True)) 107 | elif act == 'prelu': 108 | m.append(nn.PReLU(n_feats)) 109 | else: 110 | raise NotImplementedError 111 | 112 | super(Upsampler, self).__init__(*m) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | if torch.cuda.is_available(): 3 | torch.backends.cudnn.benchmark = True 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from torch.utils.tensorboard import SummaryWriter 8 | import torch.profiler 9 | from contextlib import ExitStack 10 | import logging 11 | 12 | from tqdm import tqdm 13 | import numpy as np 14 | import random 15 | import argparse 16 | import os, gc, time 17 | from datetime import timedelta 18 | import piq 19 | 20 | from mlsr_training.data import SequenceDataset, ExampleDataset 21 | from mlsr_training.model import AMDNetTiny, KPNBaseline, CPNBaseline 22 | from mlsr_training.utils import REPROJECT, rgb_to_luminance, convert_color_space, match_sub_arg_params, parse_sub_arg_params 23 | from mlsr_training.loss import instantiate_losses, instantiate_metrics 24 | import mlsr_training.cov 25 | 26 | import optuna, pathlib 27 | from typing import Any, Dict, List 28 | 29 | 30 | reprojection_obj_bicubic = None # Lazy init with w/h 31 | reprojection_obj_bilinear = None # Lazy init with w/h 32 | 33 | def ticks_to_human_time(ticks:int, digits:int=3): 34 | seconds = ticks # ticks are seconds... 35 | i_sec, f_sec = divmod(round(seconds*10**digits), 10**digits) 36 | return ("{}.{:0%d.0f}" % digits).format(timedelta(seconds=i_sec), f_sec) 37 | 38 | def do_batch(args:argparse.Namespace, model:torch.nn.Module, batch:List, loss_objs:Dict): 39 | # Extract buffers from batched tensors 40 | sr_color_b, lr_mv_b, hr_color_b = batch 41 | n,t,_,h,w = sr_color_b.shape 42 | 43 | global reprojection_obj_bicubic 44 | if reprojection_obj_bicubic is None: 45 | reprojection_obj_bicubic = REPROJECT(w, h, 'bicubic').to(device=args.device) 46 | global reprojection_obj_bilinear 47 | if reprojection_obj_bilinear is None: 48 | reprojection_obj_bilinear = REPROJECT(w, h, 'bilinear').to(device=args.device) 49 | 50 | # Move all data to GPU asynchronously 51 | sr_color_b = sr_color_b.to(device=args.device, non_blocking=True).clip(args.hdr_min, args.hdr_max).float() # Upscaled color in linear HDR 52 | lr_mv_b = lr_mv_b.to(device=args.device, non_blocking=True) # LR dilated motion vectors 53 | hr_color_b = hr_color_b.to(device=args.device, non_blocking=True).clip(args.hdr_min, args.hdr_max).float() # Target color in linear HDR 54 | 55 | # Tonemap colors 56 | sr_color_b = convert_color_space(sr_color_b, "linear", args.model_cspace) 57 | 58 | # Upscale batched MV 59 | hr_mv_b = F.interpolate(lr_mv_b.flatten(0,1), scale_factor=2, mode='nearest').view(n,t,2,h,w) 60 | 61 | # Iterate over time in a batch 62 | predictions = torch.empty_like(hr_color_b) 63 | for t in range(sr_color_b.shape[1]): 64 | sr_color = sr_color_b[:,t] 65 | hr_mv = hr_mv_b[:,t] 66 | 67 | # Initialize internal model states at the beginning of a sequence 68 | if t == 0: 69 | reprojected_prev_pred = sr_color 70 | reprojected_recurrent = torch.zeros(n, model.prediction_network.recurrent_channels, h, w, device=sr_color.device) 71 | 72 | # Reproject previous prediction and recurrent channel for t>0 73 | else: 74 | reprojected_prev_pred = reprojection_obj_bicubic(prev_pred, hr_mv) 75 | reprojected_prev_pred = reprojected_prev_pred.clip(min=args.ldr_min) # clip beacuse of bicubic 76 | # note that recurrent state should be bilinear interpolated 77 | reprojected_recurrent = reprojection_obj_bilinear(recurrent, hr_mv) 78 | 79 | # Prepare NN inputs (KPN version) 80 | # x = torch.cat([ 81 | # (sr_color - reprojected_prev_pred).abs().amax(dim=1, keepdim=True), # Max absolute difference 82 | # rgb_to_luminance(sr_color), # Current frame luma 83 | # rgb_to_luminance(reprojected_prev_pred), # History frame luma 84 | # reprojected_recurrent # Recurrent channel 85 | # ], dim=1) 86 | 87 | if args.graph_create: 88 | model = torch.cuda.make_graphed_callables(model, (x, sr_color, reprojected_prev_pred)) 89 | args.graph_create = False 90 | 91 | 92 | # Prepare NN inputs (CPN version) 93 | x = torch.cat([ 94 | sr_color, 95 | reprojected_recurrent 96 | ], dim=1) 97 | 98 | # Model prediction 99 | pred, new_recurrent = model(x, sr_color, reprojected_prev_pred) 100 | 101 | # Store prediction for loss calculation 102 | predictions[:,t] = pred 103 | 104 | # Set internal model states for a next frame 105 | prev_pred = pred 106 | recurrent = new_recurrent 107 | 108 | # Color conversion 109 | predictions = convert_color_space(predictions, args.model_cspace, args.loss_cspace) 110 | hr_color_b = convert_color_space(hr_color_b, "linear", args.loss_cspace) 111 | 112 | losses = {} 113 | for loss_name, loss_obj in loss_objs.items(): 114 | if 'temporal' in loss_name: 115 | losses[loss_name] = loss_obj(predictions, hr_color_b, hr_mv_b) 116 | else: 117 | losses[loss_name] = loss_obj(predictions.flatten(0,1), hr_color_b.flatten(0,1)) 118 | 119 | return losses, predictions, hr_color_b 120 | 121 | def parse_args() -> argparse.Namespace: 122 | parser = argparse.ArgumentParser() 123 | 124 | # Add `--arg`` and `--no-arg`, set passed default 125 | def add_bool(name:str, default_value:bool): 126 | parser.add_argument(f'--{name}', action='store_true') 127 | parser.add_argument(f'--no-{name}', dest=name, action='store_false') 128 | parser.set_defaults(**{name: default_value}) 129 | 130 | parser.add_argument("--model", type=str, default="CPNBaseline") 131 | parser.add_argument("--optimizer", type=str, default="RMSprop") 132 | parser.add_argument('--output_dir', type=str, default='./output') 133 | 134 | # Note, recommend using `--train_batch_size 8 --batch_accumulate_count 4` on 135 | # local machines (assuming ~20GB). On servers (assuming > 60GB RAM) 136 | # the batch size should be 32. 137 | parser.add_argument('--train_batch_size', type=int, default=2) 138 | parser.add_argument('--valid_batch_size', type=int, default=4) 139 | parser.add_argument('--test_batch_size', type=int, default=2) 140 | parser.add_argument('--lr', type=float, default=1e-4) 141 | parser.add_argument('--batch_accumulate_count', type=int, default=16) 142 | parser.add_argument('--num_workers', type=int, default=12) 143 | parser.add_argument('--train_clip_length', type=int, default=16) 144 | parser.add_argument('--epochs', type=int, default=100) 145 | 146 | add_bool("train", True) 147 | parser.add_argument('--train_dataset', type=str, default='/proj/dataset4/dongz/18films') 148 | parser.add_argument('--train_seq_count', type=int, default=20000) 149 | 150 | 151 | parser.add_argument('--training_losses', default=None, nargs="*") 152 | parser.add_argument('--validation_metrics', default=None, nargs="*") 153 | parser.add_argument('--testing_metrics', default=None, nargs="*") 154 | 155 | 156 | add_bool("valid", True) 157 | #parser.add_argument('--valid_dataset', type=str, default=None) 158 | parser.add_argument('--valid_dataset', type=str, default='/proj/dataset4/Downloads/validation_sequences/') 159 | parser.add_argument('--valid_seq_count', type=int, default=1230) 160 | 161 | add_bool("test", False) 162 | parser.add_argument('--test_dataset', type=str, default=None) 163 | parser.add_argument('--test_seq_count', type=int, default=10*640//32) 164 | 165 | # Auto weight losses (useful for testing) 166 | # Uses: https://openaccess.thecvf.com/content/WACV2021/papers/Groenendijk_Multi-Loss_Weighting_With_Coefficient_of_Variations_WACV_2021_paper.pdf 167 | add_bool("cov_weight", False) 168 | 169 | parser.add_argument('--model_cspace', type=str, default="st2084") 170 | parser.add_argument('--loss_cspace', type=str, default="st2084") 171 | parser.add_argument('--metric_cspace', type=str, default="st2084") 172 | 173 | parser.add_argument('--device', type=str, default='cuda:0') 174 | parser.add_argument('--seed', type=int, default=0) 175 | parser.add_argument('--hdr_min', type=float, default=1e-6) 176 | parser.add_argument('--hdr_max', type=float, default=10000.0) 177 | parser.add_argument('--ldr_min', type=float, default=1e-6) 178 | parser.add_argument('--ldr_max', type=float, default=1.0) # not used 179 | 180 | 181 | optuna_args = parser.add_argument_group("Optuna") 182 | optuna_args.add_argument('--optuna_trials', type=int, default=0, help="int: number of optuna trials to do") 183 | optuna_args.add_argument('--optuna_study_name', type=str, default=None, help="str: name of optuna study") 184 | optuna_args.add_argument('--optuna_study_path', type=str, default=None, help="str: path to optuna db (sqlite)") 185 | 186 | 187 | optuna_args.add_argument('--criteron_metric', type=str, default=None, help="str: validation metric to use as return value from training loop, and optuna reporting") 188 | 189 | parser.add_argument('--out_path', type=str, default=os.getcwd()) 190 | 191 | # Enable Weights and Biases 192 | optuna_args.add_argument('--wandb', type=str, default=None, help="str: Name of wandb project") 193 | 194 | # Enable Tensorboard 195 | add_bool('tb_enable', True) 196 | add_bool('profiler_enable', False) 197 | add_bool('profiler_performance', True) 198 | add_bool('profiler_memory', False) 199 | add_bool('profiler_stacks', True) 200 | add_bool('profiler_train', True) 201 | add_bool('profiler_validation', False) 202 | 203 | add_bool('graph_create', False) 204 | 205 | args = parser.parse_args() 206 | assert not (args.profiler_train and args.profiler_validation), "Can only profile one side or the other." 207 | 208 | return args 209 | 210 | def main(args:argparse.Namespace) -> float: 211 | if args.out_path != os.getcwd(): 212 | # maybe explicitly manage out paths later if this is an issue. 213 | os.makedirs(args.out_path, exist_ok=True) 214 | os.chdir(args.out_path) 215 | 216 | 217 | # Log arguments for posterity 218 | print("---------------------") 219 | print("Trainer:") 220 | print(args) 221 | print("---------------------") 222 | if args.wandb: 223 | import wandb 224 | wandb.login() 225 | wandb.init(project=args.wandb, config=args) 226 | else: 227 | wandb = None 228 | 229 | 230 | # Set logging config 231 | logging.basicConfig(format="%(levelname)s: %(message)s") 232 | 233 | # Set seeds 234 | torch.manual_seed(args.seed) 235 | random.seed(args.seed) 236 | np.random.seed(args.seed) 237 | 238 | def get_dataset(path:str, seq_count:int, split:str, arg_name:str): 239 | if path is None: 240 | logging.warning(f"Using ExampleDataset for {arg_name}") 241 | return ExampleDataset(num_sequences=seq_count) 242 | else: 243 | if not os.path.exists(path): 244 | raise ValueError(f"--{arg_name} is not a valid path: `{path}`") 245 | return SequenceDataset(path, split=split, num_sequences=seq_count, name=arg_name) 246 | 247 | # Avoid the final incomplete batch but on the granularity of the accumulated batch size. 248 | args.train_seq_count = (args.train_seq_count // (args.train_batch_size * args.batch_accumulate_count)) * (args.train_batch_size * args.batch_accumulate_count) 249 | train_ds = get_dataset(args.train_dataset, args.train_seq_count, split='train', arg_name="train_dataset") 250 | valid_ds = get_dataset(args.valid_dataset, args.valid_seq_count, split='valid', arg_name="valid_dataset") 251 | test_ds = get_dataset(args.test_dataset, args.test_seq_count, split='train', arg_name="test_dataset") 252 | 253 | if args.train and train_ds.num_sequences == 0: 254 | raise Exception(f"Error: train dataset cannot be found at {args.train_dataset}") 255 | if args.valid and valid_ds.num_sequences == 0: 256 | raise Exception(f"Error: validation dataset cannot be found at {args.valid_dataset}") 257 | if args.test and test_ds.num_sequences == 0: 258 | raise Exception(f"Error: test dataset cannot be found at {args.test_dataset}") 259 | 260 | 261 | # Create train/valid dataloaders 262 | train_dl = torch.utils.data.DataLoader(train_ds, batch_size=args.train_batch_size, drop_last=False, num_workers=args.num_workers, pin_memory=True) 263 | valid_dl = torch.utils.data.DataLoader(valid_ds, batch_size=args.valid_batch_size, drop_last=False, num_workers=args.num_workers, pin_memory=True) 264 | test_dl = torch.utils.data.DataLoader(test_ds, batch_size=args.test_batch_size, drop_last=False, num_workers=args.num_workers, pin_memory=True) 265 | 266 | # Create a model 267 | if args.model == "AMDNetTiny": 268 | model = AMDNetTiny().to(device=args.device) 269 | elif args.model == "KPNBaseline": 270 | model = KPNBaseline().to(device=args.device) 271 | elif args.model == "CPNBaseline": 272 | model = CPNBaseline().to(device=args.device) 273 | else: 274 | raise ValueError("passed --model ({args.model}) is not recognized") 275 | # import pdb;pdb.set_trace() 276 | # Create an optimizer 277 | if match_sub_arg_params("RMSprop", args.optimizer): 278 | optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, **parse_sub_arg_params(args.optimizer)) 279 | elif match_sub_arg_params("AdamW", args.optimizer): 280 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, **parse_sub_arg_params(args.optimizer)) 281 | else: 282 | raise ValueError("passed --optimizer ({args.optimizer}) is not recognized") 283 | 284 | # Create an scheduler 285 | # TODO: command line drive options 286 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80], gamma=0.1) 287 | 288 | 289 | # Default values. 290 | if args.training_losses is None: 291 | args.training_losses = ["l1(weight=0.21)", "ms-ssim(weight=0.5)", "temporal-l1(weight=0.29)"] 292 | if args.validation_metrics is None: 293 | args.validation_metrics = ["lpips", "ms-ssim(data_range=1.0)", "temporal-l1"] 294 | if args.testing_metrics is None: 295 | args.testing_metrics = ["lpips"] 296 | 297 | args.training_loss_objs, args.training_loss_weights = instantiate_losses(args.training_losses, args.device) 298 | args.validation_metric_objs = instantiate_metrics(args.validation_metrics, args.device) 299 | args.testing_metric_objs = instantiate_metrics(args.testing_metrics, args.device) 300 | 301 | # Training losses also need weights. Val/test metrics do not. 302 | if args.cov_weight: 303 | args.training_loss_covs = {} 304 | for k in args.training_loss_objs.keys(): 305 | args.training_loss_covs[k] = mlsr_training.cov.CovModule(cov_samples_max=len(train_dl)).to(device=args.device) 306 | if not os.path.exists(args.output_dir): 307 | os.makedirs(args.output_dir) 308 | tb_summary_writer = SummaryWriter(args.output_dir) if args.tb_enable else None 309 | with ExitStack() as context_stack: 310 | profiler = None 311 | if args.profiler_enable: 312 | profiler = torch.profiler.profile( 313 | activities=[ 314 | torch.profiler.ProfilerActivity.CPU, 315 | torch.profiler.ProfilerActivity.CUDA, 316 | ], 317 | schedule=torch.profiler.schedule( 318 | wait=20, 319 | warmup=2, 320 | active=5, 321 | repeat=1), 322 | on_trace_ready=torch.profiler.tensorboard_trace_handler(tb_summary_writer.log_dir), 323 | with_stack=args.profiler_stacks, 324 | record_shapes=False, 325 | with_flops=False, 326 | profile_memory=args.profiler_memory, 327 | ) 328 | context_stack.push(profiler) 329 | 330 | 331 | for epoch in range(args.epochs): 332 | # Try to prevent too much memory growth by clearing memory each epoch. 333 | gc.collect() 334 | torch.cuda.empty_cache() 335 | 336 | torch.manual_seed(args.seed + epoch) 337 | random.seed(args.seed + epoch) 338 | np.random.seed(args.seed + epoch) 339 | 340 | metrics = {} 341 | training_loss_dict = {"combined":0.0} 342 | validation_metric_dict = {} 343 | if args.criteron_metric is not None: 344 | validation_metric_dict[args.criteron_metric] = 0.0 345 | train_dl_epoch_ticks = train_proc_epoch_ticks = 0 346 | valid_dl_epoch_ticks = valid_proc_epoch_ticks = 0 347 | # Training 348 | if(args.train): 349 | train_dl.dataset.shuffle(epoch) 350 | model.train() 351 | 352 | # These need to match the order they are placed in the tensor 353 | loss_names = [] 354 | for k in args.training_loss_objs.keys(): 355 | loss_names.append(k) 356 | loss_names.append(f"weight-{k}") 357 | loss_names.append("combined") 358 | 359 | loss_values = torch.empty((len(train_dl),len(loss_names)), device=args.device, requires_grad=False) 360 | optimizer.zero_grad(set_to_none=True) 361 | 362 | def train_inner(): 363 | dl_epoch_ticks = proc_epoch_ticks = 0 364 | 365 | dl_start_ticks = time.monotonic() 366 | for batch_index, batch in enumerate(tqdm(train_dl)): 367 | proc_start_ticks = dl_end_ticks = time.monotonic() 368 | dl_epoch_ticks += dl_end_ticks - dl_start_ticks 369 | 370 | with torch.profiler.record_function(f"Train_{epoch}_{batch_index}"): 371 | # Pick a batch seed that will work the same for any batch_accumulate_count 372 | batch_seed = epoch * (len(train_dl) // args.batch_accumulate_count) + (batch_index // args.batch_accumulate_count) 373 | temporal_start = random.Random(batch_seed).randint(0, batch[0].shape[1] - args.train_clip_length) 374 | # truncate the data inputs to the requested clip length 375 | batch_small = [] 376 | for i in range(len(batch)): 377 | batch_small.append( batch[i][:,temporal_start:temporal_start+args.train_clip_length] ) 378 | batch = batch_small 379 | 380 | # Prediction 381 | losses, _, _ = do_batch(args, model, batch, args.training_loss_objs) 382 | 383 | 384 | if args.cov_weight: 385 | assert len(losses) == len(args.training_loss_objs) 386 | cov_list = [] 387 | for k,v in losses.items(): 388 | cov_list.append(args.training_loss_covs[k](v)) 389 | 390 | # Accumulate call updates the weights, which we can log 391 | loss = mlsr_training.cov.CovAccumulate(*cov_list).squeeze() 392 | 393 | # Log the weights: 394 | for k in losses.keys(): 395 | args.training_loss_weights[k] = args.training_loss_covs[k].weight_normalized.squeeze() 396 | else: 397 | # Accumulate losses with static weights 398 | loss = torch.zeros((1), device=args.device) 399 | 400 | for k,v in losses.items(): 401 | loss += args.training_loss_weights[k] * v 402 | 403 | # insert losses into GPU tensor so we don't sync w/ cpu each batch. 404 | # put them in the same order as the loss names, this is important for reporting 405 | assert len(loss_values[batch_index]) == len(losses)*2+1 406 | for i,(k,v) in enumerate(losses.items()): 407 | loss_values[batch_index][i*2] = v 408 | loss_values[batch_index][i*2 + 1] = args.training_loss_weights[k] 409 | loss_values[batch_index][-1] = loss.unsqueeze(0) 410 | 411 | loss /= args.batch_accumulate_count 412 | loss.backward() 413 | 414 | # Allow gradient accumulation so we can run a similar batch size on server parts vs local parts, 415 | # without changing other parameters. 416 | if (((batch_index + 1) % args.batch_accumulate_count) == 0) or batch_index + 1 == len(train_dl): 417 | optimizer.step() 418 | optimizer.zero_grad(set_to_none=True) 419 | 420 | if profiler != None and args.profiler_train: 421 | profiler.step() 422 | 423 | proc_end_ticks = time.monotonic() 424 | 425 | proc_epoch_ticks += proc_end_ticks - proc_start_ticks 426 | dl_start_ticks = time.monotonic() 427 | 428 | avg_loss_values = torch.mean(loss_values, dim=0).cpu() 429 | loss_dict = {} 430 | for k,v in zip(loss_names, avg_loss_values): 431 | loss_dict[k] = v.item() 432 | return loss_dict, dl_epoch_ticks, proc_epoch_ticks 433 | 434 | # Run in a function (local scope) to ensure all tensors are freed 435 | training_loss_dict, train_dl_epoch_ticks, train_proc_epoch_ticks = train_inner() 436 | 437 | # Update scheduler 438 | scheduler.step() 439 | 440 | for k,v in training_loss_dict.items(): 441 | metrics[f"training/loss-{k}"] = v 442 | 443 | 444 | # Validation 445 | if(args.valid): 446 | model.eval() 447 | 448 | def val_inner(): 449 | dl_epoch_ticks = proc_epoch_ticks = 0 450 | 451 | with torch.no_grad(): 452 | # These need to match the order they are placed in the tensor 453 | metric_names = [] 454 | for k in args.validation_metric_objs.keys(): 455 | metric_names.append(k) 456 | 457 | metrics_values = torch.empty((len(valid_dl),len(metric_names)), device=args.device, requires_grad=False) 458 | dl_start_ticks = time.monotonic() 459 | for batch_index, batch in enumerate(tqdm(valid_dl)): 460 | proc_start_ticks = dl_end_ticks = time.monotonic() 461 | dl_epoch_ticks += dl_end_ticks - dl_start_ticks 462 | 463 | with torch.profiler.record_function(f"Validation_{epoch}_{batch_index}"): 464 | # Prediction 465 | metrics, _,_ = do_batch(args, model, batch, args.validation_metric_objs) 466 | 467 | # insert metrics into GPU tensor so we don't sync w/ cpu each batch. 468 | # put them in the same order as the metric names, this is important for reporting 469 | assert len(metrics_values[batch_index]) == len(metrics) 470 | for i,(k,v) in enumerate(metrics.items()): 471 | metrics_values[batch_index][i] = v 472 | 473 | if profiler != None and args.profiler_validation: 474 | profiler.step() 475 | 476 | proc_end_ticks = time.monotonic() 477 | proc_epoch_ticks += proc_end_ticks - proc_start_ticks 478 | dl_start_ticks = time.monotonic() 479 | 480 | avg_metrics_values = torch.mean(metrics_values, dim=0).cpu() 481 | metric_dict = {} 482 | for k,v in zip(metric_names, avg_metrics_values): 483 | metric_dict[k] = v.item() 484 | return metric_dict, dl_epoch_ticks, proc_epoch_ticks 485 | 486 | # Run in a function (local scope) to ensure all tensors are freed 487 | validation_metric_dict, valid_dl_epoch_ticks, valid_proc_epoch_ticks = val_inner() 488 | 489 | for k,v in validation_metric_dict.items(): 490 | metrics[f"validation/metric-{k}"] = v 491 | 492 | 493 | metrics[f"perf/train-dl"] = train_dl_epoch_ticks 494 | metrics[f"perf/train-proc"] = train_proc_epoch_ticks 495 | metrics[f"perf/validation-dl"] = valid_dl_epoch_ticks 496 | metrics[f"perf/validation-proc"] = valid_proc_epoch_ticks 497 | 498 | if tb_summary_writer is not None: 499 | for k,v in metrics.items(): 500 | tb_summary_writer.add_scalar(k, v, epoch) 501 | 502 | if wandb: 503 | wandb.log(metrics) 504 | 505 | print(f"Epoch: {epoch} Loss:{training_loss_dict['combined']} Valid Loss:{validation_metric_dict}") 506 | print(f"Epoch: {epoch} Train: dl:{ticks_to_human_time(train_dl_epoch_ticks)} proc:{ticks_to_human_time(train_proc_epoch_ticks)} Valid: dl:{ticks_to_human_time(valid_dl_epoch_ticks)} proc:{ticks_to_human_time(valid_proc_epoch_ticks)}") 507 | 508 | 509 | if hasattr(args, 'optuna_trial') and args.optuna_trial: 510 | trial = args.optuna_trial 511 | 512 | if args.criteron_metric is not None: 513 | trial.report(validation_metric_dict[args.criteron_metric], epoch) 514 | 515 | # Handle pruning based on the intermediate value. 516 | if trial.should_prune() or not np.isfinite(training_loss_dict["combined"]): 517 | raise optuna.TrialPruned() 518 | 519 | if not np.isfinite(training_loss_dict["combined"]): 520 | raise Exception("Error Training loss was NaN. Training has failed.") 521 | 522 | 523 | # Test 524 | if(args.test): 525 | model.eval() 526 | def test_inner(): 527 | with torch.no_grad(): 528 | # These need to match the order they are placed in the tensor 529 | metric_names = [] 530 | for k in args.testing_metric_objs.keys(): 531 | metric_names.append(k) 532 | metrics_values = torch.empty((len(test_dl),len(metric_names)), device=args.device, requires_grad=False) 533 | for batch_index, batch in enumerate(tqdm(test_dl)): 534 | # Prediction 535 | metrics, _, _ = do_batch(args, model, batch, args.testing_metric_objs) 536 | 537 | # insert metrics into GPU tensor so we don't sync w/ cpu each batch. 538 | # put them in the same order as the metric names, this is important for reporting 539 | assert len(metrics_values[batch_index]) == len(metrics) 540 | for i,(k,v) in enumerate(metrics.items()): 541 | metrics_values[batch_index][i] = v 542 | 543 | 544 | avg_metrics_values = torch.mean(metrics_values, dim=0).cpu() 545 | metric_dict = {} 546 | for k,v in zip(metric_names, avg_metrics_values): 547 | metric_dict[k] = v.item() 548 | return metric_dict 549 | 550 | test_metric_dict = test_inner() 551 | 552 | if tb_summary_writer is not None: 553 | for k,v in test_metric_dict.items(): 554 | tb_summary_writer.add_scalar(f"test/metric-{k}", v, epoch) 555 | 556 | print(f"Test Loss: {test_metric_dict}") 557 | 558 | if tb_summary_writer is not None: 559 | tb_summary_writer.flush() 560 | tb_summary_writer.close() 561 | 562 | 563 | # Save model after training 564 | torch.save(model.state_dict(), './model.pt') 565 | 566 | if args.profiler_enable: print(profiler.key_averages(group_by_stack_n=10).table(sort_by='self_cpu_time_total', row_limit=10)) 567 | if args.profiler_enable: print(profiler.key_averages(group_by_stack_n=10).table(sort_by='self_cuda_time_total', row_limit=10)) 568 | 569 | if args.criteron_metric is not None: 570 | return validation_metric_dict[args.criteron_metric] 571 | else: 572 | return None 573 | 574 | 575 | def create_study(args: argparse.Namespace, base_trials:List[Dict[str, Any]] = None): 576 | # if running multi-process there can be some competition to see who creates the optuna db. 577 | # only one will win, but other will need to retry. 578 | for i in range(10): 579 | study = None 580 | if os.path.exists(args.optuna_study_path): 581 | study = optuna.create_study(study_name=args.optuna_study_name, load_if_exists=True, storage=f"sqlite:///{args.optuna_study_path}", sampler=sampler, pruner=pruner) 582 | break 583 | 584 | import tempfile 585 | import shutil 586 | 587 | # Create a temp file and then move it to the final destination. only one will win this way. 588 | # (prevent timing issues at study creation time) 589 | with tempfile.TemporaryDirectory() as tmp: 590 | temp_path = os.path.join(tmp, 'test.db') 591 | # use temp path 592 | temp_study = optuna.create_study(study_name=args.optuna_study_name, load_if_exists=False, storage=f"sqlite:///{temp_path}", sampler=sampler, pruner=pruner) 593 | 594 | # Seed study with initial trials, This is optional but often useful 595 | # make sure only the starting run adds the trials though. 596 | # https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html#optuna.study.Study.enqueue_trial 597 | for t in base_trials: 598 | temp_study.enqueue_trial(t) 599 | 600 | del temp_study 601 | 602 | try: 603 | pathlib.Path(args.optuna_study_path).parent.mkdir(exist_ok=True, parents=True) 604 | shutil.move(temp_path, args.optuna_study_path) 605 | except: 606 | # it's ok if we failed, one of our siblings likely succeeded 607 | pass 608 | 609 | del temp_path 610 | 611 | if study is None: 612 | raise Exception("Unable to create or load study...") 613 | 614 | return study 615 | 616 | if __name__ == "__main__": 617 | args = parse_args() 618 | 619 | # Log any additional data here into args, which will be printed/logged later. 620 | try: 621 | args.device_name = torch.cuda.get_device_name(args.device) 622 | args.device_memory = torch.cuda.get_device_properties(args.device).total_memory 623 | except: 624 | pass 625 | 626 | if args.optuna_trials < 1: 627 | main(args) 628 | else: 629 | if args.optuna_study_name is None: 630 | raise ValueError("optuna usage requires a trial name: pass --optuna_study_name [name] ") 631 | if args.optuna_study_path is None: 632 | args.optuna_study_path = os.path.join(args.out_path, "optuna.db") 633 | 634 | sampler = optuna.samplers.TPESampler() 635 | pruner = optuna.pruners.SuccessiveHalvingPruner() 636 | pruner = optuna.pruners.PatientPruner(wrapped_pruner=pruner, patience=2) 637 | 638 | # simple wrapper around main() which applies optuna parameters 639 | def objective(trial:optuna.trial.Trial): 640 | import copy 641 | args_copy = copy.deepcopy(args) 642 | args_copy.optuna_trial = trial 643 | # put each run into a separate output folder (so we can look back on the respective .pt files, etc) 644 | args_copy.out_path = os.path.join(args.out_path, args.optuna_study_name, f"{trial.number:04}") 645 | 646 | ################################### 647 | # override parameters as needed for study goals 648 | 649 | args_copy.lr = trial.suggest_float("lr", 1e-5, 0.5, log=True) 650 | 651 | ################################### 652 | 653 | return main(args_copy) 654 | 655 | study = create_study(args, [ 656 | ################################### 657 | # Add explicit trials (or not) to study. 658 | 659 | { "lr": 5e-4, } # Set the first trial to the current default value 660 | 661 | ################################### 662 | ]) 663 | 664 | study.optimize(objective, n_trials=args.optuna_trials) 665 | 666 | print(study.best_params) 667 | 668 | --------------------------------------------------------------------------------