├── .gitmodules ├── README.md ├── flow ├── __init__.py ├── datasets │ ├── __init__.py │ └── nc.py ├── modules │ ├── __init__.py │ ├── estimators.py │ ├── grids.py │ ├── losses.py │ └── warps.py └── utils │ ├── __init__.py │ ├── meter.py │ ├── plot │ ├── __init__.py │ └── plot.py │ └── plot_old.py ├── images ├── model.png └── results.png ├── run.sh ├── test_with_ip_addr.py ├── train.py └── train_with_ip_addr.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "flow/submodules/OpticalFlowToolkit"] 2 | path = flow/submodules/OpticalFlowToolkit 3 | url = https://github.com/liruoteng/OpticalFlowToolkit 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # *Deep Learning for Physical Processes: Integrating Prior Scientific Knowledge:* 2 | 3 | Official Pytorch implementation of ICLR 2018 paper [Deep Learning for Physical Processes: Integrating Prior Scientific Knowledge](https://openreview.net/pdf?id=By4HsfWAZ). 4 | 5 | ![alt text](images/model.png) 6 | 7 | ## Guetting started: 8 | Clone repository along with submodules: ``` git clone --recursive https://github.com/emited/flow``` 9 | 10 | ## Dataset 11 | Download the data [here](http://marine.copernicus.eu/services-portfolio/access-to-products/?option=com_csw&view=details&product_id=GLOBAL_ANALYSIS_FORECAST_PHY_001_024). 12 | 13 | ## Results 14 | ![alt text](images/results.png) 15 | 16 | **Note**: By defalt, this implementation currently uses bilinear interpolation for warping. This scheme works well for modeling purely advective processes. For advective and diffusive processes, a gaussian warping scheme can be used (flow/modules/warps/GaussianWarpingScheme). The gaussian warping scheme will be integrated shortly into pytorch. Take a look at the pull request [here](https://github.com/pytorch/pytorch/pull/5487) for a status update. While waiting, it is possible to build pytorch from a forked version available [here](https://github.com/pajotarthur/pytorch). 17 | -------------------------------------------------------------------------------- /flow/__init__.py: -------------------------------------------------------------------------------- 1 | import flow.datasets 2 | import flow.modules 3 | import flow.utils -------------------------------------------------------------------------------- /flow/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .nc import * 2 | -------------------------------------------------------------------------------- /flow/datasets/nc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data 3 | import pickle as pkl 4 | import numpy as np 5 | 6 | 7 | def _normalize_thetao_with_daily_stats(zdata): 8 | daily_mean = zdata['daily_mean'].reshape(-1, 1, 1) 9 | daily_std = zdata['daily_std'].reshape(-1, 1, 1) 10 | zdata['thetao'] = (zdata['thetao'] - daily_mean) / daily_std 11 | 12 | def _normalize_thetao(zdata): 13 | mean = zdata['thetao'].mean(axis=2).mean(axis=1) 14 | std = zdata['thetao'].reshape(zdata['thetao'].shape[0], -1).std(axis=1) 15 | zdata['thetao'] = (zdata['thetao'] - mean.reshape(-1, 1, 1)) / std.reshape(-1, 1, 1) 16 | #we have to modify mean and std if we renormalize again 17 | zdata['daily_mean'] = zdata['daily_mean'] + zdata['daily_std'] * mean 18 | zdata['daily_std'] = zdata['daily_std'] * std 19 | 20 | def _rescale_thetao(zdata): 21 | vmin = zdata['thetao'].min(axis=2).min(axis=1) 22 | vmax = zdata['thetao'].max(axis=2).max(axis=1) 23 | rmin = vmin.reshape(-1, 1, 1) 24 | rmax = vmax.reshape(-1, 1, 1) 25 | zdata['thetao'] = (zdata['thetao'] - rmin) / (rmax - rmin) 26 | #we have to modify mean and std if we renormalize again 27 | zdata['daily_mean'] = zdata['daily_mean'] + zdata['daily_std'] * vmin 28 | zdata['daily_std'] = zdata['daily_std'] * (vmax - vmin) 29 | 30 | def _normalize_uo_vo(zdata): 31 | # print('nromalizing ') 32 | norm = (np.abs(zdata['uo']) + np.abs(zdata['vo'])).mean(axis=2).mean(axis=1) 33 | zdata['uo'] = zdata['uo'] / norm.reshape(-1, 1, 1) 34 | zdata['vo'] = zdata['vo'] / norm.reshape(-1, 1, 1) 35 | zdata['uv_norm'] = norm 36 | 37 | 38 | class SSTSeq(torch.utils.data.Dataset): 39 | 40 | def __init__(self, root, seq_len=4, target_seq_len=6, 41 | time_slice=None, 42 | normalize_by_day=True, 43 | rescale_method=None, 44 | transform=None, target_transform=None, 45 | co_transform=None, 46 | normalize_uv=True, 47 | zones=None): 48 | 49 | if zones is None: # using all zones 50 | zones = range(1, 30) 51 | 52 | if time_slice is None: # using all times 53 | time_slice = slice(None, None) 54 | 55 | self.root = root 56 | self.zones = zones 57 | self.seq_len = seq_len 58 | self.normalize_by_day = normalize_by_day 59 | self.rescale_method = rescale_method 60 | self.target_seq_len = target_seq_len 61 | self.time_slice = time_slice 62 | self.transform = transform 63 | self.target_transform = target_transform 64 | self.co_transform = co_transform 65 | 66 | self.data = {} 67 | for zone in zones: 68 | path = os.path.join(root, 'data_' + str(zone) + '.pkl') 69 | zdata = pkl.load(open(path, 'rb')) 70 | 71 | if normalize_by_day: 72 | _normalize_thetao_with_daily_stats(zdata) 73 | 74 | if rescale_method == 'norm': 75 | print('=> norm rescale zone {}'.format(zone)) 76 | _normalize_thetao(zdata) 77 | 78 | elif rescale_method == 'minmax': 79 | print('=> minmax rescale zone {}'.format(zone)) 80 | _rescale_thetao(zdata) 81 | 82 | if normalize_uv: 83 | print('normalizing uv !') 84 | _normalize_uo_vo(zdata) 85 | for var in ['thetao', 'uo', 'vo']: 86 | zdata[var] = zdata[var][time_slice] 87 | 88 | self.data[zone] = zdata 89 | # print(time_slice) 90 | # print(f'num days: {len(self.data[self.zones[0]]["thetao"])}') 91 | self.num_single = self.data[zones[0]]['thetao'].shape[0] - seq_len - target_seq_len + 1 92 | self.num = self.num_single * len(zones) 93 | 94 | print(f'size: {len(self)} num days: {len(self.data[self.zones[0]]["thetao"])}') 95 | 96 | def __getitem__(self, index): 97 | zone = self.zones[index // self.num_single]# - 1)] 98 | # sample_num = index % (self.num_single - 1) 99 | sample_num = index % self.num_single 100 | zdata = self.data[zone] 101 | 102 | input = zdata['thetao'][sample_num: sample_num + self.seq_len] 103 | target = zdata['thetao'][sample_num + self.seq_len: sample_num + self.seq_len + self.target_seq_len] 104 | uo_target = zdata['uo'][sample_num + self.seq_len: sample_num + self.seq_len + self.target_seq_len] 105 | vo_target = zdata['vo'][sample_num + self.seq_len: sample_num + self.seq_len + self.target_seq_len] 106 | w_target = np.concatenate([np.expand_dims(uo_target, 1), np.expand_dims(vo_target, 1)], 1) 107 | # print('zdata', zdata['uo'].shape, zdata['vo'].shape, 'w_targer', w_target.shape, np.expand_dims(uo_target, 1).shape) 108 | # exit() 109 | # 110 | return input, target, w_target 111 | 112 | def __len__(self): 113 | return self.num # - len(self.zones) -------------------------------------------------------------------------------- /flow/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from . import losses 2 | from . import grids 3 | from . import estimators 4 | from . import warps -------------------------------------------------------------------------------- /flow/modules/estimators.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def conv(batch_norm, in_planes, out_planes, kernel_size=3, stride=1): 7 | if batch_norm: 8 | return nn.Sequential( 9 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, 10 | stride=stride, padding=(kernel_size - 1) // 2, bias=False), 11 | nn.BatchNorm2d(out_planes), 12 | nn.ReLU(inplace=True) 13 | # nn.LeakyReLU(0.1, inplace=True) 14 | ) 15 | else: 16 | return nn.Sequential( 17 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, 18 | stride=stride, padding=(kernel_size - 1) // 2, bias=True), 19 | # nn.LeakyReLU(0.1, inplace=True) 20 | nn.ReLU(inplace=True) 21 | 22 | ) 23 | 24 | 25 | def soft_deconv(in_planes, out_planes, upsample_mode='bilinear'): 26 | return nn.Sequential( 27 | nn.Upsample(scale_factor=2, mode=upsample_mode), 28 | nn.Conv2d(in_planes, out_planes, kernel_size=3, 29 | stride=1, padding=1, bias=False), # TURN OFF BIAS?? 30 | nn.ReLU(inplace=True) 31 | # nn.LeakyReLU(0.1, inplace=True) 32 | ) 33 | 34 | 35 | def deconv(in_planes, out_planes): 36 | return nn.Sequential( 37 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, 38 | stride=2, padding=1, bias=True), 39 | nn.LeakyReLU(0.1, inplace=True) 40 | ) 41 | 42 | 43 | def soft_conv_transpose(in_planes, out_planes, kernel_size=3, stride=2, padding=1, bias=False): 44 | return nn.Sequential( 45 | nn.Upsample(scale_factor=stride, mode='bilinear'), 46 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, 47 | stride=1, padding=1, bias=False) 48 | ) 49 | 50 | 51 | def predict_flow(in_planes, out_planes): 52 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False) 53 | 54 | 55 | class ConvDeconvEstimator(nn.Module): 56 | """ 57 | The output channel size is the same as the input. 58 | This is done by removing the last two convolutional 59 | layers of FlowNetS, and adding two deconvolutional 60 | layers at the end. 61 | """ 62 | 63 | def __init__(self, input_channels=4, output_channels=2, batch_norm=True, upsample_mode='bilinear'): 64 | super(ConvDeconvEstimator, self).__init__() 65 | 66 | self.batch_norm = batch_norm 67 | self.upsample_mode = upsample_mode 68 | self.input_channels = input_channels 69 | self.conv1 = conv(self.batch_norm, input_channels, 64, kernel_size=3, stride=2) 70 | self.conv2 = conv(self.batch_norm, 64, 128, kernel_size=3, stride=2) 71 | self.conv3 = conv(self.batch_norm, 128, 256, kernel_size=3, stride=2) 72 | self.conv3_1 = conv(self.batch_norm, 256, 256, kernel_size=3) 73 | self.conv4 = conv(self.batch_norm, 256, 512, kernel_size=3, stride=2) 74 | self.conv4_1 = conv(self.batch_norm, 512, 512, kernel_size=3) 75 | self.conv5 = conv(self.batch_norm, 512, 1024, stride=2) 76 | self.conv5_1 = conv(self.batch_norm, 1024, 1024) 77 | 78 | if upsample_mode == 'deconv': 79 | self.deconv4 = deconv(1024, 256) 80 | self.deconv3 = deconv(768, 128) 81 | self.deconv2 = deconv(384, 64) 82 | self.deconv1 = deconv(192, 32) 83 | self.deconv0 = deconv(96, 16) 84 | else: 85 | self.deconv4 = soft_deconv(1024, 256, upsample_mode=upsample_mode) 86 | self.deconv3 = soft_deconv(768, 128, upsample_mode=upsample_mode) 87 | self.deconv2 = soft_deconv(384, 64, upsample_mode=upsample_mode) 88 | self.deconv1 = soft_deconv(192, 32, upsample_mode=upsample_mode) 89 | self.deconv0 = soft_deconv(96, 16, upsample_mode=upsample_mode) 90 | 91 | self.predict_flow0 = predict_flow(16 + input_channels, output_channels) 92 | 93 | for m in self.modules(): 94 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 95 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 96 | m.weight.data.normal_(0, 0.02 / n) 97 | if m.bias is not None: 98 | m.bias.data.zero_() 99 | elif isinstance(m, nn.BatchNorm2d): 100 | m.weight.data.fill_(1) 101 | m.bias.data.zero_() 102 | 103 | def forward(self, x): 104 | 105 | out_conv1 = self.conv1(x) 106 | out_conv2 = self.conv2(out_conv1) 107 | out_conv3 = self.conv3_1(self.conv3(out_conv2)) 108 | out_conv4 = self.conv4_1(self.conv4(out_conv3)) 109 | out_conv5 = self.conv5_1(self.conv5(out_conv4)) 110 | 111 | out_deconv4 = self.deconv4(out_conv5) 112 | concat4 = torch.cat((out_conv4, out_deconv4), 1) 113 | out_deconv3 = self.deconv3(concat4) 114 | concat3 = torch.cat((out_conv3, out_deconv3), 1) 115 | out_deconv2 = self.deconv2(concat3) 116 | concat2 = torch.cat((out_conv2, out_deconv2), 1) 117 | out_deconv1 = self.deconv1(concat2) 118 | concat1 = torch.cat((out_conv1, out_deconv1), 1) 119 | out_deconv0 = self.deconv0(concat1) 120 | concat0 = torch.cat((x, out_deconv0), 1) 121 | flow0 = self.predict_flow0(concat0) 122 | 123 | return flow0 124 | -------------------------------------------------------------------------------- /flow/modules/grids.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | 6 | __all__ = [ 7 | 'DenseGridGen', 8 | ] 9 | 10 | 11 | class DenseGridGen(nn.Module): 12 | 13 | def __init__(self, transpose=True): 14 | super(DenseGridGen, self).__init__() 15 | self.transpose = transpose 16 | self.register_buffer('grid', torch.Tensor()) 17 | 18 | def forward(self, x): 19 | 20 | if self.transpose: 21 | x = x.transpose(1, 2).transpose(2, 3) 22 | 23 | g0 = torch.linspace(-1, 1, x.size(2) 24 | ).unsqueeze(0).repeat(x.size(1), 1) 25 | g1 = torch.linspace(-1, 1, x.size(1) 26 | ).unsqueeze(1).repeat(1, x.size(2)) 27 | grid = torch.cat([g0.unsqueeze(-1), g1.unsqueeze(-1)], -1) 28 | self.grid.resize_(grid.size()).copy_(grid) 29 | 30 | bgrid = Variable(self.grid) 31 | bgrid = bgrid.unsqueeze(0).expand(x.size(0), *bgrid.size()) 32 | 33 | return bgrid - x -------------------------------------------------------------------------------- /flow/modules/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def charb(x, alpha, eps): 5 | '''charbonnier Loss function''' 6 | return torch.mean(torch.pow(x.pow(2) + eps, 1. / alpha)) 7 | 8 | 9 | def AAE(input_flow, target_flow): 10 | '''Average Angular Error: 11 | 12 | Provides a relative measure of performance 13 | that avoids the divide by zero. 14 | 15 | Calculates the angle between input and target vectors 16 | augmented with an extra dimension where the associated 17 | scalar value for that dimension is one. 18 | ''' 19 | 20 | num = 1 + torch.sum(input_flow * target_flow, 1) 21 | denom = torch.sum(1 + input_flow ** 2, 1) 22 | denom_gt = torch.sum(1 + target_flow ** 2, 1) 23 | return torch.acos(num / torch.sqrt(denom * denom_gt)).mean() 24 | 25 | 26 | class CharbonnierLoss(torch.nn.Module): 27 | '''From Back to Basics: 28 | Unsupervised Learning of Optical Flow 29 | via Brightness Constancy and Motion Smoothness''' 30 | 31 | def __init__(self, alpha, eps): 32 | super(CharbonnierLoss, self).__init__() 33 | self.alpha = alpha 34 | self.eps = eps 35 | 36 | def forward(self, input, target): 37 | return charb(input - target, self.alpha, self.eps) 38 | 39 | 40 | class MagnitudeLoss(torch.nn.Module): 41 | 42 | def __init__(self, loss): 43 | super(MagnitudeLoss, self).__init__() 44 | self.loss = loss 45 | 46 | def forward(self, w): 47 | return self.loss(w, w.detach() * 0) 48 | 49 | 50 | class SmoothnessLoss(torch.nn.Module): 51 | '''From Back to Basics: 52 | Unsupervised Learning of Optical Flow 53 | via Brightness Constancy and Motion Smoothness''' 54 | 55 | def __init__(self, loss, delta=1): 56 | super(SmoothnessLoss, self).__init__() 57 | self.loss = loss 58 | self.delta = delta 59 | 60 | def forward(self, w): 61 | ldudx = self.loss((w[:, 0, 1:, :] - w[:, 0, :-1, :]) / 62 | self.delta, w[:, 0, 1:, :].detach() * 0) 63 | ldudy = self.loss((w[:, 0, :, 1:] - w[:, 0, :, :-1]) / 64 | self.delta, w[:, 0, :, 1:].detach() * 0) 65 | ldvdx = self.loss((w[:, 1, 1:, :] - w[:, 1, :-1, :]) / 66 | self.delta, w[:, 1, 1:, :].detach() * 0) 67 | ldvdy = self.loss((w[:, 1, :, 1:] - w[:, 1, :, :-1]) / 68 | self.delta, w[:, 1, :, 1:].detach() * 0) 69 | return ldudx + ldudy + ldvdx + ldvdy 70 | 71 | 72 | class DivergenceLoss(torch.nn.Module): 73 | 74 | def __init__(self, loss, delta=1): 75 | super(DivergenceLoss, self).__init__() 76 | self.delta = delta 77 | self.loss = loss 78 | 79 | def forward(self, w): 80 | dudx = (w[:, 0, 1:] - w[:, 0, :-1]) / self.delta 81 | dvdy = (w[:, 1, 1:] - w[:, 1, :-1]) / self.delta 82 | return self.loss(dudx + dvdy, dudx.detach() * 0) 83 | 84 | 85 | class WeightedSpatialMSELoss(torch.nn.Module): 86 | def __init__(self): 87 | super(WeightedSpatialMSELoss, self).__init__() 88 | self.loss = torch.nn.MSELoss(reduce=False, size_average=False) 89 | 90 | def forward(self, input, target, weights=1): 91 | return self.loss(input, target).mean(3).mean(2).mean(1) * weights -------------------------------------------------------------------------------- /flow/modules/warps.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from .grids import DenseGridGen 5 | 6 | __all__ = ['BilinearWarpingScheme', 7 | 'GaussianWarpingScheme'] 8 | 9 | 10 | class BilinearWarpingScheme(nn.Module): 11 | def __init__(self, padding_mode='zeros'): 12 | super(BilinearWarpingScheme, self).__init__() 13 | self.grid = DenseGridGen() 14 | self.padding_mode = padding_mode 15 | 16 | def forward(self, im, w): 17 | return F.grid_sample(im, self.grid(w), padding_mode=self.padding_mode, mode='bilinear') 18 | 19 | 20 | class GaussianWarpingScheme(nn.Module): 21 | def __init__(self, padding_mode='zeros', F=3, std=0.25): 22 | super(GaussianWarpingScheme, self).__init__() 23 | self.grid = DenseGridGen() 24 | self.F = F 25 | self.std = std 26 | self.padding_mode = padding_mode 27 | 28 | def forward(self, im, w): 29 | return F.grid_sample(im, self.grid(w), padding_mode=self.padding_mode, mode='gaussian', F=self.F, std=self.std) -------------------------------------------------------------------------------- /flow/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import join, dirname 2 | import sys 3 | sys.path.append(join(dirname(__file__), '../submodules/OpticalFlowToolkit/lib')) -------------------------------------------------------------------------------- /flow/utils/meter.py: -------------------------------------------------------------------------------- 1 | 2 | from collections import OrderedDict 3 | 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value""" 7 | 8 | def __init__(self): 9 | self.reset() 10 | 11 | def reset(self): 12 | self.val = 0 13 | self.avg = 0 14 | self.sum = 0 15 | self.count = 0 16 | 17 | def update(self, val, n=1): 18 | self.val = val 19 | self.sum += val * n 20 | self.count += n 21 | self.avg = self.sum / self.count 22 | 23 | def __repr__(self): 24 | return '{:.10f} ({:.10f})'.format(self.val, self.avg) 25 | 26 | 27 | class AverageMeters(object): 28 | 29 | def __init__(self): 30 | self._meters = OrderedDict() 31 | 32 | def update(self, meter_dict, n=1): 33 | for name, val in meter_dict.items(): 34 | if not name in self._meters: 35 | self._meters[name] = AverageMeter() 36 | self._meters[name].update(val, n) 37 | self._check_integrity() 38 | 39 | def _check_integrity(self): 40 | for i, (name, meter) in enumerate(self._meters.items()): 41 | if i == 0: 42 | tmpcount = meter.count 43 | elif tmpcount != meter.count: 44 | raise RuntimeError('Forgot to update meter ' + name + 45 | '. Meter has count {} instead of {}.'.format(meter.count, tmpcount)) 46 | 47 | def names(self): 48 | return list(self._meters.keys()) 49 | 50 | def val(self, name): 51 | return self._meters[name].val 52 | 53 | def avg(self, name): 54 | return self._meters[name].avg 55 | 56 | def vals(self): 57 | return OrderedDict([(name, meter.val) for name, meter in self._meters.items()]) 58 | 59 | def avgs(self): 60 | return OrderedDict([(name, meter.avg) for name, meter in self._meters.items()]) 61 | 62 | def __repr__(self): 63 | tmpstr = '' 64 | for name, meter in self._meters.items(): 65 | tmpstr += name + ' ' + meter.__repr__() + '\t' 66 | return tmpstr 67 | -------------------------------------------------------------------------------- /flow/utils/plot/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import join, dirname 2 | import sys 3 | sys.path.append(join(dirname(__file__), '../../submodules/OpticalFlowToolkit/lib')) 4 | 5 | from .plot import * 6 | -------------------------------------------------------------------------------- /flow/utils/plot/plot.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import cmocean 5 | 6 | import torch 7 | import torchvision.utils as vutils 8 | 9 | import flowlib 10 | 11 | 12 | def flow_to_image(flow): 13 | return flowlib.flow_to_image(flow.transpose(1, 2, 0)) 14 | 15 | 16 | def color_code(xlim=(-10, 10), ylim=(-10, 10), res=100): 17 | x = np.linspace(xlim[0], xlim[1], res) 18 | y = np.linspace(ylim[0], ylim[1], res) 19 | X, Y = np.meshgrid(x, y) 20 | X = np.expand_dims(X, 0) 21 | Y = np.expand_dims(Y, 0) 22 | C = np.concatenate([X, Y], axis=0) 23 | return C 24 | 25 | 26 | def color_code_image(xlim=(-10, 10), ylim=(-10, 10), res=100): 27 | code = flow_to_image(color_code(xlim, ylim, res)) 28 | return code 29 | 30 | 31 | def from_matplotlib(fig): 32 | fig.canvas.draw() 33 | rgb = fig.canvas.tostring_rgb() 34 | plt.close(fig) 35 | data = np.fromstring(rgb, dtype=np.uint8, sep='') 36 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 37 | return data 38 | 39 | def plot_results(x, nsample=1, res=1.5, cmap='viridis'): 40 | """ 41 | x is a list of tuples (name, value) where value 42 | is a dict of with keys 'in' and 'out', corresponding 43 | to the input and output sequences given to the model 44 | """ 45 | # retrieving max image value to not renorm 46 | vmin = np.inf 47 | vmax = - np.inf 48 | for _, v in x: 49 | for l in v: 50 | if v[l][0][0].shape[0] != 2: 51 | vls = [vli[:nsample] for vli in v[l]] 52 | vmin = min(vmin, np.min(vls)) 53 | vmax = max(vmax, np.max(vls)) 54 | 55 | x = OrderedDict(x) 56 | 57 | # calculating max column length 58 | maxincols, maxoutcols = 0, 0 59 | for k in x: 60 | if 'in' in x[k]: 61 | maxincols = max(len(x[k]['in']), maxincols) 62 | if 'out' in x[k]: 63 | maxoutcols = max(len(x[k]['out']), maxoutcols) 64 | cols = maxincols + maxoutcols 65 | rows = len(x) * nsample 66 | 67 | def plot_one(im, title='', cmap='viridis'): 68 | if im.shape[0] == 2: 69 | im = flow_to_image(im) 70 | plt.axis('off') 71 | if hasattr(cmocean.cm, cmap): 72 | cmap = getattr(cmocean.cm, cmap) 73 | plt.imshow(im.squeeze(), origin='lower', vmin=vmin, vmax=vmax, cmap=cmap) 74 | #plt.title(title) 75 | 76 | plt.figure(figsize=(cols * res, rows * res)) 77 | for i, k in enumerate(x): 78 | for s in range(nsample): 79 | if 'in' in x[k]: 80 | 81 | title = '{}, in'.format(k) 82 | for t in range(len(x[k]['in'])): 83 | n = s * len(x) * cols + cols * i + t 84 | im = x[k]['in'][t][s] 85 | plt.subplot(rows, cols, n + 1) 86 | plot_one(im, title, cmap=cmap) 87 | if 'out' in x[k]: 88 | title = '{}, out'.format(k) 89 | for t in range(maxincols, maxincols + len(x[k]['out'])): 90 | im = x[k]['out'][t - maxincols][s] 91 | n = s * len(x) * cols + cols * i + t 92 | plt.subplot(rows, cols, n + 1) 93 | plot_one(im, title, cmap=cmap, ) 94 | plt.tight_layout() 95 | plt.subplots_adjust(wspace=0.001, hspace=0.1) 96 | return plt.gcf() 97 | 98 | 99 | def plot_tensor(output, x, padding=6, pad_value=1): 100 | """ 101 | out:model output tensor size TxOxCxHxW 102 | x: model input arnd target size Tx(I+O)xCxHxW 103 | """ 104 | samples = [] 105 | for o, xi in zip(output.transpose(0, 1), x.transpose(0, 1)): 106 | out = vutils.make_grid(o, padding=padding, nrow=1, pad_value=pad_value) 107 | inn = vutils.make_grid(xi[:-output.size(0)], padding=padding, nrow=1, pad_value=pad_value) 108 | innout = torch.cat([inn, out], 1) 109 | targ = vutils.make_grid(xi[-output.size(0):], padding=padding, nrow=1, pad_value=pad_value) 110 | inntarg = torch.cat([inn, targ], 1) 111 | sample = torch.cat([inntarg, innout], 2) 112 | samples.append(sample) 113 | return torch.cat(samples, 2).transpose(1, 2) 114 | 115 | 116 | #def plot_color_code_image(xlim=(-10, 10), ylim=(-10, 10), resolution=100, **kwargs): 117 | # '''For some reason, colors are flipped on the y axis 118 | # with respect to real middlebury color code. 119 | # ''' 120 | # plt.title('Middlebury color code') 121 | # plt.imshow(color_code_image(xlim, ylim, resolution), **kwargs, origin='lower') 122 | # plt.plot([0, resolution-1], [resolution / 2, resolution / 2], c='black') 123 | # plt.plot([resolution / 2, resolution / 2], [0, resolution-1], c='black') 124 | # plt.axis('off') 125 | 126 | 127 | def plot_flow_quiver(flow, flow_target=None, img=None): 128 | '''Plots vector field with/without associated image 129 | according to method given. 130 | flow: np.ndarray of with size (2, x_dim, y_dim) 131 | flow_target: np.ndarray of with size (2, x_dim, y_dim) 132 | img: np.ndarray with size (x_dim, y_dim) 133 | 134 | ''' 135 | 136 | if img is not None: 137 | plt.imshow(img, origin='lower') 138 | 139 | X, Y = np.meshgrid(range(flow.shape[1]), range(flow.shape[2])) 140 | 141 | if flow_target is not None: 142 | plt.quiver(X[::3, ::3], Y[::3, ::3], flow_target[0, ::3, ::3], flow_target[1, ::3, ::3], 143 | color='r', pivot='mid', units='inches') 144 | 145 | plt.quiver(X[::3, ::3], Y[::3, ::3], flow[0, ::3, ::3], flow[1, ::3, ::3], 146 | pivot='mid', units='inches') 147 | 148 | 149 | -------------------------------------------------------------------------------- /flow/utils/plot_old.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | import flowlib 6 | 7 | 8 | def color_code(xlim=(-10, 10), ylim=(-10, 10), res=100): 9 | x = np.linspace(xlim[0], xlim[1], res) 10 | y = np.linspace(ylim[0], ylim[1], res) 11 | X, Y = np.meshgrid(x, y) 12 | X = np.expand_dims(X, 0) 13 | Y = np.expand_dims(Y, 0) 14 | C = np.concatenate([X, Y], axis=0) 15 | return C 16 | 17 | 18 | def color_code_image(xlim=(-10, 10), ylim=(-10, 10), res=100): 19 | code = flow_to_image(color_code(xlim, ylim, res)) 20 | return code 21 | 22 | 23 | def from_matplotlib(fig): 24 | fig.canvas.draw() 25 | rgb = fig.canvas.tostring_rgb() 26 | plt.close(fig) 27 | data = np.fromstring(rgb, dtype=np.uint8, sep='') 28 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 29 | return data 30 | 31 | 32 | def plot_one_image(im, title='', vmin=None, vmax=None): 33 | if im.shape[0] == 2: 34 | im = flowlib.flow_to_image(im.transpose(1, 2, 0)) 35 | plt.axis('off') 36 | plt.imshow(im.squeeze(), origin='lower', vmin=vmin, vmax=vmax) 37 | 38 | 39 | def plot_images(x, nsample=1): 40 | """ 41 | x is a list of tuples (name, value) where value 42 | is a dict of with keys 'in' and 'out', corresponding 43 | to the input and output sequences given to the model 44 | """ 45 | 46 | # retrieving max image value to not renorm 47 | vmin = np.inf 48 | vmax = - np.inf 49 | for _, v in x: 50 | for l in v: 51 | if v[l][0][0].shape[0] != 2: 52 | vmin = min(vmin, np.min(v[l])) 53 | vmax = max(vmax, np.max(v[l])) 54 | 55 | x = OrderedDict(x) 56 | 57 | # calculating max column length 58 | maxincols, maxoutcols = 0, 0 59 | for k in x: 60 | if 'in' in x[k]: 61 | maxincols = max(len(x[k]['in']), maxincols) 62 | if 'out' in x[k]: 63 | maxoutcols = max(len(x[k]['out']), maxoutcols) 64 | cols = maxincols + maxoutcols 65 | rows = len(x) * nsample 66 | 67 | res = 1.5 68 | plt.figure(figsize=(cols * res, rows * res)) 69 | for i, k in enumerate(x): 70 | for s in range(nsample): 71 | if 'in' in x[k]: 72 | title = '{}, in'.format(k) 73 | for t in range(len(x[k]['in'])): 74 | n = s * len(x) * cols + cols * i + t 75 | im = x[k]['in'][t][s] 76 | plt.subplot(rows, cols, n + 1) 77 | plot_one_image(im, title) 78 | if 'out' in x[k]: 79 | title = '{}, out'.format(k) 80 | for t in range(maxincols, maxincols + len(x[k]['out'])): 81 | im = x[k]['out'][t - maxincols][s] 82 | n = s * len(x) * cols + cols * i + t 83 | plt.subplot(rows, cols, n + 1) 84 | plot_one_image(im, title) 85 | plt.tight_layout() 86 | plt.subplots_adjust(wspace=0.001, hspace=0.1) 87 | return plt.gcf() -------------------------------------------------------------------------------- /images/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emited/flow/f7007c58e0fbc9c2590afcd07868bcc1514e2e44/images/model.png -------------------------------------------------------------------------------- /images/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emited/flow/f7007c58e0fbc9c2590afcd07868bcc1514e2e44/images/results.png -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | run=" 2 | find -name "*.pyc" -delete && 3 | OMP_NUM_THREADS=2 CUDA_VISIBLE_DEVICES=$2 4 | python $1 ${@:3} 5 | " 6 | 7 | printf "$run \n" 8 | eval $run -------------------------------------------------------------------------------- /test_with_ip_addr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import visdom 5 | from datetime import datetime 6 | import matplotlib.pyplot as plt 7 | 8 | import torch 9 | import torch.utils.data 10 | import torch.backends.cudnn as cudnn 11 | from torch.autograd import Variable 12 | from torch.utils.data.sampler import SubsetRandomSampler 13 | from torch.utils.data import DataLoader 14 | 15 | import flow.modules.losses as losses 16 | import flow.datasets as datasets 17 | import flow.modules.warps as warps 18 | import flow.modules.estimators as estimators 19 | from flow.utils.meter import AverageMeters 20 | # import flow.utils.plot as plot 21 | 22 | 23 | import sys 24 | sys.path.append('/home/debezenac/projects/flow_icml/flow/flow/utils') 25 | import plot 26 | 27 | warp_names = sorted(name for name in warps.__dict__ 28 | if not name.startswith('__')) 29 | 30 | parser = argparse.ArgumentParser(description='PyTorch FlowNet Training on several datasets') 31 | parser.add_argument('--train-root', metavar='DIR', default='/net/drunk/debezenac/CMEMS_DATA/datasets/np/train', 32 | help='path to training dataset') 33 | parser.add_argument('--test-root', metavar='DIR', default='/net/drunk/debezenac/CMEMS_DATA/datasets/np/test', 34 | help='path to testing dataset') 35 | parser.add_argument('--train-zones', type=int, nargs='+', action='store', dest='train_zones', default=[20], 36 | help='geographical zones to train on. To train on all zones, add range(1, 30)') 37 | parser.add_argument('--test-zones', type=int, nargs='+', action='store', dest='test_zones', default=[20], 38 | help='geographical zones to test on. To test on all zones, add range(1, 30)') 39 | parser.add_argument('--rescale', default='norm', type=str, 40 | help='you can choose between minmax and norm') 41 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 42 | help='number of data loading workers (default: 4)') 43 | parser.add_argument('-b', '--batch-size', default=1, type=int, 44 | metavar='N', help='mini-batch size (default: 16)') 45 | parser.add_argument('-s', '--split', default=.8, type=float, metavar='%', 46 | help='split percentage of train samples vs test (default: .8)') 47 | parser.add_argument('--seq-len', default=4, type=int, 48 | help='number of input images as input of the estimator (horizon)') 49 | parser.add_argument('--target-seq-len', default=6, type=int, 50 | help='number of target images') 51 | parser.add_argument('--test-target-seq-len', default=10, type=int, 52 | help='number of test target images') 53 | parser.add_argument('--weight-decay', '--wd', default=4e-4, type=float, 54 | metavar='W', help='weight decay (default: 4e-4)') 55 | parser.add_argument('--warp', default='BilinearWarpingScheme', choices=warp_names, 56 | help='choose warping scheme to use:' + ' | '.join(warp_names)) 57 | parser.add_argument('--upsample', default='bilinear', choices=('deconv', 'nearest', 'bilinear'), 58 | help='choose from (deconv, nearest, bilinear)') 59 | parser.add_argument('--lr', '--learning-rate', default=0.0001, type=float, 60 | metavar='LR', help='initial learning rate') 61 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 62 | help='momentum for sgd, alpha parameter for adam') 63 | parser.add_argument('--beta', default=0.999, type=float, metavar='M', 64 | help='beta parameters for adam') 65 | parser.add_argument('--smooth-coef', default=0.4, type=float, 66 | help='coefficient associated to smoothness loss in cost function') 67 | parser.add_argument('--div-coef', default=1, type=float, 68 | help='coefficient associated to divergence loss in cost function') 69 | parser.add_argument('--magn-coef', default=-0.003, type=float, 70 | help='coefficient associated to magnitude loss in cost function') 71 | parser.add_argument('--epochs', default=1, type=int, metavar='N', 72 | help='number of total epochs to run (default: 300') 73 | parser.add_argument('--save-every', default=10, type=int, metavar='N', 74 | help='') 75 | parser.add_argument('--save-start', default=20, type=int, metavar='N', 76 | help='') 77 | parser.add_argument('--save-root', default='/net/drunk/debezenac/data/flow_icml/saved_modules_iclr_2018', type=str, 78 | help=' ') 79 | parser.add_argument('--env', default='main', 80 | help='environnment for visdom') 81 | parser.add_argument('--no-plot', action='store_true', 82 | help='no plot images using visdom') 83 | parser.add_argument('--no-cuda', action='store_true', 84 | help='no cuda') 85 | parser.add_argument('--load-root', default='/net/drunk/debezenac/data/flow_icml/saved_modules_iclr_2018', type=str, 86 | help=' ') 87 | parser.add_argument('--load-fn', default='sam.-janv.-19-04:06:42.319084_470.pt', type=str, 88 | help=' ') 89 | parser.add_argument('--save-ims', action='store_true', 90 | help='no plot images using visdom') 91 | 92 | 93 | 94 | args = parser.parse_args() 95 | 96 | viz = visdom.Visdom(server='http://132.227.204.175', env=args.env) 97 | 98 | def main(): 99 | global args, viz 100 | 101 | print('=> loading datasets...') 102 | dset = datasets.SSTSeq(args.train_root, 103 | seq_len=args.seq_len, 104 | target_seq_len=args.target_seq_len, 105 | zones=args.train_zones, 106 | rescale_method=args.rescale, 107 | time_slice=slice(None, 3000), 108 | normalize_uv=True, 109 | ) 110 | 111 | test_dset = datasets.SSTSeq(args.train_root, 112 | seq_len=args.seq_len, 113 | target_seq_len=args.test_target_seq_len, 114 | zones=args.test_zones, 115 | rescale_method=args.rescale, 116 | time_slice=slice(3000, None), 117 | normalize_uv=True, 118 | ) 119 | 120 | # train_indices = range(0, int(len(dset) * args.split)) 121 | # val_indices = range(int(len(dset) * args.split), len(dset)) 122 | 123 | train_loader = DataLoader(dset, 124 | batch_size=args.batch_size, 125 | # sampler=SubsetRandomSampler(train_indices), 126 | num_workers=args.workers, 127 | shuffle=True, 128 | pin_memory=True 129 | ) 130 | # val_loader = DataLoader(dset, 131 | # batch_size=args.batch_size, 132 | # sampler=SubsetRandomSampler(val_indices), 133 | # num_workers=args.workers, 134 | # pin_memory=True 135 | # ) 136 | test_loader = DataLoader(test_dset, 137 | batch_size=args.batch_size, 138 | shuffle=False, 139 | num_workers=args.workers, 140 | pin_memory=True 141 | ) 142 | print('len(test)', len(test_dset)) 143 | 144 | splits = { 145 | 'train': train_loader, 146 | # 'valid': val_loader, 147 | 'test': test_loader, 148 | } 149 | 150 | 151 | estimator = estimators.ConvDeconvEstimator(input_channels=args.seq_len, 152 | upsample_mode=args.upsample) 153 | warp = warps.__dict__[args.warp]() 154 | print("=> creating warping scheme '{}'".format(args.warp)) 155 | 156 | # to_save = { 157 | # 'epoch': epoch, 158 | # 'estimator': estimator, 159 | # 'warp': warp, 160 | # 'optim':optimizer, 161 | # 'err_obs': results['test']['pl'], 162 | # 'err_aae': results['test']['err_aae'], 163 | # } 164 | load_path = os.path.join(args.load_root, args.load_fn) 165 | print(f'Loading {load_path} ...') 166 | loaded = torch.load(load_path) 167 | print('loaded', loaded) 168 | estimator = loaded['estimator'] 169 | warp = loaded['warp'] 170 | 171 | # estimator = estimator.cuda() 172 | # warp = warp.cuda() 173 | 174 | photo_loss = torch.nn.MSELoss() 175 | smooth_loss = losses.SmoothnessLoss(torch.nn.MSELoss()) 176 | div_loss = losses.DivergenceLoss(torch.nn.MSELoss()) 177 | magn_loss = losses.MagnitudeLoss(torch.nn.MSELoss()) 178 | sim_loss = torch.nn.functional.cosine_similarity 179 | 180 | cudnn.benchmark = True 181 | optimizer = torch.optim.Adam(estimator.parameters(), args.lr, 182 | betas=(args.momentum, args.beta), 183 | weight_decay=args.weight_decay) 184 | 185 | 186 | _x, _ys = torch.Tensor(), torch.Tensor() 187 | 188 | if not args.no_cuda: 189 | print('=> to cuda') 190 | _x, _ys = _x.cuda(), _ys.cuda() 191 | warp.cuda(), estimator.cuda() 192 | 193 | viz_wins = {} 194 | for epoch in range(1, args.epochs + 1): 195 | 196 | results = {} 197 | for split, dl in splits.items(): 198 | if split != 'test': continue 199 | 200 | meters = AverageMeters() 201 | 202 | if split == 'train': 203 | estimator.train(), warp.train() 204 | else: 205 | estimator.eval(), warp.eval() 206 | 207 | if args.save_ims: 208 | if args.test_target_seq_len == 10: 209 | index = [8, 15, 45, 59, 63, 64, 74, 76, 83, 90, 97, 107, 111, 125, 136, 139, 155, 171, 176, 182, 204, 213, 215, 218, 223, 224, 226, 232, 260, 263] 210 | elif args.test_target_seq_len == 20: 211 | index = [63, 83, 101, 109, 118, 135, 149, 150, 153, 158, 170, 189, 200, 229, 234] 212 | else: 213 | print('dataset', args.test_target_seq_len) 214 | index = [i - args.seq_len + 1 for i in index] 215 | print(f'index: {index}') 216 | 217 | for i, (input, targets, w_targets) in enumerate(dl): 218 | 219 | if args.save_ims: 220 | if i not in index: continue 221 | print('i=', i) 222 | 223 | 224 | with torch.no_grad(): 225 | _x.resize_(input.size()).copy_(input) 226 | _ys.resize_(targets.size()).copy_(targets) 227 | _ys = _ys.transpose(0, 1).unsqueeze(2) 228 | x, ys = Variable(_x), Variable(_ys) 229 | 230 | pl = 0 231 | sl = 0 232 | dl = 0 233 | ml = 0 234 | err_aee = 0 235 | 236 | ims = [] 237 | ws = [] 238 | last_im = x[:, -1].unsqueeze(1) 239 | for j, y in enumerate(ys): 240 | 241 | w = estimator(x) 242 | im = warp(x[:, -1].unsqueeze(1), w) 243 | x = torch.cat([x[:, 1:], im], 1) 244 | 245 | curr_pl = photo_loss(im, y) 246 | pl += curr_pl 247 | sl += smooth_loss(w) 248 | dl += div_loss(w) 249 | ml += magn_loss(w) 250 | 251 | err_aee += sim_loss(w, w_targets[:, j].to('cuda')).mean() 252 | # print(w_targets[:, j].shape, w.shape, 'ok') 253 | 254 | ims.append(im.cpu().data.numpy()) 255 | ws.append(w.cpu().data.numpy()) 256 | 257 | pl /= args.test_target_seq_len 258 | sl /= args.test_target_seq_len 259 | dl /= args.test_target_seq_len 260 | ml /= args.test_target_seq_len 261 | err_aee /= args.test_target_seq_len 262 | # print('err', err_aee) 263 | 264 | loss = pl + args.smooth_coef * sl + args.div_coef * dl + args.magn_coef * ml 265 | 266 | # if split == 'train': 267 | # optimizer.zero_grad() 268 | # loss.backward() 269 | # optimizer.step() 270 | 271 | meters.update( 272 | dict(loss=loss.item(), 273 | pl=pl.item(), 274 | dl=dl.item(), 275 | sl=sl.item(), 276 | ml=ml.item(), 277 | err_aae=err_aee.item(), 278 | # err_unobs=.item(), 279 | # err_obs 280 | ), 281 | n=x.size(0) 282 | ) 283 | images = [ 284 | # ('target', { 285 | # 'in': input.transpose(0, 1).numpy(), 286 | # 'out': ys.cpu().data.numpy() 287 | # } 288 | # ), 289 | ('im', { 290 | 'out': ims 291 | } 292 | ), 293 | ('ws', { 294 | 'out': ws 295 | } 296 | ), 297 | # ('ws_target', { 298 | # 'out': w_targets.transpose(0, 1).numpy() 299 | # } 300 | # ), 301 | ] 302 | if args.save_ims: 303 | import matplotlib.pyplot as plt 304 | images_save_root = '/net/drunk/debezenac/data/flow_icml/saved_images/png/sst/test/adv/' + str(args.test_target_seq_len) 305 | # images_save_root = '/net/drunk/debezenac/data/flow_icml/saved_images/png/sst/adv/20' 306 | nplots = 1 307 | images_save_fn = os.path.join(images_save_root, f'images_{load_path.replace("/", "_")}_{i + args.seq_len - 1}.png') 308 | pl = plot.plot_results(images, min(nplots, args.batch_size), cmap='tarn', renorm=False) 309 | print(images_save_fn) 310 | plt.savefig(images_save_fn) 311 | 312 | if not args.no_plot: 313 | images = [ 314 | ('target', { 315 | 'in': input.transpose(0, 1).numpy(), 316 | 'out': ys.cpu().data.numpy() 317 | } 318 | ), 319 | ('im', { 320 | 'out': ims 321 | } 322 | ), 323 | ('ws', { 324 | 'out': ws 325 | } 326 | ), 327 | ('ws_target', { 328 | 'out': w_targets.transpose(0, 1).numpy() 329 | } 330 | ), 331 | ] 332 | 333 | plt = plot.from_matplotlib(plot.plot_results(images)) 334 | viz.image(plt.transpose(2, 0, 1), 335 | opts=dict(title='{}, epoch {}'.format(split.upper(), epoch)), 336 | win=list(splits).index(split), 337 | ) 338 | 339 | results[split] = meters.avgs() 340 | print('\n\nEpoch: {} {}: {}\t'.format(epoch, split, meters)) 341 | print('seq_len:', args.test_target_seq_len) 342 | 343 | # transposing the results dict 344 | res = {} 345 | legend = [] 346 | for split in results: 347 | legend.append(split) 348 | for metric, avg in results[split].items(): 349 | res.setdefault(metric, []) 350 | res[metric].append(avg) 351 | # plotting 352 | for metric in res: 353 | y = np.expand_dims(np.array(res[metric]), 0) 354 | x = np.array([[epoch]*len(results)]) 355 | if epoch == 1: 356 | win = viz.line(X=x, Y=y, 357 | opts=dict(showlegend=True, 358 | legend=legend, 359 | title=metric)) 360 | viz_wins[metric] = win 361 | else: 362 | viz.line(X=x, Y=y, 363 | opts=dict(showlegend=True, 364 | legend=legend, 365 | title=metric), 366 | win=viz_wins[metric], 367 | update='append') 368 | 369 | # if (epoch % args.save_every == 0) and (epoch >= args.save_start): 370 | # to_save = { 371 | # 'epoch': epoch, 372 | # 'estimator': estimator, 373 | # 'warp': warp, 374 | # 'optim':optimizer, 375 | # 'err_obs': results['test']['pl'], 376 | # 'err_aae': results['test']['err_aae'], 377 | # } 378 | # time_str = datetime.now().strftime("%a-%b-%d-%H:%M:%S.%f") 379 | # save_path = os.path.join(args.save_root, f'{time_str}_{epoch}.pt') 380 | # print(f'Saving modules to {save_path} ...') 381 | # torch.save(to_save, save_path) 382 | 383 | if __name__ == '__main__': 384 | main() 385 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import visdom 5 | 6 | import torch 7 | import torch.utils.data 8 | import torch.backends.cudnn as cudnn 9 | from torch.autograd import Variable 10 | from torch.utils.data.sampler import SubsetRandomSampler 11 | from torch.utils.data import DataLoader 12 | 13 | import flow.modules.losses as losses 14 | import flow.datasets as datasets 15 | import flow.modules.warps as warps 16 | import flow.modules.estimators as estimators 17 | import flow.utils.plot as plot 18 | from flow.utils.meter import AverageMeters 19 | 20 | 21 | warp_names = sorted(name for name in warps.__dict__ 22 | if not name.startswith('__')) 23 | 24 | parser = argparse.ArgumentParser(description='PyTorch FlowNet Training on several datasets') 25 | parser.add_argument('--train-root', metavar='DIR', default='/net/drunk/debezenac/CMEMS_DATA/datasets/np/train', 26 | help='path to training dataset') 27 | parser.add_argument('--test-root', metavar='DIR', default='/net/drunk/debezenac/CMEMS_DATA/datasets/np/test', 28 | help='path to testing dataset') 29 | parser.add_argument('--train-zones', type=int, nargs='+', action='store', dest='train_zones', default=range(1, 30), 30 | help='geographical zones to train on. To train on all zones, add range(1, 30)') 31 | parser.add_argument('--test-zones', type=int, nargs='+', action='store', dest='test_zones', default=range(1, 30), 32 | help='geographical zones to test on. To test on all zones, add range(1, 30)') 33 | parser.add_argument('--rescale', default='norm', type=str, 34 | help='you can choose between minmax and norm') 35 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 36 | help='number of data loading workers (default: 4)') 37 | parser.add_argument('-b', '--batch-size', default=64, type=int, 38 | metavar='N', help='mini-batch size (default: 16)') 39 | parser.add_argument('-s', '--split', default=.8, type=float, metavar='%', 40 | help='split percentage of train samples vs test (default: .8)') 41 | parser.add_argument('--seq-len', default=4, type=int, 42 | help='number of input images as input of the estimator (horizon)') 43 | parser.add_argument('--target-seq-len', default=6, type=int, 44 | help='number of target images') 45 | parser.add_argument('--weight-decay', '--wd', default=4e-4, type=float, 46 | metavar='W', help='weight decay (default: 4e-4)') 47 | parser.add_argument('--warp', default='BilinearWarpingScheme', choices=warp_names, 48 | help='choose warping scheme to use:' + ' | '.join(warp_names)) 49 | parser.add_argument('--upsample', default='bilinear', choices=('deconv', 'nearest', 'bilinear'), 50 | help='choose from (deconv, nearest, bilinear)') 51 | parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float, 52 | metavar='LR', help='initial learning rate') 53 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 54 | help='alpha parameter for adam') 55 | parser.add_argument('--beta', default=0.999, type=float, metavar='M', 56 | help='beta parameters for adam') 57 | parser.add_argument('--smooth-coef', default=0.4, type=float, 58 | help='coefficient associated to smoothness loss in cost function') 59 | parser.add_argument('--div-coef', default=1, type=float, 60 | help='coefficient associated to divergence loss in cost function') 61 | parser.add_argument('--magn-coef', default=-0.003, type=float, 62 | help='coefficient associated to magnitude loss in cost function') 63 | parser.add_argument('--epochs', default=500, type=int, metavar='N', 64 | help='number of total epochs to run (default: 300') 65 | parser.add_argument('--env', default='main', 66 | help='environnment for visdom') 67 | parser.add_argument('--no-plot', action='store_true', 68 | help='no plot images using visdom') 69 | parser.add_argument('--no-cuda', action='store_true', 70 | help='no cuda') 71 | 72 | args = parser.parse_args() 73 | 74 | viz = visdom.Visdom(env=args.env) 75 | 76 | def main(): 77 | global args, viz 78 | 79 | print('=> loading datasets...') 80 | dset = datasets.SSTSeq(args.train_root, 81 | seq_len=args.seq_len, 82 | target_seq_len=args.target_seq_len, 83 | zones=args.train_zones, 84 | rescale_method=args.rescale, 85 | ) 86 | 87 | test_dset = datasets.SSTSeq(args.test_root, 88 | seq_len=args.seq_len, 89 | target_seq_len=args.target_seq_len, 90 | zones=args.test_zones, 91 | rescale_method=args.rescale, 92 | ) 93 | 94 | train_indices = range(0, int(len(dset) * args.split)) 95 | val_indices = range(int(len(dset) * args.split), len(dset)) 96 | 97 | train_loader = DataLoader(dset, 98 | batch_size=args.batch_size, 99 | sampler=SubsetRandomSampler(train_indices), 100 | num_workers=args.workers, 101 | pin_memory=True 102 | ) 103 | val_loader = DataLoader(dset, 104 | batch_size=args.batch_size, 105 | sampler=SubsetRandomSampler(val_indices), 106 | num_workers=args.workers, 107 | pin_memory=True 108 | ) 109 | test_loader = DataLoader(test_dset, 110 | batch_size=args.batch_size, 111 | shuffle=False, 112 | num_workers=args.workers, 113 | pin_memory=True 114 | ) 115 | 116 | splits = { 117 | 'train': train_loader, 118 | 'valid': val_loader, 119 | 'test': test_loader, 120 | } 121 | 122 | 123 | estimator = estimators.ConvDeconvEstimator(input_channels=args.seq_len, 124 | upsample_mode=args.upsample) 125 | warp = warps.__dict__[args.warp]() 126 | print("=> creating warping scheme '{}'".format(args.warp)) 127 | 128 | estimator = estimator.cuda() 129 | warp = warp.cuda() 130 | 131 | photo_loss = torch.nn.MSELoss() 132 | smooth_loss = losses.SmoothnessLoss(torch.nn.MSELoss()) 133 | div_loss = losses.DivergenceLoss(torch.nn.MSELoss()) 134 | magn_loss = losses.MagnitudeLoss(torch.nn.MSELoss()) 135 | 136 | cudnn.benchmark = True 137 | optimizer = torch.optim.Adam(estimator.parameters(), args.lr, 138 | betas=(args.momentum, args.beta), 139 | weight_decay=args.weight_decay) 140 | 141 | 142 | _x, _ys = torch.Tensor(), torch.Tensor() 143 | 144 | if not args.no_cuda: 145 | print('=> to cuda') 146 | _x, _ys = _x.cuda(), _ys.cuda() 147 | warp.cuda(), estimator.cuda() 148 | 149 | viz_wins = {} 150 | for epoch in range(1, args.epochs + 1): 151 | 152 | results = {} 153 | for split, dl in splits.items(): 154 | 155 | meters = AverageMeters() 156 | 157 | if split == 'train': 158 | estimator.train(), warp.train() 159 | else: 160 | estimator.eval(), warp.eval() 161 | 162 | for i, (input, targets) in enumerate(dl): 163 | 164 | _x.resize_(input.size()).copy_(input) 165 | _ys.resize_(targets.size()).copy_(targets) 166 | _ys = _ys.transpose(0, 1).unsqueeze(2) 167 | x, ys = Variable(_x), Variable(_ys) 168 | 169 | pl = 0 170 | sl = 0 171 | dl = 0 172 | ml = 0 173 | 174 | ims = [] 175 | ws = [] 176 | last_im = x[:, -1].unsqueeze(1) 177 | for y in ys: 178 | 179 | w = estimator(x) 180 | im = warp(x[:, -1].unsqueeze(1), w) 181 | x = torch.cat([x[:, 1:], im], 1) 182 | 183 | curr_pl = photo_loss(im, y) 184 | pl += torch.mean(curr_pl) 185 | sl += smooth_loss(w) 186 | dl += div_loss(w) 187 | ml += magn_loss(w) 188 | 189 | ims.append(im.cpu().data.numpy()) 190 | ws.append(w.cpu().data.numpy()) 191 | 192 | pl /= args.target_seq_len 193 | sl /= args.target_seq_len 194 | dl /= args.target_seq_len 195 | ml /= args.target_seq_len 196 | 197 | loss = pl + args.smooth_coef * sl + args.div_coef * dl + args.magn_coef * ml 198 | 199 | if split == 'train': 200 | optimizer.zero_grad() 201 | loss.backward() 202 | optimizer.step() 203 | 204 | meters.update( 205 | dict(loss=loss.data[0], 206 | pl=pl.data[0], 207 | dl=dl.data[0], 208 | sl=sl.data[0], 209 | ml=ml.data[0], 210 | ), 211 | n=x.size(0) 212 | ) 213 | 214 | if not args.no_plot: 215 | images = [ 216 | ('target', { 217 | 'in': input.transpose(0, 1).numpy(), 218 | 'out': ys.cpu().data.numpy() 219 | } 220 | ), 221 | ('im', { 222 | 'out': ims 223 | } 224 | ), 225 | ('ws', { 226 | 'out': ws 227 | } 228 | ), 229 | ] 230 | plt = plot.from_matplotlib(plot.plot_images(images)) 231 | viz.image(plt.transpose(2, 0, 1), 232 | opts=dict(title='{}, epoch {}'.format(split.upper(), epoch)), 233 | win=list(splits).index(split), 234 | ) 235 | 236 | results[split] = meters.avgs() 237 | print('\n\nEpoch: {} {}: {}\t'.format(epoch, split, meters)) 238 | 239 | # transposing the results dict 240 | res = {} 241 | legend = [] 242 | for split in results: 243 | legend.append(split) 244 | for metric, avg in results[split].items(): 245 | res.setdefault(metric, []) 246 | res[metric].append(avg) 247 | # plotting 248 | for metric in res: 249 | y = np.expand_dims(np.array(res[metric]), 0) 250 | x = np.array([[epoch]*len(results)]) 251 | if epoch == 1: 252 | win = viz.line(X=x, Y=y, 253 | opts=dict(showlegend=True, 254 | legend=legend, 255 | title=metric)) 256 | viz_wins[metric] = win 257 | else: 258 | viz.line(X=x, Y=y, 259 | opts=dict(showlegend=True, 260 | legend=legend, 261 | title=metric), 262 | win=viz_wins[metric], 263 | update='append') 264 | 265 | if __name__ == '__main__': 266 | main() 267 | -------------------------------------------------------------------------------- /train_with_ip_addr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import visdom 5 | from datetime import datetime 6 | 7 | import torch 8 | import torch.utils.data 9 | import torch.backends.cudnn as cudnn 10 | from torch.autograd import Variable 11 | from torch.utils.data.sampler import SubsetRandomSampler 12 | from torch.utils.data import DataLoader 13 | 14 | import flow.modules.losses as losses 15 | import flow.datasets as datasets 16 | import flow.modules.warps as warps 17 | import flow.modules.estimators as estimators 18 | import flow.utils.plot as plot 19 | from flow.utils.meter import AverageMeters 20 | 21 | 22 | warp_names = sorted(name for name in warps.__dict__ 23 | if not name.startswith('__')) 24 | 25 | parser = argparse.ArgumentParser(description='PyTorch FlowNet Training on several datasets') 26 | parser.add_argument('--train-root', metavar='DIR', default='/net/drunk/debezenac/CMEMS_DATA/datasets/np/train', 27 | help='path to training dataset') 28 | parser.add_argument('--test-root', metavar='DIR', default='/net/drunk/debezenac/CMEMS_DATA/datasets/np/test', 29 | help='path to testing dataset') 30 | parser.add_argument('--train-zones', type=int, nargs='+', action='store', dest='train_zones', default=[20], 31 | help='geographical zones to train on. To train on all zones, add range(1, 30)') 32 | parser.add_argument('--test-zones', type=int, nargs='+', action='store', dest='test_zones', default=[20], 33 | help='geographical zones to test on. To test on all zones, add range(1, 30)') 34 | parser.add_argument('--rescale', default='norm', type=str, 35 | help='you can choose between minmax and norm') 36 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 37 | help='number of data loading workers (default: 4)') 38 | parser.add_argument('-b', '--batch-size', default=64, type=int, 39 | metavar='N', help='mini-batch size (default: 16)') 40 | parser.add_argument('-s', '--split', default=.8, type=float, metavar='%', 41 | help='split percentage of train samples vs test (default: .8)') 42 | parser.add_argument('--seq-len', default=4, type=int, 43 | help='number of input images as input of the estimator (horizon)') 44 | parser.add_argument('--target-seq-len', default=6, type=int, 45 | help='number of target images') 46 | parser.add_argument('--test-target-seq-len', default=10, type=int, 47 | help='number of test target images') 48 | parser.add_argument('--weight-decay', '--wd', default=4e-4, type=float, 49 | metavar='W', help='weight decay (default: 4e-4)') 50 | parser.add_argument('--warp', default='BilinearWarpingScheme', choices=warp_names, 51 | help='choose warping scheme to use:' + ' | '.join(warp_names)) 52 | parser.add_argument('--upsample', default='bilinear', choices=('deconv', 'nearest', 'bilinear'), 53 | help='choose from (deconv, nearest, bilinear)') 54 | parser.add_argument('--lr', '--learning-rate', default=0.0001, type=float, 55 | metavar='LR', help='initial learning rate') 56 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 57 | help='momentum for sgd, alpha parameter for adam') 58 | parser.add_argument('--beta', default=0.999, type=float, metavar='M', 59 | help='beta parameters for adam') 60 | parser.add_argument('--smooth-coef', default=0.4, type=float, 61 | help='coefficient associated to smoothness loss in cost function') 62 | parser.add_argument('--div-coef', default=1, type=float, 63 | help='coefficient associated to divergence loss in cost function') 64 | parser.add_argument('--magn-coef', default=-0.003, type=float, 65 | help='coefficient associated to magnitude loss in cost function') 66 | parser.add_argument('--epochs', default=500, type=int, metavar='N', 67 | help='number of total epochs to run (default: 300') 68 | parser.add_argument('--save-every', default=10, type=int, metavar='N', 69 | help='') 70 | parser.add_argument('--save-start', default=20, type=int, metavar='N', 71 | help='') 72 | parser.add_argument('--save-root', default='/net/drunk/debezenac/data/flow_icml/saved_modules_iclr_2018', type=str, 73 | help='you can choose between minmax and norm') 74 | parser.add_argument('--env', default='main', 75 | help='environnment for visdom') 76 | parser.add_argument('--no-plot', action='store_true', 77 | help='no plot images using visdom') 78 | parser.add_argument('--no-cuda', action='store_true', 79 | help='no cuda') 80 | 81 | args = parser.parse_args() 82 | 83 | viz = visdom.Visdom(server='http://132.227.204.175', env=args.env) 84 | 85 | def main(): 86 | global args, viz 87 | 88 | print('=> loading datasets...') 89 | dset = datasets.SSTSeq(args.train_root, 90 | seq_len=args.seq_len, 91 | target_seq_len=args.target_seq_len, 92 | zones=args.train_zones, 93 | rescale_method=args.rescale, 94 | time_slice=slice(None, 3000), 95 | normalize_uv=True, 96 | ) 97 | 98 | test_dset = datasets.SSTSeq(args.train_root, 99 | seq_len=args.seq_len, 100 | target_seq_len=args.test_target_seq_len, 101 | zones=args.test_zones, 102 | rescale_method=args.rescale, 103 | time_slice=slice(3000, None), 104 | normalize_uv=True, 105 | ) 106 | 107 | # train_indices = range(0, int(len(dset) * args.split)) 108 | # val_indices = range(int(len(dset) * args.split), len(dset)) 109 | 110 | train_loader = DataLoader(dset, 111 | batch_size=args.batch_size, 112 | # sampler=SubsetRandomSampler(train_indices), 113 | num_workers=args.workers, 114 | shuffle=True, 115 | pin_memory=True, 116 | drop_last=True, 117 | ) 118 | # val_loader = DataLoader(dset, 119 | # batch_size=args.batch_size, 120 | # sampler=SubsetRandomSampler(val_indices), 121 | # num_workers=args.workers, 122 | # pin_memory=True 123 | # ) 124 | test_loader = DataLoader(test_dset, 125 | batch_size=args.batch_size, 126 | shuffle=True, 127 | drop_last=True, 128 | num_workers=args.workers, 129 | pin_memory=True 130 | ) 131 | 132 | splits = { 133 | 'train': train_loader, 134 | # 'valid': val_loader, 135 | 'test': test_loader, 136 | } 137 | 138 | 139 | estimator = estimators.ConvDeconvEstimator(input_channels=args.seq_len, 140 | upsample_mode=args.upsample) 141 | warp = warps.__dict__[args.warp]() 142 | print("=> creating warping scheme '{}'".format(args.warp)) 143 | 144 | estimator = estimator.cuda() 145 | warp = warp.cuda() 146 | 147 | photo_loss = torch.nn.MSELoss() 148 | smooth_loss = losses.SmoothnessLoss(torch.nn.MSELoss()) 149 | div_loss = losses.DivergenceLoss(torch.nn.MSELoss()) 150 | magn_loss = losses.MagnitudeLoss(torch.nn.MSELoss()) 151 | 152 | cudnn.benchmark = True 153 | optimizer = torch.optim.Adam(estimator.parameters(), args.lr, 154 | betas=(args.momentum, args.beta), 155 | weight_decay=args.weight_decay) 156 | 157 | 158 | _x, _ys = torch.Tensor(), torch.Tensor() 159 | 160 | if not args.no_cuda: 161 | print('=> to cuda') 162 | _x, _ys = _x.cuda(), _ys.cuda() 163 | warp.cuda(), estimator.cuda() 164 | 165 | viz_wins = {} 166 | for epoch in range(1, args.epochs + 1): 167 | 168 | results = {} 169 | for split, dl in splits.items(): 170 | 171 | meters = AverageMeters() 172 | 173 | if split == 'train': 174 | estimator.train(), warp.train() 175 | else: 176 | estimator.eval(), warp.eval() 177 | 178 | for i, (input, targets, w_targets) in enumerate(dl): 179 | 180 | _x.resize_(input.size()).copy_(input) 181 | _ys.resize_(targets.size()).copy_(targets) 182 | _ys = _ys.transpose(0, 1).unsqueeze(2) 183 | x, ys = Variable(_x), Variable(_ys) 184 | 185 | pl = 0 186 | sl = 0 187 | dl = 0 188 | ml = 0 189 | err_aee = 0 190 | 191 | ims = [] 192 | ws = [] 193 | last_im = x[:, -1].unsqueeze(1) 194 | for j, y in enumerate(ys): 195 | 196 | w = estimator(x) 197 | im = warp(x[:, -1].unsqueeze(1), w) 198 | x = torch.cat([x[:, 1:], im], 1) 199 | 200 | curr_pl = photo_loss(im, y) 201 | pl += curr_pl 202 | sl += smooth_loss(w) 203 | dl += div_loss(w) 204 | ml += magn_loss(w) 205 | # print('okokok', w_targets.shape, 'w', w.shape) 206 | # exit() 207 | 208 | err_aee += losses.AAE(w, w_targets[:, j].to('cuda')) 209 | 210 | ims.append(im.cpu().data.numpy()) 211 | ws.append(w.cpu().data.numpy()) 212 | 213 | pl /= args.target_seq_len 214 | sl /= args.target_seq_len 215 | dl /= args.target_seq_len 216 | ml /= args.target_seq_len 217 | err_aee /= args.target_seq_len 218 | 219 | loss = pl + args.smooth_coef * sl + args.div_coef * dl + args.magn_coef * ml 220 | 221 | if split == 'train': 222 | optimizer.zero_grad() 223 | loss.backward() 224 | optimizer.step() 225 | 226 | meters.update( 227 | dict(loss=loss.item(), 228 | pl=pl.item(), 229 | dl=dl.item(), 230 | sl=sl.item(), 231 | ml=ml.item(), 232 | err_aae=err_aee.item(), 233 | # err_unobs=.item(), 234 | # err_obs 235 | ), 236 | n=x.size(0) 237 | ) 238 | 239 | if not args.no_plot: 240 | images = [ 241 | ('target', { 242 | 'in': input.transpose(0, 1).numpy(), 243 | 'out': ys.cpu().data.numpy() 244 | } 245 | ), 246 | ('im', { 247 | 'out': ims 248 | } 249 | ), 250 | ('ws', { 251 | 'out': ws 252 | } 253 | ), 254 | ] 255 | plt = plot.from_matplotlib(plot.plot_results(images)) 256 | viz.image(plt.transpose(2, 0, 1), 257 | opts=dict(title='{}, epoch {}'.format(split.upper(), epoch)), 258 | win=list(splits).index(split), 259 | ) 260 | 261 | results[split] = meters.avgs() 262 | print('\n\nEpoch: {} {}: {}\t'.format(epoch, split, meters)) 263 | 264 | # transposing the results dict 265 | res = {} 266 | legend = [] 267 | for split in results: 268 | legend.append(split) 269 | for metric, avg in results[split].items(): 270 | res.setdefault(metric, []) 271 | res[metric].append(avg) 272 | # plotting 273 | for metric in res: 274 | y = np.expand_dims(np.array(res[metric]), 0) 275 | x = np.array([[epoch]*len(results)]) 276 | if epoch == 1: 277 | win = viz.line(X=x, Y=y, 278 | opts=dict(showlegend=True, 279 | legend=legend, 280 | title=metric)) 281 | viz_wins[metric] = win 282 | else: 283 | viz.line(X=x, Y=y, 284 | opts=dict(showlegend=True, 285 | legend=legend, 286 | title=metric), 287 | win=viz_wins[metric], 288 | update='append') 289 | 290 | if (epoch % args.save_every == 0) and (epoch >= args.save_start): 291 | to_save = { 292 | 'epoch': epoch, 293 | 'estimator': estimator, 294 | 'warp': warp, 295 | 'optim':optimizer, 296 | 'err_obs': results['test']['pl'], 297 | 'err_aae': results['test']['err_aae'], 298 | } 299 | time_str = datetime.now().strftime("%a-%b-%d-%H:%M:%S.%f") 300 | save_path = os.path.join(args.save_root, f'{time_str}_{epoch}.pt') 301 | print(f'Saving modules to {save_path} ...') 302 | torch.save(to_save, save_path) 303 | 304 | if __name__ == '__main__': 305 | main() 306 | --------------------------------------------------------------------------------