├── Assests ├── dump ├── image.png └── 1554797352491.png ├── flownet2 ├── dumpy ├── get_flow.py ├── run-caffe2pytorch.sh ├── pytorch_load.py ├── test_flownet2.py ├── convert.py ├── models.py ├── datasets.py └── main.py ├── log └── dumpy.log ├── liteFlownet ├── __init__.py ├── flow_vis.py ├── correlation │ └── correlation.py └── lite_flownet.py ├── utils └── utils.py ├── ano_pre ├── util.py ├── losses.py ├── evaluate.py ├── train.py └── eval_metric.py ├── README.md ├── Dataset └── img_dataset.py └── models ├── unet.py └── pix2pix_networks.py /Assests/dump: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /flownet2/dumpy: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /log/dumpy.log: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /liteFlownet/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Assests/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fjchange/pytorch_ano_pre/HEAD/Assests/image.png -------------------------------------------------------------------------------- /Assests/1554797352491.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fjchange/pytorch_ano_pre/HEAD/Assests/1554797352491.png -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import glob 3 | import os 4 | 5 | def saver(model_state_dict,model_path,step,max_to_save=5): 6 | total_models=glob.glob(model_path+'*') 7 | if len(total_models)>=max_to_save: 8 | total_models.sort() 9 | os.remove(total_models[0]) 10 | torch.save(model_state_dict,model_path+'-'+str(step)) 11 | print('model {} save successfully!'.format(model_path+'-'+str(step))) 12 | -------------------------------------------------------------------------------- /flownet2/get_flow.py: -------------------------------------------------------------------------------- 1 | from models import * 2 | import torch 3 | # from networks import * 4 | from liteFlownet.flow_vis import vis_mv 5 | 6 | def get_flow(model,input_tensor): 7 | ''' 8 | input tensor is in range [0,1] and 9 | :param model: the loaded flownetSD model 10 | :param input_tensor: the pytorch tensor,shape as [batch_size,channels*2,width,height] 11 | range from[0,1] 12 | :return: 13 | ''' 14 | # the flownet2 need 15 | #to the scale of [0,255] 16 | # input_tensor=input_tensor*255.0 17 | input_tensor=input_tensor.view([-1,3,2,input_tensor.shape[-2],input_tensor.shape[-1]]) 18 | 19 | model.eval() 20 | flow=model(input_tensor) 21 | 22 | return flow 23 | 24 | def get_batch_flow(model,input_tensor): 25 | #[batch_size,channels*2,height,weight] 26 | input_tensor=input_tensor*255.0 27 | # input_tensor=input_tensor.view([input_tensor[0],-1,2,input_tensor.shape[-2],input_tensor.shape[-1]]) 28 | flow=model(input_tensor) 29 | return flow 30 | 31 | 32 | -------------------------------------------------------------------------------- /flownet2/run-caffe2pytorch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | FN2PYTORCH=${1:-/} 4 | 5 | # install custom layers 6 | sudo nvidia-docker build -t $USER/pytorch:CUDA8-py27 . 7 | sudo nvidia-docker run --rm -ti --volume=${FN2PYTORCH}:/flownet2-pytorch:rw --workdir=/flownet2-pytorch $USER/pytorch:CUDA8-py27 /bin/bash -c "./install.sh" 8 | 9 | # convert FlowNet2-C, CS, CSS, CSS-ft-sd, SD, S and 2 to PyTorch 10 | sudo nvidia-docker run -ti --volume=${FN2PYTORCH}:/fn2pytorch:rw flownet2:latest /bin/bash -c "source /flownet2/flownet2/set-env.sh && cd /flownet2/flownet2/models && \ 11 | python /fn2pytorch/convert.py ./FlowNet2-C/FlowNet2-C_weights.caffemodel ./FlowNet2-C/FlowNet2-C_deploy.prototxt.template /fn2pytorch && 12 | python /fn2pytorch/convert.py ./FlowNet2-CS/FlowNet2-CS_weights.caffemodel ./FlowNet2-CS/FlowNet2-CS_deploy.prototxt.template /fn2pytorch && \ 13 | python /fn2pytorch/convert.py ./FlowNet2-CSS/FlowNet2-CSS_weights.caffemodel.h5 ./FlowNet2-CSS/FlowNet2-CSS_deploy.prototxt.template /fn2pytorch && \ 14 | python /fn2pytorch/convert.py ./FlowNet2-CSS-ft-sd/FlowNet2-CSS-ft-sd_weights.caffemodel.h5 ./FlowNet2-CSS-ft-sd/FlowNet2-CSS-ft-sd_deploy.prototxt.template /fn2pytorch && \ 15 | python /fn2pytorch/convert.py ./FlowNet2-SD/FlowNet2-SD_weights.caffemodel.h5 ./FlowNet2-SD/FlowNet2-SD_deploy.prototxt.template /fn2pytorch && \ 16 | python /fn2pytorch/convert.py ./FlowNet2-S/FlowNet2-S_weights.caffemodel.h5 ./FlowNet2-S/FlowNet2-S_deploy.prototxt.template /fn2pytorch && \ 17 | python /fn2pytorch/convert.py ./FlowNet2/FlowNet2_weights.caffemodel.h5 ./FlowNet2/FlowNet2_deploy.prototxt.template /fn2pytorch" -------------------------------------------------------------------------------- /flownet2/pytorch_load.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import torch 3 | from networks import FlowNetSD 4 | import tensorflow as tf 5 | from models.unet import Unet 6 | # def load_conv2d(state_dict,name_pth,name_tf): 7 | def tf_model_pth(checkpoint_path,pth_output_path): 8 | model=FlowNetSD.FlowNetSD(batchNorm=False).eval() 9 | state_dict=model.state_dict() 10 | 11 | with open(checkpoint_path) as f: 12 | ckptFileName=f.readline().split('"')[1] 13 | 14 | reader=tf.train.NewCheckpointReader(ckptFileName) 15 | 16 | pth_keys=state_dict.keys() 17 | keys=sorted(reader.get_variable_to_shape_map().keys()) 18 | for pth_key in pth_keys: 19 | pth_keySplits=pth_key.split('.') 20 | key_pre='FlowNetSD/' 21 | 22 | if pth_keySplits[0][:8]=='upsample': 23 | key=key_pre+'upsample_flow'+pth_keySplits[0][-6]+'to'+pth_keySplits[0][-1]+'/'+pth_keySplits[-1] 24 | elif pth_keySplits[0][:5]=='inter': 25 | key=key_pre+'interconv'+pth_keySplits[0][-1]+'/'+pth_keySplits[-1] 26 | else: 27 | key=key_pre+pth_keySplits[0]+'/'+pth_keySplits[-1] 28 | 29 | if pth_keySplits[-1]=='weight': 30 | tensor = reader.get_tensor(key+'s') 31 | state_dict[pth_key]=torch.from_numpy(tensor).permute([3,2,0,1]) 32 | else: 33 | tensor = reader.get_tensor(key+'es') 34 | state_dict[pth_key]=torch.from_numpy(tensor) 35 | 36 | torch.save(model.state_dict(),pth_output_path) 37 | 38 | def unet_tf_pth(checkpoint_path,pth_output_path): 39 | model=Unet().eval() 40 | state_dict=model.state_dict() 41 | 42 | reader=tf.train.NewCheckpointReader(checkpoint_path) 43 | 44 | pth_keys=state_dict.keys() 45 | keys=sorted(reader.get_variable_to_shape_map().keys()) 46 | print(keys) 47 | print(pth_keys) 48 | # for pth_key in pth_keys: 49 | 50 | 51 | if __name__=='__main__': 52 | #tf_model_pth(r'/home/fjc/FlowNetSD/checkpoint',r'/home/fjc/FlowNetSD/flownet-SD.pth') 53 | unet_tf_pth(r'/home/fjc/pretrains/ped1',r'/home/fjc/trans_pth/ped1.pth') 54 | -------------------------------------------------------------------------------- /ano_pre/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision.transforms 4 | 5 | def log10(t): 6 | """ 7 | Calculates the base-10 log of each element in t. 8 | @param t: The tensor from which to calculate the base-10 log. 9 | @return: A tensor with the base-10 log of each element in t. 10 | """ 11 | 12 | numerator = torch.log(t) 13 | denominator = torch.log(torch.FloatTensor([10.])).cuda() 14 | return numerator / denominator 15 | 16 | 17 | def psnr_error(gen_frames, gt_frames): 18 | """ 19 | Computes the Peak Signal to Noise Ratio error between the generated images and the ground 20 | truth images. 21 | @param gen_frames: A tensor of shape [batch_size, height, width, 3]. The frames generated by the 22 | generator model. 23 | @param gt_frames: A tensor of shape [batch_size, height, width, 3]. The ground-truth frames for 24 | each frame in gen_frames. 25 | @return: A scalar tensor. The mean Peak Signal to Noise Ratio error over each frame in the 26 | batch. 27 | """ 28 | shape = list(gen_frames.shape) 29 | num_pixels = (shape[1] * shape[2] * shape[3]) 30 | gt_frames = (gt_frames + 1.0) / 2.0 31 | gen_frames = (gen_frames + 1.0) / 2.0 32 | square_diff = (gt_frames - gen_frames)**2 33 | 34 | batch_errors = 10 * log10(1. / ((1. / num_pixels) * torch.sum(square_diff, [1, 2, 3]))) 35 | return torch.mean(batch_errors) 36 | 37 | #for [B,C,W,H] 38 | def bgr_gray(input_tensor): 39 | B=input_tensor[:,0].view(input_tensor.size()[0],1,input_tensor.size()[2],input_tensor.size()[3]) 40 | G=input_tensor[:,1].view(input_tensor.size()[0],1,input_tensor.size()[2],input_tensor.size()[3]) 41 | R=input_tensor[:,2].view(input_tensor.size()[0],1,input_tensor.size()[2],input_tensor.size()[3]) 42 | gray_tensor=B*0.114+G*0.587+R*0.299 43 | return gray_tensor 44 | 45 | def diff_mask(gen_frames, gt_frames, min_value=-1, max_value=1): 46 | # normalize to [0, 1] 47 | delta = max_value - min_value 48 | gen_frames = (gen_frames - min_value) / delta 49 | gt_frames = (gt_frames - min_value) / delta 50 | 51 | gen_gray_frames = bgr_gray(gen_frames) 52 | gt_gray_frames = bgr_gray(gt_frames) 53 | 54 | diff = torch.abs(gen_gray_frames - gt_gray_frames) 55 | return diff 56 | 57 | -------------------------------------------------------------------------------- /flownet2/test_flownet2.py: -------------------------------------------------------------------------------- 1 | import get_flow 2 | from models import FlowNet2SD 3 | from networks import FlowNetSD 4 | import numpy as np 5 | import torch 6 | from liteFlownet.flow_vis import vis_mv 7 | from scipy.misc import imread,imresize 8 | import argparse 9 | import os 10 | from Dataset.img_dataset import np_load_frame 11 | # img1='../liteFlownet/001.jpg' 12 | # img2='../liteFlownet/002.jpg' 13 | # img1='/hdd/fjc/VAD/ped1/training/frames/ped1_train_01/005.jpg' 14 | # img2='/hdd/fjc/VAD/ped1/training/frames/ped1_train_01/007.jpg' 15 | # img1='/hdd/fjc/VAD/ped2/training/frames/01/006.jpg' 16 | # img2='/hdd/fjc/VAD/ped2/training/frames/01/007.jpg' 17 | # img1='/hdd/fjc/VAD/shanghaitech/training/frames/01/01_001/01_001_0069.jpg' 18 | # img2='/hdd/fjc/VAD/shanghaitech/training/frames/01/01_001/01_001_0072.jpg' 19 | img1='/hdd/fjc/VAD/shanghaitech/training/frames/01_002/01_002_0075.jpg' 20 | img2='/hdd/fjc/VAD/shanghaitech/training/frames/01_002/01_002_0076.jpg' 21 | 22 | # img1='/hdd/fjc/VAD/avenue/training/frames/01/0003.jpg' 23 | # img2='/hdd/fjc/VAD/avenue/training/frames/01/0004.jpg' 24 | 25 | model_path='/home/fjc/FlowNet2-SD_checkpoint.pth.tar' 26 | flowSD_model_path='/home/fjc/FlowNetSD/flownet-SD.pth' 27 | 28 | os.environ['CUDA_VISIBLE_DEVICES']='3' 29 | 30 | def test(): 31 | jpg1=imread(img1) 32 | jpg2=imread(img2) 33 | #[3,256,256] 34 | jpg1=imresize(jpg1,(384,512)) 35 | jpg2=imresize(jpg2,(384,512)) 36 | 37 | # jpg1=np.expand_dims(jpg1,1) 38 | # jpg2=np.expand_dims(jpg2,1) 39 | # 40 | # #[3,2,256,256] 41 | # images=np.concatenate([jpg1,jpg2],1) 42 | images=np.array([jpg1,jpg2],np.float32) 43 | #[2,256,256,3] 44 | # images=[jpg1,jpg2] 45 | # 46 | # images=np.array(images) 47 | 48 | images=np.transpose(images,[3,0,1,2]) 49 | 50 | 51 | images=torch.FloatTensor(images).cuda() 52 | 53 | # flownet=FlowNetSD.FlowNetSD(False).cuda().eval() 54 | # 55 | # flownet.load_state_dict(torch.load(flowSD_model_path)) 56 | 57 | flownet=FlowNet2SD().cuda().eval() 58 | flownet.load_state_dict(torch.load(model_path)['state_dict']) 59 | 60 | flow=get_flow.get_flow(flownet,images) 61 | flow=flow.cpu().detach().numpy() 62 | vis_mv(flow[0].transpose([1,2,0])) 63 | print(flow[0].transpose(1,2,0)) 64 | 65 | def test_flownet(): 66 | flownetSD=FlowNetSD.FlowNetSD().cuda().eval() 67 | a=flownetSD.state_dict() 68 | print(a) 69 | 70 | if __name__=='__main__': 71 | test() 72 | 73 | 74 | -------------------------------------------------------------------------------- /ano_pre/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | class Flow_Loss(nn.Module): 5 | def __init__(self): 6 | super(Flow_Loss,self).__init__() 7 | 8 | def forward(self, gen_flows,gt_flows): 9 | 10 | return torch.mean(torch.abs(gen_flows - gt_flows)) 11 | 12 | class Intensity_Loss(nn.Module): 13 | def __init__(self,l_num): 14 | super(Intensity_Loss,self).__init__() 15 | self.l_num=l_num 16 | def forward(self, gen_frames,gt_frames): 17 | 18 | return torch.mean(torch.abs((gen_frames-gt_frames)**self.l_num)) 19 | 20 | class Gradient_Loss(nn.Module): 21 | def __init__(self,alpha,channels): 22 | super(Gradient_Loss,self).__init__() 23 | self.alpha=alpha 24 | filter=torch.FloatTensor([[-1.,1.]]).cuda() 25 | 26 | self.filter_x = filter.view(1,1,1,2).repeat(1,channels,1,1) 27 | self.filter_y = filter.view(1,1,2,1).repeat(1,channels,1,1) 28 | 29 | 30 | def forward(self, gen_frames,gt_frames): 31 | 32 | 33 | # pos=torch.from_numpy(np.identity(channels,dtype=np.float32)) 34 | # neg=-1*pos 35 | # filter_x=torch.cat([neg,pos]).view(1,pos.shape[0],-1) 36 | # filter_y=torch.cat([pos.view(1,pos.shape[0],-1),neg.vew(1,neg.shape[0],-1)]) 37 | gen_frames_x=nn.functional.pad(gen_frames,(1,0,0,0)) 38 | gen_frames_y=nn.functional.pad(gen_frames,(0,0,1,0)) 39 | gt_frames_x=nn.functional.pad(gt_frames,(1,0,0,0)) 40 | gt_frames_y=nn.functional.pad(gt_frames,(0,0,1,0)) 41 | 42 | gen_dx=nn.functional.conv2d(gen_frames_x,self.filter_x) 43 | gen_dy=nn.functional.conv2d(gen_frames_y,self.filter_y) 44 | gt_dx=nn.functional.conv2d(gt_frames_x,self.filter_x) 45 | gt_dy=nn.functional.conv2d(gt_frames_y,self.filter_y) 46 | 47 | grad_diff_x=torch.abs(gt_dx-gen_dx) 48 | grad_diff_y=torch.abs(gt_dy-gen_dy) 49 | 50 | return torch.mean(grad_diff_x**self.alpha+grad_diff_y**self.alpha) 51 | 52 | class Adversarial_Loss(nn.Module): 53 | def __init__(self): 54 | super(Adversarial_Loss,self).__init__() 55 | def forward(self, fake_outputs): 56 | return torch.mean((fake_outputs-1)**2/2) 57 | class Discriminate_Loss(nn.Module): 58 | def __init__(self): 59 | super(Discriminate_Loss,self).__init__() 60 | def forward(self,real_outputs,fake_outputs ): 61 | return torch.mean((real_outputs-1)**2/2)+torch.mean(fake_outputs**2/2) 62 | -------------------------------------------------------------------------------- /ano_pre/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from Dataset import img_dataset 3 | from models.unet import UNet 4 | import sys 5 | sys.path.append('..') 6 | from torch.utils.data import DataLoader 7 | from losses import * 8 | import numpy as np 9 | from util import psnr_error 10 | 11 | import os 12 | import time 13 | import pickle 14 | import eval_metric 15 | 16 | training_data_folder='your_path' 17 | testing_data_folder='your_path' 18 | 19 | dataset_name='avenue' 20 | 21 | psnr_dir='../psnr/' 22 | 23 | def evaluate(frame_num, layer_nums, input_channels, output_channels,model_path,evaluate_name,bn=False): 24 | ''' 25 | 26 | :param frame_num: 27 | :param layer_nums: 28 | :param input_channels: 29 | :param output_channels: 30 | :param model_path: 31 | :param evaluate_name: compute_auc 32 | :param bn: 33 | :return: 34 | ''' 35 | generator = UNet(n_channels=input_channels, layer_nums=layer_nums, output_channel=output_channels, 36 | bn=bn).cuda().eval() 37 | 38 | video_dirs = os.listdir(testing_data_folder) 39 | video_dirs.sort() 40 | 41 | num_videos = len(video_dirs) 42 | time_stamp = time.time() 43 | 44 | psnr_records=[] 45 | 46 | 47 | total = 0 48 | generator.load_state_dict(torch.load(model_path)) 49 | 50 | for dir in video_dirs: 51 | _temp_test_folder = os.path.join(testing_data_folder, dir) 52 | dataset = img_dataset.test_dataset(_temp_test_folder, clip_length=frame_num) 53 | 54 | len_dataset = dataset.pics_len 55 | test_iters = len_dataset - frame_num + 1 56 | test_counter = 0 57 | 58 | data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=1) 59 | 60 | psnrs = np.empty(shape=(len_dataset,),dtype=np.float32) 61 | for test_input, _ in data_loader: 62 | test_target = test_input[:, -1].cuda() 63 | test_input = test_input[:, :-1].view(test_input.shape[0], -1, test_input.shape[-2], 64 | test_input.shape[-1]).cuda() 65 | 66 | g_output = generator(test_input) 67 | test_psnr = psnr_error(g_output, test_target) 68 | test_psnr = test_psnr.tolist() 69 | psnrs[test_counter+frame_num-1]=test_psnr 70 | 71 | test_counter += 1 72 | total+=1 73 | if test_counter >= test_iters: 74 | psnrs[:frame_num-1]=psnrs[frame_num-1] 75 | psnr_records.append(psnrs) 76 | print('finish test video set {}'.format(_temp_test_folder)) 77 | break 78 | 79 | result_dict = {'dataset': dataset_name, 'psnr': psnr_records, 'flow': [], 'names': [], 'diff_mask': []} 80 | 81 | used_time = time.time() - time_stamp 82 | print('total time = {}, fps = {}'.format(used_time, total / used_time)) 83 | 84 | pickle_path = os.path.join(psnr_dir, os.path.split(model_path)[-1]) 85 | with open(pickle_path, 'wb') as writer: 86 | pickle.dump(result_dict, writer, pickle.HIGHEST_PROTOCOL) 87 | 88 | results = eval_metric.evaluate(evaluate_name, pickle_path) 89 | print(results) 90 | 91 | 92 | if __name__ =='__main__': 93 | evaluate(frame_num=5,layer_nums=4,input_channels=12,output_channels=3,model_path='../pth_model/ano_pred_avenue_generator.pth-9000',evaluate_name='compute_auc') 94 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch_ano_pre 2 | Pytorch Re-implemention of ano_pre_cvpr2018, replace flownet2 with lite-flownet 3 | 4 | ![img](https://github.com/StevenLiuWen/ano_pred_cvpr2018/blob/master/assets/architecture.JPG) 5 | 6 | [Future Frame Prediction for Anomaly Detection -- A New Baseline, CVPR 2018](https://arxiv.org/pdf/1712.09867.pdf) 7 | 8 | [tensorflow_offical_implement](https://github.com/StevenLiuWen/ano_pred_cvpr2018) 9 | 10 | 11 | ** This repo modify the normalization of the Regular Score, And replace flownetSD with lite-flownet ** 12 | AUC 85.6%+-0.1% of Avenue dataset 13 | 14 | **You can use FlowNet2SD Now, modify the code in train.py as the comment said.** 15 | 16 | ![img](https://github.com/fjchange/pytorch_ano_pre/blob/master/Assests/1554797352491.png) 17 | 18 | ## 1. requirement 19 | - pytorch >=0.4.1 20 | - tensorboardX (if you want) 21 | 22 | ## 2. preparation 23 | 1. Download Dataset CUHK Avenue [download_link](https://onedrive.live.com/?authkey=%21AMqh2fTSemfrokE&id=3705E349C336415F%215109&cid=3705E349C336415F), unzip in the path you want, and replace the path in **train.py** 24 | 25 | 2. Download Lite-Flownet model, and replace the path in **train.py** 26 | > wget --timestamping http://content.sniklaus.com/github/pytorch-liteflownet/network-sintel.pytorch 27 | 28 | ** The quality of optical flow matters, it would be better if you finetune the liteflownet with FlyingChairsSDHom dataset** 29 | 30 | if you want to use FlowNet2SD, you should download model form Nvidia/flownet2-pytorch, and replace the path in train.py 31 | > [Flownet2SD](https://drive.google.com/file/d/1QW03eyYG_vD-dT-Mx4wopYvtPu_msTKn/view?usp=sharin) 32 | 33 | 3. replace all the modle_output_path and log_output_path to where you want in **train.py** 34 | 35 | ## 3. training 36 | 37 | > cd ano_pre 38 | 39 | > python train.py 40 | 41 | ## 4. evalute 42 | replace the model_path and evaluate_name as you want 43 | 44 | > cd ano_pre 45 | 46 | > python evaluate.py 47 | 48 | ![img](https://github.com/fjchange/pytorch_ano_pre/blob/master/Assests/image.png) 49 | 50 | ## 5. reference 51 | 52 | If you find this useful, please cite the work as follows: 53 | 54 | ```code 55 | [1] @INPROCEEDINGS{liu2018ano_pred, 56 | author={W. Liu and W. Luo, D. Lian and S. Gao}, 57 | booktitle={2018 IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 58 | title={Future Frame Prediction for Anomaly Detection -- A New Baseline}, 59 | year={2018} 60 | } 61 | [2] misc{pytorch_ano_pred, 62 | author = {Jia-Chang Feng}, 63 | title = { A Reimplementation of {Ano_pred} Using {Pytorch}}, 64 | year = {2019}, 65 | howpublished = {\url{https://github.com/fjchange/pytorch_ano_pre}} 66 | } 67 | [3] @inproceedings{Hui_CVPR_2018, 68 | author = {Tak-Wai Hui and Xiaoou Tang and Chen Change Loy}, 69 | title = {{LiteFlowNet}: A Lightweight Convolutional Neural Network for Optical Flow Estimation}, 70 | booktitle = {IEEE Conference on Computer Vision and Pattern Recognition}, 71 | year = {2018} 72 | } 73 | [4] @misc{pytorch-liteflownet, 74 | author = {Simon Niklaus}, 75 | title = {A Reimplementation of {LiteFlowNet} Using {PyTorch}}, 76 | year = {2019}, 77 | howpublished = {\url{https://github.com/sniklaus/pytorch-liteflownet}} 78 | } 79 | ``` 80 | -------------------------------------------------------------------------------- /Dataset/img_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torchvision import transforms 4 | import numpy as np 5 | import cv2 6 | from collections import OrderedDict 7 | import glob 8 | import os 9 | 10 | from torch.utils.data import Dataset 11 | from torch.utils.data import DataLoader 12 | from torchvision.transforms import ToTensor 13 | rng = np.random.RandomState(2017) 14 | 15 | def np_load_frame(filename, resize_height, resize_width): 16 | """ 17 | Load image path and convert it to numpy.ndarray. Notes that the color channels are BGR and the color space 18 | is normalized from [0, 255] to [0, 1]. 19 | :param filename: the full path of image 20 | :param resize_height: resized height 21 | :param resize_width: resized width 22 | :return: numpy.ndarray 23 | """ 24 | image_decoded = cv2.imread(filename) 25 | image_resized = cv2.resize(image_decoded, (resize_width, resize_height)) 26 | image_resized = image_resized.astype(dtype=np.float32) 27 | image_resized = (image_resized )/255.0 28 | image_resized=np.transpose(image_resized,[2,0,1]) 29 | return image_resized 30 | 31 | class ano_pred_Dataset(Dataset): 32 | ''' 33 | specialized for ano pred model 34 | VAD dataset could not do any data augmentation 35 | normalized from [0,255] to [0,1] 36 | the channels are bgr( because of cv2 and liteFlownet 37 | ''' 38 | #video clip mean 39 | def __init__(self,dataset_folder,clip_length,size=(256,256)): 40 | self.dir=dataset_folder 41 | self.videos=OrderedDict() 42 | self.image_height=size[0] 43 | self.image_width=size[1] 44 | self.clip_length=clip_length 45 | self.setup() 46 | 47 | def __len__(self): 48 | return self.videos.__len__() 49 | 50 | def setup(self): 51 | videos = glob.glob(os.path.join(self.dir, '*')) 52 | for video in sorted(videos): 53 | video_name = video.split('/')[-1] 54 | self.videos[video_name] = {} 55 | self.videos[video_name]['path'] = video 56 | self.videos[video_name]['frame'] = glob.glob(os.path.join(video, '*.jpg')) 57 | self.videos[video_name]['frame'].sort() 58 | self.videos[video_name]['length'] = len(self.videos[video_name]['frame']) 59 | self.videos_keys=self.videos.keys() 60 | 61 | def __getitem__(self, indice): 62 | #each video get 4 frames as input and 1 frames as target output 63 | key=list(self.videos_keys)[indice] 64 | start = rng.randint(0, self.videos[key]['length'] - self.clip_length) 65 | video_clip=[] 66 | 67 | for frame_id in range(start,start+self.clip_length): 68 | #video_clip.append(frame_id) 69 | video_clip.append(np_load_frame(self.videos[key]['frame'][frame_id], self.image_height, self.image_width)) 70 | #video_clip=to_tensor(video_clip) 71 | video_clip=np.array(video_clip) 72 | video_clip=torch.from_numpy(video_clip) 73 | return video_clip 74 | #return video_clip,0 75 | 76 | class test_dataset(Dataset): 77 | # if use have to be very carefully 78 | # not cross the boundary 79 | 80 | def __init__(self,video_folder,clip_length,size=(256,256)): 81 | self.path=video_folder 82 | self.clip_length=clip_length 83 | self.img_height,self.img_width=size 84 | self.setup() 85 | 86 | def setup(self): 87 | self.pics=glob.glob(os.path.join(self.path,'*')) 88 | self.pics.sort() 89 | self.pics_len=len(self.pics) 90 | 91 | def __len__(self): 92 | return self.pics_len 93 | 94 | def __getitem__(self, indice): 95 | pic_clips=[] 96 | for frame_id in range(indice,indice+self.clip_length): 97 | pic_clips.append(np_load_frame(self.pics[frame_id],self.img_height,self.img_width)) 98 | pic_clips=np.array(pic_clips) 99 | pic_clips=torch.from_numpy(pic_clips) 100 | return pic_clips 101 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | 5 | import torch.nn.functional as F 6 | 7 | 8 | class double_conv(nn.Module): 9 | '''(conv => BN => ReLU) * 2''' 10 | 11 | def __init__(self, in_ch, out_ch,bn=False): 12 | super(double_conv, self).__init__() 13 | # self.bn=bn 14 | # if bn: 15 | self.conv= nn.Sequential( 16 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 17 | nn.BatchNorm2d(out_ch), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 20 | nn.BatchNorm2d(out_ch), 21 | nn.ReLU(inplace=True), 22 | ) 23 | # else: 24 | # self.conv = nn.Sequential( 25 | # nn.Conv2d(in_ch, out_ch, 3, padding=1), 26 | # nn.ReLU(inplace=True), 27 | # nn.Conv2d(out_ch, out_ch, 3, padding=1), 28 | # nn.ReLU(inplace=True), 29 | # ) 30 | 31 | def forward(self, x): 32 | x=self.conv(x) 33 | return x 34 | 35 | class inconv(nn.Module): 36 | def __init__(self, in_ch, out_ch,bn=False): 37 | super(inconv, self).__init__() 38 | self.conv = double_conv(in_ch, out_ch,bn) 39 | 40 | def forward(self, x): 41 | x = self.conv(x) 42 | return x 43 | 44 | 45 | class down(nn.Module): 46 | def __init__(self, in_ch, out_ch,bn=False): 47 | super(down, self).__init__() 48 | self.mpconv = nn.Sequential( 49 | nn.MaxPool2d(2), 50 | double_conv(in_ch, out_ch,bn), 51 | ) 52 | 53 | def forward(self, x): 54 | x = self.mpconv(x) 55 | return x 56 | 57 | 58 | class up(nn.Module): 59 | def __init__(self, in_ch, out_ch, bilinear=False,bn=False): 60 | super(up, self).__init__() 61 | self.bilinear=bilinear 62 | # would be a nice idea if the upsampling could be learned too, 63 | # but my machine do not have enough memory to handle all those weights 64 | if self.bilinear: 65 | self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 66 | nn.Conv2d(in_ch,in_ch//2,1),) 67 | 68 | else: 69 | self.up = nn.ConvTranspose2d(in_ch, in_ch // 2, 2, stride=2) 70 | 71 | self.conv = double_conv(in_ch, out_ch,bn) 72 | 73 | def forward(self, x1, x2): 74 | x1 = self.up(x1) 75 | 76 | # input is CHW 77 | diffY = x2.size()[2] - x1.size()[2] 78 | diffX = x2.size()[3] - x1.size()[3] 79 | 80 | x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2)) 81 | 82 | # for padding issues, see 83 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 84 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 85 | 86 | x = torch.cat([x2, x1], dim=1) 87 | x = self.conv(x) 88 | return x 89 | 90 | 91 | class outconv(nn.Module): 92 | def __init__(self, in_ch, out_ch): 93 | super(outconv, self).__init__() 94 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 95 | 96 | def forward(self, x): 97 | x = self.conv(x) 98 | return x 99 | 100 | class UNet(nn.Module): 101 | ''' 102 | layer_nums mean num of layers of half of the Unet 103 | and the features change with ratio of 2 104 | ''' 105 | def __init__(self,n_channels,layer_nums,features_root=64,output_channel=1,bn=False): 106 | super(UNet,self).__init__() 107 | self.inc = inconv(n_channels, 64,bn) 108 | self.down1 = down(64, 128,bn) 109 | self.down2 = down(128, 256,bn) 110 | self.down3 = down(256, 512,bn) 111 | self.up1 = up(512, 256) 112 | self.up2 = up(256, 128) 113 | self.up3 = up(128, 64) 114 | self.outc = outconv(64, output_channel) 115 | 116 | def forward(self, x): 117 | x1 = self.inc(x) 118 | x2 = self.down1(x1) 119 | x3 = self.down2(x2) 120 | x4 = self.down3(x3) 121 | x = self.up1(x4, x3) 122 | x = self.up2(x, x2) 123 | x = self.up3(x, x1) 124 | x = self.outc(x) 125 | 126 | return torch.sigmoid(x) 127 | 128 | ''' 129 | class SA_UNet(nn.Module): 130 | ''' 131 | def _test(): 132 | rand=torch.ones([4,12,256,256]).cuda() 133 | t=UNet(12,0,3).cuda() 134 | 135 | r=t(rand) 136 | print(r.grad_fn) 137 | print(r.requires_grad) 138 | 139 | if __name__=='__main__': 140 | _test() 141 | -------------------------------------------------------------------------------- /liteFlownet/flow_vis.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-08-03 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import matplotlib.pyplot as plt 21 | 22 | def make_colorwheel(): 23 | ''' 24 | Generates a color wheel for optical flow visualization as presented in: 25 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 26 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 27 | 28 | According to the C++ source code of Daniel Scharstein 29 | According to the Matlab source code of Deqing Sun 30 | ''' 31 | 32 | RY = 15 33 | YG = 6 34 | GC = 4 35 | CB = 11 36 | BM = 13 37 | MR = 6 38 | 39 | ncols = RY + YG + GC + CB + BM + MR 40 | colorwheel = np.zeros((ncols, 3)) 41 | col = 0 42 | 43 | # RY 44 | colorwheel[0:RY, 0] = 255 45 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 46 | col = col+RY 47 | # YG 48 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 49 | colorwheel[col:col+YG, 1] = 255 50 | col = col+YG 51 | # GC 52 | colorwheel[col:col+GC, 1] = 255 53 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 54 | col = col+GC 55 | # CB 56 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 57 | colorwheel[col:col+CB, 2] = 255 58 | col = col+CB 59 | # BM 60 | colorwheel[col:col+BM, 2] = 255 61 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 62 | col = col+BM 63 | # MR 64 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 65 | colorwheel[col:col+MR, 0] = 255 66 | return colorwheel 67 | 68 | 69 | def flow_compute_color(u, v, convert_to_bgr=False): 70 | ''' 71 | Applies the flow color wheel to (possibly clipped) flow components u and v. 72 | 73 | According to the C++ source code of Daniel Scharstein 74 | According to the Matlab source code of Deqing Sun 75 | 76 | :param u: np.ndarray, input horizontal flow 77 | :param v: np.ndarray, input vertical flow 78 | :param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB 79 | :return: 80 | ''' 81 | 82 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 83 | 84 | colorwheel = make_colorwheel() # shape [55x3] 85 | ncols = colorwheel.shape[0] 86 | 87 | rad = np.sqrt(np.square(u) + np.square(v)) 88 | a = np.arctan2(-v, -u)/np.pi 89 | 90 | fk = (a+1) / 2*(ncols-1) + 1 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 1 94 | f = fk - k0 95 | 96 | for i in range(colorwheel.shape[1]): 97 | 98 | tmp = colorwheel[:,i] 99 | col0 = tmp[k0] / 255.0 100 | col1 = tmp[k1] / 255.0 101 | col = (1-f)*col0 + f*col1 102 | 103 | idx = (rad <= 1) 104 | col[idx] = 1 - rad[idx] * (1-col[idx]) 105 | col[~idx] = col[~idx] * 0.75 # out of range? 106 | 107 | # Note the 2-i => BGR instead of RGB 108 | ch_idx = 2-i if convert_to_bgr else i 109 | flow_image[:,:,ch_idx] = np.floor(255 * col) 110 | 111 | return flow_image 112 | 113 | 114 | def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False): 115 | ''' 116 | Expects a two dimensional flow image of shape [H,W,2] 117 | 118 | According to the C++ source code of Daniel Scharstein 119 | According to the Matlab source code of Deqing Sun 120 | 121 | :param flow_uv: np.ndarray of shape [H,W,2] 122 | :param clip_flow: float, maximum clipping value for flow 123 | :return: 124 | ''' 125 | 126 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 127 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 128 | 129 | if clip_flow is not None: 130 | flow_uv = np.clip(flow_uv, 0, clip_flow) 131 | 132 | u = flow_uv[:,:,0] 133 | v = flow_uv[:,:,1] 134 | 135 | rad = np.sqrt(np.square(u) + np.square(v)) 136 | rad_max = np.max(rad) 137 | 138 | epsilon = 1e-5 139 | u = u / (rad_max + epsilon) 140 | v = v / (rad_max + epsilon) 141 | 142 | return flow_compute_color(u, v, convert_to_bgr) 143 | 144 | 145 | def visualize(flo_path): 146 | flow_size=np.fromfile(flo_path,dtype=np.int32)[1:3] 147 | flow_uv=np.reshape(np.fromfile(flo_path,dtype=np.float32)[3:],[flow_size[0],flow_size[1],2]) 148 | flow_color=flow_to_color(flow_uv,convert_to_bgr=False) 149 | plt.imshow(flow_color) 150 | plt.show() 151 | 152 | -------------------------------------------------------------------------------- /flownet2/convert.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2.7 2 | 3 | import caffe 4 | from caffe.proto import caffe_pb2 5 | import sys, os 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | import argparse, tempfile 11 | import numpy as np 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('caffe_model', help='input model in hdf5 or caffemodel format') 15 | parser.add_argument('prototxt_template',help='prototxt template') 16 | parser.add_argument('flownet2_pytorch', help='path to flownet2-pytorch') 17 | 18 | args = parser.parse_args() 19 | 20 | args.rgb_max = 255 21 | args.fp16 = False 22 | args.grads = {} 23 | 24 | # load models 25 | sys.path.append(args.flownet2_pytorch) 26 | 27 | import models 28 | from utils.param_utils import * 29 | 30 | width = 256 31 | height = 256 32 | keys = {'TARGET_WIDTH': width, 33 | 'TARGET_HEIGHT': height, 34 | 'ADAPTED_WIDTH':width, 35 | 'ADAPTED_HEIGHT':height, 36 | 'SCALE_WIDTH':1., 37 | 'SCALE_HEIGHT':1.,} 38 | 39 | template = '\n'.join(np.loadtxt(args.prototxt_template, dtype=str, delimiter='\n')) 40 | for k in keys: 41 | template = template.replace('$%s$'%(k),str(keys[k])) 42 | 43 | prototxt = tempfile.NamedTemporaryFile(mode='w', delete=True) 44 | prototxt.write(template) 45 | prototxt.flush() 46 | 47 | net = caffe.Net(prototxt.name, args.caffe_model, caffe.TEST) 48 | 49 | weights = {} 50 | biases = {} 51 | 52 | for k, v in list(net.params.items()): 53 | weights[k] = np.array(v[0].data).reshape(v[0].data.shape) 54 | biases[k] = np.array(v[1].data).reshape(v[1].data.shape) 55 | print((k, weights[k].shape, biases[k].shape)) 56 | 57 | if 'FlowNet2/' in args.caffe_model: 58 | model = models.FlowNet2(args) 59 | 60 | parse_flownetc(model.flownetc.modules(), weights, biases) 61 | parse_flownets(model.flownets_1.modules(), weights, biases, param_prefix='net2_') 62 | parse_flownets(model.flownets_2.modules(), weights, biases, param_prefix='net3_') 63 | parse_flownetsd(model.flownets_d.modules(), weights, biases, param_prefix='netsd_') 64 | parse_flownetfusion(model.flownetfusion.modules(), weights, biases, param_prefix='fuse_') 65 | 66 | state = {'epoch': 0, 67 | 'state_dict': model.state_dict(), 68 | 'best_EPE': 1e10} 69 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2_checkpoint.pth.tar')) 70 | 71 | elif 'FlowNet2-C/' in args.caffe_model: 72 | model = models.FlowNet2C(args) 73 | 74 | parse_flownetc(model.modules(), weights, biases) 75 | state = {'epoch': 0, 76 | 'state_dict': model.state_dict(), 77 | 'best_EPE': 1e10} 78 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-C_checkpoint.pth.tar')) 79 | 80 | elif 'FlowNet2-CS/' in args.caffe_model: 81 | model = models.FlowNet2CS(args) 82 | 83 | parse_flownetc(model.flownetc.modules(), weights, biases) 84 | parse_flownets(model.flownets_1.modules(), weights, biases, param_prefix='net2_') 85 | 86 | state = {'epoch': 0, 87 | 'state_dict': model.state_dict(), 88 | 'best_EPE': 1e10} 89 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-CS_checkpoint.pth.tar')) 90 | 91 | elif 'FlowNet2-CSS/' in args.caffe_model: 92 | model = models.FlowNet2CSS(args) 93 | 94 | parse_flownetc(model.flownetc.modules(), weights, biases) 95 | parse_flownets(model.flownets_1.modules(), weights, biases, param_prefix='net2_') 96 | parse_flownets(model.flownets_2.modules(), weights, biases, param_prefix='net3_') 97 | 98 | state = {'epoch': 0, 99 | 'state_dict': model.state_dict(), 100 | 'best_EPE': 1e10} 101 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-CSS_checkpoint.pth.tar')) 102 | 103 | elif 'FlowNet2-CSS-ft-sd/' in args.caffe_model: 104 | model = models.FlowNet2CSS(args) 105 | 106 | parse_flownetc(model.flownetc.modules(), weights, biases) 107 | parse_flownets(model.flownets_1.modules(), weights, biases, param_prefix='net2_') 108 | parse_flownets(model.flownets_2.modules(), weights, biases, param_prefix='net3_') 109 | 110 | state = {'epoch': 0, 111 | 'state_dict': model.state_dict(), 112 | 'best_EPE': 1e10} 113 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-CSS-ft-sd_checkpoint.pth.tar')) 114 | 115 | elif 'FlowNet2-S/' in args.caffe_model: 116 | model = models.FlowNet2S(args) 117 | 118 | parse_flownetsonly(model.modules(), weights, biases, param_prefix='') 119 | state = {'epoch': 0, 120 | 'state_dict': model.state_dict(), 121 | 'best_EPE': 1e10} 122 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-S_checkpoint.pth.tar')) 123 | 124 | elif 'FlowNet2-SD/' in args.caffe_model: 125 | model = models.FlowNet2SD(args) 126 | 127 | parse_flownetsd(model.modules(), weights, biases, param_prefix='') 128 | 129 | state = {'epoch': 0, 130 | 'state_dict': model.state_dict(), 131 | 'best_EPE': 1e10} 132 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-SD_checkpoint.pth.tar')) 133 | 134 | else: 135 | print(('model type cound not be determined from input caffe model %s'%(args.caffe_model))) 136 | quit() 137 | print(("done converting ", args.caffe_model)) -------------------------------------------------------------------------------- /flownet2/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | 5 | import math 6 | import numpy as np 7 | 8 | # from networks.resample2d_package.resample2d import Resample2d 9 | # from networks.channelnorm_package.channelnorm import ChannelNorm 10 | 11 | # from networks import FlowNetC 12 | # from networks import FlowNetS 13 | from flownet2.networks import FlowNetSD 14 | # from networks import FlowNetFusion 15 | 16 | from flownet2.networks.submodules import * 17 | 18 | 'Parameter count = 162,518,834' 19 | 20 | # 21 | 22 | class FlowNet2SD(FlowNetSD.FlowNetSD): 23 | def __init__(self,batchNorm=False, div_flow=20): 24 | super(FlowNet2SD, self).__init__( batchNorm=batchNorm) 25 | self.rgb_max = 255. 26 | self.div_flow = div_flow 27 | 28 | def forward(self, inputs): 29 | rgb_mean = inputs.contiguous().view(inputs.size()[:2] + (-1,)).mean(dim=-1).view(inputs.size()[:2] + (1, 1, 1,)) 30 | x = (inputs - rgb_mean) / self.rgb_max 31 | x = torch.cat((x[:, :, 0, :, :], x[:, :, 1, :, :]), dim=1) 32 | 33 | out_conv0 = self.conv0(x) 34 | out_conv1 = self.conv1_1(self.conv1(out_conv0)) 35 | out_conv2 = self.conv2_1(self.conv2(out_conv1)) 36 | 37 | out_conv3 = self.conv3_1(self.conv3(out_conv2)) 38 | out_conv4 = self.conv4_1(self.conv4(out_conv3)) 39 | out_conv5 = self.conv5_1(self.conv5(out_conv4)) 40 | out_conv6 = self.conv6_1(self.conv6(out_conv5)) 41 | 42 | flow6 = self.predict_flow6(out_conv6) 43 | flow6_up = self.upsampled_flow6_to_5(flow6) 44 | out_deconv5 = self.deconv5(out_conv6) 45 | 46 | concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1) 47 | out_interconv5 = self.inter_conv5(concat5) 48 | flow5 = self.predict_flow5(out_interconv5) 49 | 50 | flow5_up = self.upsampled_flow5_to_4(flow5) 51 | out_deconv4 = self.deconv4(concat5) 52 | 53 | concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1) 54 | out_interconv4 = self.inter_conv4(concat4) 55 | flow4 = self.predict_flow4(out_interconv4) 56 | flow4_up = self.upsampled_flow4_to_3(flow4) 57 | out_deconv3 = self.deconv3(concat4) 58 | 59 | concat3 = torch.cat((out_conv3, out_deconv3, flow4_up), 1) 60 | out_interconv3 = self.inter_conv3(concat3) 61 | flow3 = self.predict_flow3(out_interconv3) 62 | flow3_up = self.upsampled_flow3_to_2(flow3) 63 | out_deconv2 = self.deconv2(concat3) 64 | 65 | concat2 = torch.cat((out_conv2, out_deconv2, flow3_up), 1) 66 | out_interconv2 = self.inter_conv2(concat2) 67 | flow2 = self.predict_flow2(out_interconv2) 68 | 69 | if self.training: 70 | return flow2, flow3, flow4, flow5, flow6 71 | else: 72 | return self.upsample1(flow2 * self.div_flow) 73 | 74 | 75 | # class FlowNet2CS(nn.Module): 76 | # 77 | # def __init__(self, args, batchNorm=False, div_flow=20.): 78 | # super(FlowNet2CS, self).__init__() 79 | # self.batchNorm = batchNorm 80 | # self.div_flow = div_flow 81 | # self.rgb_max = args.rgb_max 82 | # self.args = args 83 | # 84 | # self.channelnorm = ChannelNorm() 85 | # 86 | # # First Block (FlowNetC) 87 | # self.flownetc = FlowNetC.FlowNetC(args, batchNorm=self.batchNorm) 88 | # self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear') 89 | # 90 | # if args.fp16: 91 | # self.resample1 = nn.Sequential( 92 | # tofp32(), 93 | # Resample2d(), 94 | # tofp16()) 95 | # else: 96 | # self.resample1 = Resample2d() 97 | # 98 | # # Block (FlowNetS1) 99 | # self.flownets_1 = FlowNetS.FlowNetS(args, batchNorm=self.batchNorm) 100 | # self.upsample2 = nn.Upsample(scale_factor=4, mode='bilinear') 101 | # 102 | # for m in self.modules(): 103 | # if isinstance(m, nn.Conv2d): 104 | # if m.bias is not None: 105 | # init.uniform(m.bias) 106 | # init.xavier_uniform(m.weight) 107 | # 108 | # if isinstance(m, nn.ConvTranspose2d): 109 | # if m.bias is not None: 110 | # init.uniform(m.bias) 111 | # init.xavier_uniform(m.weight) 112 | # # init_deconv_bilinear(m.weight) 113 | # 114 | # def forward(self, inputs): 115 | # rgb_mean = inputs.contiguous().view(inputs.size()[:2] + (-1,)).mean(dim=-1).view(inputs.size()[:2] + (1, 1, 1,)) 116 | # 117 | # x = (inputs - rgb_mean) / self.rgb_max 118 | # x1 = x[:, :, 0, :, :] 119 | # x2 = x[:, :, 1, :, :] 120 | # x = torch.cat((x1, x2), dim=1) 121 | # 122 | # # flownetc 123 | # flownetc_flow2 = self.flownetc(x)[0] 124 | # flownetc_flow = self.upsample1(flownetc_flow2 * self.div_flow) 125 | # 126 | # # warp img1 to img0; magnitude of diff between img0 and and warped_img1, 127 | # resampled_img1 = self.resample1(x[:, 3:, :, :], flownetc_flow) 128 | # diff_img0 = x[:, :3, :, :] - resampled_img1 129 | # norm_diff_img0 = self.channelnorm(diff_img0) 130 | # 131 | # # concat img0, img1, img1->img0, flow, diff-mag ; 132 | # concat1 = torch.cat((x, resampled_img1, flownetc_flow / self.div_flow, norm_diff_img0), dim=1) 133 | # 134 | # # flownets1 135 | # flownets1_flow2 = self.flownets_1(concat1)[0] 136 | # flownets1_flow = self.upsample2(flownets1_flow2 * self.div_flow) 137 | # 138 | # return flownets1_flow 139 | # 140 | # 141 | # class FlowNet2CSS(nn.Module): 142 | # 143 | # def __init__(self, args, batchNorm=False, div_flow=20.): 144 | # super(FlowNet2CSS, self).__init__() 145 | # self.batchNorm = batchNorm 146 | # self.div_flow = div_flow 147 | # self.rgb_max = args.rgb_max 148 | # self.args = args 149 | # 150 | # self.channelnorm = ChannelNorm() 151 | # 152 | # # First Block (FlowNetC) 153 | # self.flownetc = FlowNetC.FlowNetC(args, batchNorm=self.batchNorm) 154 | # self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear') 155 | # 156 | # if args.fp16: 157 | # self.resample1 = nn.Sequential( 158 | # tofp32(), 159 | # Resample2d(), 160 | # tofp16()) 161 | # else: 162 | # self.resample1 = Resample2d() 163 | # 164 | # # Block (FlowNetS1) 165 | # self.flownets_1 = FlowNetS.FlowNetS(args, batchNorm=self.batchNorm) 166 | # self.upsample2 = nn.Upsample(scale_factor=4, mode='bilinear') 167 | # if args.fp16: 168 | # self.resample2 = nn.Sequential( 169 | # tofp32(), 170 | # Resample2d(), 171 | # tofp16()) 172 | # else: 173 | # self.resample2 = Resample2d() 174 | # 175 | # # Block (FlowNetS2) 176 | # self.flownets_2 = FlowNetS.FlowNetS(args, batchNorm=self.batchNorm) 177 | # self.upsample3 = nn.Upsample(scale_factor=4, mode='nearest') 178 | # 179 | # for m in self.modules(): 180 | # if isinstance(m, nn.Conv2d): 181 | # if m.bias is not None: 182 | # init.uniform(m.bias) 183 | # init.xavier_uniform(m.weight) 184 | # 185 | # if isinstance(m, nn.ConvTranspose2d): 186 | # if m.bias is not None: 187 | # init.uniform(m.bias) 188 | # init.xavier_uniform(m.weight) 189 | # # init_deconv_bilinear(m.weight) 190 | # 191 | # def forward(self, inputs): 192 | # rgb_mean = inputs.contiguous().view(inputs.size()[:2] + (-1,)).mean(dim=-1).view(inputs.size()[:2] + (1, 1, 1,)) 193 | # 194 | # x = (inputs - rgb_mean) / self.rgb_max 195 | # x1 = x[:, :, 0, :, :] 196 | # x2 = x[:, :, 1, :, :] 197 | # x = torch.cat((x1, x2), dim=1) 198 | # 199 | # # flownetc 200 | # flownetc_flow2 = self.flownetc(x)[0] 201 | # flownetc_flow = self.upsample1(flownetc_flow2 * self.div_flow) 202 | # 203 | # # warp img1 to img0; magnitude of diff between img0 and and warped_img1, 204 | # resampled_img1 = self.resample1(x[:, 3:, :, :], flownetc_flow) 205 | # diff_img0 = x[:, :3, :, :] - resampled_img1 206 | # norm_diff_img0 = self.channelnorm(diff_img0) 207 | # 208 | # # concat img0, img1, img1->img0, flow, diff-mag ; 209 | # concat1 = torch.cat((x, resampled_img1, flownetc_flow / self.div_flow, norm_diff_img0), dim=1) 210 | # 211 | # # flownets1 212 | # flownets1_flow2 = self.flownets_1(concat1)[0] 213 | # flownets1_flow = self.upsample2(flownets1_flow2 * self.div_flow) 214 | # 215 | # # warp img1 to img0 using flownets1; magnitude of diff between img0 and and warped_img1 216 | # resampled_img1 = self.resample2(x[:, 3:, :, :], flownets1_flow) 217 | # diff_img0 = x[:, :3, :, :] - resampled_img1 218 | # norm_diff_img0 = self.channelnorm(diff_img0) 219 | # 220 | # # concat img0, img1, img1->img0, flow, diff-mag 221 | # concat2 = torch.cat((x, resampled_img1, flownets1_flow / self.div_flow, norm_diff_img0), dim=1) 222 | # 223 | # # flownets2 224 | # flownets2_flow2 = self.flownets_2(concat2)[0] 225 | # flownets2_flow = self.upsample3(flownets2_flow2 * self.div_flow) 226 | # 227 | # return flownets2_flow 228 | # 229 | -------------------------------------------------------------------------------- /ano_pre/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import sys 4 | 5 | from util import psnr_error 6 | from losses import * 7 | 8 | sys.path.append('..') 9 | from Dataset import img_dataset 10 | from models.unet import UNet,_test 11 | from models.pix2pix_networks import PixelDiscriminator 12 | from liteFlownet.lite_flownet import Network,batch_estimate 13 | # if you want to use flownet2-SD 14 | # from flownet2.models import FlowNet2SD 15 | from torch.utils.data import DataLoader 16 | from torch.autograd import Variable 17 | from tensorboardX import SummaryWriter 18 | 19 | 20 | from utils import utils 21 | import os 22 | 23 | from evaluate import evaluate 24 | #your gpu id 25 | torch.cuda.set_device(2) 26 | 27 | 28 | training_data_folder='your_path' 29 | testing_data_folder='your_path' 30 | 31 | writer_path='../log/ano_pred_avenue_2' 32 | 33 | model_generator_save_path='../pth_model/ano_pred_avenue_generator_2.pth' 34 | model_discriminator_save_path='../pth_model/ano_pred_avenue_discriminator_2.pth' 35 | 36 | lite_flow_model_path='../liteFlownet/network-default.pytorch' 37 | # FlowNet2SD path 38 | # flownet2SD_model_path='Your path' 39 | 40 | #----------- all the param as ano pred said --------- 41 | batch_size=4 42 | epochs=20000 43 | pretrain=False 44 | 45 | # color dataset 46 | g_lr=0.0002 47 | d_lr=0.00002 48 | 49 | #different range with the source version, should change 50 | lam_int=1.0*2 51 | lam_gd=1.0*2 52 | # here we use no flow loss 53 | lam_op=0#2.0 54 | 55 | lam_adv=0.05 56 | 57 | #for gradient loss 58 | alpha=1 59 | #for int loss 60 | l_num=2 61 | 62 | num_clips=5 63 | num_his=1 64 | num_unet_layers=4 65 | 66 | num_channels=3#avenue is 3, UCSD is 1 67 | discriminator_channels=[128,256,512,512] 68 | 69 | def weights_init_normal(m): 70 | classname = m.__class__.__name__ 71 | if classname.find('Conv') != -1: 72 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 73 | elif classname.find('BatchNorm2d') != -1: 74 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 75 | torch.nn.init.constant_(m.bias.data, 0.0) 76 | 77 | def train(frame_num,layer_nums,input_channels,output_channels,discriminator_num_filters,bn=False,pretrain=False, 78 | generator_pretrain_path=None,discriminator_pretrain_path=None): 79 | generator=UNet(n_channels=input_channels,layer_nums=layer_nums,output_channel=output_channels,bn=bn) 80 | discriminator=PixelDiscriminator(output_channels,discriminator_num_filters,use_norm=False) 81 | 82 | generator = generator.cuda() 83 | discriminator = discriminator.cuda() 84 | 85 | flow_network=Network() 86 | flow_network.load_state_dict(torch.load(lite_flow_model_path)) 87 | flow_network.cuda().eval() 88 | # if you want to use flownet2SD, comment out the part in front 89 | # flow_network=FlowNet2SD().cuda().eval() 90 | # flow_network.load_state_dict(torch.load(flownet2SD_model_path)['state_dict']) 91 | 92 | adversarial_loss=Adversarial_Loss().cuda() 93 | discriminate_loss=Discriminate_Loss().cuda() 94 | gd_loss=Gradient_Loss(alpha,num_channels).cuda() 95 | op_loss=Flow_Loss().cuda() 96 | int_loss=Intensity_Loss(l_num).cuda() 97 | step = 0 98 | 99 | if not pretrain: 100 | generator.apply(weights_init_normal) 101 | discriminator.apply(weights_init_normal) 102 | else: 103 | assert (generator_pretrain_path!=None and discriminator_pretrain_path!=None) 104 | generator.load_state_dict(torch.load(generator_pretrain_path)) 105 | discriminator.load_state_dict(torch.load(discriminator_pretrain_path)) 106 | step=int(generator_pretrain_path.split('-')[-1]) 107 | print('pretrained model loaded!') 108 | 109 | print('initializing the model with Generator-Unet {} layers,' 110 | 'PixelDiscriminator with filters {} '.format(layer_nums,discriminator_num_filters)) 111 | 112 | optimizer_G=torch.optim.Adam(generator.parameters(),lr=g_lr) 113 | optimizer_D=torch.optim.Adam(discriminator.parameters(),lr=d_lr) 114 | 115 | 116 | writer=SummaryWriter(writer_path) 117 | 118 | dataset=img_dataset.ano_pred_Dataset(training_data_folder,frame_num) 119 | dataset_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True) 120 | 121 | test_dataset=img_dataset.ano_pred_Dataset(testing_data_folder,frame_num) 122 | test_dataloader=DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True,num_workers=1,drop_last=True) 123 | 124 | for epoch in range(epochs): 125 | for (input,test_input) in zip(dataset_loader,test_dataloader): 126 | # generator = generator.train() 127 | # discriminator = discriminator.train() 128 | 129 | target=input[:,-1,:,:,:].cuda() 130 | 131 | input=input[:,:-1,] 132 | input_last=input[:,-1,].cuda() 133 | input=input.view(input.shape[0],-1,input.shape[-2],input.shape[-1]).cuda() 134 | 135 | test_target=test_input[:,-1,].cuda() 136 | test_input=test_input[:,:-1].view(test_input.shape[0],-1,test_input.shape[-2],test_input.shape[-1]).cuda() 137 | 138 | #------- update optim_G -------------- 139 | 140 | G_output=generator(input) 141 | 142 | pred_flow_esti_tensor=torch.cat([input_last,G_output],1) 143 | gt_flow_esti_tensor=torch.cat([input_last,target],1) 144 | flow_gt=batch_estimate(gt_flow_esti_tensor,flow_network) 145 | flow_pred=batch_estimate(pred_flow_esti_tensor,flow_network) 146 | 147 | # if you want to use flownet2SD, comment out the part in front 148 | # pred_flow_esti_tensor = torch.cat([input_last.view(-1,3,1,test_input.shape[-2],test_input.shape[-1]), G_output.view(-1,3,1,test_input.shape[-2],test_input.shape[-1])], 2) 149 | # gt_flow_esti_tensor = torch.cat([input_last.view(-1,3,1,test_input.shape[-2],test_input.shape[-1]), target.view(-1,3,1,test_input.shape[-2],test_input.shape[-1])], 2) 150 | # 151 | # flow_gt=flow_network(gt_flow_esti_tensor*255.0) 152 | # flow_pred=flow_network(pred_flow_esti_tensor*255.0) 153 | 154 | g_adv_loss=adversarial_loss(discriminator(G_output)) 155 | g_op_loss=op_loss(flow_pred,flow_gt) 156 | g_int_loss=int_loss(G_output,target) 157 | g_gd_loss=gd_loss(G_output,target) 158 | 159 | g_loss=lam_adv*g_adv_loss+lam_gd*g_gd_loss+lam_op*g_op_loss+lam_int*g_int_loss 160 | 161 | optimizer_G.zero_grad() 162 | 163 | g_loss.backward() 164 | optimizer_G.step() 165 | 166 | train_psnr=psnr_error(G_output,target) 167 | 168 | #----------- update optim_D ------- 169 | optimizer_D.zero_grad() 170 | 171 | d_loss=discriminate_loss(discriminator(target),discriminator(G_output.detach())) 172 | #d_loss.requires_grad=True 173 | 174 | d_loss.backward() 175 | optimizer_D.step() 176 | 177 | #----------- cal psnr -------------- 178 | test_generator=generator.eval() 179 | test_output=test_generator(test_input) 180 | test_psnr=psnr_error(test_output,test_target).cuda() 181 | 182 | if step%10==0: 183 | print("[{}/{}]: g_loss: {} d_loss {}".format(step,epoch,g_loss,d_loss)) 184 | print('\t gd_loss {}, op_loss {}, int_loss {} ,'.format(g_gd_loss,g_op_loss,g_int_loss)) 185 | print('\t train psnr{},test_psnr {}'.format(train_psnr,test_psnr)) 186 | 187 | writer.add_scalar('psnr/train_psnr', train_psnr, global_step=step) 188 | writer.add_scalar('psnr/test_psnr', test_psnr, global_step=step) 189 | 190 | writer.add_scalar('total_loss/g_loss', g_loss, global_step=step) 191 | writer.add_scalar('total_loss/d_loss', d_loss, global_step=step) 192 | writer.add_scalar('g_loss/adv_loss', g_adv_loss, global_step=step) 193 | writer.add_scalar('g_loss/op_loss', g_op_loss, global_step=step) 194 | writer.add_scalar('g_loss/int_loss', g_int_loss, global_step=step) 195 | writer.add_scalar('g_loss/gd_loss', g_gd_loss, global_step=step) 196 | 197 | writer.add_image('image/train_target', target[0], global_step=step) 198 | writer.add_image('image/train_output', G_output[0], global_step=step) 199 | writer.add_image('image/test_target', test_target[0], global_step=step) 200 | writer.add_image('image/test_output', test_output[0], global_step=step) 201 | 202 | step+=1 203 | 204 | if step%500==0: 205 | utils.saver(generator.state_dict(),model_generator_save_path,step,max_to_save=10) 206 | utils.saver(discriminator.state_dict(),model_discriminator_save_path,step,max_to_save=10) 207 | if step>=2000: 208 | print('==== begin evaluate the model of {} ===='.format(model_generator_save_path+'-'+str(step))) 209 | 210 | auc=evaluate(frame_num=5,layer_nums=4,input_channels=12,output_channels=3, 211 | model_path=model_generator_save_path+'-'+str(step),evaluate_name='compute_auc') 212 | writer.add_scalar('results/auc', auc, global_step=step) 213 | 214 | 215 | if __name__=='__main__': 216 | train(num_clips,num_unet_layers,num_channels*(num_clips-num_his),num_channels,discriminator_channels) 217 | # pretrain=True, 218 | # generator_pretrain_path='../pth_model/ano_pred_avenue_generator_2.pth-4500', 219 | # discriminator_pretrain_path='../pth_model/ano_pred_avenue_discriminator_2.pth-4500') 220 | 221 | #test(num_clips,num_unet_layers,num_channels*(num_clips-num_his),num_channels,discriminator_channels) 222 | #_test() 223 | #test(0,0,0,0,0) 224 | 225 | -------------------------------------------------------------------------------- /flownet2/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | 4 | import os, math, random 5 | from os.path import * 6 | import numpy as np 7 | 8 | from glob import glob 9 | import utils.frame_utils as frame_utils 10 | 11 | from scipy.misc import imread, imresize 12 | 13 | class StaticRandomCrop(object): 14 | def __init__(self, image_size, crop_size): 15 | self.th, self.tw = crop_size 16 | h, w = image_size 17 | self.h1 = random.randint(0, h - self.th) 18 | self.w1 = random.randint(0, w - self.tw) 19 | 20 | def __call__(self, img): 21 | return img[self.h1:(self.h1+self.th), self.w1:(self.w1+self.tw),:] 22 | 23 | class StaticCenterCrop(object): 24 | def __init__(self, image_size, crop_size): 25 | self.th, self.tw = crop_size 26 | self.h, self.w = image_size 27 | def __call__(self, img): 28 | return img[(self.h-self.th)//2:(self.h+self.th)//2, (self.w-self.tw)//2:(self.w+self.tw)//2,:] 29 | 30 | class MpiSintel(data.Dataset): 31 | def __init__(self, args, is_cropped = False, root = '', dstype = 'clean', replicates = 1): 32 | self.args = args 33 | self.is_cropped = is_cropped 34 | self.crop_size = args.crop_size 35 | self.render_size = args.inference_size 36 | self.replicates = replicates 37 | 38 | flow_root = join(root, 'flow') 39 | image_root = join(root, dstype) 40 | 41 | file_list = sorted(glob(join(flow_root, '*/*.flo'))) 42 | 43 | self.flow_list = [] 44 | self.image_list = [] 45 | 46 | for file in file_list: 47 | if 'test' in file: 48 | # print file 49 | continue 50 | 51 | fbase = file[len(flow_root)+1:] 52 | fprefix = fbase[:-8] 53 | fnum = int(fbase[-8:-4]) 54 | 55 | img1 = join(image_root, fprefix + "%04d"%(fnum+0) + '.png') 56 | img2 = join(image_root, fprefix + "%04d"%(fnum+1) + '.png') 57 | 58 | if not isfile(img1) or not isfile(img2) or not isfile(file): 59 | continue 60 | 61 | self.image_list += [[img1, img2]] 62 | self.flow_list += [file] 63 | 64 | self.size = len(self.image_list) 65 | 66 | self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape 67 | 68 | if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0]%64) or (self.frame_size[1]%64): 69 | self.render_size[0] = ( (self.frame_size[0])//64 ) * 64 70 | self.render_size[1] = ( (self.frame_size[1])//64 ) * 64 71 | 72 | args.inference_size = self.render_size 73 | 74 | assert (len(self.image_list) == len(self.flow_list)) 75 | 76 | def __getitem__(self, index): 77 | 78 | index = index % self.size 79 | 80 | img1 = frame_utils.read_gen(self.image_list[index][0]) 81 | img2 = frame_utils.read_gen(self.image_list[index][1]) 82 | 83 | flow = frame_utils.read_gen(self.flow_list[index]) 84 | 85 | images = [img1, img2] 86 | image_size = img1.shape[:2] 87 | 88 | if self.is_cropped: 89 | cropper = StaticRandomCrop(image_size, self.crop_size) 90 | else: 91 | cropper = StaticCenterCrop(image_size, self.render_size) 92 | images = list(map(cropper, images)) 93 | flow = cropper(flow) 94 | 95 | images = np.array(images).transpose(3,0,1,2) 96 | flow = flow.transpose(2,0,1) 97 | 98 | images = torch.from_numpy(images.astype(np.float32)) 99 | flow = torch.from_numpy(flow.astype(np.float32)) 100 | 101 | return [images], [flow] 102 | 103 | def __len__(self): 104 | return self.size * self.replicates 105 | 106 | class MpiSintelClean(MpiSintel): 107 | def __init__(self, args, is_cropped = False, root = '', replicates = 1): 108 | super(MpiSintelClean, self).__init__(args, is_cropped = is_cropped, root = root, dstype = 'clean', replicates = replicates) 109 | 110 | class MpiSintelFinal(MpiSintel): 111 | def __init__(self, args, is_cropped = False, root = '', replicates = 1): 112 | super(MpiSintelFinal, self).__init__(args, is_cropped = is_cropped, root = root, dstype = 'final', replicates = replicates) 113 | 114 | class FlyingChairs(data.Dataset): 115 | def __init__(self, args, is_cropped, root = '/path/to/FlyingChairs_release/data', replicates = 1): 116 | self.args = args 117 | self.is_cropped = is_cropped 118 | self.crop_size = args.crop_size 119 | self.render_size = args.inference_size 120 | self.replicates = replicates 121 | 122 | images = sorted( glob( join(root, '*.ppm') ) ) 123 | 124 | self.flow_list = sorted( glob( join(root, '*.flo') ) ) 125 | 126 | assert (len(images)//2 == len(self.flow_list)) 127 | 128 | self.image_list = [] 129 | for i in range(len(self.flow_list)): 130 | im1 = images[2*i] 131 | im2 = images[2*i + 1] 132 | self.image_list += [ [ im1, im2 ] ] 133 | 134 | assert len(self.image_list) == len(self.flow_list) 135 | 136 | self.size = len(self.image_list) 137 | 138 | self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape 139 | 140 | if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0]%64) or (self.frame_size[1]%64): 141 | self.render_size[0] = ( (self.frame_size[0])//64 ) * 64 142 | self.render_size[1] = ( (self.frame_size[1])//64 ) * 64 143 | 144 | args.inference_size = self.render_size 145 | 146 | def __getitem__(self, index): 147 | index = index % self.size 148 | 149 | img1 = frame_utils.read_gen(self.image_list[index][0]) 150 | img2 = frame_utils.read_gen(self.image_list[index][1]) 151 | 152 | flow = frame_utils.read_gen(self.flow_list[index]) 153 | 154 | images = [img1, img2] 155 | image_size = img1.shape[:2] 156 | if self.is_cropped: 157 | cropper = StaticRandomCrop(image_size, self.crop_size) 158 | else: 159 | cropper = StaticCenterCrop(image_size, self.render_size) 160 | images = list(map(cropper, images)) 161 | flow = cropper(flow) 162 | 163 | 164 | images = np.array(images).transpose(3,0,1,2) 165 | flow = flow.transpose(2,0,1) 166 | 167 | images = torch.from_numpy(images.astype(np.float32)) 168 | flow = torch.from_numpy(flow.astype(np.float32)) 169 | 170 | return [images], [flow] 171 | 172 | def __len__(self): 173 | return self.size * self.replicates 174 | 175 | class FlyingThings(data.Dataset): 176 | def __init__(self, args, is_cropped, root = '/path/to/flyingthings3d', dstype = 'frames_cleanpass', replicates = 1): 177 | self.args = args 178 | self.is_cropped = is_cropped 179 | self.crop_size = args.crop_size 180 | self.render_size = args.inference_size 181 | self.replicates = replicates 182 | 183 | image_dirs = sorted(glob(join(root, dstype, 'TRAIN/*/*'))) 184 | image_dirs = sorted([join(f, 'left') for f in image_dirs] + [join(f, 'right') for f in image_dirs]) 185 | 186 | flow_dirs = sorted(glob(join(root, 'optical_flow_flo_format/TRAIN/*/*'))) 187 | flow_dirs = sorted([join(f, 'into_future/left') for f in flow_dirs] + [join(f, 'into_future/right') for f in flow_dirs]) 188 | 189 | assert (len(image_dirs) == len(flow_dirs)) 190 | 191 | self.image_list = [] 192 | self.flow_list = [] 193 | 194 | for idir, fdir in zip(image_dirs, flow_dirs): 195 | images = sorted( glob(join(idir, '*.png')) ) 196 | flows = sorted( glob(join(fdir, '*.flo')) ) 197 | for i in range(len(flows)): 198 | self.image_list += [ [ images[i], images[i+1] ] ] 199 | self.flow_list += [flows[i]] 200 | 201 | assert len(self.image_list) == len(self.flow_list) 202 | 203 | self.size = len(self.image_list) 204 | 205 | self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape 206 | 207 | if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0]%64) or (self.frame_size[1]%64): 208 | self.render_size[0] = ( (self.frame_size[0])//64 ) * 64 209 | self.render_size[1] = ( (self.frame_size[1])//64 ) * 64 210 | 211 | args.inference_size = self.render_size 212 | 213 | def __getitem__(self, index): 214 | index = index % self.size 215 | 216 | img1 = frame_utils.read_gen(self.image_list[index][0]) 217 | img2 = frame_utils.read_gen(self.image_list[index][1]) 218 | 219 | flow = frame_utils.read_gen(self.flow_list[index]) 220 | 221 | images = [img1, img2] 222 | image_size = img1.shape[:2] 223 | if self.is_cropped: 224 | cropper = StaticRandomCrop(image_size, self.crop_size) 225 | else: 226 | cropper = StaticCenterCrop(image_size, self.render_size) 227 | images = list(map(cropper, images)) 228 | flow = cropper(flow) 229 | 230 | 231 | images = np.array(images).transpose(3,0,1,2) 232 | flow = flow.transpose(2,0,1) 233 | 234 | images = torch.from_numpy(images.astype(np.float32)) 235 | flow = torch.from_numpy(flow.astype(np.float32)) 236 | 237 | return [images], [flow] 238 | 239 | def __len__(self): 240 | return self.size * self.replicates 241 | 242 | class FlyingThingsClean(FlyingThings): 243 | def __init__(self, args, is_cropped = False, root = '', replicates = 1): 244 | super(FlyingThingsClean, self).__init__(args, is_cropped = is_cropped, root = root, dstype = 'frames_cleanpass', replicates = replicates) 245 | 246 | class FlyingThingsFinal(FlyingThings): 247 | def __init__(self, args, is_cropped = False, root = '', replicates = 1): 248 | super(FlyingThingsFinal, self).__init__(args, is_cropped = is_cropped, root = root, dstype = 'frames_finalpass', replicates = replicates) 249 | 250 | class ChairsSDHom(data.Dataset): 251 | def __init__(self, args, is_cropped, root = '/path/to/chairssdhom/data', dstype = 'train', replicates = 1): 252 | self.args = args 253 | self.is_cropped = is_cropped 254 | self.crop_size = args.crop_size 255 | self.render_size = args.inference_size 256 | self.replicates = replicates 257 | 258 | image1 = sorted( glob( join(root, dstype, 't0/*.png') ) ) 259 | image2 = sorted( glob( join(root, dstype, 't1/*.png') ) ) 260 | self.flow_list = sorted( glob( join(root, dstype, 'flow/*.flo') ) ) 261 | 262 | assert (len(image1) == len(self.flow_list)) 263 | 264 | self.image_list = [] 265 | for i in range(len(self.flow_list)): 266 | im1 = image1[i] 267 | im2 = image2[i] 268 | self.image_list += [ [ im1, im2 ] ] 269 | 270 | assert len(self.image_list) == len(self.flow_list) 271 | 272 | self.size = len(self.image_list) 273 | 274 | self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape 275 | 276 | if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0]%64) or (self.frame_size[1]%64): 277 | self.render_size[0] = ( (self.frame_size[0])//64 ) * 64 278 | self.render_size[1] = ( (self.frame_size[1])//64 ) * 64 279 | 280 | args.inference_size = self.render_size 281 | 282 | def __getitem__(self, index): 283 | index = index % self.size 284 | 285 | img1 = frame_utils.read_gen(self.image_list[index][0]) 286 | img2 = frame_utils.read_gen(self.image_list[index][1]) 287 | 288 | flow = frame_utils.read_gen(self.flow_list[index]) 289 | flow = flow[::-1,:,:] 290 | 291 | images = [img1, img2] 292 | image_size = img1.shape[:2] 293 | if self.is_cropped: 294 | cropper = StaticRandomCrop(image_size, self.crop_size) 295 | else: 296 | cropper = StaticCenterCrop(image_size, self.render_size) 297 | images = list(map(cropper, images)) 298 | flow = cropper(flow) 299 | 300 | 301 | images = np.array(images).transpose(3,0,1,2) 302 | flow = flow.transpose(2,0,1) 303 | 304 | images = torch.from_numpy(images.astype(np.float32)) 305 | flow = torch.from_numpy(flow.astype(np.float32)) 306 | 307 | return [images], [flow] 308 | 309 | def __len__(self): 310 | return self.size * self.replicates 311 | 312 | class ChairsSDHomTrain(ChairsSDHom): 313 | def __init__(self, args, is_cropped = False, root = '', replicates = 1): 314 | super(ChairsSDHomTrain, self).__init__(args, is_cropped = is_cropped, root = root, dstype = 'train', replicates = replicates) 315 | 316 | class ChairsSDHomTest(ChairsSDHom): 317 | def __init__(self, args, is_cropped = False, root = '', replicates = 1): 318 | super(ChairsSDHomTest, self).__init__(args, is_cropped = is_cropped, root = root, dstype = 'test', replicates = replicates) 319 | 320 | class ImagesFromFolder(data.Dataset): 321 | def __init__(self, args, is_cropped, root = '/path/to/frames/only/folder', iext = 'png', replicates = 1): 322 | self.args = args 323 | self.is_cropped = is_cropped 324 | self.crop_size = args.crop_size 325 | self.render_size = args.inference_size 326 | self.replicates = replicates 327 | 328 | images = sorted( glob( join(root, '*.' + iext) ) ) 329 | self.image_list = [] 330 | for i in range(len(images)-1): 331 | im1 = images[i] 332 | im2 = images[i+1] 333 | self.image_list += [ [ im1, im2 ] ] 334 | 335 | self.size = len(self.image_list) 336 | 337 | self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape 338 | 339 | if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0]%64) or (self.frame_size[1]%64): 340 | self.render_size[0] = ( (self.frame_size[0])//64 ) * 64 341 | self.render_size[1] = ( (self.frame_size[1])//64 ) * 64 342 | 343 | args.inference_size = self.render_size 344 | 345 | def __getitem__(self, index): 346 | index = index % self.size 347 | 348 | img1 = frame_utils.read_gen(self.image_list[index][0]) 349 | img2 = frame_utils.read_gen(self.image_list[index][1]) 350 | 351 | images = [img1, img2] 352 | image_size = img1.shape[:2] 353 | if self.is_cropped: 354 | cropper = StaticRandomCrop(image_size, self.crop_size) 355 | else: 356 | cropper = StaticCenterCrop(image_size, self.render_size) 357 | images = list(map(cropper, images)) 358 | 359 | images = np.array(images).transpose(3,0,1,2) 360 | images = torch.from_numpy(images.astype(np.float32)) 361 | 362 | return [images], [torch.zeros(images.size()[0:1] + (2,) + images.size()[-2:])] 363 | 364 | def __len__(self): 365 | return self.size * self.replicates 366 | 367 | ''' 368 | import argparse 369 | import sys, os 370 | import importlib 371 | from scipy.misc import imsave 372 | import numpy as np 373 | 374 | import datasets 375 | reload(datasets) 376 | 377 | parser = argparse.ArgumentParser() 378 | args = parser.parse_args() 379 | args.inference_size = [1080, 1920] 380 | args.crop_size = [384, 512] 381 | args.effective_batch_size = 1 382 | 383 | index = 500 384 | v_dataset = datasets.MpiSintelClean(args, True, root='../MPI-Sintel/flow/training') 385 | a, b = v_dataset[index] 386 | im1 = a[0].numpy()[:,0,:,:].transpose(1,2,0) 387 | im2 = a[0].numpy()[:,1,:,:].transpose(1,2,0) 388 | imsave('./img1.png', im1) 389 | imsave('./img2.png', im2) 390 | flow_utils.writeFlow('./flow.flo', b[0].numpy().transpose(1,2,0)) 391 | 392 | ''' 393 | -------------------------------------------------------------------------------- /liteFlownet/correlation/correlation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import cupy 4 | import math 5 | import re 6 | 7 | 8 | class Stream: 9 | ptr = torch.cuda.current_stream().cuda_stream 10 | 11 | 12 | # end 13 | 14 | kernel_Correlation_rearrange = ''' 15 | extern "C" __global__ void kernel_Correlation_rearrange( 16 | const int n, 17 | const float* input, 18 | float* output 19 | ) { 20 | int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; 21 | 22 | if (intIndex >= n) { 23 | return; 24 | } 25 | 26 | int intSample = blockIdx.z; 27 | int intChannel = blockIdx.y; 28 | 29 | float dblValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; 30 | 31 | __syncthreads(); 32 | 33 | int intPaddedY = (intIndex / SIZE_3(input)) + 3*{{intStride}}; 34 | int intPaddedX = (intIndex % SIZE_3(input)) + 3*{{intStride}}; 35 | int intRearrange = ((SIZE_3(input) + 6*{{intStride}}) * intPaddedY) + intPaddedX; 36 | 37 | output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = dblValue; 38 | } 39 | ''' 40 | 41 | kernel_Correlation_updateOutput = ''' 42 | extern "C" __global__ void kernel_Correlation_updateOutput( 43 | const int n, 44 | const float* rbot0, 45 | const float* rbot1, 46 | float* top 47 | ) { 48 | extern __shared__ char patch_data_char[]; 49 | 50 | float *patch_data = (float *)patch_data_char; 51 | 52 | // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 53 | int x1 = (blockIdx.x + 3) * {{intStride}}; 54 | int y1 = (blockIdx.y + 3) * {{intStride}}; 55 | int item = blockIdx.z; 56 | int ch_off = threadIdx.x; 57 | 58 | // Load 3D patch into shared shared memory 59 | for (int j = 0; j < 1; j++) { // HEIGHT 60 | for (int i = 0; i < 1; i++) { // WIDTH 61 | int ji_off = (j + i) * SIZE_3(rbot0); 62 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS 63 | int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; 64 | int idxPatchData = ji_off + ch; 65 | patch_data[idxPatchData] = rbot0[idx1]; 66 | } 67 | } 68 | } 69 | 70 | __syncthreads(); 71 | 72 | __shared__ float sum[32]; 73 | 74 | // Compute correlation 75 | for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { 76 | sum[ch_off] = 0; 77 | 78 | int s2o = (top_channel % 7 - 3) * {{intStride}}; 79 | int s2p = (top_channel / 7 - 3) * {{intStride}}; 80 | 81 | for (int j = 0; j < 1; j++) { // HEIGHT 82 | for (int i = 0; i < 1; i++) { // WIDTH 83 | int ji_off = (j + i) * SIZE_3(rbot0); 84 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS 85 | int x2 = x1 + s2o; 86 | int y2 = y1 + s2p; 87 | 88 | int idxPatchData = ji_off + ch; 89 | int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; 90 | 91 | sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; 92 | } 93 | } 94 | } 95 | 96 | __syncthreads(); 97 | 98 | if (ch_off == 0) { 99 | float total_sum = 0; 100 | for (int idx = 0; idx < 32; idx++) { 101 | total_sum += sum[idx]; 102 | } 103 | const int sumelems = SIZE_3(rbot0); 104 | const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; 105 | top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; 106 | } 107 | } 108 | } 109 | ''' 110 | 111 | kernel_Correlation_updateGradFirst = ''' 112 | #define ROUND_OFF 50000 113 | 114 | extern "C" __global__ void kernel_Correlation_updateGradFirst( 115 | const int n, 116 | const int intSample, 117 | const float* rbot0, 118 | const float* rbot1, 119 | const float* gradOutput, 120 | float* gradFirst, 121 | float* gradSecond 122 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 123 | int n = intIndex % SIZE_1(gradFirst); // channels 124 | int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 3*{{intStride}}; // w-pos 125 | int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 3*{{intStride}}; // h-pos 126 | 127 | // round_off is a trick to enable integer division with ceil, even for negative numbers 128 | // We use a large offset, for the inner part not to become negative. 129 | const int round_off = ROUND_OFF; 130 | const int round_off_s1 = {{intStride}} * round_off; 131 | 132 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: 133 | int xmin = (l - 3*{{intStride}} + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}}) / {{intStride}} 134 | int ymin = (m - 3*{{intStride}} + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}}) / {{intStride}} 135 | 136 | // Same here: 137 | int xmax = (l - 3*{{intStride}} + round_off_s1) / {{intStride}} - round_off; // floor (l - 3*{{intStride}}) / {{intStride}} 138 | int ymax = (m - 3*{{intStride}} + round_off_s1) / {{intStride}} - round_off; // floor (m - 3*{{intStride}}) / {{intStride}} 139 | 140 | float sum = 0; 141 | if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { 142 | xmin = max(0,xmin); 143 | xmax = min(SIZE_3(gradOutput)-1,xmax); 144 | 145 | ymin = max(0,ymin); 146 | ymax = min(SIZE_2(gradOutput)-1,ymax); 147 | 148 | for (int p = -3; p <= 3; p++) { 149 | for (int o = -3; o <= 3; o++) { 150 | // Get rbot1 data: 151 | int s2o = {{intStride}} * o; 152 | int s2p = {{intStride}} * p; 153 | int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n; 154 | float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n] 155 | 156 | // Index offset for gradOutput in following loops: 157 | int op = (p+3) * 7 + (o+3); // index[o,p] 158 | int idxopoffset = (intSample * SIZE_1(gradOutput) + op); 159 | 160 | for (int y = ymin; y <= ymax; y++) { 161 | for (int x = xmin; x <= xmax; x++) { 162 | int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] 163 | sum += gradOutput[idxgradOutput] * bot1tmp; 164 | } 165 | } 166 | } 167 | } 168 | } 169 | const int sumelems = SIZE_1(gradFirst); 170 | const int bot0index = ((n * SIZE_2(gradFirst)) + (m-3*{{intStride}})) * SIZE_3(gradFirst) + (l-3*{{intStride}}); 171 | gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems; 172 | } } 173 | ''' 174 | 175 | kernel_Correlation_updateGradSecond = ''' 176 | #define ROUND_OFF 50000 177 | 178 | extern "C" __global__ void kernel_Correlation_updateGradSecond( 179 | const int n, 180 | const int intSample, 181 | const float* rbot0, 182 | const float* rbot1, 183 | const float* gradOutput, 184 | float* gradFirst, 185 | float* gradSecond 186 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 187 | int n = intIndex % SIZE_1(gradSecond); // channels 188 | int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 3*{{intStride}}; // w-pos 189 | int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 3*{{intStride}}; // h-pos 190 | 191 | // round_off is a trick to enable integer division with ceil, even for negative numbers 192 | // We use a large offset, for the inner part not to become negative. 193 | const int round_off = ROUND_OFF; 194 | const int round_off_s1 = {{intStride}} * round_off; 195 | 196 | float sum = 0; 197 | for (int p = -3; p <= 3; p++) { 198 | for (int o = -3; o <= 3; o++) { 199 | int s2o = {{intStride}} * o; 200 | int s2p = {{intStride}} * p; 201 | 202 | //Get X,Y ranges and clamp 203 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: 204 | int xmin = (l - 3*{{intStride}} - s2o + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}} - s2o) / {{intStride}} 205 | int ymin = (m - 3*{{intStride}} - s2p + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}} - s2o) / {{intStride}} 206 | 207 | // Same here: 208 | int xmax = (l - 3*{{intStride}} - s2o + round_off_s1) / {{intStride}} - round_off; // floor (l - 3*{{intStride}} - s2o) / {{intStride}} 209 | int ymax = (m - 3*{{intStride}} - s2p + round_off_s1) / {{intStride}} - round_off; // floor (m - 3*{{intStride}} - s2p) / {{intStride}} 210 | 211 | if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { 212 | xmin = max(0,xmin); 213 | xmax = min(SIZE_3(gradOutput)-1,xmax); 214 | 215 | ymin = max(0,ymin); 216 | ymax = min(SIZE_2(gradOutput)-1,ymax); 217 | 218 | // Get rbot0 data: 219 | int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n; 220 | float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n] 221 | 222 | // Index offset for gradOutput in following loops: 223 | int op = (p+3) * 7 + (o+3); // index[o,p] 224 | int idxopoffset = (intSample * SIZE_1(gradOutput) + op); 225 | 226 | for (int y = ymin; y <= ymax; y++) { 227 | for (int x = xmin; x <= xmax; x++) { 228 | int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] 229 | sum += gradOutput[idxgradOutput] * bot0tmp; 230 | } 231 | } 232 | } 233 | } 234 | } 235 | const int sumelems = SIZE_1(gradSecond); 236 | const int bot1index = ((n * SIZE_2(gradSecond)) + (m-3*{{intStride}})) * SIZE_3(gradSecond) + (l-3*{{intStride}}); 237 | gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems; 238 | } } 239 | ''' 240 | 241 | 242 | def cupy_kernel(strFunction, objectVariables): 243 | strKernel = globals()[strFunction].replace('{{intStride}}', str(objectVariables['intStride'])) 244 | 245 | while True: 246 | objectMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) 247 | 248 | if objectMatch is None: 249 | break 250 | # end 251 | 252 | intArg = int(objectMatch.group(2)) 253 | 254 | strTensor = objectMatch.group(4) 255 | intSizes = objectVariables[strTensor].size() 256 | 257 | strKernel = strKernel.replace(objectMatch.group(), str(intSizes[intArg])) 258 | # end 259 | 260 | while True: 261 | objectMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) 262 | 263 | if objectMatch is None: 264 | break 265 | # end 266 | 267 | intArgs = int(objectMatch.group(2)) 268 | strArgs = objectMatch.group(4).split(',') 269 | 270 | strTensor = strArgs[0] 271 | intStrides = objectVariables[strTensor].stride() 272 | strIndex = ['((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str( 273 | intStrides[intArg]) + ')' for intArg in range(intArgs)] 274 | 275 | strKernel = strKernel.replace(objectMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') 276 | # end 277 | 278 | return strKernel 279 | 280 | 281 | # end 282 | 283 | @cupy.util.memoize(for_each_device=True) 284 | def cupy_launch(strFunction, strKernel): 285 | return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) 286 | 287 | 288 | # end 289 | 290 | class _FunctionCorrelation(torch.autograd.Function): 291 | @staticmethod 292 | def forward(self, first, second, intStride): 293 | rbot0 = first.new_zeros( 294 | [first.size(0), first.size(2) + (6 * intStride), first.size(3) + (6 * intStride), first.size(1)]) 295 | rbot1 = first.new_zeros( 296 | [first.size(0), first.size(2) + (6 * intStride), first.size(3) + (6 * intStride), first.size(1)]) 297 | 298 | self.save_for_backward(first, second, rbot0, rbot1) 299 | 300 | self.intStride = intStride 301 | 302 | assert (first.is_contiguous() == True) 303 | assert (second.is_contiguous() == True) 304 | 305 | output = first.new_zeros( 306 | [first.size(0), 49, int(math.ceil(first.size(2) / intStride)), int(math.ceil(first.size(3) / intStride))]) 307 | 308 | if first.is_cuda == True: 309 | n = first.size(2) * first.size(3) 310 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { 311 | 'intStride': self.intStride, 312 | 'input': first, 313 | 'output': rbot0 314 | }))( 315 | grid=tuple([int((n + 16 - 1) / 16), first.size(1), first.size(0)]), 316 | block=tuple([16, 1, 1]), 317 | args=[n, first.data_ptr(), rbot0.data_ptr()], 318 | stream=Stream 319 | ) 320 | 321 | n = second.size(2) * second.size(3) 322 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { 323 | 'intStride': self.intStride, 324 | 'input': second, 325 | 'output': rbot1 326 | }))( 327 | grid=tuple([int((n + 16 - 1) / 16), second.size(1), second.size(0)]), 328 | block=tuple([16, 1, 1]), 329 | args=[n, second.data_ptr(), rbot1.data_ptr()], 330 | stream=Stream 331 | ) 332 | 333 | n = output.size(1) * output.size(2) * output.size(3) 334 | cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', { 335 | 'intStride': self.intStride, 336 | 'rbot0': rbot0, 337 | 'rbot1': rbot1, 338 | 'top': output 339 | }))( 340 | grid=tuple([output.size(3), output.size(2), output.size(0)]), 341 | block=tuple([32, 1, 1]), 342 | shared_mem=first.size(1) * 4, 343 | args=[n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr()], 344 | stream=Stream 345 | ) 346 | 347 | elif first.is_cuda == False: 348 | raise NotImplementedError() 349 | 350 | # end 351 | 352 | return output 353 | 354 | # end 355 | 356 | @staticmethod 357 | def backward(self, gradOutput): 358 | first, second, rbot0, rbot1 = self.saved_tensors 359 | 360 | assert (gradOutput.is_contiguous() == True) 361 | 362 | gradFirst = first.new_zeros([first.size(0), first.size(1), first.size(2), first.size(3)]) if \ 363 | self.needs_input_grad[0] == True else None 364 | gradSecond = first.new_zeros([first.size(0), first.size(1), first.size(2), first.size(3)]) if \ 365 | self.needs_input_grad[1] == True else None 366 | 367 | if first.is_cuda == True: 368 | if gradFirst is not None: 369 | for intSample in range(first.size(0)): 370 | n = first.size(1) * first.size(2) * first.size(3) 371 | cupy_launch('kernel_Correlation_updateGradFirst', 372 | cupy_kernel('kernel_Correlation_updateGradFirst', { 373 | 'intStride': self.intStride, 374 | 'rbot0': rbot0, 375 | 'rbot1': rbot1, 376 | 'gradOutput': gradOutput, 377 | 'gradFirst': gradFirst, 378 | 'gradSecond': None 379 | }))( 380 | grid=tuple([int((n + 512 - 1) / 512), 1, 1]), 381 | block=tuple([512, 1, 1]), 382 | args=[n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), 383 | gradFirst.data_ptr(), None], 384 | stream=Stream 385 | ) 386 | # end 387 | # end 388 | 389 | if gradSecond is not None: 390 | for intSample in range(first.size(0)): 391 | n = first.size(1) * first.size(2) * first.size(3) 392 | cupy_launch('kernel_Correlation_updateGradSecond', 393 | cupy_kernel('kernel_Correlation_updateGradSecond', { 394 | 'intStride': self.intStride, 395 | 'rbot0': rbot0, 396 | 'rbot1': rbot1, 397 | 'gradOutput': gradOutput, 398 | 'gradFirst': None, 399 | 'gradSecond': gradSecond 400 | }))( 401 | grid=tuple([int((n + 512 - 1) / 512), 1, 1]), 402 | block=tuple([512, 1, 1]), 403 | args=[n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, 404 | gradSecond.data_ptr()], 405 | stream=Stream 406 | ) 407 | # end 408 | # end 409 | 410 | elif first.is_cuda == False: 411 | raise NotImplementedError() 412 | 413 | # end 414 | 415 | return gradFirst, gradSecond, None 416 | 417 | 418 | # end 419 | # end 420 | 421 | def FunctionCorrelation(tensorFirst, tensorSecond, intStride): 422 | return _FunctionCorrelation.apply(tensorFirst, tensorSecond, intStride) 423 | 424 | 425 | # end 426 | 427 | class ModuleCorrelation(torch.nn.Module): 428 | def __init__(self): 429 | super(ModuleCorrelation, self).__init__() 430 | 431 | # end 432 | 433 | def forward(self, tensorFirst, tensorSecond, intStride): 434 | return _FunctionCorrelation.apply(tensorFirst, tensorSecond, intStride) 435 | # end 436 | # end 437 | -------------------------------------------------------------------------------- /liteFlownet/lite_flownet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | 5 | import getopt 6 | import math 7 | import numpy 8 | import os 9 | import PIL 10 | import PIL.Image 11 | import sys 12 | 13 | try: 14 | from .correlation import correlation # the custom cost volume layer 15 | except: 16 | sys.path.insert(0, './correlation'); import correlation # you should consider upgrading python 17 | # end 18 | 19 | ########################################################## 20 | 21 | assert(int(str('').join(torch.__version__.split('.')[0:3])) >= 41) # requires at least pytorch version 0.4.1 22 | 23 | #torch.set_grad_enabled(False) # make sure to not compute gradients for computational performance 24 | 25 | #torch.cuda.device(1) # change this if you have a multiple graphics cards and you want to utilize them 26 | 27 | # torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance 28 | 29 | ########################################################## 30 | 31 | # arguments_strModel = 'default' 32 | # arguments_strFirst = './images/first.png' 33 | # arguments_strSecond = './images/second.png' 34 | # arguments_strOut = './out.flo' 35 | 36 | # for strOption, strArgument in getopt.getopt(sys.argv[1:], '', [ strParameter[2:] + '=' for strParameter in sys.argv[1::2] ])[0]: 37 | # if strOption == '--model' and strArgument != '': arguments_strModel = strArgument # which model to use 38 | # if strOption == '--first' and strArgument != '': arguments_strFirst = strArgument # path to the first frame 39 | # if strOption == '--second' and strArgument != '': arguments_strSecond = strArgument # path to the second frame 40 | # if strOption == '--out' and strArgument != '': arguments_strOut = strArgument # path to where the output should be stored 41 | # end 42 | 43 | ########################################################## 44 | 45 | Backward_tensorGrid = {} 46 | 47 | def Backward(tensorInput, tensorFlow): 48 | if str(tensorFlow.size()) not in Backward_tensorGrid: 49 | tensorHorizontal = torch.linspace(-1.0, 1.0, tensorFlow.size(3)).view(1, 1, 1, tensorFlow.size(3)).expand(tensorFlow.size(0), -1, tensorFlow.size(2), -1) 50 | tensorVertical = torch.linspace(-1.0, 1.0, tensorFlow.size(2)).view(1, 1, tensorFlow.size(2), 1).expand(tensorFlow.size(0), -1, -1, tensorFlow.size(3)) 51 | 52 | Backward_tensorGrid[str(tensorFlow.size())] = torch.cat([ tensorHorizontal, tensorVertical ], 1).cuda() 53 | # end 54 | 55 | tensorFlow = torch.cat([ tensorFlow[:, 0:1, :, :] / ((tensorInput.size(3) - 1.0) / 2.0), tensorFlow[:, 1:2, :, :] / ((tensorInput.size(2) - 1.0) / 2.0) ], 1) 56 | 57 | return torch.nn.functional.grid_sample(input=tensorInput, grid=(Backward_tensorGrid[str(tensorFlow.size())] + tensorFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros') 58 | # end 59 | 60 | ########################################################## 61 | 62 | class Network(torch.nn.Module): 63 | def __init__(self): 64 | super(Network, self).__init__() 65 | 66 | class Features(torch.nn.Module): 67 | def __init__(self): 68 | super(Features, self).__init__() 69 | 70 | self.moduleOne = torch.nn.Sequential( 71 | torch.nn.Conv2d(in_channels=3, out_channels=32, kernel_size=7, stride=1, padding=3), 72 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 73 | ) 74 | 75 | self.moduleTwo = torch.nn.Sequential( 76 | torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1), 77 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 78 | torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), 79 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 80 | torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), 81 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 82 | ) 83 | 84 | self.moduleThr = torch.nn.Sequential( 85 | torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), 86 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 87 | torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), 88 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 89 | ) 90 | 91 | self.moduleFou = torch.nn.Sequential( 92 | torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1), 93 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 94 | torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), 95 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 96 | ) 97 | 98 | self.moduleFiv = torch.nn.Sequential( 99 | torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1), 100 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 101 | ) 102 | 103 | self.moduleSix = torch.nn.Sequential( 104 | torch.nn.Conv2d(in_channels=128, out_channels=192, kernel_size=3, stride=2, padding=1), 105 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 106 | ) 107 | # end 108 | 109 | def forward(self, tensorInput): 110 | tensorOne = self.moduleOne(tensorInput) 111 | tensorTwo = self.moduleTwo(tensorOne) 112 | tensorThr = self.moduleThr(tensorTwo) 113 | tensorFou = self.moduleFou(tensorThr) 114 | tensorFiv = self.moduleFiv(tensorFou) 115 | tensorSix = self.moduleSix(tensorFiv) 116 | 117 | return [ tensorOne, tensorTwo, tensorThr, tensorFou, tensorFiv, tensorSix ] 118 | # end 119 | # end 120 | 121 | class Matching(torch.nn.Module): 122 | def __init__(self, intLevel): 123 | super(Matching, self).__init__() 124 | 125 | self.dblBackward = [ 0.0, 0.0, 10.0, 5.0, 2.5, 1.25, 0.625 ][intLevel] 126 | 127 | if intLevel != 2: 128 | self.moduleFeat = torch.nn.Sequential() 129 | 130 | elif intLevel == 2: 131 | self.moduleFeat = torch.nn.Sequential( 132 | torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=1, stride=1, padding=0), 133 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 134 | ) 135 | 136 | # end 137 | 138 | if intLevel == 6: 139 | self.moduleUpflow = None 140 | 141 | elif intLevel != 6: 142 | self.moduleUpflow = torch.nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=4, stride=2, padding=1, bias=False, groups=2) 143 | 144 | # end 145 | 146 | if intLevel >= 4: 147 | self.moduleUpcorr = None 148 | 149 | elif intLevel < 4: 150 | self.moduleUpcorr = torch.nn.ConvTranspose2d(in_channels=49, out_channels=49, kernel_size=4, stride=2, padding=1, bias=False, groups=49) 151 | 152 | # end 153 | 154 | self.moduleMain = torch.nn.Sequential( 155 | torch.nn.Conv2d(in_channels=49, out_channels=128, kernel_size=3, stride=1, padding=1), 156 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 157 | torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1), 158 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 159 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1), 160 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 161 | torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=[ 0, 0, 7, 5, 5, 3, 3 ][intLevel], stride=1, padding=[ 0, 0, 3, 2, 2, 1, 1 ][intLevel]) 162 | ) 163 | # end 164 | 165 | def forward(self, tensorFirst, tensorSecond, tensorFeaturesFirst, tensorFeaturesSecond, tensorFlow): 166 | tensorFeaturesFirst = self.moduleFeat(tensorFeaturesFirst) 167 | tensorFeaturesSecond = self.moduleFeat(tensorFeaturesSecond) 168 | 169 | if tensorFlow is not None: 170 | tensorFlow = self.moduleUpflow(tensorFlow) 171 | # end 172 | 173 | if tensorFlow is not None: 174 | tensorFeaturesSecond = Backward(tensorInput=tensorFeaturesSecond, tensorFlow=tensorFlow * self.dblBackward) 175 | # end 176 | 177 | if self.moduleUpcorr is None: 178 | tensorCorrelation = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tensorFirst=tensorFeaturesFirst, tensorSecond=tensorFeaturesSecond, intStride=1), negative_slope=0.1, inplace=False) 179 | 180 | elif self.moduleUpcorr is not None: 181 | tensorCorrelation = self.moduleUpcorr(torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tensorFirst=tensorFeaturesFirst, tensorSecond=tensorFeaturesSecond, intStride=2), negative_slope=0.1, inplace=False)) 182 | 183 | # end 184 | 185 | return (tensorFlow if tensorFlow is not None else 0.0) + self.moduleMain(tensorCorrelation) 186 | # end 187 | # end 188 | 189 | class Subpixel(torch.nn.Module): 190 | def __init__(self, intLevel): 191 | super(Subpixel, self).__init__() 192 | 193 | self.dblBackward = [ 0.0, 0.0, 10.0, 5.0, 2.5, 1.25, 0.625 ][intLevel] 194 | 195 | if intLevel != 2: 196 | self.moduleFeat = torch.nn.Sequential() 197 | 198 | elif intLevel == 2: 199 | self.moduleFeat = torch.nn.Sequential( 200 | torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=1, stride=1, padding=0), 201 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 202 | ) 203 | 204 | # end 205 | 206 | self.moduleMain = torch.nn.Sequential( 207 | torch.nn.Conv2d(in_channels=[ 0, 0, 130, 130, 194, 258, 386 ][intLevel], out_channels=128, kernel_size=3, stride=1, padding=1), 208 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 209 | torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1), 210 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 211 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1), 212 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 213 | torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=[ 0, 0, 7, 5, 5, 3, 3 ][intLevel], stride=1, padding=[ 0, 0, 3, 2, 2, 1, 1 ][intLevel]) 214 | ) 215 | # end 216 | 217 | def forward(self, tensorFirst, tensorSecond, tensorFeaturesFirst, tensorFeaturesSecond, tensorFlow): 218 | tensorFeaturesFirst = self.moduleFeat(tensorFeaturesFirst) 219 | tensorFeaturesSecond = self.moduleFeat(tensorFeaturesSecond) 220 | 221 | if tensorFlow is not None: 222 | tensorFeaturesSecond = Backward(tensorInput=tensorFeaturesSecond, tensorFlow=tensorFlow * self.dblBackward) 223 | # end 224 | 225 | return (tensorFlow if tensorFlow is not None else 0.0) + self.moduleMain(torch.cat([ tensorFeaturesFirst, tensorFeaturesSecond, tensorFlow ], 1)) 226 | # end 227 | # end 228 | 229 | class Regularization(torch.nn.Module): 230 | def __init__(self, intLevel): 231 | super(Regularization, self).__init__() 232 | 233 | self.dblBackward = [ 0.0, 0.0, 10.0, 5.0, 2.5, 1.25, 0.625 ][intLevel] 234 | 235 | self.intUnfold = [ 0, 0, 7, 5, 5, 3, 3 ][intLevel] 236 | 237 | if intLevel >= 5: 238 | self.moduleFeat = torch.nn.Sequential() 239 | 240 | elif intLevel < 5: 241 | self.moduleFeat = torch.nn.Sequential( 242 | torch.nn.Conv2d(in_channels=[ 0, 0, 32, 64, 96, 128, 192 ][intLevel], out_channels=128, kernel_size=1, stride=1, padding=0), 243 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 244 | ) 245 | 246 | # end 247 | 248 | self.moduleMain = torch.nn.Sequential( 249 | torch.nn.Conv2d(in_channels=[ 0, 0, 131, 131, 131, 131, 195 ][intLevel], out_channels=128, kernel_size=3, stride=1, padding=1), 250 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 251 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), 252 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 253 | torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1), 254 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 255 | torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), 256 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 257 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1), 258 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 259 | torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), 260 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 261 | ) 262 | 263 | if intLevel >= 5: 264 | self.moduleDist = torch.nn.Sequential( 265 | torch.nn.Conv2d(in_channels=32, out_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], kernel_size=[ 0, 0, 7, 5, 5, 3, 3 ][intLevel], stride=1, padding=[ 0, 0, 3, 2, 2, 1, 1 ][intLevel]) 266 | ) 267 | 268 | elif intLevel < 5: 269 | self.moduleDist = torch.nn.Sequential( 270 | torch.nn.Conv2d(in_channels=32, out_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], kernel_size=([ 0, 0, 7, 5, 5, 3, 3 ][intLevel], 1), stride=1, padding=([ 0, 0, 3, 2, 2, 1, 1 ][intLevel], 0)), 271 | torch.nn.Conv2d(in_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], out_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], kernel_size=(1, [ 0, 0, 7, 5, 5, 3, 3 ][intLevel]), stride=1, padding=(0, [ 0, 0, 3, 2, 2, 1, 1 ][intLevel])) 272 | ) 273 | 274 | # end 275 | 276 | self.moduleScaleX = torch.nn.Conv2d(in_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], out_channels=1, kernel_size=1, stride=1, padding=0) 277 | self.moduleScaleY = torch.nn.Conv2d(in_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], out_channels=1, kernel_size=1, stride=1, padding=0) 278 | # eny 279 | 280 | def forward(self, tensorFirst, tensorSecond, tensorFeaturesFirst, tensorFeaturesSecond, tensorFlow): 281 | tensorDifference = (tensorFirst - Backward(tensorInput=tensorSecond, tensorFlow=tensorFlow * self.dblBackward)).pow(2.0).sum(1, True).sqrt() 282 | 283 | tensorDist = self.moduleDist(self.moduleMain(torch.cat([ tensorDifference, tensorFlow - tensorFlow.view(tensorFlow.size(0), 2, -1).mean(2, True).view(tensorFlow.size(0), 2, 1, 1), self.moduleFeat(tensorFeaturesFirst) ], 1))) 284 | tensorDist = tensorDist.pow(2.0).neg() 285 | tensorDist = (tensorDist - tensorDist.max(1, True)[0]).exp() 286 | 287 | tensorDivisor = tensorDist.sum(1, True).reciprocal() 288 | 289 | tensorScaleX = self.moduleScaleX(tensorDist * torch.nn.functional.unfold(input=tensorFlow[:, 0:1, :, :], kernel_size=self.intUnfold, stride=1, padding=int((self.intUnfold - 1) / 2)).view_as(tensorDist)) * tensorDivisor 290 | tensorScaleY = self.moduleScaleY(tensorDist * torch.nn.functional.unfold(input=tensorFlow[:, 1:2, :, :], kernel_size=self.intUnfold, stride=1, padding=int((self.intUnfold - 1) / 2)).view_as(tensorDist)) * tensorDivisor 291 | 292 | return torch.cat([ tensorScaleX, tensorScaleY ], 1) 293 | # end 294 | # end 295 | 296 | self.moduleFeatures = Features() 297 | self.moduleMatching = torch.nn.ModuleList([ Matching(intLevel) for intLevel in [ 2, 3, 4, 5, 6 ] ]) 298 | self.moduleSubpixel = torch.nn.ModuleList([ Subpixel(intLevel) for intLevel in [ 2, 3, 4, 5, 6 ] ]) 299 | self.moduleRegularization = torch.nn.ModuleList([ Regularization(intLevel) for intLevel in [ 2, 3, 4, 5, 6 ] ]) 300 | 301 | # self.load_state_dict(torch.load('./network-' + arguments_strModel + '.pytorch')) 302 | # end 303 | 304 | def forward(self, tensorFirst, tensorSecond): 305 | tensorFirst[:, 0, :, :] = tensorFirst[:, 0, :, :] - 0.411618 306 | tensorFirst[:, 1, :, :] = tensorFirst[:, 1, :, :] - 0.434631 307 | tensorFirst[:, 2, :, :] = tensorFirst[:, 2, :, :] - 0.454253 308 | 309 | tensorSecond[:, 0, :, :] = tensorSecond[:, 0, :, :] - 0.410782 310 | tensorSecond[:, 1, :, :] = tensorSecond[:, 1, :, :] - 0.433645 311 | tensorSecond[:, 2, :, :] = tensorSecond[:, 2, :, :] - 0.452793 312 | 313 | tensorFeaturesFirst = self.moduleFeatures(tensorFirst) 314 | tensorFeaturesSecond = self.moduleFeatures(tensorSecond) 315 | 316 | tensorFirst = [ tensorFirst ] 317 | tensorSecond = [ tensorSecond ] 318 | 319 | for intLevel in [ 1, 2, 3, 4, 5 ]: 320 | tensorFirst.append(torch.nn.functional.interpolate(input=tensorFirst[-1], size=(tensorFeaturesFirst[intLevel].size(2), tensorFeaturesFirst[intLevel].size(3)), mode='bilinear', align_corners=False)) 321 | tensorSecond.append(torch.nn.functional.interpolate(input=tensorSecond[-1], size=(tensorFeaturesSecond[intLevel].size(2), tensorFeaturesSecond[intLevel].size(3)), mode='bilinear', align_corners=False)) 322 | # end 323 | 324 | tensorFlow = None 325 | 326 | for intLevel in [ -1, -2, -3, -4, -5 ]: 327 | tensorFlow = self.moduleMatching[intLevel](tensorFirst[intLevel], tensorSecond[intLevel], tensorFeaturesFirst[intLevel], tensorFeaturesSecond[intLevel], tensorFlow) 328 | tensorFlow = self.moduleSubpixel[intLevel](tensorFirst[intLevel], tensorSecond[intLevel], tensorFeaturesFirst[intLevel], tensorFeaturesSecond[intLevel], tensorFlow) 329 | tensorFlow = self.moduleRegularization[intLevel](tensorFirst[intLevel], tensorSecond[intLevel], tensorFeaturesFirst[intLevel], tensorFeaturesSecond[intLevel], tensorFlow) 330 | # end 331 | 332 | return tensorFlow #* 20.0 333 | # end 334 | # end 335 | 336 | moduleNetwork = Network().cuda().eval() 337 | 338 | ########################################################## 339 | 340 | def estimate(tensorFirst, tensorSecond): 341 | assert(tensorFirst.size(1) == tensorSecond.size(1)) 342 | assert(tensorFirst.size(2) == tensorSecond.size(2)) 343 | 344 | 345 | 346 | #assert(intWidth == 1024) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue 347 | #assert(intHeight == 436) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue 348 | intWidth = tensorFirst.size(2) 349 | intHeight = tensorFirst.size(1) 350 | tensorPreprocessedFirst = tensorFirst.cuda().view(1, 3, intHeight, intWidth) 351 | tensorPreprocessedSecond = tensorSecond.cuda().view(1, 3, intHeight, intWidth) 352 | 353 | intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 32.0) * 32.0)) 354 | intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 32.0) * 32.0)) 355 | 356 | tensorPreprocessedFirst = torch.nn.functional.interpolate(input=tensorPreprocessedFirst, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) 357 | tensorPreprocessedSecond = torch.nn.functional.interpolate(input=tensorPreprocessedSecond, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) 358 | 359 | tensorFlow = torch.nn.functional.interpolate(input=moduleNetwork(tensorPreprocessedFirst, tensorPreprocessedSecond), size=(intHeight, intWidth), mode='bilinear', align_corners=False) 360 | 361 | tensorFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth) 362 | tensorFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight) 363 | 364 | return tensorFlow[0, :, :, :].cpu() 365 | # end 366 | 367 | ########################################################## 368 | 369 | def batch_estimate(tensor_batch,moduleNetwork): 370 | # the tensor have been changed into [0.0,1.0] and [b c h w] 371 | intWidth = tensor_batch.size(3) 372 | intHeight = tensor_batch.size(2) 373 | tensorFirst=tensor_batch[:,:3,] 374 | tensorSecond=tensor_batch[:,3:,] 375 | 376 | intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 32.0) * 32.0)) 377 | intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 32.0) * 32.0)) 378 | 379 | tensorPreprocessedFirst = torch.nn.functional.interpolate(input=tensorFirst, 380 | size=(intPreprocessedHeight, intPreprocessedWidth), 381 | mode='bilinear', align_corners=False) 382 | tensorPreprocessedSecond = torch.nn.functional.interpolate(input=tensorSecond, 383 | size=(intPreprocessedHeight, intPreprocessedWidth), 384 | mode='bilinear', align_corners=False) 385 | 386 | tensorFlow = torch.nn.functional.interpolate(input=moduleNetwork(tensorPreprocessedFirst, tensorPreprocessedSecond), 387 | size=(intHeight, intWidth), mode='bilinear', align_corners=False) 388 | 389 | tensorFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth) 390 | tensorFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight) 391 | 392 | return tensorFlow 393 | 394 | # if __name__ == '__main__': 395 | # tensorFirst = torch.FloatTensor(numpy.array(PIL.Image.open(arguments_strFirst))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)) 396 | # tensorSecond = torch.FloatTensor(numpy.array(PIL.Image.open(arguments_strSecond))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)) 397 | # 398 | # tensorOutput = estimate(tensorFirst, tensorSecond) 399 | # 400 | # objectOutput = open(arguments_strOut, 'wb') 401 | # 402 | # numpy.array([ 80, 73, 69, 72 ], numpy.uint8).tofile(objectOutput) 403 | # numpy.array([ tensorOutput.size(2), tensorOutput.size(1) ], numpy.int32).tofile(objectOutput) 404 | # numpy.array(tensorOutput.numpy().transpose(1, 2, 0), numpy.float32).tofile(objectOutput) 405 | # 406 | # objectOutput.close() 407 | # end 408 | -------------------------------------------------------------------------------- /ano_pre/eval_metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io as scio 3 | import os 4 | import argparse 5 | import pickle 6 | from sklearn import metrics 7 | import json 8 | import socket 9 | 10 | 11 | # data folder contain all datasets, such as ped1, ped2, avenue, shanghaitech, etc 12 | DATA_DIR = '/hdd/fjc/VAD/' 13 | 14 | # normalize scores in each sub video 15 | NORMALIZE = True 16 | 17 | # number of history frames, since in prediction based method, the first 4 frames can not be predicted, so that 18 | # the first 4frames are undecidable, we just ignore the first 4 frames 19 | DECIDABLE_IDX = 4 20 | 21 | 22 | def parser_args(): 23 | parser = argparse.ArgumentParser(description='evaluating the model, computing the roc/auc.') 24 | 25 | parser.add_argument('-f', '--file', type=str, help='the path of loss file.') 26 | parser.add_argument('-t', '--type', type=str, default='compute_auc', 27 | help='the type of evaluation, choosing type is: plot_roc, compute_auc, ' 28 | 'test_func\n, the default type is compute_auc') 29 | return parser.parse_args() 30 | 31 | 32 | class RecordResult(object): 33 | def __init__(self, fpr=None, tpr=None, auc=-np.inf, dataset=None, loss_file=None): 34 | self.fpr = fpr 35 | self.tpr = tpr 36 | self.auc = auc 37 | self.dataset = dataset 38 | self.loss_file = loss_file 39 | 40 | def __lt__(self, other): 41 | return self.auc < other.auc 42 | 43 | def __gt__(self, other): 44 | return self.auc > other.auc 45 | 46 | def __str__(self): 47 | return 'dataset = {}, loss file = {}, auc = {}'.format(self.dataset, self.loss_file, self.auc) 48 | 49 | 50 | class GroundTruthLoader(object): 51 | AVENUE = 'avenue' 52 | PED1 = 'ped1' 53 | PED1_PIXEL_SUBSET = 'ped1_pixel_subset' 54 | PED2 = 'ped2' 55 | ENTRANCE = 'enter' 56 | EXIT = 'exit' 57 | SHANGHAITECH = 'shanghaitech' 58 | SHANGHAITECH_LABEL_PATH = os.path.join(DATA_DIR, 'shanghaitech/testing/test_frame_mask') 59 | TOY_DATA = 'toydata' 60 | TOY_DATA_LABEL_PATH = os.path.join(DATA_DIR, TOY_DATA, 'toydata.json') 61 | 62 | NAME_MAT_MAPPING = { 63 | AVENUE: os.path.join(DATA_DIR, 'avenue/avenue.mat'), 64 | PED1: os.path.join(DATA_DIR, 'ped1/ped1.mat'), 65 | PED2: os.path.join(DATA_DIR, 'ped2/ped2.mat'), 66 | ENTRANCE: os.path.join(DATA_DIR, 'enter/enter.mat'), 67 | EXIT: os.path.join(DATA_DIR, 'exit/exit.mat') 68 | } 69 | 70 | NAME_FRAMES_MAPPING = { 71 | AVENUE: os.path.join(DATA_DIR, 'avenue/testing/frames'), 72 | PED1: os.path.join(DATA_DIR, 'ped1/testing/frames'), 73 | PED2: os.path.join(DATA_DIR, 'ped2/testing/frames'), 74 | ENTRANCE: os.path.join(DATA_DIR, 'enter/testing/frames'), 75 | EXIT: os.path.join(DATA_DIR, 'exit/testing/frames') 76 | } 77 | 78 | def __init__(self, mapping_json=None): 79 | """ 80 | Initial a ground truth loader, which loads the ground truth with given dataset name. 81 | 82 | :param mapping_json: the mapping from dataset name to the path of ground truth. 83 | """ 84 | 85 | if mapping_json is not None: 86 | with open(mapping_json, 'rb') as json_file: 87 | self.mapping = json.load(json_file) 88 | else: 89 | self.mapping = GroundTruthLoader.NAME_MAT_MAPPING 90 | 91 | def __call__(self, dataset): 92 | """ get the ground truth by provided the name of dataset. 93 | 94 | :type dataset: str 95 | :param dataset: the name of dataset. 96 | :return: np.ndarray, shape(#video) 97 | np.array[0] contains all the start frame and end frame of abnormal events of video 0, 98 | and its shape is (#frapsnr, ) 99 | """ 100 | 101 | if dataset == GroundTruthLoader.SHANGHAITECH: 102 | gt = self.__load_shanghaitech_gt() 103 | elif dataset == GroundTruthLoader.TOY_DATA: 104 | gt = self.__load_toydata_gt() 105 | else: 106 | gt = self.__load_ucsd_avenue_subway_gt(dataset) 107 | return gt 108 | 109 | def __load_ucsd_avenue_subway_gt(self, dataset): 110 | assert dataset in self.mapping, 'there is no dataset named {} \n Please check {}' \ 111 | .format(dataset, GroundTruthLoader.NAME_MAT_MAPPING.keys()) 112 | 113 | mat_file = self.mapping[dataset] 114 | abnormal_events = scio.loadmat(mat_file, squeeze_me=True)['gt'] 115 | 116 | if abnormal_events.ndim == 2: 117 | abnormal_events = abnormal_events.reshape(-1, abnormal_events.shape[0], abnormal_events.shape[1]) 118 | 119 | num_video = abnormal_events.shape[0] 120 | dataset_video_folder = GroundTruthLoader.NAME_FRAMES_MAPPING[dataset] 121 | video_list = os.listdir(dataset_video_folder) 122 | video_list.sort() 123 | 124 | assert num_video == len(video_list), 'ground true does not match the number of testing videos. {} != {}' \ 125 | .format(num_video, len(video_list)) 126 | 127 | # get the total frames of sub video 128 | def get_video_length(sub_video_number): 129 | # video_name = video_name_template.format(sub_video_number) 130 | video_name = os.path.join(dataset_video_folder, video_list[sub_video_number]) 131 | assert os.path.isdir(video_name), '{} is not directory!'.format(video_name) 132 | 133 | length = len(os.listdir(video_name)) 134 | 135 | return length 136 | 137 | # need to test [].append, or np.array().append(), which one is faster 138 | gt = [] 139 | for i in range(num_video): 140 | length = get_video_length(i) 141 | 142 | sub_video_gt = np.zeros((length,), dtype=np.int8) 143 | sub_abnormal_events = abnormal_events[i] 144 | if sub_abnormal_events.ndim == 1: 145 | sub_abnormal_events = sub_abnormal_events.reshape((sub_abnormal_events.shape[0], -1)) 146 | 147 | _, num_abnormal = sub_abnormal_events.shape 148 | 149 | for j in range(num_abnormal): 150 | # (start - 1, end - 1) 151 | start = sub_abnormal_events[0, j] - 1 152 | end = sub_abnormal_events[1, j] 153 | 154 | sub_video_gt[start: end] = 1 155 | 156 | gt.append(sub_video_gt) 157 | 158 | return gt 159 | 160 | @staticmethod 161 | def __load_shanghaitech_gt(): 162 | video_path_list = os.listdir(GroundTruthLoader.SHANGHAITECH_LABEL_PATH) 163 | video_path_list.sort() 164 | 165 | gt = [] 166 | for video in video_path_list: 167 | # print(os.path.join(GroundTruthLoader.SHANGHAITECH_LABEL_PATH, video)) 168 | gt.append(np.load(os.path.join(GroundTruthLoader.SHANGHAITECH_LABEL_PATH, video))) 169 | 170 | return gt 171 | 172 | @staticmethod 173 | def __load_toydata_gt(): 174 | with open(GroundTruthLoader.TOY_DATA_LABEL_PATH, 'r') as gt_file: 175 | gt_dict = json.load(gt_file) 176 | 177 | gt = [] 178 | for video, video_info in gt_dict.items(): 179 | length = video_info['length'] 180 | video_gt = np.zeros((length,), dtype=np.uint8) 181 | sub_gt = np.array(np.matrix(video_info['gt'])) 182 | 183 | for anomaly in sub_gt: 184 | start = anomaly[0] 185 | end = anomaly[1] + 1 186 | video_gt[start: end] = 1 187 | gt.append(video_gt) 188 | return gt 189 | 190 | @staticmethod 191 | def get_pixel_masks_file_list(dataset): 192 | # pixel mask folder 193 | pixel_mask_folder = os.path.join(DATA_DIR, dataset, 'pixel_masks') 194 | pixel_mask_file_list = os.listdir(pixel_mask_folder) 195 | pixel_mask_file_list.sort() 196 | 197 | # get all testing videos 198 | dataset_video_folder = GroundTruthLoader.NAME_FRAMES_MAPPING[dataset] 199 | video_list = os.listdir(dataset_video_folder) 200 | video_list.sort() 201 | 202 | # get all testing video names with pixel masks 203 | pixel_video_ids = [] 204 | ids = 0 205 | for pixel_mask_name in pixel_mask_file_list: 206 | while ids < len(video_list): 207 | if video_list[ids] + '.npy' == pixel_mask_name: 208 | pixel_video_ids.append(ids) 209 | ids += 1 210 | break 211 | else: 212 | ids += 1 213 | 214 | assert len(pixel_video_ids) == len(pixel_mask_file_list) 215 | 216 | for i in range(len(pixel_mask_file_list)): 217 | pixel_mask_file_list[i] = os.path.join(pixel_mask_folder, pixel_mask_file_list[i]) 218 | 219 | return pixel_mask_file_list, pixel_video_ids 220 | 221 | 222 | def load_psnr_gt(loss_file): 223 | with open(loss_file, 'rb') as reader: 224 | # results { 225 | # 'dataset': the name of dataset 226 | # 'psnr': the psnr of each testing videos, 227 | # } 228 | 229 | # psnr_records['psnr'] is np.array, shape(#videos) 230 | # psnr_records[0] is np.array ------> 01.avi 231 | # psnr_records[1] is np.array ------> 02.avi 232 | # ...... 233 | # psnr_records[n] is np.array ------> xx.avi 234 | 235 | results = pickle.load(reader) 236 | 237 | dataset = results['dataset'] 238 | psnr_records = results['psnr'] 239 | 240 | num_videos = len(psnr_records) 241 | 242 | # load ground truth 243 | gt_loader = GroundTruthLoader() 244 | gt = gt_loader(dataset=dataset) 245 | 246 | assert num_videos == len(gt), 'the number of saved videos does not match the ground truth, {} != {}' \ 247 | .format(num_videos, len(gt)) 248 | 249 | return dataset, psnr_records, gt 250 | 251 | 252 | def load_psnr_gt_flow(loss_file): 253 | with open(loss_file, 'rb') as reader: 254 | # results { 255 | # 'dataset': the name of dataset 256 | # 'psnr': the psnr of each testing videos, 257 | # } 258 | 259 | # psnr_records['psnr'] is np.array, shape(#videos) 260 | # psnr_records[0] is np.array ------> 01.avi 261 | # psnr_records[1] is np.array ------> 02.avi 262 | # ...... 263 | # psnr_records[n] is np.array ------> xx.avi 264 | 265 | results = pickle.load(reader) 266 | 267 | dataset = results['dataset'] 268 | psnrs = results['psnr'] 269 | flows = results['flow'] 270 | 271 | num_videos = len(psnrs) 272 | 273 | # load ground truth 274 | gt_loader = GroundTruthLoader() 275 | gt = gt_loader(dataset=dataset) 276 | 277 | assert num_videos == len(gt), 'the number of saved videos does not match the ground truth, {} != {}' \ 278 | .format(num_videos, len(gt)) 279 | 280 | return dataset, psnrs, flows, gt 281 | 282 | 283 | def load_psnr(loss_file): 284 | """ 285 | load image psnr or optical flow psnr. 286 | :param loss_file: loss file path 287 | :return: 288 | """ 289 | with open(loss_file, 'rb') as reader: 290 | # results { 291 | # 'dataset': the name of dataset 292 | # 'psnr': the psnr of each testing videos, 293 | # } 294 | 295 | # psnr_records['psnr'] is np.array, shape(#videos) 296 | # psnr_records[0] is np.array ------> 01.avi 297 | # psnr_records[1] is np.array ------> 02.avi 298 | # ...... 299 | # psnr_records[n] is np.array ------> xx.avi 300 | 301 | results = pickle.load(reader) 302 | psnrs = results['psnr'] 303 | return psnrs 304 | 305 | 306 | def get_scores_labels(loss_file): 307 | # the name of dataset, loss, and ground truth 308 | dataset, psnr_records, gt = load_psnr_gt(loss_file=loss_file) 309 | 310 | # the number of videos 311 | num_videos = len(psnr_records) 312 | 313 | scores = np.array([], dtype=np.float32) 314 | labels = np.array([], dtype=np.int8) 315 | # video normalization 316 | for i in range(num_videos): 317 | distance = psnr_records[i] 318 | 319 | if NORMALIZE: 320 | distance -= distance.min() # distances = (distance - min) / (max - min) 321 | distance /= distance.max() 322 | # distance = 1 - distance 323 | 324 | scores = np.concatenate((scores[:], distance[DECIDABLE_IDX:]), axis=0) 325 | labels = np.concatenate((labels[:], gt[i][DECIDABLE_IDX:]), axis=0) 326 | return dataset, scores, labels 327 | 328 | 329 | def precision_recall_auc(loss_file): 330 | if not os.path.isdir(loss_file): 331 | loss_file_list = [loss_file] 332 | else: 333 | loss_file_list = os.listdir(loss_file) 334 | loss_file_list = [os.path.join(loss_file, sub_loss_file) for sub_loss_file in loss_file_list] 335 | 336 | optimal_results = RecordResult() 337 | for sub_loss_file in loss_file_list: 338 | dataset, scores, labels = get_scores_labels(sub_loss_file) 339 | precision, recall, thresholds = metrics.precision_recall_curve(labels, scores, pos_label=0) 340 | auc = metrics.auc(recall, precision) 341 | 342 | results = RecordResult(recall, precision, auc, dataset, sub_loss_file) 343 | 344 | if optimal_results < results: 345 | optimal_results = results 346 | 347 | if os.path.isdir(loss_file): 348 | print(results) 349 | print('##### optimal result and model = {}'.format(optimal_results)) 350 | return optimal_results 351 | 352 | 353 | def cal_eer(fpr, tpr): 354 | # makes fpr + tpr = 1 355 | eer = fpr[np.nanargmin(np.absolute((fpr + tpr - 1)))] 356 | return eer 357 | 358 | 359 | def compute_eer(loss_file): 360 | if not os.path.isdir(loss_file): 361 | loss_file_list = [loss_file] 362 | else: 363 | loss_file_list = os.listdir(loss_file) 364 | loss_file_list = [os.path.join(loss_file, sub_loss_file) for sub_loss_file in loss_file_list] 365 | 366 | optimal_results = RecordResult(auc=np.inf) 367 | for sub_loss_file in loss_file_list: 368 | dataset, scores, labels = get_scores_labels(sub_loss_file) 369 | fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=0) 370 | eer = cal_eer(fpr, tpr) 371 | 372 | results = RecordResult(fpr, tpr, eer, dataset, sub_loss_file) 373 | 374 | if optimal_results > results: 375 | optimal_results = results 376 | 377 | if os.path.isdir(loss_file): 378 | print(results) 379 | print('##### optimal result and model = {}'.format(optimal_results)) 380 | return optimal_results 381 | 382 | 383 | def compute_auc(loss_file): 384 | if not os.path.isdir(loss_file): 385 | loss_file_list = [loss_file] 386 | else: 387 | loss_file_list = os.listdir(loss_file) 388 | loss_file_list = [os.path.join(loss_file, sub_loss_file) for sub_loss_file in loss_file_list] 389 | 390 | optimal_results = RecordResult() 391 | for sub_loss_file in loss_file_list: 392 | # the name of dataset, loss, and ground truth 393 | dataset, psnr_records, gt = load_psnr_gt(loss_file=sub_loss_file) 394 | 395 | # the number of videos 396 | num_videos = len(psnr_records) 397 | 398 | scores = np.array([], dtype=np.float32) 399 | labels = np.array([], dtype=np.int8) 400 | # video normalization 401 | for i in range(num_videos): 402 | distance = psnr_records[i] 403 | 404 | # if NORMALIZE: 405 | # distance -= distance.min() # distances = (distance - min) / (max - min) 406 | # distance /= distance.max() 407 | # distance = 1 - distance 408 | 409 | scores = np.concatenate((scores, distance[DECIDABLE_IDX:]), axis=0) 410 | labels = np.concatenate((labels, gt[i][DECIDABLE_IDX:]), axis=0) 411 | if NORMALIZE: 412 | scores -= scores.min() # scores = (scores - min) / (max - min) 413 | scores /= scores.max() 414 | fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=0) 415 | auc = metrics.auc(fpr, tpr) 416 | 417 | results = RecordResult(fpr, tpr, auc, dataset, sub_loss_file) 418 | 419 | if optimal_results < results: 420 | optimal_results = results 421 | 422 | if os.path.isdir(loss_file): 423 | print(results) 424 | print('##### optimal result and model = {}'.format(optimal_results)) 425 | return optimal_results 426 | 427 | 428 | def average_psnr(loss_file): 429 | if not os.path.isdir(loss_file): 430 | loss_file_list = [loss_file] 431 | else: 432 | loss_file_list = os.listdir(loss_file) 433 | loss_file_list = [os.path.join(loss_file, sub_loss_file) for sub_loss_file in loss_file_list] 434 | 435 | max_avg_psnr = -np.inf 436 | max_file = '' 437 | for file in loss_file_list: 438 | psnr_records = load_psnr(file) 439 | 440 | psnr_records = np.concatenate(psnr_records, axis=0) 441 | avg_psnr = np.mean(psnr_records) 442 | if max_avg_psnr < avg_psnr: 443 | max_avg_psnr = avg_psnr 444 | max_file = file 445 | print('{}, average psnr = {}'.format(file, avg_psnr)) 446 | 447 | print('max average psnr file = {}, psnr = {}'.format(max_file, max_avg_psnr)) 448 | 449 | 450 | def calculate_psnr(loss_file): 451 | optical_result = compute_auc(loss_file) 452 | print('##### optimal result and model = {}'.format(optical_result)) 453 | 454 | mean_psnr = [] 455 | for file in os.listdir(loss_file): 456 | file = os.path.join(loss_file, file) 457 | dataset, psnr_records, gt = load_psnr_gt(file) 458 | 459 | psnr_records = np.concatenate(psnr_records, axis=0) 460 | gt = np.concatenate(gt, axis=0) 461 | 462 | mean_normal_psnr = np.mean(psnr_records[gt == 0]) 463 | mean_abnormal_psnr = np.mean(psnr_records[gt == 1]) 464 | mean = np.mean(psnr_records) 465 | print('mean normal psrn = {}, mean abnormal psrn = {}, mean = {}'.format( 466 | mean_normal_psnr, 467 | mean_abnormal_psnr, 468 | mean) 469 | ) 470 | mean_psnr.append(mean) 471 | print('max mean psnr = {}'.format(np.max(mean_psnr))) 472 | 473 | 474 | def calculate_score(loss_file): 475 | if not os.path.isdir(loss_file): 476 | loss_file_path = loss_file 477 | else: 478 | optical_result = compute_auc(loss_file) 479 | loss_file_path = optical_result.loss_file 480 | print('##### optimal result and model = {}'.format(optical_result)) 481 | dataset, psnr_records, gt = load_psnr_gt(loss_file=loss_file_path) 482 | 483 | # the number of videos 484 | num_videos = len(psnr_records) 485 | 486 | scores = np.array([], dtype=np.float32) 487 | labels = np.array([], dtype=np.int8) 488 | # video normalization 489 | for i in range(num_videos): 490 | distance = psnr_records[i] 491 | 492 | distance = (distance - distance.min()) / (distance.max() - distance.min()) 493 | 494 | scores = np.concatenate((scores, distance[DECIDABLE_IDX:]), axis=0) 495 | labels = np.concatenate((labels, gt[i][DECIDABLE_IDX:]), axis=0) 496 | 497 | mean_normal_scores = np.mean(scores[labels == 0]) 498 | mean_abnormal_scores = np.mean(scores[labels == 1]) 499 | print('mean normal scores = {}, mean abnormal scores = {}, ' 500 | 'delta = {}'.format(mean_normal_scores, mean_abnormal_scores, mean_normal_scores - mean_abnormal_scores)) 501 | 502 | 503 | def test_func(*args): 504 | # simulate testing on CUHK AVENUE dataset 505 | dataset = GroundTruthLoader.AVENUE 506 | 507 | # load the ground truth 508 | gt_loader = GroundTruthLoader() 509 | gt = gt_loader(dataset=dataset) 510 | 511 | num_videos = len(gt) 512 | 513 | simulated_results = { 514 | 'dataset': dataset, 515 | 'psnr': [] 516 | } 517 | 518 | simulated_psnr = [] 519 | for i in range(num_videos): 520 | sub_video_length = gt[i].shape[0] 521 | simulated_psnr.append(np.random.random(size=sub_video_length)) 522 | 523 | simulated_results['psnr'] = simulated_psnr 524 | 525 | # writing to file, 'generated_loss.bin' 526 | with open('generated_loss.bin', 'wb') as writer: 527 | pickle.dump(simulated_results, writer, pickle.HIGHEST_PROTOCOL) 528 | 529 | print(file_path.name) 530 | result = compute_auc(file_path.name) 531 | 532 | print('optimal = {}'.format(result)) 533 | 534 | 535 | eval_type_function = { 536 | 'compute_auc': compute_auc, 537 | 'compute_eer': compute_eer, 538 | 'precision_recall_auc': precision_recall_auc, 539 | 'calculate_psnr': calculate_psnr, 540 | 'calculate_score': calculate_score, 541 | 'average_psnr': average_psnr, 542 | 'average_psnr_sample': average_psnr 543 | } 544 | 545 | 546 | def evaluate(eval_type, save_file): 547 | assert eval_type in eval_type_function, 'there is no type of evaluation {}, please check {}' \ 548 | .format(eval_type, eval_type_function.keys()) 549 | eval_func = eval_type_function[eval_type] 550 | optimal_results = eval_func(save_file) 551 | return optimal_results 552 | 553 | 554 | if __name__ == '__main__': 555 | args = parser_args() 556 | 557 | eval_type = args.type 558 | file_path = args.file 559 | 560 | print('Evaluate type = {}'.format(eval_type)) 561 | print('File path = {}'.format(file_path)) 562 | 563 | if eval_type == 'test_func': 564 | test_func() 565 | else: 566 | evaluate(eval_type, file_path) 567 | -------------------------------------------------------------------------------- /flownet2/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | from torch.autograd import Variable 7 | from tensorboardX import SummaryWriter 8 | 9 | import argparse, os, sys, subprocess 10 | import setproctitle, colorama 11 | import numpy as np 12 | from tqdm import tqdm 13 | from glob import glob 14 | from os.path import * 15 | 16 | import models, losses, datasets 17 | from utils import flow_utils, tools 18 | 19 | # fp32 copy of parameters for update 20 | global param_copy 21 | 22 | if __name__ == '__main__': 23 | parser = argparse.ArgumentParser() 24 | 25 | parser.add_argument('--start_epoch', type=int, default=1) 26 | parser.add_argument('--total_epochs', type=int, default=10000) 27 | parser.add_argument('--batch_size', '-b', type=int, default=8, help="Batch size") 28 | parser.add_argument('--train_n_batches', type=int, default = -1, help='Number of min-batches per epoch. If < 0, it will be determined by training_dataloader') 29 | parser.add_argument('--crop_size', type=int, nargs='+', default = [256, 256], help="Spatial dimension to crop training samples for training") 30 | parser.add_argument('--gradient_clip', type=float, default=None) 31 | parser.add_argument('--schedule_lr_frequency', type=int, default=0, help='in number of iterations (0 for no schedule)') 32 | parser.add_argument('--schedule_lr_fraction', type=float, default=10) 33 | parser.add_argument("--rgb_max", type=float, default = 255.) 34 | 35 | parser.add_argument('--number_workers', '-nw', '--num_workers', type=int, default=8) 36 | parser.add_argument('--number_gpus', '-ng', type=int, default=-1, help='number of GPUs to use') 37 | parser.add_argument('--no_cuda', action='store_true') 38 | 39 | parser.add_argument('--seed', type=int, default=1) 40 | parser.add_argument('--name', default='run', type=str, help='a name to append to the save directory') 41 | parser.add_argument('--save', '-s', default='./work', type=str, help='directory for saving') 42 | 43 | parser.add_argument('--validation_frequency', type=int, default=5, help='validate every n epochs') 44 | parser.add_argument('--validation_n_batches', type=int, default=-1) 45 | parser.add_argument('--render_validation', action='store_true', help='run inference (save flows to file) and every validation_frequency epoch') 46 | 47 | parser.add_argument('--inference', action='store_true') 48 | parser.add_argument('--inference_size', type=int, nargs='+', default = [-1,-1], help='spatial size divisible by 64. default (-1,-1) - largest possible valid size would be used') 49 | parser.add_argument('--inference_batch_size', type=int, default=1) 50 | parser.add_argument('--inference_n_batches', type=int, default=-1) 51 | parser.add_argument('--save_flow', action='store_true', help='save predicted flows to file') 52 | 53 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 54 | parser.add_argument('--log_frequency', '--summ_iter', type=int, default=1, help="Log every n batches") 55 | 56 | parser.add_argument('--skip_training', action='store_true') 57 | parser.add_argument('--skip_validation', action='store_true') 58 | 59 | parser.add_argument('--fp16', action='store_true', help='Run model in pseudo-fp16 mode (fp16 storage fp32 math).') 60 | parser.add_argument('--fp16_scale', type=float, default=1024., help='Loss scaling, positive power of 2 values can improve fp16 convergence.') 61 | 62 | tools.add_arguments_for_module(parser, models, argument_for_class='model', default='FlowNet2') 63 | 64 | tools.add_arguments_for_module(parser, losses, argument_for_class='loss', default='L1Loss') 65 | 66 | tools.add_arguments_for_module(parser, torch.optim, argument_for_class='optimizer', default='Adam', skip_params=['params']) 67 | 68 | tools.add_arguments_for_module(parser, datasets, argument_for_class='training_dataset', default='MpiSintelFinal', 69 | skip_params=['is_cropped'], 70 | parameter_defaults={'root': './MPI-Sintel/flow/training'}) 71 | 72 | tools.add_arguments_for_module(parser, datasets, argument_for_class='validation_dataset', default='MpiSintelClean', 73 | skip_params=['is_cropped'], 74 | parameter_defaults={'root': './MPI-Sintel/flow/training', 75 | 'replicates': 1}) 76 | 77 | tools.add_arguments_for_module(parser, datasets, argument_for_class='inference_dataset', default='MpiSintelClean', 78 | skip_params=['is_cropped'], 79 | parameter_defaults={'root': './MPI-Sintel/flow/training', 80 | 'replicates': 1}) 81 | 82 | main_dir = os.path.dirname(os.path.realpath(__file__)) 83 | os.chdir(main_dir) 84 | 85 | # Parse the official arguments 86 | with tools.TimerBlock("Parsing Arguments") as block: 87 | args = parser.parse_args() 88 | if args.number_gpus < 0 : args.number_gpus = torch.cuda.device_count() 89 | 90 | # Get argument defaults (hastag #thisisahack) 91 | parser.add_argument('--IGNORE', action='store_true') 92 | defaults = vars(parser.parse_args(['--IGNORE'])) 93 | 94 | # Print all arguments, color the non-defaults 95 | for argument, value in sorted(vars(args).items()): 96 | reset = colorama.Style.RESET_ALL 97 | color = reset if value == defaults[argument] else colorama.Fore.MAGENTA 98 | block.log('{}{}: {}{}'.format(color, argument, value, reset)) 99 | 100 | args.model_class = tools.module_to_dict(models)[args.model] 101 | args.optimizer_class = tools.module_to_dict(torch.optim)[args.optimizer] 102 | args.loss_class = tools.module_to_dict(losses)[args.loss] 103 | 104 | args.training_dataset_class = tools.module_to_dict(datasets)[args.training_dataset] 105 | args.validation_dataset_class = tools.module_to_dict(datasets)[args.validation_dataset] 106 | args.inference_dataset_class = tools.module_to_dict(datasets)[args.inference_dataset] 107 | 108 | args.cuda = not args.no_cuda and torch.cuda.is_available() 109 | args.current_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).rstrip() 110 | args.log_file = join(args.save, 'args.txt') 111 | 112 | # dict to collect activation gradients (for training debug purpose) 113 | args.grads = {} 114 | 115 | if args.inference: 116 | args.skip_validation = True 117 | args.skip_training = True 118 | args.total_epochs = 1 119 | args.inference_dir = "{}/inference".format(args.save) 120 | 121 | print('Source Code') 122 | print((' Current Git Hash: {}\n'.format(args.current_hash))) 123 | 124 | # Change the title for `top` and `pkill` commands 125 | setproctitle.setproctitle(args.save) 126 | 127 | # Dynamically load the dataset class with parameters passed in via "--argument_[param]=[value]" arguments 128 | with tools.TimerBlock("Initializing Datasets") as block: 129 | args.effective_batch_size = args.batch_size * args.number_gpus 130 | args.effective_inference_batch_size = args.inference_batch_size * args.number_gpus 131 | args.effective_number_workers = args.number_workers * args.number_gpus 132 | gpuargs = {'num_workers': args.effective_number_workers, 133 | 'pin_memory': True, 134 | 'drop_last' : True} if args.cuda else {} 135 | inf_gpuargs = gpuargs.copy() 136 | inf_gpuargs['num_workers'] = args.number_workers 137 | 138 | if exists(args.training_dataset_root): 139 | train_dataset = args.training_dataset_class(args, True, **tools.kwargs_from_args(args, 'training_dataset')) 140 | block.log('Training Dataset: {}'.format(args.training_dataset)) 141 | block.log('Training Input: {}'.format(' '.join([str([d for d in x.size()]) for x in train_dataset[0][0]]))) 142 | block.log('Training Targets: {}'.format(' '.join([str([d for d in x.size()]) for x in train_dataset[0][1]]))) 143 | train_loader = DataLoader(train_dataset, batch_size=args.effective_batch_size, shuffle=True, **gpuargs) 144 | 145 | if exists(args.validation_dataset_root): 146 | validation_dataset = args.validation_dataset_class(args, True, **tools.kwargs_from_args(args, 'validation_dataset')) 147 | block.log('Validation Dataset: {}'.format(args.validation_dataset)) 148 | block.log('Validation Input: {}'.format(' '.join([str([d for d in x.size()]) for x in validation_dataset[0][0]]))) 149 | block.log('Validation Targets: {}'.format(' '.join([str([d for d in x.size()]) for x in validation_dataset[0][1]]))) 150 | validation_loader = DataLoader(validation_dataset, batch_size=args.effective_batch_size, shuffle=False, **gpuargs) 151 | 152 | if exists(args.inference_dataset_root): 153 | inference_dataset = args.inference_dataset_class(args, False, **tools.kwargs_from_args(args, 'inference_dataset')) 154 | block.log('Inference Dataset: {}'.format(args.inference_dataset)) 155 | block.log('Inference Input: {}'.format(' '.join([str([d for d in x.size()]) for x in inference_dataset[0][0]]))) 156 | block.log('Inference Targets: {}'.format(' '.join([str([d for d in x.size()]) for x in inference_dataset[0][1]]))) 157 | inference_loader = DataLoader(inference_dataset, batch_size=args.effective_inference_batch_size, shuffle=False, **inf_gpuargs) 158 | 159 | # Dynamically load model and loss class with parameters passed in via "--model_[param]=[value]" or "--loss_[param]=[value]" arguments 160 | with tools.TimerBlock("Building {} model".format(args.model)) as block: 161 | class ModelAndLoss(nn.Module): 162 | def __init__(self, args): 163 | super(ModelAndLoss, self).__init__() 164 | kwargs = tools.kwargs_from_args(args, 'model') 165 | self.model = args.model_class(args, **kwargs) 166 | kwargs = tools.kwargs_from_args(args, 'loss') 167 | self.loss = args.loss_class(args, **kwargs) 168 | 169 | def forward(self, data, target, inference=False ): 170 | output = self.model(data) 171 | 172 | loss_values = self.loss(output, target) 173 | 174 | if not inference : 175 | return loss_values 176 | else : 177 | return loss_values, output 178 | 179 | model_and_loss = ModelAndLoss(args) 180 | 181 | block.log('Effective Batch Size: {}'.format(args.effective_batch_size)) 182 | block.log('Number of parameters: {}'.format(sum([p.data.nelement() if p.requires_grad else 0 for p in model_and_loss.parameters()]))) 183 | 184 | # assing to cuda or wrap with dataparallel, model and loss 185 | if args.cuda and (args.number_gpus > 0) and args.fp16: 186 | block.log('Parallelizing') 187 | model_and_loss = nn.parallel.DataParallel(model_and_loss, device_ids=list(range(args.number_gpus))) 188 | 189 | block.log('Initializing CUDA') 190 | model_and_loss = model_and_loss.cuda().half() 191 | torch.cuda.manual_seed(args.seed) 192 | param_copy = [param.clone().type(torch.cuda.FloatTensor).detach() for param in model_and_loss.parameters()] 193 | 194 | elif args.cuda and args.number_gpus > 0: 195 | block.log('Initializing CUDA') 196 | model_and_loss = model_and_loss.cuda() 197 | block.log('Parallelizing') 198 | model_and_loss = nn.parallel.DataParallel(model_and_loss, device_ids=list(range(args.number_gpus))) 199 | torch.cuda.manual_seed(args.seed) 200 | 201 | else: 202 | block.log('CUDA not being used') 203 | torch.manual_seed(args.seed) 204 | 205 | # Load weights if needed, otherwise randomly initialize 206 | if args.resume and os.path.isfile(args.resume): 207 | block.log("Loading checkpoint '{}'".format(args.resume)) 208 | checkpoint = torch.load(args.resume) 209 | if not args.inference: 210 | args.start_epoch = checkpoint['epoch'] 211 | best_err = checkpoint['best_EPE'] 212 | model_and_loss.module.model.load_state_dict(checkpoint['state_dict']) 213 | block.log("Loaded checkpoint '{}' (at epoch {})".format(args.resume, checkpoint['epoch'])) 214 | 215 | elif args.resume and args.inference: 216 | block.log("No checkpoint found at '{}'".format(args.resume)) 217 | quit() 218 | 219 | else: 220 | block.log("Random initialization") 221 | 222 | block.log("Initializing save directory: {}".format(args.save)) 223 | if not os.path.exists(args.save): 224 | os.makedirs(args.save) 225 | 226 | train_logger = SummaryWriter(log_dir = os.path.join(args.save, 'train'), comment = 'training') 227 | validation_logger = SummaryWriter(log_dir = os.path.join(args.save, 'validation'), comment = 'validation') 228 | 229 | # Dynamically load the optimizer with parameters passed in via "--optimizer_[param]=[value]" arguments 230 | with tools.TimerBlock("Initializing {} Optimizer".format(args.optimizer)) as block: 231 | kwargs = tools.kwargs_from_args(args, 'optimizer') 232 | if args.fp16: 233 | optimizer = args.optimizer_class([p for p in param_copy if p.requires_grad], **kwargs) 234 | else: 235 | optimizer = args.optimizer_class([p for p in model_and_loss.parameters() if p.requires_grad], **kwargs) 236 | for param, default in list(kwargs.items()): 237 | block.log("{} = {} ({})".format(param, default, type(default))) 238 | 239 | # Log all arguments to file 240 | for argument, value in sorted(vars(args).items()): 241 | block.log2file(args.log_file, '{}: {}'.format(argument, value)) 242 | 243 | # Reusable functions for training and validataion 244 | def train(args, epoch, start_iteration, data_loader, model, optimizer, logger, is_validate=False, offset=0): 245 | statistics = [] 246 | total_loss = 0 247 | 248 | if is_validate: 249 | model.eval() 250 | title = 'Validating Epoch {}'.format(epoch) 251 | args.validation_n_batches = np.inf if args.validation_n_batches < 0 else args.validation_n_batches 252 | progress = tqdm(tools.IteratorTimer(data_loader), ncols=100, total=np.minimum(len(data_loader), args.validation_n_batches), leave=True, position=offset, desc=title) 253 | else: 254 | model.train() 255 | title = 'Training Epoch {}'.format(epoch) 256 | args.train_n_batches = np.inf if args.train_n_batches < 0 else args.train_n_batches 257 | progress = tqdm(tools.IteratorTimer(data_loader), ncols=120, total=np.minimum(len(data_loader), args.train_n_batches), smoothing=.9, miniters=1, leave=True, position=offset, desc=title) 258 | 259 | last_log_time = progress._time() 260 | for batch_idx, (data, target) in enumerate(progress): 261 | 262 | data, target = [Variable(d) for d in data], [Variable(t) for t in target] 263 | if args.cuda and args.number_gpus == 1: 264 | data, target = [d.cuda(async=True) for d in data], [t.cuda(async=True) for t in target] 265 | 266 | optimizer.zero_grad() if not is_validate else None 267 | losses = model(data[0], target[0]) 268 | losses = [torch.mean(loss_value) for loss_value in losses] 269 | loss_val = losses[0] # Collect first loss for weight update 270 | total_loss += loss_val.data[0] 271 | loss_values = [v.data[0] for v in losses] 272 | 273 | # gather loss_labels, direct return leads to recursion limit error as it looks for variables to gather' 274 | loss_labels = list(model.module.loss.loss_labels) 275 | 276 | assert not np.isnan(total_loss) 277 | 278 | if not is_validate and args.fp16: 279 | loss_val.backward() 280 | if args.gradient_clip: 281 | torch.nn.utils.clip_grad_norm(model.parameters(), args.gradient_clip) 282 | 283 | params = list(model.parameters()) 284 | for i in range(len(params)): 285 | param_copy[i].grad = params[i].grad.clone().type_as(params[i]).detach() 286 | param_copy[i].grad.mul_(1./args.loss_scale) 287 | optimizer.step() 288 | for i in range(len(params)): 289 | params[i].data.copy_(param_copy[i].data) 290 | 291 | elif not is_validate: 292 | loss_val.backward() 293 | if args.gradient_clip: 294 | torch.nn.utils.clip_grad_norm(model.parameters(), args.gradient_clip) 295 | optimizer.step() 296 | 297 | # Update hyperparameters if needed 298 | global_iteration = start_iteration + batch_idx 299 | if not is_validate: 300 | tools.update_hyperparameter_schedule(args, epoch, global_iteration, optimizer) 301 | loss_labels.append('lr') 302 | loss_values.append(optimizer.param_groups[0]['lr']) 303 | 304 | loss_labels.append('load') 305 | loss_values.append(progress.iterable.last_duration) 306 | 307 | # Print out statistics 308 | statistics.append(loss_values) 309 | title = '{} Epoch {}'.format('Validating' if is_validate else 'Training', epoch) 310 | 311 | progress.set_description(title + ' ' + tools.format_dictionary_of_losses(loss_labels, statistics[-1])) 312 | 313 | if ((((global_iteration + 1) % args.log_frequency) == 0 and not is_validate) or 314 | (is_validate and batch_idx == args.validation_n_batches - 1)): 315 | 316 | global_iteration = global_iteration if not is_validate else start_iteration 317 | 318 | logger.add_scalar('batch logs per second', len(statistics) / (progress._time() - last_log_time), global_iteration) 319 | last_log_time = progress._time() 320 | 321 | all_losses = np.array(statistics) 322 | 323 | for i, key in enumerate(loss_labels): 324 | logger.add_scalar('average batch ' + str(key), all_losses[:, i].mean(), global_iteration) 325 | logger.add_histogram(str(key), all_losses[:, i], global_iteration) 326 | 327 | # Reset Summary 328 | statistics = [] 329 | 330 | if ( is_validate and ( batch_idx == args.validation_n_batches) ): 331 | break 332 | 333 | if ( (not is_validate) and (batch_idx == (args.train_n_batches)) ): 334 | break 335 | 336 | progress.close() 337 | 338 | return total_loss / float(batch_idx + 1), (batch_idx + 1) 339 | 340 | # Reusable functions for inference 341 | def inference(args, epoch, data_loader, model, offset=0): 342 | 343 | model.eval() 344 | 345 | if args.save_flow or args.render_validation: 346 | flow_folder = "{}/inference/{}.epoch-{}-flow-field".format(args.save,args.name.replace('/', '.'),epoch) 347 | if not os.path.exists(flow_folder): 348 | os.makedirs(flow_folder) 349 | 350 | 351 | args.inference_n_batches = np.inf if args.inference_n_batches < 0 else args.inference_n_batches 352 | 353 | progress = tqdm(data_loader, ncols=100, total=np.minimum(len(data_loader), args.inference_n_batches), desc='Inferencing ', 354 | leave=True, position=offset) 355 | 356 | statistics = [] 357 | total_loss = 0 358 | for batch_idx, (data, target) in enumerate(progress): 359 | if args.cuda: 360 | data, target = [d.cuda(async=True) for d in data], [t.cuda(async=True) for t in target] 361 | data, target = [Variable(d) for d in data], [Variable(t) for t in target] 362 | 363 | # when ground-truth flows are not available for inference_dataset, 364 | # the targets are set to all zeros. thus, losses are actually L1 or L2 norms of compute optical flows, 365 | # depending on the type of loss norm passed in 366 | with torch.no_grad(): 367 | losses, output = model(data[0], target[0], inference=True) 368 | 369 | losses = [torch.mean(loss_value) for loss_value in losses] 370 | loss_val = losses[0] # Collect first loss for weight update 371 | total_loss += loss_val.data[0] 372 | loss_values = [v.data[0] for v in losses] 373 | 374 | # gather loss_labels, direct return leads to recursion limit error as it looks for variables to gather' 375 | loss_labels = list(model.module.loss.loss_labels) 376 | 377 | statistics.append(loss_values) 378 | # import IPython; IPython.embed() 379 | if args.save_flow or args.render_validation: 380 | for i in range(args.inference_batch_size): 381 | _pflow = output[i].data.cpu().numpy().transpose(1, 2, 0) 382 | flow_utils.writeFlow( join(flow_folder, '%06d.flo'%(batch_idx * args.inference_batch_size + i)), _pflow) 383 | 384 | progress.set_description('Inference Averages for Epoch {}: '.format(epoch) + tools.format_dictionary_of_losses(loss_labels, np.array(statistics).mean(axis=0))) 385 | progress.update(1) 386 | 387 | if batch_idx == (args.inference_n_batches - 1): 388 | break 389 | 390 | progress.close() 391 | 392 | return 393 | 394 | # Primary epoch loop 395 | best_err = 1e8 396 | progress = tqdm(list(range(args.start_epoch, args.total_epochs + 1)), miniters=1, ncols=100, desc='Overall Progress', leave=True, position=0) 397 | offset = 1 398 | last_epoch_time = progress._time() 399 | global_iteration = 0 400 | 401 | for epoch in progress: 402 | if args.inference or (args.render_validation and ((epoch - 1) % args.validation_frequency) == 0): 403 | stats = inference(args=args, epoch=epoch - 1, data_loader=inference_loader, model=model_and_loss, offset=offset) 404 | offset += 1 405 | 406 | if not args.skip_validation and ((epoch - 1) % args.validation_frequency) == 0: 407 | validation_loss, _ = train(args=args, epoch=epoch - 1, start_iteration=global_iteration, data_loader=validation_loader, model=model_and_loss, optimizer=optimizer, logger=validation_logger, is_validate=True, offset=offset) 408 | offset += 1 409 | 410 | is_best = False 411 | if validation_loss < best_err: 412 | best_err = validation_loss 413 | is_best = True 414 | 415 | checkpoint_progress = tqdm(ncols=100, desc='Saving Checkpoint', position=offset) 416 | tools.save_checkpoint({ 'arch' : args.model, 417 | 'epoch': epoch, 418 | 'state_dict': model_and_loss.module.model.state_dict(), 419 | 'best_EPE': best_err}, 420 | is_best, args.save, args.model) 421 | checkpoint_progress.update(1) 422 | checkpoint_progress.close() 423 | offset += 1 424 | 425 | if not args.skip_training: 426 | train_loss, iterations = train(args=args, epoch=epoch, start_iteration=global_iteration, data_loader=train_loader, model=model_and_loss, optimizer=optimizer, logger=train_logger, offset=offset) 427 | global_iteration += iterations 428 | offset += 1 429 | 430 | # save checkpoint after every validation_frequency number of epochs 431 | if ((epoch - 1) % args.validation_frequency) == 0: 432 | checkpoint_progress = tqdm(ncols=100, desc='Saving Checkpoint', position=offset) 433 | tools.save_checkpoint({ 'arch' : args.model, 434 | 'epoch': epoch, 435 | 'state_dict': model_and_loss.module.model.state_dict(), 436 | 'best_EPE': train_loss}, 437 | False, args.save, args.model, filename = 'train-checkpoint.pth.tar') 438 | checkpoint_progress.update(1) 439 | checkpoint_progress.close() 440 | 441 | 442 | train_logger.add_scalar('seconds per epoch', progress._time() - last_epoch_time, epoch) 443 | last_epoch_time = progress._time() 444 | print("\n") 445 | -------------------------------------------------------------------------------- /models/pix2pix_networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | 7 | 8 | ############################################################################### 9 | # Helper Functions 10 | ############################################################################### 11 | def get_norm_layer(norm_type='instance'): 12 | """Return a normalization layer 13 | 14 | Parameters: 15 | norm_type (str) -- the name of the normalization layer: batch | instance | none 16 | 17 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). 18 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. 19 | """ 20 | if norm_type == 'batch': 21 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 22 | elif norm_type == 'instance': 23 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 24 | elif norm_type == 'none': 25 | norm_layer = None 26 | else: 27 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 28 | return norm_layer 29 | 30 | 31 | def get_scheduler(optimizer, opt): 32 | """Return a learning rate scheduler 33 | 34 | Parameters: 35 | optimizer -- the optimizer of the network 36 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  37 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 38 | 39 | For 'linear', we keep the same learning rate for the first epochs 40 | and linearly decay the rate to zero over the next epochs. 41 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 42 | See https://pytorch.org/docs/stable/optim.html for more details. 43 | """ 44 | if opt.lr_policy == 'linear': 45 | def lambda_rule(epoch): 46 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 47 | return lr_l 48 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 49 | elif opt.lr_policy == 'step': 50 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 51 | elif opt.lr_policy == 'plateau': 52 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 53 | elif opt.lr_policy == 'cosine': 54 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) 55 | else: 56 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 57 | return scheduler 58 | 59 | 60 | def init_weights(net, init_type='normal', init_gain=0.02): 61 | """Initialize network weights. 62 | 63 | Parameters: 64 | net (network) -- network to be initialized 65 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 66 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 67 | 68 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 69 | work better for some applications. Feel free to try yourself. 70 | """ 71 | def init_func(m): # define the initialization function 72 | classname = m.__class__.__name__ 73 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 74 | if init_type == 'normal': 75 | init.normal_(m.weight.data, 0.0, init_gain) 76 | elif init_type == 'xavier': 77 | init.xavier_normal_(m.weight.data, gain=init_gain) 78 | elif init_type == 'kaiming': 79 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 80 | elif init_type == 'orthogonal': 81 | init.orthogonal_(m.weight.data, gain=init_gain) 82 | else: 83 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 84 | if hasattr(m, 'bias') and m.bias is not None: 85 | init.constant_(m.bias.data, 0.0) 86 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 87 | init.normal_(m.weight.data, 1.0, init_gain) 88 | init.constant_(m.bias.data, 0.0) 89 | 90 | print('initialize network with %s' % init_type) 91 | net.apply(init_func) # apply the initialization function 92 | 93 | 94 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 95 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 96 | Parameters: 97 | net (network) -- the network to be initialized 98 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 99 | gain (float) -- scaling factor for normal, xavier and orthogonal. 100 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 101 | 102 | Return an initialized network. 103 | """ 104 | if len(gpu_ids) > 0: 105 | assert(torch.cuda.is_available()) 106 | net.to(gpu_ids[0]) 107 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 108 | init_weights(net, init_type, init_gain=init_gain) 109 | return net 110 | 111 | 112 | def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]): 113 | """Create a generator 114 | 115 | Parameters: 116 | input_nc (int) -- the number of channels in input images 117 | output_nc (int) -- the number of channels in output images 118 | ngf (int) -- the number of filters in the last conv layer 119 | netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128 120 | norm (str) -- the name of normalization layers used in the network: batch | instance | none 121 | use_dropout (bool) -- if use dropout layers. 122 | init_type (str) -- the name of our initialization method. 123 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 124 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 125 | 126 | Returns a generator 127 | 128 | Our current implementation provides two types of generators: 129 | U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images) 130 | The original U-Net paper: https://arxiv.org/abs/1505.04597 131 | 132 | Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks) 133 | Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations. 134 | We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style). 135 | 136 | 137 | The generator has been initialized by . It uses RELU for non-linearity. 138 | """ 139 | net = None 140 | norm_layer = get_norm_layer(norm_type=norm) 141 | 142 | if netG == 'resnet_9blocks': 143 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) 144 | elif netG == 'resnet_6blocks': 145 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6) 146 | elif netG == 'unet_128': 147 | net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 148 | elif netG == 'unet_256': 149 | net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 150 | else: 151 | raise NotImplementedError('Generator model name [%s] is not recognized' % netG) 152 | return init_net(net, init_type, init_gain, gpu_ids) 153 | 154 | 155 | def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]): 156 | """Create a discriminator 157 | 158 | Parameters: 159 | input_nc (int) -- the number of channels in input images 160 | ndf (int) -- the number of filters in the first conv layer 161 | netD (str) -- the architecture's name: basic | n_layers | pixel 162 | n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' 163 | norm (str) -- the type of normalization layers used in the network. 164 | init_type (str) -- the name of the initialization method. 165 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 166 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 167 | 168 | Returns a discriminator 169 | 170 | Our current implementation provides three types of discriminators: 171 | [basic]: 'PatchGAN' classifier described in the original pix2pix paper. 172 | It can classify whether 70×70 overlapping patches are real or fake. 173 | Such a patch-level discriminator architecture has fewer parameters 174 | than a full-image discriminator and can work on arbitrarily-sized images 175 | in a fully convolutional fashion. 176 | 177 | [n_layers]: With this mode, you cna specify the number of conv layers in the discriminator 178 | with the parameter (default=3 as used in [basic] (PatchGAN).) 179 | 180 | [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not. 181 | It encourages greater color diversity but has no effect on spatial statistics. 182 | 183 | The discriminator has been initialized by . It uses Leakly RELU for non-linearity. 184 | """ 185 | net = None 186 | norm_layer = get_norm_layer(norm_type=norm) 187 | 188 | if netD == 'basic': # default PatchGAN classifier 189 | net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer) 190 | elif netD == 'n_layers': # more options 191 | net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer) 192 | elif netD == 'pixel': # classify if each pixel is real or fake 193 | net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) 194 | else: 195 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % net) 196 | return init_net(net, init_type, init_gain, gpu_ids) 197 | 198 | 199 | ############################################################################## 200 | # Classes 201 | ############################################################################## 202 | class GANLoss(nn.Module): 203 | """Define different GAN objectives. 204 | 205 | The GANLoss class abstracts away the need to create the target label tensor 206 | that has the same size as the input. 207 | """ 208 | 209 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): 210 | """ Initialize the GANLoss class. 211 | 212 | Parameters: 213 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. 214 | target_real_label (bool) - - label for a real image 215 | target_fake_label (bool) - - label of a fake image 216 | 217 | Note: Do not use sigmoid as the last layer of Discriminator. 218 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. 219 | """ 220 | super(GANLoss, self).__init__() 221 | self.register_buffer('real_label', torch.tensor(target_real_label)) 222 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 223 | self.gan_mode = gan_mode 224 | if gan_mode == 'lsgan': 225 | self.loss = nn.MSELoss() 226 | elif gan_mode == 'vanilla': 227 | self.loss = nn.BCEWithLogitsLoss() 228 | elif gan_mode in ['wgangp']: 229 | self.loss = None 230 | else: 231 | raise NotImplementedError('gan mode %s not implemented' % gan_mode) 232 | 233 | def get_target_tensor(self, prediction, target_is_real): 234 | """Create label tensors with the same size as the input. 235 | 236 | Parameters: 237 | prediction (tensor) - - tpyically the prediction from a discriminator 238 | target_is_real (bool) - - if the ground truth label is for real images or fake images 239 | 240 | Returns: 241 | A label tensor filled with ground truth label, and with the size of the input 242 | """ 243 | 244 | if target_is_real: 245 | target_tensor = self.real_label 246 | else: 247 | target_tensor = self.fake_label 248 | return target_tensor.expand_as(prediction) 249 | 250 | def __call__(self, prediction, target_is_real): 251 | """Calculate loss given Discriminator's output and grount truth labels. 252 | 253 | Parameters: 254 | prediction (tensor) - - tpyically the prediction output from a discriminator 255 | target_is_real (bool) - - if the ground truth label is for real images or fake images 256 | 257 | Returns: 258 | the calculated loss. 259 | """ 260 | if self.gan_mode in ['lsgan', 'vanilla']: 261 | target_tensor = self.get_target_tensor(prediction, target_is_real) 262 | loss = self.loss(prediction, target_tensor) 263 | elif self.gan_mode == 'wgangp': 264 | if target_is_real: 265 | loss = -prediction.mean() 266 | else: 267 | loss = prediction.mean() 268 | return loss 269 | 270 | 271 | def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): 272 | """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 273 | 274 | Arguments: 275 | netD (network) -- discriminator network 276 | real_data (tensor array) -- real images 277 | fake_data (tensor array) -- generated images from the generator 278 | device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 279 | type (str) -- if we mix real and fake data or not [real | fake | mixed]. 280 | constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2 281 | lambda_gp (float) -- weight for this loss 282 | 283 | Returns the gradient penalty loss 284 | """ 285 | if lambda_gp > 0.0: 286 | if type == 'real': # either use real images, fake images, or a linear interpolation of two. 287 | interpolatesv = real_data 288 | elif type == 'fake': 289 | interpolatesv = fake_data 290 | elif type == 'mixed': 291 | alpha = torch.rand(real_data.shape[0], 1) 292 | alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) 293 | alpha = alpha.to(device) 294 | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) 295 | else: 296 | raise NotImplementedError('{} not implemented'.format(type)) 297 | interpolatesv.requires_grad_(True) 298 | disc_interpolates = netD(interpolatesv) 299 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, 300 | grad_outputs=torch.ones(disc_interpolates.size()).to(device), 301 | create_graph=True, retain_graph=True, only_inputs=True) 302 | gradients = gradients[0].view(real_data.size(0), -1) # flat the data 303 | gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps 304 | return gradient_penalty, gradients 305 | else: 306 | return 0.0, None 307 | 308 | 309 | class ResnetGenerator(nn.Module): 310 | """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. 311 | 312 | We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) 313 | """ 314 | 315 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): 316 | """Construct a Resnet-based generator 317 | 318 | Parameters: 319 | input_nc (int) -- the number of channels in input images 320 | output_nc (int) -- the number of channels in output images 321 | ngf (int) -- the number of filters in the last conv layer 322 | norm_layer -- normalization layer 323 | use_dropout (bool) -- if use dropout layers 324 | n_blocks (int) -- the number of ResNet blocks 325 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero 326 | """ 327 | assert(n_blocks >= 0) 328 | super(ResnetGenerator, self).__init__() 329 | if type(norm_layer) == functools.partial: 330 | use_bias = norm_layer.func == nn.InstanceNorm2d 331 | else: 332 | use_bias = norm_layer == nn.InstanceNorm2d 333 | 334 | model = [nn.ReflectionPad2d(3), 335 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), 336 | norm_layer(ngf), 337 | nn.ReLU(True)] 338 | 339 | n_downsampling = 2 340 | for i in range(n_downsampling): # add downsampling layers 341 | mult = 2 ** i 342 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), 343 | norm_layer(ngf * mult * 2), 344 | nn.ReLU(True)] 345 | 346 | mult = 2 ** n_downsampling 347 | for i in range(n_blocks): # add ResNet blocks 348 | 349 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 350 | 351 | for i in range(n_downsampling): # add upsampling layers 352 | mult = 2 ** (n_downsampling - i) 353 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 354 | kernel_size=3, stride=2, 355 | padding=1, output_padding=1, 356 | bias=use_bias), 357 | norm_layer(int(ngf * mult / 2)), 358 | nn.ReLU(True)] 359 | model += [nn.ReflectionPad2d(3)] 360 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 361 | model += [nn.Tanh()] 362 | 363 | self.model = nn.Sequential(*model) 364 | 365 | def forward(self, input): 366 | """Standard forward""" 367 | return self.model(input) 368 | 369 | 370 | class ResnetBlock(nn.Module): 371 | """Define a Resnet block""" 372 | 373 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 374 | """Initialize the Resnet block 375 | 376 | A resnet block is a conv block with skip connections 377 | We construct a conv block with build_conv_block function, 378 | and implement skip connections in function. 379 | Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf 380 | """ 381 | super(ResnetBlock, self).__init__() 382 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 383 | 384 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 385 | """Construct a convolutional block. 386 | 387 | Parameters: 388 | dim (int) -- the number of channels in the conv layer. 389 | padding_type (str) -- the name of padding layer: reflect | replicate | zero 390 | norm_layer -- normalization layer 391 | use_dropout (bool) -- if use dropout layers. 392 | use_bias (bool) -- if the conv layer uses bias or not 393 | 394 | Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) 395 | """ 396 | conv_block = [] 397 | p = 0 398 | if padding_type == 'reflect': 399 | conv_block += [nn.ReflectionPad2d(1)] 400 | elif padding_type == 'replicate': 401 | conv_block += [nn.ReplicationPad2d(1)] 402 | elif padding_type == 'zero': 403 | p = 1 404 | else: 405 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 406 | 407 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] 408 | if use_dropout: 409 | conv_block += [nn.Dropout(0.5)] 410 | 411 | p = 0 412 | if padding_type == 'reflect': 413 | conv_block += [nn.ReflectionPad2d(1)] 414 | elif padding_type == 'replicate': 415 | conv_block += [nn.ReplicationPad2d(1)] 416 | elif padding_type == 'zero': 417 | p = 1 418 | else: 419 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 420 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] 421 | 422 | return nn.Sequential(*conv_block) 423 | 424 | def forward(self, x): 425 | """Forward function (with skip connections)""" 426 | out = x + self.conv_block(x) # add skip connections 427 | return out 428 | 429 | 430 | class UnetGenerator(nn.Module): 431 | """Create a Unet-based generator""" 432 | 433 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): 434 | """Construct a Unet generator 435 | Parameters: 436 | input_nc (int) -- the number of channels in input images 437 | output_nc (int) -- the number of channels in output images 438 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, 439 | image of size 128x128 will become of size 1x1 # at the bottleneck 440 | ngf (int) -- the number of filters in the last conv layer 441 | norm_layer -- normalization layer 442 | 443 | We construct the U-Net from the innermost layer to the outermost layer. 444 | It is a recursive process. 445 | """ 446 | super(UnetGenerator, self).__init__() 447 | # construct unet structure 448 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer 449 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters 450 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) 451 | # gradually reduce the number of filters from ngf * 8 to ngf 452 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 453 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 454 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 455 | self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer 456 | 457 | def forward(self, input): 458 | """Standard forward""" 459 | return self.model(input) 460 | 461 | 462 | class UnetSkipConnectionBlock(nn.Module): 463 | """Defines the Unet submodule with skip connection. 464 | X -------------------identity---------------------- 465 | |-- downsampling -- |submodule| -- upsampling --| 466 | """ 467 | 468 | def __init__(self, outer_nc, inner_nc, input_nc=None, 469 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): 470 | """Construct a Unet submodule with skip connections. 471 | 472 | Parameters: 473 | outer_nc (int) -- the number of filters in the outer conv layer 474 | inner_nc (int) -- the number of filters in the inner conv layer 475 | input_nc (int) -- the number of channels in input images/features 476 | submodule (UnetSkipConnectionBlock) -- previously defined submodules 477 | outermost (bool) -- if this module is the outermost module 478 | innermost (bool) -- if this module is the innermost module 479 | norm_layer -- normalization layer 480 | user_dropout (bool) -- if use dropout layers. 481 | """ 482 | super(UnetSkipConnectionBlock, self).__init__() 483 | self.outermost = outermost 484 | if type(norm_layer) == functools.partial: 485 | use_bias = norm_layer.func == nn.InstanceNorm2d 486 | else: 487 | use_bias = norm_layer == nn.InstanceNorm2d 488 | if input_nc is None: 489 | input_nc = outer_nc 490 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, 491 | stride=2, padding=1, bias=use_bias) 492 | downrelu = nn.LeakyReLU(0.2, True) 493 | downnorm = norm_layer(inner_nc) 494 | uprelu = nn.ReLU(True) 495 | upnorm = norm_layer(outer_nc) 496 | 497 | if outermost: 498 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 499 | kernel_size=4, stride=2, 500 | padding=1) 501 | down = [downconv] 502 | up = [uprelu, upconv, nn.Tanh()] 503 | model = down + [submodule] + up 504 | elif innermost: 505 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 506 | kernel_size=4, stride=2, 507 | padding=1, bias=use_bias) 508 | down = [downrelu, downconv] 509 | up = [uprelu, upconv, upnorm] 510 | model = down + up 511 | else: 512 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 513 | kernel_size=4, stride=2, 514 | padding=1, bias=use_bias) 515 | down = [downrelu, downconv, downnorm] 516 | up = [uprelu, upconv, upnorm] 517 | 518 | if use_dropout: 519 | model = down + [submodule] + up + [nn.Dropout(0.5)] 520 | else: 521 | model = down + [submodule] + up 522 | 523 | self.model = nn.Sequential(*model) 524 | 525 | def forward(self, x): 526 | if self.outermost: 527 | return self.model(x) 528 | else: # add skip connections 529 | return torch.cat([x, self.model(x)], 1) 530 | 531 | 532 | class NLayerDiscriminator(nn.Module): 533 | """Defines a PatchGAN discriminator""" 534 | 535 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): 536 | """Construct a PatchGAN discriminator 537 | 538 | Parameters: 539 | input_nc (int) -- the number of channels in input images 540 | ndf (int) -- the number of filters in the last conv layer 541 | n_layers (int) -- the number of conv layers in the discriminator 542 | norm_layer -- normalization layer 543 | """ 544 | super(NLayerDiscriminator, self).__init__() 545 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 546 | use_bias = norm_layer.func != nn.BatchNorm2d 547 | else: 548 | use_bias = norm_layer != nn.BatchNorm2d 549 | 550 | kw = 4 551 | padw = 1 552 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 553 | nf_mult = 1 554 | nf_mult_prev = 1 555 | for n in range(1, n_layers): # gradually increase the number of filters 556 | nf_mult_prev = nf_mult 557 | nf_mult = min(2 ** n, 8) 558 | sequence += [ 559 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 560 | norm_layer(ndf * nf_mult), 561 | nn.LeakyReLU(0.2, True) 562 | ] 563 | 564 | nf_mult_prev = nf_mult 565 | nf_mult = min(2 ** n_layers, 8) 566 | sequence += [ 567 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 568 | norm_layer(ndf * nf_mult), 569 | nn.LeakyReLU(0.2, True) 570 | ] 571 | 572 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 573 | self.model = nn.Sequential(*sequence) 574 | 575 | def forward(self, input): 576 | """Standard forward.""" 577 | return self.model(input) 578 | 579 | # modified as anopred said 580 | class PixelDiscriminator(nn.Module): 581 | """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" 582 | 583 | def __init__(self, input_nc, num_filters, use_norm=False,norm_layer=nn.BatchNorm2d): 584 | """Construct a 1x1 PatchGAN discriminator 585 | 586 | Parameters: 587 | input_nc (int) -- the number of channels in input images 588 | ndf (int) -- the number of filters in the last conv layer 589 | norm_layer -- normalization layer 590 | """ 591 | ''' 592 | different from ano_pred with norm here 593 | ''' 594 | 595 | 596 | super(PixelDiscriminator, self).__init__() 597 | if use_norm: 598 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 599 | use_bias = norm_layer.func != nn.InstanceNorm2d 600 | else: 601 | use_bias = norm_layer != nn.InstanceNorm2d 602 | else: 603 | use_bias=True 604 | 605 | self.net=[] 606 | self.net.append(nn.Conv2d(input_nc,num_filters[0],kernel_size=4,padding=2,stride=2)) 607 | self.net.append(nn.LeakyReLU(0.1,True)) 608 | if use_norm: 609 | for i in range(1,len(num_filters)-1): 610 | self.net.extend([nn.Conv2d(num_filters[i-1],num_filters[i],4,2,2,bias=use_bias), 611 | nn.LeakyReLU(0.1,True), 612 | norm_layer(num_filters[i])]) 613 | else : 614 | for i in range(1,len(num_filters)-1): 615 | self.net.extend([nn.Conv2d(num_filters[i-1],num_filters[i],4,2,2,bias=use_bias), 616 | nn.LeakyReLU(0.1,True)]) 617 | self.net.append(nn.Conv2d(num_filters[-1],1,4,1,2)) 618 | # self.net = [ 619 | # nn.Conv2d(input_nc, num_filters[0], kernel_size=1, stride=1, padding=0), 620 | # nn.LeakyReLU(0.2, True), 621 | # nn.Conv2d(num_filters[0], num_filters[1], kernel_size=1, stride=1, padding=0, bias=use_bias), 622 | # norm_layer(num_filters[1]), 623 | # nn.LeakyReLU(0.2, True), 624 | # nn.Conv2d(num_filters[1], 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] 625 | 626 | self.net = nn.Sequential(*self.net) 627 | 628 | def forward(self, input): 629 | """Standard forward.""" 630 | return self.net(input) 631 | 632 | --------------------------------------------------------------------------------