├── README.md ├── authors ├── .DS_Store ├── ChengyuWang.png ├── JingShao.png ├── JunjieYan.png ├── XiaogangWang.png ├── XuJia.png ├── juntingpan.png └── lusheng.png ├── figs ├── .DS_Store ├── CVPRLogo.png ├── cvpr2019.png ├── full_architecture.png ├── gifs │ ├── .DS_Store │ ├── flow │ │ ├── .DS_Store │ │ ├── flow_1.gif │ │ ├── flow_2.gif │ │ ├── pcity_1.gif │ │ └── pcity_2.gif │ ├── generation │ │ ├── .DS_Store │ │ ├── gcity_1.gif │ │ └── gcity_2.gif │ ├── kitti │ │ ├── .DS_Store │ │ ├── kitti_1.gif │ │ ├── kitti_2.gif │ │ ├── kitti_3.gif │ │ ├── kitti_4.gif │ │ ├── kitti_5.gif │ │ └── kitti_6.gif │ ├── kth │ │ ├── .DS_Store │ │ ├── kth_1.gif │ │ ├── kth_2.gif │ │ ├── kth_3.gif │ │ ├── kth_4.gif │ │ └── kth_5.gif │ ├── length │ │ ├── .DS_Store │ │ ├── lcity_1.gif │ │ ├── lcity_3.gif │ │ ├── lcity_4.gif │ │ ├── lcity_5.gif │ │ └── lctiy_2.gif │ └── ucf101 │ │ ├── .DS_Store │ │ ├── ice_01.gif │ │ ├── ice_02.gif │ │ ├── ice_03.gif │ │ ├── violin_01.gif │ │ ├── violin_02.gif │ │ └── violin_03.gif ├── paper_thumbnail.jpg └── two_stage.png ├── gifs ├── flow │ ├── .DS_Store │ ├── flow_1.gif │ ├── flow_2.gif │ ├── pcity_1.gif │ └── pcity_2.gif ├── generation │ ├── .DS_Store │ ├── gcity_1.gif │ └── gcity_2.gif ├── kitti │ ├── .DS_Store │ ├── kitti_1.gif │ ├── kitti_2.gif │ ├── kitti_3.gif │ ├── kitti_4.gif │ ├── kitti_5.gif │ └── kitti_6.gif ├── kth │ ├── .DS_Store │ ├── kth_1.gif │ ├── kth_2.gif │ ├── kth_3.gif │ ├── kth_4.gif │ └── kth_5.gif ├── length │ ├── .DS_Store │ ├── lcity_1.gif │ ├── lcity_3.gif │ ├── lcity_4.gif │ ├── lcity_5.gif │ └── lctiy_2.gif └── ucf101 │ ├── .DS_Store │ ├── ice_01.gif │ ├── ice_02.gif │ ├── ice_03.gif │ ├── violin_01.gif │ ├── violin_02.gif │ └── violin_03.gif ├── logos ├── .DS_Store ├── cuhk.png └── sensetime.png └── src ├── dataset.py ├── datasets ├── __init__.py ├── cityscapes_dataset_w_mask.py ├── cityscapes_dataset_w_mask_pix2pixHD.py ├── cityscapes_dataset_w_mask_pix2pixHD_two_path.py ├── cityscapes_dataset_w_mask_two_path.py ├── dataset_path.py ├── kitti_dataset.py ├── kth_dataset.py └── ucf_dataset.py ├── file_list ├── cityscapes_train_sequence_full_9.txt ├── cityscapes_val_sequence_full_9.txt ├── kth_test_handwaving_16_ok.txt ├── kth_test_jogging_16_ok.txt ├── kth_test_walking_16_ok.txt ├── kth_train_handwaving_16_ok.txt ├── kth_train_jogging_16_ok.txt └── kth_train_walking_16_ok.txt ├── losses.py ├── models ├── __init__.py ├── multiframe_genmask.py ├── multiframe_w_mask_genmask.py ├── multiframe_w_mask_genmask_two_path.py ├── multiframe_w_mask_genmask_two_path_iterative.py ├── vgg_128.py └── vgg_utils.py ├── opts.py ├── test_refine.py ├── test_refine_w_mask.py ├── test_refine_w_mask_two_path.py ├── test_refine_w_mask_two_path_iterative.py ├── train_refine_multigpu.py ├── train_refine_multigpu_w_mask.py ├── train_refine_multigpu_w_mask_two_path.py └── utils ├── __init__.py ├── check_file_list.py ├── cityscapes_gen_list.py ├── cityscapes_gen_pix2pixImage_list.py ├── cityscapes_preprocess.py ├── get_ucf101_list.py ├── kth.py ├── kth_genlist.py ├── ops.py ├── semantic_segmask_order_data.py ├── svg_utils.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Video Generation from Single Semantic Label Map 2 | | ![CVPR 2019 logo][logo-cvpr] | Paper accepted at [2019 IEEE Conference on Computer Vision and Pattern Recognition (CVPR)](http://cvpr2019.thecvf.com/) | 3 | |:-:|---| 4 | 5 | [logo-cvpr]: https://github.com/junting/seg2vid/blob/junting/figs/cvpr2019.png "CVPR 2019 logo" 6 | 7 | | ![Junting Pan][JuntingPan-photo] | ![Chengyu Wang][ChengyuWang-photo] | ![Xu Jia][XuJia-photo] | ![Jing Shao][JingShao-photo] | ![Lu Sheng][LuSheng-photo] |![Junjie Yan][JunjieYan-photo] | ![Xiaogang Wang][XiaogangWang-photo] | 8 | |:-:|:-:|:-:|:-:|:-:|:-:|:-:| 9 | | [Junting Pan][JuntingPan-web] | [Chengyu Wang][ChengyuWang-web] | [Xu Jia][XuJia-web] | [Jing Shao][JingShao-web] | [Lu Sheng][LuSheng-web] | [Junjie Yan][JunjieYan-web] | [Xiaogang Wang][XiaogangWang-web] | 10 | 11 | [JuntingPan-web]: https://junting.github.io/ 12 | [ChengyuWang-web]: https://www.linkedin.com/in/chengyu-wang/ 13 | [XuJia-web]: https://stephenjia.github.io/ 14 | [JingShao-web]: https://amandajshao.github.io/ 15 | [LuSheng-web]: https://scholar.google.com.hk/citations?user=_8lB7xcAAAAJ&hl=en 16 | [JunjieYan-web]: http://www.cbsr.ia.ac.cn/users/jjyan/main.htm 17 | [XiaogangWang-web]: http://www.ee.cuhk.edu.hk/~xgwang/ 18 | 19 | [JuntingPan-photo]: https://github.com/junting/seg2vid/blob/junting/authors/juntingpan.png "Junting Pan" 20 | [ChengyuWang-photo]: https://github.com/junting/seg2vid/blob/junting/authors/ChengyuWang.png "Chengyu Wang" 21 | [XuJia-photo]: https://github.com/junting/seg2vid/blob/junting/authors/XuJia.png "Xu Jia" 22 | [JingShao-photo]: https://github.com/junting/seg2vid/blob/junting/authors/JingShao.png "JingShao" 23 | [LuSheng-photo]: https://github.com/junting/seg2vid/blob/junting/authors/lusheng.png "Lu Sheng" 24 | [JunjieYan-photo]: https://github.com/junting/seg2vid/blob/junting/authors/JunjieYan.png "Junjie Yan" 25 | [XiaogangWang-photo]: https://github.com/junting/seg2vid/blob/junting/authors/XiaogangWang.png "Xiaogang Wang" 26 | 27 | 28 | ## Abstract 29 | 30 | This paper proposes the novel task of video generation conditioned on a SINGLE semantic label map, which provides a good balance between flexibility and quality in the generation process. Different from typical end-to-end approaches, which model both scene content and dynamics in a single step, we propose to decompose this difficult task into two sub-problems. As current image generation methods do better than video generation in terms of detail, we synthesize high quality content by only generating the first frame. Then we animate the scene based on its semantic meaning to obtain the temporally coherent video, giving us excellent results overall. We employ a cVAE for predicting optical flow as a beneficial intermediate step to generate a video sequence conditioned on the initial single frame. A semantic label map is integrated into the flow prediction module to achieve major improvements in the image-to-video generation process. Extensive experiments on the Cityscapes dataset show that our method outperforms all competing methods. 31 | 32 | 33 | ## Publication 34 | 35 | Find our work on [arXiv](https://arxiv.org/abs/1903.04480). 36 | 37 | ![Image of the paper](https://github.com/junting/seg2vid/blob/junting/figs/paper_thumbnail.jpg) 38 | 39 | Please cite with the following Bibtex code: 40 | 41 | ``` 42 | @article{pan2019video, 43 | title={Video Generation from Single Semantic Label Map}, 44 | author={Pan, Junting and Wang, Chengyu and Jia, Xu and Shao, Jing and Sheng, Lu and Yan, Junjie and Wang, Xiaogang}, 45 | journal={arXiv preprint arXiv:1903.04480}, 46 | year={2019} 47 | } 48 | ``` 49 | 50 | You may also want to refer to our publication with the more human-friendly Chicago style: 51 | 52 | *Junting Pan, Chengyu Wang, Xu Jia, Jing Shao, Lu Sheng, Junjie Yan and Xiaogang Wang. "Video Generation from Single Semantic Label Map." CVPR 2019.* 53 | 54 | 55 | ## Models 56 | 57 | The Seg2Vid presented in our work can be downloaded from the links provided below the figure: 58 | 59 | Seg2Vid Architecture 60 | ![architecture-fig] 61 | 62 | Img2Vid Architecture 63 | ![img2vid-fig] 64 | 65 | * [[Img2Img pretrained models]](https://github.com/NVIDIA/pix2pixHD) 66 | * [[Img2vid pretrained models]](https://drive.google.com/drive/folders/1-EuWjU2-UOFDBCoD5JRHn0F5xbevIZZg) 67 | 68 | [architecture-fig]: https://github.com/junting/seg2vid/blob/junting/figs/two_stage.png "seg2vid architecture" 69 | [Img2vid-fig]: https://github.com/junting/seg2vid/blob/junting/figs/full_architecture.png "img2vid architecture" 70 | 71 | 72 | ## Visual Results 73 | ### Cityscapes (Generation) 74 | | ![Generated Video 1] | ![Generated Video 2] | 75 | |:-:|:-:| 76 | 77 | [Generated Video 1]:https://github.com/junting/seg2vid/blob/junting/gifs/generation/gcity_1.gif 78 | [Generated Video 2]:https://github.com/junting/seg2vid/blob/junting/gifs/generation/gcity_2.gif 79 | 80 | ### Cityscapes (Prediction given the 1st frame and its segmetation mask) 81 | | ![Predicted Video 1] | ![Predicted Video 2] | 82 | |:-:|:-:| 83 | | ![Predicted Flow 1] | ![Predicted Flow 2] | 84 | 85 | [Predicted Video 1]:https://github.com/junting/seg2vid/blob/junting/gifs/flow/pcity_1.gif 86 | [Predicted Video 2]:https://github.com/junting/seg2vid/blob/junting/gifs/flow/pcity_2.gif 87 | [Predicted Flow 1]:https://github.com/junting/seg2vid/blob/junting/gifs/flow/flow_1.gif 88 | [Predicted Flow 2]:https://github.com/junting/seg2vid/blob/junting/gifs/flow/flow_2.gif 89 | 90 | ### Cityscapes 24 frames (Prediction given the 1st frame and its segmetation mask) 91 | | ![Long Video 1] | ![Long Video 2] | ![Long Video 3] | 92 | |:-:|:-:|:-:| 93 | 94 | [Long Video 1]:https://github.com/junting/seg2vid/blob/junting/gifs/length/lcity_1.gif 95 | [Long Video 2]:https://github.com/junting/seg2vid/blob/junting/gifs/length/lctiy_2.gif 96 | [Long Video 3]:https://github.com/junting/seg2vid/blob/junting/gifs/length/lcity_3.gif 97 | 98 | ### UCF-101 (Prediction given the 1st frame) 99 | | ![UCF Video 1] | ![UCF Video 2] | ![UCF Video 3] | ![UCF Video 4] | ![UCF Video 5] | ![UCF Video 6] | 100 | |:-:|:-:|:-:|:-:|:-:|:-:| 101 | 102 | [UCF Video 1]:https://github.com/junting/seg2vid/blob/junting/gifs/ucf101/ice_01.gif 103 | [UCF Video 2]:https://github.com/junting/seg2vid/blob/junting/gifs/ucf101/ice_02.gif 104 | [UCF Video 3]:https://github.com/junting/seg2vid/blob/junting/gifs/ucf101/ice_03.gif 105 | [UCF Video 4]:https://github.com/junting/seg2vid/blob/junting/gifs/ucf101/violin_01.gif 106 | [UCF Video 5]:https://github.com/junting/seg2vid/blob/junting/gifs/ucf101/violin_02.gif 107 | [UCF Video 6]:https://github.com/junting/seg2vid/blob/junting/gifs/ucf101/violin_03.gif 108 | 109 | ### KTH (Prediction given the 1st frame) 110 | | ![KTH Video 1] | ![KTH Video 2] | ![KTH Video 3] | ![KTH Video 4] | ![KTH Video 5] | 111 | |:-:|:-:|:-:|:-:|:-:| 112 | 113 | [KTH Video 1]:https://github.com/junting/seg2vid/blob/junting/gifs/kth/kth_1.gif 114 | [KTH Video 2]:https://github.com/junting/seg2vid/blob/junting/gifs/kth/kth_2.gif 115 | [KTH Video 3]:https://github.com/junting/seg2vid/blob/junting/gifs/kth/kth_3.gif 116 | [KTH Video 4]:https://github.com/junting/seg2vid/blob/junting/gifs/kth/kth_4.gif 117 | [KTh Video 5]:https://github.com/junting/seg2vid/blob/junting/gifs/kth/kth_5.gif 118 | 119 | ## Getting Started 120 | 121 | ### Dataset 122 | - Cityscapes 123 | - Cityscapes dataset can be downloaded from the [official website](https://www.cityscapes-dataset.com/) (registration required). 124 | - We apply Deeplab-V3 [github-repo](https://github.com/tensorflow/models/tree/master/research/deeplab) to get the corresponding semantic maps. 125 | - We organize the dataset following as below: 126 | ``` 127 | seg2vid 128 | ├── authors 129 | ├── figs 130 | ├── gifs 131 | ├── logos 132 | ├── pretrained_models 133 | ├── src 134 | ├── data 135 | │ ├── cityscapes 136 | │ │ ├── leftImg8bit_sequence 137 | │ │ │ ├── train_256x128 138 | │ │ │ │ ├── aachen 139 | │ │ │ │ │ ├── aachen_000003_000019_leftImg8bit.png 140 | │ │ │ ├── val_256x2128 141 | │ │ │ ├── val_pix2pixHD 142 | │ │ │ │ ├── frankfurt 143 | │ │ │ │ │ ├── frankfurt_000000_000294_pix2pixHD.png 144 | │ │ │ ├── train_semantic_segmask 145 | │ │ │ ├── val_semantic_segmask 146 | │ │ │ │ ├── frankfurt 147 | │ │ │ │ │ ├── frankfurt_000000_000275_ssmask.png 148 | │ │ ├── gtFine 149 | │ │ │ ├── train 150 | │ │ │ ├── val 151 | │ │ │ │ ├── frankfurt 152 | │ │ │ │ │ ├── frankfurt_000000_000294_gtFine_instanceIds.png 153 | ``` 154 | 155 | - KTH 156 | - We use the [KTH human action dataset](http://www.nada.kth.se/cvap/actions/) dataset, and we follow the data processing in [svg](https://github.com/edenton/svg). 157 | - UCF-101 158 | - UCF-101 dataset can be downloader from the [official website](https://www.crcv.ucf.edu/research/data-sets/human-actions/ucf101/) 159 | 160 | ### Testing 161 | ``` 162 | python -u test_refine_w_mask_two_path.py --suffix refine_w_mask_two_path --dataset cityscapes_two_path 163 | ``` 164 | ### Training 165 | ``` 166 | python -u train_refine_multigpu_w_mask_two_path.py --batch_size 8 --dataset cityscapes_two_path 167 | ``` 168 | 169 | ### Seg2Vid on Pytorch 170 | 171 | Seg2Vid is implemented in [Pytorch](https://pytorch.org/). 172 | 173 | ## Contact 174 | 175 | If you have any general doubt about our work or code which may be of interest for other researchers, please use the [public issues section](https://github.com/junting/seg2vid/issues) on this github repo. Alternatively, drop us an e-mail at . 176 | 177 | 180 | -------------------------------------------------------------------------------- /authors/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/authors/.DS_Store -------------------------------------------------------------------------------- /authors/ChengyuWang.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/authors/ChengyuWang.png -------------------------------------------------------------------------------- /authors/JingShao.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/authors/JingShao.png -------------------------------------------------------------------------------- /authors/JunjieYan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/authors/JunjieYan.png -------------------------------------------------------------------------------- /authors/XiaogangWang.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/authors/XiaogangWang.png -------------------------------------------------------------------------------- /authors/XuJia.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/authors/XuJia.png -------------------------------------------------------------------------------- /authors/juntingpan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/authors/juntingpan.png -------------------------------------------------------------------------------- /authors/lusheng.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/authors/lusheng.png -------------------------------------------------------------------------------- /figs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/.DS_Store -------------------------------------------------------------------------------- /figs/CVPRLogo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/CVPRLogo.png -------------------------------------------------------------------------------- /figs/cvpr2019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/cvpr2019.png -------------------------------------------------------------------------------- /figs/full_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/full_architecture.png -------------------------------------------------------------------------------- /figs/gifs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/.DS_Store -------------------------------------------------------------------------------- /figs/gifs/flow/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/flow/.DS_Store -------------------------------------------------------------------------------- /figs/gifs/flow/flow_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/flow/flow_1.gif -------------------------------------------------------------------------------- /figs/gifs/flow/flow_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/flow/flow_2.gif -------------------------------------------------------------------------------- /figs/gifs/flow/pcity_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/flow/pcity_1.gif -------------------------------------------------------------------------------- /figs/gifs/flow/pcity_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/flow/pcity_2.gif -------------------------------------------------------------------------------- /figs/gifs/generation/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/generation/.DS_Store -------------------------------------------------------------------------------- /figs/gifs/generation/gcity_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/generation/gcity_1.gif -------------------------------------------------------------------------------- /figs/gifs/generation/gcity_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/generation/gcity_2.gif -------------------------------------------------------------------------------- /figs/gifs/kitti/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/kitti/.DS_Store -------------------------------------------------------------------------------- /figs/gifs/kitti/kitti_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/kitti/kitti_1.gif -------------------------------------------------------------------------------- /figs/gifs/kitti/kitti_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/kitti/kitti_2.gif -------------------------------------------------------------------------------- /figs/gifs/kitti/kitti_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/kitti/kitti_3.gif -------------------------------------------------------------------------------- /figs/gifs/kitti/kitti_4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/kitti/kitti_4.gif -------------------------------------------------------------------------------- /figs/gifs/kitti/kitti_5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/kitti/kitti_5.gif -------------------------------------------------------------------------------- /figs/gifs/kitti/kitti_6.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/kitti/kitti_6.gif -------------------------------------------------------------------------------- /figs/gifs/kth/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/kth/.DS_Store -------------------------------------------------------------------------------- /figs/gifs/kth/kth_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/kth/kth_1.gif -------------------------------------------------------------------------------- /figs/gifs/kth/kth_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/kth/kth_2.gif -------------------------------------------------------------------------------- /figs/gifs/kth/kth_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/kth/kth_3.gif -------------------------------------------------------------------------------- /figs/gifs/kth/kth_4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/kth/kth_4.gif -------------------------------------------------------------------------------- /figs/gifs/kth/kth_5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/kth/kth_5.gif -------------------------------------------------------------------------------- /figs/gifs/length/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/length/.DS_Store -------------------------------------------------------------------------------- /figs/gifs/length/lcity_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/length/lcity_1.gif -------------------------------------------------------------------------------- /figs/gifs/length/lcity_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/length/lcity_3.gif -------------------------------------------------------------------------------- /figs/gifs/length/lcity_4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/length/lcity_4.gif -------------------------------------------------------------------------------- /figs/gifs/length/lcity_5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/length/lcity_5.gif -------------------------------------------------------------------------------- /figs/gifs/length/lctiy_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/length/lctiy_2.gif -------------------------------------------------------------------------------- /figs/gifs/ucf101/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/ucf101/.DS_Store -------------------------------------------------------------------------------- /figs/gifs/ucf101/ice_01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/ucf101/ice_01.gif -------------------------------------------------------------------------------- /figs/gifs/ucf101/ice_02.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/ucf101/ice_02.gif -------------------------------------------------------------------------------- /figs/gifs/ucf101/ice_03.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/ucf101/ice_03.gif -------------------------------------------------------------------------------- /figs/gifs/ucf101/violin_01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/ucf101/violin_01.gif -------------------------------------------------------------------------------- /figs/gifs/ucf101/violin_02.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/ucf101/violin_02.gif -------------------------------------------------------------------------------- /figs/gifs/ucf101/violin_03.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/gifs/ucf101/violin_03.gif -------------------------------------------------------------------------------- /figs/paper_thumbnail.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/paper_thumbnail.jpg -------------------------------------------------------------------------------- /figs/two_stage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/figs/two_stage.png -------------------------------------------------------------------------------- /gifs/flow/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/flow/.DS_Store -------------------------------------------------------------------------------- /gifs/flow/flow_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/flow/flow_1.gif -------------------------------------------------------------------------------- /gifs/flow/flow_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/flow/flow_2.gif -------------------------------------------------------------------------------- /gifs/flow/pcity_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/flow/pcity_1.gif -------------------------------------------------------------------------------- /gifs/flow/pcity_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/flow/pcity_2.gif -------------------------------------------------------------------------------- /gifs/generation/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/generation/.DS_Store -------------------------------------------------------------------------------- /gifs/generation/gcity_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/generation/gcity_1.gif -------------------------------------------------------------------------------- /gifs/generation/gcity_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/generation/gcity_2.gif -------------------------------------------------------------------------------- /gifs/kitti/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/kitti/.DS_Store -------------------------------------------------------------------------------- /gifs/kitti/kitti_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/kitti/kitti_1.gif -------------------------------------------------------------------------------- /gifs/kitti/kitti_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/kitti/kitti_2.gif -------------------------------------------------------------------------------- /gifs/kitti/kitti_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/kitti/kitti_3.gif -------------------------------------------------------------------------------- /gifs/kitti/kitti_4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/kitti/kitti_4.gif -------------------------------------------------------------------------------- /gifs/kitti/kitti_5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/kitti/kitti_5.gif -------------------------------------------------------------------------------- /gifs/kitti/kitti_6.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/kitti/kitti_6.gif -------------------------------------------------------------------------------- /gifs/kth/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/kth/.DS_Store -------------------------------------------------------------------------------- /gifs/kth/kth_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/kth/kth_1.gif -------------------------------------------------------------------------------- /gifs/kth/kth_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/kth/kth_2.gif -------------------------------------------------------------------------------- /gifs/kth/kth_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/kth/kth_3.gif -------------------------------------------------------------------------------- /gifs/kth/kth_4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/kth/kth_4.gif -------------------------------------------------------------------------------- /gifs/kth/kth_5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/kth/kth_5.gif -------------------------------------------------------------------------------- /gifs/length/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/length/.DS_Store -------------------------------------------------------------------------------- /gifs/length/lcity_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/length/lcity_1.gif -------------------------------------------------------------------------------- /gifs/length/lcity_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/length/lcity_3.gif -------------------------------------------------------------------------------- /gifs/length/lcity_4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/length/lcity_4.gif -------------------------------------------------------------------------------- /gifs/length/lcity_5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/length/lcity_5.gif -------------------------------------------------------------------------------- /gifs/length/lctiy_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/length/lctiy_2.gif -------------------------------------------------------------------------------- /gifs/ucf101/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/ucf101/.DS_Store -------------------------------------------------------------------------------- /gifs/ucf101/ice_01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/ucf101/ice_01.gif -------------------------------------------------------------------------------- /gifs/ucf101/ice_02.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/ucf101/ice_02.gif -------------------------------------------------------------------------------- /gifs/ucf101/ice_03.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/ucf101/ice_03.gif -------------------------------------------------------------------------------- /gifs/ucf101/violin_01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/ucf101/violin_01.gif -------------------------------------------------------------------------------- /gifs/ucf101/violin_02.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/ucf101/violin_02.gif -------------------------------------------------------------------------------- /gifs/ucf101/violin_03.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/gifs/ucf101/violin_03.gif -------------------------------------------------------------------------------- /logos/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/logos/.DS_Store -------------------------------------------------------------------------------- /logos/cuhk.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/logos/cuhk.png -------------------------------------------------------------------------------- /logos/sensetime.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/STVIR/seg2vid/89f5ca98dcac4f0fd9e592302b7a1430918d6779/logos/sensetime.png -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | from datasets.dataset_path import * 2 | import os 3 | 4 | def get_training_set(opt): 5 | assert opt.dataset in ['cityscapes', 'cityscapes_two_path', 'kth'] 6 | 7 | if opt.dataset == 'cityscapes': 8 | from datasets.cityscapes_dataset_w_mask import Cityscapes 9 | 10 | train_Dataset = Cityscapes(datapath=CITYSCAPES_VAL_DATA_PATH, mask_data_path=CITYSCAPES_VAL_DATA_SEGMASK_PATH, 11 | datalist=CITYSCAPES_VAL_DATA_LIST, 12 | size=opt.input_size, split='train', split_num=1, num_frames=opt.num_frames, 13 | mask_suffix='ssmask.png', returnpath=False) 14 | 15 | elif opt.dataset == 'cityscapes_two_path': 16 | from datasets.cityscapes_dataset_w_mask_two_path import Cityscapes 17 | train_Dataset = Cityscapes(datapath=CITYSCAPES_VAL_DATA_PATH, mask_data_path=CITYSCAPES_VAL_DATA_SEGMASK_PATH, 18 | datalist=CITYSCAPES_VAL_DATA_LIST, 19 | size=opt.input_size, split='train', split_num=1, num_frames=opt.num_frames, 20 | mask_suffix='ssmask.png', returnpath=False) 21 | 22 | elif opt.dataset == 'kth': 23 | 24 | from datasets.kth_dataset import KTH 25 | train_Dataset = KTH(dataset_root=KTH_DATA_PATH, 26 | datalist=KTH_DATA_PATH_LIST, 27 | size=opt.input_size, num_frames=opt.num_frames) 28 | 29 | return train_Dataset 30 | 31 | 32 | def get_test_set(opt): 33 | assert opt.dataset in ['cityscapes', 'cityscapes_two_path', 'kth', 'ucf101', 'KITTI'] 34 | 35 | if opt.dataset == 'cityscapes': 36 | from datasets.cityscapes_dataset_w_mask import Cityscapes 37 | test_Dataset = Cityscapes(datapath=CITYSCAPES_VAL_DATA_PATH, mask_data_path=CITYSCAPES_VAL_DATA_SEGMASK_PATH, 38 | datalist=CITYSCAPES_VAL_DATA_LIST, 39 | size=opt.input_size, split='test', split_num=1, num_frames=opt.num_frames, 40 | mask_suffix='ssmask.png', returnpath=True) 41 | 42 | elif opt.dataset == 'cityscapes_two_path': 43 | from datasets.cityscapes_dataset_w_mask_two_path import Cityscapes 44 | test_Dataset = Cityscapes(datapath=CITYSCAPES_VAL_DATA_PATH, mask_data_path=CITYSCAPES_VAL_DATA_SEGMASK_PATH, 45 | datalist=CITYSCAPES_VAL_DATA_LIST, 46 | size=opt.input_size, split='test', split_num=1, num_frames=opt.num_frames, 47 | mask_suffix='ssmask.png', returnpath=True) 48 | 49 | elif opt.dataset == 'cityscapes_pix2pixHD': 50 | from cityscapes_dataloader_w_mask_pix2pixHD import Cityscapes 51 | test_Dataset = Cityscapes(datapath=CITYSCAPES_TEST_DATA_PATH, 52 | mask_data_path=CITYSCAPES_VAL_DATA_SEGMASK_PATH, 53 | datalist=CITYSCAPES_VAL_DATA_MASK_LIST, 54 | size=opt.input_size, split='test', split_num=1, 55 | num_frames=opt.num_frames, mask_suffix='ssmask.png', returnpath=True) 56 | 57 | elif opt.dataset == 'kth': 58 | from datasets.kth_dataset import KTH 59 | test_Dataset = KTH(dataset_root=KTH_DATA_PATH, 60 | datalist='./file_list/kth_test_%s_16_ok.txt' % opt.category, 61 | size=opt.input_size, num_frames=opt.num_frames) 62 | 63 | elif opt.dataset == 'KITTI': 64 | from datasets.kitti_dataset import KITTI 65 | kitti_dataset_list = os.listdir(KITTI_DATA_PATH) 66 | test_Dataset = KITTI(datapath=KITTI_DATA_PATH, datalist=kitti_dataset_list, size=opt.input_size, 67 | returnpath=True) 68 | 69 | elif opt.dataset == 'ucf101': 70 | from datasets.ucf_dataset import UCF101 71 | test_Dataset = UCF101(datapath=os.path.join(UCF_101_DATA_PATH, opt.category), 72 | datalist=os.path.join(UCF_101_DATA_PATH, 'list/test%s.txt' % (opt.category.lower())), 73 | returnpath=True) 74 | 75 | return test_Dataset 76 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_path import * 2 | from .cityscapes_dataset_w_mask_pix2pixHD import * 3 | from .cityscapes_dataset_w_mask_pix2pixHD_two_path import * 4 | from .cityscapes_dataset_w_mask import * 5 | from .cityscapes_dataset_w_mask_two_path import * 6 | from .kitti_dataset import * 7 | from .kth_dataset import * 8 | from .ucf_dataset import * 9 | 10 | 11 | # def model_entry(config): 12 | # return globals()[config['arch']](**config['kwargs']) 13 | -------------------------------------------------------------------------------- /src/datasets/cityscapes_dataset_w_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from torch.utils.data import Dataset, DataLoader 5 | from torchvision import utils 6 | import random 7 | import cv2 8 | import re 9 | import time 10 | from scipy.misc import imread 11 | random.seed(1234) 12 | 13 | 14 | # num_class = 20 15 | def cv2_tensor(pic): 16 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 17 | img = img.view(pic.shape[0], pic.shape[1], 3) 18 | img = img.transpose(0, 2).transpose(1, 2).contiguous() 19 | return img.float().div(255) 20 | 21 | 22 | def replace_index_and_read(image_dir, indx, size): 23 | new_dir = image_dir[0:-22] + str(int(image_dir[-22:-16]) + indx).zfill(6) + image_dir[-16::] 24 | 25 | img = cv2.resize(cv2.cvtColor(cv2.imread(new_dir, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB), (size[1], size[0])) 26 | # except: 27 | # print ('orgin_dir: ' + image_dir) 28 | # print ('new_dir: ' + new_dir) 29 | # # center crop 30 | # if img.shape[0] != img.shape[1]: 31 | # frame = cv2_tensor(img[:, 64:64+128]) 32 | # else: 33 | frame = cv2_tensor(img) 34 | return frame 35 | 36 | 37 | def load_mask(mask_dir, size): 38 | mask = imread(mask_dir) 39 | mask = cv2.resize(mask, (size[1], size[0]), interpolation=cv2.INTER_NEAREST) 40 | mask_volume = np.concatenate([np.expand_dims(mask, 0) == i for i in range(0, 20)], 0).astype(int) 41 | mask_volume = torch.from_numpy(mask_volume).contiguous().type(torch.FloatTensor) 42 | return mask_volume 43 | 44 | 45 | def imagetoframe(image_dir, size, num_frame): 46 | 47 | samples = [replace_index_and_read(image_dir, indx, size) for indx in range(num_frame)] 48 | return torch.stack(samples) 49 | 50 | 51 | def complete_full_list(image_dir, num_frames, output_name): 52 | dir_list = [image_dir[0:-22] + str(int(image_dir[-22:-16]) + i).zfill(6) + '_' + output_name for i in range(num_frames)] 53 | return dir_list 54 | 55 | 56 | class Cityscapes(Dataset): 57 | def __init__(self, datapath, mask_data_path, datalist, num_frames=5, size=(128, 128), split='train', split_num=1600, mask_suffix='gtFine_labelIds.png', returnpath=False): 58 | self.datapath = datapath 59 | # if split is 'train': 60 | # self.datalist = open(datalist).readlines()[0:-split_num] 61 | # else: 62 | # self.datalist = open(datalist).readlines()[-split_num::] 63 | self.datalist = open(datalist).readlines() 64 | self.size = size 65 | self.num_frame = num_frames 66 | self.mask_root = mask_data_path 67 | self.mask_suffix = mask_suffix 68 | self.returnPath = returnpath 69 | 70 | def __len__(self): 71 | return len(self.datalist) 72 | 73 | def __getitem__(self, idx): 74 | image_dir = os.path.join(self.datapath, self.datalist[idx].strip()) 75 | sample = imagetoframe(image_dir, self.size, self.num_frame) 76 | # mask_dir = os.path.join(self.mask_root, self.datalist[idx].strip()[0:-15]+'gtFine_labelIds.png') 77 | mask_dir = os.path.join(self.mask_root, self.datalist[idx].strip()[0:-15]+self.mask_suffix) 78 | mask = load_mask(mask_dir, self.size) 79 | if self.returnPath: 80 | return sample, mask, complete_full_list(self.datalist[idx].strip(), self.num_frame, 'pred.png') 81 | else: 82 | return sample, mask 83 | if __name__ == '__main__': 84 | 85 | start_time = time.time() 86 | from dataset_path import * 87 | cityscapes_Dataset = Cityscapes(datapath=CITYSCAPES_VAL_DATA_PATH, mask_data_path=CITYSCAPES_VAL_DATA_SEGMASK_PATH, 88 | datalist=CITYSCAPES_VAL_DATA_MASK_LIST, 89 | size=(128, 256), split='train', split_num=1, 90 | num_frames=5, mask_suffix='ssmask.png') 91 | 92 | dataloader = DataLoader(cityscapes_Dataset, batch_size=32, shuffle=False, num_workers=4) 93 | 94 | sample, mask = iter(dataloader).next() 95 | print (sample.shape) 96 | print (mask.shape) 97 | 98 | -------------------------------------------------------------------------------- /src/datasets/cityscapes_dataset_w_mask_pix2pixHD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from torch.utils.data import Dataset, DataLoader 5 | from torchvision import utils 6 | import random 7 | import cv2 8 | import re 9 | import time 10 | from scipy.misc import imread 11 | random.seed(1234) 12 | 13 | # num_class = 20 14 | 15 | 16 | def cv2_tensor(pic): 17 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 18 | img = img.view(pic.shape[0], pic.shape[1], 3) 19 | img = img.transpose(0, 2).transpose(1, 2).contiguous() 20 | return img.float().div(255) 21 | 22 | 23 | def replace_index_and_read(image_dir, indx, size): 24 | new_dir = image_dir[0:-22] + str(int(image_dir[-22:-16]) + indx).zfill(6) + '_pix2pixHD.png' 25 | try: 26 | img = cv2.resize(cv2.cvtColor(cv2.imread(new_dir, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB), (size[1], size[0])) 27 | except: 28 | print ('orgin_dir: ' + image_dir) 29 | print ('new_dir: ' + new_dir) 30 | # center crop 31 | if img.shape[0] != img.shape[1]: 32 | frame = cv2_tensor(img[:, 64:64+128]) 33 | else: 34 | frame = cv2_tensor(img) 35 | return frame 36 | 37 | 38 | def load_mask(mask_dir, size): 39 | mask = imread(mask_dir) 40 | mask = cv2.resize(mask, (size[1], size[0]), interpolation=cv2.INTER_NEAREST) 41 | # if flag == 'fg': 42 | # mask_volume = np.concatenate([np.expand_dims(mask, 0) == i for i in range(11, 20)], 0).astype(int) 43 | # else: 44 | # mask_volume = np.concatenate([np.expand_dims(mask, 0) == i for i in range(0, 11)], 0).astype(int) 45 | mask_volume = np.concatenate([np.expand_dims(mask, 0) == i for i in range(0, 20)], 0).astype(int) 46 | mask_volume = torch.from_numpy(mask_volume).contiguous().type(torch.FloatTensor) 47 | return mask_volume 48 | 49 | 50 | def imagetoframe(image_dir, size, num_frame): 51 | 52 | samples = [replace_index_and_read(image_dir, indx, size) for indx in range(num_frame)] 53 | return torch.stack(samples) 54 | 55 | 56 | def complete_full_list(image_dir, num_frames, output_name): 57 | dir_list = [image_dir[0:-22] + str(int(image_dir[-22:-16]) + i).zfill(6) + '_' + output_name for i in range(num_frames)] 58 | return dir_list 59 | 60 | 61 | class Cityscapes(Dataset): 62 | def __init__(self, datapath, mask_data_path, datalist, num_frames=5, size=(128, 128), mask_suffix='gtFine_labelIds.png', returnpath=False): 63 | self.datapath = datapath 64 | # if split is 'train': 65 | # self.datalist = open(datalist).readlines()[0:-split_num] 66 | # else: 67 | # self.datalist = open(datalist).readlines()[-split_num::] 68 | self.datalist = open(datalist).readlines() 69 | self.size = size 70 | self.num_frame = num_frames 71 | self.mask_root = mask_data_path 72 | self.mask_suffix = mask_suffix 73 | self.returnPath = returnpath 74 | 75 | def __len__(self): 76 | return len(self.datalist) 77 | 78 | def __getitem__(self, idx): 79 | image_dir = os.path.join(self.datapath, self.datalist[idx].strip())[0:-15]+'pix2pixHD.png' 80 | # sample = imagetoframe(image_dir, self.size, self.num_frame) 81 | img = cv2.resize(cv2.cvtColor(cv2.imread(image_dir, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB), (self.size[1], self.size[0])) 82 | sample = cv2_tensor(img) 83 | sample = sample.repeat(self.num_frame, 1, 1, 1) 84 | mask_dir = os.path.join(self.mask_root, self.datalist[idx].strip()[0:-15]+self.mask_suffix) 85 | # bg_mask = load_mask(mask_dir, self.size, 'bg') 86 | # fg_mask = load_mask(mask_dir, self.size, 'fg') 87 | mask = load_mask(mask_dir, self.size) 88 | 89 | if self.returnPath: 90 | return sample, mask, complete_full_list(self.datalist[idx].strip(), self.num_frame, 'pred.png') 91 | else: 92 | return sample, mask 93 | 94 | 95 | if __name__ == '__main__': 96 | 97 | start_time = time.time() 98 | from dataset_path import * 99 | 100 | cityscapes_Dataset = Cityscapes(datapath=CITYSCAPES_TEST_DATA_PATH, 101 | mask_data_path=CITYSCAPES_VAL_DATA_SEGMASK_PATH, 102 | datalist=CITYSCAPES_VAL_DATA_MASK_LIST, 103 | size=(128, 128), split='train', split_num=1, num_frames=8, mask_suffix='ssmask.png') 104 | 105 | dataloader = DataLoader(cityscapes_Dataset, batch_size=32, shuffle=False, num_workers=8) 106 | 107 | sample, mask= iter(dataloader).next() 108 | print (sample.shape) 109 | print (mask.shape) 110 | -------------------------------------------------------------------------------- /src/datasets/cityscapes_dataset_w_mask_pix2pixHD_two_path.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from torch.utils.data import Dataset, DataLoader 5 | from torchvision import utils 6 | import random 7 | import cv2 8 | import re 9 | import time 10 | from scipy.misc import imread 11 | random.seed(1234) 12 | 13 | # num_class = 20 14 | 15 | 16 | def cv2_tensor(pic): 17 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 18 | img = img.view(pic.shape[0], pic.shape[1], 3) 19 | img = img.transpose(0, 2).transpose(1, 2).contiguous() 20 | return img.float().div(255) 21 | 22 | 23 | def replace_index_and_read(image_dir, indx, size): 24 | new_dir = image_dir[0:-22] + str(int(image_dir[-22:-16]) + indx).zfill(6) + '_pix2pixHD.png' 25 | try: 26 | img = cv2.resize(cv2.cvtColor(cv2.imread(new_dir, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB), (size[1], size[0])) 27 | except: 28 | print ('orgin_dir: ' + image_dir) 29 | print ('new_dir: ' + new_dir) 30 | # center crop 31 | if img.shape[0] != img.shape[1]: 32 | frame = cv2_tensor(img[:, 64:64+128]) 33 | else: 34 | frame = cv2_tensor(img) 35 | return frame 36 | 37 | 38 | def load_mask(mask_dir, size, flag): 39 | mask = imread(mask_dir) 40 | mask = cv2.resize(mask, (size[1], size[0]), interpolation=cv2.INTER_NEAREST) 41 | if flag == 'fg': 42 | mask_volume = np.concatenate([np.expand_dims(mask, 0) == i for i in range(11, 20)], 0).astype(int) 43 | else: 44 | mask_volume = np.concatenate([np.expand_dims(mask, 0) == i for i in range(0, 11)], 0).astype(int) 45 | mask_volume = torch.from_numpy(mask_volume).contiguous().type(torch.FloatTensor) 46 | return mask_volume 47 | 48 | 49 | def imagetoframe(image_dir, size, num_frame): 50 | 51 | samples = [replace_index_and_read(image_dir, indx, size) for indx in range(num_frame)] 52 | return torch.stack(samples) 53 | 54 | 55 | def complete_full_list(image_dir, num_frames, output_name): 56 | dir_list = [image_dir[0:-22] + str(int(image_dir[-22:-16]) + i).zfill(6) + '_' + output_name for i in range(num_frames)] 57 | return dir_list 58 | 59 | 60 | class Cityscapes(Dataset): 61 | def __init__(self, datapath, mask_data_path, datalist, num_frames=5, size=(128, 128), split='train', split_num=1600, mask_suffix='gtFine_labelIds.png', returnpath=False): 62 | self.datapath = datapath 63 | # if split is 'train': 64 | # self.datalist = open(datalist).readlines()[0:-split_num] 65 | # else: 66 | # self.datalist = open(datalist).readlines()[-split_num::] 67 | self.datalist = open(datalist).readlines() 68 | self.size = size 69 | self.num_frame = num_frames 70 | self.mask_root = mask_data_path 71 | self.mask_suffix = mask_suffix 72 | self.returnPath = returnpath 73 | 74 | def __len__(self): 75 | return len(self.datalist) 76 | 77 | def __getitem__(self, idx): 78 | image_dir = os.path.join(self.datapath, self.datalist[idx].strip())[0:-15]+'pix2pixHD.png' 79 | # sample = imagetoframe(image_dir, self.size, self.num_frame) 80 | img = cv2.resize(cv2.cvtColor(cv2.imread(image_dir, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB), (self.size[1], self.size[0])) 81 | sample = cv2_tensor(img) 82 | sample = sample.repeat(self.num_frame, 1, 1, 1) 83 | mask_dir = os.path.join(self.mask_root, self.datalist[idx].strip()[0:-15]+self.mask_suffix) 84 | bg_mask = load_mask(mask_dir, self.size, 'bg') 85 | fg_mask = load_mask(mask_dir, self.size, 'fg') 86 | if self.returnPath: 87 | return sample, bg_mask, fg_mask, complete_full_list(self.datalist[idx].strip(), self.num_frame, 'pred.png') 88 | else: 89 | return sample, bg_mask, fg_mask 90 | 91 | 92 | if __name__ == '__main__': 93 | 94 | start_time = time.time() 95 | from dataset_path import * 96 | 97 | cityscapes_Dataset = Cityscapes(datapath=CITYSCAPES_TEST_DATA_PATH, 98 | mask_data_path=CITYSCAPES_VAL_DATA_SEGMASK_PATH, 99 | datalist=CITYSCAPES_VAL_DATA_MASK_LIST, 100 | size=(128, 128), split='train', split_num=1, num_frames=8, mask_suffix='ssmask.png') 101 | 102 | dataloader = DataLoader(cityscapes_Dataset, batch_size=32, shuffle=False, num_workers=8) 103 | 104 | sample, bg_mask, fg_mask = iter(dataloader).next() 105 | print (sample.shape) 106 | print (bg_mask.shape) 107 | print (fg_mask.shape) 108 | 109 | -------------------------------------------------------------------------------- /src/datasets/cityscapes_dataset_w_mask_two_path.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from torch.utils.data import Dataset, DataLoader 5 | from torchvision import utils 6 | import random 7 | import cv2 8 | import re 9 | import time 10 | from scipy.misc import imread 11 | random.seed(1234) 12 | 13 | # num_class = 20 14 | def cv2_tensor(pic): 15 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 16 | img = img.view(pic.shape[0], pic.shape[1], 3) 17 | img = img.transpose(0, 2).transpose(1, 2).contiguous() 18 | return img.float().div(255) 19 | 20 | def replace_index_and_read(image_dir, indx, size): 21 | new_dir = image_dir[0:-22] + str(int(image_dir[-22:-16]) + indx).zfill(6) + image_dir[-16::] 22 | try: 23 | img = cv2.resize(cv2.cvtColor(cv2.imread(new_dir, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB), (size[1], size[0])) 24 | except: 25 | print ('orgin_dir: ' + image_dir) 26 | print ('new_dir: ' + new_dir) 27 | # center crop 28 | # if img.shape[0] != img.shape[1]: 29 | # frame = cv2_tensor(img[:, 64:64+128]) 30 | # else: 31 | # frame = cv2_tensor(img) 32 | frame = cv2_tensor(img) 33 | return frame 34 | 35 | def load_mask(mask_dir, size, flag): 36 | mask = imread(mask_dir) 37 | mask = cv2.resize(mask, (size[1], size[0]), interpolation=cv2.INTER_NEAREST) 38 | if flag == 'fg': 39 | mask_volume = np.concatenate([np.expand_dims(mask, 0) == i for i in range(11, 20)], 0).astype(int) 40 | else: 41 | mask_volume = np.concatenate([np.expand_dims(mask, 0) == i for i in range(0, 11)], 0).astype(int) 42 | mask_volume = torch.from_numpy(mask_volume).contiguous().type(torch.FloatTensor) 43 | return mask_volume 44 | 45 | def imagetoframe(image_dir, size, num_frame): 46 | 47 | samples = [replace_index_and_read(image_dir, indx, size) for indx in range(num_frame)] 48 | return torch.stack(samples) 49 | 50 | def complete_full_list(image_dir, num_frames, output_name): 51 | dir_list = [image_dir[0:-22] + str(int(image_dir[-22:-16]) + i).zfill(6) + '_' + output_name for i in range(num_frames)] 52 | return dir_list 53 | 54 | class Cityscapes(Dataset): 55 | def __init__(self, datapath, mask_data_path, datalist, num_frames=5, size=(128, 128), split='train', split_num=1600, 56 | mask_suffix='gtFine_labelIds.png', returnpath=False): 57 | self.datapath = datapath 58 | # if split is 'train': 59 | # self.datalist = open(datalist).readlines()[0:-split_num] 60 | # else: 61 | # self.datalist = open(datalist).readlines()[-split_num::] 62 | self.datalist = open(datalist).readlines() 63 | self.size = size 64 | self.num_frame = num_frames 65 | self.mask_root = mask_data_path 66 | self.mask_suffix = mask_suffix 67 | self.returnPath = returnpath 68 | def __len__(self): 69 | return len(self.datalist) 70 | 71 | def __getitem__(self, idx): 72 | image_dir = os.path.join(self.datapath, self.datalist[idx].strip()) 73 | sample = imagetoframe(image_dir, self.size, self.num_frame) 74 | # mask_dir = os.path.join(self.mask_root, self.datalist[idx].strip()[0:-15]+'gtFine_labelIds.png') 75 | mask_dir = os.path.join(self.mask_root, self.datalist[idx].strip()[0:-15]+self.mask_suffix) 76 | bg_mask = load_mask(mask_dir, self.size, 'bg') 77 | fg_mask = load_mask(mask_dir, self.size, 'fg') 78 | if self.returnPath: 79 | return sample, bg_mask, fg_mask, complete_full_list(self.datalist[idx].strip(), self.num_frame, 'pred.png') 80 | else: 81 | return sample, bg_mask, fg_mask 82 | 83 | if __name__ == '__main__': 84 | 85 | start_time = time.time() 86 | 87 | from dataset_path import * 88 | 89 | cityscapes_Dataset = Cityscapes(datapath=CITYSCAPES_VAL_DATA_PATH, mask_data_path=CITYSCAPES_VAL_DATA_SEGMASK_PATH, 90 | datalist=CITYSCAPES_VAL_DATA_MASK_LIST, 91 | size=(256, 128), split='train', split_num=1, 92 | num_frames=5, mask_suffix='ssmask.png', returnpath=True) 93 | 94 | dataloader = DataLoader(cityscapes_Dataset, batch_size=32, shuffle=False, num_workers=4) 95 | 96 | sample, bg_mask, fg_mask, paths = iter(dataloader).next() 97 | print (sample.shape) 98 | print (bg_mask.shape) 99 | print (fg_mask.shape) 100 | import pdb 101 | pdb.set_trace() 102 | 103 | # import pdb 104 | # pdb.set_trace() 105 | 106 | ''' 107 | (Pdb) paths[0][0] 108 | 'frankfurt/frankfurt_000000_013067_pred.png' 109 | (Pdb) paths[1][0] 110 | 'frankfurt/frankfurt_000000_013068_pred.png' 111 | ''' 112 | 113 | 114 | -------------------------------------------------------------------------------- /src/datasets/dataset_path.py: -------------------------------------------------------------------------------- 1 | CITYSCAPES_TRAIN_DATA_PATH = '/DATAshare/leftImg8bit_sequence/train_512x256/' 2 | CITYSCAPES_VAL_DATA_PATH = '/DATAshare/leftImg8bit_sequence/val_512x256/' 3 | 4 | CITYSCAPES_TRAIN_DATA_LIST = '/panjunting/video_generation/data/cityscapes_train_sequence_full.txt' 5 | CITYSCAPES_VAL_DATA_LIST = '/panjunting/video_generation/data/cityscapes_val_sequence_full.txt' 6 | 7 | CITYSCAPES_TRAIN_DATA_LIST_8 = '/panjunting/video_generation/data/cityscapes_train_sequence_full_8.txt' 8 | CITYSCAPES_VAL_DATA_LIST_8 = '/panjunting/video_generation/data/cityscapes_val_sequence_full_8.txt' 9 | 10 | CITYSCAPES_TEST_DATA_PATH = '/DATAshare/unzip/leftImg8bit_sequence/val_pix2pixHD/' 11 | 12 | CITYSCAPES_TRAIN_MASK_PATH = '/DATAshare/gtFine/train/' 13 | CITYSCAPES_TRAIN_DATA_MASK_LIST = '/panjunting/video_generation/data/cityscapes_train_sequence_w_mask.txt' 14 | 15 | CITYSCAPES_VAL_MASK_PATH = '/DATAshare/gtFine/val/' 16 | CITYSCAPES_VAL_DATA_MASK_LIST = '/panjunting/video_generation/data/cityscapes_val_sequence_w_mask.txt' 17 | 18 | CITYSCAPES_TRAIN_DATA_MASK_LIST_8 = '/panjunting/video_generation/data/cityscapes_train_sequence_w_mask_8.txt' 19 | CITYSCAPES_VAL_DATA_MASK_LIST_8 = '/panjunting/video_generation/data/cityscapes_val_sequence_w_mask_8.txt' 20 | 21 | CITYSCAPES_TRAIN_DATA_SEGMASK_PATH = '/DATAshare/unzip/leftImg8bit_sequence/train_semantic_segmask' 22 | CITYSCAPES_VAL_DATA_SEGMASK_PATH = '/DATAshare/unzip/leftImg8bit_sequence/val_semantic_segmask' 23 | 24 | KTH_DATA_PATH = '/panjunting/kth/KTH/processed' 25 | UCF_101_DATA_PATH = '/panjunting/f2video2.0/UCF-101' 26 | 27 | KITTI_DATA_PATH = '/panjunting/kitti' 28 | -------------------------------------------------------------------------------- /src/datasets/kitti_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from torch.utils.data import Dataset, DataLoader 5 | from torchvision import utils 6 | import random 7 | import cv2 8 | import re 9 | import time 10 | random.seed(1234) 11 | 12 | 13 | def cv2_tensor(pic): 14 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 15 | img = img.view(pic.shape[0], pic.shape[1], 3) 16 | img = img.transpose(0, 2).transpose(1, 2).contiguous() 17 | return img.float().div(255) 18 | 19 | 20 | class KITTI(Dataset): 21 | def __init__(self, datapath, datalist, size=(128, 128), returnpath=False): 22 | self.datapath = datapath 23 | # self.datalist = open(datalist).readlines() 24 | self.datalist = datalist 25 | self.size = size 26 | self.returnpath = returnpath 27 | 28 | def __len__(self): 29 | return len(self.datalist) 30 | 31 | def __getitem__(self, idx): 32 | image_name = os.path.join(self.datapath, self.datalist[idx].strip()) 33 | img = cv2.resize(cv2.cvtColor(cv2.imread(image_name, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB), (self.size[1], self.size[0])) 34 | sample = cv2_tensor(img) 35 | if self.returnpath: 36 | return sample, self.datalist[idx][0:-4] 37 | return sample 38 | 39 | 40 | if __name__ == '__main__': 41 | 42 | start_time = time.time() 43 | 44 | kitti_dataset_path = '/mnt/lustre/panjunting/kitti' 45 | kitti_dataset_list = os.listdir(kitti_dataset_path) 46 | 47 | kitti_Dataset = KITTI(datapath=kitti_dataset_path, 48 | datalist=kitti_dataset_list, 49 | size=(128, 256), returnpath=True) 50 | 51 | dataloader = DataLoader(kitti_Dataset, batch_size=32, shuffle=False, num_workers=8) 52 | 53 | sample, path = iter(dataloader).next() 54 | import pdb 55 | pdb.set_trace() 56 | print (sample.shape) # # 57 | # spent_time = time.time() - start_time 58 | # print spent_time 59 | # from tqdm import tqdm 60 | # a= [ 1 for sample in tqdm(iter(dataloader))] 61 | -------------------------------------------------------------------------------- /src/datasets/kth_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from torch.utils.data import Dataset, DataLoader 5 | from torchvision import utils 6 | import random 7 | import cv2 8 | import re 9 | import time 10 | random.seed(1234) 11 | 12 | 13 | def cv2_tensor(pic): 14 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 15 | img = img.view(pic.shape[0], pic.shape[1], 1) 16 | img = img.transpose(0, 2).transpose(1, 2).contiguous() 17 | return img.float().div(255) 18 | 19 | def replace_index_and_read(image_dir, indx, size): 20 | new_dir = image_dir[0:-15] + str(int(image_dir[-15:-12]) + indx).zfill(3) + image_dir[-12::] 21 | try: 22 | img = cv2.resize(cv2.imread(new_dir, 0), size) 23 | except: 24 | print ('orgin_dir: ' + image_dir) 25 | print ('new_dir: ' + new_dir) 26 | # center crop 27 | # img = cv2.resize(cv2.imread(new_dir, 0), size) 28 | frame = cv2_tensor(img) 29 | return frame 30 | 31 | def imagetoframe(image_dir, size, num_frame): 32 | 33 | samples = [replace_index_and_read(image_dir, indx, size) for indx in range(num_frame)] 34 | return torch.stack(samples) 35 | 36 | def get_path_list(image_dir, num_frame): 37 | new_dirs = [image_dir[1:-15] + str(int(image_dir[-15:-12]) + indx).zfill(3) + image_dir[-12::] for indx in range(num_frame)] 38 | return new_dirs 39 | 40 | class KTH(Dataset): 41 | def __init__(self, dataset_root, datalist, num_frames=5, size=(128, 128), returnpath=False): 42 | self.datalist = open(datalist).readlines() 43 | self.size = size 44 | self.num_frame = num_frames 45 | self.dataset_root = dataset_root 46 | self.returnpath = returnpath 47 | def __len__(self): 48 | return len(self.datalist) 49 | 50 | def __getitem__(self, idx): 51 | sample = imagetoframe(self.dataset_root+self.datalist[idx].strip(), self.size, self.num_frame) 52 | 53 | if self.returnpath: 54 | paths = get_path_list(self.datalist[idx].strip(), self.num_frame) 55 | return sample, paths 56 | else: 57 | return sample 58 | 59 | 60 | if __name__ == '__main__': 61 | 62 | start_time = time.time() 63 | cityscapes_Dataset = KTH(dataset_root='/mnt/lustrenew/panjunting/kth/KTH/processed', datalist='kth_train_16.txt', 64 | size=(128, 128), num_frames=16) 65 | 66 | dataloader = DataLoader(cityscapes_Dataset, batch_size=32, shuffle=False, num_workers=1) 67 | 68 | sample = iter(dataloader).next() 69 | print (sample.size()) 70 | # from tqdm import tqdm 71 | # a= [ 1 for sample in tqdm(iter(dataloader))] 72 | -------------------------------------------------------------------------------- /src/datasets/ucf_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from torch.utils.data import Dataset, DataLoader 5 | from torchvision import utils 6 | import random 7 | 8 | import time 9 | random.seed(1234) 10 | 11 | 12 | class UCF101(Dataset): 13 | def __init__(self, datapath, datalist, num_frame=5, size=128, returnpath=False): 14 | self.datapath = datapath 15 | self.datalist = open(datalist).readlines() 16 | self.numframe = num_frame 17 | self.size = size 18 | self.returnpath = returnpath 19 | 20 | def __len__(self): 21 | return len(self.datalist) 22 | 23 | def get_path_list(self, image_dir, num_frame, start): 24 | # new_dirs = [image_dir[1:-15] + str(int(image_dir[-15:-12]) + indx).zfill(3) + image_dir[-12::] for indx in 25 | # range(num_frame)] 26 | # 27 | video_folder = os.path.join(image_dir[0:-4], str(start)) 28 | # new_dirs = [os.path.join(video_folder, '%d.png'%indx) for indx in range(num_frame)] 29 | new_dirs = [video_folder for indx in range(num_frame)] 30 | return new_dirs 31 | 32 | def __getitem__(self, idx): 33 | 34 | item = np.load(os.path.join(self.datapath, self.datalist[idx].split(' ')[0]).strip()) 35 | start = int(self.datalist[idx].split(' ')[1]) 36 | item = item[start:start + self.numframe, :, :, :] / 255.0 37 | 38 | data = torch.from_numpy(np.array(item)) 39 | data = data.contiguous() 40 | # print data.shape 41 | data = data.transpose(2, 3).transpose(1, 2) 42 | data = data.float() 43 | 44 | if self.size == 64: 45 | bs, T, c, h, w = data.size() 46 | data = Vb(data, requires_grad=False) 47 | data = F.avg_pool2d(data.view(-1, c, h, w), 2, 2) 48 | data = data.view(bs, T, c, h / 2, w / 2).data 49 | 50 | if self.returnpath: 51 | paths = self.get_path_list(self.datalist[idx].split(' ')[0].strip(), self.numframe, start) 52 | return data, paths 53 | return data 54 | 55 | 56 | if __name__ == '__main__': 57 | start_time = time.time() 58 | 59 | # v_IceDancing_g06_c01.npy 60 | 61 | Path = '/mnt/lustre/panjunting/f2video2.0/UCF-101' 62 | 63 | # train_Dataset = UCF101(datapath=os.path.join(Path, 'IceDancing'), 64 | # datalist=os.path.join(Path, 'list/trainicedancing.txt')) 65 | test_Dataset = UCF101(datapath=os.path.join(Path, 'IceDancing'), 66 | datalist=os.path.join(Path, 'list/testicedancing.txt'), returnpath=True) 67 | 68 | dataloader = DataLoader(test_Dataset, batch_size=32, shuffle=True, num_workers=8) 69 | 70 | sample, path = iter(dataloader).next() 71 | # print sample.size() 72 | import pdb 73 | pdb.set_trace() 74 | spent_time = time.time() - start_time 75 | # print spent_time 76 | # from tqdm import tqdm 77 | # i = 0 78 | # for sample in tqdm(iter(dataloader)): 79 | # if i ==0: 80 | # x = sample.shape 81 | # i=1 82 | # # print sample.shape 83 | # if sample.shape != x: 84 | # print sample.shape -------------------------------------------------------------------------------- /src/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable as Vb 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | import torchvision.utils as tov 7 | import cv2 8 | import datetime 9 | import numpy as np 10 | import utils.ops 11 | from utils.utils import * 12 | from models.vgg_utils import my_vgg 13 | from utils import ops 14 | from torchvision import transforms as trn 15 | 16 | 17 | preprocess = trn.Compose([ 18 | # trn.ToTensor(), 19 | trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 20 | ]) 21 | 22 | mean = Vb(torch.FloatTensor([0.485, 0.456, 0.406])).view([1,3,1,1]) 23 | std = Vb(torch.FloatTensor([0.229, 0.224, 0.225])).view([1,3,1,1]) 24 | 25 | 26 | def normalize(x): 27 | gpu_id = x.get_device() 28 | return (x- mean.cuda(gpu_id))/std.cuda(gpu_id) 29 | 30 | 31 | class TrainingLoss(object): 32 | 33 | def __init__(self, opt, flowwarpper): 34 | self.opt = opt 35 | self.flowwarp = flowwarpper 36 | 37 | def gdloss(self, a, b): 38 | xloss = torch.sum( 39 | torch.abs(torch.abs(a[:, :, 1:, :] - a[:, :, :-1, :]) - torch.abs(b[:, :, 1:, :] - b[:, :, :-1, :]))) 40 | yloss = torch.sum( 41 | torch.abs(torch.abs(a[:, :, :, 1:] - a[:, :, :, :-1]) - torch.abs(b[:, :, :, 1:] - b[:, :, :, :-1]))) 42 | return (xloss + yloss) / (a.size()[0] * a.size()[1] * a.size()[2] * a.size()[3]) 43 | 44 | def vgg_loss(self, y_pred_feat, y_true_feat): 45 | loss = 0 46 | for i in range(len(y_pred_feat)): 47 | loss += (y_true_feat[i] - y_pred_feat[i]).abs().mean() 48 | return loss 49 | 50 | def _quickflowloss(self, flow, img, neighber=5, alpha=1): 51 | flow = flow * 128 52 | img = img * 256 53 | bs, c, h, w = img.size() 54 | center = int((neighber - 1) / 2) 55 | loss = [] 56 | neighberrange = list(range(neighber)) 57 | neighberrange.remove(center) 58 | for i in neighberrange: 59 | for j in neighberrange: 60 | flowsub = (flow[:, :, center:-center, center:-center] - 61 | flow[:, :, i:h - (neighber - i - 1), j:w - (neighber - j - 1)]) ** 2 62 | imgsub = (img[:, :, center:-center, center:-center] - 63 | img[:, :, i:h - (neighber - i - 1), j:w - (neighber - j - 1)]) ** 2 64 | flowsub = flowsub.sum(1) 65 | imgsub = imgsub.sum(1) 66 | indexsub = (i - center) ** 2 + (j - center) ** 2 67 | loss.append(flowsub * torch.exp(-alpha * imgsub - indexsub)) 68 | return torch.stack(loss).sum() / (bs * w * h) 69 | 70 | def quickflowloss(self, flow, img, t=1): 71 | flowloss = 0. 72 | for ii in range(t): 73 | flowloss += self._quickflowloss(flow[:, :, ii, :, :], img[:, ii, :, :, :]) 74 | return flowloss 75 | 76 | def _flowgradloss(self, flow, image): 77 | flow = flow * 128 78 | image = image * 256 79 | flowgradx = ops.gradientx(flow) 80 | flowgrady = ops.gradienty(flow) 81 | imggradx = ops.gradientx(image) 82 | imggrady = ops.gradienty(image) 83 | weightx = torch.exp(-torch.mean(torch.abs(imggradx), 1, keepdim=True)) 84 | weighty = torch.exp(-torch.mean(torch.abs(imggrady), 1, keepdim=True)) 85 | lossx = flowgradx * weightx 86 | lossy = flowgrady * weighty 87 | # return torch.mean(torch.abs(lossx + lossy)) 88 | return torch.mean(torch.abs(lossx)) + torch.mean(torch.abs(lossy)) 89 | 90 | def flowgradloss(self, flow, image, t=1): 91 | 92 | flow_gradient_loss = 0. 93 | for ii in range(t): 94 | flow_gradient_loss += self._flowgradloss(flow[:, :, ii, :, :], image[:, ii, :, :, :]) 95 | return flow_gradient_loss 96 | 97 | def imagegradloss(self, input, target): 98 | input_gradx = ops.gradientx(input) 99 | input_grady = ops.gradienty(input) 100 | 101 | target_gradx = ops.gradientx(target) 102 | target_grady = ops.gradienty(target) 103 | 104 | return F.l1_loss(torch.abs(target_gradx), torch.abs(input_gradx)) \ 105 | + F.l1_loss(torch.abs(target_grady), torch.abs(input_grady)) 106 | 107 | def SSIM(self, x, y): 108 | C1 = 0.01 ** 2 109 | C2 = 0.03 ** 2 110 | 111 | mu_x = F.avg_pool2d(x, 3, 1) 112 | mu_y = F.avg_pool2d(y, 3, 1) 113 | 114 | sigma_x = F.avg_pool2d(x ** 2, 3, 1) - mu_x ** 2 115 | sigma_y = F.avg_pool2d(y ** 2, 3, 1) - mu_y ** 2 116 | sigma_xy = F.avg_pool2d(x * y, 3, 1) - mu_x * mu_y 117 | 118 | SSIM_n = (2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2) 119 | SSIM_d = (mu_x ** 2 + mu_y ** 2 + C1) * (sigma_x + sigma_y + C2) 120 | 121 | SSIM = SSIM_n / SSIM_d 122 | 123 | return torch.clamp((1 - SSIM) / 2, 0, 1).mean() 124 | 125 | def image_similarity(self, x, y, opt=None): 126 | sim = 0 127 | # for ii in range(opt.num_predicted_frames): 128 | for ii in range(x.size()[1]): 129 | sim += opt.alpha_recon_image * self.SSIM(x[:, ii, ...], y[:, ii, ...]) \ 130 | + (1 - opt.alpha_recon_image) * F.l1_loss(x[:, ii, ...], y[:, ii, ...]) 131 | return sim 132 | 133 | def loss_function(self, mu, logvar, bs): 134 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 135 | KLD /= 1000 136 | return KLD 137 | 138 | def kl_criterion(self, mu, logvar, bs): 139 | # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 140 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 141 | KLD /= self.opt.batch_size 142 | return KLD 143 | 144 | def _flowconsist(self, flow, flowback, mask_fw=None, mask_bw=None): 145 | if mask_fw is not None: 146 | # mask_fw, mask_bw = occlusion(flow, flowback, self.flowwarp) 147 | prevloss = (mask_bw * torch.abs(self.flowwarp(flow, -flowback) - flowback)).mean() 148 | nextloss = (mask_fw * torch.abs(self.flowwarp(flowback, flow) - flow)).mean() 149 | else: 150 | prevloss = torch.abs(self.flowwarp(flow, -flowback) - flowback).mean() 151 | nextloss = torch.abs(self.flowwarp(flowback, flow) - flow).mean() 152 | return prevloss + nextloss 153 | 154 | def flowconsist(self, flow, flowback, mask_fw=None, mask_bw=None, t=4): 155 | flowcon = 0. 156 | if mask_bw is not None: 157 | for ii in range(t): 158 | flowcon += self._flowconsist(flow[:, :, ii, :, :], flowback[:, :, ii, :, :], 159 | mask_fw=mask_fw[:, ii:ii + 1, ...], 160 | mask_bw=mask_bw[:, ii:ii + 1, ...]) 161 | else: 162 | for ii in range(t): 163 | flowcon += self._flowconsist(flow[:, :, ii, :, :], flowback[:, :, ii, :, :]) 164 | return flowcon 165 | 166 | def reconlossT(self, x, y, t=4, mask=None): 167 | if mask is not None: 168 | x = x * mask.unsqueeze(2) 169 | y = y * mask.unsqueeze(2) 170 | 171 | loss = (x.contiguous() - y.contiguous()).abs().mean() 172 | return loss 173 | 174 | 175 | class losses_multigpu_only_mask(nn.Module): 176 | def __init__(self, opt, flowwarpper): 177 | super(losses_multigpu_only_mask, self).__init__() 178 | self.tl = TrainingLoss(opt, flowwarpper) 179 | self.flowwarpper = flowwarpper 180 | self.opt = opt 181 | 182 | def forward(self, frame1, frame2, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw, prediction_vgg_feature, gt_vgg_feature, y_pred_before_refine=None): 183 | opt = self.opt 184 | flowwarpper = self.flowwarpper 185 | tl = self.tl 186 | output = y_pred 187 | 188 | '''flowloss''' 189 | flowloss = tl.quickflowloss(flow, frame2) 190 | flowloss += tl.quickflowloss(flowback, frame1.unsqueeze(1)) 191 | flowloss *= 0.01 192 | 193 | '''flow consist''' 194 | flowcon = tl.flowconsist(flow, flowback, mask_fw, mask_bw, t=opt.num_predicted_frames) 195 | 196 | '''kldloss''' 197 | kldloss = tl.loss_function(mu, logvar, opt.batch_size) 198 | 199 | '''flow gradient loss''' 200 | # flow_gradient_loss = tl.flowgradloss(flow, frame2) 201 | # flow_gradient_loss += tl.flowgradloss(flowback, frame1) 202 | # flow_gradient_loss *= 0.01 203 | 204 | '''Image Similarity loss''' 205 | sim_loss = tl.image_similarity(output, frame2, opt) 206 | 207 | '''reconstruct loss''' 208 | prevframe = [torch.unsqueeze(flowwarpper(frame2[:, ii, :, :, :], -flowback[:, :, ii, :, :]* mask_bw[:, ii:ii + 1, ...]), 1) 209 | for ii in range(opt.num_predicted_frames)] 210 | prevframe = torch.cat(prevframe, 1) 211 | 212 | reconloss_back = tl.reconlossT(prevframe, 213 | torch.unsqueeze(frame1, 1).repeat(1, opt.num_predicted_frames, 1, 1, 1), 214 | mask=mask_bw, t=opt.num_predicted_frames) 215 | reconloss = tl.reconlossT(output, frame2, t=opt.num_predicted_frames) 216 | 217 | if y_pred_before_refine is not None: 218 | reconloss_before = tl.reconlossT(y_pred_before_refine, frame2, mask=mask_fw, t=opt.num_predicted_frames) 219 | else: 220 | reconloss_before = 0. 221 | 222 | '''vgg loss''' 223 | vgg_loss = tl.vgg_loss(prediction_vgg_feature, gt_vgg_feature) 224 | 225 | '''mask loss''' 226 | mask_loss = (1 - mask_bw).mean() + (1 - mask_fw).mean() 227 | 228 | return flowloss, reconloss, reconloss_back, reconloss_before, kldloss, flowcon, sim_loss, vgg_loss, mask_loss 229 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .multiframe_genmask import * 2 | from .multiframe_w_mask_genmask import * 3 | from .multiframe_w_mask_genmask_two_path import * 4 | from .multiframe_w_mask_genmask_two_path_iterative import * 5 | from .vgg_utils import * -------------------------------------------------------------------------------- /src/models/multiframe_genmask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable as Vb 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | import torchvision.models 7 | import torch.optim as optim 8 | import os 9 | import logging 10 | import torchvision.utils as tov 11 | 12 | import sys 13 | sys.path.insert(0, '../utils') 14 | from utils import utils 15 | from utils import ops 16 | from models.vgg_utils import my_vgg 17 | 18 | 19 | class motion_net(nn.Module): 20 | def __init__(self, opt): 21 | super(motion_net, self).__init__() 22 | # input 3*128*128 23 | self.main = nn.Sequential( 24 | nn.Conv2d(opt.num_frames*opt.input_channel, 32, 4, 2, 1, bias=False), # 64 25 | nn.LeakyReLU(0.2, inplace=True), 26 | nn.Conv2d(32, 64, 4, 2, 1, bias=False), # 32 27 | nn.BatchNorm2d(64), 28 | nn.LeakyReLU(0.2, inplace=True), 29 | nn.Conv2d(64, 64, 3, 1, 1, bias=False), # 32 30 | nn.BatchNorm2d(64), 31 | nn.LeakyReLU(0.2, inplace=True), 32 | nn.Conv2d(64, 128, 4, 2, 1, bias=False), # 16 33 | nn.BatchNorm2d(128), 34 | nn.LeakyReLU(0.2, inplace=True), 35 | nn.Conv2d(128, 128, 3, 1, 1, bias=False), # 16 36 | nn.BatchNorm2d(128), 37 | nn.LeakyReLU(0.2, inplace=True), 38 | nn.Conv2d(128, 64, 4, 2, 1, bias=False) # 8 39 | ) 40 | self.fc1 = nn.Linear(1024, 1024) 41 | self.fc2 = nn.Linear(1024, 1024) 42 | 43 | def forward(self, x): 44 | temp = self.main(x).view(-1, 1024) 45 | mu = self.fc1(temp) 46 | # print 'mu: '+str(mu.size()) 47 | logvar = self.fc2(temp) 48 | return mu, logvar 49 | 50 | 51 | class gateconv3d_bak(nn.Module): 52 | def __init__(self, innum, outnum, kernel, stride, pad): 53 | super(gateconv3d, self).__init__() 54 | self.conv = nn.Conv3d(innum, outnum * 2, kernel, stride, pad, bias=True) 55 | self.bn = nn.BatchNorm3d(outnum * 2) 56 | 57 | def forward(self, x): 58 | return F.glu(self.bn(self.conv(x)), 1) + x 59 | 60 | 61 | class gateconv3d(nn.Module): 62 | def __init__(self, innum, outnum, kernel, stride, pad): 63 | super(gateconv3d, self).__init__() 64 | self.conv = nn.Conv3d(innum, outnum, kernel, stride, pad, bias=True) 65 | self.bn = nn.BatchNorm3d(outnum) 66 | 67 | def forward(self, x): 68 | return F.leaky_relu(self.bn(self.conv(x)), 0.2) 69 | 70 | 71 | class convblock(nn.Module): 72 | def __init__(self, innum, outnum, kernel, stride, pad): 73 | super(convblock, self).__init__() 74 | self.main = nn.Sequential( 75 | nn.Conv2d(innum, outnum, kernel, stride, pad, bias=False), 76 | nn.BatchNorm2d(outnum), 77 | nn.LeakyReLU(0.2, inplace=True)) 78 | 79 | def forward(self, x): 80 | return self.main(x) 81 | 82 | 83 | class convbase(nn.Module): 84 | def __init__(self, innum, outnum, kernel, stride, pad): 85 | super(convbase, self).__init__() 86 | self.main = nn.Sequential( 87 | nn.Conv2d(innum, outnum, kernel, stride, pad), 88 | nn.LeakyReLU(0.2, inplace=True)) 89 | 90 | def forward(self, x): 91 | return self.main(x) 92 | 93 | 94 | class upconv(nn.Module): 95 | def __init__(self, innum, outnum, kernel, stride, pad): 96 | super(upconv, self).__init__() 97 | self.main = nn.Sequential( 98 | nn.Conv2d(innum, outnum * 2, kernel, stride, pad), 99 | nn.BatchNorm2d(outnum * 2), 100 | nn.LeakyReLU(0.2, inplace=True), 101 | nn.Conv2d(outnum * 2, outnum, kernel, stride, pad), 102 | nn.BatchNorm2d(outnum), 103 | nn.LeakyReLU(0.2, inplace=True), 104 | nn.Upsample(scale_factor=2) 105 | ) 106 | 107 | def forward(self, x): 108 | return self.main(x) 109 | 110 | 111 | class getflow(nn.Module): 112 | def __init__(self): 113 | super(getflow, self).__init__() 114 | self.main = nn.Sequential( 115 | upconv(64, 16, 5, 1, 2), 116 | nn.Conv2d(16, 2, 5, 1, 2), 117 | ) 118 | 119 | def forward(self, x): 120 | return self.main(x) 121 | 122 | class get_occlusion_mask(nn.Module): 123 | def __init__(self): 124 | super(get_occlusion_mask, self).__init__() 125 | self.main = nn.Sequential( 126 | upconv(64, 16, 5, 1, 2), 127 | nn.Conv2d(16, 2, 5, 1, 2), 128 | ) 129 | 130 | def forward(self, x): 131 | return F.sigmoid(self.main(x)) 132 | 133 | 134 | class get_frames(nn.Module): 135 | def __init__(self, opt): 136 | super(get_frames, self).__init__() 137 | opt = opt 138 | self.main = nn.Sequential( 139 | upconv(64, 16, 5, 1, 2), 140 | nn.Conv2d(16, opt.input_channel, 5, 1, 2) 141 | ) 142 | 143 | def forward(self, x): 144 | return F.sigmoid(self.main(x)) 145 | 146 | 147 | class encoder(nn.Module): 148 | def __init__(self, opt): 149 | super(encoder, self).__init__() 150 | self.econv1 = convbase(opt.input_channel, 32, 4, 2, 1) # 32,64,64 151 | self.econv2 = convblock(32, 64, 4, 2, 1) # 64,32,32 152 | self.econv3 = convblock(64, 128, 4, 2, 1) # 128,16,16 153 | self.econv4 = convblock(128, 256, 4, 2, 1) # 256,8,8 154 | 155 | def forward(self, x): 156 | enco1 = self.econv1(x) # 32 157 | enco2 = self.econv2(enco1) # 64 158 | enco3 = self.econv3(enco2) # 128 159 | codex = self.econv4(enco3) # 256 160 | return enco1, enco2, enco3, codex 161 | 162 | 163 | class decoder(nn.Module): 164 | def __init__(self, opt): 165 | super(decoder, self).__init__() 166 | self.opt = opt 167 | self.dconv1 = convblock(256 + 16, 256, 3, 1, 1) # 256,8,8 168 | self.dconv2 = upconv(256, 128, 3, 1, 1) # 128,16,16 169 | self.dconv3 = upconv(256, 64, 3, 1, 1) # 64,32,32 170 | self.dconv4 = upconv(128, 32, 3, 1, 1) # 32,64,64 171 | self.gateconv1 = gateconv3d(64, 64, 3, 1, 1) 172 | self.gateconv2 = gateconv3d(32, 32, 3, 1, 1) 173 | 174 | def forward(self, enco1, enco2, enco3, z): 175 | opt = self.opt 176 | deco1 = self.dconv1(z) # .view(-1,256,4,4,4)# bs*4,256,8,8 177 | deco2 = torch.cat(torch.chunk(self.dconv2(deco1).unsqueeze(2), opt.num_predicted_frames, 0), 2) # bs*4,128,16,16 178 | deco2 = torch.cat(torch.unbind(torch.cat([deco2, torch.unsqueeze(enco3, 2).repeat(1, 1, opt.num_predicted_frames, 1, 1)], 1), 2), 0) 179 | deco3 = torch.cat(self.dconv3(deco2).unsqueeze(2).chunk(opt.num_predicted_frames, 0), 2) # 128,32,32 180 | deco3 = self.gateconv1(deco3) 181 | deco3 = torch.cat(torch.unbind(torch.cat([deco3, torch.unsqueeze(enco2, 2).repeat(1, 1, opt.num_predicted_frames, 1, 1)], 1), 2), 0) 182 | deco4 = torch.cat(self.dconv4(deco3).unsqueeze(2).chunk(opt.num_predicted_frames, 0), 2) # 32,4,64,64 183 | deco4 = self.gateconv2(deco4) 184 | deco4 = torch.cat(torch.unbind(torch.cat([deco4, torch.unsqueeze(enco1, 2).repeat(1, 1, opt.num_predicted_frames, 1, 1)], 1), 2), 0) 185 | return deco4 186 | 187 | # from nets import Flow2Frame_warped 188 | # from vgg_128 import Flow2Frame_warped # with skip connections 189 | 190 | mean = Vb(torch.FloatTensor([0.485, 0.456, 0.406])).view([1,3,1,1]) 191 | std = Vb(torch.FloatTensor([0.229, 0.224, 0.225])).view([1,3,1,1]) 192 | 193 | 194 | class VAE(nn.Module): 195 | def __init__(self, hallucination=False, opt=None, refine=True): 196 | super(VAE, self).__init__() 197 | 198 | self.opt = opt 199 | self.hallucination = hallucination 200 | self.motion_net = motion_net(opt) 201 | self.encoder = encoder(opt) 202 | self.flow_decoder = decoder(opt) 203 | if self.hallucination: 204 | self.raw_decoder = decoder(opt) 205 | self.predict = get_frames(opt) 206 | 207 | self.zconv = convbase(256 + 64, int(16*self.opt.num_predicted_frames), 3, 1, 1) 208 | self.floww = ops.flowwrapper() 209 | self.fc = nn.Linear(1024, 1024) 210 | self.flownext = getflow() 211 | self.flowprev = getflow() 212 | self.get_mask = get_occlusion_mask() 213 | self.refine = refine 214 | if self.refine: 215 | from models.vgg_128 import RefineNet 216 | self.refine_net = RefineNet(num_channels=opt.input_channel) 217 | 218 | vgg19 = torchvision.models.vgg19(pretrained=True) 219 | self.vgg_net = my_vgg(vgg19) 220 | for param in self.vgg_net.parameters(): 221 | param.requires_grad = False 222 | 223 | def reparameterize(self, mu, logvar): 224 | if self.training: 225 | std = logvar.mul(0.5).exp_() 226 | eps = Vb(std.data.new(std.size()).normal_()) 227 | return eps.mul(std).add_(mu) 228 | else: 229 | return Vb(mu.data.new(mu.size()).normal_()) 230 | 231 | def _normalize(self, x): 232 | gpu_id = x.get_device() 233 | return (x - mean.cuda(gpu_id)) / std.cuda(gpu_id) 234 | 235 | 236 | def forward(self, x, data, noise_bg, z_m=None): 237 | 238 | frame1 = data[:, 0, :, :, :] 239 | frame2 = data[:, 1:, :, :, :] 240 | 241 | opt = self.opt 242 | 243 | y = torch.cat( 244 | [frame1, frame2.contiguous().view(-1, opt.num_predicted_frames * opt.input_channel, opt.input_size[0], 245 | opt.input_size[1]) - 246 | frame1.repeat(1, opt.num_predicted_frames, 1, 1)], 1) 247 | 248 | # Encoder Network --> encode input frames 249 | enco1, enco2, enco3, codex = self.encoder(x) 250 | 251 | # Motion Network --> compute latent vector 252 | mu, logvar = self.motion_net(y.contiguous().view(-1, opt.num_frames*opt.input_channel, opt.input_size[0], opt.input_size[1])) 253 | if z_m is None: 254 | z_m = self.reparameterize(mu, logvar) 255 | codey = self.zconv( 256 | torch.cat([self.fc(z_m).view(-1, 64, int(opt.input_size[0] / 16), int(opt.input_size[1] / 16)), codex], 1)) 257 | codex = torch.unsqueeze(codex, 2).repeat(1, 1, opt.num_predicted_frames, 1, 1) # bs,256,4,8,8 258 | codey = torch.cat(torch.chunk(codey.unsqueeze(2), opt.num_predicted_frames, 1), 2) # bs,16,4,8,8 259 | z = torch.cat(torch.unbind(torch.cat([codex, codey], 1), 2), 0) # (256L, 272L, 8L, 8L) 272-256=16 260 | 261 | # Flow Decoder Network --> decode latent vectors into flow fields. 262 | flow_deco4 = self.flow_decoder(enco1, enco2, enco3, z) # (256, 64, 64, 64) 263 | flow = torch.cat(self.flownext(flow_deco4).unsqueeze(2).chunk(opt.num_predicted_frames, 0), 2) # (64, 2, 4, 128, 128) 264 | flowback = torch.cat(self.flowprev(flow_deco4).unsqueeze(2).chunk(opt.num_predicted_frames, 0), 2) # (64, 2, 4, 128, 128) 265 | 266 | # Warp frames using computed flows 267 | # out = [torch.unsqueeze(self.floww(x, flow[:, :, i, :, :]), 1) for i in range(opt.num_predicted_frames)] 268 | # out = torch.cat(out, 1) # (64, 4, 3, 128, 128) 269 | 270 | '''Compute Occlusion Mask''' 271 | # mask_fw, mask_bw = ops.get_occlusion_mask(flow, flowback, self.floww, opt, t=opt.num_predicted_frames) 272 | masks = torch.cat(self.get_mask(flow_deco4).unsqueeze(2).chunk(opt.num_predicted_frames, 0), 2) # (64, 2, 4, 128, 128) 273 | mask_fw = masks[:,0,...] 274 | mask_bw = masks[:,1,...] 275 | 276 | '''Use mask before warpping''' 277 | output = ops.warp(x, flow, opt, self.floww, mask_fw) 278 | 279 | y_pred = output 280 | 281 | '''Go through the refine network.''' 282 | if self.refine: 283 | y_pred = ops.refine(output, flow, mask_fw, self.refine_net, opt, noise_bg) 284 | 285 | if self.training: 286 | 287 | tmp1 = output.contiguous().view(-1, opt.input_channel, opt.input_size[0], opt.input_size[1]) 288 | tmp2 = frame2.contiguous().view(-1, opt.input_channel, opt.input_size[0], opt.input_size[1]) 289 | 290 | if opt.input_channel ==1: 291 | tmp1 = tmp1.repeat(1, 3, 1, 1) 292 | tmp2 = tmp2.repeat(1, 3, 1, 1) 293 | 294 | prediction_vgg_feature = self.vgg_net(self._normalize(tmp1)) 295 | gt_vgg_feature = self.vgg_net(self._normalize(tmp2)) 296 | 297 | return output, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw, prediction_vgg_feature, gt_vgg_feature 298 | 299 | else: 300 | return output, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw 301 | -------------------------------------------------------------------------------- /src/models/multiframe_w_mask_genmask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable as Vb 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | import torchvision.models 7 | import torch.optim as optim 8 | import os 9 | import logging 10 | import torchvision.utils as tov 11 | 12 | import sys 13 | sys.path.insert(0, '../utils') 14 | from utils import utils 15 | from utils import ops 16 | from models.vgg_utils import my_vgg 17 | 18 | 19 | class motion_net(nn.Module): 20 | def __init__(self, opt, input_channel, output_channel=int(1024/2)): 21 | super(motion_net, self).__init__() 22 | # input 3*128*128 23 | self.main = nn.Sequential( 24 | nn.Conv2d(input_channel, 32, 4, 2, 1, bias=False), # 64 25 | nn.LeakyReLU(0.2, inplace=True), 26 | nn.Conv2d(32, 64, 4, 2, 1, bias=False), # 32 27 | nn.BatchNorm2d(64), 28 | nn.LeakyReLU(0.2, inplace=True), 29 | nn.Conv2d(64, 64, 3, 1, 1, bias=False), # 32 30 | nn.BatchNorm2d(64), 31 | nn.LeakyReLU(0.2, inplace=True), 32 | nn.Conv2d(64, 128, 4, 2, 1, bias=False), # 16 33 | nn.BatchNorm2d(128), 34 | nn.LeakyReLU(0.2, inplace=True), 35 | nn.Conv2d(128, 128, 3, 1, 1, bias=False), # 16 36 | nn.BatchNorm2d(128), 37 | nn.LeakyReLU(0.2, inplace=True), 38 | nn.Conv2d(128, 64, 4, 2, 1, bias=False) # 8 39 | ) 40 | self.fc1 = nn.Linear(1024, output_channel) 41 | self.fc2 = nn.Linear(1024, output_channel) 42 | 43 | def forward(self, x): 44 | temp = self.main(x).view(-1, 1024) 45 | mu = self.fc1(temp) 46 | # print 'mu: '+str(mu.size()) 47 | logvar = self.fc2(temp) 48 | return mu, logvar 49 | 50 | 51 | class gateconv3d_bak(nn.Module): 52 | def __init__(self, innum, outnum, kernel, stride, pad): 53 | super(gateconv3d, self).__init__() 54 | self.conv = nn.Conv3d(innum, outnum * 2, kernel, stride, pad, bias=True) 55 | self.bn = nn.BatchNorm3d(outnum * 2) 56 | 57 | def forward(self, x): 58 | return F.glu(self.bn(self.conv(x)), 1) + x 59 | 60 | 61 | class gateconv3d(nn.Module): 62 | def __init__(self, innum, outnum, kernel, stride, pad): 63 | super(gateconv3d, self).__init__() 64 | self.conv = nn.Conv3d(innum, outnum, kernel, stride, pad, bias=True) 65 | self.bn = nn.BatchNorm3d(outnum) 66 | 67 | def forward(self, x): 68 | return F.leaky_relu(self.bn(self.conv(x)), 0.2) 69 | 70 | 71 | class convblock(nn.Module): 72 | def __init__(self, innum, outnum, kernel, stride, pad): 73 | super(convblock, self).__init__() 74 | self.main = nn.Sequential( 75 | nn.Conv2d(innum, outnum, kernel, stride, pad, bias=False), 76 | nn.BatchNorm2d(outnum), 77 | nn.LeakyReLU(0.2, inplace=True)) 78 | 79 | def forward(self, x): 80 | return self.main(x) 81 | 82 | 83 | class convbase(nn.Module): 84 | def __init__(self, innum, outnum, kernel, stride, pad): 85 | super(convbase, self).__init__() 86 | self.main = nn.Sequential( 87 | nn.Conv2d(innum, outnum, kernel, stride, pad), 88 | nn.LeakyReLU(0.2, inplace=True)) 89 | 90 | def forward(self, x): 91 | return self.main(x) 92 | 93 | 94 | class upconv(nn.Module): 95 | def __init__(self, innum, outnum, kernel, stride, pad): 96 | super(upconv, self).__init__() 97 | self.main = nn.Sequential( 98 | nn.Conv2d(innum, outnum * 2, kernel, stride, pad), 99 | nn.BatchNorm2d(outnum * 2), 100 | nn.LeakyReLU(0.2, inplace=True), 101 | nn.Conv2d(outnum * 2, outnum, kernel, stride, pad), 102 | nn.BatchNorm2d(outnum), 103 | nn.LeakyReLU(0.2, inplace=True), 104 | nn.Upsample(scale_factor=2, mode='bilinear') 105 | ) 106 | 107 | def forward(self, x): 108 | return self.main(x) 109 | 110 | 111 | class getflow(nn.Module): 112 | def __init__(self): 113 | super(getflow, self).__init__() 114 | self.main = nn.Sequential( 115 | upconv(64, 16, 5, 1, 2), 116 | nn.Conv2d(16, 2, 5, 1, 2), 117 | ) 118 | 119 | def forward(self, x): 120 | return self.main(x) 121 | 122 | 123 | class get_occlusion_mask(nn.Module): 124 | def __init__(self): 125 | super(get_occlusion_mask, self).__init__() 126 | self.main = nn.Sequential( 127 | upconv(64, 16, 5, 1, 2), 128 | nn.Conv2d(16, 2, 5, 1, 2), 129 | ) 130 | 131 | def forward(self, x): 132 | return torch.sigmoid(self.main(x)) 133 | 134 | 135 | class get_frames(nn.Module): 136 | def __init__(self, opt): 137 | super(get_frames, self).__init__() 138 | opt = opt 139 | self.main = nn.Sequential( 140 | upconv(64, 16, 5, 1, 2), 141 | nn.Conv2d(16, opt.input_channel, 5, 1, 2) 142 | ) 143 | 144 | def forward(self, x): 145 | return torch.sigmoid(self.main(x)) 146 | 147 | 148 | class encoder(nn.Module): 149 | def __init__(self, opt): 150 | super(encoder, self).__init__() 151 | self.econv1 = convbase(opt.input_channel + opt.mask_channel, 32, 4, 2, 1) # 32,64,64 152 | self.econv2 = convblock(32, 64, 4, 2, 1) # 64,32,32 153 | self.econv3 = convblock(64, 128, 4, 2, 1) # 128,16,16 154 | self.econv4 = convblock(128, 256, 4, 2, 1) # 256,8,8 155 | 156 | def forward(self, x): 157 | enco1 = self.econv1(x) # 32 158 | enco2 = self.econv2(enco1) # 64 159 | enco3 = self.econv3(enco2) # 128 160 | codex = self.econv4(enco3) # 256 161 | return enco1, enco2, enco3, codex 162 | 163 | 164 | class decoder(nn.Module): 165 | def __init__(self, opt): 166 | super(decoder, self).__init__() 167 | self.opt = opt 168 | self.dconv1 = convblock(256 + 16, 256, 3, 1, 1) # 256,8,8 169 | self.dconv2 = upconv(256, 128, 3, 1, 1) # 128,16,16 170 | self.dconv3 = upconv(256, 64, 3, 1, 1) # 64,32,32 171 | self.dconv4 = upconv(128, 32, 3, 1, 1) # 32,64,64 172 | self.gateconv1 = gateconv3d(64, 64, 3, 1, 1) 173 | self.gateconv2 = gateconv3d(32, 32, 3, 1, 1) 174 | 175 | def forward(self, enco1, enco2, enco3, z): 176 | opt = self.opt 177 | deco1 = self.dconv1(z) # .view(-1,256,4,4,4)# bs*4,256,8,8 178 | deco2 = torch.cat(torch.chunk(self.dconv2(deco1).unsqueeze(2), opt.num_predicted_frames, 0), 2) # bs*4,128,16,16 179 | deco2 = torch.cat(torch.unbind(torch.cat([deco2, torch.unsqueeze(enco3, 2).repeat(1, 1, opt.num_predicted_frames, 1, 1)], 1), 2), 0) 180 | deco3 = torch.cat(self.dconv3(deco2).unsqueeze(2).chunk(opt.num_predicted_frames, 0), 2) # 128,32,32 181 | deco3 = self.gateconv1(deco3) 182 | deco3 = torch.cat(torch.unbind(torch.cat([deco3, torch.unsqueeze(enco2, 2).repeat(1, 1, opt.num_predicted_frames, 1, 1)], 1), 2), 0) 183 | deco4 = torch.cat(self.dconv4(deco3).unsqueeze(2).chunk(opt.num_predicted_frames, 0), 2) # 32,4,64,64 184 | deco4 = self.gateconv2(deco4) 185 | deco4 = torch.cat(torch.unbind(torch.cat([deco4, torch.unsqueeze(enco1, 2).repeat(1, 1, opt.num_predicted_frames, 1, 1)], 1), 2), 0) 186 | return deco4 187 | 188 | 189 | mean = Vb(torch.FloatTensor([0.485, 0.456, 0.406])).view([1,3,1,1]) 190 | std = Vb(torch.FloatTensor([0.229, 0.224, 0.225])).view([1,3,1,1]) 191 | 192 | 193 | class VAE(nn.Module): 194 | def __init__(self, hallucination=False, opt=None, refine=True): 195 | super(VAE, self).__init__() 196 | 197 | self.opt = opt 198 | self.hallucination = hallucination 199 | 200 | self.motion_net = motion_net(opt, int(opt.num_frames*opt.input_channel)+20, 1024) 201 | 202 | self.encoder = encoder(opt) 203 | self.flow_decoder = decoder(opt) 204 | if self.hallucination: 205 | self.raw_decoder = decoder(opt) 206 | self.predict = get_frames(opt) 207 | 208 | self.zconv = convbase(256 + 64, 16*self.opt.num_predicted_frames, 3, 1, 1) 209 | self.floww = ops.flowwrapper() 210 | self.fc = nn.Linear(1024, 1024) 211 | self.flownext = getflow() 212 | self.flowprev = getflow() 213 | self.get_mask = get_occlusion_mask() 214 | self.refine = refine 215 | if self.refine: 216 | 217 | from models.vgg_128 import RefineNet 218 | self.refine_net = RefineNet(num_channels=opt.input_channel) 219 | 220 | vgg19 = torchvision.models.vgg19(pretrained=True) 221 | self.vgg_net = my_vgg(vgg19) 222 | for param in self.vgg_net.parameters(): 223 | param.requires_grad = False 224 | 225 | def reparameterize(self, mu, logvar): 226 | if self.training: 227 | std = logvar.mul(0.5).exp_() 228 | eps = Vb(std.data.new(std.size()).normal_()) 229 | return eps.mul(std).add_(mu) 230 | else: 231 | return Vb(mu.data.new(mu.size()).normal_()) 232 | 233 | def _normalize(self, x): 234 | gpu_id = x.get_device() 235 | return (x - mean.cuda(gpu_id)) / std.cuda(gpu_id) 236 | 237 | def forward(self, x, data, mask, noise_bg, z_m=None): 238 | 239 | frame1 = data[:, 0, :, :, :] 240 | frame2 = data[:, 1:, :, :, :] 241 | input = torch.cat([x, mask], 1) 242 | opt = self.opt 243 | 244 | y = torch.cat( 245 | [frame1, frame2.contiguous().view(-1, opt.num_predicted_frames * opt.input_channel, opt.input_size[0], opt.input_size[1]) 246 | - frame1.repeat(1, opt.num_predicted_frames, 1, 1)], 1) 247 | 248 | # Encoder Network --> encode input frames 249 | enco1, enco2, enco3, codex = self.encoder(input) 250 | 251 | # Motion Network --> compute latent vector 252 | mu, logvar = self.motion_net(torch.cat([y, mask], 1).contiguous()) 253 | 254 | if z_m is None: 255 | z_m = self.reparameterize(mu, logvar) 256 | codey = self.zconv( 257 | torch.cat([self.fc(z_m).view(-1, 64, int(opt.input_size[0] / 16), int(opt.input_size[1] / 16)), codex], 1)) 258 | codex = torch.unsqueeze(codex, 2).repeat(1, 1, opt.num_predicted_frames, 1, 1) # bs,256,4,8,8 259 | codey = torch.cat(torch.chunk(codey.unsqueeze(2), opt.num_predicted_frames, 1), 2) # bs,16,4,8,8 260 | z = torch.cat(torch.unbind(torch.cat([codex, codey], 1), 2), 0) # (256L, 272L, 8L, 8L) 272-256=16 261 | 262 | # Flow Decoder Network --> decode latent vectors into flow fields. 263 | flow_deco4 = self.flow_decoder(enco1, enco2, enco3, z) # (256, 64, 64, 64) 264 | flow = torch.cat(self.flownext(flow_deco4).unsqueeze(2).chunk(opt.num_predicted_frames, 0), 2) # (64, 2, 4, 128, 128) 265 | flowback = torch.cat(self.flowprev(flow_deco4).unsqueeze(2).chunk(opt.num_predicted_frames, 0), 2) # (64, 2, 4, 128, 128) 266 | 267 | # Warp frames using computed flows 268 | # out = [torch.unsqueeze(self.floww(x, flow[:, :, i, :, :]), 1) for i in range(opt.num_predicted_frames)] 269 | # out = torch.cat(out, 1) # (64, 4, 3, 128, 128) 270 | 271 | '''Compute Occlusion Mask''' 272 | # mask_fw, mask_bw = ops.get_occlusion_mask(flow, flowback, self.floww, opt, t=opt.num_predicted_frames) 273 | masks = torch.cat(self.get_mask(flow_deco4).unsqueeze(2).chunk(opt.num_predicted_frames, 0), 274 | 2) # (64, 2, 4, 128, 128) 275 | mask_fw = masks[:, 0, ...] 276 | mask_bw = masks[:, 1, ...] 277 | 278 | '''Use mask before warpping''' 279 | output = ops.warp(x, flow, opt, self.floww, mask_fw) 280 | 281 | y_pred = output 282 | 283 | '''Go through the refine network.''' 284 | if self.refine: 285 | y_pred = ops.refine(output, flow, mask_fw, self.refine_net, opt, noise_bg) 286 | 287 | if self.training: 288 | # y_pred_vgg_feature = self.vgg_net( 289 | # self._normalize(y_pred.contiguous().view(-1, opt.input_channel, opt.input_size, opt.input_size))) 290 | prediction_vgg_feature = self.vgg_net( 291 | self._normalize(output.contiguous().view(-1, opt.input_channel, opt.input_size[0], opt.input_size[1]))) 292 | gt_vgg_feature = self.vgg_net( 293 | self._normalize(frame2.contiguous().view(-1, opt.input_channel, opt.input_size[0], opt.input_size[1]))) 294 | 295 | return output, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw, prediction_vgg_feature, gt_vgg_feature#, y_pred_vgg_feature 296 | else: 297 | return output, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw -------------------------------------------------------------------------------- /src/models/multiframe_w_mask_genmask_two_path.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable as Vb 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | import torchvision.models 7 | import torch.optim as optim 8 | import os 9 | import logging 10 | import torchvision.utils as tov 11 | 12 | import sys 13 | sys.path.insert(0, '../utils') 14 | from utils import utils 15 | from utils import ops 16 | from models.vgg_utils import my_vgg 17 | 18 | 19 | class motion_net(nn.Module): 20 | def __init__(self, opt, input_channel, output_channel=int(1024/2)): 21 | super(motion_net, self).__init__() 22 | # input 3*128*128 23 | self.main = nn.Sequential( 24 | nn.Conv2d(input_channel, 32, 4, 2, 1, bias=False), # 64 25 | nn.LeakyReLU(0.2, inplace=True), 26 | nn.Conv2d(32, 64, 4, 2, 1, bias=False), # 32 27 | nn.BatchNorm2d(64), 28 | nn.LeakyReLU(0.2, inplace=True), 29 | nn.Conv2d(64, 64, 3, 1, 1, bias=False), # 32 30 | nn.BatchNorm2d(64), 31 | nn.LeakyReLU(0.2, inplace=True), 32 | nn.Conv2d(64, 128, 4, 2, 1, bias=False), # 16 33 | nn.BatchNorm2d(128), 34 | nn.LeakyReLU(0.2, inplace=True), 35 | nn.Conv2d(128, 128, 3, 1, 1, bias=False), # 16 36 | nn.BatchNorm2d(128), 37 | nn.LeakyReLU(0.2, inplace=True), 38 | nn.Conv2d(128, 64, 4, 2, 1, bias=False) # 8 39 | ) 40 | self.fc1 = nn.Linear(1024, output_channel) 41 | self.fc2 = nn.Linear(1024, output_channel) 42 | 43 | def forward(self, x): 44 | temp = self.main(x).view(-1, 1024) 45 | mu = self.fc1(temp) 46 | # print 'mu: '+str(mu.size()) 47 | logvar = self.fc2(temp) 48 | return mu, logvar 49 | 50 | 51 | class gateconv3d_bak(nn.Module): 52 | def __init__(self, innum, outnum, kernel, stride, pad): 53 | super(gateconv3d, self).__init__() 54 | self.conv = nn.Conv3d(innum, outnum * 2, kernel, stride, pad, bias=True) 55 | self.bn = nn.BatchNorm3d(outnum * 2) 56 | 57 | def forward(self, x): 58 | return F.glu(self.bn(self.conv(x)), 1) + x 59 | 60 | 61 | class gateconv3d(nn.Module): 62 | def __init__(self, innum, outnum, kernel, stride, pad): 63 | super(gateconv3d, self).__init__() 64 | self.conv = nn.Conv3d(innum, outnum, kernel, stride, pad, bias=True) 65 | self.bn = nn.BatchNorm3d(outnum) 66 | 67 | def forward(self, x): 68 | return F.leaky_relu(self.bn(self.conv(x)), 0.2) 69 | 70 | 71 | class convblock(nn.Module): 72 | def __init__(self, innum, outnum, kernel, stride, pad): 73 | super(convblock, self).__init__() 74 | self.main = nn.Sequential( 75 | nn.Conv2d(innum, outnum, kernel, stride, pad, bias=False), 76 | nn.BatchNorm2d(outnum), 77 | nn.LeakyReLU(0.2, inplace=True)) 78 | 79 | def forward(self, x): 80 | return self.main(x) 81 | 82 | 83 | class convbase(nn.Module): 84 | def __init__(self, innum, outnum, kernel, stride, pad): 85 | super(convbase, self).__init__() 86 | self.main = nn.Sequential( 87 | nn.Conv2d(innum, outnum, kernel, stride, pad), 88 | nn.LeakyReLU(0.2, inplace=True)) 89 | 90 | def forward(self, x): 91 | return self.main(x) 92 | 93 | 94 | class upconv(nn.Module): 95 | def __init__(self, innum, outnum, kernel, stride, pad): 96 | super(upconv, self).__init__() 97 | self.main = nn.Sequential( 98 | nn.Conv2d(innum, outnum * 2, kernel, stride, pad), 99 | nn.BatchNorm2d(outnum * 2), 100 | nn.LeakyReLU(0.2, inplace=True), 101 | nn.Conv2d(outnum * 2, outnum, kernel, stride, pad), 102 | nn.BatchNorm2d(outnum), 103 | nn.LeakyReLU(0.2, inplace=True), 104 | nn.Upsample(scale_factor=2, mode='bilinear') 105 | ) 106 | 107 | def forward(self, x): 108 | return self.main(x) 109 | 110 | 111 | class getflow(nn.Module): 112 | def __init__(self, output_channel=2): 113 | super(getflow, self).__init__() 114 | self.main = nn.Sequential( 115 | upconv(64, 16, 5, 1, 2), 116 | nn.Conv2d(16, output_channel, 5, 1, 2), 117 | ) 118 | 119 | def forward(self, x): 120 | return self.main(x) 121 | 122 | 123 | class get_occlusion_mask(nn.Module): 124 | def __init__(self): 125 | super(get_occlusion_mask, self).__init__() 126 | self.main = nn.Sequential( 127 | upconv(64, 16, 5, 1, 2), 128 | nn.Conv2d(16, 2, 5, 1, 2), 129 | ) 130 | 131 | def forward(self, x): 132 | return torch.sigmoid(self.main(x)) 133 | 134 | 135 | class get_frames(nn.Module): 136 | def __init__(self, opt): 137 | super(get_frames, self).__init__() 138 | opt = opt 139 | self.main = nn.Sequential( 140 | upconv(64, 16, 5, 1, 2), 141 | nn.Conv2d(16, opt.input_channel, 5, 1, 2) 142 | ) 143 | 144 | def forward(self, x): 145 | return torch.sigmoid(self.main(x)) 146 | 147 | 148 | class encoder(nn.Module): 149 | def __init__(self, opt): 150 | super(encoder, self).__init__() 151 | self.econv1 = convbase(opt.input_channel + opt.mask_channel, 32, 4, 2, 1) # 32,64,64 152 | self.econv2 = convblock(32, 64, 4, 2, 1) # 64,32,32 153 | self.econv3 = convblock(64, 128, 4, 2, 1) # 128,16,16 154 | self.econv4 = convblock(128, 256, 4, 2, 1) # 256,8,8 155 | 156 | def forward(self, x): 157 | enco1 = self.econv1(x) # 32 158 | enco2 = self.econv2(enco1) # 64 159 | enco3 = self.econv3(enco2) # 128 160 | codex = self.econv4(enco3) # 256 161 | return enco1, enco2, enco3, codex 162 | 163 | 164 | class decoder(nn.Module): 165 | def __init__(self, opt): 166 | super(decoder, self).__init__() 167 | self.opt = opt 168 | self.dconv1 = convblock(256 + 16, 256, 3, 1, 1) # 256,8,8 169 | self.dconv2 = upconv(256, 128, 3, 1, 1) # 128,16,16 170 | self.dconv3 = upconv(256, 64, 3, 1, 1) # 64,32,32 171 | self.dconv4 = upconv(128, 32, 3, 1, 1) # 32,64,64 172 | self.gateconv1 = gateconv3d(64, 64, 3, 1, 1) 173 | self.gateconv2 = gateconv3d(32, 32, 3, 1, 1) 174 | 175 | def forward(self, enco1, enco2, enco3, z): 176 | opt = self.opt 177 | deco1 = self.dconv1(z) # .view(-1,256,4,4,4)# bs*4,256,8,8 178 | deco2 = torch.cat(torch.chunk(self.dconv2(deco1).unsqueeze(2), opt.num_predicted_frames, 0), 2) # bs*4,128,16,16 179 | deco2 = torch.cat(torch.unbind(torch.cat([deco2, torch.unsqueeze(enco3, 2).repeat(1, 1, opt.num_predicted_frames, 1, 1)], 1), 2), 0) 180 | deco3 = torch.cat(self.dconv3(deco2).unsqueeze(2).chunk(opt.num_predicted_frames, 0), 2) # 128,32,32 181 | deco3 = self.gateconv1(deco3) 182 | deco3 = torch.cat(torch.unbind(torch.cat([deco3, torch.unsqueeze(enco2, 2).repeat(1, 1, opt.num_predicted_frames, 1, 1)], 1), 2), 0) 183 | deco4 = torch.cat(self.dconv4(deco3).unsqueeze(2).chunk(opt.num_predicted_frames, 0), 2) # 32,4,64,64 184 | deco4 = self.gateconv2(deco4) 185 | deco4 = torch.cat(torch.unbind(torch.cat([deco4, torch.unsqueeze(enco1, 2).repeat(1, 1, opt.num_predicted_frames, 1, 1)], 1), 2), 0) 186 | return deco4 187 | 188 | 189 | mean = Vb(torch.FloatTensor([0.485, 0.456, 0.406])).view([1,3,1,1]) 190 | std = Vb(torch.FloatTensor([0.229, 0.224, 0.225])).view([1,3,1,1]) 191 | 192 | 193 | class VAE(nn.Module): 194 | def __init__(self, hallucination=False, opt=None, refine=True, bg=512, fg=512): 195 | super(VAE, self).__init__() 196 | 197 | self.opt = opt 198 | self.hallucination = hallucination 199 | 200 | # BG 201 | self.motion_net_bg = motion_net(opt, int(opt.num_frames*opt.input_channel)+11, bg) 202 | # FG 203 | self.motion_net_fg = motion_net(opt, int(opt.num_frames*opt.input_channel)+9, fg) 204 | 205 | self.encoder = encoder(opt) 206 | self.flow_decoder = decoder(opt) 207 | if self.hallucination: 208 | self.raw_decoder = decoder(opt) 209 | self.predict = get_frames(opt) 210 | 211 | self.zconv = convbase(256 + 64, 16*self.opt.num_predicted_frames, 3, 1, 1) 212 | self.floww = ops.flowwrapper() 213 | self.fc = nn.Linear(1024, 1024) 214 | self.flownext = getflow() 215 | self.flowprev = getflow() 216 | self.get_mask = get_occlusion_mask() 217 | self.refine = refine 218 | if self.refine: 219 | from models.vgg_128 import RefineNet 220 | self.refine_net = RefineNet(num_channels=opt.input_channel) 221 | 222 | vgg19 = torchvision.models.vgg19(pretrained=True) 223 | self.vgg_net = my_vgg(vgg19) 224 | for param in self.vgg_net.parameters(): 225 | param.requires_grad = False 226 | 227 | def reparameterize(self, mu, logvar): 228 | if self.training: 229 | std = logvar.mul(0.5).exp_() 230 | eps = Vb(std.data.new(std.size()).normal_()) 231 | return eps.mul(std).add_(mu) 232 | else: 233 | return Vb(mu.data.new(mu.size()).normal_()) 234 | 235 | def _normalize(self, x): 236 | gpu_id = x.get_device() 237 | return (x - mean.cuda(gpu_id)) / std.cuda(gpu_id) 238 | 239 | def forward(self, x, data, bg_mask, fg_mask, noise_bg, z_m=None): 240 | 241 | frame1 = data[:, 0, :, :, :] 242 | frame2 = data[:, 1:, :, :, :] 243 | mask = torch.cat([bg_mask, fg_mask], 1) 244 | input = torch.cat([x, mask], 1) 245 | opt = self.opt 246 | 247 | y = torch.cat( 248 | [frame1, frame2.contiguous().view(-1, opt.num_predicted_frames * opt.input_channel, opt.input_size[0], 249 | opt.input_size[1]) - 250 | frame1.repeat(1, opt.num_predicted_frames, 1, 1)], 1) 251 | 252 | # Encoder Network --> encode input frames 253 | enco1, enco2, enco3, codex = self.encoder(input) 254 | 255 | # Motion Network --> compute latent vector 256 | 257 | # BG 258 | mu_bg, logvar_bg = self.motion_net_bg(torch.cat([y, bg_mask], 1).contiguous()) 259 | # FG 260 | mu_fg, logvar_fg = self.motion_net_fg(torch.cat([y , fg_mask], 1).contiguous()) 261 | 262 | mu = torch.cat([mu_bg, mu_fg], 1) 263 | logvar = torch.cat([logvar_bg, logvar_fg], 1) 264 | 265 | # mu = mu_bg + mu_fg 266 | # logvar = logvar_bg + logvar_fg 267 | # print (mu.size()) 268 | # z_m = self.reparameterize(mu, logvar) 269 | # print (z_m.size()) 270 | if z_m is None: 271 | z_m = self.reparameterize(mu, logvar) 272 | 273 | codey = self.zconv(torch.cat([self.fc(z_m).view(-1, 64, int(opt.input_size[0]/16), int(opt.input_size[1]/16)), codex], 1)) 274 | codex = torch.unsqueeze(codex, 2).repeat(1, 1, opt.num_predicted_frames, 1, 1) # bs,256,4,8,8 275 | codey = torch.cat(torch.chunk(codey.unsqueeze(2), opt.num_predicted_frames, 1), 2) # bs,16,4,8,8 276 | z = torch.cat(torch.unbind(torch.cat([codex, codey], 1), 2), 0) # (256L, 272L, 8L, 8L) 272-256=16 277 | 278 | # Flow Decoder Network --> decode latent vectors into flow fields. 279 | flow_deco4 = self.flow_decoder(enco1, enco2, enco3, z) # (256, 64, 64, 64) 280 | flow = torch.cat(self.flownext(flow_deco4).unsqueeze(2).chunk(opt.num_predicted_frames, 0), 2) # (64, 2, 4, 128, 128) 281 | flowback = torch.cat(self.flowprev(flow_deco4).unsqueeze(2).chunk(opt.num_predicted_frames, 0), 2) # (64, 2, 4, 128, 128) 282 | 283 | # Warp frames using computed flows 284 | # out = [torch.unsqueeze(self.floww(x, flow[:, :, i, :, :]), 1) for i in range(opt.num_predicted_frames)] 285 | # out = torch.cat(out, 1) # (64, 4, 3, 128, 128) 286 | 287 | '''Compute Occlusion Mask''' 288 | # mask_fw, mask_bw = ops.get_occlusion_mask(flow, flowback, self.floww, opt, t=opt.num_predicted_frames) 289 | masks = torch.cat(self.get_mask(flow_deco4).unsqueeze(2).chunk(opt.num_predicted_frames, 0), 290 | 2) # (64, 2, 4, 128, 128) 291 | mask_fw = masks[:, 0, ...] 292 | mask_bw = masks[:, 1, ...] 293 | 294 | '''Use mask before warpping''' 295 | output = ops.warp(x, flow, opt, self.floww, mask_fw) 296 | 297 | y_pred = output 298 | 299 | '''Go through the refine network.''' 300 | if self.refine: 301 | y_pred = ops.refine(output, flow, mask_fw, self.refine_net, opt, noise_bg) 302 | 303 | if self.training: 304 | # y_pred_vgg_feature = self.vgg_net( 305 | # self._normalize(y_pred.contiguous().view(-1, opt.input_channel, opt.input_size, opt.input_size))) 306 | prediction_vgg_feature = self.vgg_net( 307 | self._normalize(output.contiguous().view(-1, opt.input_channel, opt.input_size[0], opt.input_size[1]))) 308 | gt_vgg_feature = self.vgg_net( 309 | self._normalize(frame2.contiguous().view(-1, opt.input_channel, opt.input_size[0], opt.input_size[1]))) 310 | 311 | return output, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw, prediction_vgg_feature, gt_vgg_feature#, y_pred_vgg_feature 312 | else: 313 | return output, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw 314 | 315 | 316 | 317 | -------------------------------------------------------------------------------- /src/models/vgg_128.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class vgg_layer(nn.Module): 6 | def __init__(self, nin, nout): 7 | super(vgg_layer, self).__init__() 8 | self.main = nn.Sequential( 9 | nn.Conv2d(nin, nout, 3, 1, 1), 10 | nn.BatchNorm2d(nout), 11 | nn.LeakyReLU(0.2, inplace=True) 12 | ) 13 | 14 | def forward(self, input): 15 | return self.main(input) 16 | 17 | class encoder(nn.Module): 18 | def __init__(self, dim, nc=1): 19 | super(encoder, self).__init__() 20 | self.dim = dim 21 | # 128 x 128 22 | self.c1 = nn.Sequential( 23 | vgg_layer(nc, 64), 24 | vgg_layer(64, 64), 25 | ) 26 | # 64 x 64 27 | self.c2 = nn.Sequential( 28 | vgg_layer(64, 128), 29 | vgg_layer(128, 128), 30 | ) 31 | # 32 x 32 32 | self.c3 = nn.Sequential( 33 | vgg_layer(128, 256), 34 | vgg_layer(256, 256), 35 | vgg_layer(256, 256), 36 | ) 37 | # 16 x 16 38 | self.c4 = nn.Sequential( 39 | vgg_layer(256, 512), 40 | vgg_layer(512, 512), 41 | vgg_layer(512, 512), 42 | ) 43 | # 8 x 8 44 | self.c5 = nn.Sequential( 45 | vgg_layer(512, 512), 46 | vgg_layer(512, 512), 47 | vgg_layer(512, 512), 48 | ) 49 | # 4 x 4 50 | self.c6 = nn.Sequential( 51 | nn.Conv2d(512, dim, 4, 1, 0), 52 | nn.BatchNorm2d(dim), 53 | nn.Tanh() 54 | ) 55 | self.mp = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 56 | # self.mp = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) 57 | def forward(self, input): 58 | h1 = self.c1(input) # 128 -> 64 59 | h2 = self.c2(self.mp(h1)) # 64 -> 32 60 | h3 = self.c3(self.mp(h2)) # 32 -> 16 61 | h4 = self.c4(self.mp(h3)) # 16 -> 8 62 | h5 = self.c5(self.mp(h4)) # 8 -> 4 63 | h6 = self.c6(self.mp(h5)) # 4 -> 1 64 | # return h6.view(-1, self.dim), [h1, h2, h3, h4, h5] 65 | return h6, [h1, h2, h3, h4, h5] 66 | 67 | 68 | class decoder(nn.Module): 69 | def __init__(self, dim, nc=1): 70 | super(decoder, self).__init__() 71 | self.dim = dim 72 | # 1 x 1 -> 4 x 4 73 | self.upc1 = nn.Sequential( 74 | nn.ConvTranspose2d(dim, 512, 4, 1, 0), 75 | nn.BatchNorm2d(512), 76 | nn.LeakyReLU(0.2, inplace=True) 77 | ) 78 | # 8 x 8 79 | self.upc2 = nn.Sequential( 80 | vgg_layer(512*2, 512), 81 | vgg_layer(512, 512), 82 | vgg_layer(512, 512) 83 | ) 84 | # 16 x 16 85 | self.upc3 = nn.Sequential( 86 | vgg_layer(512*2, 512), 87 | vgg_layer(512, 512), 88 | vgg_layer(512, 256) 89 | ) 90 | # 32 x 32 91 | self.upc4 = nn.Sequential( 92 | vgg_layer(256*2, 256), 93 | vgg_layer(256, 256), 94 | vgg_layer(256, 128) 95 | ) 96 | # 64 x 64 97 | self.upc5 = nn.Sequential( 98 | vgg_layer(128*2, 128), 99 | vgg_layer(128, 64) 100 | ) 101 | # 128 x 128 102 | self.upc6 = nn.Sequential( 103 | vgg_layer(64*2, 64), 104 | nn.ConvTranspose2d(64, nc, 3, 1, 1), 105 | nn.Sigmoid() 106 | ) 107 | self.up = nn.Upsample(scale_factor=2., mode='bilinear') 108 | 109 | def forward(self, input): 110 | vec, skip = input 111 | d1 = self.upc1(vec) # 1 -> 4 112 | # d1 = self.upc1(vec.view(-1, self.dim, 1, 1)) # 1 -> 4 113 | up1 = self.up(d1) # 4 -> 8 114 | d2 = self.upc2(torch.cat([up1, skip[4]], 1)) # 8 x 8 115 | up2 = self.up(d2) # 8 -> 16 116 | d3 = self.upc3(torch.cat([up2, skip[3]], 1)) # 16 x 16 117 | up3 = self.up(d3) # 16 -> 32 118 | d4 = self.upc4(torch.cat([up3, skip[2]], 1)) # 32 x 32 119 | up4 = self.up(d4) # 32 -> 64 120 | d5 = self.upc5(torch.cat([up4, skip[1]], 1)) # 64 x 64 121 | up5 = self.up(d5) # 64 -> 128 122 | output = self.upc6(torch.cat([up5, skip[0]], 1)) # 128 x 128 123 | return output 124 | 125 | 126 | class Flow2Frame_warped(nn.Module): 127 | def __init__(self, num_channels): 128 | super(Flow2Frame_warped, self).__init__() 129 | # input shape [batch, 3, 128, 128] 130 | self.flow_encoder = encoder(dim=512, nc=2) 131 | self.image_encoder = encoder(dim=1024, nc=num_channels) 132 | self.image_decoder = decoder(dim=1024+512, nc=num_channels) 133 | 134 | def forward(self, warped_img, flow): 135 | img_hidden, img_skip = self.image_encoder(warped_img) 136 | flow_hidden, _ = self.flow_encoder(flow) 137 | concatenated_features = torch.cat( 138 | [img_hidden, flow_hidden], dim=1) 139 | # print concatenated_features.size() 140 | return self.image_decoder((concatenated_features, img_skip)) 141 | 142 | 143 | '''version without flow encoder''' 144 | class RefineNet(nn.Module): 145 | def __init__(self, num_channels): 146 | super(RefineNet, self).__init__() 147 | # input shape [batch, 3, 128, 128] 148 | # self.flow_encoder = encoder(dim=512, nc=2) 149 | self.image_encoder = encoder(dim=1024, nc=num_channels) 150 | self.image_decoder = decoder(dim=1024, nc=num_channels) 151 | 152 | def forward(self, warped_img, flow): 153 | img_hidden, img_skip = self.image_encoder(warped_img) 154 | return self.image_decoder((img_hidden, img_skip)) -------------------------------------------------------------------------------- /src/models/vgg_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class my_vgg(nn.Module): 6 | def __init__(self, vgg): 7 | super(my_vgg, self).__init__() 8 | self.vgg = vgg 9 | self.avgpool = nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2)) 10 | 11 | def forward(self, img): 12 | # x = img.unsqueeze(0) 13 | x1 = self.vgg.features[0](img) 14 | x2 = self.vgg.features[1](x1) # relu 15 | x3 = self.vgg.features[2](x2) 16 | x4 = self.vgg.features[3](x3) # relu 17 | x5 = self.avgpool(x4) 18 | 19 | x6 = self.vgg.features[5](x5) 20 | x7 = self.vgg.features[6](x6) # relu 21 | x8 = self.vgg.features[7](x7) 22 | x9 = self.vgg.features[8](x8) # relu 23 | x10 = self.avgpool(x9) 24 | 25 | x11 = self.vgg.features[10](x10) 26 | x12 = self.vgg.features[11](x11) # relu 27 | x13 = self.vgg.features[12](x12) 28 | x14 = self.vgg.features[13](x13) # relu 29 | x15 = self.vgg.features[14](x14) 30 | x16 = self.vgg.features[15](x15) # relu 31 | x17 = self.vgg.features[16](x16) 32 | x18 = self.vgg.features[17](x17) # relu 33 | x19 = self.avgpool(x18) 34 | 35 | x20 = self.vgg.features[19](x19) 36 | x21 = self.vgg.features[20](x20) # relu 37 | x22 = self.vgg.features[21](x21) 38 | x23 = self.vgg.features[22](x22) # relu 39 | x24 = self.vgg.features[23](x23) 40 | x25 = self.vgg.features[24](x24) # relu 41 | x26 = self.vgg.features[25](x25) 42 | x27 = self.vgg.features[26](x26) # relu 43 | 44 | return x2, x4, x7, x9, x12, x14, x16, x18, x21, x23, x25, x27 45 | # return x1, x3, x6, x8, x11, x13, x15, x17, x20, x22, x24, x26 46 | 47 | 48 | if __name__ == '__main__': 49 | import torchvision 50 | vgg19 = torchvision.models.vgg19(pretrained=True) 51 | vgg = my_vgg(vgg19) 52 | tmp = torch.randn((32, 3, 16, 16)) 53 | x2, x4, x7, x9, x12, x14, x16, x18, x21, x23, x25, x27 = vgg(tmp) 54 | print(x27.size()) 55 | -------------------------------------------------------------------------------- /src/opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def parse_opts(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument( 7 | '--batch_size', 8 | default=16, 9 | type=int, 10 | help='batch size') 11 | parser.add_argument( 12 | '--input_channel', 13 | default=3, 14 | type=int, 15 | help='input image channel (3 for RGB, 1 for Grayscale)') 16 | parser.add_argument( 17 | '--alpha_recon_image', 18 | default=0.85, 19 | type=float, 20 | help='weight of reconstruction loss.') 21 | parser.add_argument( 22 | '--input_size', 23 | default=(128, 256), 24 | type=tuple, 25 | help='input image size') 26 | parser.add_argument( 27 | '--num_frames', 28 | default=5, 29 | type=int, 30 | help='number of frames for each video clip') 31 | parser.add_argument( 32 | '--num_predicted_frames', 33 | default=4, 34 | type=int, 35 | help='number of frames to predict') 36 | parser.add_argument( 37 | '--num_epochs', 38 | default=1000, 39 | type=int, 40 | help= 41 | 'Max. number of epochs to train.' 42 | ) 43 | parser.add_argument( 44 | '--lr_rate', 45 | default=0.001, 46 | type=float, 47 | help='learning rate used for training.' 48 | ) 49 | parser.add_argument( 50 | '--lamda', 51 | default=0.1, 52 | type=float, 53 | help='weight use to penalize the generated occlusion mask.' 54 | ) 55 | parser.add_argument( 56 | '--workers', 57 | default=3, 58 | type=int, 59 | help='number of workers used for data loading.' 60 | ) 61 | parser.add_argument( 62 | '--dataset', 63 | default='cityscapes', 64 | type=str, 65 | help= 66 | 'Used dataset (cityscpes | cityscapes_two_path | kth | ucf101).' 67 | ) 68 | parser.add_argument( 69 | '--iter_to_load', 70 | default=1, 71 | type=int, 72 | help='iteration to load' 73 | ) 74 | parser.add_argument( 75 | '--mask_channel', 76 | default=20, 77 | type=int, 78 | help='channel of the input semantic lable map' 79 | ) 80 | parser.add_argument( 81 | '--category', 82 | default='walking', 83 | type=str, 84 | help='class category of the video to train (only apply to KTH and UCF101)' 85 | ) 86 | parser.add_argument( 87 | '--seed', 88 | default=31415, 89 | type=int, 90 | help='Manually set random seed' 91 | ) 92 | parser.add_argument( 93 | '--suffix', 94 | default='', 95 | type=str, 96 | help='model suffix' 97 | ) 98 | 99 | args = parser.parse_args() 100 | 101 | return args -------------------------------------------------------------------------------- /src/test_refine.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | from torch.autograd import Variable as Vb 4 | from torch.utils.data import DataLoader 5 | 6 | import os, time, sys 7 | from argparse import ArgumentParser, Namespace 8 | from tqdm import tqdm 9 | 10 | from models.multiframe_genmask import * 11 | from dataset import get_test_set 12 | from utils import utils 13 | from opts import parse_opts 14 | 15 | args = parse_opts() 16 | print(args) 17 | 18 | 19 | def make_save_dir(output_image_dir): 20 | val_cities = ['frankfurt', 'lindau', 'munster'] 21 | for city in val_cities: 22 | pathOutputImages = os.path.join(output_image_dir, city) 23 | if not os.path.isdir(pathOutputImages): 24 | os.makedirs(pathOutputImages) 25 | 26 | 27 | class flowgen(object): 28 | 29 | def __init__(self, opt): 30 | 31 | self.opt = opt 32 | 33 | print("Random Seed: ", self.opt.seed) 34 | torch.manual_seed(self.opt.seed) 35 | torch.cuda.manual_seed_all(self.opt.seed) 36 | 37 | dataset = opt.dataset 38 | self.suffix = '_' + opt.suffix 39 | 40 | self.refine = True 41 | self.useHallucination = False 42 | self.jobname = dataset + self.suffix 43 | self.modeldir = self.jobname + 'model' 44 | 45 | # whether to start training from an existing snapshot 46 | self.load = True 47 | self.iter_to_load = opt.iter_to_load 48 | 49 | ''' Cityscapes''' 50 | 51 | test_Dataset = get_test_set(opt) 52 | 53 | # self.sampledir = os.path.join('../city_scapes_test_results', self.jobname, 54 | # self.suffix + '_' + str(self.iter_to_load)+'_'+str(opt.seed)) 55 | 56 | self.sampledir = os.path.join('../ucf101_results', opt.category, self.suffix + '_' + str(opt.seed)) 57 | 58 | if not os.path.exists(self.sampledir): 59 | os.makedirs(self.sampledir) 60 | 61 | self.testloader = DataLoader(test_Dataset, batch_size=opt.batch_size, shuffle=False, pin_memory=True, 62 | num_workers=8) 63 | 64 | # Create Folder for test images. 65 | self.output_image_before_dir = self.sampledir + '_images_before' 66 | self.output_image_dir = self.sampledir + '_images' 67 | self.output_bw_flow_dir = self.sampledir + '_bw_flow' 68 | self.output_fw_flow_dir = self.sampledir + '_fw_flow' 69 | 70 | self.output_bw_mask_dir = self.sampledir + '_bw_mask' 71 | self.output_fw_mask_dir = self.sampledir + '_fw_mask' 72 | 73 | make_save_dir(self.output_image_dir) 74 | make_save_dir(self.output_image_before_dir) 75 | 76 | make_save_dir(self.output_bw_flow_dir) 77 | make_save_dir(self.output_fw_flow_dir) 78 | 79 | make_save_dir(self.output_fw_mask_dir) 80 | make_save_dir(self.output_bw_mask_dir) 81 | 82 | def test(self): 83 | 84 | opt = self.opt 85 | 86 | gpu_ids = range(torch.cuda.device_count()) 87 | print('Number of GPUs in use {}'.format(gpu_ids)) 88 | 89 | iteration = 0 90 | 91 | if torch.cuda.device_count() > 1: 92 | vae = nn.DataParallel(VAE(hallucination=self.useHallucination, opt=opt, refine=self.refine), 93 | device_ids=gpu_ids).cuda() 94 | else: 95 | vae = VAE(hallucination=self.useHallucination, opt=opt, refine=self.refine).cuda() 96 | 97 | print(self.jobname) 98 | 99 | if self.load: 100 | # model_name = '../' + self.jobname + '/{:06d}_model.pth.tar'.format(self.iter_to_load) 101 | 102 | if opt.dataset == 'cityscapes': 103 | model_name = '../pretrained_models/cityscapes/refine_genmask_098000.pth.tar' 104 | 105 | elif opt.dataset == 'ucf101': 106 | model_name = '../pretrained_models/' + opt.dataset + '/' + opt.category.lower() + '_model.pth.tar' 107 | 108 | else: 109 | model_name = '../pretrained_models/' + opt.dataset + '/' + opt.category + '_16frames_model.pth.tar' 110 | 111 | print("loading model from {}".format(model_name)) 112 | 113 | state_dict = torch.load(model_name) 114 | if torch.cuda.device_count() > 1: 115 | vae.module.load_state_dict(state_dict['vae']) 116 | else: 117 | vae.load_state_dict(state_dict['vae']) 118 | 119 | z_noise = torch.ones(1, 1024).normal_() 120 | for sample, _, paths in tqdm(iter(self.testloader)): 121 | # Set to evaluation mode (randomly sample z from the whole distribution) 122 | vae.eval() 123 | 124 | # Read data 125 | data = Vb(sample) 126 | 127 | # If test on generated images 128 | # data = data.unsqueeze(1) 129 | # data = data.repeat(1, opt.num_frames, 1, 1, 1) 130 | 131 | frame1 = data[:, 0, :, :, :] 132 | 133 | noise_bg = Vb(torch.randn(frame1.size())).cuda() 134 | 135 | z_m = Vb(z_noise.repeat(frame1.size()[0] * 4 * int(frame1.shape[-1] / 128), 1)) 136 | 137 | y_pred_before_refine, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw = vae(frame1, data, noise_bg, 138 | z_m) 139 | 140 | utils.save_samples(data, y_pred_before_refine, y_pred, flow, mask_fw, mask_bw, iteration, self.sampledir, 141 | opt, 142 | eval=True, useMask=True) 143 | 144 | utils.save_images(self.output_image_dir, data, y_pred, paths, opt) 145 | utils.save_images(self.output_image_before_dir, data, y_pred_before_refine, paths, opt) 146 | 147 | data = data.cpu().data.transpose(2, 3).transpose(3, 4).numpy() 148 | utils.save_gif(data * 255, opt.num_frames, [8, 4], self.sampledir + '/{:06d}_real.gif'.format(iteration)) 149 | 150 | # utils.save_flows(self.output_fw_flow_dir, flow, paths) 151 | # utils.save_flows(self.output_bw_flow_dir, flowback, paths) 152 | # 153 | # utils.save_occ_map(self.output_fw_mask_dir, mask_fw, paths) 154 | # utils.save_occ_map(self.output_bw_mask_dir, mask_bw, paths) 155 | 156 | iteration += 1 157 | 158 | 159 | if __name__ == '__main__': 160 | a = flowgen(opt=args) 161 | a.test() 162 | -------------------------------------------------------------------------------- /src/test_refine_w_mask.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | from torch.autograd import Variable as Vb 4 | from torch.utils.data import DataLoader 5 | 6 | import os, time, sys 7 | from tqdm import tqdm 8 | 9 | from models.multiframe_w_mask_genmask import * 10 | from dataset import get_test_set 11 | from utils import utils 12 | from opts import parse_opts 13 | 14 | args = parse_opts() 15 | print (args) 16 | 17 | 18 | def make_save_dir(output_image_dir): 19 | val_cities = ['frankfurt', 'lindau', 'munster'] 20 | for city in val_cities: 21 | pathOutputImages = os.path.join(output_image_dir, city) 22 | if not os.path.isdir(pathOutputImages): 23 | os.makedirs(pathOutputImages) 24 | 25 | 26 | class flowgen(object): 27 | 28 | def __init__(self, opt): 29 | 30 | self.opt = opt 31 | 32 | print("Random Seed: ", self.opt.seed) 33 | torch.manual_seed(self.opt.seed) 34 | torch.cuda.manual_seed_all(self.opt.seed) 35 | 36 | dataset = opt.dataset 37 | self.suffix = '_' + opt.suffix 38 | 39 | self.refine = True 40 | self.useHallucination = False 41 | self.jobname = dataset + self.suffix 42 | self.modeldir = self.jobname + 'model' 43 | 44 | # whether to start training from an existing snapshot 45 | self.load = True 46 | self.iter_to_load = opt.iter_to_load 47 | 48 | ''' Cityscapes''' 49 | test_Dataset = get_test_set(opt) 50 | 51 | self.sampledir = os.path.join('../city_scapes_test_results', self.jobname, 52 | self.suffix + '_' + str(self.iter_to_load)+'_'+str(opt.seed)) 53 | 54 | if not os.path.exists(self.sampledir): 55 | os.makedirs(self.sampledir) 56 | 57 | self.testloader = DataLoader(test_Dataset, batch_size=opt.batch_size, shuffle=False, pin_memory=True, num_workers=8) 58 | 59 | # Create Folder for test images. 60 | 61 | self.output_image_dir = self.sampledir + '_images' 62 | self.output_image_before_dir = self.sampledir + '_images_before' 63 | self.output_bw_flow_dir = self.sampledir + '_bw_flow' 64 | self.output_fw_flow_dir = self.sampledir + '_fw_flow' 65 | 66 | self.output_bw_mask_dir = self.sampledir + '_bw_mask' 67 | self.output_fw_mask_dir = self.sampledir + '_fw_mask' 68 | 69 | make_save_dir(self.output_image_dir) 70 | make_save_dir(self.output_image_before_dir) 71 | 72 | make_save_dir(self.output_bw_flow_dir) 73 | make_save_dir(self.output_fw_flow_dir) 74 | 75 | make_save_dir(self.output_fw_mask_dir) 76 | make_save_dir(self.output_bw_mask_dir) 77 | 78 | def test(self): 79 | 80 | opt = self.opt 81 | gpu_ids = range(torch.cuda.device_count()) 82 | print ('Number of GPUs in use {}'.format(gpu_ids)) 83 | 84 | iteration = 0 85 | 86 | if torch.cuda.device_count() > 1: 87 | vae = nn.DataParallel(VAE(hallucination=self.useHallucination, opt=opt, refine=self.refine), device_ids=gpu_ids).cuda() 88 | else: 89 | vae = VAE(hallucination=self.useHallucination, opt=opt).cuda() 90 | 91 | print(self.jobname) 92 | 93 | if self.load: 94 | # model_name = '../' + self.jobname + '/{:06d}_model.pth.tar'.format(self.iter_to_load) 95 | model_name = '../pretrained_models/cityscapes/refine_genmask_w_mask_098000.pth.tar' 96 | 97 | print ("loading model from {}".format(model_name)) 98 | 99 | state_dict = torch.load(model_name) 100 | if torch.cuda.device_count() > 1: 101 | vae.module.load_state_dict(state_dict['vae']) 102 | else: 103 | vae.load_state_dict(state_dict['vae']) 104 | 105 | z_noise = torch.ones(1, 1024).normal_() 106 | 107 | for sample, mask, paths in tqdm(iter(self.testloader)): 108 | # Set to evaluation mode (randomly sample z from the whole distribution) 109 | vae.eval() 110 | 111 | # Read data 112 | data = Vb(sample) 113 | mask = Vb(mask) 114 | 115 | # If test on generated images 116 | # data = data.unsqueeze(1) 117 | # data = data.repeat(1, opt.num_frames, 1, 1, 1) 118 | 119 | frame1 = data[:, 0, :, :, :] 120 | noise_bg = Vb(torch.randn(frame1.size())).cuda() 121 | z_m = Vb(z_noise.repeat(frame1.size()[0] * 8, 1)) 122 | 123 | y_pred_before_refine, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw = vae(frame1, data, mask, noise_bg, z_m) 124 | 125 | # y_pred = y_pred * mask_fw.unsqueeze(2) 126 | 127 | utils.save_samples(data, y_pred_before_refine, y_pred, flow, mask_fw, mask_bw, iteration, self.sampledir, opt, 128 | eval=True, useMask=True, grid=[4, 4]) 129 | 130 | utils.save_images(self.output_image_before_dir, data, y_pred_before_refine, paths, opt) 131 | utils.save_images(self.output_image_dir, data, y_pred, paths, opt) 132 | 133 | utils.save_flows(self.output_fw_flow_dir, flow, paths) 134 | utils.save_flows(self.output_bw_flow_dir, flowback, paths) 135 | 136 | utils.save_occ_map(self.output_fw_mask_dir, mask_fw, paths) 137 | utils.save_occ_map(self.output_bw_mask_dir, mask_bw, paths) 138 | 139 | data = data.cpu().data.transpose(2, 3).transpose(3, 4).numpy() 140 | utils.save_gif(data * 255, opt.num_frames, [4, 4], self.sampledir + '/{:06d}_real.gif'.format(iteration)) 141 | 142 | iteration += 1 143 | 144 | 145 | if __name__ == '__main__': 146 | a = flowgen(opt=args) 147 | a.test() -------------------------------------------------------------------------------- /src/test_refine_w_mask_two_path.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | from torch.autograd import Variable as Vb 4 | from torch.utils.data import DataLoader 5 | 6 | import os, time, sys 7 | from argparse import ArgumentParser, Namespace 8 | from tqdm import tqdm 9 | 10 | from models.multiframe_w_mask_genmask_two_path import * 11 | from dataset import get_test_set 12 | from utils import utils 13 | from opts import parse_opts 14 | 15 | args = parse_opts() 16 | print (args) 17 | 18 | 19 | def make_save_dir(output_image_dir): 20 | val_cities = ['frankfurt', 'lindau', 'munster'] 21 | for city in val_cities: 22 | pathOutputImages = os.path.join(output_image_dir, city) 23 | if not os.path.isdir(pathOutputImages): 24 | os.makedirs(pathOutputImages) 25 | 26 | 27 | class flowgen(object): 28 | 29 | def __init__(self, opt): 30 | 31 | self.opt = opt 32 | 33 | print("Random Seed: ", self.opt.seed) 34 | torch.manual_seed(self.opt.seed) 35 | torch.cuda.manual_seed_all(self.opt.seed) 36 | 37 | dataset = opt.dataset 38 | self.suffix = '_' + opt.suffix 39 | 40 | self.refine = True 41 | self.useHallucination = False 42 | self.jobname = dataset + self.suffix 43 | self.modeldir = self.jobname + 'model' 44 | 45 | # whether to start training from an existing snapshot 46 | self.load = True 47 | self.iter_to_load = opt.iter_to_load 48 | 49 | ''' Cityscapes''' 50 | test_Dataset = get_test_set(opt) 51 | 52 | self.sampledir = os.path.join('../city_scapes_test_results', self.jobname, 53 | self.suffix + '_' + str(self.iter_to_load)+'_'+str(opt.seed)) 54 | 55 | if not os.path.exists(self.sampledir): 56 | os.makedirs(self.sampledir) 57 | 58 | self.testloader = DataLoader(test_Dataset, batch_size=opt.batch_size, shuffle=False, pin_memory=True, num_workers=8) 59 | 60 | # Create Folder for test images. 61 | self.output_image_dir = self.sampledir + '_images' 62 | 63 | self.output_image_dir_before = self.sampledir + '_images_before' 64 | self.output_bw_flow_dir = self.sampledir + '_bw_flow' 65 | self.output_fw_flow_dir = self.sampledir + '_fw_flow' 66 | 67 | self.output_bw_mask_dir = self.sampledir + '_bw_mask' 68 | self.output_fw_mask_dir = self.sampledir + '_fw_mask' 69 | 70 | make_save_dir(self.output_image_dir) 71 | make_save_dir(self.output_image_dir_before) 72 | 73 | make_save_dir(self.output_bw_flow_dir) 74 | make_save_dir(self.output_fw_flow_dir) 75 | 76 | make_save_dir(self.output_fw_mask_dir) 77 | make_save_dir(self.output_bw_mask_dir) 78 | 79 | def test(self): 80 | 81 | opt = self.opt 82 | 83 | gpu_ids = range(torch.cuda.device_count()) 84 | print ('Number of GPUs in use {}'.format(gpu_ids)) 85 | 86 | iteration = 0 87 | 88 | if torch.cuda.device_count() > 1: 89 | vae = nn.DataParallel(VAE(hallucination=self.useHallucination, opt=opt, refine=self.refine, bg=128, fg=896), device_ids=gpu_ids).cuda() 90 | else: 91 | vae = VAE(hallucination=self.useHallucination, opt=opt).cuda() 92 | 93 | print(self.jobname) 94 | 95 | if self.load: 96 | model_name = '../pretrained_models/cityscapes/refine_genmask_w_mask_two_path_096000.pth.tar' 97 | # model_name = '../' + self.jobname + '/{:06d}_model.pth.tar'.format(self.iter_to_load) 98 | 99 | print ("loading model from {}".format(model_name)) 100 | 101 | state_dict = torch.load(model_name) 102 | if torch.cuda.device_count() > 1: 103 | vae.module.load_state_dict(state_dict['vae']) 104 | else: 105 | vae.load_state_dict(state_dict['vae']) 106 | 107 | z_noise = torch.ones(1, 1024).normal_() 108 | 109 | for data, bg_mask, fg_mask, paths in tqdm(iter(self.testloader)): 110 | # Set to evaluation mode (randomly sample z from the whole distribution) 111 | vae.eval() 112 | 113 | # If test on generated images 114 | # data = data.unsqueeze(1) 115 | # data = data.repeat(1, opt.num_frames, 1, 1, 1) 116 | 117 | frame1 = data[:, 0, :, :, :] 118 | noise_bg = torch.randn(frame1.size()) 119 | z_m = Vb(z_noise.repeat(frame1.size()[0] * 8, 1)) 120 | # 121 | 122 | y_pred_before_refine, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw = vae(frame1, data, bg_mask, fg_mask, noise_bg, z_m) 123 | 124 | utils.save_samples(data, y_pred_before_refine, y_pred, flow, mask_fw, mask_bw, iteration, self.sampledir, opt, 125 | eval=True, useMask=True, grid=[4, 4]) 126 | 127 | '''save images''' 128 | utils.save_images(self.output_image_dir, data, y_pred, paths, opt) 129 | utils.save_images(self.output_image_dir_before, data, y_pred_before_refine, paths, opt) 130 | 131 | data = data.cpu().data.transpose(2, 3).transpose(3, 4).numpy() 132 | utils.save_gif(data * 255, opt.num_frames, [4, 4], self.sampledir + '/{:06d}_real.gif'.format(iteration)) 133 | 134 | '''save flows''' 135 | utils.save_flows(self.output_fw_flow_dir, flow, paths) 136 | utils.save_flows(self.output_bw_flow_dir, flowback, paths) 137 | 138 | '''save occlusion maps''' 139 | utils.save_occ_map(self.output_fw_mask_dir, mask_fw, paths) 140 | utils.save_occ_map(self.output_bw_mask_dir, mask_bw, paths) 141 | 142 | iteration += 1 143 | 144 | 145 | if __name__ == '__main__': 146 | a = flowgen(opt=args) 147 | a.test() -------------------------------------------------------------------------------- /src/test_refine_w_mask_two_path_iterative.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | from torch.autograd import Variable as Vb 4 | from torch.utils.data import DataLoader 5 | 6 | import os, time, sys 7 | from tqdm import tqdm 8 | 9 | from models.multiframe_w_mask_genmask_two_path_iterative import * 10 | from dataset import get_test_set 11 | from utils import utils 12 | from opts import parse_opts 13 | 14 | args = parse_opts() 15 | print (args) 16 | 17 | 18 | def make_save_dir(output_image_dir): 19 | val_cities = ['frankfurt', 'lindau', 'munster'] 20 | for city in val_cities: 21 | pathOutputImages = os.path.join(output_image_dir, city) 22 | if not os.path.isdir(pathOutputImages): 23 | os.makedirs(pathOutputImages) 24 | 25 | 26 | class flowgen(object): 27 | 28 | def __init__(self, opt): 29 | 30 | self.opt = opt 31 | 32 | print("Random Seed: ", self.opt.seed) 33 | torch.manual_seed(self.opt.seed) 34 | torch.cuda.manual_seed_all(self.opt.seed) 35 | 36 | dataset = opt.dataset 37 | self.suffix = '_' + opt.suffix 38 | 39 | self.refine = True 40 | self.useHallucination = False 41 | self.jobname = dataset + self.suffix 42 | self.modeldir = self.jobname + 'model' 43 | 44 | # whether to start training from an existing snapshot 45 | self.load = True 46 | self.iter_to_load = opt.iter_to_load 47 | 48 | ''' Cityscapes''' 49 | from cityscapes_dataloader_w_mask_two_path import Cityscapes 50 | 51 | test_Dataset = get_test_set(opt) 52 | 53 | self.sampledir = os.path.join('../city_scapes_test_results', self.jobname, 54 | self.suffix + '_' + str(self.iter_to_load)+'_'+str(opt.seed)+'_iterative') 55 | 56 | if not os.path.exists(self.sampledir): 57 | os.makedirs(self.sampledir) 58 | 59 | self.testloader = DataLoader(test_Dataset, batch_size=opt.batch_size, shuffle=False, pin_memory=True, num_workers=8) 60 | 61 | # Create Folder for test images. 62 | self.output_image_dir = self.sampledir + '_images' 63 | 64 | self.output_image_dir_before = self.sampledir + '_images_before' 65 | self.output_bw_flow_dir = self.sampledir + '_bw_flow' 66 | self.output_fw_flow_dir = self.sampledir + '_fw_flow' 67 | 68 | self.output_bw_mask_dir = self.sampledir + '_bw_mask' 69 | self.output_fw_mask_dir = self.sampledir + '_fw_mask' 70 | 71 | make_save_dir(self.output_image_dir) 72 | make_save_dir(self.output_image_dir_before) 73 | 74 | make_save_dir(self.output_bw_flow_dir) 75 | make_save_dir(self.output_fw_flow_dir) 76 | 77 | make_save_dir(self.output_fw_mask_dir) 78 | make_save_dir(self.output_bw_mask_dir) 79 | 80 | def test(self): 81 | 82 | opt = self.opt 83 | 84 | gpu_ids = range(torch.cuda.device_count()) 85 | print ('Number of GPUs in use {}'.format(gpu_ids)) 86 | 87 | iteration = 0 88 | 89 | if torch.cuda.device_count() > 1: 90 | vae = nn.DataParallel(VAE(hallucination=self.useHallucination, opt=opt, refine=self.refine, bg=128, fg=896), device_ids=gpu_ids).cuda() 91 | else: 92 | vae = VAE(hallucination=self.useHallucination, opt=opt).cuda() 93 | 94 | print(self.jobname) 95 | 96 | if self.load: 97 | # model_name = '../' + self.jobname + '/{:06d}_model.pth.tar'.format(self.iter_to_load) 98 | model_name = '../pretrained_models/cityscapes/refine_genmask_w_mask_two_path_096000.pth.tar' 99 | 100 | print ("loading model from {}".format(model_name)) 101 | 102 | state_dict = torch.load(model_name) 103 | if torch.cuda.device_count() > 1: 104 | vae.module.load_state_dict(state_dict['vae']) 105 | else: 106 | vae.load_state_dict(state_dict['vae']) 107 | 108 | z_noise = torch.ones(1, 1024).normal_() 109 | 110 | for data, bg_mask, fg_mask, paths in tqdm(iter(self.testloader)): 111 | # Set to evaluation mode (randomly sample z from the whole distribution) 112 | vae.eval() 113 | 114 | # If test on generated images 115 | # data = data.unsqueeze(1) 116 | # data = data.repeat(1, opt.num_frames, 1, 1, 1) 117 | 118 | frame1 = data[:, 0, :, :, :] 119 | noise_bg = torch.randn(frame1.size()) 120 | z_m = Vb(z_noise.repeat(frame1.size()[0] * 8, 1)) 121 | 122 | y_pred_before_refine, y_pred, flow, flowback, mask_fw, mask_bw, warped_mask_bg, warped_mask_fg = vae(frame1, data, bg_mask, fg_mask, noise_bg, z_m) 123 | 124 | '''iterative generation''' 125 | 126 | for i in range(5): 127 | noise_bg = torch.randn(frame1.size()) 128 | 129 | y_pred_before_refine_1, y_pred_1, flow_1, flowback_1, mask_fw_1, mask_bw_1, warped_mask_bg, warped_mask_fg = vae(y_pred[:,-1,...], y_pred, warped_mask_bg, warped_mask_fg, noise_bg, z_m) 130 | 131 | y_pred_before_refine = torch.cat([y_pred_before_refine, y_pred_before_refine_1], 1) 132 | y_pred = torch.cat([y_pred, y_pred_1], 1) 133 | flow = torch.cat([flow, flow_1], 2) 134 | flowback = torch.cat([flowback, flowback_1], 2) 135 | mask_fw = torch.cat([mask_fw, mask_fw_1], 1) 136 | mask_bw = torch.cat([mask_bw, mask_bw_1], 1) 137 | 138 | print(y_pred_before_refine.size()) 139 | 140 | utils.save_samples(data, y_pred_before_refine, y_pred, flow, mask_fw, mask_bw, iteration, self.sampledir, opt, 141 | eval=True, useMask=True, grid=[4, 4]) 142 | 143 | # '''save images''' 144 | utils.save_images(self.output_image_dir, data, y_pred, paths, opt) 145 | utils.save_images(self.output_image_dir_before, data, y_pred_before_refine, paths, opt) 146 | 147 | iteration += 1 148 | 149 | 150 | if __name__ == '__main__': 151 | a = flowgen(opt=args) 152 | a.test() -------------------------------------------------------------------------------- /src/train_refine_multigpu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable as Vb 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader 5 | import os, time, sys 6 | 7 | from models.multiframe_genmask import * 8 | from utils import utils 9 | from utils import ops 10 | import losses 11 | from dataset import get_training_set, get_test_set 12 | from opts import parse_opts 13 | 14 | opt = parse_opts() 15 | print (opt) 16 | 17 | 18 | class flowgen(object): 19 | 20 | def __init__(self, opt): 21 | self.opt = opt 22 | dataset = 'cityscapes_seq_full' 23 | self.workspace = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..') 24 | 25 | self.jobname = dataset + '_gpu8_refine_genmask_linklink_256_1node' 26 | self.modeldir = self.jobname + 'model' 27 | self.sampledir = os.path.join(self.workspace, self.jobname) 28 | self.parameterdir = self.sampledir + '/params' 29 | self.useHallucination = False 30 | 31 | if not os.path.exists(self.parameterdir): 32 | os.makedirs(self.parameterdir) 33 | 34 | # whether to start training from an existing snapshot 35 | self.load = False 36 | self.iter_to_load = 62000 37 | 38 | # Write parameters setting file 39 | if os.path.exists(self.parameterdir): 40 | utils.save_parameters(self) 41 | 42 | ''' Cityscapes''' 43 | train_Dataset = get_training_set(opt) 44 | test_Dataset = get_test_set(opt) 45 | 46 | self.trainloader = DataLoader(train_Dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.workers, 47 | pin_memory=True, drop_last=True) 48 | self.testloader = DataLoader(test_Dataset, batch_size=2, shuffle=False, num_workers=opt.workers, 49 | pin_memory=True, drop_last=True) 50 | 51 | def train(self): 52 | 53 | opt = self.opt 54 | gpu_ids = range(torch.cuda.device_count()) 55 | print ('Number of GPUs in use {}'.format(gpu_ids)) 56 | 57 | iteration = 0 58 | 59 | vae = VAE(hallucination=self.useHallucination, opt=opt).cuda() 60 | if torch.cuda.device_count() > 1: 61 | vae = nn.DataParallel(vae).cuda() 62 | 63 | objective_func = losses.losses_multigpu_only_mask(opt, vae.module.floww) 64 | 65 | print(self.jobname) 66 | 67 | optimizer = optim.Adam(vae.parameters(), lr=opt.lr_rate) 68 | 69 | if self.load: 70 | 71 | model_name = self.sampledir + '/{:06d}_model.pth.tar'.format(self.iter_to_load) 72 | print ("loading model from {}".format(model_name)) 73 | 74 | state_dict = torch.load(model_name) 75 | if torch.cuda.device_count() > 1: 76 | vae.module.load_state_dict(state_dict['vae']) 77 | optimizer.load_state_dict(state_dict['optimizer']) 78 | else: 79 | vae.load_state_dict(state_dict['vae']) 80 | optimizer.load_state_dict(state_dict['optimizer']) 81 | iteration = self.iter_to_load + 1 82 | 83 | for epoch in range(opt.num_epochs): 84 | 85 | print('Epoch {}/{}'.format(epoch, opt.num_epochs - 1)) 86 | print('-' * 10) 87 | 88 | for sample, _ in iter(self.trainloader): 89 | 90 | # get the inputs 91 | data = sample.cuda() 92 | frame1 = data[:, 0, :, :, :] 93 | frame2 = data[:, 1:, :, :, :] 94 | noise_bg = torch.randn(frame1.size()).cuda() 95 | 96 | start = time.time() 97 | 98 | # Set train mode 99 | vae.train() 100 | 101 | # zero the parameter gradients 102 | optimizer.zero_grad() 103 | 104 | # forward + backward + optimize 105 | y_pred_before_refine, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw, prediction_vgg_feature, gt_vgg_feature = vae( 106 | frame1, data, noise_bg) 107 | 108 | # Compute losses 109 | flowloss, reconloss, reconloss_back, reconloss_before, kldloss, flowcon, sim_loss, vgg_loss, mask_loss = objective_func( 110 | frame1, frame2, 111 | y_pred, mu, logvar, flow, flowback, 112 | mask_fw, mask_bw, prediction_vgg_feature, gt_vgg_feature, 113 | y_pred_before_refine=y_pred_before_refine) 114 | 115 | loss = (flowloss + 2. * reconloss + reconloss_back + reconloss_before + kldloss * self.opt.lamda + flowcon + sim_loss + vgg_loss + 0.1 * mask_loss) 116 | 117 | # backward 118 | loss.backward() 119 | 120 | # Update 121 | optimizer.step() 122 | end = time.time() 123 | 124 | # print statistics 125 | if iteration % 20 == 0: 126 | print( 127 | "iter {} (epoch {}), recon_loss = {:.6f}, recon_loss_back = {:.3f}, " 128 | "recon_loss_before = {:.3f}, flow_loss = {:.6f}, flow_consist = {:.3f}, kl_loss = {:.6f}, " 129 | "img_sim_loss= {:.3f}, vgg_loss= {:.3f}, mask_loss={:.3f}, time/batch = {:.3f}" 130 | .format(iteration, epoch, reconloss.item(), reconloss_back.item(), reconloss_before.item(), 131 | flowloss.item(), flowcon.item(), 132 | kldloss.item(), sim_loss.item(), vgg_loss.item(), mask_loss.item(), end - start)) 133 | 134 | if iteration % 500 == 0: 135 | utils.save_samples(data, y_pred_before_refine, y_pred, flow, mask_fw, mask_bw, iteration, 136 | self.sampledir, opt) 137 | 138 | if iteration % 2000 == 0: 139 | # Set to evaluation mode (randomly sample z from the whole distribution) 140 | with torch.no_grad(): 141 | vae.eval() 142 | val_sample, _, _ = iter(self.testloader).next() 143 | 144 | # Read data 145 | data = val_sample.cuda() 146 | frame1 = data[:, 0, :, :, :] 147 | 148 | noise_bg = torch.randn(frame1.size()).cuda() 149 | y_pred_before_refine, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw = vae(frame1, data, 150 | noise_bg) 151 | 152 | utils.save_samples(data, y_pred_before_refine, y_pred, flow, mask_fw, mask_bw, iteration, 153 | self.sampledir, opt, 154 | eval=True, useMask=True) 155 | 156 | # Save model's parameter 157 | checkpoint_path = self.sampledir + '/{:06d}_model.pth.tar'.format(iteration) 158 | print("model saved to {}".format(checkpoint_path)) 159 | 160 | if torch.cuda.device_count() > 1: 161 | torch.save({'vae': vae.state_dict(), 'optimizer': optimizer.state_dict()}, 162 | checkpoint_path) 163 | else: 164 | torch.save({'vae': vae.module.state_dict(), 'optimizer': optimizer.state_dict()}, 165 | checkpoint_path) 166 | 167 | iteration += 1 168 | 169 | 170 | if __name__ == '__main__': 171 | '''Dist Init!!''' 172 | a = flowgen(opt) 173 | a.train() 174 | -------------------------------------------------------------------------------- /src/train_refine_multigpu_w_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable as Vb 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader 5 | import os, time, sys 6 | 7 | from models.multiframe_w_mask_genmask import * 8 | from utils import utils 9 | from utils import ops 10 | import losses 11 | from dataset import get_training_set, get_test_set 12 | from opts import parse_opts 13 | 14 | opt = parse_opts() 15 | print (opt) 16 | 17 | 18 | class flowgen(object): 19 | 20 | def __init__(self, opt): 21 | 22 | self.opt = opt 23 | dataset = 'cityscapes_seq_full' 24 | self.workspace = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..') 25 | 26 | self.jobname = dataset + '_gpu8_refine_genmask_linklink_256_1node' 27 | self.modeldir = self.jobname + 'model' 28 | self.sampledir = os.path.join(self.workspace, self.jobname) 29 | self.parameterdir = self.sampledir + '/params' 30 | self.useHallucination = False 31 | 32 | if not os.path.exists(self.parameterdir): 33 | os.makedirs(self.parameterdir) 34 | 35 | # whether to start training from an existing snapshot 36 | self.load = False 37 | self.iter_to_load = 62000 38 | 39 | # Write parameters setting file 40 | if os.path.exists(self.parameterdir): 41 | utils.save_parameters(self) 42 | 43 | ''' Cityscapes''' 44 | train_Dataset = get_training_set(opt) 45 | test_Dataset = get_test_set(opt) 46 | 47 | self.trainloader = DataLoader(train_Dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.workers, 48 | pin_memory=True, drop_last=True) 49 | self.testloader = DataLoader(test_Dataset, batch_size=2, shuffle=False, num_workers=opt.workers, 50 | pin_memory=True, drop_last=True) 51 | 52 | def train(self): 53 | 54 | opt = self.opt 55 | gpu_ids = range(torch.cuda.device_count()) 56 | print ('Number of GPUs in use {}'.format(gpu_ids)) 57 | 58 | iteration = 0 59 | 60 | vae = VAE(hallucination=self.useHallucination, opt=opt).cuda() 61 | if torch.cuda.device_count() > 1: 62 | vae = nn.DataParallel(vae).cuda() 63 | 64 | objective_func = losses.losses_multigpu_only_mask(opt, vae.module.floww) 65 | 66 | print(self.jobname) 67 | 68 | optimizer = optim.Adam(vae.parameters(), lr=opt.lr_rate) 69 | 70 | if self.load: 71 | 72 | model_name = self.sampledir + '/{:06d}_model.pth.tar'.format(self.iter_to_load) 73 | print ("loading model from {}".format(model_name)) 74 | 75 | state_dict = torch.load(model_name) 76 | if torch.cuda.device_count() > 1: 77 | vae.module.load_state_dict(state_dict['vae']) 78 | optimizer.load_state_dict(state_dict['optimizer']) 79 | else: 80 | vae.load_state_dict(state_dict['vae']) 81 | optimizer.load_state_dict(state_dict['optimizer']) 82 | iteration = self.iter_to_load + 1 83 | 84 | for epoch in range(opt.num_epochs): 85 | 86 | print('Epoch {}/{}'.format(epoch, opt.num_epochs - 1)) 87 | print('-' * 10) 88 | 89 | for sample, mask in iter(self.trainloader): 90 | 91 | # get the inputs 92 | data = sample.cuda() 93 | mask = mask.cuda() 94 | 95 | frame1 = data[:, 0, :, :, :] 96 | frame2 = data[:, 1:, :, :, :] 97 | noise_bg = torch.randn(frame1.size()).cuda() 98 | 99 | # torch.cuda.synchronize() 100 | start = time.time() 101 | 102 | # Set train mode 103 | vae.train() 104 | 105 | # zero the parameter gradients 106 | optimizer.zero_grad() 107 | 108 | # forward + backward + optimize 109 | y_pred_before_refine, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw, prediction_vgg_feature, gt_vgg_feature = vae(frame1, data, mask, noise_bg) 110 | 111 | # Compute losses 112 | flowloss, reconloss, reconloss_back, reconloss_before, kldloss, flowcon, sim_loss, vgg_loss, mask_loss = objective_func(frame1, frame2, 113 | y_pred, mu, logvar, flow, flowback, 114 | mask_fw, mask_bw, prediction_vgg_feature, gt_vgg_feature, 115 | y_pred_before_refine=y_pred_before_refine) 116 | 117 | loss = (flowloss + 2.*reconloss + reconloss_back + reconloss_before + kldloss * self.opt.lamda + flowcon + sim_loss + vgg_loss + 0.1*mask_loss) 118 | 119 | # backward 120 | loss.backward() 121 | 122 | # Update 123 | optimizer.step() 124 | 125 | end = time.time() 126 | 127 | # print statistics 128 | if iteration % 20 == 0: 129 | print("iter {} (epoch {}), recon_loss = {:.6f}, recon_loss_back = {:.3f}, recon_loss_before = {:.3f}, " 130 | "flow_loss = {:.6f}, flow_consist = {:.3f}, " 131 | "kl_loss = {:.6f}, img_sim_loss= {:.3f}, vgg_loss= {:.3f}, mask_loss={:.3f}, time/batch = {:.3f}" 132 | .format(iteration, epoch, reconloss.item(), reconloss_back.item(), reconloss_before.item(), flowloss.item(), flowcon.item(), 133 | kldloss.item(), sim_loss.item(), vgg_loss.item(), mask_loss.item(), end - start)) 134 | 135 | if iteration % 2000 == 0: 136 | # Set to evaluation mode (randomly sample z from the whole distribution) 137 | with torch.no_grad(): 138 | vae.eval() 139 | 140 | val_sample, val_mask, _ = iter(self.testloader).next() 141 | 142 | # Read data 143 | data = val_sample.cuda() 144 | mask = val_mask.cuda() 145 | frame1 = data[:, 0, :, :, :] 146 | 147 | noise_bg = torch.randn(frame1.size()).cuda() 148 | y_pred_before_refine, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw = vae(frame1, data, 149 | mask, noise_bg) 150 | 151 | utils.save_samples(data, y_pred_before_refine, y_pred, flow, mask_fw, mask_bw, iteration, 152 | self.sampledir, opt, 153 | eval=True, useMask=True) 154 | 155 | # Save model's parameter 156 | checkpoint_path = self.sampledir + '/{:06d}_model.pth.tar'.format(iteration) 157 | print("model saved to {}".format(checkpoint_path)) 158 | 159 | if torch.cuda.device_count() > 1: 160 | torch.save({'vae': vae.state_dict(), 'optimizer': optimizer.state_dict()}, 161 | checkpoint_path) 162 | else: 163 | torch.save({'vae': vae.module.state_dict(), 'optimizer': optimizer.state_dict()}, 164 | checkpoint_path) 165 | 166 | iteration += 1 167 | 168 | 169 | if __name__ == '__main__': 170 | a = flowgen(opt) 171 | a.train() 172 | -------------------------------------------------------------------------------- /src/train_refine_multigpu_w_mask_two_path.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable as Vb 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader 5 | import os, time, sys 6 | 7 | from models.multiframe_w_mask_genmask_two_path import * 8 | from utils import utils 9 | from utils import ops 10 | import losses 11 | from dataset import get_training_set, get_test_set 12 | from opts import parse_opts 13 | 14 | opt = parse_opts() 15 | print (opt) 16 | 17 | 18 | class flowgen(object): 19 | 20 | def __init__(self, opt): 21 | 22 | self.opt = opt 23 | dataset = 'cityscapes_seq_full' 24 | self.workspace = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..') 25 | 26 | self.jobname = dataset + '_gpu8_refine_genmask_linklink_256_1node' 27 | self.modeldir = self.jobname + 'model' 28 | self.sampledir = os.path.join(self.workspace, self.jobname) 29 | self.parameterdir = self.sampledir + '/params' 30 | self.useHallucination = False 31 | 32 | if not os.path.exists(self.parameterdir): 33 | os.makedirs(self.parameterdir) 34 | 35 | # whether to start training from an existing snapshot 36 | self.load = False 37 | self.iter_to_load = 62000 38 | 39 | # Write parameters setting file 40 | if os.path.exists(self.parameterdir): 41 | utils.save_parameters(self) 42 | 43 | ''' Cityscapes''' 44 | train_Dataset = get_training_set(opt) 45 | test_Dataset = get_test_set(opt) 46 | 47 | self.trainloader = DataLoader(train_Dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.workers, 48 | pin_memory=True, drop_last=True) 49 | self.testloader = DataLoader(test_Dataset, batch_size=2, shuffle=False, num_workers=opt.workers, 50 | pin_memory=True, drop_last=True) 51 | 52 | def train(self): 53 | 54 | opt = self.opt 55 | gpu_ids = range(torch.cuda.device_count()) 56 | print ('Number of GPUs in use {}'.format(gpu_ids)) 57 | 58 | iteration = 0 59 | 60 | vae = VAE(hallucination=self.useHallucination, opt=opt).cuda() 61 | if torch.cuda.device_count() > 1: 62 | vae = nn.DataParallel(vae).cuda() 63 | 64 | objective_func = losses.losses_multigpu_only_mask(opt, vae.module.floww) 65 | 66 | print(self.jobname) 67 | 68 | optimizer = optim.Adam(vae.parameters(), lr=opt.lr_rate) 69 | 70 | if self.load: 71 | 72 | model_name = self.sampledir + '/{:06d}_model.pth.tar'.format(self.iter_to_load) 73 | print ("loading model from {}".format(model_name)) 74 | 75 | state_dict = torch.load(model_name) 76 | if torch.cuda.device_count() > 1: 77 | vae.module.load_state_dict(state_dict['vae']) 78 | optimizer.load_state_dict(state_dict['optimizer']) 79 | else: 80 | vae.load_state_dict(state_dict['vae']) 81 | optimizer.load_state_dict(state_dict['optimizer']) 82 | iteration = self.iter_to_load + 1 83 | 84 | for epoch in range(opt.num_epochs): 85 | 86 | print('Epoch {}/{}'.format(epoch, opt.num_epochs - 1)) 87 | print('-' * 10) 88 | 89 | for sample, bg_mask, fg_mask in iter(self.trainloader): 90 | 91 | # get the inputs 92 | data = sample.cuda() 93 | # mask = mask.cuda() 94 | bg_mask = bg_mask.cuda() 95 | fg_mask = fg_mask.cuda() 96 | # print('loaded data') 97 | 98 | frame1 = data[:, 0, :, :, :] 99 | frame2 = data[:, 1:, :, :, :] 100 | noise_bg = torch.randn(frame1.size()).cuda() 101 | 102 | start = time.time() 103 | 104 | # Set train mode 105 | vae.train() 106 | 107 | # zero the parameter gradients 108 | optimizer.zero_grad() 109 | 110 | # forward + backward + optimize 111 | y_pred_before_refine, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw, prediction_vgg_feature, gt_vgg_feature = vae( 112 | frame1, data, bg_mask, fg_mask, noise_bg) 113 | 114 | # Compute losses 115 | flowloss, reconloss, reconloss_back, reconloss_before, kldloss, flowcon, sim_loss, vgg_loss, mask_loss = objective_func( 116 | frame1, frame2, 117 | y_pred, mu, logvar, flow, flowback, 118 | mask_fw, mask_bw, prediction_vgg_feature, gt_vgg_feature, 119 | y_pred_before_refine=y_pred_before_refine) 120 | 121 | loss = (flowloss + 2. * reconloss + reconloss_back + reconloss_before + kldloss * self.opt.lamda + flowcon + sim_loss + vgg_loss + 0.1 * mask_loss) 122 | 123 | # backward 124 | loss.backward() 125 | 126 | # Update 127 | optimizer.step() 128 | end = time.time() 129 | 130 | # print statistics 131 | if iteration % 20 == 0: 132 | print( 133 | "iter {} (epoch {}), recon_loss = {:.6f}, recon_loss_back = {:.3f}, " 134 | "recon_loss_before = {:.3f}, flow_loss = {:.6f}, flow_consist = {:.3f}, kl_loss = {:.6f}, " 135 | "img_sim_loss= {:.3f}, vgg_loss= {:.3f}, mask_loss={:.3f}, time/batch = {:.3f}" 136 | .format(iteration, epoch, reconloss.item(), reconloss_back.item(), reconloss_before.item(), 137 | flowloss.item(), flowcon.item(), 138 | kldloss.item(), sim_loss.item(), vgg_loss.item(), mask_loss.item(), end - start)) 139 | 140 | if iteration % 500 == 0: 141 | utils.save_samples(data, y_pred_before_refine, y_pred, flow, mask_fw, mask_bw, iteration, 142 | self.sampledir, opt) 143 | 144 | if iteration % 2000 == 0: 145 | # Set to evaluation mode (randomly sample z from the whole distribution) 146 | with torch.no_grad(): 147 | vae.eval() 148 | val_sample, val_bg_mask, val_fg_mask, _ = iter(self.testloader).next() 149 | 150 | # Read data 151 | data = val_sample.cuda() 152 | bg_mask = val_bg_mask.cuda() 153 | fg_mask = val_fg_mask.cuda() 154 | frame1 = data[:, 0, :, :, :] 155 | 156 | noise_bg = torch.randn(frame1.size()).cuda() 157 | y_pred_before_refine, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw = vae(frame1, data, 158 | bg_mask, 159 | fg_mask, 160 | noise_bg) 161 | 162 | utils.save_samples(data, y_pred_before_refine, y_pred, flow, mask_fw, mask_bw, iteration, 163 | self.sampledir, opt, 164 | eval=True, useMask=True) 165 | 166 | # Save model's parameter 167 | checkpoint_path = self.sampledir + '/{:06d}_model.pth.tar'.format(iteration) 168 | print("model saved to {}".format(checkpoint_path)) 169 | 170 | if torch.cuda.device_count() > 1: 171 | torch.save({'vae': vae.state_dict(), 'optimizer': optimizer.state_dict()}, 172 | checkpoint_path) 173 | else: 174 | torch.save({'vae': vae.module.state_dict(), 'optimizer': optimizer.state_dict()}, 175 | checkpoint_path) 176 | 177 | iteration += 1 178 | 179 | 180 | if __name__ == '__main__': 181 | 182 | a = flowgen(opt) 183 | a.train() 184 | 185 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .ops import * 3 | 4 | -------------------------------------------------------------------------------- /src/utils/check_file_list.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import os 4 | import numpy as np 5 | import glob 6 | from tqdm import tqdm 7 | from multiprocessing.dummy import Pool as ThreadPool 8 | 9 | # image_root_dir = '/mnt/lustre/panjunting/video_generation/cityscapes/leftImg8bit/train_extra/*' 10 | # image_root_dir = '/mnt/lustre/panjunting/video_generation/cityscapes/leftImg8bit/demoVideo/' 11 | image_root_dir = '/mnt/lustre/panjunting/video_generation/cityscapes/leftImg8bit_sequence/train/' 12 | listfile = open("cityscapes_train_sequence_full_8.txt", 'r') 13 | num_frame_to_predict = 8 14 | 15 | file_names = [file_name.strip() for file_name in listfile.readlines()] 16 | 17 | for file_name in tqdm(file_names): 18 | image_dir = image_root_dir + file_name 19 | 20 | for i in range(num_frame_to_predict): 21 | new_dir = image_dir[0:-22] + str(int(image_dir[-22:-16]) + i).zfill(6) + image_dir[-16::] 22 | if not os.path.isfile(new_dir): 23 | print new_dir 24 | print file_dir -------------------------------------------------------------------------------- /src/utils/cityscapes_gen_list.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import os 4 | import numpy as np 5 | import glob 6 | from tqdm import tqdm 7 | from multiprocessing.dummy import Pool as ThreadPool 8 | 9 | # image_root_dir = '/mnt/lustre/panjunting/video_generation/cityscapes/leftImg8bit/train_extra/*' 10 | # image_root_dir = '/mnt/lustre/panjunting/video_generation/cityscapes/leftImg8bit/demoVideo/' 11 | 12 | image_root_dir = '/mnt/lustrenew/DATAshare/leftImg8bit_sequence/val/' 13 | listfile = open("cityscapes_val_sequence_full_18.txt", 'a') 14 | 15 | # image_root_dir = '/mnt/lustrenew/DATAshare/gtFine/val/' 16 | # listfile = open("cityscapes_val_sequence_w_mask_8.txt", 'a') 17 | print (image_root_dir) 18 | # max = [6299, 599, 4599] 19 | # i = 0 20 | num_frame_to_predict = 18 21 | 22 | def gen_list_per_city(sub_dir): 23 | # image_list = glob.glob(sub_dir + "/*_gtFine_labelIds.png") 24 | image_list = glob.glob(sub_dir + "/*.png") 25 | for image_dir in tqdm(image_list): 26 | flag = True 27 | for j in range(1, num_frame_to_predict): 28 | new_dir = image_dir[0:-22] + str(int(image_dir[-22:-16]) + j).zfill(6) + image_dir[-16::] 29 | if not os.path.isfile(new_dir): 30 | flag = False 31 | if flag: 32 | # Replace mask suffix for image suffix 33 | # img_dir = image_dir.split(image_root_dir)[-1].split('_gtFine_labelIds.png')[0] + '_leftImg8bit.png' 34 | listfile.write(image_dir.split(image_root_dir)[-1] + "\n") 35 | # i += 1 36 | 37 | # new_dir = image_dir[0:-22] + str(int(image_dir[-22:-16]) + num_frame_to_predict).zfill(6) + image_dir[-16::] 38 | # if os.path.isfile(new_dir): 39 | # listfile.write(image_dir.split(image_root_dir)[-1]+"\n") 40 | # i += 1 41 | # print i 42 | 43 | #25 26 27 28 29 44 | 45 | cities = [sub_dir for sub_dir in glob.glob(image_root_dir + '*')] 46 | print (cities) 47 | # for city in cities: 48 | # gen_list_per_city(city) 49 | # # make the Pool of workers 50 | pool = ThreadPool(len(cities)) 51 | 52 | # open the urls in their own threads 53 | # and return the results 54 | results = pool.map(gen_list_per_city, cities) 55 | listfile.close() 56 | # close the pool and wait for the work to finish 57 | pool.close() 58 | pool.join() -------------------------------------------------------------------------------- /src/utils/cityscapes_gen_pix2pixImage_list.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from tqdm import tqdm 4 | 5 | image_root_dir = '/mnt/lustre/share/charlie/synth/images/' 6 | image_list_file = open('cityscapes_test_pix2pixImage_list.txt', 'w') 7 | 8 | 9 | for file_name in tqdm(glob.glob(image_root_dir+'*_synthesized_image.jpg')): 10 | image_list_file.write(file_name.split(image_root_dir)[-1]+'\n') 11 | image_list_file.close() 12 | -------------------------------------------------------------------------------- /src/utils/cityscapes_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | from tqdm import tqdm 5 | 6 | from multiprocessing.dummy import Pool as ThreadPool 7 | 8 | # image_root_dir = '/mnt/lustre/panjunting/video_generation/cityscapes/leftImg8bit/train_extra/*' 9 | # resized_image_root_dir = '/mnt/lustre/panjunting/video_generation/cityscapes/leftImg8bit/train_extra_256x512/' 10 | 11 | image_root_dir = '/mnt/lustrenew/DATAshare/leftImg8bit_sequence/val/*' 12 | resized_image_root_dir = '/mnt/lustrenew/DATAshare/leftImg8bit_sequence/val_256x128' 13 | 14 | img_size = (256, 128) 15 | 16 | 17 | # for sub_dir in glob.glob(image_root_dir): 18 | # for image_dir in tqdm(glob.glob(sub_dir + "/*.png")): 19 | # imageResized = cv2.resize(cv2.imread(image_dir, cv2.IMREAD_COLOR), img_size, interpolation=cv2.INTER_AREA) 20 | # filename = image_dir.split('/')[-1] 21 | # city_name = sub_dir.split('/')[-1] 22 | # 23 | # pathOutputImages = os.path.join(resized_image_root_dir, city_name) 24 | # 25 | # if not os.path.isdir(pathOutputImages): 26 | # os.makedirs(pathOutputImages) 27 | # cv2.imwrite(os.path.join(pathOutputImages, filename), imageResized) 28 | 29 | 30 | def resize_and_save(sub_dir): 31 | print sub_dir 32 | for image_dir in tqdm(glob.glob(sub_dir + "/*.png")): 33 | imageResized = cv2.resize(cv2.imread(image_dir, cv2.IMREAD_COLOR), img_size, interpolation=cv2.INTER_AREA) 34 | filename = image_dir.split('/')[-1] 35 | city_name = sub_dir.split('/')[-1] 36 | 37 | pathOutputImages = os.path.join(resized_image_root_dir, city_name) 38 | 39 | if not os.path.isdir(pathOutputImages): 40 | os.makedirs(pathOutputImages) 41 | cv2.imwrite(os.path.join(pathOutputImages, filename), imageResized) 42 | 43 | 44 | cities = [sub_dir for sub_dir in glob.glob(image_root_dir)] 45 | 46 | # make the Pool of workers 47 | pool = ThreadPool(len(cities)) 48 | 49 | # open the urls in their own threads 50 | # and return the results 51 | results = pool.map(resize_and_save, cities) 52 | 53 | # close the pool and wait for the work to finish 54 | pool.close() 55 | pool.join() -------------------------------------------------------------------------------- /src/utils/get_ucf101_list.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import os 4 | 5 | 6 | 7 | 8 | def listinit(path, testfilename='testbaskeball.txt', trainfilename='trainbaskeball.txt'): 9 | 10 | output_dir = '/mnt/lustre/panjunting/f2video2.0/UCF-101/list' 11 | 12 | fop1 = open(os.path.join(output_dir, testfilename), 'w') 13 | fop2 = open(os.path.join(output_dir, trainfilename), 'w') 14 | 15 | for item in os.listdir(path): 16 | if item.endswith('.npy'): 17 | k = random.randint(0, 9) 18 | print (item) 19 | try: 20 | a = np.load(os.path.join(path, item)) 21 | num = a.shape[0] 22 | 23 | if k == 9: 24 | for i in range(num - 4): 25 | fop1.write(item + ' ' + str(i) + '\n') 26 | else: 27 | for i in range(num - 4): 28 | fop2.write(item + ' ' + str(i) + '\n') 29 | except: 30 | print ("invalid npy") 31 | print (item) 32 | 33 | # class toframe(object): 34 | # def __init__(self, path): 35 | # self.path = path 36 | # self.datalist = os.listdir(path) 37 | # 38 | # def run1(self): 39 | # for item in self.datalist: 40 | # if mode == 'dir': 41 | # dirpath = os.path.join(self.path, item.strip('.avi')) 42 | # if tf.gfile.Exists(dirpath): 43 | # break 44 | # else: 45 | # tf.gfile.MkDir(dirpath) 46 | # elif mode == 'npy': 47 | # savef = [] 48 | # if item.endswith('.avi'): # and item.replace('.avi','.npy') not in self.datalist: 49 | # avipath = os.path.join(self.path, item) 50 | # cap = cv2.VideoCapture(avipath) 51 | # print avipath 52 | # ret = 1 53 | # while (ret): 54 | # ret, frame = cap.read() 55 | # if (ret): 56 | # try: 57 | # frame = cv2.resize(np.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)), (64, 64), 58 | # interpolation=cv2.INTER_AREA) 59 | # savef.append(frame) 60 | # except: 61 | # print 'noframe' 62 | # savef = np.array(savef) 63 | # np.save(avipath.replace('.avi', '_64.npy'), savef) 64 | 65 | 66 | from tqdm import tqdm 67 | 68 | def get_ucf_list(num_frames, in_filename, category): 69 | 70 | datapath = '/mnt/lustre/panjunting/f2video2.0/UCF-101' 71 | datalist = open(os.path.join(datapath, 'list', in_filename)).readlines() 72 | 73 | out_filename = in_filename[0:-4] + '_%dframes.txt'%num_frames 74 | 75 | f = open(os.path.join(datapath, out_filename), 'w') 76 | count = 0 77 | 78 | for idx in tqdm(range(len(datalist))): 79 | item = np.load(os.path.join(datapath, category, datalist[idx].split(' ')[0]).strip()) 80 | num = item.shape[0] 81 | start = int(datalist[idx].split(' ')[1]) 82 | if start + num_frames < num: 83 | f.write(datalist[idx]) 84 | count += 1 85 | print len(datalist) 86 | print count 87 | f.close() 88 | 89 | 90 | 91 | 92 | 93 | if __name__ == '__main__': 94 | # listinit(path='/mnt/lustre/panjunting/f2video2.0/UCF-101/Skiing', 95 | # testfilename='testskiing.txt', 96 | # trainfilename='trainskiing.txt') 97 | # get_ucf_list(18, 'trainskiing.txt', 'trainskiing_18frames.txt', 'Skiing') 98 | get_ucf_list(18, 'trainplayingviolin.txt', 'PlayingViolin') 99 | get_ucf_list(18, 'testplayingviolin.txt', 'PlayingViolin') 100 | -------------------------------------------------------------------------------- /src/utils/kth.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import numpy as np 4 | import socket 5 | import torch 6 | from scipy import misc 7 | from torch.utils.serialization import load_lua 8 | 9 | class KTH(object): 10 | 11 | def __init__(self, train, data_root, seq_len = 20, image_size=64): 12 | self.data_root = '%s/KTH/processed/' % data_root 13 | self.seq_len = seq_len 14 | self.image_size = image_size 15 | self.classes = ['boxing', 'handclapping', 'handwaving', 'jogging', 'running', 'walking'] 16 | 17 | self.dirs = os.listdir(self.data_root) 18 | if train: 19 | self.train = True 20 | data_type = 'train' 21 | self.persons = list(range(1, 21)) 22 | else: 23 | self.train = False 24 | self.persons = list(range(21, 26)) 25 | data_type = 'test' 26 | 27 | self.data= {} 28 | for c in self.classes: 29 | self.data[c] = load_lua('%s/%s/%s_meta%dx%d.t7' % (self.data_root, c, data_type, image_size, image_size)) 30 | 31 | 32 | self.seed_set = False 33 | 34 | def get_sequence(self): 35 | t = self.seq_len 36 | while True: # skip seqeunces that are too short 37 | c_idx = np.random.randint(len(self.classes)) 38 | c = self.classes[c_idx] 39 | vid_idx = np.random.randint(len(self.data[c])) 40 | vid = self.data[c][vid_idx] 41 | seq_idx = np.random.randint(len(vid['files'])) 42 | if len(vid['files'][seq_idx]) - t >= 0: 43 | break 44 | dname = '%s/%s/%s' % (self.data_root, c, vid['vid']) 45 | st = random.randint(0, len(vid['files'][seq_idx])-t) 46 | 47 | 48 | seq = [] 49 | for i in range(st, st+t): 50 | fname = '%s/%s' % (dname, vid['files'][seq_idx][i]) 51 | im = misc.imread(fname)/255. 52 | seq.append(im[:, :, 0].reshape(self.image_size, self.image_size, 1)) 53 | return np.array(seq) 54 | 55 | def __getitem__(self, index): 56 | if not self.seed_set: 57 | self.seed_set = True 58 | random.seed(index) 59 | np.random.seed(index) 60 | #torch.manual_seed(index) 61 | return torch.from_numpy(self.get_sequence()) 62 | 63 | def __len__(self): 64 | return len(self.dirs)*36*5 # arbitrary 65 | 66 | -------------------------------------------------------------------------------- /src/utils/kth_genlist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import glob 4 | from tqdm import tqdm 5 | from multiprocessing.dummy import Pool as ThreadPool 6 | 7 | 8 | def gen_list_per_city(image_root_dir, category, split, datalist, num_frame_to_predict): 9 | listfile = open("kth_" + split + "_" + category + "_%d_ok.txt" % num_frame_to_predict, 'w') 10 | for image_dir in tqdm(datalist): 11 | flag = True 12 | for j in range(1, num_frame_to_predict): 13 | new_dir = image_dir[0:-15] + str(int(image_dir[-15:-12]) + j).zfill(3) + image_dir[-12::] 14 | new_dir = image_root_dir + new_dir 15 | if not os.path.isfile(new_dir): 16 | flag = False 17 | if flag: 18 | # Replace mask suffix for image suffix 19 | # img_dir = image_dir.split(image_root_dir)[-1].split('_gtFine_labelIds.png')[0] + '_leftImg8bit.png' 20 | listfile.write(image_dir + "\n") 21 | listfile.close() 22 | 23 | 24 | 25 | def get_list(category, num_frame_to_predict, split): 26 | # image_root_dir = '/mnt/lustre/panjunting/video_generation/cityscapes/leftImg8bit/train_extra/*' 27 | # image_root_dir = '/mnt/lustre/panjunting/video_generation/cityscapes/leftImg8bit/demoVideo/' 28 | 29 | # image_root_dir = '/mnt/lustrenew/DATAshare/leftImg8bit_sequence/val/' 30 | # listfile = open("cityscapes_val_sequence_full_18.txt", 'a') 31 | 32 | image_root_dir = '/mnt/lustrenew/panjunting/kth/KTH/processed/' 33 | # listfile = open("kth_"+split+"_"+category+"_%d_ok.txt" %num_frame_to_predict, 'w') 34 | 35 | 36 | # image_root_dir = '/mnt/lustrenew/DATAshare/gtFine/val/' 37 | # listfile = open("cityscapes_val_sequence_w_mask_8.txt", 'a') 38 | print (image_root_dir) 39 | # max = [6299, 599, 4599] 40 | # i = 0 41 | # cities = [sub_dir for sub_dir in glob.glob(image_root_dir + '*')] 42 | # print (cities) 43 | # for city in cities: 44 | # gen_list_per_city(city) 45 | # # make the Pool of workers 46 | # pool = ThreadPool(len(cities)) 47 | 48 | 49 | datalist= open('kth_'+split+'_'+category+'_16_ok.txt', 'r') 50 | datalist = [l.strip() for l in datalist.readlines()] 51 | gen_list_per_city(image_root_dir, category, split, datalist, num_frame_to_predict) 52 | # open the urls in their own threads 53 | # and return the results 54 | # results = pool.map(gen_list_per_city, cities) 55 | # close the pool and wait for the work to finish 56 | # pool.close() 57 | # pool.join() 58 | 59 | 60 | def process_per_class(): 61 | listfile = open("kth_train_walking_16.txt", 'w') 62 | datalist = open('kth_train_16.txt','r') 63 | datalist = [l.strip() for l in datalist.readlines()] 64 | 65 | for image_dir in tqdm(datalist): 66 | if image_dir.split('/')[1] == 'walking': 67 | listfile.write(image_dir+'\n') 68 | listfile.close() 69 | 70 | 71 | def main(): 72 | # process_per_class() 73 | # main() 74 | 75 | # self.classes = ['boxing', 'handclapping', 'handwaving', 'jogging', 'running', 'walking'] 76 | datalist = open('kth_train_walking_16.txt','r') 77 | listfile = open('kth_train_walking_16_ok.txt', 'w') 78 | datalist = [l.strip() for l in datalist.readlines()] 79 | print(len(datalist)) 80 | datalist = set(datalist) 81 | for l in tqdm(datalist): 82 | listfile.write(l+'\n') 83 | print (len(datalist)) 84 | # process_per_class() 85 | 86 | def new_main(class_name='hadwaving', split='train'): 87 | 88 | data_list = open('kth_' + split + '_16.txt', 'r') 89 | datalist = [l.strip() for l in data_list.readlines()] 90 | 91 | listfile = open("kth_" + split + "_%s_16_ok.txt"%class_name, 'w') 92 | 93 | class_specific_list = [image_dir for image_dir in datalist if image_dir.split('/')[1] == class_name] 94 | print(len(class_specific_list)) 95 | 96 | class_specific_list_unique = set(class_specific_list) 97 | 98 | for l in tqdm(class_specific_list_unique): 99 | listfile.write(l+'\n') 100 | print (len(class_specific_list_unique)) 101 | listfile.close() 102 | data_list.close() 103 | 104 | # new_main('handwaving', split='train') 105 | # new_main('handwaving', split='test') 106 | 107 | # get_list('handwaving', 18, split='train') 108 | get_list('handwaving', 18, split='test') -------------------------------------------------------------------------------- /src/utils/ops.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | # import vgg 4 | from torch.autograd import Variable as Vb 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.nn.init as init 8 | import torchvision.utils as tov 9 | import cv2 10 | import datetime 11 | import numpy as np 12 | 13 | 14 | def make_color_wheel(): 15 | """ 16 | Generate color wheel according Middlebury color code 17 | :return: Color wheel 18 | """ 19 | RY = 15 20 | YG = 6 21 | GC = 4 22 | CB = 11 23 | BM = 13 24 | MR = 6 25 | 26 | ncols = RY + YG + GC + CB + BM + MR 27 | 28 | colorwheel = np.zeros([ncols, 3]) 29 | 30 | col = 0 31 | 32 | # RY 33 | colorwheel[0:RY, 0] = 255 34 | colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY)) 35 | col += RY 36 | 37 | # YG 38 | colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG)) 39 | colorwheel[col:col + YG, 1] = 255 40 | col += YG 41 | 42 | # GC 43 | colorwheel[col:col + GC, 1] = 255 44 | colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC)) 45 | col += GC 46 | 47 | # CB 48 | colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB)) 49 | colorwheel[col:col + CB, 2] = 255 50 | col += CB 51 | 52 | # BM 53 | colorwheel[col:col + BM, 2] = 255 54 | colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM)) 55 | col += + BM 56 | 57 | # MR 58 | colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) 59 | colorwheel[col:col + MR, 0] = 255 60 | 61 | return colorwheel 62 | 63 | 64 | def merge(images, size): 65 | cdim = images.shape[-1] 66 | h, w = images.shape[1], images.shape[2] 67 | if cdim == 1: 68 | img = np.zeros((h * size[0], w * size[1])) 69 | for idx, image in enumerate(images): 70 | i = idx % size[1] 71 | j = idx // size[1] 72 | img[j * h:j * h + h, i * w:i * w + w] = np.squeeze(image) 73 | return img 74 | else: 75 | img = np.zeros((h * size[0], w * size[1], cdim)) 76 | for idx, image in enumerate(images): 77 | i = idx % size[1] 78 | j = idx // size[1] 79 | img[j * h:j * h + h, i * w:i * w + w, :] = image 80 | # print img.shape 81 | return img 82 | 83 | 84 | def compute_color(u, v): 85 | """ 86 | compute optical flow color map 87 | :param u: optical flow horizontal map 88 | :param v: optical flow vertical map 89 | :return: optical flow in color code 90 | """ 91 | [h, w] = u.shape 92 | img = np.zeros([h, w, 3]) 93 | nanIdx = np.isnan(u) | np.isnan(v) 94 | u[nanIdx] = 0 95 | v[nanIdx] = 0 96 | 97 | colorwheel = make_color_wheel() 98 | ncols = np.size(colorwheel, 0) 99 | 100 | rad = np.sqrt(u ** 2 + v ** 2) 101 | 102 | a = np.arctan2(-v, -u) / np.pi 103 | 104 | fk = (a + 1) / 2 * (ncols - 1) + 1 105 | 106 | k0 = np.floor(fk).astype(int) 107 | 108 | k1 = k0 + 1 109 | k1[k1 == ncols + 1] = 1 110 | f = fk - k0 111 | 112 | for i in range(0, np.size(colorwheel, 1)): 113 | tmp = colorwheel[:, i] 114 | col0 = tmp[k0 - 1] / 255 115 | col1 = tmp[k1 - 1] / 255 116 | col = (1 - f) * col0 + f * col1 117 | 118 | idx = rad <= 1 119 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 120 | notidx = np.logical_not(idx) 121 | 122 | col[notidx] *= 0.75 123 | img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx))) 124 | 125 | return img 126 | 127 | 128 | # def saveflow(flows, imgsize, size, savepath): 129 | # num_images = size[0] * size[1] 130 | # flows = merge(flows[0:num_images], size) * 32 131 | # u = flows[:, :, 0] 132 | # v = flows[:, :, 1] 133 | # image = compute_color(u, v) 134 | # flow = cv.resize(image, imgsize) 135 | # cv2.imwrite(savepath, flow) 136 | 137 | def saveflow(flows, imgsize, savepath): 138 | u = flows[:, :, 0]*3 139 | v = flows[:, :, 1]*3 140 | image = compute_color(u, v) 141 | flow = cv2.resize(image, imgsize) 142 | cv2.imwrite(savepath, flow) 143 | 144 | 145 | def compute_flow_color_map(flows): 146 | u = flows[:, :, 0] * 3 147 | v = flows[:, :, 1] * 3 148 | flow = compute_color(u, v) 149 | # flow = cv2.resize(image, imgsize) 150 | return flow 151 | 152 | def compute_flow_img(flows, imgsize, size): 153 | # import pdb 154 | # pdb.set_trace() 155 | 156 | num_images = size[0] * size[1] 157 | flows = merge(flows[0:num_images], size) * 3 158 | u = flows[:, :, 0] 159 | v = flows[:, :, 1] 160 | image = compute_color(u, v) 161 | return image 162 | # cv2.imwrite(savepath, image) 163 | 164 | import imageio 165 | 166 | def save_flow_sequence(flows, length, imgsize, size, savepath): 167 | flow_seq = [np.uint8(compute_flow_img(flows[:,i,...], imgsize, size)) for i in range(length)] 168 | imageio.mimsave(savepath, flow_seq, fps=int(length)) 169 | 170 | 171 | def saveflowopencv(flows, imgsize, size, savepath): 172 | # print flows.shape 173 | hsv = np.uint8(np.zeros([128 * 4, 128 * 4, 3])) 174 | hsv[..., 1] = 255 175 | flows = np.clip(merge(flows, size), -1, 1) 176 | 177 | mag, ang = cv2.cartToPolar(flows[..., 0], flows[..., 1]) 178 | hsv[..., 0] = ang * 180 / np.pi / 2 179 | hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) 180 | bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) 181 | bgr = bgr * 10 182 | cv2.imwrite(savepath, bgr) 183 | 184 | 185 | class flowwrapper(nn.Module): 186 | def __init__(self): 187 | super(flowwrapper, self).__init__() 188 | 189 | def forward(self, x, flow): 190 | # flow: (batch size, 2, height, width) 191 | # x = x.cuda() 192 | N = x.size()[0] 193 | H = x.size()[2] 194 | W = x.size()[3] 195 | base_grid = torch.zeros([N, H, W, 2]) 196 | linear_points = torch.linspace(-1, 1, W) if W > 1 else torch.Tensor([-1]) 197 | base_grid[:, :, :, 0] = torch.ger(torch.ones(H), linear_points).expand_as(base_grid[:, :, :, 0]) 198 | linear_points = torch.linspace(-1, 1, H) if H > 1 else torch.Tensor([-1]) 199 | base_grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W)).expand_as(base_grid[:, :, :, 1]) 200 | if x.is_cuda: 201 | base_grid = Vb(base_grid).cuda() 202 | else: 203 | base_grid = Vb(base_grid) 204 | # print flow.shape 205 | flow = flow.transpose(1, 2).transpose(2, 3) 206 | # print flow.size() 207 | grid = base_grid - flow 208 | # print grid.size() 209 | out = F.grid_sample(x, grid) 210 | return out 211 | 212 | 213 | def testcode(): 214 | a = flowwrapper() 215 | img = cv2.imread('image.jpg') 216 | b, g, r = cv2.split(img) 217 | img = cv2.merge([r, g, b]) 218 | image = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes())) 219 | image = image.view(img.shape[0], img.shape[1], 3) 220 | image = image.transpose(0, 2).transpose(1, 2).contiguous() 221 | image = image.float().div(255) 222 | img = Vb(torch.stack([image])).cuda() 223 | flow = Vb(torch.randn([1, 2, img.size()[2], img.size()[3]]).div(40)).cuda() 224 | newimg = a(img, flow) 225 | tov.save_image(newimg.data, 'img1.jpg') 226 | 227 | 228 | def viewflow(filename): 229 | ''' 230 | cap=cv2.VideoCapture(filename) 231 | _,frame1=cap.read() 232 | _,frame2=cap.read() 233 | prvs = cv2.cvtColor(frame1,cv2.COLOR_BGR2GRAY) 234 | next = cv2.cvtColor(frame2,cv2.COLOR_BGR2GRAY) 235 | flow = cv2.calcOpticalFlowFarneback(prvs, next, 0.5, 3, 15, 3, 5, 1.2, 0) 236 | flow=np.array(flow) 237 | print flow.shape 238 | u=flow[:,:,0] 239 | v=flow[:,:,1] 240 | ''' 241 | u = -np.ones([128, 128]) * 40 242 | v = np.zeros([128, 128]) * 40 243 | image = compute_color(u, v) 244 | cv2.imwrite('flow.jpg', image) 245 | # cv2.imwrite('frame1.jpg',frame1) 246 | # cv2.imwrite('frame2.jpg',frame2) 247 | # cap.release() 248 | 249 | 250 | def gradientx(img): 251 | return img[:, :, :, :-1] - img[:, :, :, 1:] 252 | 253 | 254 | def gradienty(img): 255 | return img[:, :, :-1, :] - img[:, :, 1:, :] 256 | 257 | 258 | def length_sq(x): 259 | return torch.sum(x**2, dim=1) 260 | 261 | 262 | def occlusion(flow, flowback, flowwarp, opt): 263 | flow_bw_warped = flowwarp(flow, -flowback) 264 | flow_fw_warped = flowwarp(flowback, flow) 265 | 266 | # flow_diff_fw = torch.abs(flow - flow_bw_warped) 267 | # flow_diff_bw = torch.abs(flowback - flow_fw_warped) 268 | 269 | flow_diff_fw = torch.abs(flow - flow_fw_warped) 270 | flow_diff_bw = torch.abs(flowback - flow_bw_warped) 271 | 272 | occ_thresh = opt.alpha1 * (length_sq(flow) + length_sq(flowback)) + opt.alpha2 273 | 274 | occ_fw = (length_sq(flow_diff_fw) > occ_thresh).float().unsqueeze(1) 275 | occ_bw = (length_sq(flow_diff_bw) > occ_thresh).float().unsqueeze(1) 276 | 277 | return 1-occ_fw, 1-occ_bw 278 | 279 | 280 | def get_occlusion_mask(flow, flowback, flowwarpper, opt, t=4): 281 | mask_fw = [] 282 | mask_bw = [] 283 | for i in xrange(t): 284 | tmp_mask_fw, tmp_mask_bw = occlusion(flow[:, :, i, :, :], flowback[:, :, i, :, :], flowwarpper, opt) 285 | 286 | mask_fw.append(tmp_mask_fw) 287 | mask_bw.append(tmp_mask_bw) 288 | 289 | mask_bw = torch.cat(mask_bw, 1) 290 | mask_fw = torch.cat(mask_fw, 1) 291 | return mask_fw, mask_bw 292 | 293 | def warp(frame, flow, opt, flowwarpper, mask): 294 | '''Use mask before warpping''' 295 | out = [torch.unsqueeze(flowwarpper(frame, flow[:, :, i, :, :] * mask[:, i:i + 1, ...]), 1) 296 | for i in range(opt.num_predicted_frames)] 297 | output = torch.cat(out, 1) # (64, 4, 3, 128, 128) 298 | return output 299 | 300 | 301 | def warp_back(frame2, flowback, opt, flowwarpper, mask): 302 | prevframe = [ 303 | torch.unsqueeze(flowwarpper(frame2[:, ii, :, :, :], -flowback[:, :, ii, :, :] * mask[:, ii:ii + 1, ...]), 1) 304 | for ii in range(opt.num_predicted_frames)] 305 | output = torch.cat(prevframe, 1) 306 | return output 307 | 308 | 309 | def refine(input, flow, mask, refine_net, opt, noise_bg): 310 | '''Go through the refine network.''' 311 | # apply mask to the warpped image 312 | out = [torch.unsqueeze(refine_net(input[:, i, ...] * mask[:, i:i + 1, ...] + noise_bg * (1. - mask[:, i:i + 1, ...]) 313 | , flow[:, :, i, :, :] 314 | ), 1) for i in range(opt.num_predicted_frames)] 315 | 316 | out = torch.cat(out, 1) 317 | return out 318 | 319 | def refine_id(input, flow, mask, refine_net, opt, noise_bg): 320 | '''Go through the refine network.''' 321 | # apply mask to the warpped image 322 | out = [torch.unsqueeze(refine_net(input[:, i+1, ...] * mask[:, i:i + 1, ...] + noise_bg * (1. - mask[:, i:i + 1, ...]) 323 | , flow[:, :, i, :, :] 324 | ), 1) for i in range(opt.num_predicted_frames)] 325 | out1 = [refine_net(input[:, 0, ...], flow[:, :, 0, :, :]).unsqueeze(1)] 326 | 327 | out = torch.cat(out1+out, 1) 328 | return out 329 | 330 | def refine_w_mask(input, ssmask, flow, mask, refine_net, opt, noise_bg): 331 | '''Go through the refine network.''' 332 | # apply mask to the warpped image 333 | out = [torch.unsqueeze(refine_net(input[:, i, ...] * mask[:, i:i + 1, ...] + noise_bg * (1. - mask[:, i:i + 1, ...]) 334 | , flow[:, :, i, :, :], ssmask[:, i, ...] 335 | ), 1) for i in range(opt.num_predicted_frames)] 336 | 337 | out = torch.cat(out, 1) 338 | return out 339 | 340 | 341 | if __name__ == '__main__': 342 | 343 | viewflow('a') 344 | # viewflow('/ssd/10.10.20.21/share/guojiaming/UCF-101/Surfing/v_Surfing_g15_c02.avi') 345 | 346 | img = Vb(torch.randn([16, 3, 128, 128]).div(40)).cuda() 347 | flow = Vb(torch.randn([16, 2, img.size()[2], img.size()[3]]).div(40)).cuda() 348 | begin = datetime.datetime.now() 349 | print (quickflowloss(flow, img)) 350 | end = datetime.datetime.now() 351 | time2 = end-begin 352 | print (time2.total_seconds()) 353 | 354 | 355 | neighber=5 356 | bound = ((neighber-1)/2) 357 | x = torch.zeros([neighber, neighber]) 358 | linear_points = torch.linspace(-bound, bound, neighber) 359 | x = torch.ger(torch.ones(neighber), linear_points).expand_as(x) 360 | 361 | y = torch.ger(linear_points, torch.ones(neighber)).expand_as(x) 362 | dst = x**2+y**2 363 | print (dst) 364 | 365 | testcode() 366 | -------------------------------------------------------------------------------- /src/utils/semantic_segmask_order_data.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from tqdm import tqdm 4 | 5 | import multiprocessing as mp 6 | 7 | 8 | cities = glob.glob('/mnt/lustrenew/DATAshare/leftImg8bit_sequence/val/*') 9 | 10 | cities = [city[50:] for city in cities] 11 | 12 | # print (cities) 13 | 14 | # root_folder = '/mnt/lustrenew/DATAshare/leftImg8bit_sequence/train_semantic_segmask' 15 | root_folder = '/mnt/lustrenew/DATAshare/unzip/leftImg8bit_sequence/val_pix2pixHD' 16 | 17 | if not os.path.exists(root_folder): 18 | os.makedirs(root_folder) 19 | 20 | for city in cities: 21 | if not os.path.exists(os.path.join(root_folder, city)): 22 | os.makedirs(os.path.join(root_folder, city)) 23 | 24 | mask_folder = '/mnt/lustrenew/DATAshare/synth/images/*synthesized_image.jpg' 25 | 26 | segmask_list = glob.glob(mask_folder) 27 | 28 | # print (len(segmask_list)) 29 | # print (segmask_list[0]) 30 | # print(segmask_list[0][0:97]+ '\\' + segmask_list[0][97:-5] + '\\' + segmask_list[0][-5::]) 31 | # print(segmask_list[0][98:-5]+'_ssmask.png') 32 | 33 | # for segmask in tqdm(segmask_list): 34 | # 35 | # city = segmask[98:-5].split('_')[0] 36 | # new_segmask_name = segmask[98:-5]+'_ssmask.png' 37 | # 38 | # target_folder = os.path.join(root_folder, city) 39 | # command = 'mv ' + segmask[0:97]+ '\\' + segmask[97:-5] + '\\' + segmask[-5::] + ' ' + target_folder + '/' +new_segmask_name 40 | # os.system(command) 41 | # print (command) 42 | # 43 | # break 44 | 45 | def processing(segmask): 46 | city = segmask[38::].split('_')[0] 47 | new_segmask_name = segmask[38:-37] + 'pix2pixHD.png' 48 | 49 | target_folder = os.path.join(root_folder, city) 50 | # command = 'mv ' + segmask[0:97] + '\\' + segmask[97:-5] + '\\' + segmask[ 51 | # -5::] + ' ' + target_folder + '/' + new_segmask_name 52 | 53 | command = 'mv ' + segmask + ' ' + target_folder + '/' + new_segmask_name 54 | 55 | os.system(command) 56 | print (command) 57 | 58 | mp.Pool(mp.cpu_count()).map(processing, segmask_list) 59 | -------------------------------------------------------------------------------- /src/utils/svg_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import socket 4 | import argparse 5 | import os 6 | import numpy as np 7 | from sklearn.manifold import TSNE 8 | import scipy.misc 9 | import matplotlib 10 | matplotlib.use('agg') 11 | import matplotlib.pyplot as plt 12 | import functools 13 | from skimage.measure import compare_psnr as psnr_metric 14 | from skimage.measure import compare_ssim as ssim_metric 15 | from scipy import signal 16 | from scipy import ndimage 17 | from PIL import Image, ImageDraw 18 | 19 | 20 | from torchvision import datasets, transforms 21 | from torch.autograd import Variable 22 | import imageio 23 | 24 | 25 | hostname = socket.gethostname() 26 | 27 | def load_dataset(opt): 28 | if opt.dataset == 'smmnist': 29 | from data.moving_mnist import MovingMNIST 30 | train_data = MovingMNIST( 31 | train=True, 32 | data_root=opt.data_root, 33 | seq_len=opt.n_past+opt.n_future, 34 | image_size=opt.image_width, 35 | deterministic=False, 36 | num_digits=opt.num_digits) 37 | test_data = MovingMNIST( 38 | train=False, 39 | data_root=opt.data_root, 40 | seq_len=opt.n_eval, 41 | image_size=opt.image_width, 42 | deterministic=False, 43 | num_digits=opt.num_digits) 44 | elif opt.dataset == 'bair': 45 | from data.bair import RobotPush 46 | train_data = RobotPush( 47 | data_root=opt.data_root, 48 | train=True, 49 | seq_len=opt.n_past+opt.n_future, 50 | image_size=opt.image_width) 51 | test_data = RobotPush( 52 | data_root=opt.data_root, 53 | train=False, 54 | seq_len=opt.n_eval, 55 | image_size=opt.image_width) 56 | elif opt.dataset == 'kth': 57 | from data.kth import KTH 58 | train_data = KTH( 59 | train=True, 60 | data_root=opt.data_root, 61 | seq_len=opt.n_past+opt.n_future, 62 | image_size=opt.image_width) 63 | test_data = KTH( 64 | train=False, 65 | data_root=opt.data_root, 66 | seq_len=opt.n_eval, 67 | image_size=opt.image_width) 68 | 69 | return train_data, test_data 70 | 71 | def sequence_input(seq, dtype): 72 | return [Variable(x.type(dtype)) for x in seq] 73 | 74 | def normalize_data(opt, dtype, sequence): 75 | if opt.dataset == 'smmnist' or opt.dataset == 'kth' or opt.dataset == 'bair' : 76 | sequence.transpose_(0, 1) 77 | sequence.transpose_(3, 4).transpose_(2, 3) 78 | else: 79 | sequence.transpose_(0, 1) 80 | 81 | return sequence_input(sequence, dtype) 82 | 83 | def is_sequence(arg): 84 | return (not hasattr(arg, "strip") and 85 | not type(arg) is np.ndarray and 86 | not hasattr(arg, "dot") and 87 | (hasattr(arg, "__getitem__") or 88 | hasattr(arg, "__iter__"))) 89 | 90 | def image_tensor(inputs, padding=1): 91 | # assert is_sequence(inputs) 92 | assert len(inputs) > 0 93 | # print(inputs) 94 | 95 | # if this is a list of lists, unpack them all and grid them up 96 | if is_sequence(inputs[0]) or (hasattr(inputs, "dim") and inputs.dim() > 4): 97 | images = [image_tensor(x) for x in inputs] 98 | if images[0].dim() == 3: 99 | c_dim = images[0].size(0) 100 | x_dim = images[0].size(1) 101 | y_dim = images[0].size(2) 102 | else: 103 | c_dim = 1 104 | x_dim = images[0].size(0) 105 | y_dim = images[0].size(1) 106 | 107 | result = torch.ones(c_dim, 108 | x_dim * len(images) + padding * (len(images)-1), 109 | y_dim) 110 | for i, image in enumerate(images): 111 | result[:, i * x_dim + i * padding : 112 | (i+1) * x_dim + i * padding, :].copy_(image) 113 | 114 | return result 115 | 116 | # if this is just a list, make a stacked image 117 | else: 118 | images = [x.data if isinstance(x, torch.autograd.Variable) else x 119 | for x in inputs] 120 | # print(images) 121 | if images[0].dim() == 3: 122 | c_dim = images[0].size(0) 123 | x_dim = images[0].size(1) 124 | y_dim = images[0].size(2) 125 | else: 126 | c_dim = 1 127 | x_dim = images[0].size(0) 128 | y_dim = images[0].size(1) 129 | 130 | result = torch.ones(c_dim, 131 | x_dim, 132 | y_dim * len(images) + padding * (len(images)-1)) 133 | for i, image in enumerate(images): 134 | result[:, :, i * y_dim + i * padding : 135 | (i+1) * y_dim + i * padding].copy_(image) 136 | return result 137 | 138 | def save_np_img(fname, x): 139 | if x.shape[0] == 1: 140 | x = np.tile(x, (3, 1, 1)) 141 | img = scipy.misc.toimage(x, 142 | high=255*x.max(), 143 | channel_axis=0) 144 | img.save(fname) 145 | 146 | def make_image(tensor): 147 | tensor = tensor.cpu().clamp(0, 1) 148 | if tensor.size(0) == 1: 149 | tensor = tensor.expand(3, tensor.size(1), tensor.size(2)) 150 | # pdb.set_trace() 151 | return scipy.misc.toimage(tensor.numpy(), 152 | high=255.*tensor.max(), 153 | channel_axis=0) 154 | 155 | def draw_text_tensor(tensor, text): 156 | np_x = tensor.transpose(0, 1).transpose(1, 2).data.cpu().numpy() 157 | pil = Image.fromarray(np.uint8(np_x*255)) 158 | draw = ImageDraw.Draw(pil) 159 | draw.text((4, 64), text, (0,0,0)) 160 | img = np.asarray(pil) 161 | return Variable(torch.Tensor(img / 255.)).transpose(1, 2).transpose(0, 1) 162 | 163 | def save_gif(filename, inputs, duration=0.25): 164 | images = [] 165 | for tensor in inputs: 166 | img = image_tensor(tensor, padding=0) 167 | img = img.cpu() 168 | img = img.transpose(0,1).transpose(1,2).clamp(0,1) 169 | img = img.numpy()*255. 170 | img = img.astype('uint8') 171 | images.append(img) 172 | imageio.mimsave(filename, images, duration=duration) 173 | 174 | def save_gif_with_text(filename, inputs, text, duration=0.25): 175 | images = [] 176 | for tensor, text in zip(inputs, text): 177 | img = image_tensor([draw_text_tensor(ti, texti) for ti, texti in zip(tensor, text)], padding=0) 178 | img = img.cpu() 179 | img = img.transpose(0,1).transpose(1,2).clamp(0,1).numpy() 180 | images.append(img) 181 | imageio.mimsave(filename, images, duration=duration) 182 | 183 | def save_image(filename, tensor): 184 | img = make_image(tensor) 185 | img.save(filename) 186 | 187 | def save_tensors_image(filename, inputs, padding=1): 188 | images = image_tensor(inputs, padding) 189 | return save_image(filename, images) 190 | 191 | def prod(l): 192 | return functools.reduce(lambda x, y: x * y, l) 193 | 194 | def batch_flatten(x): 195 | return x.resize(x.size(0), prod(x.size()[1:])) 196 | 197 | def clear_progressbar(): 198 | # moves up 3 lines 199 | print("\033[2A") 200 | # deletes the whole line, regardless of character position 201 | print("\033[2K") 202 | # moves up two lines again 203 | print("\033[2A") 204 | 205 | def mse_metric(x1, x2): 206 | err = np.sum((x1 - x2) ** 2) 207 | err /= float(x1.shape[0] * x1.shape[1] * x1.shape[2]) 208 | return err 209 | 210 | def eval_seq(gt, pred): 211 | T = len(gt) 212 | bs = gt[0].shape[0] 213 | ssim = np.zeros((bs, T)) 214 | psnr = np.zeros((bs, T)) 215 | mse = np.zeros((bs, T)) 216 | for i in range(bs): 217 | for t in range(T): 218 | for c in range(gt[t][i].shape[0]): 219 | ssim[i, t] += ssim_metric(gt[t][i][c], pred[t][i][c]) 220 | psnr[i, t] += psnr_metric(gt[t][i][c], pred[t][i][c]) 221 | ssim[i, t] /= gt[t][i].shape[0] 222 | psnr[i, t] /= gt[t][i].shape[0] 223 | mse[i, t] = mse_metric(gt[t][i], pred[t][i]) 224 | 225 | return mse, ssim, psnr 226 | 227 | # ssim function used in Babaeizadeh et al. (2017), Fin et al. (2016), etc. 228 | def finn_eval_seq(gt, pred): 229 | T = len(gt) 230 | bs = gt[0].shape[0] 231 | ssim = np.zeros((bs, T)) 232 | psnr = np.zeros((bs, T)) 233 | mse = np.zeros((bs, T)) 234 | for i in range(bs): 235 | for t in range(T): 236 | for c in range(gt[t][i].shape[0]): 237 | res = finn_ssim(gt[t][i][c], pred[t][i][c]).mean() 238 | if math.isnan(res): 239 | ssim[i, t] += -1 240 | else: 241 | ssim[i, t] += res 242 | psnr[i, t] += finn_psnr(gt[t][i][c], pred[t][i][c]) 243 | ssim[i, t] /= gt[t][i].shape[0] 244 | psnr[i, t] /= gt[t][i].shape[0] 245 | mse[i, t] = mse_metric(gt[t][i], pred[t][i]) 246 | 247 | return mse, ssim, psnr 248 | 249 | 250 | def finn_psnr(x, y): 251 | mse = ((x - y)**2).mean() 252 | return 10*np.log(1/mse)/np.log(10) 253 | 254 | 255 | def gaussian2(size, sigma): 256 | A = 1/(2.0*np.pi*sigma**2) 257 | x, y = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] 258 | g = A*np.exp(-((x**2/(2.0*sigma**2))+(y**2/(2.0*sigma**2)))) 259 | return g 260 | 261 | def fspecial_gauss(size, sigma): 262 | x, y = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] 263 | g = np.exp(-((x**2 + y**2)/(2.0*sigma**2))) 264 | return g/g.sum() 265 | 266 | def finn_ssim(img1, img2, cs_map=False): 267 | img1 = img1.astype(np.float64) 268 | img2 = img2.astype(np.float64) 269 | size = 11 270 | sigma = 1.5 271 | window = fspecial_gauss(size, sigma) 272 | K1 = 0.01 273 | K2 = 0.03 274 | L = 1 #bitdepth of image 275 | C1 = (K1*L)**2 276 | C2 = (K2*L)**2 277 | mu1 = signal.fftconvolve(img1, window, mode='valid') 278 | mu2 = signal.fftconvolve(img2, window, mode='valid') 279 | mu1_sq = mu1*mu1 280 | mu2_sq = mu2*mu2 281 | mu1_mu2 = mu1*mu2 282 | sigma1_sq = signal.fftconvolve(img1*img1, window, mode='valid') - mu1_sq 283 | sigma2_sq = signal.fftconvolve(img2*img2, window, mode='valid') - mu2_sq 284 | sigma12 = signal.fftconvolve(img1*img2, window, mode='valid') - mu1_mu2 285 | if cs_map: 286 | return (((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)* 287 | (sigma1_sq + sigma2_sq + C2)), 288 | (2.0*sigma12 + C2)/(sigma1_sq + sigma2_sq + C2)) 289 | else: 290 | return ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)* 291 | (sigma1_sq + sigma2_sq + C2)) 292 | 293 | 294 | def init_weights(m): 295 | classname = m.__class__.__name__ 296 | if classname.find('Conv') != -1 or classname.find('Linear') != -1: 297 | m.weight.data.normal_(0.0, 0.02) 298 | m.bias.data.fill_(0) 299 | elif classname.find('BatchNorm') != -1: 300 | m.weight.data.normal_(1.0, 0.02) 301 | m.bias.data.fill_(0) 302 | 303 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import math 3 | import os 4 | import numpy as np 5 | from PIL import Image 6 | import imageio 7 | import cv2 8 | import torch 9 | import sys 10 | from utils import ops 11 | 12 | 13 | def save_images(images, size, image_path): 14 | # images = (images+1.)/2. 15 | # import pdb 16 | # pdb.set_trace() 17 | num_images = size[0] * size[1] 18 | puzzle = merge(images[0:num_images], size) 19 | 20 | im = Image.fromarray(np.uint8(puzzle)) 21 | return im.save(image_path) 22 | 23 | 24 | def save_gif(images, length, size, gifpath): 25 | num_images = size[0] * size[1] 26 | images = np.array(images[0:num_images]) 27 | savegif = [np.uint8(merge(images[:, times, :, :, :], size)) for times in range(0, length)] 28 | imageio.mimsave(gifpath, savegif, fps=int(length)) 29 | 30 | 31 | def merge(images, size): 32 | cdim = images.shape[-1] 33 | h, w = images.shape[1], images.shape[2] 34 | if cdim == 1: 35 | img = np.zeros((h * size[0], w * size[1])) 36 | for idx, image in enumerate(images): 37 | i = idx % size[1] 38 | j = idx // size[1] 39 | img[j * h:j * h + h, i * w:i * w + w] = np.squeeze(image) 40 | return img 41 | else: 42 | img = np.zeros((h * size[0], w * size[1], cdim)) 43 | for idx, image in enumerate(images): 44 | i = idx % size[1] 45 | j = idx // size[1] 46 | img[j * h:j * h + h, i * w:i * w + w, :] = image 47 | # print img.shape 48 | return img 49 | 50 | 51 | def psnr(img1, img2): 52 | mse = np.mean((img1 - img2) ** 2) 53 | if mse == 0: 54 | return 100 55 | PIXEL_MAX = 255.0 56 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 57 | 58 | 59 | def sharpness(img1, img2): 60 | dxI = np.abs(img1[:, 1:, 1:, :] - img1[:, :-1, 1:, :]) 61 | dyI = np.abs(img1[:, 1:, 1:, :] - img1[:, 1:, :-1, :]) 62 | dxJ = np.abs(img2[:, 1:, 1:, :] - img2[:, :-1, 1:, :]) 63 | dyJ = np.abs(img2[:, 1:, 1:, :] - img2[:, 1:, :-1, :]) 64 | PIXEL_MAX = 255.0 65 | grad = np.mean(np.abs(dxI + dyI - dxJ - dyJ)) 66 | if grad == 0: 67 | return 100 68 | else: 69 | return 20 * math.log10(PIXEL_MAX / math.sqrt(grad)) 70 | 71 | 72 | def save_samples(data, y_pred_before_refine, y_pred, flow, mask_fw, mask_bw, iteration, sampledir, opt, eval=False, useMask=True, 73 | grid=[8, 4], single=False, bidirectional=False): 74 | 75 | frame1 = data[:, 0, :, :, :] 76 | 77 | num_predicted_frames = y_pred.size()[1] -1 78 | num_frames = y_pred.size()[1] 79 | 80 | if useMask: 81 | save_gif(mask_fw.unsqueeze(4).data.cpu().numpy() * 255., num_predicted_frames, grid, sampledir + 82 | '/{:06d}_foward_occ_map.gif'.format(iteration)) 83 | save_gif(mask_bw.unsqueeze(4).data.cpu().numpy() * 255., num_predicted_frames, grid, sampledir + 84 | '/{:06d}_backward_occ_map.gif'.format(iteration)) 85 | 86 | 87 | # Save results before refinement 88 | frame1_ = torch.unsqueeze(frame1, 1) 89 | if bidirectional: 90 | fakegif_before_refinement = torch.cat([y_pred_before_refine[:, 0:3, ...], frame1_.cuda(), y_pred_before_refine[:, 3::, ...]], 1) 91 | else: 92 | fakegif_before_refinement = torch.cat([frame1_.cuda(), y_pred_before_refine], 1) 93 | fakegif_before_refinement = fakegif_before_refinement.transpose(2, 3).transpose(3, 4).data.cpu().numpy() 94 | 95 | 96 | # Save reconstructed or sampled video 97 | if bidirectional: 98 | fakegif = torch.cat([y_pred[:,0:3,...], frame1_.cuda(), y_pred[:,3::,...]], 1) 99 | 100 | else: 101 | fakegif = torch.cat([frame1_.cuda(), y_pred], 1) 102 | fakegif = fakegif.transpose(2, 3).transpose(3, 4).data.cpu().numpy() 103 | 104 | # Save flow field 105 | # _flow = flow[:, :, -1, :, :] 106 | # _flow = _flow.cpu().data.transpose(1, 2).transpose(2, 3).numpy() 107 | _flow = flow.permute(0, 2, 3, 4, 1) 108 | _flow = _flow.cpu().data.numpy() 109 | 110 | if eval: 111 | save_file_name = 'sample' 112 | # Save ground truth sample 113 | if bidirectional: 114 | data = data[:, [1, 2, 3, 0, 4, 5, 6, 7], ...].cpu().data.transpose(2, 3).transpose(3, 4).numpy() 115 | else: 116 | data = data.cpu().data.transpose(2, 3).transpose(3, 4).numpy() 117 | save_gif(data * 255, opt.num_frames, [8, 4], sampledir + '/{:06d}_gt.gif'.format(iteration)) 118 | else: 119 | save_file_name = 'recon' 120 | 121 | save_gif(fakegif * 255, num_frames, grid, sampledir + '/{:06d}_%s.gif'.format(iteration)%save_file_name) 122 | save_gif(fakegif_before_refinement * 255, num_frames, grid, 123 | sampledir + '/{:06d}_%s_bf_refine.gif'.format(iteration)%save_file_name) 124 | # ops.saveflow(_flow, opt.input_size, grid, sampledir + '/{:06d}_%s_flow.jpg'.format(iteration)%save_file_name) 125 | ops.save_flow_sequence(_flow, num_predicted_frames, opt.input_size, grid, sampledir + '/{:06d}_%s_flow.gif'.format(iteration) % save_file_name) 126 | 127 | if single: 128 | import scipy.misc 129 | for i in range(5): 130 | scipy.misc.imsave(sampledir +'/{:06d}_'.format(iteration) + str(i)+'.png', fakegif[0,i,...]) 131 | 132 | 133 | def save_parameters(flowgen): 134 | '''Write parameters setting file''' 135 | with open(os.path.join(flowgen.parameterdir, 'params.txt'), 'w') as file: 136 | file.write(flowgen.jobname) 137 | file.write('Training Parameters: \n') 138 | file.write(str(flowgen.opt) + '\n') 139 | if flowgen.load: 140 | file.write('Load pretrained model: ' + str(flowgen.load) + '\n') 141 | file.write('Iteration to load:' + str(flowgen.iter_to_load) + '\n') 142 | import cv2 143 | 144 | def save_images(root_dir, data, y_pred, paths, opt): 145 | 146 | frame1 = data[:, 0, :, :, :] 147 | frame1_ = torch.unsqueeze(frame1, 1) 148 | frame_sequence = torch.cat([frame1_.cuda(), y_pred], 1) 149 | frame_sequence = frame_sequence.permute((0, 1, 3, 4, 2)).cpu().data.numpy()* 255 # batch, num_frame, H, W, C 150 | 151 | for i in range(y_pred.size()[0]): 152 | 153 | # save images as gif 154 | frames_fo_save = [np.uint8(frame_sequence[i][frame_id]) for frame_id in range(y_pred.size()[1]+1)] 155 | # 3fps 156 | aux_dir = os.path.join(root_dir, paths[0][i][0:-22]) 157 | if not os.path.isdir(aux_dir): 158 | os.makedirs(aux_dir) 159 | 160 | imageio.mimsave(os.path.join(root_dir, paths[0][i][0:-4] + '.gif'), frames_fo_save, fps=int(len(paths)*2)) 161 | 162 | 163 | # new added 164 | 165 | for j, frame in enumerate(frames_fo_save): 166 | # import pdb 167 | # pdb.set_trace() 168 | cv2.imwrite(os.path.join(root_dir, paths[0][i][0:-4] + '{:02d}.png'.format(j)), cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)) 169 | 170 | # for j in range(len(frame_sequence[0])): 171 | # # aux_dir = os.path.join(root_dir, paths[j][i][0:-22]) 172 | # # if not os.path.isdir(aux_dir): 173 | # # os.makedirs(aux_dir) 174 | # frame = frame_sequence[i][j] 175 | # # frameResized = cv2.resize(frame, (256, 128), interpolation=cv2.INTER_LINEAR) 176 | # # cv2.imwrite(os.path.join(root_dir, paths[j][i]), frame) 177 | # # cv2.imwrite(os.path.join(root_dir, paths[j][i]), cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)) 178 | # cv2.imwrite(os.path.join(root_dir, paths[j][i]), frame) 179 | 180 | 181 | 182 | def save_images_ucf(root_dir, data, y_pred, paths, opt): 183 | 184 | frame1 = data[:, 0, :, :, :] 185 | frame1_ = torch.unsqueeze(frame1, 1) 186 | frame_sequence = torch.cat([frame1_.cuda(), y_pred], 1) 187 | frame_sequence = frame_sequence.permute((0, 1, 3, 4, 2)).cpu().data.numpy()* 255 # batch, num_frame, H, W, C 188 | 189 | for i in range(y_pred.size()[0]): 190 | 191 | # save images as gif 192 | frames_fo_save = [np.uint8(frame_sequence[i][frame_id]) for frame_id in range(y_pred.size()[1]+1)] 193 | # 3fps 194 | # import pdb 195 | # pdb.set_trace() 196 | aux_dir = os.path.join(root_dir, paths[0][i]) 197 | if not os.path.isdir(aux_dir): 198 | os.makedirs(aux_dir) 199 | 200 | imageio.mimsave(os.path.join(root_dir, paths[0][i] + '.gif'), frames_fo_save, fps=int(len(paths)*2)) 201 | 202 | # new added 203 | 204 | for j, frame in enumerate(frames_fo_save): 205 | # import pdb 206 | # pdb.set_trace() 207 | cv2.imwrite(os.path.join(root_dir, paths[0][i], '{:02d}.png'.format(j)), cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)) 208 | 209 | 210 | 211 | def save_images_kitti(root_dir, data, y_pred, paths, opt): 212 | 213 | frame1 = data[:, 0, :, :, :] 214 | frame1_ = torch.unsqueeze(frame1, 1) 215 | frame_sequence = torch.cat([frame1_.cuda(), y_pred], 1) 216 | frame_sequence = frame_sequence.permute((0, 1, 3, 4, 2)).cpu().data.numpy()* 255 # batch, num_frame, H, W, C 217 | 218 | for i in range(y_pred.size()[0]): 219 | 220 | # save images as gif 221 | frames_fo_save = [np.uint8(frame_sequence[i][frame_id]) for frame_id in range(y_pred.size()[1]+1)] 222 | # 3fps 223 | # import pdb 224 | # pdb.set_trace() 225 | aux_dir = os.path.join(root_dir, paths[i]) 226 | if not os.path.isdir(aux_dir): 227 | os.makedirs(aux_dir) 228 | 229 | imageio.mimsave(os.path.join(root_dir, paths[i] + '.gif'), frames_fo_save, fps=int(len(paths)*2)) 230 | 231 | # new added 232 | 233 | for j, frame in enumerate(frames_fo_save): 234 | # import pdb 235 | # pdb.set_trace() 236 | frame = cv2.resize(frame, (256, 78)) 237 | cv2.imwrite(os.path.join(root_dir, paths[i], '{:02d}.png'.format(j)), cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)) 238 | 239 | 240 | 241 | def save_flows(root_dir, flow, paths): 242 | # print(flow.size()) 243 | _flow = flow.permute(0, 2, 3, 4, 1) 244 | _flow = _flow.cpu().data.numpy() 245 | # mask = mask.unsqueeze(4) 246 | # # print (mask.size()) 247 | # mask = mask.data.cpu().numpy() * 255. 248 | 249 | for i in range(flow.size()[0]): 250 | 251 | # save flow*mask as gif 252 | # *mask[i][frame_id]) 253 | flow_fo_save = [np.uint8(ops.compute_flow_color_map(_flow[i][frame_id])) for frame_id in range(len(paths)-1)] 254 | # 3fps 255 | imageio.mimsave(os.path.join(root_dir, paths[0][i][0:-4] + '.gif'), flow_fo_save, fps=int(len(paths)-1-2)) 256 | 257 | for j in range(flow.size()[2]): 258 | ops.saveflow(_flow[i][j], (256, 128), os.path.join(root_dir, paths[j+1][i])) 259 | 260 | 261 | def save_occ_map(root_dir, mask, paths): 262 | mask = mask.data.cpu().numpy() * 255. 263 | for i in range(mask.shape[0]): 264 | for j in range(mask.shape[1]): 265 | cv2.imwrite(os.path.join(root_dir, paths[j+1][i]), mask[i][j]) 266 | 267 | 268 | def save_samples_no_flow(data, y_pred, iteration, sampledir, opt, eval=False, 269 | grid=[8, 4], single=False, bidirectional=False): 270 | 271 | frame1 = data[:, 0, :, :, :] 272 | num_frames = y_pred.size()[1] 273 | 274 | # Save results before refinement 275 | frame1_ = torch.unsqueeze(frame1, 1) 276 | 277 | # Save reconstructed or sampled video 278 | if bidirectional: 279 | fakegif = torch.cat([y_pred[:,0:3,...], frame1_.cuda(), y_pred[:,3::,...]], 1) 280 | 281 | else: 282 | fakegif = torch.cat([frame1_.cuda(), y_pred], 1) 283 | fakegif = fakegif.transpose(2, 3).transpose(3, 4).data.cpu().numpy() 284 | 285 | if eval: 286 | save_file_name = 'sample' 287 | # Save ground truth sample 288 | if bidirectional: 289 | data = data[:, [1, 2, 3, 0, 4, 5, 6, 7], ...].cpu().data.transpose(2, 3).transpose(3, 4).numpy() 290 | else: 291 | data = data.cpu().data.transpose(2, 3).transpose(3, 4).numpy() 292 | save_gif(data * 255, opt.num_frames, [8, 4], sampledir + '/{:06d}_gt.gif'.format(iteration)) 293 | else: 294 | save_file_name = 'recon' 295 | 296 | save_gif(fakegif * 255, num_frames, grid, sampledir + '/{:06d}_%s.gif'.format(iteration)%save_file_name) 297 | 298 | if single: 299 | import scipy.misc 300 | for i in range(5): 301 | scipy.misc.imsave(sampledir +'/{:06d}_'.format(iteration) + str(i)+'.png', fakegif[0,i,...]) 302 | --------------------------------------------------------------------------------