├── .gitignore ├── ALVC.py ├── BasketballPass ├── f001.png ├── f002.png ├── f003.png ├── f004.png ├── f005.png ├── f006.png ├── f007.png ├── f008.png ├── f009.png ├── f010.png ├── f011.png ├── f012.png ├── f013.png ├── f014.png ├── f015.png ├── f016.png ├── f017.png ├── f018.png ├── f019.png ├── f020.png ├── f021.png ├── f022.png ├── f023.png ├── f024.png ├── f025.png ├── f026.png ├── f027.png ├── f028.png ├── f029.png ├── f030.png ├── f031.png ├── f032.png ├── f033.png ├── f034.png ├── f035.png ├── f036.png ├── f037.png ├── f038.png ├── f039.png ├── f040.png ├── f041.png ├── f042.png ├── f043.png ├── f044.png ├── f045.png ├── f046.png ├── f047.png ├── f048.png ├── f049.png ├── f050.png ├── f051.png ├── f052.png ├── f053.png ├── f054.png ├── f055.png ├── f056.png ├── f057.png ├── f058.png ├── f059.png ├── f060.png ├── f061.png ├── f062.png ├── f063.png ├── f064.png ├── f065.png ├── f066.png ├── f067.png ├── f068.png ├── f069.png ├── f070.png ├── f071.png ├── f072.png ├── f073.png ├── f074.png ├── f075.png ├── f076.png ├── f077.png ├── f078.png ├── f079.png ├── f080.png ├── f081.png ├── f082.png ├── f083.png ├── f084.png ├── f085.png ├── f086.png ├── f087.png ├── f088.png ├── f089.png ├── f090.png ├── f091.png ├── f092.png ├── f093.png ├── f094.png ├── f095.png ├── f096.png ├── f097.png ├── f098.png ├── f099.png └── f100.png ├── CNN_img.py ├── CNN_recurrent.py ├── Extrapolation.py ├── Interpolation_Compression.py ├── MC_network_inter.py ├── README.md ├── Recurrent_AutoEncoder_Extrapolation.py ├── Recurrent_Prob_Model.py ├── arithmeticcoding.py ├── func.py ├── functions.py ├── functions_inter.py ├── helper.py ├── helper2.py ├── inv_flow.py ├── mc_func.py ├── motion.py ├── ms_ssim_np.py ├── rec_exp.py ├── sepconv_inter.py └── sepconv_inter_enc.py /.gitignore: -------------------------------------------------------------------------------- 1 | tensorflow_compression 2 | model 3 | BasketballPass* 4 | __pycache__ 5 | model.zip 6 | t.py 7 | VVCSoftware_VTM -------------------------------------------------------------------------------- /ALVC.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser( 5 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 6 | parser.add_argument("--path", default='BasketballPass') 7 | parser.add_argument("--l", type=int, default=512, choices=[256, 512, 1024, 2048]) 8 | args = parser.parse_args() 9 | 10 | frames = len(os.listdir(args.path)) 11 | 12 | print('Running ALVC with extrapolation (P-frames)') 13 | os.system('python Recurrent_AutoEncoder_Extrapolation.py --path ' + args.path + ' --frame ' + str(frames) 14 | + ' --l ' + str(args.l)) 15 | 16 | print('Running entropy coding') 17 | os.system('python Recurrent_Prob_Model.py --path ' + args.path + ' --frame ' + str(frames) 18 | + ' --l ' + str(args.l)) 19 | 20 | print('Running ALVC with interpolation (B-frames)') 21 | os.system('python Interpolation_Compression.py --path ' + args.path + ' --frame ' + str(frames) 22 | + ' --l ' + str(args.l)) -------------------------------------------------------------------------------- /BasketballPass/f001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f001.png -------------------------------------------------------------------------------- /BasketballPass/f002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f002.png -------------------------------------------------------------------------------- /BasketballPass/f003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f003.png -------------------------------------------------------------------------------- /BasketballPass/f004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f004.png -------------------------------------------------------------------------------- /BasketballPass/f005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f005.png -------------------------------------------------------------------------------- /BasketballPass/f006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f006.png -------------------------------------------------------------------------------- /BasketballPass/f007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f007.png -------------------------------------------------------------------------------- /BasketballPass/f008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f008.png -------------------------------------------------------------------------------- /BasketballPass/f009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f009.png -------------------------------------------------------------------------------- /BasketballPass/f010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f010.png -------------------------------------------------------------------------------- /BasketballPass/f011.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f011.png -------------------------------------------------------------------------------- /BasketballPass/f012.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f012.png -------------------------------------------------------------------------------- /BasketballPass/f013.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f013.png -------------------------------------------------------------------------------- /BasketballPass/f014.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f014.png -------------------------------------------------------------------------------- /BasketballPass/f015.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f015.png -------------------------------------------------------------------------------- /BasketballPass/f016.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f016.png -------------------------------------------------------------------------------- /BasketballPass/f017.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f017.png -------------------------------------------------------------------------------- /BasketballPass/f018.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f018.png -------------------------------------------------------------------------------- /BasketballPass/f019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f019.png -------------------------------------------------------------------------------- /BasketballPass/f020.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f020.png -------------------------------------------------------------------------------- /BasketballPass/f021.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f021.png -------------------------------------------------------------------------------- /BasketballPass/f022.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f022.png -------------------------------------------------------------------------------- /BasketballPass/f023.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f023.png -------------------------------------------------------------------------------- /BasketballPass/f024.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f024.png -------------------------------------------------------------------------------- /BasketballPass/f025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f025.png -------------------------------------------------------------------------------- /BasketballPass/f026.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f026.png -------------------------------------------------------------------------------- /BasketballPass/f027.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f027.png -------------------------------------------------------------------------------- /BasketballPass/f028.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f028.png -------------------------------------------------------------------------------- /BasketballPass/f029.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f029.png -------------------------------------------------------------------------------- /BasketballPass/f030.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f030.png -------------------------------------------------------------------------------- /BasketballPass/f031.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f031.png -------------------------------------------------------------------------------- /BasketballPass/f032.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f032.png -------------------------------------------------------------------------------- /BasketballPass/f033.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f033.png -------------------------------------------------------------------------------- /BasketballPass/f034.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f034.png -------------------------------------------------------------------------------- /BasketballPass/f035.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f035.png -------------------------------------------------------------------------------- /BasketballPass/f036.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f036.png -------------------------------------------------------------------------------- /BasketballPass/f037.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f037.png -------------------------------------------------------------------------------- /BasketballPass/f038.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f038.png -------------------------------------------------------------------------------- /BasketballPass/f039.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f039.png -------------------------------------------------------------------------------- /BasketballPass/f040.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f040.png -------------------------------------------------------------------------------- /BasketballPass/f041.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f041.png -------------------------------------------------------------------------------- /BasketballPass/f042.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f042.png -------------------------------------------------------------------------------- /BasketballPass/f043.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f043.png -------------------------------------------------------------------------------- /BasketballPass/f044.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f044.png -------------------------------------------------------------------------------- /BasketballPass/f045.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f045.png -------------------------------------------------------------------------------- /BasketballPass/f046.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f046.png -------------------------------------------------------------------------------- /BasketballPass/f047.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f047.png -------------------------------------------------------------------------------- /BasketballPass/f048.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f048.png -------------------------------------------------------------------------------- /BasketballPass/f049.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f049.png -------------------------------------------------------------------------------- /BasketballPass/f050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f050.png -------------------------------------------------------------------------------- /BasketballPass/f051.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f051.png -------------------------------------------------------------------------------- /BasketballPass/f052.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f052.png -------------------------------------------------------------------------------- /BasketballPass/f053.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f053.png -------------------------------------------------------------------------------- /BasketballPass/f054.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f054.png -------------------------------------------------------------------------------- /BasketballPass/f055.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f055.png -------------------------------------------------------------------------------- /BasketballPass/f056.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f056.png -------------------------------------------------------------------------------- /BasketballPass/f057.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f057.png -------------------------------------------------------------------------------- /BasketballPass/f058.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f058.png -------------------------------------------------------------------------------- /BasketballPass/f059.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f059.png -------------------------------------------------------------------------------- /BasketballPass/f060.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f060.png -------------------------------------------------------------------------------- /BasketballPass/f061.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f061.png -------------------------------------------------------------------------------- /BasketballPass/f062.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f062.png -------------------------------------------------------------------------------- /BasketballPass/f063.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f063.png -------------------------------------------------------------------------------- /BasketballPass/f064.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f064.png -------------------------------------------------------------------------------- /BasketballPass/f065.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f065.png -------------------------------------------------------------------------------- /BasketballPass/f066.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f066.png -------------------------------------------------------------------------------- /BasketballPass/f067.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f067.png -------------------------------------------------------------------------------- /BasketballPass/f068.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f068.png -------------------------------------------------------------------------------- /BasketballPass/f069.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f069.png -------------------------------------------------------------------------------- /BasketballPass/f070.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f070.png -------------------------------------------------------------------------------- /BasketballPass/f071.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f071.png -------------------------------------------------------------------------------- /BasketballPass/f072.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f072.png -------------------------------------------------------------------------------- /BasketballPass/f073.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f073.png -------------------------------------------------------------------------------- /BasketballPass/f074.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f074.png -------------------------------------------------------------------------------- /BasketballPass/f075.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f075.png -------------------------------------------------------------------------------- /BasketballPass/f076.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f076.png -------------------------------------------------------------------------------- /BasketballPass/f077.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f077.png -------------------------------------------------------------------------------- /BasketballPass/f078.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f078.png -------------------------------------------------------------------------------- /BasketballPass/f079.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f079.png -------------------------------------------------------------------------------- /BasketballPass/f080.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f080.png -------------------------------------------------------------------------------- /BasketballPass/f081.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f081.png -------------------------------------------------------------------------------- /BasketballPass/f082.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f082.png -------------------------------------------------------------------------------- /BasketballPass/f083.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f083.png -------------------------------------------------------------------------------- /BasketballPass/f084.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f084.png -------------------------------------------------------------------------------- /BasketballPass/f085.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f085.png -------------------------------------------------------------------------------- /BasketballPass/f086.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f086.png -------------------------------------------------------------------------------- /BasketballPass/f087.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f087.png -------------------------------------------------------------------------------- /BasketballPass/f088.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f088.png -------------------------------------------------------------------------------- /BasketballPass/f089.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f089.png -------------------------------------------------------------------------------- /BasketballPass/f090.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f090.png -------------------------------------------------------------------------------- /BasketballPass/f091.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f091.png -------------------------------------------------------------------------------- /BasketballPass/f092.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f092.png -------------------------------------------------------------------------------- /BasketballPass/f093.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f093.png -------------------------------------------------------------------------------- /BasketballPass/f094.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f094.png -------------------------------------------------------------------------------- /BasketballPass/f095.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f095.png -------------------------------------------------------------------------------- /BasketballPass/f096.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f096.png -------------------------------------------------------------------------------- /BasketballPass/f097.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f097.png -------------------------------------------------------------------------------- /BasketballPass/f098.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f098.png -------------------------------------------------------------------------------- /BasketballPass/f099.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f099.png -------------------------------------------------------------------------------- /BasketballPass/f100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYang-home/ALVC/f0073dba6aa32dd40812a2be42149c0bb67d08ea/BasketballPass/f100.png -------------------------------------------------------------------------------- /CNN_img.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import tensorflow_compression as tfc 7 | 8 | 9 | def MV_analysis(tensor, num_filters, M): 10 | """Builds the analysis transform.""" 11 | 12 | with tf.variable_scope("MV_analysis"): 13 | with tf.variable_scope("layer_0"): 14 | layer = tfc.SignalConv2D( 15 | num_filters, (3, 3), corr=True, strides_down=2, padding="same_zeros", 16 | use_bias=True, activation=tfc.GDN()) 17 | tensor = layer(tensor) 18 | 19 | with tf.variable_scope("layer_1"): 20 | layer = tfc.SignalConv2D( 21 | num_filters, (3, 3), corr=True, strides_down=2, padding="same_zeros", 22 | use_bias=True, activation=tfc.GDN()) 23 | tensor = layer(tensor) 24 | 25 | with tf.variable_scope("layer_2"): 26 | layer = tfc.SignalConv2D( 27 | num_filters, (3, 3), corr=True, strides_down=2, padding="same_zeros", 28 | use_bias=True, activation=tfc.GDN()) 29 | tensor = layer(tensor) 30 | 31 | with tf.variable_scope("layer_3"): 32 | layer = tfc.SignalConv2D( 33 | M, (3, 3), corr=True, strides_down=2, padding="same_zeros", 34 | use_bias=False, activation=None) 35 | tensor = layer(tensor) 36 | 37 | return tensor 38 | 39 | 40 | def MV_synthesis(tensor, num_filters, out_filters=2): 41 | """Builds the synthesis transform.""" 42 | 43 | with tf.variable_scope("MV_synthesis"): 44 | with tf.variable_scope("layer_0"): 45 | layer = tfc.SignalConv2D( 46 | num_filters, (3, 3), corr=False, strides_up=2, padding="same_zeros", 47 | use_bias=True, activation=tfc.GDN(inverse=True)) 48 | tensor = layer(tensor) 49 | 50 | with tf.variable_scope("layer_1"): 51 | layer = tfc.SignalConv2D( 52 | num_filters, (3, 3), corr=False, strides_up=2, padding="same_zeros", 53 | use_bias=True, activation=tfc.GDN(inverse=True)) 54 | tensor = layer(tensor) 55 | 56 | with tf.variable_scope("layer_2"): 57 | layer = tfc.SignalConv2D( 58 | num_filters, (3, 3), corr=False, strides_up=2, padding="same_zeros", 59 | use_bias=True, activation=tfc.GDN(inverse=True)) 60 | tensor = layer(tensor) 61 | 62 | with tf.variable_scope("layer_3"): 63 | layer = tfc.SignalConv2D( 64 | out_filters, (3, 3), corr=False, strides_up=2, padding="same_zeros", 65 | use_bias=True, activation=None) 66 | tensor = layer(tensor) 67 | 68 | return tensor 69 | 70 | def Res_analysis(tensor, num_filters, M, reuse=False): 71 | """Builds the analysis transform.""" 72 | 73 | with tf.variable_scope("analysis", reuse=reuse): 74 | with tf.variable_scope("layer_0"): 75 | layer = tfc.SignalConv2D( 76 | num_filters, (5, 5), corr=True, strides_down=2, padding="same_zeros", 77 | use_bias=True, activation=tfc.GDN()) 78 | tensor = layer(tensor) 79 | 80 | with tf.variable_scope("layer_1"): 81 | layer = tfc.SignalConv2D( 82 | num_filters, (5, 5), corr=True, strides_down=2, padding="same_zeros", 83 | use_bias=True, activation=tfc.GDN()) 84 | tensor = layer(tensor) 85 | 86 | with tf.variable_scope("layer_2"): 87 | layer = tfc.SignalConv2D( 88 | num_filters, (5, 5), corr=True, strides_down=2, padding="same_zeros", 89 | use_bias=True, activation=tfc.GDN()) 90 | tensor = layer(tensor) 91 | 92 | with tf.variable_scope("layer_3"): 93 | layer = tfc.SignalConv2D( 94 | M, (5, 5), corr=True, strides_down=2, padding="same_zeros", 95 | use_bias=False, activation=None) 96 | tensor = layer(tensor) 97 | 98 | return tensor 99 | 100 | def Res_synthesis(tensor, num_filters, reuse=False): 101 | """Builds the synthesis transform.""" 102 | 103 | with tf.variable_scope("synthesis", reuse=reuse): 104 | with tf.variable_scope("layer_0"): 105 | layer = tfc.SignalConv2D( 106 | num_filters, (5, 5), corr=False, strides_up=2, padding="same_zeros", 107 | use_bias=True, activation=tfc.GDN(inverse=True)) 108 | tensor = layer(tensor) 109 | 110 | with tf.variable_scope("layer_1"): 111 | layer = tfc.SignalConv2D( 112 | num_filters, (5, 5), corr=False, strides_up=2, padding="same_zeros", 113 | use_bias=True, activation=tfc.GDN(inverse=True)) 114 | tensor = layer(tensor) 115 | 116 | with tf.variable_scope("layer_2"): 117 | layer = tfc.SignalConv2D( 118 | num_filters, (5, 5), corr=False, strides_up=2, padding="same_zeros", 119 | use_bias=True, activation=tfc.GDN(inverse=True)) 120 | tensor = layer(tensor) 121 | 122 | with tf.variable_scope("layer_3"): 123 | layer = tfc.SignalConv2D( 124 | 3, (5, 5), corr=False, strides_up=2, padding="same_zeros", 125 | use_bias=True, activation=None) 126 | tensor = layer(tensor) 127 | 128 | return tensor -------------------------------------------------------------------------------- /CNN_recurrent.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_compression as tfc 3 | import functions 4 | 5 | def one_step_rnn(tensor, state_c, state_h, Height, Width, num_filters, scale, kernal, act): 6 | 7 | tensor = tf.expand_dims(tensor, axis=1) 8 | 9 | cell = functions.ConvLSTMCell(shape=[Height // scale, Width // scale], activation=act, 10 | filters=num_filters, kernel=kernal) 11 | state = tf.nn.rnn_cell.LSTMStateTuple(state_c, state_h) 12 | tensor, state = tf.nn.dynamic_rnn(cell, tensor, initial_state=state, dtype=tensor.dtype) 13 | state_c, state_h = state 14 | 15 | tensor = tf.squeeze(tensor, axis=1) 16 | 17 | return tensor, state_c, state_h 18 | 19 | 20 | def MV_analysis(tensor, num_filters, out_filters, Height, Width, c_state, h_state, act): 21 | """Builds the analysis transform.""" 22 | 23 | with tf.variable_scope("MV_analysis", reuse=tf.AUTO_REUSE): 24 | with tf.variable_scope("layer_0"): 25 | layer = tfc.SignalConv2D( 26 | num_filters, (3, 3), corr=True, strides_down=2, padding="same_zeros", 27 | use_bias=True, activation=tfc.GDN()) 28 | tensor = layer(tensor) 29 | 30 | with tf.variable_scope("layer_1"): 31 | layer = tfc.SignalConv2D( 32 | num_filters, (3, 3), corr=True, strides_down=2, padding="same_zeros", 33 | use_bias=True, activation=tfc.GDN()) 34 | tensor = layer(tensor) 35 | 36 | with tf.variable_scope("recurrent"): 37 | tensor, c_state_out, h_state_out = one_step_rnn(tensor, c_state, h_state, 38 | Height, Width, num_filters, 39 | scale=4, kernal=[3, 3], act=act) 40 | 41 | with tf.variable_scope("layer_2"): 42 | layer = tfc.SignalConv2D( 43 | num_filters, (3, 3), corr=True, strides_down=2, padding="same_zeros", 44 | use_bias=True, activation=tfc.GDN()) 45 | tensor = layer(tensor) 46 | 47 | with tf.variable_scope("layer_3"): 48 | layer = tfc.SignalConv2D( 49 | out_filters, (3, 3), corr=True, strides_down=2, padding="same_zeros", 50 | use_bias=True, activation=None) 51 | tensor = layer(tensor) 52 | 53 | return tensor, c_state_out, h_state_out 54 | 55 | 56 | def MV_synthesis(tensor, num_filters, Height, Width, c_state, h_state, act): 57 | """Builds the synthesis transform.""" 58 | 59 | with tf.variable_scope("MV_synthesis", reuse=tf.AUTO_REUSE): 60 | with tf.variable_scope("layer_0"): 61 | layer = tfc.SignalConv2D( 62 | num_filters, (3, 3), corr=False, strides_up=2, padding="same_zeros", 63 | use_bias=True, activation=tfc.GDN(inverse=True)) 64 | tensor = layer(tensor) 65 | 66 | with tf.variable_scope("layer_1"): 67 | layer = tfc.SignalConv2D( 68 | num_filters, (3, 3), corr=False, strides_up=2, padding="same_zeros", 69 | use_bias=True, activation=tfc.GDN(inverse=True)) 70 | tensor = layer(tensor) 71 | 72 | with tf.variable_scope("recurrent"): 73 | tensor, c_state_out, h_state_out = one_step_rnn(tensor, c_state, h_state, 74 | Height, Width, num_filters, 75 | scale=4, kernal=[3, 3], act=act) 76 | 77 | with tf.variable_scope("layer_2"): 78 | layer = tfc.SignalConv2D( 79 | num_filters, (3, 3), corr=False, strides_up=2, padding="same_zeros", 80 | use_bias=True, activation=tfc.GDN(inverse=True)) 81 | tensor = layer(tensor) 82 | 83 | with tf.variable_scope("layer_3"): 84 | layer = tfc.SignalConv2D( 85 | 2, (3, 3), corr=False, strides_up=2, padding="same_zeros", 86 | use_bias=True, activation=None) 87 | tensor = layer(tensor) 88 | 89 | return tensor, c_state_out, h_state_out 90 | 91 | 92 | def Res_analysis(tensor, num_filters, out_filters, Height, Width, c_state, h_state, act): 93 | """Builds the analysis transform.""" 94 | 95 | with tf.variable_scope("analysis", reuse=tf.AUTO_REUSE): 96 | with tf.variable_scope("layer_0"): 97 | layer = tfc.SignalConv2D( 98 | num_filters, (5, 5), corr=True, strides_down=2, padding="same_zeros", 99 | use_bias=True, activation=tfc.GDN()) 100 | tensor = layer(tensor) 101 | 102 | with tf.variable_scope("layer_1"): 103 | layer = tfc.SignalConv2D( 104 | num_filters, (5, 5), corr=True, strides_down=2, padding="same_zeros", 105 | use_bias=True, activation=tfc.GDN()) 106 | tensor = layer(tensor) 107 | 108 | with tf.variable_scope("recurrent"): 109 | tensor, c_state_out, h_state_out = one_step_rnn(tensor, c_state, h_state, 110 | Height, Width, num_filters, 111 | scale=4, kernal=[5, 5], act=act) 112 | 113 | with tf.variable_scope("layer_2"): 114 | layer = tfc.SignalConv2D( 115 | num_filters, (5, 5), corr=True, strides_down=2, padding="same_zeros", 116 | use_bias=True, activation=tfc.GDN()) 117 | tensor = layer(tensor) 118 | 119 | with tf.variable_scope("layer_3"): 120 | layer = tfc.SignalConv2D( 121 | out_filters, (5, 5), corr=True, strides_down=2, padding="same_zeros", 122 | use_bias=True, activation=None) 123 | tensor = layer(tensor) 124 | 125 | return tensor, c_state_out, h_state_out 126 | 127 | 128 | def Res_synthesis(tensor, num_filters, Height, Width, c_state, h_state, act): 129 | """Builds the synthesis transform.""" 130 | 131 | with tf.variable_scope("synthesis", reuse=tf.AUTO_REUSE): 132 | with tf.variable_scope("layer_0"): 133 | layer = tfc.SignalConv2D( 134 | num_filters, (5, 5), corr=False, strides_up=2, padding="same_zeros", 135 | use_bias=True, activation=tfc.GDN(inverse=True)) 136 | tensor = layer(tensor) 137 | 138 | with tf.variable_scope("layer_1"): 139 | layer = tfc.SignalConv2D( 140 | num_filters, (5, 5), corr=False, strides_up=2, padding="same_zeros", 141 | use_bias=True, activation=tfc.GDN(inverse=True)) 142 | tensor = layer(tensor) 143 | 144 | with tf.variable_scope("recurrent"): 145 | tensor, c_state_out, h_state_out = one_step_rnn(tensor, c_state, h_state, 146 | Height, Width, num_filters, 147 | scale=4, kernal=[5, 5], act=act) 148 | 149 | with tf.variable_scope("layer_2"): 150 | layer = tfc.SignalConv2D( 151 | num_filters, (5, 5), corr=False, strides_up=2, padding="same_zeros", 152 | use_bias=True, activation=tfc.GDN(inverse=True)) 153 | tensor = layer(tensor) 154 | 155 | with tf.variable_scope("layer_3"): 156 | layer = tfc.SignalConv2D( 157 | 3, (5, 5), corr=False, strides_up=2, padding="same_zeros", 158 | use_bias=True, activation=None) 159 | tensor = layer(tensor) 160 | 161 | return tensor, c_state_out, h_state_out 162 | 163 | 164 | def rec_prob(tensor, num_filters, Height, Width, c_state, h_state, k=3, act=tf.tanh): 165 | 166 | with tf.variable_scope("CNN_input"): 167 | tensor = tf.expand_dims(tensor, axis=1) 168 | y1 = functions.recurrent_cnn(tensor, 1, layer=4, num_filters=num_filters, stride=1, 169 | out_filters=num_filters, kernel=[k, k], act=tf.nn.relu, act_last=None) 170 | y1 = tf.squeeze(y1, axis=1) 171 | 172 | with tf.variable_scope("RNN"): 173 | y2, c_state_out, h_state_out = one_step_rnn(y1, c_state, h_state, 174 | Height, Width, num_filters, 175 | scale=16, kernal=[k, k], act=act) 176 | 177 | with tf.variable_scope("CNN_output"): 178 | y2 = tf.expand_dims(y2, axis=1) 179 | y3 = functions.recurrent_cnn(y2, 1, layer=4, num_filters=num_filters, stride=1, 180 | out_filters=2 * num_filters, kernel=[k, k], act=tf.nn.relu, act_last=None) 181 | y3 = tf.squeeze(y3, axis=1) 182 | 183 | return y3, c_state_out, h_state_out 184 | 185 | 186 | def rec_prob_new(tensor, num_filters, Height, Width, c_state, h_state, k=3, act=tf.tanh): 187 | 188 | with tf.variable_scope("CNN_input"): 189 | tensor = tf.expand_dims(tensor, axis=1) 190 | y1 = functions.recurrent_cnn(tensor, 1, layer=4, num_filters=num_filters, stride=1, 191 | out_filters=num_filters, kernel=[k, k], act=tf.nn.relu, act_last=None) 192 | y1 = tf.squeeze(y1, axis=1) 193 | 194 | with tf.variable_scope("RNN"): 195 | y2, c_state_out, h_state_out = one_step_rnn(y1, c_state, h_state, 196 | Height, Width, num_filters, 197 | scale=16, kernal=[k, k], act=act) 198 | 199 | with tf.variable_scope("CNN_output_1"): 200 | y2 = tf.expand_dims(y2, axis=1) 201 | y3 = functions.recurrent_cnn(y2, 1, layer=4, num_filters=num_filters, stride=1, 202 | out_filters=num_filters, kernel=[k, k], act=tf.nn.relu, act_last=None) 203 | y3 = tf.squeeze(y3, axis=1) 204 | 205 | with tf.variable_scope("CNN_output_2"): 206 | 207 | y4 = functions.recurrent_cnn(y2, 1, layer=4, num_filters=num_filters, stride=1, 208 | out_filters=num_filters, kernel=[k, k], act=tf.nn.relu, act_last=None) 209 | y4 = tf.squeeze(y4, axis=1) 210 | 211 | y5 = tf.concat([y3, y4], axis=-1) 212 | 213 | return y5, c_state_out, h_state_out 214 | 215 | 216 | def bpp_est(x_target, sigma_mu, num_filters, tiny=1e-10): 217 | 218 | sigma, mu = tf.split(sigma_mu, [num_filters, num_filters], axis=-1) 219 | 220 | half = tf.constant(.5, dtype=tf.float32) 221 | 222 | upper = tf.math.add(x_target, half) 223 | lower = tf.math.add(x_target, -half) 224 | 225 | sig = tf.maximum(sigma, -7.0) 226 | upper_l = tf.sigmoid(tf.multiply((upper - mu), (tf.exp(-sig) + tiny))) 227 | lower_l = tf.sigmoid(tf.multiply((lower - mu), (tf.exp(-sig) + tiny))) 228 | p_element = upper_l - lower_l 229 | p_element = tf.clip_by_value(p_element, tiny, 1 - tiny) 230 | 231 | ent = -tf.log(p_element) / tf.log(2.0) 232 | bits = tf.math.reduce_sum(ent) 233 | 234 | return bits, sigma, mu -------------------------------------------------------------------------------- /Extrapolation.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import rec_exp as nn 3 | import numpy as np 4 | import argparse 5 | import imageio 6 | # from func import * 7 | import os 8 | 9 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 10 | 11 | # config = tf.ConfigProto(allow_soft_placement=True, device_count={'GPU': 0}) 12 | config = tf.ConfigProto(allow_soft_placement=True) 13 | sess = tf.Session(config=config) 14 | 15 | parser = argparse.ArgumentParser( 16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | parser.add_argument("--motion", type=str, default='flow_mc', choices=['flow', 'flow_mc', 'sepconv']) 18 | parser.add_argument("--input_norm", type=int, default=0, choices=[0, 1]) 19 | parser.add_argument("--ens", type=int, default=0, choices=[0, 1]) 20 | parser.add_argument("--path", type=str) 21 | parser.add_argument("--idx", type=int, default=3) 22 | parser.add_argument("--l", type=int, default=1024, choices=[256, 512, 1024, 2048]) 23 | parser.add_argument("--dirc", type=str, default='fw', choices=['fw', 'bw']) 24 | args = parser.parse_args() 25 | 26 | os.makedirs(args.path + '/extra_states', exist_ok=True) 27 | 28 | if args.dirc == 'fw': 29 | frame_1 = imageio.imread(args.path + 'f' + str(args.idx - 2).zfill(3) + '.png').astype(np.float32) / 255.0 30 | frame_2 = imageio.imread(args.path + 'f' + str(args.idx - 1).zfill(3) + '.png').astype(np.float32) / 255.0 31 | else: 32 | frame_1 = imageio.imread(args.path + 'f' + str(args.idx + 2).zfill(3) + '.png').astype(np.float32) / 255.0 33 | frame_2 = imageio.imread(args.path + 'f' + str(args.idx + 1).zfill(3) + '.png').astype(np.float32) / 255.0 34 | 35 | batch_size = 1 36 | Height = frame_1.shape[0] 37 | Width = frame_1.shape[1] 38 | 39 | frame_1 = np.expand_dims(frame_1, axis=0) 40 | frame_2 = np.expand_dims(frame_2, axis=0) 41 | 42 | if not os.path.exists(args.path + '/extra_states/state_enc_1.npy'): 43 | state_enc_1 = np.zeros([batch_size, Height // 4, Width // 4, 128], dtype=np.float32) 44 | state_enc_2 = np.zeros([batch_size, Height // 4, Width // 4, 128], dtype=np.float32) 45 | state_dec_1 = np.zeros([batch_size, Height // 4, Width // 4, 128], dtype=np.float32) 46 | state_dec_2 = np.zeros([batch_size, Height // 4, Width // 4, 128], dtype=np.float32) 47 | state_fea_1 = np.zeros([batch_size, Height // 16, Width // 16, 512], dtype=np.float32) 48 | state_fea_2 = np.zeros([batch_size, Height // 16, Width // 16, 512], dtype=np.float32) 49 | 50 | if args.motion == 'flow' or args.motion == 'flow_mc': 51 | flow = np.zeros([batch_size, Height, Width, 4], dtype=np.float32) 52 | 53 | else: 54 | 55 | state_enc_1 = np.load(args.path + '/extra_states/state_enc_1.npy') 56 | state_enc_2 = np.load(args.path + '/extra_states/state_enc_2.npy') 57 | state_dec_1 = np.load(args.path + '/extra_states/state_dec_1.npy') 58 | state_dec_2 = np.load(args.path + '/extra_states/state_dec_2.npy') 59 | state_fea_1 = np.load(args.path + '/extra_states/state_fea_1.npy') 60 | state_fea_2 = np.load(args.path + '/extra_states/state_fea_2.npy') 61 | 62 | if args.motion == 'flow' or args.motion == 'flow_mc': 63 | flow = np.load(args.path + '/extra_states/pre_flow.npy') 64 | 65 | state_encoder = tf.nn.rnn_cell.LSTMStateTuple(state_enc_1, state_enc_2) 66 | state_decoder = tf.nn.rnn_cell.LSTMStateTuple(state_dec_1, state_dec_2) 67 | state_feature = tf.nn.rnn_cell.LSTMStateTuple(state_fea_1, state_fea_2) 68 | 69 | # if args.motion == 'flow' or args.motion == 'flow_mc': 70 | frame_input = tf.concat([frame_1, frame_2, flow], axis=-1) 71 | frame_output, state_encoder, state_decoder, state_feature, flow \ 72 | = nn.get_network_pp(frame_input, state_encoder, state_decoder, state_feature, args.motion, args.input_norm) 73 | 74 | s_enc_1, s_enc_2 = state_encoder 75 | s_dec_1, s_dec_2 = state_decoder 76 | s_fea_1, s_fea_2 = state_feature 77 | 78 | saver = tf.train.Saver(max_to_keep=None) 79 | save_root = './model/Extrapolation' 80 | save_path = save_root + '/lambda_' + str(args.l) + '_extra/' 81 | # latest_model = tf.train.latest_checkpoint(checkpoint_dir=save_path) 82 | print("\033[31m" + save_path + "\033[0m") 83 | if os.path.exists(save_path + 'model.ckpt.index'): 84 | saver.restore(sess, save_path + 'model.ckpt') 85 | else: 86 | saver.restore(sess, save_path + 'model.ckpt-150000') 87 | 88 | frame_out, state_enc_1, state_enc_2, \ 89 | state_dec_1, state_dec_2, state_fea_1, state_fea_2, pre_flow \ 90 | = sess.run([frame_output, s_enc_1, s_enc_2, s_dec_1, s_dec_2, s_fea_1, s_fea_2, flow]) 91 | 92 | # np.save(args.path + 'f' + str(args.idx).zfill(3) + '_extra.npy', frame_out) 93 | frame_out = np.uint8(np.round(np.clip(frame_out, 0, 1) * 255.0)) 94 | imageio.imwrite(args.path + 'f' + str(args.idx).zfill(3) + '_extra.png', frame_out[0]) 95 | 96 | np.save(args.path + '/extra_states/state_enc_1.npy', state_enc_1) 97 | np.save(args.path + '/extra_states/state_enc_2.npy', state_enc_2) 98 | np.save(args.path + '/extra_states/state_dec_1.npy', state_dec_1) 99 | np.save(args.path + '/extra_states/state_dec_2.npy', state_dec_2) 100 | np.save(args.path + '/extra_states/state_fea_1.npy', state_fea_1) 101 | np.save(args.path + '/extra_states/state_fea_2.npy', state_fea_2) 102 | np.save(args.path + '/extra_states/pre_flow.npy', pre_flow) 103 | 104 | -------------------------------------------------------------------------------- /Interpolation_Compression.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import helper 5 | from functions_inter import * 6 | from scipy import misc 7 | import inv_flow 8 | import mc_func 9 | 10 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 11 | 12 | config = tf.ConfigProto(allow_soft_placement=True) 13 | sess = tf.Session(config=config) 14 | 15 | parser = argparse.ArgumentParser( 16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | parser.add_argument("--path", default='BasketballPass') 18 | parser.add_argument("--frame", type=int, default=14) 19 | parser.add_argument("--f_P", type=int, default=5) 20 | parser.add_argument("--inter", type=int, default=2) 21 | parser.add_argument("--b_P", type=int, default=5) 22 | parser.add_argument("--mode", default='PSNR', choices=['PSNR', 'MS-SSIM']) 23 | parser.add_argument("--metric", default='PSNR', choices=['PSNR', 'MS-SSIM']) 24 | parser.add_argument("--VTM", type=int, default=1, choices=[0, 1]) 25 | parser.add_argument("--python_path", default='python') 26 | parser.add_argument("--l", type=int, default=512, choices=[256, 512, 1024, 2048]) 27 | parser.add_argument("--N", type=int, default=128, choices=[128]) 28 | parser.add_argument("--M", type=int, default=128, choices=[128]) 29 | args = parser.parse_args() 30 | 31 | path_root = './' 32 | path_raw = args.path + '/' 33 | 34 | # Settings 35 | I_level, Height, Width, batch_size, Channel, \ 36 | activation, GOP_size, GOP_num, \ 37 | path, path_com, path_bin, path_lat = helper.configure(args, path_root=path_root, path_raw=path_raw) 38 | 39 | 40 | # Placeholder 41 | data_tensor = tf.placeholder(tf.float32, [batch_size, 5, Height, Width, Channel]) 42 | inter_num = tf.placeholder(tf.float32, []) 43 | 44 | [frame_left, frame_0, frame_t, frame_1, frame_right] = tf.unstack(data_tensor, axis=1) 45 | 46 | def q_flow(flow1, flow2, t): 47 | 48 | T = inter_num + 1 49 | 50 | a = (2 * T * flow1 + 2 * flow2)/(T + T ** 2) 51 | v0 = (-(T ** 2) * flow1 + flow2)/(T + T ** 2) 52 | 53 | return 0.5 * a * (t ** 2) + v0 * t 54 | 55 | def parametric_relu(_x, name='alpha'): 56 | alphas = tf.get_variable(name, _x.get_shape()[-1], 57 | initializer=tf.constant_initializer(0.25), 58 | dtype=tf.float32) 59 | pos = tf.nn.relu(_x) 60 | neg = alphas * (_x - abs(_x)) * 0.5 61 | 62 | return pos + neg 63 | 64 | with tf.variable_scope("flow_motion_short", reuse=tf.AUTO_REUSE): 65 | flow_0left, _, _, _, _, _ = inv_flow.optical_flow(frame_left, frame_0, batch_size, Height, Width) 66 | with tf.variable_scope("flow_motion_long", reuse=tf.AUTO_REUSE): 67 | flow_01, _, _, _, _, _ = inv_flow.optical_flow(frame_1, frame_0, batch_size, Height, Width) 68 | with tf.variable_scope("flow_motion_short", reuse=tf.AUTO_REUSE): 69 | flow_1right, _, _, _, _, _ = inv_flow.optical_flow(frame_right, frame_1, batch_size, Height, Width) 70 | with tf.variable_scope("flow_motion_long", reuse=tf.AUTO_REUSE): 71 | flow_10, _, _, _, _, _ = inv_flow.optical_flow(frame_0, frame_1, batch_size, Height, Width) 72 | 73 | flow_0t = q_flow(flow_0left, flow_01, t=1) 74 | flow_1t = q_flow(flow_1right, flow_10, t=inter_num) 75 | 76 | flow_t0 = inv_flow.reverse_flow(flow_0t, Height, Width) 77 | flow_t1 = inv_flow.reverse_flow(flow_1t, Height, Width) 78 | 79 | t_warp0 = tf.contrib.image.dense_image_warp(frame_0, flow_t0) 80 | t_warp1 = tf.contrib.image.dense_image_warp(frame_1, flow_t1) 81 | 82 | with tf.variable_scope('flow_refine'): 83 | refine_input = tf.concat([frame_0, frame_1, flow_01, flow_10, flow_t0, flow_t1, t_warp0, t_warp1], axis=-1) 84 | refine_output, feature = mc_func.refine_net(refine_input, out_channel=8) 85 | 86 | refine_1, refine_2, refine_3, refine_4 = tf.split(refine_output, [2, 2, 2, 2], axis=-1) 87 | 88 | flow_t0_refine = tf.contrib.image.dense_image_warp(flow_t0, 10 * tf.tanh(refine_1)) + refine_2 89 | flow_t1_refine = tf.contrib.image.dense_image_warp(flow_t1, 10 * tf.tanh(refine_3)) + refine_4 90 | 91 | frame_t_warp0 = tf.contrib.image.dense_image_warp(frame_0, flow_t0_refine) 92 | frame_t_warp1 = tf.contrib.image.dense_image_warp(frame_1, flow_t1_refine) 93 | 94 | with tf.variable_scope("masknet"): 95 | mask_input = tf.concat([frame_t_warp0, frame_t_warp1, feature], axis=-1) 96 | tensor_mask = tf.layers.conv2d(inputs=mask_input, filters=32, kernel_size=5, strides=1, activation=parametric_relu, padding='same') 97 | tensor_mask = tf.layers.conv2d(inputs=tensor_mask, filters=16, kernel_size=3, strides=1, activation=parametric_relu, padding='same') 98 | tensor_mask = tf.sigmoid(tf.layers.conv2d(inputs=tensor_mask, filters=2, kernel_size=3, strides=1, padding='same')) 99 | 100 | frame_t_warp = tf.div_no_nan((frame_t_warp0 * tensor_mask[:, :, :, 0:1] * inter_num + frame_t_warp1 * tensor_mask[:, :, :, 1:2]), 101 | (tensor_mask[:, :, :, 0:1] * inter_num + tensor_mask[:, :, :, 1:2])) 102 | 103 | with tf.variable_scope('post'): 104 | 105 | input_to_post = tf.concat([frame_0, frame_1, frame_t_warp, tensor_mask, frame_t_warp0, frame_t_warp1, flow_t0_refine, flow_t1_refine], axis=-1) 106 | output = frame_t_warp + mc_func.MC_RLVC(input_to_post) 107 | 108 | entropy_mv = tfc.EntropyBottleneck() 109 | entropy_res = tfc.EntropyBottleneck() 110 | 111 | frame_t_com, mse_loss, psnr_loss, bpp_loss, flow_lat, res_lat = \ 112 | DVC_compress(output, frame_t, entropy_mv, entropy_res, batch_size, Height, Width, args, training=False) 113 | 114 | psnr_value = np.load(path_bin + 'quality.npy') 115 | bpp_value = np.load(path_bin + 'bpp.npy') 116 | 117 | sess.run(tf.global_variables_initializer()) 118 | all_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 119 | saver = tf.train.Saver(max_to_keep=None, var_list=all_var) 120 | model_path = './model/Interpolation/lambda_' + str(args.l) + '_inter_' + str(args.inter) + '/' 121 | saver.restore(sess, save_path=model_path + 'model.ckpt-200000') 122 | 123 | # encode GOPs 124 | for g in range(GOP_num): 125 | 126 | F_left = misc.imread(path_com + 'f' + str(g * GOP_size + args.f_P).zfill(3) + '.png').astype(float) 127 | F_0 = misc.imread(path_com + 'f' + str(g * GOP_size + args.f_P + 1).zfill(3) + '.png').astype(float) 128 | F_1 = misc.imread(path_com + 'f' + str(g * GOP_size + args.f_P + args.inter + 2).zfill(3) + '.png').astype(float) 129 | F_right = misc.imread(path_com + 'f' + str(g * GOP_size + args.f_P + args.inter + 3).zfill(3) + '.png').astype(float) 130 | 131 | G_t = misc.imread(path_raw + 'f' + str(g * GOP_size + args.f_P + 2).zfill(3) + '.png').astype(float) 132 | 133 | input_data = np.stack([F_left, F_0, G_t, F_1, F_right], axis=0) 134 | input_data = np.expand_dims(input_data/255.0, axis=0) 135 | 136 | psnr, bpp, F_t, mv_latent, res_latent \ 137 | = sess.run([psnr_loss, bpp_loss, frame_t_com, flow_lat, res_lat], 138 | feed_dict={data_tensor:input_data, inter_num:args.inter}) 139 | 140 | F_t = np.clip(F_t, 0, 1) 141 | F_t = np.uint8(F_t * 255.0) 142 | 143 | psnr_value[g * GOP_size + args.f_P + 1] = psnr 144 | bpp_value[g * GOP_size + args.f_P + 1] = bpp 145 | 146 | print('Frame', g * GOP_size + args.f_P + 2, args.metric + ' =', psnr) 147 | 148 | misc.imsave(path_com + 'f' + str(g * GOP_size + args.f_P + 2).zfill(3) + '.png', F_t[0]) 149 | np.save(path_lat + 'f' + str(g * GOP_size + args.f_P + 2).zfill(3) + '_mv.npy', mv_latent) 150 | np.save(path_lat + 'f' + str(g * GOP_size + args.f_P + 2).zfill(3) + '_res.npy', res_latent) 151 | 152 | sess.run(tf.global_variables_initializer()) 153 | all_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 154 | saver = tf.train.Saver(max_to_keep=None, var_list=all_var) 155 | model_path = './model/Interpolation/lambda_' + str(args.l) + '_inter_' + str(args.inter - 1) + '/' 156 | saver.restore(sess, save_path=model_path + 'model.ckpt-200000') 157 | 158 | # encode GOPs 159 | for g in range(GOP_num): 160 | 161 | # print(GOP_size, g * GOP_size + args.f_P + args.inter + 3) 162 | F_left = misc.imread(path_com + 'f' + str(g * GOP_size + args.f_P + 1).zfill(3) + '.png').astype(float) 163 | F_0 = misc.imread(path_com + 'f' + str(g * GOP_size + args.f_P + 2).zfill(3) + '.png').astype(float) 164 | F_1 = misc.imread(path_com + 'f' + str(g * GOP_size + args.f_P + args.inter + 2).zfill(3) + '.png').astype(float) 165 | F_right = misc.imread(path_com + 'f' + str(g * GOP_size + args.f_P + args.inter + 3).zfill(3) + '.png').astype(float) 166 | 167 | G_t = misc.imread(path_raw + 'f' + str(g * GOP_size + args.f_P + args.inter + 1).zfill(3) + '.png').astype(float) 168 | 169 | input_data = np.stack([F_left, F_0, G_t, F_1, F_right], axis=0) 170 | input_data = np.expand_dims(input_data / 255.0, axis=0) 171 | 172 | psnr, bpp, F_t, mv_latent, res_latent\ 173 | = sess.run([psnr_loss, bpp_loss, frame_t_com, flow_lat, res_lat], 174 | feed_dict={data_tensor: input_data, inter_num:args.inter - 1}) 175 | 176 | F_t = np.clip(F_t, 0, 1) 177 | F_t = np.uint8(F_t * 255.0) 178 | 179 | psnr_value[g * GOP_size + args.f_P + args.inter] = psnr 180 | bpp_value[g * GOP_size + args.f_P + args.inter] = bpp 181 | 182 | print('Frame', g * GOP_size + args.f_P + args.inter + 1, args.metric + ' =', psnr) 183 | 184 | misc.imsave(path_com + 'f' + str(g * GOP_size + args.f_P + args.inter + 1).zfill(3) + '.png', F_t[0]) 185 | np.save(path_lat + 'f' + str(g * GOP_size + args.f_P + args.inter + 1).zfill(3) + '_mv.npy', mv_latent) 186 | np.save(path_lat + 'f' + str(g * GOP_size + args.f_P + args.inter + 1).zfill(3) + '_res.npy', res_latent) 187 | 188 | print('Average: ' + args.path, np.average(psnr_value), np.average(bpp_value)) 189 | 190 | np.save(path_bin + '/quality.npy', psnr_value) 191 | np.save(path_bin + '/bpp.npy', bpp_value) 192 | 193 | 194 | 195 | -------------------------------------------------------------------------------- /MC_network_inter.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def resblock(input, IC, OC, name): 5 | 6 | l1 = tf.nn.relu(input, name=name + 'relu1') 7 | 8 | l1 = tf.layers.conv2d(inputs=l1, filters=np.minimum(IC, OC), kernel_size=3, strides=1, padding='same', 9 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=True), name=name + 'l1') 10 | 11 | l2 = tf.nn.relu(l1, name='relu2') 12 | 13 | l2 = tf.layers.conv2d(inputs=l2, filters=OC, kernel_size=3, strides=1, padding='same', 14 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=True), name=name + 'l2') 15 | 16 | if IC != OC: 17 | input = tf.layers.conv2d(inputs=input, filters=OC, kernel_size=1, strides=1, padding='same', 18 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=True), name=name + 'map') 19 | 20 | return input + l2 21 | 22 | 23 | def MC(input): 24 | 25 | m1 = tf.layers.conv2d(inputs=input, filters=64, kernel_size=3, strides=1, padding='same', 26 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=True), name='mc1') 27 | 28 | m2 = resblock(m1, 64, 64, name='mc2') 29 | 30 | m3 = tf.layers.average_pooling2d(m2, pool_size=2, strides=2, padding='same') 31 | 32 | m4 = resblock(m3, 64, 64, name='mc4') 33 | 34 | m5 = tf.layers.average_pooling2d(m4, pool_size=2, strides=2, padding='same') 35 | 36 | m6 = resblock(m5, 64, 64, name='mc6') 37 | 38 | m7 = resblock(m6, 64, 64, name='mc7') 39 | 40 | m8 = tf.image.resize_images(m7, [2 * tf.shape(m7)[1], 2 * tf.shape(m7)[2]]) 41 | 42 | m8 = m4 + m8 43 | 44 | m9 = resblock(m8, 64, 64, name='mc9') 45 | 46 | m10 = tf.image.resize_images(m9, [2 * tf.shape(m9)[1], 2 * tf.shape(m9)[2]]) 47 | 48 | m10 = m2 + m10 49 | 50 | m11 = resblock(m10, 64, 64, name='mc11') 51 | 52 | m12 = tf.layers.conv2d(inputs=m11, filters=64, kernel_size=3, strides=1, padding='same', 53 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=True), name='mc12') 54 | 55 | m12 = tf.nn.relu(m12, name='relu12') 56 | 57 | m13 = tf.layers.conv2d(inputs=m12, filters=3, kernel_size=3, strides=1, padding='same', 58 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=True), name='mc13') 59 | 60 | return m13 61 | 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Our other works on learned video compression: 2 | 3 | - Perceptual Learned Video Compression (PLVC) (**IJCAI 2022**) [[Paper](https://arxiv.org/abs/2109.03082)] [[Codes](https://github.com/RenYang-home/PLVC)] 4 | 5 | - Hierarchical Learned Video Compression (HLVC) (**CVPR 2020**) [[Paper](https://arxiv.org/abs/2003.01966)] [[Codes](https://github.com/RenYang-home/HLVC)] 6 | 7 | - Recurrent Learned Video Compression (RLVC) (**IEEE J-STSP 2021**) [[Paper](https://ieeexplore.ieee.org/abstract/document/9288876)] [[Codes](https://github.com/RenYang-home/RLVC)] 8 | 9 | - OpenDVC: An open source implementation of DVC [[Codes](https://github.com/RenYang-home/OpenDVC)] [[Technical report](https://arxiv.org/abs/2006.15862)] 10 | 11 | # Advancing Learned Video Compression with In-loop Frame Prediction 12 | 13 | The project page for the paper: 14 | 15 | > Ren Yang, Radu Timofte and Luc Van Gool, "Advancing Learned Video Compression with In-loop Frame Prediction", IEEE Transactions on Circuits and Systems for Video Technology (IEEE T-CSVT), 2022. [[Paper]](https://ieeexplore.ieee.org/abstract/document/9950550) 16 | 17 | If our paper and codes are useful for your research, please cite: 18 | ``` 19 | @article{yang2022advancing, 20 | title={Advancing Learned Video Compression with In-loop Frame Prediction}, 21 | author={Yang, Ren and Timofte, Radu and Van Gool, Luc}, 22 | journal={IEEE Transactions on Circuits and Systems for Video Technology}, 23 | year={2022}, 24 | publisher={IEEE} 25 | } 26 | ``` 27 | 28 | If you have questions or find bugs, please contact: 29 | 30 | Ren Yang @ ETH Zurich, Switzerland 31 | 32 | Email: r.yangchn@gmail.com 33 | 34 | ## Codes 35 | 36 | ### Preperation 37 | 38 | We feed RGB images into the our encoder. To compress a YUV video, please first convert to PNG images with the following command. 39 | 40 | ``` 41 | ffmpeg -pix_fmt yuv420p -s WidthxHeight -i Name.yuv -vframes Frame path_to_PNG/f%03d.png 42 | ``` 43 | 44 | Note that, our RLVC codes currently only support the frames with the height and width as the multiples of 16. Therefore, when using these codes, if the height and width of frames are not the multiples of 16, please first crop frames, e.g., 45 | 46 | ``` 47 | ffmpeg -pix_fmt yuv420p -s 1920x1080 -i Name.yuv -vframes Frame -filter:v "crop=1920:1072:0:0" path_to_PNG/f%03d.png 48 | ``` 49 | 50 | We uploaded a prepared sequence *BasketballPass* here as a test demo, which contains the PNG files of the first 100 frames. 51 | 52 | ### Dependency 53 | 54 | - Tensorflow 1.12 55 | 56 | (*Since we train the models on tensorflow-compression 1.0, which is only compatibable with tf 1.12, the pre-trained models are not compatible with higher versions.*) 57 | 58 | - Tensorflow-compression 1.0 ([Download link](https://github.com/tensorflow/compression/releases/tag/v1.0)) 59 | 60 | (*After downloading, put the folder "tensorflow_compression" to the same directory as the codes.*) 61 | 62 | - SciPy 1.2.0 63 | 64 | (*Since we use misc.imread, do not use higher versions in which misc.imread is removed.*) 65 | 66 | - Pre-trained models ([Download link](https://drive.google.com/file/d/1WJo_VkyG4qRRyGox_R1ip0l4LMcrIQur/view?usp=sharing)) 67 | 68 | (*Download the folder "model" to the same directory as the codes.*) 69 | 70 | - VTM ([Download link](https://vcgit.hhi.fraunhofer.de/jvet/VVCSoftware_VTM)) 71 | 72 | (*In our PSNR model, we use VVC to compress I-frames. Please compile VTM and put the folder "VVCSoftware_VTM" in the same directory as the codes.*) 73 | 74 | ### Test code 75 | 76 | The augments in the ALVC test code (ALVC.py) include: 77 | 78 | ``` 79 | --path, the path to PNG files; 80 | 81 | --l, lambda value. The pre-trained PSNR models are trained by 4 lambda values, i.e., 256, 512, 1024 and 2048, with increasing bit-rate/PSNR. 82 | ``` 83 | For example: 84 | ``` 85 | python ALVC.py --path BasketballPass 86 | ``` 87 | -------------------------------------------------------------------------------- /Recurrent_AutoEncoder_Extrapolation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import tensorflow as tf 4 | import tensorflow_compression as tfc 5 | import os 6 | from scipy import misc 7 | import CNN_recurrent 8 | import motion 9 | import functions 10 | import helper2 11 | # import flow_vis 12 | 13 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 14 | 15 | config = tf.ConfigProto(allow_soft_placement=True) 16 | sess = tf.Session(config=config) 17 | 18 | parser = argparse.ArgumentParser( 19 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 20 | parser.add_argument("--path", default='BasketballPass') 21 | parser.add_argument("--frame", type=int, default=100) 22 | parser.add_argument("--f_P", type=int, default=5) 23 | parser.add_argument("--inter", type=int, default=2) 24 | parser.add_argument("--b_P", type=int, default=5) 25 | parser.add_argument("--mode", default='PSNR', choices=['PSNR']) 26 | parser.add_argument("--metric", default='PSNR', choices=['PSNR']) 27 | parser.add_argument("--VTM", type=int, default=1, choices=[0, 1]) 28 | parser.add_argument("--l", type=int, default=512, choices=[256, 512, 1024, 2048]) 29 | parser.add_argument("--N", type=int, default=128, choices=[128]) 30 | parser.add_argument("--M", type=int, default=128, choices=[128]) 31 | 32 | args = parser.parse_args() 33 | 34 | # Settings 35 | I_level, Height, Width, batch_size, Channel, \ 36 | activation, GOP_size, GOP_num, \ 37 | path, path_com, path_bin, path_lat = helper2.configure(args) 38 | 39 | if os.path.exists(path_com + '/extra_states'): 40 | os.system('rm -r ' + path_com + '/extra_states') 41 | 42 | # Placeholder 43 | Y0_com_tensor = tf.placeholder(tf.float32, [batch_size, Height, Width, Channel]) # reference frame 44 | Y1_raw_tensor = tf.placeholder(tf.float32, [batch_size, Height, Width, Channel]) # raw frame to compress 45 | 46 | hidden_states = tf.placeholder(tf.float32, [8, batch_size, Height//4, Width//4, args.N]) # hidden states in RAE 47 | 48 | c_enc_mv, h_enc_mv, \ 49 | c_dec_mv, h_dec_mv, \ 50 | c_enc_res, h_enc_res, \ 51 | c_dec_res, h_dec_res = tf.split(hidden_states, 8, axis=0) 52 | 53 | RPM_flag = tf.placeholder(tf.bool, []) # use RPM (=1) or bottleneck (=0) 54 | 55 | # motion estimation 56 | with tf.variable_scope("flow_motion"): 57 | motion_tensor, _, _, _, _, _ = motion.optical_flow(Y0_com_tensor, Y1_raw_tensor, batch_size, Height, Width) 58 | 59 | # RAE encoder for motion 60 | motion_latent, c_enc_mv_out, h_enc_mv_out = CNN_recurrent.MV_analysis(motion_tensor, num_filters=args.N, out_filters=args.M, 61 | Height=Height, Width=Width, 62 | c_state=c_enc_mv[0], h_state=h_enc_mv[0], act=activation) 63 | 64 | # encode the latent of the first P frame by the bottleneck 65 | entropy_bottleneck = tfc.EntropyBottleneck(name='entropy_bottleneck') 66 | string = tf.squeeze(entropy_bottleneck.compress(motion_latent), axis=0) 67 | motion_latent_decom = entropy_bottleneck.decompress(tf.expand_dims(string, 0), [Height//16, Width//16, args.M], channels=args.M) 68 | motion_latent_hat = tf.cond(RPM_flag, lambda: tf.round(motion_latent), lambda: motion_latent_decom) 69 | 70 | # RAE decoder for motion 71 | motion_hat, c_dec_mv_out, h_dec_mv_out = CNN_recurrent.MV_synthesis(motion_latent_hat, num_filters=args.N, 72 | Height=Height, Width=Width, 73 | c_state=c_dec_mv[0], h_state=h_dec_mv[0], act=activation) 74 | 75 | # Motion Compensation 76 | Y1_warp = tf.contrib.image.dense_image_warp(Y0_com_tensor, motion_hat) 77 | # Y1_warp = warp.dense_image_warp(Y0_com_tensor, motion_hat) 78 | 79 | MC_input = tf.concat([motion_hat, Y0_com_tensor, Y1_warp], axis=-1) 80 | Y1_MC = functions.MC_RLVC(MC_input) 81 | 82 | # Get residual 83 | Res = Y1_raw_tensor - Y1_MC 84 | 85 | # RAE encoder for residual 86 | res_latent, c_enc_res_out, h_enc_res_out = CNN_recurrent.Res_analysis(Res, num_filters=args.N, out_filters=args.M, 87 | Height=Height, Width=Width, 88 | c_state=c_enc_res[0], h_state=h_enc_res[0], act=activation) 89 | 90 | # encode the latent of the first P frame by the bottleneck 91 | entropy_bottleneck2 = tfc.EntropyBottleneck(name='entropy_bottleneck_1_1') 92 | string2 = entropy_bottleneck2.compress(res_latent) 93 | string2 = tf.squeeze(string2, axis=0) 94 | res_latent_decom = entropy_bottleneck2.decompress(tf.expand_dims(string2, 0), [Height//16, Width//16, args.M], channels=args.M) 95 | res_latent_hat = tf.cond(RPM_flag, lambda: tf.round(res_latent), lambda: res_latent_decom) 96 | 97 | # RAE decoder for residual 98 | res_hat, c_dec_res_out, h_dec_res_out = CNN_recurrent.Res_synthesis(res_latent_hat, num_filters=args.N, 99 | Height=Height, Width=Width, 100 | c_state=c_dec_res[0], h_state=h_dec_res[0], act=activation) 101 | 102 | # reconstructed frame 103 | Y1_decoded = tf.clip_by_value(res_hat + Y1_MC, 0, 1) 104 | 105 | # output hidden states 106 | hidden_states_out = tf.stack([c_enc_mv_out, h_enc_mv_out, 107 | c_dec_mv_out, h_dec_mv_out, 108 | c_enc_res_out, h_enc_res_out, 109 | c_dec_res_out, h_dec_res_out], axis=0) 110 | 111 | # PANR or MS-SSIM 112 | # if args.metric == 'PSNR': 113 | mse = tf.reduce_mean(tf.squared_difference(Y1_decoded, Y1_raw_tensor)) 114 | quality_tensor = 10.0*tf.log(1.0/mse)/tf.log(10.0) 115 | mse_mc = tf.reduce_mean(tf.squared_difference(Y1_MC, Y1_raw_tensor)) 116 | psnr_mc = 10.0 * tf.log(1.0 / mse_mc) / tf.log(10.0) 117 | # elif args.metric == 'MS-SSIM': 118 | # quality_tensor = tf.math.reduce_mean(tf.image.ssim_multiscale(Y1_decoded, Y1_raw_tensor, max_val=1)) 119 | # mse_mc = tf.reduce_mean(tf.squared_difference(Y1_MC, Y1_raw_tensor)) 120 | # psnr_mc = 10.0 * tf.log(1.0 / mse_mc) / tf.log(10.0) 121 | 122 | # load model 123 | saver = tf.train.Saver(max_to_keep=None) 124 | model_path = './model/Extrapolation/lambda_' + str(args.l) + '_extra' 125 | saver.restore(sess, save_path=model_path + '/model.ckpt-150000') 126 | 127 | 128 | # init quality 129 | quality_frame = np.zeros([args.frame]) 130 | 131 | # encode the first I frame 132 | frame_index = 1 133 | quality = helper2.encode_I(args, frame_index, I_level, path, path_com, path_bin) 134 | quality_frame[frame_index - 1] = quality 135 | 136 | if os.path.exists(path_com + '/extra_states'): 137 | os.system('rm -r ' + path_com + '/extra_states') 138 | 139 | # encode GOPs 140 | for g in range(GOP_num): 141 | 142 | np.save(path_bin + 'quality.npy', quality_frame) 143 | 144 | # forward P frames 145 | 146 | # load I frame (compressed) 147 | frame_index = g * GOP_size + 1 148 | F0_com = misc.imread(path_com + 'f' + str(frame_index).zfill(3) + '.png') 149 | F0_com = np.expand_dims(F0_com, axis=0) 150 | 151 | for f in range(args.f_P): 152 | 153 | # load P frame (raw) 154 | frame_index = g * GOP_size + f + 2 155 | F1_raw = misc.imread(path + 'f' + str(frame_index).zfill(3) + '.png') 156 | F1_raw = np.expand_dims(F1_raw, axis=0) 157 | 158 | # init hidden states 159 | if f % 5 == 0: 160 | h_state = np.zeros([8, batch_size, Height // 4, Width // 4, args.N], dtype=np.float) 161 | # since the model is optimized on 6 frames, we reset hidden states every 6 P frames 162 | 163 | if f == 0: 164 | flag = False 165 | # the first P frame uses bottleneck 166 | else: 167 | flag = True 168 | 169 | if f >= 1: 170 | # python_cpu = '' 171 | os.system( 172 | 'python Extrapolation.py --path ' + path_com + ' --idx ' + str(frame_index) + ' --l ' + str(args.l)) 173 | F_ref = misc.imread(path_com + 'f' + str(frame_index).zfill(3) + '_extra.png').astype(float) 174 | F_ref = np.expand_dims(F_ref, axis=0) 175 | # F_ref = np.load(path_com + 'f' + str(frame_index).zfill(3) + '_extra.npy') * 255.0 176 | mse = np.mean(np.power(np.subtract(F_ref / 255.0, F1_raw / 255.0), 2.0)) 177 | quality = 10 * np.log10(1.0 / mse) 178 | 179 | # print('Frame', frame_index, args.metric + '_extra =', quality) 180 | 181 | else: 182 | F_ref = F0_com 183 | 184 | # run RAE 185 | F0_com, string_MV, string_Res, quality, h_state, latent_mv, latent_res, psnr_mc_value, motion_map \ 186 | = sess.run([Y1_decoded, string, string2, quality_tensor, 187 | hidden_states_out, motion_latent_hat, res_latent_hat, psnr_mc, motion_tensor], 188 | feed_dict={Y0_com_tensor: F_ref / 255.0, Y1_raw_tensor: F1_raw / 255.0, 189 | hidden_states: h_state, RPM_flag: flag}) 190 | F0_com = F0_com * 255 191 | # flow_color = flow_vis.flow_to_color(motion_map[0], convert_to_bgr=False) 192 | # save bottleneck bitstream 193 | if not flag: 194 | with open(path_bin + '/f' + str(frame_index).zfill(3) + '.bin', "wb") as ff: 195 | ff.write(np.array(len(string_MV), dtype=np.uint16).tobytes()) 196 | ff.write(string_MV) 197 | ff.write(string_Res) 198 | 199 | # save compressed frame and latents 200 | misc.imsave(path_com + '/f' + str(frame_index).zfill(3) + '.png', np.uint8(np.round(F0_com[0]))) 201 | # misc.imsave(path_com + '/f' + str(frame_index).zfill(3) + '_flow.png', flow_color) 202 | np.save(path_lat + '/f' + str(frame_index).zfill(3) + '_mv.npy', latent_mv) 203 | np.save(path_lat + '/f' + str(frame_index).zfill(3) + '_res.npy', latent_res) 204 | 205 | quality_frame[frame_index - 1] = quality 206 | 207 | print('Frame', frame_index, args.metric + ' =', quality) 208 | 209 | os.system('rm -r ' + path_com + '/extra_states') 210 | 211 | # encode the next I frame 212 | frame_index = (g + 1) * GOP_size + 1 213 | quality = helper2.encode_I(args, frame_index, I_level, path, path_com, path_bin) 214 | quality_frame[frame_index - 1] = quality 215 | 216 | # backward P frames 217 | 218 | # load I frame (compressed) 219 | F0_com = misc.imread(path_com + 'f' + str(frame_index).zfill(3) + '.png') 220 | F0_com = np.expand_dims(F0_com, axis=0) 221 | 222 | for f in range(args.b_P): 223 | 224 | # load P frame (raw) 225 | frame_index = (g + 1) * GOP_size - f 226 | F1_raw = misc.imread(path + 'f' + str(frame_index).zfill(3) + '.png') 227 | F1_raw = np.expand_dims(F1_raw, axis=0) 228 | 229 | # init hidden states 230 | if f % 5 == 0: 231 | h_state = np.zeros([8, batch_size, Height // 4, Width // 4, args.N], dtype=np.float) 232 | # since the model is optimized on 6 frames, we reset hidden states every 6 P frames 233 | 234 | if f == 0: 235 | flag = False 236 | # the first P frame uses bottleneck 237 | else: 238 | flag = True 239 | 240 | if f >= 1: 241 | # python_cpu = '' 242 | os.system( 243 | 'python Extrapolation.py --dirc bw --path ' + path_com + ' --idx ' + str(frame_index) + ' --l ' + str(args.l)) 244 | F_ref = misc.imread(path_com + 'f' + str(frame_index).zfill(3) + '_extra.png').astype(float) 245 | F_ref = np.expand_dims(F_ref, axis=0) 246 | 247 | mse = np.mean(np.power(np.subtract(F_ref / 255.0, F1_raw / 255.0), 2.0)) 248 | quality = 10 * np.log10(1.0 / mse) 249 | 250 | # print('Frame', frame_index, args.metric + '_extra =', quality) 251 | 252 | else: 253 | F_ref = F0_com 254 | 255 | # run RAE 256 | F0_com, string_MV, string_Res, quality, h_state, latent_mv, latent_res, psnr_mc_value, motion_map \ 257 | = sess.run([Y1_decoded, string, string2, quality_tensor, 258 | hidden_states_out, motion_latent_hat, res_latent_hat, psnr_mc, motion_tensor], 259 | feed_dict={Y0_com_tensor: F_ref / 255.0, Y1_raw_tensor: F1_raw / 255.0, 260 | hidden_states: h_state, RPM_flag: flag}) 261 | F0_com = F0_com * 255 262 | 263 | # flow_color = flow_vis.flow_to_color(motion_map[0], convert_to_bgr=False) 264 | 265 | # save bottleneck bitstream 266 | if not flag: 267 | with open(path_bin + '/f' + str(frame_index).zfill(3) + '.bin', "wb") as ff: 268 | ff.write(np.array(len(string_MV), dtype=np.uint16).tobytes()) 269 | ff.write(string_MV) 270 | ff.write(string_Res) 271 | 272 | # save compressed frame and latents 273 | misc.imsave(path_com + '/f' + str(frame_index).zfill(3) + '.png', np.uint8(np.round(F0_com[0]))) 274 | # misc.imsave(path_com + '/f' + str(frame_index).zfill(3) + '_flow.png', flow_color) 275 | np.save(path_lat + '/f' + str(frame_index).zfill(3) + '_mv.npy', latent_mv) 276 | np.save(path_lat + '/f' + str(frame_index).zfill(3) + '_res.npy', latent_res) 277 | 278 | quality_frame[frame_index - 1] = quality 279 | 280 | print('Frame', frame_index, args.metric + ' =', quality) 281 | 282 | os.system('rm -r ' + path_com + '/extra_states') 283 | 284 | # encode rest frames 285 | rest_frame_num = args.frame - 1 - GOP_size * GOP_num 286 | 287 | # load I frame (compressed) 288 | frame_index = GOP_num * GOP_size + 1 289 | F0_com = misc.imread(path_com + 'f' + str(frame_index).zfill(3) + '.png') 290 | F0_com = np.expand_dims(F0_com, axis=0) 291 | 292 | for f in range(rest_frame_num): 293 | 294 | # load P frame (raw) 295 | frame_index = GOP_num * GOP_size + f + 2 296 | F1_raw = misc.imread(path + 'f' + str(frame_index).zfill(3) + '.png') 297 | F1_raw = np.expand_dims(F1_raw, axis=0) 298 | 299 | # init hidden states 300 | if f % 5 == 0: 301 | h_state = np.zeros([8, batch_size, Height // 4, Width // 4, args.N], dtype=np.float) 302 | if os.path.exists(path_com + '/extra_states'): 303 | os.system('rm -r ' + path_com + '/extra_states') 304 | # since the model is optimized on 6 frames, we reset hidden states every 6 P frames 305 | 306 | if f == 0: 307 | flag = False 308 | # the first P frame uses the bottleneck 309 | else: 310 | flag = True 311 | 312 | 313 | if f >= 1: 314 | # python_cpu = '' 315 | os.system('python Extrapolation.py --path ' + path_com + ' --idx ' + str(frame_index) + ' --l ' + str(args.l)) 316 | 317 | F_ref = misc.imread(path_com + 'f' + str(frame_index).zfill(3) + '_extra.png').astype(float) 318 | F_ref = np.expand_dims(F_ref, axis=0) 319 | 320 | mse = np.mean(np.power(np.subtract(F_ref / 255.0, F1_raw / 255.0), 2.0)) 321 | quality = 10 * np.log10(1.0 / mse) 322 | 323 | # print('Frame', frame_index, args.metric + '_extra =', quality) 324 | 325 | else: 326 | F_ref = F0_com 327 | 328 | 329 | # run RAE 330 | F0_com, string_MV, string_Res, quality, h_state, latent_mv, latent_res, psnr_mc_value, res_value \ 331 | = sess.run([Y1_decoded, string, string2, quality_tensor, 332 | hidden_states_out, motion_latent_hat, res_latent_hat, psnr_mc, Res], 333 | feed_dict={Y0_com_tensor: F_ref / 255.0, Y1_raw_tensor: F1_raw / 255.0, 334 | hidden_states: h_state, RPM_flag: flag}) 335 | F0_com = F0_com * 255 336 | # save bottleneck bitstream 337 | if not flag: 338 | with open(path_bin + '/f' + str(frame_index).zfill(3) + '.bin', "wb") as ff: 339 | ff.write(np.array(len(string_MV), dtype=np.uint16).tobytes()) 340 | ff.write(string_MV) 341 | ff.write(string_Res) 342 | 343 | # save compressed frame and latents 344 | misc.imsave(path_com + '/f' + str(frame_index).zfill(3) + '.png', np.uint8(np.round(F0_com[0]))) 345 | np.save(path_lat + '/f' + str(frame_index).zfill(3) + '_mv.npy', latent_mv) 346 | np.save(path_lat + '/f' + str(frame_index).zfill(3) + '_res.npy', latent_res) 347 | 348 | quality_frame[frame_index - 1] = quality 349 | 350 | print('Frame', frame_index, args.metric + ' =', quality) 351 | 352 | if os.path.exists(path_com + '/extra_states'): 353 | os.system('rm -r ' + path_com + '/extra_states') 354 | 355 | np.save(path_bin + 'quality.npy', quality_frame) 356 | 357 | os.system('rm ' + path_com + '/*_extra.png') 358 | os.system('rm ' + path_com + '/*.yuv') 359 | 360 | 361 | 362 | 363 | 364 | 365 | -------------------------------------------------------------------------------- /Recurrent_Prob_Model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import tensorflow as tf 4 | import os 5 | import CNN_recurrent 6 | import helper2 7 | 8 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 9 | 10 | # use CPU for RPM to ensure the determinism 11 | # config = tf.ConfigProto(allow_soft_placement=True, device_count={'GPU': 0}) 12 | config = tf.ConfigProto(allow_soft_placement=True) 13 | sess = tf.Session(config=config) 14 | 15 | parser = argparse.ArgumentParser( 16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | parser.add_argument("--path", default='BasketballPass') 18 | parser.add_argument("--frame", type=int, default=100) 19 | parser.add_argument("--f_P", type=int, default=5) 20 | parser.add_argument("--inter", type=int, default=2) 21 | parser.add_argument("--b_P", type=int, default=5) 22 | parser.add_argument("--mode", default='PSNR', choices=['PSNR', 'MS-SSIM']) 23 | parser.add_argument("--l", type=int, default=512, choices=[256, 512, 1024, 2048]) 24 | parser.add_argument("--entropy_coding", type=int, default=0) 25 | parser.add_argument("--N", type=int, default=128, choices=[128]) 26 | parser.add_argument("--M", type=int, default=128, choices=[128]) 27 | args = parser.parse_args() 28 | 29 | # Settings 30 | I_level, Height, Width, batch_size, Channel, \ 31 | activation, GOP_size, GOP_num, \ 32 | path, path_com, path_bin, path_lat = helper2.configure(args) 33 | 34 | # Placeholder 35 | prior_tensor = tf.placeholder(tf.float32, [batch_size, Height//16, Width//16, args.M]) # previous latent 36 | latent_tensor = tf.placeholder(tf.float32, [batch_size, Height//16, Width//16, args.M]) # latent to compress 37 | 38 | hidden_states = tf.placeholder(tf.float32, [2, batch_size, Height//16, Width//16, args.N]) # hidden states in RPM 39 | 40 | c_prob, h_prob = tf.split(hidden_states, 2, axis=0) 41 | 42 | # RPM network 43 | prob_latent, c_prob_out, h_prob_out \ 44 | = CNN_recurrent.rec_prob(prior_tensor, args.N, Height, Width, c_prob[0], h_prob[0]) 45 | 46 | # estimate bpp 47 | bits_est, sigma, mu = CNN_recurrent.bpp_est(latent_tensor, prob_latent, args.N) 48 | 49 | hidden_states_out = tf.stack([c_prob_out, h_prob_out], axis = 0) 50 | 51 | # calculates bits for I frames and bottlenecks 52 | total_bits = 0 53 | bpp = np.zeros([args.frame]) 54 | 55 | for g in range(GOP_num + 1): 56 | 57 | I_index = g * GOP_size + 1 58 | 59 | if I_index <= args.frame: 60 | 61 | if os.path.exists(path_bin + 'f' + str(I_index).zfill(3) + '.bin'): 62 | total_bits += os.path.getsize(path_bin + 'f' + str(I_index).zfill(3) + '.bin') * 8 63 | bpp[I_index - 1] = os.path.getsize(path_bin + 'f' + str(I_index).zfill(3) + '.bin') * 8 /Height/Width 64 | else: 65 | bpp_I = np.load(path_bin + 'f' + str(I_index).zfill(3) + '.npy') 66 | total_bits += bpp_I * Height * Width 67 | bpp[I_index - 1] = bpp_I 68 | 69 | # if there exists forward P frame(s), I_index + 1 is encoded by the bottleneck 70 | if args.f_P > 0 and I_index + 1 <= args.frame: 71 | total_bits += os.path.getsize(path_bin + 'f' + str(I_index + 1).zfill(3) + '.bin') * 8 72 | bpp[I_index] = os.path.getsize(path_bin + 'f' + str(I_index + 1).zfill(3) + '.bin') * 8 / Height / Width 73 | 74 | # if there exists backward P frame(s), I_index - 1 is encoded by the bottleneck 75 | if args.b_P > 0 and I_index - 1 >= 1: 76 | total_bits += os.path.getsize(path_bin + 'f' + str(I_index - 1).zfill(3) + '.bin') * 8 77 | bpp[I_index - 2] = os.path.getsize(path_bin + 'f' + str(I_index - 1).zfill(3) + '.bin') * 8 / Height / Width 78 | 79 | # start RPM 80 | 81 | latents = ['mv', 'res'] # two kinds of latents 82 | 83 | for lat in latents: 84 | 85 | # load model 86 | model_path = './model/RPM/RPM_' + args.mode + '_' + str(args.l) + '_' + lat 87 | saver = tf.train.Saver(max_to_keep=None) 88 | saver.restore(sess, save_path=model_path + '/model.ckpt') 89 | 90 | # encode GOPs 91 | for g in range(GOP_num): 92 | 93 | # forward P frames (only if more than 2 P frames exist) 94 | if args.f_P >= 2: 95 | # load first prior 96 | frame_index = g * GOP_size + 2 97 | prior_value = np.load(path_lat + '/f' + str(frame_index).zfill(3) + '_' + lat + '.npy') 98 | 99 | # init state 100 | h_state = np.zeros([2, batch_size, Height // 16, Width // 16, args.N], dtype=np.float) 101 | 102 | for f in range(args.f_P - 1): 103 | 104 | # load latent 105 | frame_index = g * GOP_size + f + 3 106 | latent_value = np.load(path_lat + '/f' + str(frame_index).zfill(3) + '_' + lat + '.npy') 107 | 108 | # run RPM 109 | bits_estimation, sigma_value, mu_value, h_state \ 110 | = sess.run([bits_est, sigma, mu, hidden_states_out], 111 | feed_dict={prior_tensor: prior_value, latent_tensor: latent_value, 112 | hidden_states: h_state}) 113 | 114 | if args.entropy_coding: 115 | bits_value = helper2.entropy_coding(frame_index, lat, path_bin, latent_value, sigma_value, mu_value) 116 | total_bits += bits_value 117 | # print('Frame', frame_index, lat + '_bits =', bits_value) 118 | else: 119 | total_bits += bits_estimation 120 | # print('Frame', frame_index, lat + '_bits =', bits_estimation) 121 | bpp[frame_index - 1] += bits_estimation / Height / Width 122 | 123 | # the latent will be the prior for the next latent 124 | prior_value = latent_value 125 | 126 | # backward P frames (only if more than 2 P frames exist) 127 | if args.b_P >= 2: 128 | # load first prior 129 | frame_index = (g + 1) * GOP_size 130 | prior_value = np.load(path_lat + '/f' + str(frame_index).zfill(3) + '_' + lat + '.npy') 131 | 132 | # init state 133 | h_state = np.zeros([2, batch_size, Height // 16, Width // 16, args.N], dtype=np.float) 134 | 135 | for f in range(args.b_P - 1): 136 | 137 | # load latent 138 | frame_index = (g + 1) * GOP_size - f - 1 139 | latent_value = np.load(path_lat + '/f' + str(frame_index).zfill(3) + '_' + lat + '.npy') 140 | 141 | # run RPM 142 | bits_estimation, sigma_value, mu_value, h_state \ 143 | = sess.run([bits_est, sigma, mu, hidden_states_out], 144 | feed_dict={prior_tensor: prior_value, latent_tensor: latent_value, 145 | hidden_states: h_state}) 146 | 147 | if args.entropy_coding: 148 | bits_value = helper2.entropy_coding(frame_index, lat, path_bin, latent_value, sigma_value, mu_value) 149 | total_bits += bits_value 150 | # print('Frame', frame_index, lat + '_bits =', bits_value) 151 | else: 152 | total_bits += bits_estimation 153 | # print('Frame', frame_index, lat + '_bits =', bits_estimation) 154 | bpp[frame_index - 1] += bits_estimation / Height / Width 155 | 156 | # the latent will be the prior for the next latent 157 | prior_value = latent_value 158 | 159 | # encode rest frames (only if more than 2 P frames exist) 160 | rest_frame_num = args.frame - 1 - GOP_size * GOP_num 161 | 162 | if rest_frame_num >= 2: 163 | # load first prior 164 | frame_index = GOP_num * GOP_size + 2 165 | prior_value = np.load(path_lat + '/f' + str(frame_index).zfill(3) + '_' + lat + '.npy') 166 | 167 | # init state 168 | h_state = np.zeros([2, batch_size, Height // 16, Width // 16, args.N], dtype=np.float) 169 | 170 | for f in range(rest_frame_num - 1): 171 | 172 | # load latent 173 | frame_index = GOP_num * GOP_size + f + 3 174 | latent_value = np.load(path_lat + '/f' + str(frame_index).zfill(3) + '_' + lat + '.npy') 175 | 176 | # run RPM 177 | bits_estimation, sigma_value, mu_value, h_state \ 178 | = sess.run([bits_est, sigma, mu, hidden_states_out], 179 | feed_dict={prior_tensor: prior_value, latent_tensor: latent_value, 180 | hidden_states: h_state}) 181 | 182 | if args.entropy_coding: 183 | bits_value = helper2.entropy_coding(frame_index, lat, path_bin, latent_value, sigma_value, mu_value) 184 | total_bits += bits_value 185 | # print('Frame', frame_index, lat + '_bits =', bits_value) 186 | else: 187 | total_bits += bits_estimation 188 | # print('Frame', frame_index, lat + '_bits =', bits_estimation) 189 | bpp[frame_index - 1] += bits_estimation / Height / Width 190 | 191 | # the latent will be the prior for the next latent 192 | prior_value = latent_value 193 | 194 | bpp_video = total_bits/args.frame/Height/Width 195 | 196 | os.system('rm ' + path_bin + '/*.bin') 197 | np.save(path_bin + 'bpp.npy', bpp) 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | -------------------------------------------------------------------------------- /arithmeticcoding.py: -------------------------------------------------------------------------------- 1 | # 2 | # Reference arithmetic coding 3 | # Copyright (c) Project Nayuki 4 | # 5 | # https://www.nayuki.io/page/reference-arithmetic-coding 6 | # https://github.com/nayuki/Reference-arithmetic-coding 7 | # 8 | 9 | import sys 10 | import numpy as np 11 | import math 12 | import scipy.special 13 | 14 | python3 = sys.version_info.major >= 3 15 | 16 | 17 | # ---- Arithmetic coding core classes ---- 18 | 19 | # Provides the state and behaviors that arithmetic coding encoders and decoders share. 20 | class ArithmeticCoderBase(object): 21 | 22 | # Constructs an arithmetic coder, which initializes the code range. 23 | def __init__(self, numbits): 24 | if numbits < 1: 25 | raise ValueError("State size out of range") 26 | 27 | # -- Configuration fields -- 28 | # Number of bits for the 'low' and 'high' state variables. Must be at least 1. 29 | # - Larger values are generally better - they allow a larger maximum frequency total (maximum_total), 30 | # and they reduce the approximation error inherent in adapting fractions to integers; 31 | # both effects reduce the data encoding loss and asymptotically approach the efficiency 32 | # of arithmetic coding using exact fractions. 33 | # - But larger state sizes increase the computation time for integer arithmetic, 34 | # and compression gains beyond ~30 bits essentially zero in real-world applications. 35 | # - Python has native bigint arithmetic, so there is no upper limit to the state size. 36 | # For Java and C++ where using native machine-sized integers makes the most sense, 37 | # they have a recommended value of num_state_bits=32 as the most versatile setting. 38 | self.num_state_bits = numbits 39 | # Maximum range (high+1-low) during coding (trivial), which is 2^num_state_bits = 1000...000. 40 | self.full_range = 1 << self.num_state_bits 41 | # The top bit at width num_state_bits, which is 0100...000. 42 | self.half_range = self.full_range >> 1 # Non-zero 43 | # The second highest bit at width num_state_bits, which is 0010...000. This is zero when num_state_bits=1. 44 | self.quarter_range = self.half_range >> 1 # Can be zero 45 | # Minimum range (high+1-low) during coding (non-trivial), which is 0010...010. 46 | self.minimum_range = self.quarter_range + 2 # At least 2 47 | # Maximum allowed total from a frequency table at all times during coding. This differs from Java 48 | # and C++ because Python's native bigint avoids constraining the size of intermediate computations. 49 | self.maximum_total = self.minimum_range 50 | # Bit mask of num_state_bits ones, which is 0111...111. 51 | self.state_mask = self.full_range - 1 52 | 53 | # -- State fields -- 54 | # Low end of this arithmetic coder's current range. Conceptually has an infinite number of trailing 0s. 55 | self.low = 0 56 | # High end of this arithmetic coder's current range. Conceptually has an infinite number of trailing 1s. 57 | self.high = self.state_mask 58 | 59 | # Updates the code range (low and high) of this arithmetic coder as a result 60 | # of processing the given symbol with the given frequency table. 61 | # Invariants that are true before and after encoding/decoding each symbol 62 | # (letting full_range = 2^num_state_bits): 63 | # - 0 <= low <= code <= high < full_range. ('code' exists only in the decoder.) 64 | # Therefore these variables are unsigned integers of num_state_bits bits. 65 | # - low < 1/2 * full_range <= high. 66 | # In other words, they are in different halves of the full range. 67 | # - (low < 1/4 * full_range) || (high >= 3/4 * full_range). 68 | # In other words, they are not both in the middle two quarters. 69 | # - Let range = high - low + 1, then full_range/4 < minimum_range 70 | # <= range <= full_range. These invariants for 'range' essentially 71 | # dictate the maximum total that the incoming frequency table can have. 72 | def update(self, freqs, symbol): 73 | # State check 74 | low = self.low 75 | high = self.high 76 | if low >= high or (low & self.state_mask) != low or (high & self.state_mask) != high: 77 | raise AssertionError("Low or high out of range") 78 | range = high - low + 1 79 | if not (self.minimum_range <= range <= self.full_range): 80 | raise AssertionError("Range out of range") 81 | 82 | # Frequency table values check 83 | total = freqs.get_total() 84 | symlow = freqs.get_low(symbol) 85 | symhigh = freqs.get_high(symbol) 86 | if symlow == symhigh: 87 | raise ValueError("Symbol has zero frequency") 88 | if total > self.maximum_total: 89 | raise ValueError("Cannot code symbol because total is too large") 90 | 91 | # Update range 92 | newlow = low + symlow * range // total 93 | newhigh = low + symhigh * range // total - 1 94 | self.low = int(newlow) 95 | self.high = int(newhigh) 96 | 97 | # While low and high have the same top bit value, shift them out 98 | while ((self.low ^ self.high) & self.half_range) == 0: 99 | self.shift() 100 | self.low = ((self.low << 1) & self.state_mask) 101 | self.high = ((self.high << 1) & self.state_mask) | 1 102 | # Now low's top bit must be 0 and high's top bit must be 1 103 | 104 | # While low's top two bits are 01 and high's are 10, delete the second highest bit of both 105 | while (self.low & ~self.high & self.quarter_range) != 0: 106 | self.underflow() 107 | self.low = (self.low << 1) ^ self.half_range 108 | self.high = ((self.high ^ self.half_range) << 1) | self.half_range | 1 109 | 110 | # Called to handle the situation when the top bit of 'low' and 'high' are equal. 111 | def shift(self): 112 | raise NotImplementedError() 113 | 114 | # Called to handle the situation when low=01(...) and high=10(...). 115 | def underflow(self): 116 | raise NotImplementedError() 117 | 118 | 119 | # Encodes symbols and writes to an arithmetic-coded bit stream. 120 | class ArithmeticEncoder(ArithmeticCoderBase): 121 | 122 | # Constructs an arithmetic coding encoder based on the given bit output stream. 123 | def __init__(self, numbits, bitout): 124 | super(ArithmeticEncoder, self).__init__(numbits) 125 | # The underlying bit output stream. 126 | self.output = bitout 127 | # Number of saved underflow bits. This value can grow without bound. 128 | self.num_underflow = 0 129 | 130 | # Encodes the given symbol based on the given frequency table. 131 | # This updates this arithmetic coder's state and may write out some bits. 132 | def write(self, freqs, symbol): 133 | if not isinstance(freqs, CheckedFrequencyTable): 134 | freqs = CheckedFrequencyTable(freqs) 135 | self.update(freqs, symbol) 136 | 137 | # Terminates the arithmetic coding by flushing any buffered bits, so that the output can be decoded properly. 138 | # It is important that this method must be called at the end of the each encoding process. 139 | # Note that this method merely writes data to the underlying output stream but does not close it. 140 | def finish(self): 141 | self.output.write(1) 142 | 143 | def shift(self): 144 | bit = self.low >> (self.num_state_bits - 1) 145 | self.output.write(bit) 146 | 147 | # Write out the saved underflow bits 148 | for _ in range(self.num_underflow): 149 | self.output.write(bit ^ 1) 150 | self.num_underflow = 0 151 | 152 | def underflow(self): 153 | self.num_underflow += 1 154 | 155 | 156 | # Reads from an arithmetic-coded bit stream and decodes symbols. 157 | class ArithmeticDecoder(ArithmeticCoderBase): 158 | 159 | # Constructs an arithmetic coding decoder based on the 160 | # given bit input stream, and fills the code bits. 161 | def __init__(self, numbits, bitin): 162 | super(ArithmeticDecoder, self).__init__(numbits) 163 | # The underlying bit input stream. 164 | self.input = bitin 165 | # The current raw code bits being buffered, which is always in the range [low, high]. 166 | self.code = 0 167 | for _ in range(self.num_state_bits): 168 | self.code = self.code << 1 | self.read_code_bit() 169 | 170 | # Decodes the next symbol based on the given frequency table and returns it. 171 | # Also updates this arithmetic coder's state and may read in some bits. 172 | def read(self, freqs): 173 | if not isinstance(freqs, CheckedFrequencyTable): 174 | freqs = CheckedFrequencyTable(freqs) 175 | 176 | # Translate from coding range scale to frequency table scale 177 | total = freqs.get_total() 178 | if total > self.maximum_total: 179 | raise ValueError("Cannot decode symbol because total is too large") 180 | range = self.high - self.low + 1 181 | offset = self.code - self.low 182 | value = ((offset + 1) * total - 1) // range 183 | assert value * range // total <= offset 184 | assert 0 <= value < total 185 | 186 | # A kind of binary search. Find highest symbol such that freqs.get_low(symbol) <= value. 187 | start = 0 188 | end = freqs.get_symbol_limit() 189 | while end - start > 1: 190 | middle = (start + end) >> 1 191 | if freqs.get_low(middle) > value: 192 | end = middle 193 | else: 194 | start = middle 195 | assert start + 1 == end 196 | 197 | symbol = start 198 | assert freqs.get_low(symbol) * range // total <= offset < freqs.get_high(symbol) * range // total 199 | self.update(freqs, symbol) 200 | if not (self.low <= self.code <= self.high): 201 | raise AssertionError("Code out of range") 202 | return symbol 203 | 204 | def shift(self): 205 | self.code = ((self.code << 1) & self.state_mask) | self.read_code_bit() 206 | 207 | def underflow(self): 208 | self.code = (self.code & self.half_range) | ((self.code << 1) & (self.state_mask >> 1)) | self.read_code_bit() 209 | 210 | # Returns the next bit (0 or 1) from the input stream. The end 211 | # of stream is treated as an infinite number of trailing zeros. 212 | def read_code_bit(self): 213 | temp = self.input.read() 214 | if temp == -1: 215 | temp = 0 216 | return temp 217 | 218 | 219 | # ---- Frequency table classes ---- 220 | 221 | # A table of symbol frequencies. The table holds data for symbols numbered from 0 222 | # to get_symbol_limit()-1. Each symbol has a frequency, which is a non-negative integer. 223 | # Frequency table objects are primarily used for getting cumulative symbol 224 | # frequencies. These objects can be mutable depending on the implementation. 225 | class FrequencyTable(object): 226 | 227 | # Returns the number of symbols in this frequency table, which is a positive number. 228 | def get_symbol_limit(self): 229 | raise NotImplementedError() 230 | 231 | # Returns the frequency of the given symbol. The returned value is at least 0. 232 | def get(self, symbol): 233 | raise NotImplementedError() 234 | 235 | # Sets the frequency of the given symbol to the given value. 236 | # The frequency value must be at least 0. 237 | def set(self, symbol, freq): 238 | raise NotImplementedError() 239 | 240 | # Increments the frequency of the given symbol. 241 | def increment(self, symbol): 242 | raise NotImplementedError() 243 | 244 | # Returns the total of all symbol frequencies. The returned value is at 245 | # least 0 and is always equal to get_high(get_symbol_limit() - 1). 246 | def get_total(self): 247 | raise NotImplementedError() 248 | 249 | # Returns the sum of the frequencies of all the symbols strictly 250 | # below the given symbol value. The returned value is at least 0. 251 | def get_low(self, symbol): 252 | raise NotImplementedError() 253 | 254 | # Returns the sum of the frequencies of the given symbol 255 | # and all the symbols below. The returned value is at least 0. 256 | def get_high(self, symbol): 257 | raise NotImplementedError() 258 | 259 | 260 | # An immutable frequency table where every symbol has the same frequency of 1. 261 | # Useful as a fallback model when no statistics are available. 262 | class FlatFrequencyTable(FrequencyTable): 263 | 264 | # Constructs a flat frequency table with the given number of symbols. 265 | def __init__(self, numsyms): 266 | if numsyms < 1: 267 | raise ValueError("Number of symbols must be positive") 268 | self.numsymbols = numsyms # Total number of symbols, which is at least 1 269 | 270 | # Returns the number of symbols in this table, which is at least 1. 271 | def get_symbol_limit(self): 272 | return self.numsymbols 273 | 274 | # Returns the frequency of the given symbol, which is always 1. 275 | def get(self, symbol): 276 | self._check_symbol(symbol) 277 | return 1 278 | 279 | # Returns the total of all symbol frequencies, which is 280 | # always equal to the number of symbols in this table. 281 | def get_total(self): 282 | return self.numsymbols 283 | 284 | # Returns the sum of the frequencies of all the symbols strictly below 285 | # the given symbol value. The returned value is equal to 'symbol'. 286 | def get_low(self, symbol): 287 | self._check_symbol(symbol) 288 | return symbol 289 | 290 | # Returns the sum of the frequencies of the given symbol and all 291 | # the symbols below. The returned value is equal to 'symbol' + 1. 292 | def get_high(self, symbol): 293 | self._check_symbol(symbol) 294 | return symbol + 1 295 | 296 | # Returns silently if 0 <= symbol < numsymbols, otherwise raises an exception. 297 | def _check_symbol(self, symbol): 298 | if 0 <= symbol < self.numsymbols: 299 | return 300 | else: 301 | raise ValueError("Symbol out of range") 302 | 303 | # Returns a string representation of this frequency table. The format is subject to change. 304 | def __str__(self): 305 | return "FlatFrequencyTable={}".format(self.numsymbols) 306 | 307 | # Unsupported operation, because this frequency table is immutable. 308 | def set(self, symbol, freq): 309 | raise NotImplementedError() 310 | 311 | # Unsupported operation, because this frequency table is immutable. 312 | def increment(self, symbol): 313 | raise NotImplementedError() 314 | 315 | 316 | # A mutable table of symbol frequencies. The number of symbols cannot be changed 317 | # after construction. The current algorithm for calculating cumulative frequencies 318 | # takes linear time, but there exist faster algorithms such as Fenwick trees. 319 | class SimpleFrequencyTable(FrequencyTable): 320 | 321 | # Constructs a simple frequency table in one of two ways: 322 | # - SimpleFrequencyTable(sequence): 323 | # Builds a frequency table from the given sequence of symbol frequencies. 324 | # There must be at least 1 symbol, and no symbol has a negative frequency. 325 | # - SimpleFrequencyTable(freqtable): 326 | # Builds a frequency table by copying the given frequency table. 327 | def __init__(self, freqs): 328 | if isinstance(freqs, FrequencyTable): 329 | numsym = freqs.get_symbol_limit() 330 | self.frequencies = [freqs.get(i) for i in range(numsym)] 331 | else: # Assume it is a sequence type 332 | self.frequencies = list(freqs) # Make copy 333 | 334 | # 'frequencies' is a list of the frequency for each symbol. 335 | # Its length is at least 1, and each element is non-negative. 336 | if len(self.frequencies) < 1: 337 | raise ValueError("At least 1 symbol needed") 338 | for freq in self.frequencies: 339 | if freq < 0: 340 | raise ValueError("Negative frequency") 341 | 342 | # Always equal to the sum of 'frequencies' 343 | self.total = sum(self.frequencies) 344 | 345 | # cumulative[i] is the sum of 'frequencies' from 0 (inclusive) to i (exclusive). 346 | # Initialized lazily. When it is not None, the data is valid. 347 | self.cumulative = None 348 | 349 | # Returns the number of symbols in this frequency table, which is at least 1. 350 | def get_symbol_limit(self): 351 | return len(self.frequencies) 352 | 353 | # Returns the frequency of the given symbol. The returned value is at least 0. 354 | def get(self, symbol): 355 | self._check_symbol(symbol) 356 | return self.frequencies[symbol] 357 | 358 | # Sets the frequency of the given symbol to the given value. The frequency value 359 | # must be at least 0. If an exception is raised, then the state is left unchanged. 360 | def set(self, symbol, freq): 361 | self._check_symbol(symbol) 362 | if freq < 0: 363 | raise ValueError("Negative frequency") 364 | temp = self.total - self.frequencies[symbol] 365 | assert temp >= 0 366 | self.total = temp + freq 367 | self.frequencies[symbol] = freq 368 | self.cumulative = None 369 | 370 | # Increments the frequency of the given symbol. 371 | def increment(self, symbol): 372 | self._check_symbol(symbol) 373 | self.total += 1 374 | self.frequencies[symbol] += 1 375 | self.cumulative = None 376 | 377 | # Returns the total of all symbol frequencies. The returned value is at 378 | # least 0 and is always equal to get_high(get_symbol_limit() - 1). 379 | def get_total(self): 380 | return self.total 381 | 382 | # Returns the sum of the frequencies of all the symbols strictly 383 | # below the given symbol value. The returned value is at least 0. 384 | def get_low(self, symbol): 385 | self._check_symbol(symbol) 386 | if self.cumulative is None: 387 | self._init_cumulative() 388 | return self.cumulative[symbol] 389 | 390 | # Returns the sum of the frequencies of the given symbol 391 | # and all the symbols below. The returned value is at least 0. 392 | def get_high(self, symbol): 393 | self._check_symbol(symbol) 394 | if self.cumulative is None: 395 | self._init_cumulative() 396 | return self.cumulative[symbol + 1] 397 | 398 | # Recomputes the array of cumulative symbol frequencies. 399 | def _init_cumulative(self): 400 | cumul = [0] 401 | sum = 0 402 | for freq in self.frequencies: 403 | sum += freq 404 | cumul.append(sum) 405 | assert sum == self.total 406 | self.cumulative = cumul 407 | 408 | # Returns silently if 0 <= symbol < len(frequencies), otherwise raises an exception. 409 | def _check_symbol(self, symbol): 410 | if 0 <= symbol < len(self.frequencies): 411 | return 412 | else: 413 | raise ValueError("Symbol out of range") 414 | 415 | # Returns a string representation of this frequency table, 416 | # useful for debugging only, and the format is subject to change. 417 | def __str__(self): 418 | result = "" 419 | for (i, freq) in enumerate(self.frequencies): 420 | result += "{}\t{}\n".format(i, freq) 421 | return result 422 | 423 | 424 | # A wrapper that checks the preconditions (arguments) and postconditions (return value) of all 425 | # the frequency table methods. Useful for finding faults in a frequency table implementation. 426 | class CheckedFrequencyTable(FrequencyTable): 427 | 428 | def __init__(self, freqtab): 429 | # The underlying frequency table that holds the data 430 | self.freqtable = freqtab 431 | 432 | def get_symbol_limit(self): 433 | result = self.freqtable.get_symbol_limit() 434 | if result <= 0: 435 | raise AssertionError("Non-positive symbol limit") 436 | return result 437 | 438 | def get(self, symbol): 439 | result = self.freqtable.get(symbol) 440 | if not self._is_symbol_in_range(symbol): 441 | raise AssertionError("ValueError expected") 442 | if result < 0: 443 | raise AssertionError("Negative symbol frequency") 444 | return result 445 | 446 | def get_total(self): 447 | result = self.freqtable.get_total() 448 | if result < 0: 449 | raise AssertionError("Negative total frequency") 450 | return result 451 | 452 | def get_low(self, symbol): 453 | if self._is_symbol_in_range(symbol): 454 | low = self.freqtable.get_low(symbol) 455 | high = self.freqtable.get_high(symbol) 456 | if not (0 <= low <= high <= self.freqtable.get_total()): 457 | raise AssertionError("Symbol low cumulative frequency out of range") 458 | return low 459 | else: 460 | self.freqtable.get_low(symbol) 461 | raise AssertionError("ValueError expected") 462 | 463 | def get_high(self, symbol): 464 | if self._is_symbol_in_range(symbol): 465 | low = self.freqtable.get_low(symbol) 466 | high = self.freqtable.get_high(symbol) 467 | if not (0 <= low <= high <= self.freqtable.get_total()): 468 | raise AssertionError("Symbol high cumulative frequency out of range") 469 | return high 470 | else: 471 | self.freqtable.get_high(symbol) 472 | raise AssertionError("ValueError expected") 473 | 474 | def __str__(self): 475 | return "CheckedFrequencyTable (" + str(self.freqtable) + ")" 476 | 477 | def set(self, symbol, freq): 478 | self.freqtable.set(symbol, freq) 479 | if not self._is_symbol_in_range(symbol) or freq < 0: 480 | raise AssertionError("ValueError expected") 481 | 482 | def increment(self, symbol): 483 | self.freqtable.increment(symbol) 484 | if not self._is_symbol_in_range(symbol): 485 | raise AssertionError("ValueError expected") 486 | 487 | def _is_symbol_in_range(self, symbol): 488 | return 0 <= symbol < self.get_symbol_limit() 489 | 490 | 491 | # ---- Bit-oriented I/O streams ---- 492 | 493 | # A stream of bits that can be read. Because they come from an underlying byte stream, 494 | # the total number of bits is always a multiple of 8. The bits are read in big endian. 495 | class BitInputStream(object): 496 | 497 | # Constructs a bit input stream based on the given byte input stream. 498 | def __init__(self, inp): 499 | # The underlying byte stream to read from 500 | self.input = inp 501 | # Either in the range [0x00, 0xFF] if bits are available, or -1 if end of stream is reached 502 | self.currentbyte = 0 503 | # Number of remaining bits in the current byte, always between 0 and 7 (inclusive) 504 | self.numbitsremaining = 0 505 | 506 | # Reads a bit from this stream. Returns 0 or 1 if a bit is available, or -1 if 507 | # the end of stream is reached. The end of stream always occurs on a byte boundary. 508 | def read(self): 509 | if self.currentbyte == -1: 510 | return -1 511 | if self.numbitsremaining == 0: 512 | temp = self.input.read(1) 513 | if len(temp) == 0: 514 | self.currentbyte = -1 515 | return -1 516 | self.currentbyte = temp[0] if python3 else ord(temp) 517 | self.numbitsremaining = 8 518 | assert self.numbitsremaining > 0 519 | self.numbitsremaining -= 1 520 | return (self.currentbyte >> self.numbitsremaining) & 1 521 | 522 | # Reads a bit from this stream. Returns 0 or 1 if a bit is available, or raises an EOFError 523 | # if the end of stream is reached. The end of stream always occurs on a byte boundary. 524 | def read_no_eof(self): 525 | result = self.read() 526 | if result != -1: 527 | return result 528 | else: 529 | raise EOFError() 530 | 531 | # Closes this stream and the underlying input stream. 532 | def close(self): 533 | self.input.close() 534 | self.currentbyte = -1 535 | self.numbitsremaining = 0 536 | 537 | 538 | # A stream where bits can be written to. Because they are written to an underlying 539 | # byte stream, the end of the stream is padded with 0's up to a multiple of 8 bits. 540 | # The bits are written in big endian. 541 | class BitOutputStream(object): 542 | 543 | # Constructs a bit output stream based on the given byte output stream. 544 | def __init__(self, out): 545 | self.output = out # The underlying byte stream to write to 546 | self.currentbyte = 0 # The accumulated bits for the current byte, always in the range [0x00, 0xFF] 547 | self.numbitsfilled = 0 # Number of accumulated bits in the current byte, always between 0 and 7 (inclusive) 548 | 549 | # Writes a bit to the stream. The given bit must be 0 or 1. 550 | def write(self, b): 551 | if b not in (0, 1): 552 | raise ValueError("Argument must be 0 or 1") 553 | self.currentbyte = (self.currentbyte << 1) | b 554 | self.numbitsfilled += 1 555 | if self.numbitsfilled == 8: 556 | towrite = bytes((self.currentbyte,)) if python3 else chr(self.currentbyte) 557 | self.output.write(towrite) 558 | self.currentbyte = 0 559 | self.numbitsfilled = 0 560 | 561 | # Closes this stream and the underlying output stream. If called when this 562 | # bit stream is not at a byte boundary, then the minimum number of "0" bits 563 | # (between 0 and 7 of them) are written as padding to reach the next byte boundary. 564 | def close(self): 565 | while self.numbitsfilled != 0: 566 | self.write(0) 567 | self.output.close() 568 | 569 | class TryFrequencyTable(FrequencyTable): 570 | 571 | def __init__(self, f): 572 | 573 | self.f = f 574 | self.num_symbols = len(f) 575 | self.total = np.sum(self.f) 576 | self.cumulative = None 577 | 578 | # Returns the number of symbols in this frequency table, which is at least 1. 579 | def get_symbol_limit(self): 580 | return self.num_symbols 581 | 582 | # Returns the frequency of the given symbol. The returned value is at least 0. 583 | def get(self, symbol): 584 | self._check_symbol(symbol) 585 | freq = self.f[symbol] 586 | return freq 587 | 588 | def get_total(self): 589 | return self.total 590 | 591 | # Returns the sum of the frequencies of all the symbols strictly 592 | # below the given symbol value. The returned value is at least 0. 593 | def get_low(self, symbol): 594 | self._check_symbol(symbol) 595 | low = np.sum(self.f[0: symbol]) 596 | return low 597 | 598 | # Returns the sum of the frequencies of the given symbol 599 | # and all the symbols below. The returned value is at least 0. 600 | def get_high(self, symbol): 601 | self._check_symbol(symbol) 602 | high = np.sum(self.f[0: symbol + 1]) 603 | 604 | return high 605 | 606 | # Returns silently if 0 <= symbol < len(frequencies), otherwise raises an exception. 607 | def _check_symbol(self, symbol): 608 | if 0 <= symbol < self.get_symbol_limit(): 609 | return 610 | else: 611 | raise ValueError("Symbol out of range") 612 | 613 | # Returns a string representation of this frequency table, 614 | # useful for debugging only, and the format is subject to change. 615 | def __str__(self): 616 | result = "" 617 | return result 618 | 619 | 620 | class OurFrequencyTable(FrequencyTable): 621 | 622 | def __init__(self, f): 623 | 624 | self.f = f 625 | self.num_symbols = len(f) 626 | self.total = np.sum(self.f) 627 | self.cumulative = None 628 | 629 | # Returns the number of symbols in this frequency table, which is at least 1. 630 | def get_symbol_limit(self): 631 | return self.num_symbols 632 | 633 | # Returns the frequency of the given symbol. The returned value is at least 0. 634 | def get(self, symbol): 635 | self._check_symbol(symbol) 636 | freq = self.f[symbol] 637 | return freq 638 | 639 | def get_total(self): 640 | return self.total 641 | 642 | # Returns the sum of the frequencies of all the symbols strictly 643 | # below the given symbol value. The returned value is at least 0. 644 | def get_low(self, symbol): 645 | self._check_symbol(symbol) 646 | low = np.int(np.sum(self.f[0: symbol])) 647 | return low 648 | 649 | # Returns the sum of the frequencies of the given symbol 650 | # and all the symbols below. The returned value is at least 0. 651 | def get_high(self, symbol): 652 | self._check_symbol(symbol) 653 | high = np.sum(self.f[0: symbol + 1]) 654 | 655 | return high 656 | 657 | # Returns silently if 0 <= symbol < len(frequencies), otherwise raises an exception. 658 | def _check_symbol(self, symbol): 659 | if 0 <= symbol < self.get_symbol_limit(): 660 | return 661 | else: 662 | raise ValueError("Symbol out of range") 663 | 664 | # Returns a string representation of this frequency table, 665 | # useful for debugging only, and the format is subject to change. 666 | def __str__(self): 667 | result = "" 668 | return result 669 | 670 | class ModelFrequencyTable(FrequencyTable): 671 | # Constructs a simple frequency table in one of two ways: 672 | # - SimpleFrequencyTable(sequence): 673 | # Builds a frequency table from the given sequence of symbol frequencies. 674 | # There must be at least 1 symbol, and no symbol has a negative frequency. 675 | # - SimpleFrequencyTable(freqtable): 676 | # Builds a frequency table by copying the given frequency table. 677 | def __init__(self, mu_val=0, sigma_val=1): 678 | self.mul_factor = 10000000 679 | self.num_symbols = 513 680 | self.EOF = self.num_symbols - 1 681 | 682 | self.mu_val = mu_val 683 | self.sigma_val = np.abs(sigma_val) 684 | 685 | # self.TINY = 1e-2 686 | self.TINY = 1e-10 687 | 688 | # print("mu_val: " + str(mu_val)) 689 | # print("sigma_val: " + str(sigma_val)) 690 | # Always equal to the sum of 'frequencies' 691 | self.total = self.mul_factor + 513 692 | 693 | # cumulative[i] is the sum of 'frequencies' from 0 (inclusive) to i (exclusive). 694 | # Initialized lazily. When it is not None, the data is valid. 695 | self.cumulative = None 696 | 697 | def set_mu(self, mu_val): 698 | self.mu_val = mu_val 699 | 700 | def set_sigma(self, sigma_val): 701 | self.sigma_val = sigma_val 702 | 703 | # Returns the number of symbols in this frequency table, which is at least 1. 704 | def get_symbol_limit(self): 705 | return self.num_symbols 706 | 707 | # Returns the frequency of the given symbol. The returned value is at least 0. 708 | def get(self, symbol): 709 | self._check_symbol(symbol) 710 | 711 | if symbol == self.EOF: 712 | return 1 713 | else: 714 | c2 = 0.5 * (1 + scipy.special.erf((symbol + 0.5 - self.mu_val) / ((self.sigma_val + self.TINY) * 2 ** 0.5))) 715 | c1 = 0.5 * (1 + scipy.special.erf((symbol - 0.5 - self.mu_val) / ((self.sigma_val + self.TINY) * 2 ** 0.5))) 716 | freq = int(math.floor((c2 - c1) * self.mul_factor) + 1) 717 | 718 | return freq 719 | 720 | # Sets the frequency of the given symbol to the given value. The frequency value 721 | # must be at least 0. If an exception is raised, then the state is left unchanged. 722 | # def set(self, symbol, freq): 723 | # self._check_symbol(symbol) 724 | # if freq < 0: 725 | # raise ValueError("Negative frequency") 726 | # temp = self.total - self.frequencies[symbol] 727 | # assert temp >= 0 728 | # self.total = temp + freq 729 | # self.frequencies[symbol] = freq 730 | # self.cumulative = None 731 | # 732 | # # Increments the frequency of the given symbol. 733 | # def increment(self, symbol): 734 | # self._check_symbol(symbol) 735 | # self.total += 1 736 | # self.frequencies[symbol] += 1 737 | # self.cumulative = None 738 | 739 | # Returns the total of all symbol frequencies. The returned value is at 740 | # least 0 and is always equal to get_high(get_symbol_limit() - 1). 741 | def get_total(self): 742 | return self.total 743 | 744 | # Returns the sum of the frequencies of all the symbols strictly 745 | # below the given symbol value. The returned value is at least 0. 746 | def get_low(self, symbol): 747 | self._check_symbol(symbol) 748 | c = 0.5 * (1 + scipy.special.erf(((symbol-1) + 0.5 - self.mu_val) / ((self.sigma_val + self.TINY) * 2 ** 0.5))) 749 | c = int(math.floor(c * self.mul_factor) + symbol) 750 | return c 751 | 752 | # Returns the sum of the frequencies of the given symbol 753 | # and all the symbols below. The returned value is at least 0. 754 | def get_high(self, symbol): 755 | self._check_symbol(symbol) 756 | c = 0.5 * (1 + scipy.special.erf(((symbol) + 0.5 - self.mu_val) / ((self.sigma_val + self.TINY) * 2 ** 0.5))) 757 | c = int(math.floor(c * self.mul_factor) + symbol + 1) 758 | 759 | # if symbol == self.EOF: 760 | # c = c + 1 761 | 762 | return c 763 | 764 | # Recomputes the array of cumulative symbol frequencies. 765 | # def _init_cumulative(self): 766 | # cumul = [0] 767 | # sum = 0 768 | # for freq in self.frequencies: 769 | # sum += freq 770 | # cumul.append(sum) 771 | # assert sum == self.total 772 | # self.cumulative = cumul 773 | 774 | # Returns silently if 0 <= symbol < len(frequencies), otherwise raises an exception. 775 | def _check_symbol(self, symbol): 776 | if 0 <= symbol < self.get_symbol_limit(): 777 | return 778 | else: 779 | raise ValueError("Symbol out of range") 780 | 781 | # Returns a string representation of this frequency table, 782 | # useful for debugging only, and the format is subject to change. 783 | def __str__(self): 784 | result = "" 785 | # for (i, freq) in enumerate(self.frequencies): 786 | # result += "{}\t{}\n".format(i, freq) 787 | return result 788 | 789 | class logFrequencyTable(FrequencyTable): 790 | # Constructs a simple frequency table in one of two ways: 791 | # - SimpleFrequencyTable(sequence): 792 | # Builds a frequency table from the given sequence of symbol frequencies. 793 | # There must be at least 1 symbol, and no symbol has a negative frequency. 794 | # - SimpleFrequencyTable(freqtable): 795 | # Builds a frequency table by copying the given frequency table. 796 | def __init__(self, mu_val, sigma_val, num_symbols): 797 | self.mul_factor = 10000000 798 | self.num_symbols = num_symbols 799 | self.mu_val = mu_val 800 | self.sigma_val = np.abs(sigma_val) 801 | 802 | # self.TINY = 1e-2 803 | self.TINY = 1e-10 804 | 805 | # Always equal to the sum of 'frequencies' 806 | self.total = self.mul_factor + 1 807 | 808 | # cumulative[i] is the sum of 'frequencies' from 0 (inclusive) to i (exclusive). 809 | # Initialized lazily. When it is not None, the data is valid. 810 | self.cumulative = None 811 | 812 | def set_mu(self, mu_val): 813 | self.mu_val = mu_val 814 | 815 | def set_sigma(self, sigma_val): 816 | self.sigma_val = sigma_val 817 | 818 | # Returns the number of symbols in this frequency table, which is at least 1. 819 | def get_symbol_limit(self): 820 | return self.num_symbols 821 | 822 | # Returns the frequency of the given symbol. The returned value is at least 0. 823 | def get(self, symbol): 824 | self._check_symbol(symbol) 825 | 826 | if symbol == self.EOF: 827 | return 1 828 | else: 829 | c2 = scipy.special.expit(np.multiply((symbol + 0.5 - self.mu_val), (self.sigma_val ** 2 + self.TINY))) 830 | c1 = scipy.special.expit(np.multiply((symbol - 0.5 - self.mu_val), (self.sigma_val ** 2 + self.TINY))) 831 | freq = int(math.floor((c2 - c1) * self.mul_factor) + 1) 832 | 833 | return freq 834 | 835 | # Returns the total of all symbol frequencies. The returned value is at 836 | # least 0 and is always equal to get_high(get_symbol_limit() - 1). 837 | def get_total(self): 838 | return self.total 839 | 840 | # Returns the sum of the frequencies of all the symbols strictly 841 | # below the given symbol value. The returned value is at least 0. 842 | def get_low(self, symbol): 843 | self._check_symbol(symbol) 844 | c = scipy.special.expit(np.multiply((symbol - 1 + 0.5 - self.mu_val), (self.sigma_val ** 2 + self.TINY))) 845 | c = int(math.floor(c * self.mul_factor)) 846 | return c 847 | 848 | # Returns the sum of the frequencies of the given symbol 849 | # and all the symbols below. The returned value is at least 0. 850 | def get_high(self, symbol): 851 | self._check_symbol(symbol) 852 | c = scipy.special.expit(np.multiply((symbol + 0.5 - self.mu_val), (self.sigma_val ** 2 + self.TINY))) 853 | c = int(math.floor(c * self.mul_factor) + 1) 854 | 855 | return c 856 | 857 | # Returns silently if 0 <= symbol < len(frequencies), otherwise raises an exception. 858 | def _check_symbol(self, symbol): 859 | if 0 <= symbol < self.get_symbol_limit(): 860 | return 861 | else: 862 | raise ValueError("Symbol out of range") 863 | 864 | # Returns a string representation of this frequency table, 865 | # useful for debugging only, and the format is subject to change. 866 | def __str__(self): 867 | result = "" 868 | # for (i, freq) in enumerate(self.frequencies): 869 | # result += "{}\t{}\n".format(i, freq) 870 | return result 871 | 872 | class logFrequencyTable_exp(FrequencyTable): 873 | # Constructs a simple frequency table in one of two ways: 874 | # - SimpleFrequencyTable(sequence): 875 | # Builds a frequency table from the given sequence of symbol frequencies. 876 | # There must be at least 1 symbol, and no symbol has a negative frequency. 877 | # - SimpleFrequencyTable(freqtable): 878 | # Builds a frequency table by copying the given frequency table. 879 | def __init__(self, mu_val, sigma_val, num_symbols): 880 | self.mul_factor = 10000000 881 | self.num_symbols = num_symbols 882 | self.mu_val = mu_val 883 | self.sigma_val = np.maximum(sigma_val, -7.0) 884 | 885 | # self.TINY = 1e-2 886 | self.TINY = 1e-10 887 | 888 | # Always equal to the sum of 'frequencies' 889 | self.total = self.mul_factor + num_symbols 890 | 891 | # cumulative[i] is the sum of 'frequencies' from 0 (inclusive) to i (exclusive). 892 | # Initialized lazily. When it is not None, the data is valid. 893 | self.cumulative = None 894 | 895 | # def set_mu(self, mu_val): 896 | # self.mu_val = mu_val 897 | # 898 | # def set_sigma(self, sigma_val): 899 | # self.sigma_val = sigma_val 900 | 901 | # Returns the number of symbols in this frequency table, which is at least 1. 902 | def get_symbol_limit(self): 903 | return self.num_symbols 904 | 905 | # Returns the frequency of the given symbol. The returned value is at least 0. 906 | def get(self, symbol): 907 | self._check_symbol(symbol) 908 | 909 | if symbol == self.EOF: 910 | return 1 911 | else: 912 | c2 = scipy.special.expit(np.multiply((symbol + 0.5 - self.mu_val), (np.exp(-self.sigma_val) + self.TINY))) 913 | c1 = scipy.special.expit(np.multiply((symbol - 0.5 - self.mu_val), (np.exp(-self.sigma_val) + self.TINY))) 914 | freq = int(math.floor((c2 - c1) * self.mul_factor) + 1) 915 | # freq = int(math.floor(c2 * self.mul_factor) + 1) - int(math.floor(c1 * self.mul_factor)) 916 | 917 | return freq 918 | 919 | # Returns the total of all symbol frequencies. The returned value is at 920 | # least 0 and is always equal to get_high(get_symbol_limit() - 1). 921 | def get_total(self): 922 | return self.total 923 | 924 | # Returns the sum of the frequencies of all the symbols strictly 925 | # below the given symbol value. The returned value is at least 0. 926 | def get_low(self, symbol): 927 | self._check_symbol(symbol) 928 | c = scipy.special.expit(np.multiply((symbol - 1 + 0.5 - self.mu_val), (np.exp(-self.sigma_val) + self.TINY))) 929 | # c = int(math.floor(c * self.mul_factor)) 930 | c = int(math.floor(c * self.mul_factor) + symbol) 931 | return c 932 | 933 | # Returns the sum of the frequencies of the given symbol 934 | # and all the symbols below. The returned value is at least 0. 935 | def get_high(self, symbol): 936 | self._check_symbol(symbol) 937 | c = scipy.special.expit(np.multiply((symbol + 0.5 - self.mu_val), (np.exp(-self.sigma_val) + self.TINY))) 938 | # c = int(math.floor(c * self.mul_factor) + 1) 939 | c = int(math.floor(c * self.mul_factor) + symbol + 1) 940 | return c 941 | 942 | # Returns silently if 0 <= symbol < len(frequencies), otherwise raises an exception. 943 | def _check_symbol(self, symbol): 944 | if 0 <= symbol < self.get_symbol_limit(): 945 | return 946 | else: 947 | print(symbol) 948 | raise ValueError("Symbol out of range") 949 | 950 | # Returns a string representation of this frequency table, 951 | # useful for debugging only, and the format is subject to change. 952 | def __str__(self): 953 | result = "" 954 | # for (i, freq) in enumerate(self.frequencies): 955 | # result += "{}\t{}\n".format(i, freq) 956 | return result -------------------------------------------------------------------------------- /func.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | # config = tf.ConfigProto(allow_soft_placement=True) 5 | # sess = tf.Session(config=config) 6 | # 7 | # batch_size = 3 8 | # Height = 128 9 | # Width = 128 10 | # Channel = 3 11 | # 12 | # def read_png(path): 13 | # 14 | # image_group = [] 15 | # 16 | # index = np.random.randint(1, 6) 17 | # # index = 2 18 | # 19 | # for i in range(index, index + 3): 20 | # 21 | # string = tf.read_file(path + '/im' + str(index) + '.png') 22 | # image = tf.image.decode_image(string, channels=3) 23 | # image = tf.cast(image, tf.float32) 24 | # image /= 255 25 | # 26 | # image_group.append(image) 27 | # 28 | # return tf.stack(image_group, axis=0) 29 | # 30 | # # with tf.device("/cpu:0"): 31 | # train_files = np.load('/scratch_net/maja_second/compression/models/folder_vimeo.npy').tolist() 32 | # 33 | # train_dataset = tf.data.Dataset.from_tensor_slices(train_files) 34 | # train_dataset = train_dataset.shuffle(buffer_size=len(train_files)).repeat() 35 | # train_dataset = train_dataset.map(read_png, 36 | # num_parallel_calls=16) 37 | # train_dataset = train_dataset.map(lambda x: tf.random_crop(x, (3, Height, Width, 3)), 38 | # num_parallel_calls=16) 39 | # train_dataset = train_dataset.batch(batch_size) 40 | # train_dataset = train_dataset.prefetch(32) 41 | # 42 | # data_tensor = train_dataset.make_one_shot_iterator().get_next() 43 | # data_tensor = tf.ensure_shape(data_tensor, (batch_size, 3, Height, Width, 3)) 44 | 45 | def channel_norm(R): 46 | 47 | # R_mean, R_var = tf.nn.moments(R, axes=[1, 2, 3]) 48 | # R_mean = R_mean[:, tf.newaxis, tf.newaxis, tf.newaxis] 49 | # R_var = R_var[:, tf.newaxis, tf.newaxis, tf.newaxis] 50 | # R_dev = tf.sqrt(R_var) 51 | 52 | R_mean = tf.reduce_mean(R, axis=[1, 2, 3]) 53 | R_mean = R_mean[:, tf.newaxis, tf.newaxis, tf.newaxis] 54 | R_dev = tf.sqrt(tf.reduce_mean(tf.square(R - R_mean), axis=[1, 2, 3])) 55 | R_dev = R_dev[:, tf.newaxis, tf.newaxis, tf.newaxis] 56 | 57 | R_norm = (R - R_mean) / (R_dev + 0.0000001) 58 | R1, R2 = tf.split(R_norm, 2, axis=-1) 59 | 60 | return R1, R2, R_mean, R_dev 61 | 62 | def input_norm(x): 63 | 64 | R = tf.concat([x[:, :, :, 0:1], x[:, :, :, 3:4]], axis=-1) 65 | R1, R2, _, _ = channel_norm(R) 66 | 67 | G = tf.concat([x[:, :, :, 1:2], x[:, :, :, 4:5]], axis=-1) 68 | G1, G2, _, _ = channel_norm(G) 69 | 70 | B = tf.concat([x[:, :, :, 2:3], x[:, :, :, 5:6]], axis=-1) 71 | B1, B2, _, _ = channel_norm(B) 72 | 73 | x_norm = tf.concat([R1, G1, B1, R2, G2, B2], axis=-1) 74 | 75 | return x_norm 76 | 77 | def input_norm_2(x): 78 | 79 | x_1, x_2 = tf.split(x, 2, axis=-1) 80 | x_new = tf.stack([x_1, x_2], axis=1) 81 | 82 | x_mean = tf.reduce_mean(x_new, axis=[1, 2, 3]) 83 | x_mean = x_mean[:, tf.newaxis, tf.newaxis, tf.newaxis, :] 84 | x_dev = tf.sqrt(tf.reduce_mean((x_new - x_mean) ** 2, axis=[1, 2, 3])) 85 | x_dev = x_dev[:, tf.newaxis, tf.newaxis, tf.newaxis, :] 86 | 87 | x_norm = (x_new - x_mean) / (x_dev + 0.0000001) 88 | x_1_norm, x_2_norm = tf.split(x_norm, 2, axis=1) 89 | 90 | x_1_norm = tf.squeeze(x_1_norm) 91 | x_2_norm = tf.squeeze(x_2_norm) 92 | 93 | return tf.concat([x_1_norm, x_2_norm], axis=-1) 94 | 95 | 96 | def input_norm_np(in_frames, batch): 97 | 98 | out_frames = np.copy(in_frames) 99 | 100 | for b in range(batch): 101 | for ch in range(3): 102 | xx = np.concatenate([in_frames[b, :, :, ch:ch + 1], in_frames[b, :, :, ch + 3:ch + 4]], axis=-1) 103 | 104 | R_m_np = np.mean(xx) 105 | R_dev_np = np.std(xx) 106 | 107 | xx = (xx - R_m_np) / R_dev_np 108 | 109 | xx_1, xx_2 = np.split(xx, 2, axis=-1) 110 | 111 | out_frames[b, :, :, ch:ch + 1] = xx_1 112 | out_frames[b, :, :, ch + 3:ch + 4] = xx_2 113 | 114 | return out_frames 115 | 116 | # [frame_in1, frame_out_gt, frame_in2] = tf.split(data_tensor, 3, axis=1) 117 | # input_frames = tf.squeeze(tf.concat([frame_in1, frame_in2], axis=-1)) 118 | # x_output = input_norm(input_frames) 119 | # x_output_2 = input_norm_2(input_frames) 120 | # 121 | # x_o, x_o_2, in_frames = sess.run([x_output, x_output_2, input_frames]) 122 | # 123 | # xx = input_norm_np(in_frames, batch_size) 124 | # 125 | # mse = np.mean(np.square(x_o - x_o_2)) 126 | # psnr = 10 * np.log10(1 / mse) 127 | # print(mse, psnr) 128 | # 129 | # mse = np.mean(np.square(x_o - xx)) 130 | # psnr = 10 * np.log10(1 / mse) 131 | # print(mse, psnr) 132 | # 133 | # mse = np.mean(np.square(xx - x_o_2)) 134 | # psnr = 10 * np.log10(1 / mse) 135 | # print(mse, psnr) 136 | 137 | 138 | 139 | class ConvLSTMCell(tf.nn.rnn_cell.RNNCell): 140 | """A LSTM cell with convolutions instead of multiplications. 141 | Reference: 142 | Xingjian, S. H. I., et al. "Convolutional LSTM network: A machine learning approach for precipitation nowcasting." Advances in Neural Information Processing Systems. 2015. 143 | """ 144 | 145 | def __init__(self, shape, filters, kernel, forget_bias=1.0, activation=tf.tanh, normalize=False, peephole=False, data_format='channels_last', reuse=None): 146 | super(ConvLSTMCell, self).__init__(_reuse=reuse) 147 | self._kernel = kernel 148 | self._filters = filters 149 | self._forget_bias = forget_bias 150 | self._activation = activation 151 | self._normalize = normalize 152 | self._peephole = peephole 153 | if data_format == 'channels_last': 154 | self._size = tf.TensorShape(shape + [self._filters]) 155 | self._feature_axis = self._size.ndims 156 | self._data_format = None 157 | elif data_format == 'channels_first': 158 | self._size = tf.TensorShape([self._filters] + shape) 159 | self._feature_axis = 0 160 | self._data_format = 'NC' 161 | else: 162 | raise ValueError('Unknown data_format') 163 | 164 | @property 165 | def state_size(self): 166 | return tf.nn.rnn_cell.LSTMStateTuple(self._size, self._size) 167 | 168 | @property 169 | def output_size(self): 170 | return self._size 171 | 172 | def call(self, x, state): 173 | c, h = state 174 | 175 | x = tf.concat([x, h], axis=self._feature_axis) 176 | n = x.shape[-1].value 177 | m = 4 * self._filters if self._filters > 1 else 4 178 | W = tf.get_variable('kernel', self._kernel + [n, m]) 179 | y = tf.nn.convolution(x, W, 'SAME', data_format=self._data_format) 180 | if not self._normalize: 181 | y += tf.get_variable('bias', [m], initializer=tf.zeros_initializer()) 182 | j, i, f, o = tf.split(y, 4, axis=self._feature_axis) 183 | 184 | if self._peephole: 185 | i += tf.get_variable('W_ci', c.shape[1:]) * c 186 | f += tf.get_variable('W_cf', c.shape[1:]) * c 187 | 188 | if self._normalize: 189 | j = tf.contrib.layers.layer_norm(j) 190 | i = tf.contrib.layers.layer_norm(i) 191 | f = tf.contrib.layers.layer_norm(f) 192 | 193 | f = tf.sigmoid(f + self._forget_bias) 194 | i = tf.sigmoid(i) 195 | c = c * f + i * self._activation(j) 196 | 197 | if self._peephole: 198 | o += tf.get_variable('W_co', c.shape[1:]) * c 199 | 200 | if self._normalize: 201 | o = tf.contrib.layers.layer_norm(o) 202 | c = tf.contrib.layers.layer_norm(c) 203 | 204 | o = tf.sigmoid(o) 205 | h = o * self._activation(c) 206 | 207 | state = tf.nn.rnn_cell.LSTMStateTuple(c, h) 208 | 209 | return h, state 210 | 211 | -------------------------------------------------------------------------------- /functions.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def resblock(input, IC, OC, name, reuse=tf.AUTO_REUSE): 5 | 6 | l1 = tf.nn.relu(input, name=name + 'relu1') 7 | 8 | l1 = tf.layers.conv2d(inputs=l1, filters=np.minimum(IC, OC), kernel_size=3, strides=1, padding='same', 9 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=True), name=name + 'l1', reuse=reuse) 10 | 11 | l2 = tf.nn.relu(l1, name='relu2') 12 | 13 | l2 = tf.layers.conv2d(inputs=l2, filters=OC, kernel_size=3, strides=1, padding='same', 14 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=True), name=name + 'l2', reuse=reuse) 15 | 16 | if IC != OC: 17 | input = tf.layers.conv2d(inputs=input, filters=OC, kernel_size=1, strides=1, padding='same', 18 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=True), name=name + 'map', reuse=reuse) 19 | 20 | return input + l2 21 | 22 | 23 | def MC_RLVC(input, reuse=tf.AUTO_REUSE): 24 | 25 | m1 = tf.layers.conv2d(inputs=input, filters=64, kernel_size=3, strides=1, padding='same', 26 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=True), name='mc1', reuse=reuse) 27 | 28 | m2 = resblock(m1, 64, 64, name='mc2', reuse=reuse) 29 | 30 | m3 = tf.layers.average_pooling2d(m2, pool_size=2, strides=2, padding='same') 31 | 32 | m4 = resblock(m3, 64, 64, name='mc4', reuse=reuse) 33 | 34 | m5 = tf.layers.average_pooling2d(m4, pool_size=2, strides=2, padding='same') 35 | 36 | m6 = resblock(m5, 64, 64, name='mc6', reuse=reuse) 37 | 38 | m7 = resblock(m6, 64, 64, name='mc7', reuse=reuse) 39 | 40 | m8 = tf.image.resize_images(m7, [2 * tf.shape(m7)[1], 2 * tf.shape(m7)[2]]) 41 | 42 | m8 = m4 + m8 43 | 44 | m9 = resblock(m8, 64, 64, name='mc9', reuse=reuse) 45 | 46 | m10 = tf.image.resize_images(m9, [2 * tf.shape(m9)[1], 2 * tf.shape(m9)[2]]) 47 | 48 | m10 = m2 + m10 49 | 50 | m11 = resblock(m10, 64, 64, name='mc11', reuse=reuse) 51 | 52 | m12 = tf.layers.conv2d(inputs=m11, filters=64, kernel_size=3, strides=1, padding='same', 53 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=True), name='mc12', reuse=reuse) 54 | 55 | m12 = tf.nn.relu(m12, name='relu12') 56 | 57 | m13 = tf.layers.conv2d(inputs=m12, filters=3, kernel_size=3, strides=1, padding='same', 58 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=True), name='mc13', reuse=reuse) 59 | 60 | return m13 61 | 62 | 63 | def cnn_layers(tensor, layer, num_filters, out_filters, kernel, stride=2, uni=True, act=tf.nn.relu, act_last=None, reuse=tf.AUTO_REUSE): 64 | 65 | for l in range(layer-1): 66 | 67 | tensor = tf.layers.conv2d(inputs=tensor, filters=num_filters, kernel_size=kernel, padding='same', 68 | reuse=reuse, activation=act, strides=stride, 69 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=uni), name='cnn_' + str(l + 1)) 70 | 71 | tensor = tf.layers.conv2d(inputs=tensor, filters=out_filters, kernel_size=kernel, padding='same', 72 | reuse=reuse, activation=act_last, strides=stride, 73 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=uni), name='cnn_' + str(layer)) 74 | 75 | return tensor 76 | 77 | 78 | def dnn_layers(tensor, layer, num_filters, out_filters, kernel, stride=2, uni=True, act=tf.nn.relu, act_last=None, reuse=tf.AUTO_REUSE): 79 | 80 | for l in range(layer-1): 81 | 82 | tensor = tf.layers.conv2d_transpose(inputs=tensor, filters=num_filters, kernel_size=kernel, padding='same', 83 | reuse=reuse, activation=act, strides=stride, 84 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=uni), name='dnn_' + str(l + 1)) 85 | 86 | tensor = tf.layers.conv2d_transpose(inputs=tensor, filters=out_filters, kernel_size=kernel, padding='same', 87 | reuse=reuse, activation=act_last, strides=stride, 88 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=uni), name='dnn_' + str(layer)) 89 | 90 | return tensor 91 | 92 | 93 | def recurrent_cnn(tensor, step, layer, num_filters, out_filters, kernel, stride=2, uni=True, act=tf.nn.relu, act_last=None, reuse=tf.AUTO_REUSE): 94 | 95 | for i in range(step): 96 | 97 | tensor_i = tensor[:, i, :, :, :] 98 | tensor_i = cnn_layers(tensor_i, layer, num_filters, out_filters, kernel, stride, uni, act, act_last, reuse) 99 | 100 | if i == 0: 101 | tensor_out = tf.expand_dims(tensor_i, 1) 102 | else: 103 | tensor_out = tf.concat([tensor_out, tf.expand_dims(tensor_i, 1)], axis=1) 104 | 105 | return tensor_out 106 | 107 | 108 | def recurrent_dnn(tensor, step, layer, num_filters, out_filters, kernel, stride=2, uni=True, act=tf.nn.relu, act_last=None, reuse=tf.AUTO_REUSE): 109 | 110 | for i in range(step): 111 | 112 | tensor_i = tensor[:, i, :, :, :] 113 | tensor_i = dnn_layers(tensor_i, layer, num_filters, out_filters, kernel, stride, uni, act, act_last, reuse) 114 | 115 | if i == 0: 116 | tensor_out = tf.expand_dims(tensor_i, 1) 117 | else: 118 | tensor_out = tf.concat([tensor_out, tf.expand_dims(tensor_i, 1)], axis=1) 119 | 120 | return tensor_out 121 | 122 | 123 | class ConvLSTMCell(tf.nn.rnn_cell.RNNCell): 124 | """A LSTM cell with convolutions instead of multiplications. 125 | Reference: 126 | Xingjian, S. H. I., et al. "Convolutional LSTM network: A machine learning approach for precipitation nowcasting." Advances in Neural Information Processing Systems. 2015. 127 | """ 128 | 129 | def __init__(self, shape, filters, kernel, forget_bias=1.0, activation=tf.tanh, normalize=False, peephole=False, data_format='channels_last', reuse=None): 130 | super(ConvLSTMCell, self).__init__(_reuse=reuse) 131 | self._kernel = kernel 132 | self._filters = filters 133 | self._forget_bias = forget_bias 134 | self._activation = activation 135 | self._normalize = normalize 136 | self._peephole = peephole 137 | if data_format == 'channels_last': 138 | self._size = tf.TensorShape(shape + [self._filters]) 139 | self._feature_axis = self._size.ndims 140 | self._data_format = None 141 | elif data_format == 'channels_first': 142 | self._size = tf.TensorShape([self._filters] + shape) 143 | self._feature_axis = 0 144 | self._data_format = 'NC' 145 | else: 146 | raise ValueError('Unknown data_format') 147 | 148 | @property 149 | def state_size(self): 150 | return tf.nn.rnn_cell.LSTMStateTuple(self._size, self._size) 151 | 152 | @property 153 | def output_size(self): 154 | return self._size 155 | 156 | def call(self, x, state): 157 | c, h = state 158 | 159 | x = tf.concat([x, h], axis=self._feature_axis) 160 | n = x.shape[-1].value 161 | m = 4 * self._filters if self._filters > 1 else 4 162 | W = tf.get_variable('kernel', self._kernel + [n, m]) 163 | y = tf.nn.convolution(x, W, 'SAME', data_format=self._data_format) 164 | if not self._normalize: 165 | y += tf.get_variable('bias', [m], initializer=tf.zeros_initializer()) 166 | j, i, f, o = tf.split(y, 4, axis=self._feature_axis) 167 | 168 | if self._peephole: 169 | i += tf.get_variable('W_ci', c.shape[1:]) * c 170 | f += tf.get_variable('W_cf', c.shape[1:]) * c 171 | 172 | if self._normalize: 173 | j = tf.contrib.layers.layer_norm(j) 174 | i = tf.contrib.layers.layer_norm(i) 175 | f = tf.contrib.layers.layer_norm(f) 176 | 177 | f = tf.sigmoid(f + self._forget_bias) 178 | i = tf.sigmoid(i) 179 | c = c * f + i * self._activation(j) 180 | 181 | if self._peephole: 182 | o += tf.get_variable('W_co', c.shape[1:]) * c 183 | 184 | if self._normalize: 185 | o = tf.contrib.layers.layer_norm(o) 186 | c = tf.contrib.layers.layer_norm(c) 187 | 188 | o = tf.sigmoid(o) 189 | h = o * self._activation(c) 190 | 191 | state = tf.nn.rnn_cell.LSTMStateTuple(c, h) 192 | 193 | return h, state 194 | 195 | 196 | -------------------------------------------------------------------------------- /functions_inter.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_compression as tfc 3 | import sepconv_inter_enc as nn 4 | import sepconv_inter as nn_inter 5 | import motion 6 | import MC_network_inter 7 | import CNN_img 8 | import numpy as np 9 | 10 | def inter_short(frame_1, frame_2): 11 | 12 | frames = tf.concat([frame_1, frame_2], axis=-1) 13 | output_frame, _ = nn_inter.get_network_pp(frames, motion_flag='flow_mc') 14 | 15 | return output_frame 16 | 17 | def inter_long(ref_frames): 18 | 19 | [frame_in0, frame_in1, frame_in2, frame_in3] = tf.split(ref_frames, 4, axis=1) 20 | 21 | input_pair_1 = tf.squeeze(tf.concat([frame_in0, frame_in1], axis=-1), axis=1) 22 | input_pair_2 = tf.squeeze(tf.concat([frame_in1, frame_in2], axis=-1), axis=1) 23 | input_pair_3 = tf.squeeze(tf.concat([frame_in2, frame_in3], axis=-1), axis=1) 24 | input_all = tf.squeeze(tf.concat([frame_in0, frame_in1, frame_in2, frame_in3], axis=-1), axis=1) 25 | 26 | pool5_1, skip4_1, skip3_1, skip2_1 = nn.get_network_enc(input_pair_1, 'enc_1') 27 | pool5_2, skip4_2, skip3_2, skip2_2 = nn.get_network_enc(input_pair_2, 'enc_2') 28 | pool5_3, skip4_3, skip3_3, skip2_3 = nn.get_network_enc(input_pair_3, 'enc_3') 29 | 30 | pool5 = nn.conv_map(pool5_1, pool5_2, pool5_3, 512, 'map_5') 31 | skip4 = nn.conv_map(skip4_1, skip4_2, skip4_3, 256, 'map_4') 32 | skip3 = nn.conv_map(skip3_1, skip3_2, skip3_3, 128, 'map_3') 33 | skip2 = nn.conv_map(skip2_1, skip2_2, skip2_3, 64, 'map_2') 34 | 35 | output_frame = nn.get_network_dec(pool5, skip4, skip3, skip2, input_all) 36 | 37 | return output_frame 38 | 39 | def DVC_compress(Y0_com, Y1_raw, entropy_mv, entropy_res, batch_size, Height, Width, args, training=True): 40 | 41 | with tf.variable_scope("flow_motion"): 42 | 43 | flow_tensor, _, _, _, _, _ = motion.optical_flow(Y0_com, Y1_raw, batch_size, Height, Width) 44 | 45 | # Encode flow 46 | flow_latent = CNN_img.MV_analysis(flow_tensor, args.N, args.M) 47 | 48 | string_mv = entropy_mv.compress(flow_latent) 49 | # string_mv = tf.squeeze(string_mv, axis=0) 50 | 51 | flow_latent_hat, MV_likelihoods = entropy_mv(flow_latent, training=training) 52 | 53 | flow_hat = CNN_img.MV_synthesis(flow_latent_hat, args.N) 54 | 55 | # Motion Compensation 56 | Y1_warp = tf.contrib.image.dense_image_warp(Y0_com, flow_hat) 57 | 58 | MC_input = tf.concat([flow_hat, Y0_com, Y1_warp], axis=-1) 59 | Y1_MC = MC_network_inter.MC(MC_input) 60 | 61 | # Encode residual 62 | Res = Y1_raw - Y1_MC 63 | 64 | res_latent = CNN_img.Res_analysis(Res, num_filters=args.N, M=args.M) 65 | 66 | string_res = entropy_res.compress(res_latent) 67 | # string_res = tf.squeeze(string_res, axis=0) 68 | 69 | res_latent_hat, Res_likelihoods = entropy_res(res_latent, training=training) 70 | 71 | Res_hat = CNN_img.Res_synthesis(res_latent_hat, num_filters=args.N) 72 | 73 | # Reconstructed frame 74 | Y1_com = Res_hat + Y1_MC 75 | 76 | # Total number of bits divided by number of pixels. 77 | train_bpp_MV = tf.reduce_sum(tf.log(MV_likelihoods)) / (-np.log(2) * Height * Width * batch_size) 78 | train_bpp_Res = tf.reduce_sum(tf.log(Res_likelihoods)) / (-np.log(2) * Height * Width * batch_size) 79 | train_bpp = train_bpp_MV + train_bpp_Res 80 | 81 | # Mean squared error across pixels. 82 | # if args.mode == 'PSNR': 83 | total_mse = tf.reduce_mean(tf.squared_difference(Y1_com, Y1_raw)) 84 | psnr = 10.0*tf.log(1.0/total_mse)/tf.log(10.0) 85 | # else: 86 | # total_mse = 0 87 | # psnr = tf.math.reduce_mean(tf.image.ssim_multiscale(tf.clip_by_value(Y1_com, 0, 1), Y1_raw, max_val=1)) 88 | 89 | return Y1_com, total_mse, psnr, train_bpp, flow_latent_hat, res_latent_hat 90 | 91 | 92 | -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import os 4 | from scipy import misc 5 | from ms_ssim_np import MultiScaleSSIM 6 | import arithmeticcoding 7 | 8 | def configure(args, path_root, path_raw): 9 | 10 | if args.l == 256: 11 | I_level = 37 12 | elif args.l == 512: 13 | I_level = 32 14 | elif args.l == 1024: 15 | I_level = 27 16 | elif args.l == 2048: 17 | I_level = 22 18 | 19 | elif args.l == 8: 20 | I_level = 3 21 | elif args.l == 16: 22 | I_level = 4 23 | elif args.l == 32: 24 | I_level = 5 25 | elif args.l == 64: 26 | I_level = 6 27 | 28 | path = args.path + '/' 29 | 30 | # if args.mode == 'MS-SSIM': 31 | # path_com = path_root + args.path + '_SSIM_' + str(args.l) + '/frames/' 32 | # path_bin = path_root + args.path + '_SSIM_' + str(args.l) + '/bitstreams/' 33 | # path_lat = path_root + args.path + '_SSIM_' + str(args.l) + '/latents/' 34 | # else: 35 | path_com = path_root + args.path + '_' + args.mode + '_' + str(args.l) + '/frames/' 36 | path_bin = path_root + args.path + '_' + args.mode + '_' + str(args.l) + '/RD_results/' 37 | path_lat = path_root + args.path + '_' + args.mode + '_' + str(args.l) + '/latents/' 38 | 39 | os.makedirs(path_com, exist_ok=True) 40 | os.makedirs(path_bin, exist_ok=True) 41 | os.makedirs(path_lat, exist_ok=True) 42 | 43 | F1 = misc.imread(path_raw + 'f001.png') 44 | Height = np.size(F1, 0) 45 | Width = np.size(F1, 1) 46 | batch_size = 1 47 | Channel = 3 48 | 49 | if (Height % 16 != 0) or (Width % 16 != 0): 50 | raise ValueError('Height and Width must be a mutiple of 16.') 51 | 52 | activation = tf.nn.relu 53 | 54 | GOP_size = args.f_P + args.b_P + 1 + args.inter 55 | GOP_num = int(np.floor((args.frame - 1)/GOP_size)) 56 | 57 | return I_level, Height, Width, batch_size, \ 58 | Channel, activation, GOP_size, GOP_num, \ 59 | path, path_com, path_bin, path_lat 60 | 61 | def configure_gan(args): 62 | 63 | I_level = args.quality 64 | 65 | path = args.path 66 | path_video = '/srv/beegfs-benderdata/scratch/reyang_data/data/RLVC/GAN_results_inter/' + args.path + '_' + str(I_level) + '_0.001_rc1' 67 | 68 | # path_com = './GAN_nrec/' + args.path + '_' + str(I_level) \ 69 | # + '_' + str(args.w_g) + '_rc' + str(args.rc) + '/frames/' 70 | # path_bin = './GAN_nrec/' + args.path + '_' + str(I_level) \ 71 | # + '_' + str(args.w_g) + '_rc' + str(args.rc) + '/bitstreams/' 72 | # path_lat = './GAN_nrec/' + args.path + '_' + str(I_level) \ 73 | # + '_' + str(args.w_g) + '_rc' + str(args.rc) + '/latents/' 74 | 75 | path_com = path_video + '/frames/' 76 | path_bin = path_video + '/bitstreams/' 77 | path_lat = path_video + '/latents/' 78 | 79 | os.makedirs(path_com, exist_ok=True) 80 | os.makedirs(path_bin, exist_ok=True) 81 | os.makedirs(path_lat, exist_ok=True) 82 | 83 | F1 = misc.imread('/srv/beegfs-benderdata/scratch/reyang_data/data/RLVC/' + path + '/f001.png') 84 | Height = np.size(F1, 0) 85 | Width = np.size(F1, 1) 86 | batch_size = 1 87 | Channel = 3 88 | 89 | if (Height % 16 != 0) or (Width % 16 != 0): 90 | raise ValueError('Height and Width must be a mutiple of 16.') 91 | 92 | activation = tf.nn.relu 93 | 94 | GOP_size = args.f_P + args.b_P + args.inter + 1 95 | GOP_num = int(np.floor((args.frame - 1)/GOP_size)) 96 | 97 | return I_level, Height, Width, batch_size, \ 98 | Channel, activation, GOP_size, GOP_num, \ 99 | path, path_com, path_bin, path_lat 100 | 101 | def configure_decoder(args): 102 | 103 | path = '/srv/beegfs02/scratch/reyang_data/data/origCfP/' + args.path + '/' 104 | path_com = './results/' + args.path + '_' + args.mode + '_' + str(args.l) + '/frames_dec/' 105 | path_bin = './results/' + args.path + '_' + args.mode + '_' + str(args.l) + '/bitstreams/' 106 | path_lat = './results/' + args.path + '_' + args.mode + '_' + str(args.l) + '/latents_dec/' 107 | 108 | os.makedirs(path_com, exist_ok=True) 109 | os.makedirs(path_lat, exist_ok=True) 110 | 111 | activation = tf.nn.relu 112 | 113 | GOP_size = args.f_P + args.b_P + 1 114 | GOP_num = int(np.floor((args.frame - 1)/GOP_size)) 115 | 116 | return activation, GOP_size, GOP_num, \ 117 | path, path_com, path_bin, path_lat 118 | 119 | 120 | def encode_I(args, frame_index, I_level, path, path_com, path_bin): 121 | 122 | if args.mode == 'PSNR': 123 | 124 | if args.VTM == 1: 125 | 126 | F1 = misc.imread(path + '/f001.png') 127 | Height = np.size(F1, 0) 128 | Width = np.size(F1, 1) 129 | 130 | # path_yuv = './RLVC_VTM/' + args.path + '_' + args.mode + '_' + str(args.l) + '/frames/' 131 | # 132 | # if not os.path.exists(path_com + 'f' + str(frame_index).zfill(3) + '.yuv'): 133 | # 134 | os.system('ffmpeg -i ' + path + 'f' + str(frame_index).zfill(3) + '.png ' 135 | '-pix_fmt yuv444p ' + path + 'f' + str(frame_index).zfill(3) + '.yuv -y -loglevel error') 136 | os.system( 137 | '/scratch_net/maja_second/VVCSoftware_VTM/bin/EncoderAppStatic -c /scratch_net/maja_second/VVCSoftware_VTM/encoder_intra_vtm.cfg ' 138 | '-i ' + path + 'f' + str(frame_index).zfill(3) + '.yuv -b ' + path_bin + 'f' + str(frame_index).zfill(3) + '.bin ' 139 | '-o ' + path_com + 'f' + str(frame_index).zfill(3) + '.yuv -f 1 -fr 2 -wdt ' + str(Width) + ' -hgt ' + str(Height) + 140 | ' -q ' + str(I_level) + ' --InputBitDepth=8 --OutputBitDepth=8 --OutputBitDepthC=8 --InputChromaFormat=444 > /dev/null') 141 | 142 | # os.system('cp ' + './RLVC_VTM_extra_7/' + args.path + '_' + args.mode + '_' + str(args.l) + '/bitstreams/' + 'f' + str(frame_index).zfill(3) + '.bin ' + 143 | # path_bin + 'f' + str(frame_index).zfill(3) + '.bin') 144 | # os.system('cp ' + './RLVC_VTM_extra_7/' + args.path + '_' + args.mode + '_' + str(args.l) + '/frames/' + 'f' + str(frame_index).zfill(3) + '.png ' + 145 | # path_com + 'f' + str(frame_index).zfill(3) + '.png') 146 | os.system( 147 | 'ffmpeg -f rawvideo -pix_fmt yuv444p -s ' + str(Width) + 'x' + str(Height) + 148 | ' -i ' + path_com + 'f' + str(frame_index).zfill(3) + '.yuv ' 149 | + path_com + 'f' + str(frame_index).zfill(3) + '.png -y -loglevel error') 150 | 151 | else: 152 | os.system('bpgenc -f 444 -m 9 ' + path + 'f' + str(frame_index).zfill(3) + '.png ' 153 | '-o ' + path_bin + 'f' + str(frame_index).zfill(3) + '.bin -q ' + str(I_level)) 154 | os.system('bpgdec ' + path_bin + 'f' + str(frame_index).zfill(3) + '.bin ' 155 | '-o ' + path_com + 'f' + str(frame_index).zfill(3) + '.png') 156 | 157 | elif args.mode == 'MS-SSIM': 158 | os.system(args.python_path + ' ' + args.CA_model_path + '/encode.py --model_type 1 ' 159 | '--input_path ' + path + 'f' + str(frame_index).zfill(3) + '.png' + 160 | ' --compressed_file_path ' + path_bin + 'f' + str(frame_index).zfill(3) + '.bin' 161 | + ' --quality_level ' + str(I_level)) 162 | os.system(args.python_path + ' ' + args.CA_model_path + '/decode.py --compressed_file_path ' 163 | + path_bin + 'f' + str(frame_index).zfill(3) + '.bin' 164 | + ' --recon_path ' + path_com + 'f' + str(frame_index).zfill(3) + '.png') 165 | 166 | # bits = os.path.getsize(path_bin + str(frame_index).zfill(3) + '.bin') 167 | # bits = bits * 8 168 | 169 | F0_com = misc.imread(path_com + 'f' + str(frame_index).zfill(3) + '.png') 170 | F0_raw = misc.imread(path + 'f' + str(frame_index).zfill(3) + '.png') 171 | 172 | F0_com = np.expand_dims(F0_com, axis=0) 173 | F0_raw = np.expand_dims(F0_raw, axis=0) 174 | 175 | if args.metric == 'PSNR': 176 | mse = np.mean(np.power(np.subtract(F0_com / 255.0, F0_raw / 255.0), 2.0)) 177 | quality = 10 * np.log10(1.0 / mse) 178 | elif args.metric == 'MS-SSIM': 179 | quality = MultiScaleSSIM(F0_com, F0_raw, max_val=255) 180 | 181 | print('Frame', frame_index, args.metric + ' =', quality) 182 | 183 | return quality 184 | 185 | def hific_I(args, frame_index, I_level, path, path_com, path_bin): 186 | 187 | # python = '/scratch_net/maja_second/miniconda3/envs/env_cpu_3/bin/python3.6' 188 | # os.system(python + ' tfci.py compress hific-' + I_level + ' ' + path + ' ' + path_bin) 189 | # os.system(python + ' tfci.py decompress ' + path_bin + ' ' + path_com) 190 | # 191 | path_hific = '/scratch_net/maja_second/compression/models/hific_results/' 192 | # 193 | # if I_level == 'hhi': 194 | # I_level = 'hi' 195 | 196 | os.system('cp ' + path_hific + path + '_' + str(I_level) + '/f' + str(frame_index).zfill(3) + '.tfci ' + path_bin) 197 | os.system('cp ' + path_hific + path + '_' + str(I_level) + '/f' + str(frame_index).zfill(3) + '.png ' + path_com) 198 | 199 | F0_com = misc.imread(path_com + 'f' + str(frame_index).zfill(3) + '.png') 200 | F0_raw = misc.imread(path + '/f' + str(frame_index).zfill(3) + '.png') 201 | 202 | F0_com = np.expand_dims(F0_com, axis=0) 203 | F0_raw = np.expand_dims(F0_raw, axis=0) 204 | 205 | if args.metric == 'PSNR': 206 | mse = np.mean(np.power(np.subtract(F0_com / 255.0, F0_raw / 255.0), 2.0)) 207 | quality = 10 * np.log10(1.0 / mse) 208 | elif args.metric == 'MS-SSIM': 209 | quality = MultiScaleSSIM(F0_com, F0_raw, max_val=255) 210 | 211 | bits = os.path.getsize(path_bin + '/f' + str(frame_index).zfill(3) + '.tfci') 212 | bits = bits * 8 / np.size(F0_com, 1) / np.size(F0_com, 2) 213 | 214 | print('Frame', frame_index, 'bpp =', bits, args.metric + ' =', quality) 215 | 216 | return quality, bits 217 | 218 | def decode_I(args, frame_index, path_com, path_bin): 219 | 220 | if args.mode == 'PSNR': 221 | os.system('bpgdec ' + path_bin + 'f' + str(frame_index).zfill(3) + '.bin ' 222 | '-o ' + path_com + 'f' + str(frame_index).zfill(3) + '.png') 223 | 224 | elif args.mode == 'MS-SSIM': 225 | os.system(args.python_path + ' ' + args.CA_model_path + '/decode.py --compressed_file_path ' 226 | + path_bin + 'f' + str(frame_index).zfill(3) + '.bin' 227 | + ' --recon_path ' + path_com + 'f' + str(frame_index).zfill(3) + '.png') 228 | 229 | print('Decoded I-frame', frame_index) 230 | 231 | 232 | def entropy_coding(frame_index, lat, path_bin, latent, sigma, mu): 233 | 234 | if lat == 'mv': 235 | bias = 50 236 | else: 237 | bias = 100 238 | 239 | bin_name = 'f' + str(frame_index).zfill(3) + '_' + lat + '.bin' 240 | bitout = arithmeticcoding.BitOutputStream(open(path_bin + bin_name, "wb")) 241 | enc = arithmeticcoding.ArithmeticEncoder(32, bitout) 242 | 243 | for h in range(latent.shape[1]): 244 | for w in range(latent.shape[2]): 245 | for ch in range(latent.shape[3]): 246 | mu_val = mu[0, h, w, ch] + bias 247 | sigma_val = sigma[0, h, w, ch] 248 | symbol = latent[0, h, w, ch] + bias 249 | 250 | freq = arithmeticcoding.logFrequencyTable_exp(mu_val, sigma_val, np.int(bias * 2 + 1)) 251 | enc.write(freq, symbol) 252 | 253 | enc.finish() 254 | bitout.close() 255 | 256 | bits_value = os.path.getsize(path_bin + bin_name) * 8 257 | 258 | return bits_value 259 | 260 | 261 | def entropy_decoding(frame_index, lat, path_bin, path_lat, sigma, mu): 262 | 263 | if lat == 'mv': 264 | bias = 50 265 | else: 266 | bias = 100 267 | 268 | bin_name = 'f' + str(frame_index).zfill(3) + '_' + lat + '.bin' 269 | bitin = arithmeticcoding.BitInputStream(open(path_bin + bin_name, "rb")) 270 | dec = arithmeticcoding.ArithmeticDecoder(32, bitin) 271 | 272 | latent = np.zeros([1, mu.shape[1], mu.shape[2], mu.shape[3]]) 273 | 274 | for h in range(mu.shape[1]): 275 | for w in range(mu.shape[2]): 276 | for ch in range(mu.shape[3]): 277 | 278 | mu_val = mu[0, h, w, ch] + bias 279 | sigma_val = sigma[0, h, w, ch] 280 | 281 | freq = arithmeticcoding.logFrequencyTable_exp(mu_val, sigma_val, np.int(bias * 2 + 1)) 282 | symbol = dec.read(freq) 283 | latent[0, h, w, ch] = symbol - bias 284 | 285 | bitin.close() 286 | 287 | np.save(path_lat + '/f' + str(frame_index).zfill(3) + '_' + lat + '.npy', latent) 288 | print('Decoded latent_' + lat + ' frame', frame_index) 289 | 290 | return latent 291 | 292 | 293 | 294 | 295 | -------------------------------------------------------------------------------- /helper2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import os 4 | from scipy import misc 5 | from ms_ssim_np import MultiScaleSSIM 6 | import arithmeticcoding 7 | 8 | def configure(args): 9 | 10 | if args.l == 256: 11 | I_level = 37 12 | elif args.l == 512: 13 | I_level = 32 14 | elif args.l == 1024: 15 | I_level = 27 16 | elif args.l == 2048: 17 | I_level = 22 18 | 19 | elif args.l == 8: 20 | I_level = 3 21 | elif args.l == 16: 22 | I_level = 4 23 | elif args.l == 32: 24 | I_level = 5 25 | elif args.l == 64: 26 | I_level = 6 27 | 28 | path = args.path + '/' 29 | 30 | path_root = './' 31 | 32 | # if args.mode == 'MS-SSIM': 33 | # path_com = path_root + args.path + '_SSIM_' + str(args.l) + '/frames/' 34 | # path_bin = path_root + args.path + '_SSIM_' + str(args.l) + '/bitstreams/' 35 | # path_lat = path_root + args.path + '_SSIM_' + str(args.l) + '/latents/' 36 | # else: 37 | path_com = path_root + args.path + '_' + args.mode + '_' + str(args.l) + '/frames/' 38 | path_bin = path_root + args.path + '_' + args.mode + '_' + str(args.l) + '/RD_results/' 39 | path_lat = path_root + args.path + '_' + args.mode + '_' + str(args.l) + '/latents/' 40 | 41 | os.makedirs(path_com, exist_ok=True) 42 | os.makedirs(path_bin, exist_ok=True) 43 | os.makedirs(path_lat, exist_ok=True) 44 | 45 | F1 = misc.imread(path + 'f001.png') 46 | Height = np.size(F1, 0) 47 | Width = np.size(F1, 1) 48 | batch_size = 1 49 | Channel = 3 50 | 51 | if (Height % 16 != 0) or (Width % 16 != 0): 52 | raise ValueError('Height and Width must be a mutiple of 16.') 53 | 54 | activation = tf.nn.relu 55 | 56 | GOP_size = args.f_P + args.b_P + 1 + args.inter 57 | GOP_num = int(np.floor((args.frame - 1)/GOP_size)) 58 | 59 | return I_level, Height, Width, batch_size, \ 60 | Channel, activation, GOP_size, GOP_num, \ 61 | path, path_com, path_bin, path_lat 62 | 63 | 64 | def encode_I(args, frame_index, I_level, path, path_com, path_bin): 65 | 66 | F1 = misc.imread(path + '/f001.png') 67 | Height = np.size(F1, 0) 68 | Width = np.size(F1, 1) 69 | 70 | os.system('ffmpeg -i ' + path + 'f' + str(frame_index).zfill(3) + '.png ' 71 | '-pix_fmt yuv444p ' + path + 'f' + str(frame_index).zfill(3) + '.yuv -y -loglevel error') 72 | os.system( 73 | './VVCSoftware_VTM/bin/EncoderAppStatic -c ./VVCSoftware_VTM/cfg/encoder_intra_vtm.cfg ' 74 | '-i ' + path + 'f' + str(frame_index).zfill(3) + '.yuv -b ' + path_bin + 'f' + str(frame_index).zfill(3) + '.bin ' 75 | '-o ' + path_com + 'f' + str(frame_index).zfill(3) + '.yuv -f 1 -fr 2 -wdt ' + str(Width) + ' -hgt ' + str(Height) + 76 | ' -q ' + str(I_level) + ' --InputBitDepth=8 --OutputBitDepth=8 --OutputBitDepthC=8 --InputChromaFormat=444 > /dev/null') 77 | 78 | os.system( 79 | 'ffmpeg -f rawvideo -pix_fmt yuv444p -s ' + str(Width) + 'x' + str(Height) + 80 | ' -i ' + path_com + 'f' + str(frame_index).zfill(3) + '.yuv ' 81 | + path_com + 'f' + str(frame_index).zfill(3) + '.png -y -loglevel error') 82 | 83 | 84 | F0_com = misc.imread(path_com + 'f' + str(frame_index).zfill(3) + '.png') 85 | F0_raw = misc.imread(path + 'f' + str(frame_index).zfill(3) + '.png') 86 | 87 | F0_com = np.expand_dims(F0_com, axis=0) 88 | F0_raw = np.expand_dims(F0_raw, axis=0) 89 | 90 | # if args.metric == 'PSNR': 91 | mse = np.mean(np.power(np.subtract(F0_com / 255.0, F0_raw / 255.0), 2.0)) 92 | quality = 10 * np.log10(1.0 / mse) 93 | # elif args.metric == 'MS-SSIM': 94 | # quality = MultiScaleSSIM(F0_com, F0_raw, max_val=255) 95 | 96 | print('Frame', frame_index, args.metric + ' =', quality) 97 | 98 | return quality 99 | 100 | 101 | def entropy_coding(frame_index, lat, path_bin, latent, sigma, mu): 102 | 103 | if lat == 'mv': 104 | bias = 50 105 | else: 106 | bias = 100 107 | 108 | bin_name = 'f' + str(frame_index).zfill(3) + '_' + lat + '.bin' 109 | bitout = arithmeticcoding.BitOutputStream(open(path_bin + bin_name, "wb")) 110 | enc = arithmeticcoding.ArithmeticEncoder(32, bitout) 111 | 112 | for h in range(latent.shape[1]): 113 | for w in range(latent.shape[2]): 114 | for ch in range(latent.shape[3]): 115 | mu_val = mu[0, h, w, ch] + bias 116 | sigma_val = sigma[0, h, w, ch] 117 | symbol = latent[0, h, w, ch] + bias 118 | 119 | freq = arithmeticcoding.logFrequencyTable_exp(mu_val, sigma_val, np.int(bias * 2 + 1)) 120 | enc.write(freq, symbol) 121 | 122 | enc.finish() 123 | bitout.close() 124 | 125 | bits_value = os.path.getsize(path_bin + bin_name) * 8 126 | 127 | return bits_value 128 | 129 | 130 | def entropy_decoding(frame_index, lat, path_bin, path_lat, sigma, mu): 131 | 132 | if lat == 'mv': 133 | bias = 50 134 | else: 135 | bias = 100 136 | 137 | bin_name = 'f' + str(frame_index).zfill(3) + '_' + lat + '.bin' 138 | bitin = arithmeticcoding.BitInputStream(open(path_bin + bin_name, "rb")) 139 | dec = arithmeticcoding.ArithmeticDecoder(32, bitin) 140 | 141 | latent = np.zeros([1, mu.shape[1], mu.shape[2], mu.shape[3]]) 142 | 143 | for h in range(mu.shape[1]): 144 | for w in range(mu.shape[2]): 145 | for ch in range(mu.shape[3]): 146 | 147 | mu_val = mu[0, h, w, ch] + bias 148 | sigma_val = sigma[0, h, w, ch] 149 | 150 | freq = arithmeticcoding.logFrequencyTable_exp(mu_val, sigma_val, np.int(bias * 2 + 1)) 151 | symbol = dec.read(freq) 152 | latent[0, h, w, ch] = symbol - bias 153 | 154 | bitin.close() 155 | 156 | np.save(path_lat + '/f' + str(frame_index).zfill(3) + '_' + lat + '.npy', latent) 157 | print('Decoded latent_' + lat + ' frame', frame_index) 158 | 159 | return latent 160 | 161 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /inv_flow.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def convnet(im1_warp, im2, flow, layer): 5 | 6 | with tf.variable_scope("flow_cnn_" + str(layer), reuse=tf.AUTO_REUSE): 7 | 8 | input = tf.concat([im1_warp, im2, flow], axis=-1) 9 | 10 | conv1 = tf.layers.conv2d(inputs=input, filters=32, kernel_size=[7, 7], padding="same", 11 | activation=tf.nn.relu) 12 | conv2 = tf.layers.conv2d(inputs=conv1, filters=64, kernel_size=[7, 7], padding="same", 13 | activation=tf.nn.relu) 14 | conv3 = tf.layers.conv2d(inputs=conv2, filters=32, kernel_size=[7, 7], padding="same", 15 | activation=tf.nn.relu) 16 | conv4 = tf.layers.conv2d(inputs=conv3, filters=16, kernel_size=[7, 7], padding="same", 17 | activation=tf.nn.relu) 18 | conv5 = tf.layers.conv2d(inputs=conv4, filters=2 , kernel_size=[7, 7], padding="same", 19 | activation=None) 20 | 21 | return conv5 22 | 23 | 24 | def loss(flow_course, im1, im2, layer): 25 | 26 | flow = tf.image.resize_images(flow_course, [tf.shape(im1)[1], tf.shape(im2)[2]]) 27 | im1_warped = tf.contrib.image.dense_image_warp(im1, flow) 28 | res = convnet(im1_warped, im2, flow, layer) 29 | flow_fine = res + flow 30 | 31 | im1_warped_fine = tf.contrib.image.dense_image_warp(im1, flow_fine) 32 | loss_layer = tf.reduce_mean(tf.squared_difference(im1_warped_fine, im2)) 33 | 34 | return loss_layer, flow_fine 35 | 36 | 37 | def optical_flow(im1_4, im2_4, batch, h, w): 38 | 39 | im1_3 = tf.layers.average_pooling2d(im1_4, pool_size=2, strides=2, padding='same') 40 | im1_2 = tf.layers.average_pooling2d(im1_3, pool_size=2, strides=2, padding='same') 41 | im1_1 = tf.layers.average_pooling2d(im1_2, pool_size=2, strides=2, padding='same') 42 | im1_0 = tf.layers.average_pooling2d(im1_1, pool_size=2, strides=2, padding='same') 43 | 44 | im2_3 = tf.layers.average_pooling2d(im2_4, pool_size=2, strides=2, padding='same') 45 | im2_2 = tf.layers.average_pooling2d(im2_3, pool_size=2, strides=2, padding='same') 46 | im2_1 = tf.layers.average_pooling2d(im2_2, pool_size=2, strides=2, padding='same') 47 | im2_0 = tf.layers.average_pooling2d(im2_1, pool_size=2, strides=2, padding='same') 48 | 49 | flow_zero = tf.zeros([batch, h//16, w//16, 2]) 50 | 51 | loss_0, flow_0 = loss(flow_zero, im1_0, im2_0, 0) 52 | loss_1, flow_1 = loss(flow_0, im1_1, im2_1, 1) 53 | loss_2, flow_2 = loss(flow_1, im1_2, im2_2, 2) 54 | loss_3, flow_3 = loss(flow_2, im1_3, im2_3, 3) 55 | loss_4, flow_4 = loss(flow_3, im1_4, im2_4, 4) 56 | 57 | return flow_4, loss_0, loss_1, loss_2, loss_3, loss_4 58 | 59 | 60 | def reverse_sample(x_shift, y_shift, h, w, weight): 61 | 62 | x, y = tf.meshgrid(tf.range(h), tf.range(w), indexing='ij') 63 | 64 | x = tf.cast(tf.expand_dims(x, -1), tf.float32) 65 | y = tf.cast(tf.expand_dims(y, -1), tf.float32) 66 | 67 | x -= x_shift 68 | y -= y_shift 69 | 70 | x = tf.clip_by_value(x, 0, h - 1) 71 | y = tf.clip_by_value(y, 0, w - 1) 72 | 73 | grid1 = tf.concat([x, y], axis=-1) 74 | grid1 = tf.cast(grid1, tf.int32) 75 | 76 | tf_zeros = tf.zeros([h, w, 1, 1], tf.int32) 77 | indices = tf.expand_dims(grid1, 2) 78 | indices = tf.concat([indices, tf_zeros], axis=-1) 79 | 80 | ref_x = tf.Variable(tf.zeros([h, w, 1], np.float32), trainable=False, dtype=tf.float32) 81 | ref_y = tf.Variable(tf.zeros([h, w, 1], np.float32), trainable=False, dtype=tf.float32) 82 | ref_w = tf.Variable(tf.zeros([h, w, 1], np.float32) + 1e-9, trainable=False, dtype=tf.float32) 83 | 84 | inv_flow_x = tf.scatter_nd_update(ref_x, indices, -x_shift * weight) 85 | inv_flow_y = tf.scatter_nd_update(ref_y, indices, -y_shift * weight) 86 | 87 | inv_flow_batch = tf.expand_dims(tf.concat([inv_flow_x, inv_flow_y], axis=-1), axis=0) 88 | 89 | weight_x = tf.scatter_nd_update(ref_w, indices, weight) 90 | 91 | weight_batch = tf.expand_dims(weight_x, axis=0) 92 | 93 | return inv_flow_batch, weight_batch 94 | 95 | def reverse_flow(flow_input, h, w): 96 | 97 | flow_list = tf.unstack(flow_input) 98 | 99 | inv_flow = [] 100 | 101 | for flow in flow_list: 102 | 103 | x_flow, y_flow = tf.split(flow, [1, 1], axis=-1) 104 | x_1 = tf.floor(x_flow) 105 | x_2 = x_1 + 1 106 | y_1 = tf.floor(y_flow) 107 | y_2 = y_1 + 1 108 | 109 | weight_1 = tf.exp(-((x_flow - x_1) ** 2 + (y_flow - y_1) ** 2)) 110 | weight_2 = tf.exp(-((x_flow - x_1) ** 2 + (y_flow - y_2) ** 2)) 111 | weight_3 = tf.exp(-((x_flow - x_2) ** 2 + (y_flow - y_1) ** 2)) 112 | weight_4 = tf.exp(-((x_flow - x_2) ** 2 + (y_flow - y_2) ** 2)) 113 | 114 | inv_flow_1, norm_1 = reverse_sample(x_1, y_1, h, w, weight_1) 115 | inv_flow_2, norm_2 = reverse_sample(x_1, y_2, h, w, weight_2) 116 | inv_flow_3, norm_3 = reverse_sample(x_2, y_1, h, w, weight_3) 117 | inv_flow_4, norm_4 = reverse_sample(x_2, y_2, h, w, weight_4) 118 | 119 | inv_flow_batch = inv_flow_1 + inv_flow_2 + inv_flow_3 + inv_flow_4 120 | norm_batch = norm_1 + norm_2 + norm_3 + norm_4 121 | 122 | inv_flow_norm = tf.divide(inv_flow_batch, norm_batch) 123 | inv_flow.append(inv_flow_norm) 124 | 125 | inv_flow = tf.concat(inv_flow, axis=0) 126 | 127 | return inv_flow 128 | -------------------------------------------------------------------------------- /mc_func.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def resblock(input, IC, OC, name, reuse=tf.AUTO_REUSE): 5 | 6 | l1 = tf.nn.relu(input, name=name + 'relu1') 7 | 8 | l1 = tf.layers.conv2d(inputs=l1, filters=np.minimum(IC, OC), kernel_size=3, strides=1, padding='same', 9 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=True), name=name + 'l1', reuse=reuse) 10 | 11 | l2 = tf.nn.relu(l1, name='relu2') 12 | 13 | l2 = tf.layers.conv2d(inputs=l2, filters=OC, kernel_size=3, strides=1, padding='same', 14 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=True), name=name + 'l2', reuse=reuse) 15 | 16 | if IC != OC: 17 | input = tf.layers.conv2d(inputs=input, filters=OC, kernel_size=1, strides=1, padding='same', 18 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=True), name=name + 'map', reuse=reuse) 19 | 20 | return input + l2 21 | 22 | def resblock_init(input, IC, OC, name, reuse=tf.AUTO_REUSE): 23 | 24 | l1 = tf.nn.relu(input, name=name + 'relu1') 25 | 26 | l1 = tf.layers.conv2d(inputs=l1, filters=np.minimum(IC, OC), kernel_size=3, strides=1, padding='same', name=name + 'l1', reuse=reuse) 27 | 28 | l2 = tf.nn.relu(l1, name='relu2') 29 | 30 | l2 = tf.layers.conv2d(inputs=l2, filters=OC, kernel_size=3, strides=1, padding='same', name=name + 'l2', reuse=reuse) 31 | 32 | if IC != OC: 33 | input = tf.layers.conv2d(inputs=input, filters=OC, kernel_size=1, strides=1, padding='same', name=name + 'map', reuse=reuse) 34 | 35 | return input + l2 36 | 37 | def MC_RLVC(input, reuse=tf.AUTO_REUSE): 38 | 39 | m1 = tf.layers.conv2d(inputs=input, filters=64, kernel_size=3, strides=1, padding='same', 40 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=True), name='mc1', reuse=reuse) 41 | 42 | m2 = resblock(m1, 64, 64, name='mc2', reuse=reuse) 43 | 44 | m3 = tf.layers.average_pooling2d(m2, pool_size=2, strides=2, padding='same') 45 | 46 | m4 = resblock(m3, 64, 64, name='mc4', reuse=reuse) 47 | 48 | m5 = tf.layers.average_pooling2d(m4, pool_size=2, strides=2, padding='same') 49 | 50 | m6 = resblock(m5, 64, 64, name='mc6', reuse=reuse) 51 | 52 | m7 = resblock(m6, 64, 64, name='mc7', reuse=reuse) 53 | 54 | m8 = tf.image.resize_images(m7, [2 * tf.shape(m7)[1], 2 * tf.shape(m7)[2]]) 55 | 56 | m8 = m4 + m8 57 | 58 | m9 = resblock(m8, 64, 64, name='mc9', reuse=reuse) 59 | 60 | m10 = tf.image.resize_images(m9, [2 * tf.shape(m9)[1], 2 * tf.shape(m9)[2]]) 61 | 62 | m10 = m2 + m10 63 | 64 | m11 = resblock(m10, 64, 64, name='mc11', reuse=reuse) 65 | 66 | m12 = tf.layers.conv2d(inputs=m11, filters=64, kernel_size=3, strides=1, padding='same', 67 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=True), name='mc12', reuse=reuse) 68 | 69 | m12 = tf.nn.relu(m12, name='relu12') 70 | 71 | m13 = tf.layers.conv2d(inputs=m12, filters=3, kernel_size=3, strides=1, padding='same', 72 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=True), name='mc13', reuse=reuse) 73 | 74 | return m13 75 | 76 | def MC_RLVC_init(input, reuse=tf.AUTO_REUSE): 77 | 78 | m1 = tf.layers.conv2d(inputs=input, filters=64, kernel_size=3, strides=1, padding='same', name='mc1', reuse=reuse) 79 | 80 | m2 = resblock_init(m1, 64, 64, name='mc2', reuse=reuse) 81 | 82 | m3 = tf.layers.average_pooling2d(m2, pool_size=2, strides=2, padding='same') 83 | 84 | m4 = resblock_init(m3, 64, 64, name='mc4', reuse=reuse) 85 | 86 | m5 = tf.layers.average_pooling2d(m4, pool_size=2, strides=2, padding='same') 87 | 88 | m6 = resblock_init(m5, 64, 64, name='mc6', reuse=reuse) 89 | 90 | m7 = resblock_init(m6, 64, 64, name='mc7', reuse=reuse) 91 | 92 | m8 = tf.image.resize_images(m7, [2 * tf.shape(m7)[1], 2 * tf.shape(m7)[2]]) 93 | 94 | m8 = m4 + m8 95 | 96 | m9 = resblock_init(m8, 64, 64, name='mc9', reuse=reuse) 97 | 98 | m10 = tf.image.resize_images(m9, [2 * tf.shape(m9)[1], 2 * tf.shape(m9)[2]]) 99 | 100 | m10 = m2 + m10 101 | 102 | m11 = resblock_init(m10, 64, 64, name='mc11', reuse=reuse) 103 | 104 | m12 = tf.layers.conv2d(inputs=m11, filters=64, kernel_size=3, strides=1, padding='same', name='mc12', reuse=reuse) 105 | 106 | m12 = tf.nn.relu(m12, name='relu12') 107 | 108 | m13 = tf.layers.conv2d(inputs=m12, filters=3, kernel_size=3, strides=1, padding='same', name='mc13', reuse=reuse) 109 | 110 | return m13 111 | 112 | def MC_light(input, filter_num = 32, out_filter=3, name='post_light', reuse=tf.AUTO_REUSE): 113 | 114 | with tf.variable_scope(name, reuse=reuse): 115 | 116 | m1 = tf.layers.conv2d(inputs=input, filters=filter_num, kernel_size=3, strides=1, padding='same', 117 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=True), name='mc1') 118 | 119 | m2 = resblock(m1, filter_num, filter_num, name='mc2') 120 | 121 | # m3 = tf.layers.average_pooling2d(m2, pool_size=2, strides=2, padding='same') 122 | # 123 | # m4 = resblock(m3, filter_num, filter_num, name='mc4') 124 | # 125 | # m9 = resblock(m4, filter_num, filter_num, name='mc9') 126 | # 127 | # m10 = tf.image.resize_images(m9, [2 * tf.shape(m9)[1], 2 * tf.shape(m9)[2]]) 128 | # 129 | # m10 = m2 + m10 130 | 131 | m11 = resblock(m2, filter_num, filter_num, name='mc11') + m1 132 | 133 | m12 = tf.layers.conv2d(inputs=m11, filters=filter_num, kernel_size=3, strides=1, padding='same', 134 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=True), name='mc12') 135 | 136 | m12 = tf.nn.relu(m12, name='relu12') 137 | 138 | m13 = tf.layers.conv2d(inputs=m12, filters=out_filter, kernel_size=3, strides=1, padding='same', 139 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=True), name='mc13') 140 | 141 | return m13 142 | 143 | 144 | def refine_net(x, out_channel): 145 | def parametric_relu(_x, name='alpha'): 146 | alphas = tf.get_variable(name, _x.get_shape()[-1], 147 | initializer=tf.constant_initializer(0.25), 148 | dtype=tf.float32) 149 | pos = tf.nn.relu(_x) 150 | neg = alphas * (_x - abs(_x)) * 0.5 151 | 152 | return pos + neg 153 | 154 | def layers(tensor, down=False, up=False, filters=64, layer_num=3, f_size=3): 155 | if down: 156 | tensor = tf.layers.average_pooling2d(tensor, 2, 2, padding='same') 157 | if up: 158 | tensor = tf.image.resize_bilinear(tensor, [2 * tf.shape(tensor)[1], 2 * tf.shape(tensor)[2]]) 159 | for i in range(layer_num): 160 | tensor = tf.layers.conv2d(tensor, filters, f_size, activation=parametric_relu, padding='same') 161 | 162 | return tensor 163 | 164 | with tf.variable_scope('unet', None, [x]): 165 | 166 | with tf.variable_scope('encoder', None, [x]): 167 | with tf.variable_scope('downscale_1', None, [x]): 168 | pool1 = layers(x, down=False, up=False, filters=32, layer_num=2, f_size=7) 169 | 170 | with tf.variable_scope('downscale_2', None, [pool1]): 171 | pool2 = layers(pool1, down=True, up=False, filters=64, layer_num=2, f_size=5) 172 | 173 | with tf.variable_scope('downscale_3', None, [pool2]): 174 | pool3 = layers(pool2, down=True, up=False, filters=128, layer_num=2) 175 | 176 | with tf.variable_scope('downscale_4', None, [pool3]): 177 | pool4 = layers(pool3, down=True, up=False, filters=256, layer_num=2) 178 | 179 | with tf.variable_scope('downscale_5', None, [pool4]): 180 | pool5 = layers(pool4, down=True, up=False, filters=512, layer_num=3) 181 | 182 | with tf.variable_scope('decoder', None, [pool5, pool4, pool3, pool2, pool1]): 183 | 184 | with tf.variable_scope('upscale_4', None, [pool5, pool4]): 185 | up4 = layers(pool5, down=False, up=True, filters=256, layer_num=1) 186 | up4 = tf.concat([up4, pool4], axis=-1) 187 | up4 = tf.layers.conv2d(up4, 256, 3, activation=parametric_relu, padding='same') 188 | 189 | with tf.variable_scope('upscale_3', None, [up4, pool3]): 190 | up3 = layers(up4, down=False, up=True, filters=128, layer_num=1) 191 | up3 = tf.concat([up3, pool3], axis=-1) 192 | up3 = tf.layers.conv2d(up3, 128, 3, activation=parametric_relu, padding='same') 193 | 194 | with tf.variable_scope('upscale_2', None, [up3, pool2]): 195 | up2 = layers(up3, down=False, up=True, filters=64, layer_num=1) 196 | up2 = tf.concat([up2, pool2], axis=-1) 197 | up2 = tf.layers.conv2d(up2, 64, 3, activation=parametric_relu, padding='same') 198 | 199 | with tf.variable_scope('upscale_1', None, [up2, pool1]): 200 | up1 = layers(up2, down=False, up=True, filters=32, layer_num=1) 201 | up1 = tf.concat([up1, pool1], axis=-1) 202 | up1 = tf.layers.conv2d(up1, 32, 3, activation=parametric_relu, padding='same') 203 | 204 | with tf.variable_scope('output', None, [up1]): 205 | 206 | output = tf.layers.conv2d(up1, out_channel, 3, padding='same') 207 | 208 | return output, up1 209 | 210 | -------------------------------------------------------------------------------- /motion.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def convnet(im1_warp, im2, flow, layer): 5 | 6 | with tf.variable_scope("flow_cnn_" + str(layer), reuse=tf.AUTO_REUSE): 7 | 8 | input = tf.concat([im1_warp, im2, flow], axis=-1) 9 | 10 | conv1 = tf.layers.conv2d(inputs=input, filters=32, kernel_size=[7, 7], padding="same", 11 | activation=tf.nn.relu) 12 | conv2 = tf.layers.conv2d(inputs=conv1, filters=64, kernel_size=[7, 7], padding="same", 13 | activation=tf.nn.relu) 14 | conv3 = tf.layers.conv2d(inputs=conv2, filters=32, kernel_size=[7, 7], padding="same", 15 | activation=tf.nn.relu) 16 | conv4 = tf.layers.conv2d(inputs=conv3, filters=16, kernel_size=[7, 7], padding="same", 17 | activation=tf.nn.relu) 18 | conv5 = tf.layers.conv2d(inputs=conv4, filters=2 , kernel_size=[7, 7], padding="same", 19 | activation=None) 20 | 21 | return conv5 22 | 23 | 24 | def loss(flow_course, im1, im2, layer): 25 | 26 | flow = tf.image.resize_images(flow_course, [tf.shape(im1)[1], tf.shape(im2)[2]]) 27 | im1_warped = tf.contrib.image.dense_image_warp(im1, flow) 28 | res = convnet(im1_warped, im2, flow, layer) 29 | flow_fine = res + flow 30 | 31 | im1_warped_fine = tf.contrib.image.dense_image_warp(im1, flow_fine) 32 | loss_layer = tf.reduce_mean(tf.squared_difference(im1_warped_fine, im2)) 33 | 34 | return loss_layer, flow_fine 35 | 36 | 37 | def optical_flow(im1_4, im2_4, batch, h, w): 38 | 39 | im1_3 = tf.layers.average_pooling2d(im1_4, pool_size=2, strides=2, padding='same') 40 | im1_2 = tf.layers.average_pooling2d(im1_3, pool_size=2, strides=2, padding='same') 41 | im1_1 = tf.layers.average_pooling2d(im1_2, pool_size=2, strides=2, padding='same') 42 | im1_0 = tf.layers.average_pooling2d(im1_1, pool_size=2, strides=2, padding='same') 43 | 44 | im2_3 = tf.layers.average_pooling2d(im2_4, pool_size=2, strides=2, padding='same') 45 | im2_2 = tf.layers.average_pooling2d(im2_3, pool_size=2, strides=2, padding='same') 46 | im2_1 = tf.layers.average_pooling2d(im2_2, pool_size=2, strides=2, padding='same') 47 | im2_0 = tf.layers.average_pooling2d(im2_1, pool_size=2, strides=2, padding='same') 48 | 49 | flow_zero = tf.zeros([batch, h//16, w//16, 2]) 50 | 51 | loss_0, flow_0 = loss(flow_zero, im1_0, im2_0, 0) 52 | loss_1, flow_1 = loss(flow_0, im1_1, im2_1, 1) 53 | loss_2, flow_2 = loss(flow_1, im1_2, im2_2, 2) 54 | loss_3, flow_3 = loss(flow_2, im1_3, im2_3, 3) 55 | loss_4, flow_4 = loss(flow_3, im1_4, im2_4, 4) 56 | 57 | return flow_4, loss_0, loss_1, loss_2, loss_3, loss_4 58 | 59 | def tf_inverse_flow(flow_input, b, h, w): 60 | 61 | # x = vertical (channel 0 in flow), y = horizontal (channel 1 in flow) 62 | flow_list = tf.unstack(flow_input) 63 | 64 | x, y = tf.meshgrid(tf.range(h), tf.range(w), indexing='ij') 65 | 66 | x = tf.expand_dims(x, -1) 67 | y = tf.expand_dims(y, -1) 68 | 69 | grid = tf.cast(tf.concat([x, y], axis=-1), tf.float32) 70 | 71 | for r in range(b): 72 | 73 | flow = flow_list[r] 74 | grid1 = grid - flow 75 | 76 | x1, y1 = tf.split(grid1, [1, 1], axis=-1) 77 | x1 = tf.clip_by_value(x1, 0, h - 1) 78 | y1 = tf.clip_by_value(y1, 0, w - 1) 79 | grid1 = tf.concat([x1, y1], axis=-1) 80 | grid1 = tf.cast(grid1, tf.int32) 81 | 82 | tf_zeros = tf.zeros([h, w, 1, 1], tf.int32) 83 | indices = tf.expand_dims(grid1, 2) 84 | indices = tf.concat([indices, tf_zeros], axis=-1) 85 | 86 | flow_x, flow_y = tf.split(flow, [1, 1], axis=-1) 87 | 88 | ref_x = tf.Variable(np.zeros([h, w, 1], np.float32), trainable=False, dtype=tf.float32) 89 | ref_y = tf.Variable(np.zeros([h, w, 1], np.float32), trainable=False, dtype=tf.float32) 90 | inv_flow_x = tf.scatter_nd_update(ref_x, indices, -flow_x) 91 | inv_flow_y = tf.scatter_nd_update(ref_y, indices, -flow_y) 92 | inv_flow_batch = tf.expand_dims(tf.concat([inv_flow_x, inv_flow_y], axis=-1), axis=0) 93 | 94 | if r == 0: 95 | inv_flow = inv_flow_batch 96 | else: 97 | inv_flow = tf.concat([inv_flow, inv_flow_batch], axis=0) 98 | 99 | return inv_flow 100 | 101 | def reverse_sample(x_shift, y_shift, h, w, weight): 102 | 103 | x, y = tf.meshgrid(tf.range(h), tf.range(w), indexing='ij') 104 | 105 | x = tf.cast(tf.expand_dims(x, -1), tf.float32) 106 | y = tf.cast(tf.expand_dims(y, -1), tf.float32) 107 | 108 | x -= x_shift 109 | y -= y_shift 110 | 111 | x = tf.clip_by_value(x, 0, h - 1) 112 | y = tf.clip_by_value(y, 0, w - 1) 113 | 114 | grid1 = tf.concat([x, y], axis=-1) 115 | grid1 = tf.cast(grid1, tf.int32) 116 | 117 | tf_zeros = tf.zeros([h, w, 1, 1], tf.int32) 118 | indices = tf.expand_dims(grid1, 2) 119 | indices = tf.concat([indices, tf_zeros], axis=-1) 120 | 121 | ref_x = tf.Variable(tf.zeros([h, w, 1], np.float32), trainable=False, dtype=tf.float32) 122 | ref_y = tf.Variable(tf.zeros([h, w, 1], np.float32), trainable=False, dtype=tf.float32) 123 | ref_w = tf.Variable(tf.zeros([h, w, 1], np.float32) + 1e-9, trainable=False, dtype=tf.float32) 124 | 125 | inv_flow_x = tf.scatter_nd_update(ref_x, indices, -x_shift * weight) 126 | inv_flow_y = tf.scatter_nd_update(ref_y, indices, -y_shift * weight) 127 | 128 | inv_flow_batch = tf.expand_dims(tf.concat([inv_flow_x, inv_flow_y], axis=-1), axis=0) 129 | 130 | weight_x = tf.scatter_nd_update(ref_w, indices, weight) 131 | 132 | weight_batch = tf.expand_dims(weight_x, axis=0) 133 | 134 | return inv_flow_batch, weight_batch 135 | 136 | def reverse_flow(flow_input, h, w): 137 | 138 | flow_list = tf.unstack(flow_input) 139 | 140 | inv_flow = [] 141 | 142 | for flow in flow_list: 143 | 144 | x_flow, y_flow = tf.split(flow, [1, 1], axis=-1) 145 | x_1 = tf.floor(x_flow) 146 | x_2 = x_1 + 1 147 | y_1 = tf.floor(y_flow) 148 | y_2 = y_1 + 1 149 | 150 | weight_1 = tf.exp(-((x_flow - x_1) ** 2 + (y_flow - y_1) ** 2)) 151 | weight_2 = tf.exp(-((x_flow - x_1) ** 2 + (y_flow - y_2) ** 2)) 152 | weight_3 = tf.exp(-((x_flow - x_2) ** 2 + (y_flow - y_1) ** 2)) 153 | weight_4 = tf.exp(-((x_flow - x_2) ** 2 + (y_flow - y_2) ** 2)) 154 | 155 | inv_flow_1, norm_1 = reverse_sample(x_1, y_1, h, w, weight_1) 156 | inv_flow_2, norm_2 = reverse_sample(x_1, y_2, h, w, weight_2) 157 | inv_flow_3, norm_3 = reverse_sample(x_2, y_1, h, w, weight_3) 158 | inv_flow_4, norm_4 = reverse_sample(x_2, y_2, h, w, weight_4) 159 | 160 | inv_flow_batch = inv_flow_1 + inv_flow_2 + inv_flow_3 + inv_flow_4 161 | norm_batch = norm_1 + norm_2 + norm_3 + norm_4 162 | 163 | inv_flow_norm = tf.divide(inv_flow_batch, norm_batch) 164 | inv_flow.append(inv_flow_norm) 165 | 166 | inv_flow = tf.concat(inv_flow, axis=0) 167 | 168 | return inv_flow 169 | 170 | 171 | 172 | -------------------------------------------------------------------------------- /ms_ssim_np.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | # Adapted with tf_msssim_np, removed main() function 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | from scipy import signal 22 | from scipy.ndimage.filters import convolve 23 | 24 | 25 | def tf_msssim_np(img1, img2, data_format='NHWC'): 26 | assert img1.shape.ndims == img2.shape.ndims == 4, 'Expected {}, got {}'.format(data_format, img1.shape, img2.shape) 27 | assert tf.uint8.is_compatible_with(img1.dtype), 'Expected uint8 intput' 28 | assert tf.uint8.is_compatible_with(img2.dtype), 'Expected uint8 intput' 29 | 30 | if data_format == 'NCHW': 31 | def make_NHWC(x): 32 | return tf.transpose(x, (0, 2, 3, 1), name='make_NHWC') 33 | return tf_msssim_np(make_NHWC(img1), make_NHWC(img2), data_format='NHWC') 34 | 35 | assert img1.shape[3] == 3, 'Expected 3-channel images, got {}'.format(img1) 36 | 37 | with tf.name_scope('ms-ssim_np'): 38 | v = tf.py_func(_calc_msssim_orig, [img1, img2], tf.float32, stateful=False, name='MS-SSIM') 39 | v.set_shape(()) 40 | return v 41 | 42 | 43 | def _calc_msssim_orig(img1, img2): 44 | v = MultiScaleSSIM(img1, img2, max_val=255) 45 | if np.isnan(v): 46 | print(img1[0, :15, :15, 0]) 47 | print(img2[0, :15, :15, 0]) 48 | return np.float32(v) 49 | 50 | 51 | def MultiScaleSSIM(img1, img2, max_val=255, filter_size=11, filter_sigma=1.5, 52 | k1=0.01, k2=0.03, weights=None): 53 | """Return the MS-SSIM score between `img1` and `img2`. 54 | This function implements Multi-Scale Structural Similarity (MS-SSIM) Image 55 | Quality Assessment according to Zhou Wang's paper, "Multi-scale structural 56 | similarity for image quality assessment" (2003). 57 | Link: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf 58 | Author's MATLAB implementation: 59 | http://www.cns.nyu.edu/~lcv/ssim/msssim.zip 60 | Arguments: 61 | img1: Numpy array holding the first RGB image batch. 62 | img2: Numpy array holding the second RGB image batch. 63 | max_val: the dynamic range of the images (i.e., the difference between the 64 | maximum the and minimum allowed values). 65 | filter_size: Size of blur kernel to use (will be reduced for small images). 66 | filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced 67 | for small images). 68 | k1: Constant used to maintain stability in the SSIM calculation (0.01 in 69 | the original paper). 70 | k2: Constant used to maintain stability in the SSIM calculation (0.03 in 71 | the original paper). 72 | weights: List of weights for each level; if none, use five levels and the 73 | weights from the original paper. 74 | Returns: 75 | MS-SSIM score between `img1` and `img2`. 76 | Raises: 77 | RuntimeError: If input images don't have the same shape or don't have four 78 | dimensions: [batch_size, height, width, depth]. 79 | """ 80 | if img1.shape != img2.shape: 81 | raise RuntimeError('Input images must have the same shape (%s vs. %s).', 82 | img1.shape, img2.shape) 83 | if img1.ndim != 4: 84 | raise RuntimeError('Input images must have four dimensions, not %d', 85 | img1.ndim) 86 | 87 | # Note: default weights don't sum to 1.0 but do match the paper / matlab code. 88 | weights = np.array(weights if weights else 89 | [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) 90 | levels = weights.size 91 | downsample_filter = np.ones((1, 2, 2, 1)) / 4.0 92 | im1, im2 = [x.astype(np.float64) for x in [img1, img2]] 93 | mssim = np.array([]) 94 | mcs = np.array([]) 95 | for _ in range(levels): 96 | ssim, cs = _SSIMForMultiScale( 97 | im1, im2, max_val=max_val, filter_size=filter_size, 98 | filter_sigma=filter_sigma, k1=k1, k2=k2) 99 | mssim = np.append(mssim, ssim) 100 | mcs = np.append(mcs, cs) 101 | filtered = [convolve(im, downsample_filter, mode='reflect') 102 | for im in [im1, im2]] 103 | im1, im2 = [x[:, ::2, ::2, :] for x in filtered] 104 | return (np.prod(mcs[0:levels - 1] ** weights[0:levels - 1]) * 105 | (mssim[levels - 1] ** weights[levels - 1])) 106 | 107 | 108 | def _FSpecialGauss(size, sigma): 109 | """Function to mimic the 'fspecial' gaussian MATLAB function.""" 110 | radius = size // 2 111 | offset = 0.0 112 | start, stop = -radius, radius + 1 113 | if size % 2 == 0: 114 | offset = 0.5 115 | stop -= 1 116 | x, y = np.mgrid[offset + start:stop, offset + start:stop] 117 | assert len(x) == size 118 | g = np.exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2))) 119 | return g / g.sum() 120 | 121 | 122 | def _SSIMForMultiScale(img1, img2, max_val=255, filter_size=11, 123 | filter_sigma=1.5, k1=0.01, k2=0.03): 124 | """Return the Structural Similarity Map between `img1` and `img2`. 125 | This function attempts to match the functionality of ssim_index_new.m by 126 | Zhou Wang: http://www.cns.nyu.edu/~lcv/ssim/msssim.zip 127 | Arguments: 128 | img1: Numpy array holding the first RGB image batch. 129 | img2: Numpy array holding the second RGB image batch. 130 | max_val: the dynamic range of the images (i.e., the difference between the 131 | maximum the and minimum allowed values). 132 | filter_size: Size of blur kernel to use (will be reduced for small images). 133 | filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced 134 | for small images). 135 | k1: Constant used to maintain stability in the SSIM calculation (0.01 in 136 | the original paper). 137 | k2: Constant used to maintain stability in the SSIM calculation (0.03 in 138 | the original paper). 139 | Returns: 140 | Pair containing the mean SSIM and contrast sensitivity between `img1` and 141 | `img2`. 142 | Raises: 143 | RuntimeError: If input images don't have the same shape or don't have four 144 | dimensions: [batch_size, height, width, depth]. 145 | """ 146 | if img1.shape != img2.shape: 147 | raise RuntimeError('Input images must have the same shape (%s vs. %s).', 148 | img1.shape, img2.shape) 149 | if img1.ndim != 4: 150 | raise RuntimeError('Input images must have four dimensions, not %d', 151 | img1.ndim) 152 | 153 | img1 = img1.astype(np.float64) 154 | img2 = img2.astype(np.float64) 155 | _, height, width, _ = img1.shape 156 | 157 | # Filter size can't be larger than height or width of images. 158 | size = min(filter_size, height, width) 159 | 160 | # Scale down sigma if a smaller filter size is used. 161 | sigma = size * filter_sigma / filter_size if filter_size else 0 162 | 163 | if filter_size: 164 | window = np.reshape(_FSpecialGauss(size, sigma), (1, size, size, 1)) 165 | mu1 = signal.fftconvolve(img1, window, mode='valid') 166 | mu2 = signal.fftconvolve(img2, window, mode='valid') 167 | sigma11 = signal.fftconvolve(img1 * img1, window, mode='valid') 168 | sigma22 = signal.fftconvolve(img2 * img2, window, mode='valid') 169 | sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid') 170 | else: 171 | # Empty blur kernel so no need to convolve. 172 | mu1, mu2 = img1, img2 173 | sigma11 = img1 * img1 174 | sigma22 = img2 * img2 175 | sigma12 = img1 * img2 176 | 177 | mu11 = mu1 * mu1 178 | mu22 = mu2 * mu2 179 | mu12 = mu1 * mu2 180 | sigma11 -= mu11 181 | sigma22 -= mu22 182 | sigma12 -= mu12 183 | 184 | # Calculate intermediate values used by both ssim and cs_map. 185 | c1 = (k1 * max_val) ** 2 186 | c2 = (k2 * max_val) ** 2 187 | v1 = 2.0 * sigma12 + c2 188 | v2 = sigma11 + sigma22 + c2 189 | ssim = np.mean((((2.0 * mu12 + c1) * v1) / ((mu11 + mu22 + c1) * v2))) 190 | cs = np.mean(v1 / v2) 191 | return ssim, cs -------------------------------------------------------------------------------- /rec_exp.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | # import networks._tf as _tf 3 | # from networks.ops.gpu_ops import SEPCONV_MODULE 4 | from func import * 5 | # from pretrained import * 6 | import functions 7 | 8 | def get_network_pp(x, state_enc, state_dec, state_feat, motion_flag='flow_mc', in_norm=0): 9 | 10 | def parametric_relu(_x, name='alpha'): 11 | alphas = tf.get_variable(name, _x.get_shape()[-1], 12 | initializer=tf.constant_initializer(0.25), 13 | dtype=tf.float32) 14 | pos = tf.nn.relu(_x) 15 | neg = alphas * (_x - abs(_x)) * 0.5 16 | 17 | return pos + neg 18 | 19 | def one_step_rnn(tensor, state, num_filters=128, kernel=3, act=parametric_relu): 20 | 21 | tensor = tf.expand_dims(tensor, axis=1) 22 | # print(tensor.shape.as_list()[2:4]) 23 | cell = ConvLSTMCell(shape=tensor.shape.as_list()[2:4], activation=act, 24 | filters=num_filters, kernel=[kernel, kernel]) 25 | 26 | tensor, state = tf.nn.dynamic_rnn(cell, tensor, initial_state=state, dtype=tensor.dtype) 27 | tensor = tf.squeeze(tensor, axis=1) 28 | 29 | return tensor, state 30 | 31 | def sepconv(tensor, kh, kv): 32 | 33 | t_shape = tf.shape(tensor) 34 | 35 | image_patches = tf.reshape(tf.image.extract_image_patches( 36 | tensor, ksizes=[1, 51, 51, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding='SAME'), 37 | (t_shape[0], t_shape[1], t_shape[2], 51, 51, t_shape[3])) 38 | 39 | frame = tf.reduce_sum(tf.reduce_sum(image_patches * tf.expand_dims(tf.expand_dims(kh, -2), -1) 40 | * tf.expand_dims(tf.expand_dims(kv, -1), -1), axis=-2), axis=-2) 41 | 42 | return frame 43 | 44 | def layers(tensor, down=False, up=False, filters=64, layer_num=3): 45 | if down: 46 | tensor = tf.layers.conv2d(tensor, filters, 3, strides=(2, 2), activation=parametric_relu, padding='same') 47 | if up: 48 | tensor = tf.image.resize_bilinear(tensor, [2 * tf.shape(tensor)[1], 2 * tf.shape(tensor)[2]]) 49 | for i in range(layer_num): 50 | tensor = tf.layers.conv2d(tensor, filters, 3, activation=parametric_relu, padding='same') 51 | 52 | return tensor 53 | 54 | def resblock(tensor, filters, num=2): 55 | 56 | for i in range(num): 57 | 58 | l1 = tf.layers.conv2d(inputs=tensor, filters=filters, kernel_size=3, strides=1, activation=parametric_relu, padding='same') 59 | l2 = tf.layers.conv2d(inputs=l1, filters=filters, kernel_size=3, strides=1, activation=parametric_relu, padding='same') 60 | tensor += l2 61 | 62 | return tensor 63 | 64 | def subnet(tensor, out_filter=51): 65 | tensor = tf.image.resize_bilinear(tensor, [2 * tf.shape(tensor)[1], 2 * tf.shape(tensor)[2]]) 66 | tensor = tf.layers.conv2d(tensor, 64, 3, activation=parametric_relu, padding='same') 67 | tensor = tf.layers.conv2d(tensor, out_filter, 3, padding='same') 68 | 69 | return tensor 70 | 71 | with tf.variable_scope('unet', None, [x, state_enc, state_dec, state_feat], reuse=tf.AUTO_REUSE): 72 | 73 | with tf.variable_scope('encoder', None, [x, state_enc, state_feat]): 74 | with tf.variable_scope('downscale_1', None, [x]): 75 | 76 | if in_norm == 1: 77 | 78 | if motion_flag == 'flow': 79 | 80 | x_in, flow_in = tf.split(x, [6, 4], axis=-1) 81 | x_in = input_norm(x) 82 | # print('input_norm') 83 | x_in = tf.concat([x_in, flow_in], axis=-1) 84 | 85 | else: 86 | x_in = input_norm(x) 87 | # print('input_norm') 88 | else: 89 | x_in = x 90 | # print('w/o input_norm') 91 | 92 | pool1 = layers(x_in, down=False, up=False, filters=32, layer_num=3) 93 | 94 | with tf.variable_scope('downscale_2', None, [pool1]): 95 | pool2 = layers(pool1, down=True, up=False, filters=64, layer_num=1) 96 | 97 | with tf.variable_scope('downscale_3', None, [pool2]): 98 | pool3 = layers(pool2, down=True, up=False, filters=128, layer_num=1) 99 | 100 | with tf.variable_scope('rec_enc', None, [pool3, state_enc]): 101 | pool3_rec, state_enc = one_step_rnn(pool3, state_enc) 102 | 103 | with tf.variable_scope('downscale_4', None, [pool3_rec]): 104 | pool4 = layers(pool3_rec, down=True, up=False, filters=256, layer_num=1) 105 | 106 | with tf.variable_scope('downscale_5', None, [pool4, state_feat]): 107 | pool5 = layers(pool4, down=True, up=False, filters=512, layer_num=1) 108 | pool5_rec, state_feat = one_step_rnn(pool5, state_feat, num_filters=512) 109 | pool5 = layers(pool5_rec, down=False, up=False, filters=512, layer_num=1) 110 | 111 | with tf.variable_scope('decoder', None, [pool5, pool4, pool3, pool2, pool1, state_dec]): 112 | 113 | with tf.variable_scope('upscale_4', None, [pool5, pool4]): 114 | up4 = layers(pool5, down=False, up=True, filters=256, layer_num=2) 115 | up4 += resblock(pool4, filters=256, num=1) 116 | 117 | with tf.variable_scope('upscale_3', None, [up4, pool3]): 118 | up3 = layers(up4, down=False, up=True, filters=128, layer_num=2) 119 | up3 += resblock(pool3, filters=128, num=1) 120 | 121 | with tf.variable_scope('rec_dec', None, [up3, state_dec]): 122 | up3_rec, state_dec = one_step_rnn(up3, state_dec) 123 | 124 | with tf.variable_scope('upscale_2', None, [up3_rec, pool2]): 125 | up2 = layers(up3_rec, down=False, up=True, filters=64, layer_num=2) 126 | up2 += resblock(pool2, filters=64, num=1) 127 | 128 | if motion_flag == 'sepconv': 129 | 130 | with tf.variable_scope('sepconv', None, [up2, x]): 131 | 132 | with tf.variable_scope('frame_1', None, [up2]): 133 | kv_1 = subnet(up2) 134 | kh_1 = subnet(up2) 135 | 136 | with tf.variable_scope('frame_2', None, [up2]): 137 | kv_2 = subnet(up2) 138 | kh_2 = subnet(up2) 139 | 140 | frame_1 = sepconv(x[:, :, :, :3], kv_1, kh_1) 141 | frame_2 = sepconv(x[:, :, :, 3:], kv_2, kh_2) 142 | 143 | output = frame_1 + frame_2 144 | 145 | else: 146 | 147 | with tf.variable_scope('flow', None, [up2, x]): 148 | 149 | with tf.variable_scope('frame_1', None, [up2]): 150 | flow_mask_1 = subnet(up2, out_filter=3) 151 | flow_1, mask_1 = tf.split(flow_mask_1, [2, 1], axis=-1) 152 | with tf.variable_scope('frame_2', None, [up2]): 153 | flow_mask_2 = subnet(up2, out_filter=3) 154 | flow_2, mask_2 = tf.split(flow_mask_2, [2, 1], axis=-1) 155 | 156 | frame_1 = mask_1 * tf.contrib.image.dense_image_warp(x[:, :, :, 0:3], flow_1) 157 | frame_2 = mask_2 * tf.contrib.image.dense_image_warp(x[:, :, :, 3:6], flow_2) 158 | 159 | if motion_flag == 'flow': 160 | output = frame_1 + frame_2 161 | else: 162 | with tf.variable_scope('refine', None, [x, flow_1, flow_2, mask_1, mask_2, frame_1, frame_2]): 163 | 164 | input_to_refine = tf.concat([x[:, :, :, 0:3], x[:, :, :, 3:6], 165 | flow_1, flow_2, mask_1, mask_2, 166 | frame_1, frame_2], axis=-1) 167 | 168 | output = frame_1 + frame_2 + functions.MC_RLVC(input_to_refine) 169 | 170 | if motion_flag == 'sepconv': 171 | 172 | return output, state_enc, state_dec, state_feat 173 | 174 | else: 175 | 176 | return output, state_enc, state_dec, state_feat, tf.concat([flow_1, flow_2], axis=-1) 177 | 178 | 179 | 180 | -------------------------------------------------------------------------------- /sepconv_inter.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | # import networks._tf as _tf 3 | # from networks.ops.gpu_ops import SEPCONV_MODULE 4 | from func import * 5 | # from pretrained import * 6 | import mc_func 7 | 8 | def get_network_pp(x, motion_flag='flow_mc'): 9 | 10 | def parametric_relu(_x, name='alpha'): 11 | alphas = tf.get_variable(name, _x.get_shape()[-1], 12 | initializer=tf.constant_initializer(0.25), 13 | dtype=tf.float32) 14 | pos = tf.nn.relu(_x) 15 | neg = alphas * (_x - abs(_x)) * 0.5 16 | 17 | return pos + neg 18 | 19 | def one_step_rnn(tensor, state, num_filters=128, kernel=3, act=parametric_relu): 20 | 21 | tensor = tf.expand_dims(tensor, axis=1) 22 | # print(tensor.shape.as_list()[2:4]) 23 | cell = ConvLSTMCell(shape=tensor.shape.as_list()[2:4], activation=act, 24 | filters=num_filters, kernel=[kernel, kernel]) 25 | 26 | tensor, state = tf.nn.dynamic_rnn(cell, tensor, initial_state=state, dtype=tensor.dtype) 27 | tensor = tf.squeeze(tensor, axis=1) 28 | 29 | return tensor, state 30 | 31 | def sepconv(tensor, kh, kv): 32 | 33 | t_shape = tf.shape(tensor) 34 | 35 | image_patches = tf.reshape(tf.image.extract_image_patches( 36 | tensor, ksizes=[1, 51, 51, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding='SAME'), 37 | (t_shape[0], t_shape[1], t_shape[2], 51, 51, t_shape[3])) 38 | 39 | frame = tf.reduce_sum(tf.reduce_sum(image_patches * tf.expand_dims(tf.expand_dims(kh, -2), -1) 40 | * tf.expand_dims(tf.expand_dims(kv, -1), -1), axis=-2), axis=-2) 41 | 42 | return frame 43 | 44 | def layers(tensor, down=False, up=False, filters=64, layer_num=3): 45 | if down: 46 | tensor = tf.layers.conv2d(tensor, filters, 3, strides=(2, 2), activation=parametric_relu, padding='same') 47 | if up: 48 | tensor = tf.image.resize_bilinear(tensor, [2 * tf.shape(tensor)[1], 2 * tf.shape(tensor)[2]]) 49 | for i in range(layer_num): 50 | tensor = tf.layers.conv2d(tensor, filters, 3, activation=parametric_relu, padding='same') 51 | 52 | return tensor 53 | 54 | def resblock(tensor, filters, num=2): 55 | 56 | for i in range(num): 57 | 58 | l1 = tf.layers.conv2d(inputs=tensor, filters=filters, kernel_size=3, strides=1, activation=parametric_relu, padding='same') 59 | l2 = tf.layers.conv2d(inputs=l1, filters=filters, kernel_size=3, strides=1, activation=parametric_relu, padding='same') 60 | tensor += l2 61 | 62 | return tensor 63 | 64 | def subnet(tensor, filter_num=64, out_filter=3): 65 | tensor = tf.image.resize_bilinear(tensor, [2 * tf.shape(tensor)[1], 2 * tf.shape(tensor)[2]]) 66 | tensor = tf.layers.conv2d(tensor, filter_num, 3, activation=parametric_relu, padding='same') 67 | tensor = tf.layers.conv2d(tensor, out_filter, 3, padding='same') 68 | 69 | return tensor 70 | 71 | with tf.variable_scope('unet', None, [x], reuse=tf.AUTO_REUSE): 72 | 73 | with tf.variable_scope('encoder', None, [x]): 74 | with tf.variable_scope('downscale_1', None, [x]): 75 | pool1 = layers(x, down=False, up=False, filters=32, layer_num=3) 76 | 77 | with tf.variable_scope('downscale_2', None, [pool1]): 78 | pool2 = layers(pool1, down=True, up=False, filters=64, layer_num=1) 79 | 80 | with tf.variable_scope('downscale_3', None, [pool2]): 81 | pool3 = layers(pool2, down=True, up=False, filters=128, layer_num=1) 82 | 83 | with tf.variable_scope('downscale_4', None, [pool3]): 84 | pool4 = layers(pool3, down=True, up=False, filters=256, layer_num=1) 85 | 86 | with tf.variable_scope('downscale_5', None, [pool4]): 87 | pool5 = layers(pool4, down=True, up=False, filters=512, layer_num=3) 88 | 89 | with tf.variable_scope('decoder', None, [pool5, pool4, pool3, pool2, pool1]): 90 | 91 | with tf.variable_scope('upscale_4', None, [pool5, pool4]): 92 | up4 = layers(pool5, down=False, up=True, filters=256, layer_num=2) 93 | up4 += resblock(pool4, filters=256, num=1) 94 | 95 | with tf.variable_scope('upscale_3', None, [up4, pool3]): 96 | up3 = layers(up4, down=False, up=True, filters=128, layer_num=2) 97 | up3 += resblock(pool3, filters=128, num=1) 98 | 99 | with tf.variable_scope('upscale_2', None, [up3, pool2]): 100 | up2 = layers(up3, down=False, up=True, filters=64, layer_num=2) 101 | up2 += resblock(pool2, filters=64, num=1) 102 | 103 | with tf.variable_scope('flow', None, [up2, x]): 104 | 105 | if motion_flag == 'flow_mc': 106 | 107 | with tf.variable_scope('frame_1', None, [up2]): 108 | flow_mask_1 = subnet(up2, out_filter=3) 109 | flow_1, mask_1 = tf.split(flow_mask_1, [2, 1], axis=-1) 110 | with tf.variable_scope('frame_2', None, [up2]): 111 | flow_mask_2 = subnet(up2, out_filter=3) 112 | flow_2, mask_2 = tf.split(flow_mask_2, [2, 1], axis=-1) 113 | 114 | flag = 1 115 | 116 | else: 117 | with tf.variable_scope('frame_12', None, [up2]): 118 | flow_mask = subnet(up2, filter_num=128, out_filter=6) 119 | flow_1, mask_1, flow_2, mask_2 = tf.split(flow_mask, [2, 1, 2, 1], axis=-1) 120 | 121 | flag = 2 122 | 123 | frame_1 = mask_1 * tf.contrib.image.dense_image_warp(x[:, :, :, 0:3], flow_1) 124 | frame_2 = mask_2 * tf.contrib.image.dense_image_warp(x[:, :, :, 3:6], flow_2) 125 | 126 | with tf.variable_scope('refine', None, [x, flow_1, flow_2, mask_1, mask_2, frame_1, frame_2]): 127 | 128 | input_to_refine = tf.concat([x, flow_1, flow_2, mask_1, mask_2, frame_1, frame_2], axis=-1) 129 | 130 | output = frame_1 + frame_2 + mc_func.MC_RLVC(input_to_refine) 131 | 132 | return output, flag 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /sepconv_inter_enc.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | # import networks._tf as _tf 3 | # from networks.ops.gpu_ops import SEPCONV_MODULE 4 | from func import * 5 | # from pretrained import * 6 | import mc_func 7 | 8 | def get_network_enc(x, name): 9 | 10 | def parametric_relu(_x, name='alpha'): 11 | alphas = tf.get_variable(name, _x.get_shape()[-1], 12 | initializer=tf.constant_initializer(0.25), 13 | dtype=tf.float32) 14 | pos = tf.nn.relu(_x) 15 | neg = alphas * (_x - abs(_x)) * 0.5 16 | 17 | return pos + neg 18 | 19 | def layers(tensor, down=False, up=False, filters=64, layer_num=3): 20 | if down: 21 | tensor = tf.layers.conv2d(tensor, filters, 3, strides=(2, 2), activation=parametric_relu, padding='same') 22 | if up: 23 | tensor = tf.image.resize_bilinear(tensor, [2 * tf.shape(tensor)[1], 2 * tf.shape(tensor)[2]]) 24 | for i in range(layer_num): 25 | tensor = tf.layers.conv2d(tensor, filters, 3, activation=parametric_relu, padding='same') 26 | 27 | return tensor 28 | 29 | def resblock(tensor, filters, num=2): 30 | 31 | for i in range(num): 32 | 33 | l1 = tf.layers.conv2d(inputs=tensor, filters=filters, kernel_size=3, strides=1, activation=parametric_relu, padding='same') 34 | l2 = tf.layers.conv2d(inputs=l1, filters=filters, kernel_size=3, strides=1, activation=parametric_relu, padding='same') 35 | tensor += l2 36 | 37 | return tensor 38 | 39 | with tf.variable_scope(name, None, [x], reuse=False): 40 | 41 | with tf.variable_scope('encoder', None, [x]): 42 | with tf.variable_scope('downscale_1', None, [x]): 43 | pool1 = layers(x, down=False, up=False, filters=32, layer_num=3) 44 | 45 | with tf.variable_scope('downscale_2', None, [pool1]): 46 | pool2 = layers(pool1, down=True, up=False, filters=64, layer_num=1) 47 | skip2 = resblock(pool2, filters=64, num=1) 48 | 49 | with tf.variable_scope('downscale_3', None, [pool2]): 50 | pool3 = layers(pool2, down=True, up=False, filters=128, layer_num=1) 51 | skip3 = resblock(pool3, filters=128, num=1) 52 | 53 | with tf.variable_scope('downscale_4', None, [pool3]): 54 | pool4 = layers(pool3, down=True, up=False, filters=256, layer_num=1) 55 | skip4 = resblock(pool4, filters=256, num=1) 56 | 57 | with tf.variable_scope('downscale_5', None, [pool4]): 58 | pool5 = layers(pool4, down=True, up=False, filters=512, layer_num=3) 59 | 60 | return pool5, skip4, skip3, skip2 61 | 62 | def get_network_dec(pool5, skip4, skip3, skip2, x): 63 | 64 | def parametric_relu(_x, name='alpha'): 65 | alphas = tf.get_variable(name, _x.get_shape()[-1], 66 | initializer=tf.constant_initializer(0.25), 67 | dtype=tf.float32) 68 | pos = tf.nn.relu(_x) 69 | neg = alphas * (_x - abs(_x)) * 0.5 70 | 71 | return pos + neg 72 | 73 | 74 | def layers(tensor, down=False, up=False, filters=64, layer_num=3): 75 | if down: 76 | tensor = tf.layers.conv2d(tensor, filters, 3, strides=(2, 2), activation=parametric_relu, padding='same') 77 | if up: 78 | tensor = tf.image.resize_bilinear(tensor, [2 * tf.shape(tensor)[1], 2 * tf.shape(tensor)[2]]) 79 | for i in range(layer_num): 80 | tensor = tf.layers.conv2d(tensor, filters, 3, activation=parametric_relu, padding='same') 81 | 82 | return tensor 83 | 84 | def subnet(tensor, filter_num=64, out_filter=3): 85 | tensor = tf.image.resize_bilinear(tensor, [2 * tf.shape(tensor)[1], 2 * tf.shape(tensor)[2]]) 86 | tensor = tf.layers.conv2d(tensor, filter_num, 3, activation=parametric_relu, padding='same') 87 | tensor = tf.layers.conv2d(tensor, out_filter, 3, padding='same') 88 | 89 | return tensor 90 | 91 | with tf.variable_scope('dec', None, [pool5, skip4, skip3, skip2, x], reuse=tf.AUTO_REUSE): 92 | 93 | with tf.variable_scope('upscale_4', None, [pool5, skip4]): 94 | up4 = layers(pool5, down=False, up=True, filters=256, layer_num=2) 95 | up4 += skip4 96 | 97 | with tf.variable_scope('upscale_3', None, [up4, skip3]): 98 | up3 = layers(up4, down=False, up=True, filters=128, layer_num=2) 99 | up3 += skip3 100 | 101 | with tf.variable_scope('upscale_2', None, [up3, skip2]): 102 | up2 = layers(up3, down=False, up=True, filters=64, layer_num=2) 103 | up2 += skip2 104 | 105 | with tf.variable_scope('frame_1', None, [up2]): 106 | flow_mask_1 = subnet(up2, out_filter=3) 107 | flow_1, mask_1 = tf.split(flow_mask_1, [2, 1], axis=-1) 108 | with tf.variable_scope('frame_2', None, [up2]): 109 | flow_mask_2 = subnet(up2, out_filter=3) 110 | flow_2, mask_2 = tf.split(flow_mask_2, [2, 1], axis=-1) 111 | 112 | frame_1 = mask_1 * tf.contrib.image.dense_image_warp(x[:, :, :, 3:6], flow_1) 113 | frame_2 = mask_2 * tf.contrib.image.dense_image_warp(x[:, :, :, 6:9], flow_2) 114 | 115 | with tf.variable_scope('refine', None, [x, flow_1, flow_2, mask_1, mask_2, frame_1, frame_2]): 116 | 117 | input_to_refine = tf.concat([x, flow_1, flow_2, mask_1, mask_2, frame_1, frame_2], axis=-1) 118 | output = frame_1 + frame_2 + mc_func.MC_RLVC(input_to_refine) 119 | 120 | return output 121 | 122 | 123 | def conv_map(x1, x2, x3, filters, name): 124 | 125 | def parametric_relu(_x, name='alpha'): 126 | alphas = tf.get_variable(name, _x.get_shape()[-1], 127 | initializer=tf.constant_initializer(0.25), 128 | dtype=tf.float32) 129 | pos = tf.nn.relu(_x) 130 | neg = alphas * (_x - abs(_x)) * 0.5 131 | 132 | return pos + neg 133 | 134 | with tf.variable_scope(name, None, [x1, x2, x3], reuse=tf.AUTO_REUSE): 135 | 136 | x = tf.concat([x1, x2, x3], axis=-1) 137 | y = tf.layers.conv2d(inputs=x, filters=filters, kernel_size=1, strides=1, activation=parametric_relu, 138 | padding='same') 139 | 140 | return y 141 | 142 | 143 | 144 | --------------------------------------------------------------------------------