├── .gitignore ├── LICENSE ├── README.md ├── environment.yml ├── epsilonparam ├── __pycache__ │ └── config.cpython-39.pyc ├── config.py ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── load_data.cpython-310.pyc │ │ ├── load_data.cpython-37.pyc │ │ ├── load_data.cpython-38.pyc │ │ ├── load_data.cpython-39.pyc │ │ ├── load_dataset.cpython-310.pyc │ │ ├── load_dataset.cpython-37.pyc │ │ ├── load_dataset.cpython-38.pyc │ │ ├── load_dataset.cpython-39.pyc │ │ ├── transposed_collate.cpython-310.pyc │ │ ├── transposed_collate.cpython-37.pyc │ │ ├── transposed_collate.cpython-38.pyc │ │ └── transposed_collate.cpython-39.pyc │ ├── datasets │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── audi.cpython-310.pyc │ │ │ ├── audi.cpython-38.pyc │ │ │ ├── audi.cpython-39.pyc │ │ │ ├── bair_robot_pushing.cpython-310.pyc │ │ │ ├── bair_robot_pushing.cpython-37.pyc │ │ │ ├── bair_robot_pushing.cpython-38.pyc │ │ │ ├── bair_robot_pushing.cpython-39.pyc │ │ │ ├── big.cpython-310.pyc │ │ │ ├── big.cpython-37.pyc │ │ │ ├── big.cpython-38.pyc │ │ │ ├── big.cpython-39.pyc │ │ │ ├── bouncing_ball.cpython-310.pyc │ │ │ ├── bouncing_ball.cpython-37.pyc │ │ │ ├── bouncing_ball.cpython-38.pyc │ │ │ ├── bouncing_ball.cpython-39.pyc │ │ │ ├── bouncing_ball_creator.cpython-310.pyc │ │ │ ├── bouncing_ball_creator.cpython-37.pyc │ │ │ ├── bouncing_ball_creator.cpython-38.pyc │ │ │ ├── bouncing_ball_creator.cpython-39.pyc │ │ │ ├── city.cpython-310.pyc │ │ │ ├── city.cpython-38.pyc │ │ │ ├── city.cpython-39.pyc │ │ │ ├── climate.cpython-310.pyc │ │ │ ├── climate.cpython-38.pyc │ │ │ ├── climate.cpython-39.pyc │ │ │ ├── freedoc.cpython-37.pyc │ │ │ ├── image.cpython-310.pyc │ │ │ ├── image.cpython-37.pyc │ │ │ ├── image.cpython-38.pyc │ │ │ ├── image.cpython-39.pyc │ │ │ ├── kinetics.cpython-37.pyc │ │ │ ├── kth_actions.cpython-310.pyc │ │ │ ├── kth_actions.cpython-37.pyc │ │ │ ├── kth_actions.cpython-38.pyc │ │ │ ├── kth_actions.cpython-39.pyc │ │ │ ├── moving_mnist.cpython-310.pyc │ │ │ ├── moving_mnist.cpython-37.pyc │ │ │ ├── moving_mnist.cpython-38.pyc │ │ │ ├── moving_mnist.cpython-39.pyc │ │ │ ├── simu.cpython-310.pyc │ │ │ ├── simu.cpython-38.pyc │ │ │ ├── simu.cpython-39.pyc │ │ │ ├── stochastic_moving_mnist.cpython-310.pyc │ │ │ ├── stochastic_moving_mnist.cpython-37.pyc │ │ │ ├── stochastic_moving_mnist.cpython-38.pyc │ │ │ ├── stochastic_moving_mnist.cpython-39.pyc │ │ │ ├── uvg.cpython-310.pyc │ │ │ ├── uvg.cpython-37.pyc │ │ │ ├── uvg.cpython-38.pyc │ │ │ ├── uvg.cpython-39.pyc │ │ │ ├── vimeo.cpython-310.pyc │ │ │ ├── vimeo.cpython-37.pyc │ │ │ ├── vimeo.cpython-38.pyc │ │ │ ├── vimeo.cpython-39.pyc │ │ │ ├── youtube.cpython-310.pyc │ │ │ ├── youtube.cpython-37.pyc │ │ │ ├── youtube.cpython-38.pyc │ │ │ └── youtube.cpython-39.pyc │ │ ├── audi.py │ │ ├── bair.py │ │ ├── bair_robot.py │ │ ├── bair_robot_pushing.py │ │ ├── big.py │ │ ├── bouncing_ball.py │ │ ├── bouncing_ball_creator.py │ │ ├── city.py │ │ ├── climate.py │ │ ├── image.py │ │ ├── kth_actions.py │ │ ├── moving_mnist.py │ │ ├── simu.py │ │ ├── stochastic_moving_mnist.py │ │ ├── uvg.py │ │ ├── vimeo.py │ │ └── youtube.py │ ├── load_data.py │ ├── load_dataset.py │ ├── misc_data_util │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── convert_bair.cpython-37.pyc │ │ │ ├── convert_bair.cpython-38.pyc │ │ │ ├── convert_kth_actions.cpython-37.pyc │ │ │ ├── convert_kth_actions.cpython-38.pyc │ │ │ ├── kth_actions_frames.cpython-37.pyc │ │ │ ├── kth_actions_frames.cpython-38.pyc │ │ │ ├── transforms.cpython-310.pyc │ │ │ ├── transforms.cpython-37.pyc │ │ │ ├── transforms.cpython-38.pyc │ │ │ ├── transforms.cpython-39.pyc │ │ │ ├── url_save.cpython-310.pyc │ │ │ ├── url_save.cpython-37.pyc │ │ │ ├── url_save.cpython-38.pyc │ │ │ └── url_save.cpython-39.pyc │ │ ├── convert_bair.py │ │ ├── convert_kth_actions.py │ │ ├── kth_actions_frames.py │ │ ├── transforms.py │ │ └── url_save.py │ └── transposed_collate.py ├── modules │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── compress_modules.cpython-39.pyc │ │ ├── denoising_diffusion.cpython-39.pyc │ │ ├── network_components.cpython-39.pyc │ │ ├── unet.cpython-39.pyc │ │ └── utils.cpython-39.pyc │ ├── ae.py │ ├── compress_modules.py │ ├── denoising_diffusion.py │ ├── distill_diffusion.py │ ├── distill_trainer.py │ ├── network_components.py │ ├── trainer.py │ ├── unet.py │ └── utils.py ├── test_epsilonparam.py └── train.py ├── imgs ├── 1.png ├── 2.png └── 3.png └── xparam ├── config.py ├── config_ae.py ├── data ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── load_data.cpython-39.pyc │ └── load_dataset.cpython-39.pyc ├── datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── audi.cpython-310.pyc │ │ ├── audi.cpython-38.pyc │ │ ├── audi.cpython-39.pyc │ │ ├── bair_robot_pushing.cpython-310.pyc │ │ ├── bair_robot_pushing.cpython-37.pyc │ │ ├── bair_robot_pushing.cpython-38.pyc │ │ ├── bair_robot_pushing.cpython-39.pyc │ │ ├── big.cpython-310.pyc │ │ ├── big.cpython-37.pyc │ │ ├── big.cpython-38.pyc │ │ ├── big.cpython-39.pyc │ │ ├── bouncing_ball.cpython-310.pyc │ │ ├── bouncing_ball.cpython-37.pyc │ │ ├── bouncing_ball.cpython-38.pyc │ │ ├── bouncing_ball.cpython-39.pyc │ │ ├── bouncing_ball_creator.cpython-310.pyc │ │ ├── bouncing_ball_creator.cpython-37.pyc │ │ ├── bouncing_ball_creator.cpython-38.pyc │ │ ├── bouncing_ball_creator.cpython-39.pyc │ │ ├── city.cpython-310.pyc │ │ ├── city.cpython-38.pyc │ │ ├── city.cpython-39.pyc │ │ ├── climate.cpython-310.pyc │ │ ├── climate.cpython-38.pyc │ │ ├── climate.cpython-39.pyc │ │ ├── freedoc.cpython-37.pyc │ │ ├── image.cpython-310.pyc │ │ ├── image.cpython-37.pyc │ │ ├── image.cpython-38.pyc │ │ ├── image.cpython-39.pyc │ │ ├── kinetics.cpython-37.pyc │ │ ├── kth_actions.cpython-310.pyc │ │ ├── kth_actions.cpython-37.pyc │ │ ├── kth_actions.cpython-38.pyc │ │ ├── kth_actions.cpython-39.pyc │ │ ├── moving_mnist.cpython-310.pyc │ │ ├── moving_mnist.cpython-37.pyc │ │ ├── moving_mnist.cpython-38.pyc │ │ ├── moving_mnist.cpython-39.pyc │ │ ├── simu.cpython-310.pyc │ │ ├── simu.cpython-38.pyc │ │ ├── simu.cpython-39.pyc │ │ ├── stochastic_moving_mnist.cpython-310.pyc │ │ ├── stochastic_moving_mnist.cpython-37.pyc │ │ ├── stochastic_moving_mnist.cpython-38.pyc │ │ ├── stochastic_moving_mnist.cpython-39.pyc │ │ ├── uvg.cpython-310.pyc │ │ ├── uvg.cpython-37.pyc │ │ ├── uvg.cpython-38.pyc │ │ ├── uvg.cpython-39.pyc │ │ ├── vimeo.cpython-310.pyc │ │ ├── vimeo.cpython-37.pyc │ │ ├── vimeo.cpython-38.pyc │ │ ├── vimeo.cpython-39.pyc │ │ ├── youtube.cpython-310.pyc │ │ ├── youtube.cpython-37.pyc │ │ ├── youtube.cpython-38.pyc │ │ └── youtube.cpython-39.pyc │ ├── audi.py │ ├── bair.py │ ├── bair_robot.py │ ├── bair_robot_pushing.py │ ├── big.py │ ├── bouncing_ball.py │ ├── bouncing_ball_creator.py │ ├── city.py │ ├── climate.py │ ├── image.py │ ├── kth_actions.py │ ├── moving_mnist.py │ ├── simu.py │ ├── stochastic_moving_mnist.py │ ├── uvg.py │ ├── vimeo.py │ └── youtube.py ├── load_data.py ├── load_dataset.py ├── misc_data_util │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── convert_bair.cpython-37.pyc │ │ ├── convert_bair.cpython-38.pyc │ │ ├── convert_kth_actions.cpython-37.pyc │ │ ├── convert_kth_actions.cpython-38.pyc │ │ ├── kth_actions_frames.cpython-37.pyc │ │ ├── kth_actions_frames.cpython-38.pyc │ │ ├── transforms.cpython-310.pyc │ │ ├── transforms.cpython-37.pyc │ │ ├── transforms.cpython-38.pyc │ │ ├── transforms.cpython-39.pyc │ │ ├── url_save.cpython-310.pyc │ │ ├── url_save.cpython-37.pyc │ │ ├── url_save.cpython-38.pyc │ │ └── url_save.cpython-39.pyc │ ├── convert_bair.py │ ├── convert_kth_actions.py │ ├── kth_actions_frames.py │ ├── transforms.py │ └── url_save.py └── transposed_collate.py ├── modules ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-39.pyc │ ├── compress_modules.cpython-310.pyc │ ├── compress_modules.cpython-39.pyc │ ├── denoising_diffusion.cpython-310.pyc │ ├── denoising_diffusion.cpython-39.pyc │ ├── network_components.cpython-310.pyc │ ├── network_components.cpython-39.pyc │ ├── unet.cpython-310.pyc │ ├── unet.cpython-39.pyc │ ├── utils.cpython-310.pyc │ └── utils.cpython-39.pyc ├── compress_modules.py ├── denoising_diffusion.py ├── network_components.py ├── trainer.py ├── unet.py └── utils.py ├── test_xparam.py ├── train.py └── train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | submit.sh 2 | test_slurm.sh -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Buggyyang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lossy Image Compression with Conditional Diffusion Models 2 | 3 | This repository contains the codebase for our paper on [Lossy Image Compression with Conditional Diffusion Models](https://arxiv.org/pdf/2209.06950.pdf). We provide an off-the-shelf test code for both x-parameterization and epsilon-parameterization. 4 | 5 | ## Usage 6 | 7 | - There are two separate folders, `epsilonparam` and `xparam`, for the epsilon-parameterization and x-parameterization models, respectively. Please use the appropriate folder depending on the model you want to work with. This is because there are some minor differences between the x-param and e-param models, making them incompatible with each other. 8 | - Before running the test code, please read the comments about the arguments in the code to ensure proper usage. 9 | - Note that this test code is provided as a template, you may need to write your own dataloader to get the actual results. 10 | - We also provide 3 images from Kodak dataset in the `imgs` folder for testing. You can use them to test the code. 11 | 12 | ## Model Weights 13 | 14 | The model weights can be downloaded from [this link](https://drive.google.com/drive/folders/197Wl5cwjaCvrEvggMcyNeHOSxq2rDZ1F?usp=sharing). 15 | - Why the x-param weights are approximately twice as large as the epsilon-param weights? For the x-parameterization, I saved both the exponential moving average (ema) and the latest model. When I load the model, I only load the ema. 16 | 17 | Please feel free to explore the code and experiment with the models. If you have any questions or encounter any issues, don't hesitate to reach out to us. 18 | 19 | ## Environment 20 | 21 | please use the environment.yml file to create a conda environment. (It may contain redundant packages.) -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: exp_pytorch 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - appdirs=1.4.4=pyhd3eb1b0_0 11 | - blas=1.0=mkl 12 | - blessings=1.7=py39h06a4308_1002 13 | - brotli=1.0.9=h5eee18b_7 14 | - brotli-bin=1.0.9=h5eee18b_7 15 | - brotlipy=0.7.0=py39h27cfd23_1003 16 | - bzip2=1.0.8=h7b6447c_0 17 | - ca-certificates=2023.01.10=h06a4308_0 18 | - certifi=2022.12.7=py39h06a4308_0 19 | - cffi=1.15.1=py39h5eee18b_3 20 | - colorama=0.4.6=pyhd8ed1ab_0 21 | - contourpy=1.0.5=py39hdb19cb5_0 22 | - cryptography=39.0.1=py39h9ce1e76_0 23 | - cuda-cudart=11.8.89=0 24 | - cuda-cupti=11.8.87=0 25 | - cuda-libraries=11.8.0=0 26 | - cuda-nvrtc=11.8.89=0 27 | - cuda-nvtx=11.8.86=0 28 | - cuda-runtime=11.8.0=0 29 | - cycler=0.11.0=pyhd3eb1b0_0 30 | - dbus=1.13.18=hb2f20db_0 31 | - expat=2.4.9=h6a678d5_0 32 | - ffmpeg=4.3=hf484d3e_0 33 | - filelock=3.9.0=py39h06a4308_0 34 | - flit-core=3.8.0=py39h06a4308_0 35 | - fontconfig=2.14.1=h4c34cd2_2 36 | - fonttools=4.25.0=pyhd3eb1b0_0 37 | - freetype=2.12.1=h4a9f257_0 38 | - giflib=5.2.1=h5eee18b_3 39 | - glib=2.69.1=he621ea3_2 40 | - gmp=6.2.1=h295c915_3 41 | - gmpy2=2.1.2=py39heeb90bb_0 42 | - gnutls=3.6.15=he1e5248_0 43 | - gpustat=0.6.0=pyhd3eb1b0_1 44 | - gst-plugins-base=1.14.1=h6a678d5_1 45 | - gstreamer=1.14.1=h5eee18b_1 46 | - icu=58.2=he6710b0_3 47 | - idna=3.4=py39h06a4308_0 48 | - importlib_resources=5.2.0=pyhd3eb1b0_1 49 | - intel-openmp=2021.4.0=h06a4308_3561 50 | - jinja2=3.1.2=py39h06a4308_0 51 | - jpeg=9e=h5eee18b_1 52 | - kiwisolver=1.4.4=py39h6a678d5_0 53 | - krb5=1.19.4=h568e23c_0 54 | - lame=3.100=h7b6447c_0 55 | - lcms2=2.12=h3be6417_0 56 | - ld_impl_linux-64=2.38=h1181459_1 57 | - lerc=3.0=h295c915_0 58 | - libbrotlicommon=1.0.9=h5eee18b_7 59 | - libbrotlidec=1.0.9=h5eee18b_7 60 | - libbrotlienc=1.0.9=h5eee18b_7 61 | - libclang=14.0.6=default_hc6dbbc7_1 62 | - libclang13=14.0.6=default_he11475f_1 63 | - libcublas=11.11.3.6=0 64 | - libcufft=10.9.0.58=0 65 | - libcufile=1.6.0.25=0 66 | - libcurand=10.3.2.56=0 67 | - libcusolver=11.4.1.48=0 68 | - libcusparse=11.7.5.86=0 69 | - libdeflate=1.17=h5eee18b_0 70 | - libedit=3.1.20221030=h5eee18b_0 71 | - libevent=2.1.12=h8f2d780_0 72 | - libffi=3.4.2=h6a678d5_6 73 | - libgcc-ng=11.2.0=h1234567_1 74 | - libgfortran-ng=11.2.0=h00389a5_1 75 | - libgfortran5=11.2.0=h1234567_1 76 | - libgomp=11.2.0=h1234567_1 77 | - libiconv=1.16=h7f8727e_2 78 | - libidn2=2.3.2=h7f8727e_0 79 | - libllvm14=14.0.6=hdb19cb5_2 80 | - libnpp=11.8.0.86=0 81 | - libnvjpeg=11.9.0.86=0 82 | - libpng=1.6.39=h5eee18b_0 83 | - libpq=12.9=h16c4e8d_3 84 | - libstdcxx-ng=11.2.0=h1234567_1 85 | - libtasn1=4.19.0=h5eee18b_0 86 | - libtiff=4.5.0=h6a678d5_2 87 | - libunistring=0.9.10=h27cfd23_0 88 | - libuuid=1.41.5=h5eee18b_0 89 | - libwebp=1.2.4=h11a3e52_1 90 | - libwebp-base=1.2.4=h5eee18b_1 91 | - libxcb=1.15=h7f8727e_0 92 | - libxkbcommon=1.0.1=h5eee18b_1 93 | - libxml2=2.10.3=hcbfbd50_0 94 | - libxslt=1.1.37=h2085143_0 95 | - lz4-c=1.9.4=h6a678d5_0 96 | - matplotlib=3.7.1=py39h06a4308_1 97 | - matplotlib-base=3.7.1=py39h417a72b_1 98 | - mkl=2021.4.0=h06a4308_640 99 | - mkl-service=2.4.0=py39h7f8727e_0 100 | - mkl_fft=1.3.1=py39hd3c417c_0 101 | - mkl_random=1.2.2=py39h51133e4_0 102 | - mpc=1.1.0=h10f8cd9_1 103 | - mpfr=4.0.2=hb69a4c5_1 104 | - mpmath=1.2.1=py39h06a4308_0 105 | - munkres=1.1.4=py_0 106 | - ncurses=6.4=h6a678d5_0 107 | - nettle=3.7.3=hbbd107a_1 108 | - networkx=2.8.4=py39h06a4308_1 109 | - nspr=4.33=h295c915_0 110 | - nss=3.74=h0370c37_0 111 | - numpy-base=1.23.5=py39h31eccc5_0 112 | - nvidia-ml=7.352.0=pyhd3eb1b0_0 113 | - openh264=2.1.1=h4ff587b_0 114 | - openssl=1.1.1t=h7f8727e_0 115 | - packaging=23.0=py39h06a4308_0 116 | - pcre=8.45=h295c915_0 117 | - pillow=9.4.0=py39h6a678d5_0 118 | - pip=23.0.1=py39h06a4308_0 119 | - ply=3.11=py39h06a4308_0 120 | - pooch=1.4.0=pyhd3eb1b0_0 121 | - psutil=5.9.0=py39h5eee18b_0 122 | - pycparser=2.21=pyhd3eb1b0_0 123 | - pyopenssl=23.0.0=py39h06a4308_0 124 | - pyparsing=3.0.9=py39h06a4308_0 125 | - pyqt=5.15.7=py39h6a678d5_1 126 | - pyqt5-sip=12.11.0=py39h6a678d5_1 127 | - pysocks=1.7.1=py39h06a4308_0 128 | - python=3.9.16=h7a1cb2a_2 129 | - python-dateutil=2.8.2=pyhd3eb1b0_0 130 | - pytorch=2.0.0=py3.9_cuda11.8_cudnn8.7.0_0 131 | - pytorch-cuda=11.8=h7e8668a_3 132 | - pytorch-mutex=1.0=cuda 133 | - qt-main=5.15.2=h8373d8f_8 134 | - qt-webengine=5.15.9=hbbf29b9_6 135 | - qtwebkit=5.212=h3fafdc1_5 136 | - readline=8.2=h5eee18b_0 137 | - scipy=1.10.1=py39h14f4228_0 138 | - sip=6.6.2=py39h6a678d5_0 139 | - six=1.16.0=pyhd3eb1b0_1 140 | - sqlite=3.41.1=h5eee18b_0 141 | - sympy=1.11.1=py39h06a4308_0 142 | - tk=8.6.12=h1ccaba5_0 143 | - toml=0.10.2=pyhd3eb1b0_0 144 | - torchaudio=2.0.0=py39_cu118 145 | - torchtriton=2.0.0=py39 146 | - torchvision=0.15.0=py39_cu118 147 | - tornado=6.2=py39h5eee18b_0 148 | - tqdm=4.65.0=pyhd8ed1ab_1 149 | - typing_extensions=4.4.0=py39h06a4308_0 150 | - tzdata=2023c=h04d1e81_0 151 | - urllib3=1.26.15=py39h06a4308_0 152 | - xz=5.2.10=h5eee18b_1 153 | - zlib=1.2.13=h5eee18b_0 154 | - zstd=1.5.4=hc292b87_0 155 | - pip: 156 | - absl-py==1.4.0 157 | - addict==2.4.0 158 | - autograd==1.5 159 | - blosc2==2.0.0 160 | - cachetools==5.3.0 161 | - chardet==5.2.0 162 | - charset-normalizer==3.1.0 163 | - compressai==1.2.4 164 | - cython==0.29.34 165 | - einops==0.6.0 166 | - ema-pytorch==0.2.3 167 | - ftfy==6.1.1 168 | - future==0.18.3 169 | - google-auth==2.17.3 170 | - google-auth-oauthlib==0.4.6 171 | - grpcio==1.54.0 172 | - huggingface-hub==0.13.4 173 | - imageio==2.27.0 174 | - imgaug==0.4.0 175 | - importlib-metadata==6.6.0 176 | - lazy-loader==0.2 177 | - lmdb==1.4.1 178 | - lpips==0.1.4 179 | - markdown==3.4.3 180 | - markupsafe==2.1.2 181 | - msgpack==1.0.5 182 | - numexpr==2.8.4 183 | - numpy==1.24.3 184 | - oauthlib==3.2.2 185 | - openai-clip==1.0.1 186 | - opencv-python==4.7.0.72 187 | - protobuf==3.20.3 188 | - py-cpuinfo==9.0.0 189 | - pyasn1==0.5.0 190 | - pyasn1-modules==0.3.0 191 | - pytorch-msssim==0.2.1 192 | - pywavelets==1.4.1 193 | - pyyaml==6.0 194 | - regex==2023.3.23 195 | - requests==2.28.2 196 | - requests-oauthlib==1.3.1 197 | - rsa==4.9 198 | - scikit-image==0.20.0 199 | - setuptools==67.7.1 200 | - shapely==2.0.1 201 | - tables==3.8.0 202 | - tensorboard==2.11.2 203 | - tensorboard-data-server==0.6.1 204 | - tensorboard-plugin-wit==1.8.1 205 | - tifffile==2023.4.12 206 | - timm==0.6.13 207 | - tomli==2.0.1 208 | - werkzeug==2.2.3 209 | - wheel==0.40.0 210 | - yapf==0.33.0 211 | - zipp==3.15.0 212 | -------------------------------------------------------------------------------- /epsilonparam/__pycache__/config.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/__pycache__/config.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/config.py: -------------------------------------------------------------------------------- 1 | # training config 2 | n_step = 1000000 3 | scheduler_checkpoint_step = 100000 4 | log_checkpoint_step = 5000 5 | gradient_accumulate_every = 1 6 | lr = 4e-5 7 | decay = 0.9 8 | minf = 0.5 9 | optimizer = "adam" # adamw or adam 10 | n_workers = 4 11 | 12 | # load 13 | load_model = True 14 | load_step = True 15 | 16 | # diffusion config 17 | pred_mode = 'noise' 18 | loss_type = "l1" 19 | iteration_step = 20000 20 | sample_steps = 500 21 | embed_dim = 64 22 | dim_mults = (1, 2, 3, 4, 5, 6) 23 | hyper_dim_mults = (4, 4, 4) 24 | context_channels = 3 25 | clip_noise = "none" 26 | val_num_of_batch = 1 27 | additional_note = "" 28 | vbr = False 29 | context_dim_mults = (1, 2, 3, 4) 30 | sample_mode = "ddim" 31 | var_schedule = "linear" 32 | aux_loss_type = "lpips" 33 | compressor = "big" 34 | 35 | # data config 36 | data_config = { 37 | "dataset_name": "vimeo", 38 | "data_path": "*", 39 | "sequence_length": 1, 40 | "img_size": 256, 41 | "img_channel": 3, 42 | "add_noise": False, 43 | "img_hz_flip": False, 44 | } 45 | 46 | batch_size = 4 47 | 48 | result_root = "*" 49 | tensorboard_root = "*" 50 | -------------------------------------------------------------------------------- /epsilonparam/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .load_data import load_data 2 | -------------------------------------------------------------------------------- /epsilonparam/data/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /epsilonparam/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/data/__pycache__/load_data.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/__pycache__/load_data.cpython-310.pyc -------------------------------------------------------------------------------- /epsilonparam/data/__pycache__/load_data.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/__pycache__/load_data.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/__pycache__/load_data.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/__pycache__/load_data.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/__pycache__/load_data.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/__pycache__/load_data.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/data/__pycache__/load_dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/__pycache__/load_dataset.cpython-310.pyc -------------------------------------------------------------------------------- /epsilonparam/data/__pycache__/load_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/__pycache__/load_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/__pycache__/load_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/__pycache__/load_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/__pycache__/load_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/__pycache__/load_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/data/__pycache__/transposed_collate.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/__pycache__/transposed_collate.cpython-310.pyc -------------------------------------------------------------------------------- /epsilonparam/data/__pycache__/transposed_collate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/__pycache__/transposed_collate.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/__pycache__/transposed_collate.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/__pycache__/transposed_collate.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/__pycache__/transposed_collate.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/__pycache__/transposed_collate.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .bair_robot_pushing import BAIRRobotPushing 2 | # from .bair_robot import BAIRRobot 3 | from .kth_actions import KTHActions 4 | from .moving_mnist import MovingMNIST 5 | from .stochastic_moving_mnist import StochasticMovingMNIST 6 | from .bouncing_ball_creator import make_bouncing_ball_dataset 7 | from .bouncing_ball import BouncingBall 8 | from .big import BIG 9 | from .image import IMG 10 | from .vimeo import VIMEO 11 | from .youtube import Youtube 12 | from .uvg import UVG 13 | from .audi import AUDI 14 | from .climate import ClimateData 15 | from .city import CITY 16 | from .simu import Simulation -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/audi.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/audi.cpython-310.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/audi.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/audi.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/audi.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/audi.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/bair_robot_pushing.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/bair_robot_pushing.cpython-310.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/bair_robot_pushing.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/bair_robot_pushing.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/bair_robot_pushing.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/bair_robot_pushing.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/bair_robot_pushing.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/bair_robot_pushing.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/big.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/big.cpython-310.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/big.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/big.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/big.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/big.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/big.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/big.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/bouncing_ball.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/bouncing_ball.cpython-310.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/bouncing_ball.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/bouncing_ball.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/bouncing_ball.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/bouncing_ball.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/bouncing_ball.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/bouncing_ball.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/bouncing_ball_creator.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/bouncing_ball_creator.cpython-310.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/bouncing_ball_creator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/bouncing_ball_creator.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/bouncing_ball_creator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/bouncing_ball_creator.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/bouncing_ball_creator.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/bouncing_ball_creator.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/city.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/city.cpython-310.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/city.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/city.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/city.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/city.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/climate.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/climate.cpython-310.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/climate.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/climate.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/climate.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/climate.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/freedoc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/freedoc.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/image.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/image.cpython-310.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/image.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/image.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/image.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/image.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/image.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/image.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/kinetics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/kinetics.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/kth_actions.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/kth_actions.cpython-310.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/kth_actions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/kth_actions.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/kth_actions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/kth_actions.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/kth_actions.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/kth_actions.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/moving_mnist.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/moving_mnist.cpython-310.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/moving_mnist.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/moving_mnist.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/moving_mnist.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/moving_mnist.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/moving_mnist.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/moving_mnist.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/simu.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/simu.cpython-310.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/simu.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/simu.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/simu.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/simu.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/stochastic_moving_mnist.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/stochastic_moving_mnist.cpython-310.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/stochastic_moving_mnist.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/stochastic_moving_mnist.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/stochastic_moving_mnist.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/stochastic_moving_mnist.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/stochastic_moving_mnist.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/stochastic_moving_mnist.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/uvg.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/uvg.cpython-310.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/uvg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/uvg.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/uvg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/uvg.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/uvg.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/uvg.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/vimeo.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/vimeo.cpython-310.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/vimeo.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/vimeo.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/vimeo.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/vimeo.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/vimeo.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/vimeo.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/youtube.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/youtube.cpython-310.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/youtube.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/youtube.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/youtube.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/youtube.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/__pycache__/youtube.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/datasets/__pycache__/youtube.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/data/datasets/audi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | from pathlib import Path 6 | import torch 7 | 8 | 9 | class AUDI(Dataset): 10 | def __init__(self, path, num_of_frame, train=True, transform=None, add_noise=False): 11 | assert os.path.exists(path), "Invalid path to AUDI data set: " + path 12 | self.path = path 13 | self.transform = transform 14 | self.train = train 15 | if train: 16 | self.video_list = list( 17 | Path(os.path.join(path, "camera_lidar_semantic")).glob("*/camera/cam_front_center") 18 | )[:-1] 19 | else: 20 | self.video_list = list( 21 | Path(os.path.join(path, "camera_lidar_semantic")).glob("*/camera/cam_front_center") 22 | )[-1:] 23 | self.add_noise = add_noise 24 | self.num_of_frame = num_of_frame 25 | self.img_paths = [] 26 | for each in self.video_list: 27 | self.img_paths.append(sorted(list(each.glob("**/*small.png")))) 28 | 29 | def __getitem__(self, ind): 30 | # load the images from the ind directory to get list of PIL images 31 | if self.train: 32 | start_index = torch.randint(0, len(self.img_paths[ind]) - self.num_of_frame, (1,)).item() 33 | else: 34 | start_index = 525 35 | imgs = [Image.open(self.img_paths[ind][start_index + i]) for i in range(self.num_of_frame)] 36 | if self.transform is not None: 37 | imgs = self.transform(imgs) 38 | 39 | if self.add_noise: 40 | imgs = imgs + (torch.rand_like(imgs) - 0.5) / 256.0 41 | 42 | return imgs 43 | 44 | def __len__(self): 45 | # total number of videos 46 | return len(self.video_list) 47 | -------------------------------------------------------------------------------- /epsilonparam/data/datasets/bair.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | from scipy.misc import imresize 4 | import numpy as np 5 | from PIL import Image 6 | from scipy.misc import imresize 7 | from scipy.misc import imread 8 | 9 | 10 | class RobotPush(object): 11 | 12 | """Data Handler that loads robot pushing data.""" 13 | 14 | def __init__(self, data_root, train=True, seq_len=20, image_size=64): 15 | self.root_dir = data_root 16 | if train: 17 | self.data_dir = '%s/processed_data/train' % self.root_dir 18 | self.ordered = False 19 | else: 20 | self.data_dir = '%s/processed_data/test' % self.root_dir 21 | self.ordered = True 22 | self.dirs = [] 23 | for d1 in os.listdir(self.data_dir): 24 | for d2 in os.listdir('%s/%s' % (self.data_dir, d1)): 25 | self.dirs.append('%s/%s/%s' % (self.data_dir, d1, d2)) 26 | self.seq_len = seq_len 27 | self.image_size = image_size 28 | self.seed_is_set = False # multi threaded loading 29 | self.d = 0 30 | 31 | def set_seed(self, seed): 32 | if not self.seed_is_set: 33 | self.seed_is_set = True 34 | np.random.seed(seed) 35 | 36 | def __len__(self): 37 | return 10000 38 | 39 | def get_seq(self): 40 | if self.ordered: 41 | d = self.dirs[self.d] 42 | if self.d == len(self.dirs) - 1: 43 | self.d = 0 44 | else: 45 | self.d+=1 46 | else: 47 | d = self.dirs[np.random.randint(len(self.dirs))] 48 | image_seq = [] 49 | for i in range(self.seq_len): 50 | fname = '%s/%d.png' % (d, i) 51 | im = imread(fname).reshape(1, 64, 64, 3) 52 | image_seq.append(im/255.) 53 | image_seq = np.concatenate(image_seq, axis=0) 54 | return image_seq 55 | 56 | 57 | def __getitem__(self, index): 58 | self.set_seed(index) 59 | return self.get_seq() 60 | 61 | 62 | -------------------------------------------------------------------------------- /epsilonparam/data/datasets/bair_robot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | import os 5 | 6 | class RobotPushing(torch.utils.data.Dataset): 7 | """ 8 | dataset class for moving-mnist dataset 9 | """ 10 | def __init__(self, is_train, data_path=None): 11 | super(RobotPushing, self).__init__() 12 | if data_path is None: 13 | self.data_path = '/local-scratch/chenleic/Data/BAIR/robot_pushing/main_frames' 14 | # print() 15 | else: 16 | self.data_path = data_path 17 | 18 | 19 | all_vid_list = sorted(os.listdir(self.data_path)) 20 | 21 | self.is_train = is_train 22 | if self.is_train: 23 | self.vid_list = all_vid_list[:10000] 24 | else: 25 | self.vid_list = all_vid_list[-100:] 26 | 27 | 28 | def __len__(self): 29 | return len(self.vid_list) 30 | 31 | 32 | def __getitem__(self, item): 33 | frames = np.load(os.path.join(self.data_path, self.vid_list[item])) 34 | frames = torch.FloatTensor(frames).permute(3, 0, 2, 1).contiguous() 35 | 36 | frames = frames/255 37 | 38 | return frames 39 | 40 | 41 | if __name__ == '__main__': 42 | dataset = RobotPushing(is_train=True) 43 | dataloader = torch.utils.data.DataLoader(dataset) 44 | 45 | import matplotlib.pyplot as plt 46 | 47 | for batch in dataloader: 48 | frame = batch[0,:,0,:,:].permute(1,2,0).contiguous() 49 | plt.imshow(frame) 50 | plt.draw() 51 | plt.savefig('/local-scratch/chenleic/Projects/seq_flow/seq_flow_robot_results/check_frame.jpg') 52 | break 53 | 54 | print(batch.size()) 55 | 56 | 57 | -------------------------------------------------------------------------------- /epsilonparam/data/datasets/bair_robot_pushing.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | import torch 5 | 6 | 7 | class BAIRRobotPushing(Dataset): 8 | """ 9 | Dataset object for BAIR robot pushing dataset. The dataset must be stored 10 | with each video in a separate directory: 11 | /path 12 | /0 13 | /0.png 14 | /1.png 15 | /... 16 | /1 17 | /... 18 | """ 19 | 20 | def __init__(self, path, transform=None, add_noise=False): 21 | assert os.path.exists(path), 'Invalid path to BAIR data set: ' + path 22 | self.path = path 23 | self.transform = transform 24 | self.video_list = os.listdir(self.path) 25 | 26 | self.add_noise = add_noise 27 | 28 | def __getitem__(self, ind): 29 | # load the images from the ind directory to get list of PIL images 30 | img_names = os.listdir(os.path.join(self.path, self.video_list[ind])) 31 | img_names = [img_name.split('.')[0] for img_name in img_names] 32 | img_names.sort(key=float) 33 | imgs = [Image.open(os.path.join(self.path, self.video_list[ind], i + '.png')) for i in img_names] 34 | if self.transform is not None: 35 | # apply the image/video transforms 36 | imgs = self.transform(imgs) 37 | 38 | # imgs = imgs.unsqueeze(1) 39 | 40 | if self.add_noise: 41 | imgs = imgs + (torch.rand_like(imgs)-0.5) / 256. 42 | 43 | return imgs 44 | 45 | def __len__(self): 46 | # total number of videos 47 | return len(self.video_list) 48 | -------------------------------------------------------------------------------- /epsilonparam/data/datasets/big.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | from torchvision.io import read_video 5 | import torch 6 | 7 | 8 | class BIG(Dataset): 9 | """ 10 | Dataset object for BAIR robot pushing dataset. The dataset must be stored 11 | with each video in a separate directory: 12 | /path 13 | /0 14 | /0.png 15 | /1.png 16 | /... 17 | /1 18 | /... 19 | """ 20 | 21 | def __init__(self, path, transform=None, add_noise=False, img_mode=False): 22 | assert os.path.exists( 23 | path), 'Invalid path to UCF+HMDB data set: ' + path 24 | self.path = path 25 | self.transform = transform 26 | self.video_list = os.listdir(self.path) 27 | self.img_mode = img_mode 28 | self.add_noise = add_noise 29 | 30 | def __getitem__(self, ind): 31 | # load the images from the ind directory to get list of PIL images 32 | img_names = os.listdir(os.path.join( 33 | self.path, self.video_list[ind])) 34 | img_names = [img_name.split('.')[0] for img_name in img_names] 35 | img_names.sort(key=float) 36 | if not self.img_mode: 37 | imgs = [Image.open(os.path.join( 38 | self.path, self.video_list[ind], i + '.png')) for i in img_names] 39 | else: 40 | select = torch.randint(0, len(img_names), (1,)) 41 | imgs = [Image.open(os.path.join( 42 | self.path, self.video_list[ind], img_names[select] + '.png'))] 43 | if self.transform is not None: 44 | # apply the image/video transforms 45 | imgs = self.transform(imgs) 46 | 47 | # imgs = imgs.unsqueeze(1) 48 | 49 | if self.add_noise: 50 | imgs = imgs + (torch.rand_like(imgs)-0.5) / 256. 51 | 52 | return imgs 53 | 54 | def __len__(self): 55 | # total number of videos 56 | return len(self.video_list) 57 | -------------------------------------------------------------------------------- /epsilonparam/data/datasets/bouncing_ball.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class BouncingBall(Dataset): 8 | """ 9 | Dataset class for moving MNIST dataset. 10 | 11 | Args: 12 | path (str): path to the .mat dataset 13 | transform (torchvision.transforms): image/video transforms 14 | """ 15 | def __init__(self, path, sequence_lengh): 16 | assert os.path.exists(path), 'Invalid path to Bouncing Ball data set: ' + path 17 | self.sequence_length = sequence_lengh 18 | self.data = np.load(path) 19 | 20 | def __getitem__(self, ind): 21 | imgs = self.data[ind,:,:,:].astype('float32') 22 | s, h, w = imgs.shape 23 | imgs = imgs.reshape(s, 1, h, w) 24 | 25 | imgs = imgs[:self.sequence_length, :, :, :] 26 | 27 | return torch.FloatTensor(imgs).contiguous() 28 | 29 | def __len__(self): 30 | return self.data.shape[0] 31 | -------------------------------------------------------------------------------- /epsilonparam/data/datasets/bouncing_ball_creator.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script comes from the RTRBM code by Ilya Sutskever from 3 | http://www.cs.utoronto.ca/~ilya/code/2008/RTRBM.tar 4 | """ 5 | 6 | from numpy import * 7 | from scipy import * 8 | import pdb 9 | import pickle 10 | import scipy.io 11 | import sys, os 12 | 13 | import matplotlib 14 | 15 | matplotlib.use('Agg') 16 | import matplotlib.pyplot as plt 17 | 18 | shape_std = shape 19 | 20 | 21 | def shape(A): 22 | if isinstance(A, ndarray): 23 | return shape_std(A) 24 | else: 25 | return A.shape() 26 | 27 | 28 | size_std = size 29 | 30 | 31 | def size(A): 32 | if isinstance(A, ndarray): 33 | return size_std(A) 34 | else: 35 | return A.size() 36 | 37 | 38 | det = linalg.det 39 | 40 | 41 | def new_speeds(m1, m2, v1, v2): 42 | new_v2 = (2 * m1 * v1 + v2 * (m2 - m1)) / (m1 + m2) 43 | new_v1 = new_v2 + (v2 - v1) 44 | return new_v1, new_v2 45 | 46 | 47 | def norm(x): return sqrt((x ** 2).sum()) 48 | 49 | 50 | def sigmoid(x): return 1. / (1. + exp(-x)) 51 | 52 | 53 | SIZE = 10 54 | 55 | 56 | # size of bounding box: SIZE X SIZE. 57 | 58 | def bounce_n(T=128, n=2, r=None, m=None): 59 | if r is None: r = array([1.2] * n) 60 | if m is None: m = array([1] * n) 61 | # r is to be rather small. 62 | X = zeros((T, n, 2), dtype='float') 63 | v = randn(n, 2) 64 | v = v / norm(v) * .5 65 | good_config = False 66 | while not good_config: 67 | x = 2 + rand(n, 2) * 8 68 | good_config = True 69 | for i in range(n): 70 | for z in range(2): 71 | if x[i][z] - r[i] < 0: good_config = False 72 | if x[i][z] + r[i] > SIZE: good_config = False 73 | 74 | # that's the main part. 75 | for i in range(n): 76 | for j in range(i): 77 | if norm(x[i] - x[j]) < r[i] + r[j]: 78 | good_config = False 79 | 80 | eps = .5 81 | for t in range(T): 82 | # for how long do we show small simulation 83 | 84 | for i in range(n): 85 | X[t, i] = x[i] 86 | 87 | for mu in range(int(1 / eps)): 88 | 89 | for i in range(n): 90 | x[i] += eps * v[i] 91 | 92 | for i in range(n): 93 | for z in range(2): 94 | if x[i][z] - r[i] < 0: v[i][z] = abs(v[i][z]) # want positive 95 | if x[i][z] + r[i] > SIZE: v[i][z] = -abs(v[i][z]) # want negative 96 | 97 | for i in range(n): 98 | for j in range(i): 99 | if norm(x[i] - x[j]) < r[i] + r[j]: 100 | # the bouncing off part: 101 | w = x[i] - x[j] 102 | w = w / norm(w) 103 | 104 | v_i = dot(w.transpose(), v[i]) 105 | v_j = dot(w.transpose(), v[j]) 106 | 107 | new_v_i, new_v_j = new_speeds(m[i], m[j], v_i, v_j) 108 | 109 | v[i] += w * (new_v_i - v_i) 110 | v[j] += w * (new_v_j - v_j) 111 | 112 | return X 113 | 114 | 115 | def ar(x, y, z): 116 | return z / 2 + arange(x, y, z, dtype='float') 117 | 118 | 119 | def matricize(X, res, r=None): 120 | T, n = shape(X)[0:2] 121 | if r is None: r = array([1.2] * n) 122 | 123 | A = zeros((T, res, res), dtype='float') 124 | 125 | [I, J] = meshgrid(ar(0, 1, 1. / res) * SIZE, ar(0, 1, 1. / res) * SIZE) 126 | 127 | for t in range(T): 128 | for i in range(n): 129 | A[t] += exp(-(((I - X[t, i, 0]) ** 2 + (J - X[t, i, 1]) ** 2) / (r[i] ** 2)) ** 4) 130 | 131 | A[t][A[t] > 1] = 1 132 | return A 133 | 134 | 135 | def bounce_mat(res, n=2, T=128, r=None): 136 | if r == None: r = array([1.2] * n) 137 | x = bounce_n(T, n, r); 138 | A = matricize(x, res, r) 139 | return A 140 | 141 | 142 | def bounce_vec(res, n=2, T=128, r=None, m=None): 143 | if r == None: r = array([1.2] * n) 144 | x = bounce_n(T, n, r, m); 145 | V = matricize(x, res, r) 146 | return V.reshape(T, res ** 2) 147 | 148 | 149 | # make sure you have this folder 150 | # logdir = './sample' 151 | 152 | 153 | # def show_sample(V): 154 | # T = len(V) 155 | # res = int(sqrt(shape(V)[1])) 156 | # for t in range(T): 157 | # plt.imshow(V[t].reshape(res, res), cmap=matplotlib.cm.Greys_r) 158 | # # Save it 159 | # fname = logdir + '/' + str(t) + '.png' 160 | # plt.savefig(fname) 161 | 162 | 163 | def make_bouncing_ball_dataset(data_path, res, n_ball, T, N_train, N_val): 164 | train_data = zeros((N_train, T, res, res)) 165 | for i in range(N_train): 166 | train_data[i] = bounce_vec(res=res, n=n_ball, T=T).reshape((T, res, res)) 167 | sys.stdout.write('\rcreating bouncing ball train clip {}/{}'.format(i, N_train)) 168 | sys.stdout.flush() 169 | 170 | train_data = train_data.reshape((N_train, T, res, res)) 171 | save(os.path.join(data_path, 'bouncing_balls_train_data.npy'), train_data) 172 | print() 173 | 174 | val_data = zeros((N_val, T, res, res)) 175 | for i in range(N_val): 176 | val_data[i] = bounce_vec(res=res, n=n_ball, T=T).reshape((T, res, res)) 177 | sys.stdout.write('\rcreating bouncing ball val clip {}/{}'.format(i, N_val)) 178 | sys.stdout.flush() 179 | 180 | val_data = val_data.reshape((N_val, T, res, res)) 181 | save(os.path.join(data_path, 'bouncing_balls_val_data.npy'), val_data) 182 | 183 | 184 | 185 | # if __name__ == "__main__": 186 | # res = 30 187 | # T = 100 188 | # N = 4000 189 | # dat = empty((N), dtype=object) 190 | # for i in range(N): 191 | # dat[i] = bounce_vec(res=res, n=3, T=100) 192 | # data = {} 193 | # data['Data'] = dat 194 | # scipy.io.savemat('bouncing_balls_training_data.mat', data) 195 | # 196 | # N = 200 197 | # dat = empty((N), dtype=object) 198 | # for i in range(N): 199 | # dat[i] = bounce_vec(res=res, n=3, T=100) 200 | # data = {} 201 | # data['Data'] = dat 202 | # scipy.io.savemat('bouncing_balls_testing_data.mat', data) 203 | 204 | # show one video 205 | # show_sample(dat[1]) 206 | # ffmpeg -start_number 0 -i %d.png -c:v libx264 -pix_fmt yuv420p -r 30 sample.mp4 -------------------------------------------------------------------------------- /epsilonparam/data/datasets/city.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | from pathlib import Path 6 | import torch 7 | 8 | 9 | class CITY(Dataset): 10 | def __init__(self, path, num_of_frame, train=True, transform=None, add_noise=False): 11 | assert os.path.exists(path), "Invalid path to CITY data set: " + path 12 | self.path = path 13 | self.transform = transform 14 | self.train = train 15 | if train: 16 | self.frame_list = Path(os.path.join(path, "leftImg8bit_sequence/train")).glob("*/*.png") 17 | else: 18 | self.frame_list = Path(os.path.join(path, "leftImg8bit_sequence/val")).glob("*/*.png") 19 | self.add_noise = add_noise 20 | self.num_of_frame = num_of_frame 21 | self.frame_list = sorted(self.frame_list) 22 | 23 | def __getitem__(self, ind): 24 | # load the images from the ind directory to get list of PIL images 25 | first_frame_ind = ind * 30 26 | last_frame_ind = (ind+1) * 30 27 | if self.train: 28 | start_ind = torch.randint(first_frame_ind, last_frame_ind - self.num_of_frame, (1,)).item() 29 | else: 30 | start_ind = first_frame_ind 31 | imgs = [Image.open(self.frame_list[start_ind + i]) for i in range(self.num_of_frame)] 32 | if self.transform is not None: 33 | imgs = self.transform(imgs) 34 | 35 | if self.add_noise: 36 | imgs = imgs + (torch.rand_like(imgs) - 0.5) / 256.0 37 | 38 | return imgs 39 | 40 | def __len__(self): 41 | # total number of videos 42 | return len(self.frame_list) // 30 43 | -------------------------------------------------------------------------------- /epsilonparam/data/datasets/climate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import os 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class ClimateData(Dataset): 9 | def __init__(self, path, mode): 10 | data = np.load(os.path.join(path, "climate_timestep/W_fields.npy"), mmap_mode="r") 11 | data = np.reshape(data, (-1, 192, 30, 128), order="F") 12 | data = np.reshape(data, (-1, 24, 8, 30, 128)) 13 | self.mean = data.mean() 14 | self.std = np.std(data) 15 | data = (data - self.mean) / self.std 16 | 17 | if mode == "train": 18 | 19 | self.t = 20 20 | train = data[:, :20, :, :, :] 21 | del data 22 | train = np.reshape(train, (-1, 8, 30, 128)) 23 | train = np.reshape(train, (-1, 30, 128)) 24 | train = np.pad(train, ((0, 0), (1, 1), (0, 0)), "symmetric") 25 | self.data = torch.from_numpy(train).float() 26 | 27 | else: 28 | 29 | self.t = 4 30 | test = data[:, 20:, :, :, :] 31 | del data 32 | test = np.reshape(test, (-1, 8, 30, 128)) 33 | test = np.reshape(test, (-1, 30, 128)) 34 | test = np.pad(test, ((0, 0), (1, 1), (0, 0)), "symmetric") 35 | self.data = torch.from_numpy(test).float() 36 | 37 | def __len__(self): 38 | 39 | return self.data.size()[0] 40 | 41 | def __getitem__(self, idx): 42 | 43 | width = self.t * 8 44 | start = int(idx / (width)) 45 | p = idx % width 46 | if p > width - 8: 47 | p = width - 8 48 | begin = start * width + p 49 | return self.data[begin : begin + 8, :, :].unsqueeze(1) 50 | 51 | -------------------------------------------------------------------------------- /epsilonparam/data/datasets/image.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class IMG(Dataset): 7 | def __init__(self, path, transform=None): 8 | assert os.path.exists(path), 'Invalid path to IMAGE data set: ' + path 9 | self.path = path 10 | self.transform = transform 11 | self.img_list = os.listdir(self.path) 12 | 13 | def __getitem__(self, ind): 14 | # load the images from the ind directory to get list of PIL images 15 | img = [Image.open(os.path.join(self.path, self.img_list[ind]))] 16 | if self.transform is not None: 17 | img = self.transform(img) 18 | if img.shape[1] == 1: 19 | img = img.expand(-1, 3, -1, -1) 20 | return img 21 | 22 | def __len__(self): 23 | # total number of videos 24 | return len(self.img_list) 25 | -------------------------------------------------------------------------------- /epsilonparam/data/datasets/kth_actions.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | import torch 5 | 6 | 7 | class KTHActions(Dataset): 8 | """ 9 | Dataset object for KTH actions dataset. The dataset must be stored 10 | with each video (action sequence) in a separate directory: 11 | /path 12 | /person01_walking_d1_0 13 | /0.png 14 | /1.png 15 | /... 16 | /person01_walking_d1_1 17 | /... 18 | """ 19 | def __init__(self, path, transform=None, add_noise=False): 20 | assert os.path.exists(path), 'Invalid path to KTH actions data set: ' + path 21 | self.path = path 22 | self.transform = transform 23 | self.video_list = os.listdir(self.path) 24 | self.add_noise = add_noise 25 | 26 | def __getitem__(self, ind): 27 | # load the images from the ind directory to get list of PIL images 28 | img_names = os.listdir(os.path.join(self.path, self.video_list[ind])) 29 | img_names = [img_name.split('.')[0] for img_name in img_names] 30 | img_names.sort(key=float) 31 | imgs = [Image.open(os.path.join(self.path, self.video_list[ind], i + '.png')).convert('L') for i in img_names] 32 | if self.transform is not None: 33 | # apply the image/video transforms 34 | imgs = self.transform(imgs) 35 | 36 | if self.add_noise: 37 | imgs += torch.randn_like(imgs)/256 38 | 39 | return imgs 40 | 41 | def __len__(self): 42 | # returns the total number of videos 43 | return len(os.listdir(self.path)) 44 | -------------------------------------------------------------------------------- /epsilonparam/data/datasets/moving_mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class MovingMNIST(Dataset): 8 | """ 9 | Dataset class for moving MNIST dataset. 10 | 11 | Args: 12 | path (str): path to the .npy dataset 13 | transform (torchvision.transforms): image/video transforms 14 | """ 15 | 16 | def __init__(self, path, transform=None, add_noise=False): 17 | assert os.path.exists(path), 'Invalid path to Moving MNIST data set: ' + path 18 | self.transform = transform 19 | self.data = np.load(path) 20 | self.add_noise = add_noise 21 | 22 | def __getitem__(self, ind): 23 | imgs = self.data[:, ind, :, :].astype('float32') 24 | s, h, w = imgs.shape 25 | imgs = imgs.reshape(s, 1, h, w) 26 | if self.transform is not None: 27 | # apply the image/video transforms 28 | imgs = self.transform(imgs) 29 | 30 | if self.add_noise: 31 | imgs += torch.randn_like(imgs)/256 32 | return imgs 33 | 34 | def __len__(self): 35 | return self.data.shape[1] 36 | -------------------------------------------------------------------------------- /epsilonparam/data/datasets/simu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision.transforms.functional as G 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class Simulation(Dataset): 8 | 9 | def __init__(self, path, number_of_frame, train, size, transform=None): 10 | 11 | data = np.load(path).astype(np.single) 12 | mmin = data.min() 13 | mmax = data.max() 14 | self.number_of_frame = number_of_frame 15 | self.transform = transform 16 | 17 | if train: 18 | self.t = 1000 19 | train = data[:8000, :, :] 20 | train = (train - mmin) / (mmax - mmin) 21 | self.data = torch.from_numpy(train) 22 | self.data = self.data.unsqueeze(1) 23 | self.data = G.resize(self.data, size) 24 | 25 | else: 26 | 27 | self.t = 250 28 | test = data[8000:, :, :] 29 | test = (test - mmin) / (mmax - mmin) 30 | self.data = torch.from_numpy(test) 31 | self.data = self.data.unsqueeze(1) 32 | self.data = G.resize(self.data, size) 33 | 34 | def __len__(self): 35 | 36 | return self.data.size()[0] 37 | 38 | def __getitem__(self, idx): 39 | 40 | width = self.t 41 | start = int(idx/(width)) 42 | p = idx % width 43 | if p > width - self.number_of_frame: 44 | p = width - self.number_of_frame 45 | begin = start * width + p 46 | frames = self.data[begin:begin + self.number_of_frame, :, :, :] 47 | return frames -------------------------------------------------------------------------------- /epsilonparam/data/datasets/stochastic_moving_mnist.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | from torchvision import datasets, transforms 5 | import torch 6 | 7 | 8 | class StochasticMovingMNIST(Dataset): 9 | """Data Handler that creates Bouncing MNIST dataset on the fly.""" 10 | def __init__(self, 11 | train, 12 | data_root, 13 | seq_len=20, 14 | num_digits=2, 15 | image_size=64, 16 | deterministic=True, 17 | add_noise=False, 18 | epoch_size=0): 19 | path = data_root 20 | self.seq_len = seq_len 21 | self.num_digits = num_digits 22 | self.image_size = image_size 23 | self.step_length = 0.1 24 | self.digit_size = 32 25 | self.deterministic = deterministic 26 | self.seed_is_set = False # multi threaded loading 27 | self.channels = 1 28 | self.add_noise = add_noise 29 | self.epoch_size = epoch_size 30 | 31 | self.data = datasets.MNIST(path, 32 | train=train, 33 | download=True, 34 | transform=transforms.Compose( 35 | [transforms.Scale(self.digit_size), 36 | transforms.ToTensor()])) 37 | 38 | self.N = len(self.data) 39 | 40 | def set_seed(self, seed): 41 | if not self.seed_is_set: 42 | self.seed_is_set = True 43 | np.random.seed(seed) 44 | 45 | def __len__(self): 46 | if self.epoch_size > 0: 47 | return self.epoch_size 48 | else: 49 | return self.N 50 | 51 | def __getitem__(self, index): 52 | self.set_seed(index) 53 | image_size = self.image_size 54 | digit_size = self.digit_size 55 | x = np.zeros((self.seq_len, image_size, image_size, self.channels), dtype=np.float32) 56 | for n in range(self.num_digits): 57 | idx = np.random.randint(self.N) 58 | digit, _ = self.data[idx] 59 | 60 | sx = np.random.randint(image_size - digit_size) 61 | sy = np.random.randint(image_size - digit_size) 62 | dx = np.random.randint(-4, 5) 63 | dy = np.random.randint(-4, 5) 64 | for t in range(self.seq_len): 65 | if sy < 0: 66 | sy = 0 67 | if self.deterministic: 68 | dy = -dy 69 | else: 70 | dy = np.random.randint(1, 5) 71 | dx = np.random.randint(-4, 5) 72 | elif sy >= image_size - 32: 73 | sy = image_size - 32 - 1 74 | if self.deterministic: 75 | dy = -dy 76 | else: 77 | dy = np.random.randint(-4, 0) 78 | dx = np.random.randint(-4, 5) 79 | 80 | if sx < 0: 81 | sx = 0 82 | if self.deterministic: 83 | dx = -dx 84 | else: 85 | dx = np.random.randint(1, 5) 86 | dy = np.random.randint(-4, 5) 87 | elif sx >= image_size - 32: 88 | sx = image_size - 32 - 1 89 | if self.deterministic: 90 | dx = -dx 91 | else: 92 | dx = np.random.randint(-4, 0) 93 | dy = np.random.randint(-4, 5) 94 | 95 | x[t, sy:sy + 32, sx:sx + 32, 0] += digit.numpy().squeeze() 96 | sy += dy 97 | sx += dx 98 | 99 | x = torch.FloatTensor(x).permute(0, 3, 1, 2).contiguous() 100 | if self.add_noise: 101 | x += torch.randn_like(x) / 256 102 | 103 | x[x < 0] = 0. 104 | x[x > 1] = 1. 105 | return x 106 | -------------------------------------------------------------------------------- /epsilonparam/data/datasets/uvg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | import torch 6 | import random 7 | 8 | 9 | class UVG(Dataset): 10 | def __init__(self, path, nframe=3, transform=None, seed=1212): 11 | assert os.path.exists(path), 'Invalid path to uvg data set: ' + path 12 | random.seed(seed) 13 | ldir = os.listdir(path) 14 | random.shuffle(ldir) 15 | self.transform = transform 16 | video_list = np.core.defchararray.add(f'{path}/', ldir) 17 | self.video_list = video_list 18 | self.nframe = nframe 19 | 20 | def __getitem__(self, ind): 21 | tot_nframe = len(os.listdir(self.video_list[ind])) 22 | assert tot_nframe >= self.nframe 23 | start_ind = torch.randint(1, 1 + tot_nframe - self.nframe, (1, )).item() 24 | imgs = [ 25 | Image.open(os.path.join(self.video_list[ind], 26 | str(img_name) + '.png')) for img_name in range(start_ind, start_ind + self.nframe) 27 | ] 28 | if self.transform is not None: 29 | imgs = self.transform(imgs) 30 | 31 | return imgs 32 | 33 | def __len__(self): 34 | # total number of videos 35 | return len(self.video_list) 36 | -------------------------------------------------------------------------------- /epsilonparam/data/datasets/vimeo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | import torch 6 | 7 | 8 | class VIMEO(Dataset): 9 | 10 | def __init__(self, path, train=True, transform=None, add_noise=False): 11 | assert os.path.exists( 12 | path), 'Invalid path to VIMEO data set: ' + path 13 | self.path = path 14 | self.transform = transform 15 | if train: 16 | self.video_list = os.path.join(path, 'sep_trainlist.txt') 17 | else: 18 | self.video_list = os.path.join(path, 'sep_testlist.txt') 19 | self.video_list = np.loadtxt(self.video_list, dtype=str) 20 | self.video_list = np.core.defchararray.add(f'{os.path.join(path, "sequences")}/', self.video_list) 21 | 22 | self.add_noise = add_noise 23 | 24 | def __getitem__(self, ind): 25 | # load the images from the ind directory to get list of PIL images 26 | img_names = os.listdir(str(self.video_list[ind])) 27 | imgs = [Image.open(os.path.join(self.video_list[ind], str(img_name))) 28 | for img_name in img_names] 29 | if self.transform is not None: 30 | imgs = self.transform(imgs) 31 | 32 | if self.add_noise: 33 | imgs = imgs + (torch.rand_like(imgs)-0.5) / 256. 34 | 35 | return imgs 36 | 37 | def __len__(self): 38 | # total number of videos 39 | return len(self.video_list) 40 | -------------------------------------------------------------------------------- /epsilonparam/data/datasets/youtube.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image, ImageChops 4 | from torch.utils.data import Dataset 5 | import torch 6 | import random 7 | 8 | 9 | class Youtube(Dataset): 10 | def __init__(self, path, nframe=3, train=True, transform=None, seed=1212): 11 | assert os.path.exists(path), 'Invalid path to youtube data set: ' + path 12 | random.seed(seed) 13 | ldir = os.listdir(path) 14 | random.shuffle(ldir) 15 | self.transform = transform 16 | video_list = np.core.defchararray.add(f'{path}/', ldir) 17 | if train: 18 | self.video_list = video_list[:-32] 19 | else: 20 | self.video_list = video_list[-32:] 21 | self.nframe = nframe 22 | 23 | def __getitem__(self, ind): 24 | tot_nframe = len(os.listdir(self.video_list[ind])) 25 | assert tot_nframe >= self.nframe 26 | start_ind = torch.randint(1, 1 + tot_nframe - self.nframe, (1, )).item() 27 | imgs = [ 28 | Image.open(os.path.join(self.video_list[ind], 29 | str(img_name) + '.png')) for img_name in range(start_ind, start_ind + self.nframe) 30 | ] 31 | if self.transform is not None: 32 | imgs = self.transform(imgs) 33 | 34 | return imgs 35 | 36 | def __len__(self): 37 | # total number of videos 38 | return len(self.video_list) 39 | -------------------------------------------------------------------------------- /epsilonparam/data/load_data.py: -------------------------------------------------------------------------------- 1 | from .load_dataset import load_dataset 2 | from .transposed_collate import train_transposed_collate, test_transposed_collate 3 | from torch.utils.data.dataloader import default_collate 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data.distributed import DistributedSampler 6 | 7 | 8 | def load_data(data_config, batch_size, num_workers=4, pin_memory=True, distributed=False): 9 | """ 10 | Wrapper around load_dataset. Gets the dataset, then places it in a DataLoader. 11 | 12 | Args: 13 | data_config (dict): data configuration dictionary 14 | batch_size (dict): run configuration dictionary 15 | num_workers (int): number of threads of multi-processed data Loading 16 | pin_memory (bool): whether or not to pin memory in cpu 17 | sequence (bool): whether data examples are sequences, in which case the 18 | data loader returns transposed batches with the sequence 19 | step as the first dimension and batch index as the 20 | second dimension 21 | """ 22 | train, val = load_dataset(data_config) 23 | train_spl = DistributedSampler(train) if distributed else None 24 | val_spl = DistributedSampler(val, shuffle=False) if distributed else None 25 | 26 | if train is not None: 27 | train = DataLoader( 28 | train, 29 | batch_size=batch_size, 30 | shuffle=False if distributed else True, 31 | collate_fn=train_transposed_collate, 32 | num_workers=num_workers, 33 | pin_memory=pin_memory, 34 | sampler=train_spl 35 | ) 36 | 37 | if val is not None: 38 | val = DataLoader( 39 | val, 40 | batch_size=batch_size, 41 | shuffle=False, 42 | collate_fn=test_transposed_collate, 43 | num_workers=num_workers, 44 | pin_memory=pin_memory, 45 | sampler=val_spl 46 | ) 47 | return train, val 48 | -------------------------------------------------------------------------------- /epsilonparam/data/misc_data_util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/misc_data_util/__init__.py -------------------------------------------------------------------------------- /epsilonparam/data/misc_data_util/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/misc_data_util/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /epsilonparam/data/misc_data_util/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/misc_data_util/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/misc_data_util/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/misc_data_util/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/misc_data_util/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/misc_data_util/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/data/misc_data_util/__pycache__/convert_bair.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/misc_data_util/__pycache__/convert_bair.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/misc_data_util/__pycache__/convert_bair.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/misc_data_util/__pycache__/convert_bair.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/misc_data_util/__pycache__/convert_kth_actions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/misc_data_util/__pycache__/convert_kth_actions.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/misc_data_util/__pycache__/convert_kth_actions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/misc_data_util/__pycache__/convert_kth_actions.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/misc_data_util/__pycache__/kth_actions_frames.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/misc_data_util/__pycache__/kth_actions_frames.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/misc_data_util/__pycache__/kth_actions_frames.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/misc_data_util/__pycache__/kth_actions_frames.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/misc_data_util/__pycache__/transforms.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/misc_data_util/__pycache__/transforms.cpython-310.pyc -------------------------------------------------------------------------------- /epsilonparam/data/misc_data_util/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/misc_data_util/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/misc_data_util/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/misc_data_util/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/misc_data_util/__pycache__/transforms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/misc_data_util/__pycache__/transforms.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/data/misc_data_util/__pycache__/url_save.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/misc_data_util/__pycache__/url_save.cpython-310.pyc -------------------------------------------------------------------------------- /epsilonparam/data/misc_data_util/__pycache__/url_save.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/misc_data_util/__pycache__/url_save.cpython-37.pyc -------------------------------------------------------------------------------- /epsilonparam/data/misc_data_util/__pycache__/url_save.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/misc_data_util/__pycache__/url_save.cpython-38.pyc -------------------------------------------------------------------------------- /epsilonparam/data/misc_data_util/__pycache__/url_save.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/data/misc_data_util/__pycache__/url_save.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/data/misc_data_util/convert_bair.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import tensorflow.compat.v1 as tf 5 | from tensorflow.python.platform import gfile 6 | from imageio import imwrite as imsave 7 | 8 | # Convert BAIR robot pushing data to numpy to use with PyTorch 9 | # Based on Emily Denton's script: https://github.com/edenton/svg/blob/master/data/convert_bair.py 10 | 11 | 12 | def convert(data_path): 13 | # iterate through the data splits 14 | for data_split in ['train', 'test']: 15 | os.makedirs(os.path.join(data_path, data_split)) 16 | data_split_path = os.path.join(data_path, 'softmotion30_44k', data_split) 17 | data_split_files = gfile.Glob(os.path.join(data_split_path, '*')) 18 | # iterate through the TF records 19 | for f in data_split_files: 20 | print('Current file: ' + f) 21 | ind = int(f.split('/')[-1].split('_')[1]) # starting video index 22 | # iterate through the sequences in this TF record 23 | for serialized_example in tf.python_io.tf_record_iterator(f): 24 | os.makedirs(os.path.join(data_path, data_split, str(ind))) 25 | example = tf.train.Example() 26 | example.ParseFromString(serialized_example) 27 | # iterate through the sequence 28 | for i in range(30): 29 | image_name = str(i) + '/image_aux1/encoded' 30 | byte_str = example.features.feature[image_name].bytes_list.value[0] 31 | img = Image.frombytes('RGB', (64, 64), byte_str) 32 | img = np.array(img.getdata()).reshape(img.size[1], img.size[0], 3) / 255. 33 | imsave(os.path.join(data_path, data_split, str(ind), str(i) + '.png'), img) 34 | print(' Finished processing sequence ' + str(ind)) 35 | ind += 1 36 | -------------------------------------------------------------------------------- /epsilonparam/data/misc_data_util/convert_kth_actions.py: -------------------------------------------------------------------------------- 1 | import os 2 | from imageio import imwrite as imsave 3 | from moviepy.editor import VideoFileClip 4 | from .kth_actions_frames import kth_actions_dict 5 | 6 | settings = ['d1', 'd2', 'd3', 'd4'] 7 | actions = ['walking', 'jogging', 'running', 'boxing', 'handwaving', 'handclapping'] 8 | person_ids = {'train': ['11', '12', '13', '14', '15', '16', '17', '18'], 9 | 'val': ['19', '20', '21', '23', '24', '25', '01', '04'], 10 | 'test': ['22', '02', '03', '05', '06', '07', '08', '09', '10']} 11 | 12 | 13 | def convert(data_path): 14 | # iterate through the data splits 15 | for data_split in ['train', 'val', 'test']: 16 | print('Converting ' + data_split) 17 | os.makedirs(os.path.join(data_path, data_split)) 18 | split_person_ids = person_ids[data_split] 19 | # iterate through the ids, actions, and settings for this split 20 | for person_id in split_person_ids: 21 | print(' Converting person' + person_id) 22 | for action in kth_actions_dict['person'+person_id]: 23 | for setting in kth_actions_dict['person'+person_id][action]: 24 | frame_nums = kth_actions_dict['person'+person_id][action][setting] 25 | if len(frame_nums) > 0: 26 | start_frames = [frame_pair[0] for frame_pair in frame_nums] 27 | end_frames = [frame_pair[1] for frame_pair in frame_nums] 28 | # load the video 29 | file_name = 'person' + person_id + '_' + action + '_' + setting + '_uncomp.avi' 30 | print(file_name) 31 | video = VideoFileClip(os.path.join(data_path, action, file_name)) 32 | # write each sequence to a directory 33 | sequence_frame_index = 0 34 | sequence_index = 0 35 | sequence_name = '' 36 | in_sequence = False 37 | for frame_index, frame in enumerate(video.iter_frames()): 38 | if frame_index + 1 in start_frames: 39 | # start a new sequence 40 | in_sequence = True 41 | sequence_frame_index = 0 42 | sequence_name = 'person' + person_id + '_' + action + '_' + setting + '_' + str(sequence_index) 43 | os.makedirs(os.path.join(data_path, data_split, sequence_name)) 44 | if frame_index + 1 in end_frames: 45 | # end the current sequence 46 | in_sequence = False 47 | sequence_index += 1 48 | if frame_index + 1 == max(end_frames): 49 | break 50 | if in_sequence: 51 | # write frame to the current sequence 52 | frame = frame.astype('float32') / 255. 53 | imsave(os.path.join(data_path, data_split, sequence_name, str(sequence_frame_index) + '.png'), frame) 54 | sequence_frame_index += 1 55 | del video.reader 56 | del video 57 | -------------------------------------------------------------------------------- /epsilonparam/data/misc_data_util/url_save.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import urllib 5 | 6 | """ 7 | Utility functions for downloading files. 8 | """ 9 | 10 | def report_hook(count, block_size, total_size): 11 | # to display download progress 12 | # see https://blog.shichao.io/2012/10/04/progress_speed_indicator_for_urlretrieve_in_python.html 13 | global start_time 14 | if count == 0: 15 | start_time = time.time() 16 | return 17 | duration = time.time() - start_time 18 | progress_size = int(count * block_size) 19 | speed = int(progress_size / (1024 * duration)) 20 | percent = min(int(count*block_size*100/total_size),100) 21 | sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" % 22 | (percent, progress_size / (1024 * 1024), speed, duration)) 23 | sys.stdout.flush() 24 | 25 | def save(url, file_name): 26 | urllib.request.urlretrieve(url, file_name, report_hook) 27 | -------------------------------------------------------------------------------- /epsilonparam/data/transposed_collate.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataloader import default_collate 2 | import torch 3 | 4 | 5 | def train_transposed_collate(batch): 6 | """ 7 | Wrapper around the default collate function to return sequences of PyTorch 8 | tensors with sequence step as the first dimension and batch index as the 9 | second dimension. 10 | 11 | Args: 12 | batch (list): data examples 13 | """ 14 | batch = filter(lambda img: img is not None, batch) 15 | collated_batch = default_collate(list(batch)) 16 | transposed_batch = collated_batch.transpose_(0, 1) 17 | # assert transposed_batch.shape[0] >= 4 18 | # idx = torch.randint(4, transposed_batch.shape[0] + 1, size=(1,)).item() 19 | # return transposed_batch[:idx] 20 | return transposed_batch 21 | 22 | 23 | def test_transposed_collate(batch): 24 | """ 25 | Wrapper around the default collate function to return sequences of PyTorch 26 | tensors with sequence step as the first dimension and batch index as the 27 | second dimension. 28 | 29 | Args: 30 | batch (list): data examples 31 | """ 32 | batch = filter(lambda img: img is not None, batch) 33 | collated_batch = default_collate(list(batch)) 34 | transposed_batch = collated_batch.transpose_(0, 1) 35 | return transposed_batch 36 | -------------------------------------------------------------------------------- /epsilonparam/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/modules/__init__.py -------------------------------------------------------------------------------- /epsilonparam/modules/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/modules/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/modules/__pycache__/compress_modules.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/modules/__pycache__/compress_modules.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/modules/__pycache__/denoising_diffusion.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/modules/__pycache__/denoising_diffusion.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/modules/__pycache__/network_components.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/modules/__pycache__/network_components.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/modules/__pycache__/unet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/modules/__pycache__/unet.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/modules/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/epsilonparam/modules/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /epsilonparam/modules/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pathlib import Path 3 | from torch.optim import Adam, AdamW 4 | from torch.utils.tensorboard import SummaryWriter 5 | 6 | from .utils import cycle 7 | from torch.optim.lr_scheduler import LambdaLR 8 | 9 | 10 | class EMA: 11 | def __init__(self, beta): 12 | super().__init__() 13 | self.beta = beta 14 | 15 | def update_model_average(self, ma_model, current_model): 16 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 17 | old_weight, up_weight = ma_params.data, current_params.data 18 | ma_params.data = self.update_average(old_weight, up_weight) 19 | 20 | def update_average(self, old, new): 21 | if old is None: 22 | return new 23 | return old * self.beta + (1 - self.beta) * new 24 | 25 | 26 | # trainer class 27 | class Trainer(object): 28 | def __init__( 29 | self, 30 | rank, 31 | sample_steps, 32 | diffusion_model, 33 | train_dl, 34 | val_dl, 35 | scheduler_function, 36 | ema_decay=0.995, 37 | train_lr=1e-4, 38 | train_num_steps=1000000, 39 | scheduler_checkpoint_step=100000, 40 | step_start_ema=2000, 41 | update_ema_every=10, 42 | save_and_sample_every=1000, 43 | results_folder="./results", 44 | tensorboard_dir="./tensorboard_logs/diffusion-video/", 45 | model_name="model", 46 | val_num_of_batch=1, 47 | optimizer="adam", 48 | sample_mode="ddpm" 49 | ): 50 | super().__init__() 51 | self.model = diffusion_model 52 | # self.ema = EMA(ema_decay) 53 | # self.ema_model = copy.deepcopy(self.model) 54 | self.sample_mode = sample_mode 55 | # self.update_ema_every = update_ema_every 56 | self.val_num_of_batch = val_num_of_batch 57 | self.sample_steps = sample_steps 58 | 59 | # self.step_start_ema = step_start_ema 60 | self.save_and_sample_every = save_and_sample_every 61 | 62 | self.train_num_steps = train_num_steps 63 | 64 | self.train_dl_class = train_dl 65 | self.val_dl_class = val_dl 66 | self.train_dl = cycle(train_dl) 67 | self.val_dl = cycle(val_dl) 68 | if optimizer == "adam": 69 | self.opt = Adam(self.model.parameters(), lr=train_lr) 70 | elif optimizer == "adamw": 71 | self.opt = AdamW(self.model.parameters(), lr=train_lr) 72 | self.scheduler = LambdaLR(self.opt, lr_lambda=scheduler_function) 73 | 74 | self.step = 0 75 | self.device = rank 76 | self.scheduler_checkpoint_step = scheduler_checkpoint_step 77 | 78 | self.results_folder = Path(results_folder) 79 | self.results_folder.mkdir(exist_ok=True) 80 | self.model_name = model_name 81 | 82 | # if os.path.isdir(tensorboard_dir): 83 | # shutil.rmtree(tensorboard_dir) 84 | self.writer = SummaryWriter(tensorboard_dir) 85 | 86 | # self.reset_parameters() 87 | 88 | # def reset_parameters(self): 89 | # # self.ema_model.load_state_dict(self.model.state_dict()) 90 | # pass 91 | 92 | # def step_ema(self): 93 | # # if self.step < self.step_start_ema: 94 | # # self.reset_parameters() 95 | # # else: 96 | # # self.ema.update_model_average(self.ema_model, self.model) 97 | # pass 98 | 99 | def save(self): 100 | data = { 101 | "step": self.step, 102 | "model": self.model.state_dict(), 103 | # "ema": self.ema_model.module.state_dict(), 104 | } 105 | idx = (self.step // self.save_and_sample_every) % 3 106 | torch.save(data, str(self.results_folder / f"{self.model_name}_{idx}.pt")) 107 | 108 | def load(self, idx=0, load_step=True): 109 | data = torch.load( 110 | str(self.results_folder / f"{self.model_name}_{idx}.pt"), 111 | map_location=lambda storage, loc: storage, 112 | ) 113 | 114 | if load_step: 115 | self.step = data["step"] 116 | try: 117 | self.model.module.load_state_dict(data["model"], strict=False) 118 | except: 119 | self.model.load_state_dict(data["model"], strict=False) 120 | # self.ema_model.module.load_state_dict(data["ema"], strict=False) 121 | 122 | def train(self): 123 | 124 | while self.step < self.train_num_steps: 125 | self.opt.zero_grad() 126 | if (self.step >= self.scheduler_checkpoint_step) and (self.step != 0): 127 | self.scheduler.step() 128 | data = next(self.train_dl).to(self.device)[0] 129 | self.model.train() 130 | loss, aloss = self.model(data * 2.0 - 1.0) 131 | loss.backward() 132 | aloss.backward() 133 | self.writer.add_scalar("loss", loss.item(), self.step) 134 | 135 | self.opt.step() 136 | 137 | if (self.step % self.save_and_sample_every == 0): 138 | # milestone = self.step // self.save_and_sample_every 139 | for i, batch in enumerate(self.val_dl): 140 | if i >= self.val_num_of_batch: 141 | break 142 | if self.model.vbr: 143 | scaler = torch.zeros(batch.shape[1]).unsqueeze(1).to(self.device) 144 | else: 145 | scaler = None 146 | self.model.eval() 147 | compressed, bpp = self.model.compress( 148 | batch[0].to(self.device) * 2.0 - 1.0, self.sample_steps, scaler, self.sample_mode 149 | ) 150 | compressed = (compressed + 1.0) * 0.5 151 | self.writer.add_scalar( 152 | f"bpp/num{i}", 153 | bpp, 154 | self.step // self.save_and_sample_every, 155 | ) 156 | self.writer.add_images( 157 | f"compressed/num{i}", 158 | compressed.clamp(0.0, 1.0), 159 | self.step // self.save_and_sample_every, 160 | ) 161 | self.writer.add_images( 162 | f"original/num{i}", 163 | batch[0], 164 | self.step // self.save_and_sample_every, 165 | ) 166 | self.save() 167 | 168 | self.step += 1 169 | self.save() 170 | print("training completed") 171 | -------------------------------------------------------------------------------- /epsilonparam/modules/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .utils import exists, default 4 | from .network_components import ( 5 | LayerNorm, 6 | Residual, 7 | # SinusoidalPosEmb, 8 | Upsample, 9 | Downsample, 10 | PreNorm, 11 | LinearAttention, 12 | # Block, 13 | ResnetBlock 14 | ) 15 | 16 | 17 | class Unet(nn.Module): 18 | def __init__( 19 | self, 20 | dim, 21 | out_dim=None, 22 | dim_mults=(1, 2, 4, 8), 23 | context_dim_mults=(1, 2, 3, 3), 24 | channels=3, 25 | context_channels=3, 26 | with_time_emb=True, 27 | ): 28 | super().__init__() 29 | self.channels = channels 30 | 31 | dims = [channels, *map(lambda m: dim * m, dim_mults)] 32 | context_dims = [context_channels, *map(lambda m: dim * m, context_dim_mults)] 33 | in_out = list(zip(dims[:-1], dims[1:])) 34 | 35 | if with_time_emb: 36 | time_dim = dim 37 | # self.time_mlp = nn.Sequential( 38 | # SinusoidalPosEmb(dim), nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim) 39 | # ) 40 | self.time_mlp = nn.Sequential(nn.Linear(1, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)) 41 | else: 42 | time_dim = None 43 | self.time_mlp = None 44 | 45 | self.downs = nn.ModuleList([]) 46 | self.ups = nn.ModuleList([]) 47 | num_resolutions = len(in_out) 48 | 49 | for ind, (dim_in, dim_out) in enumerate(in_out): 50 | is_last = ind >= (num_resolutions - 1) 51 | 52 | self.downs.append( 53 | nn.ModuleList( 54 | [ 55 | ResnetBlock( 56 | dim_in + context_dims[ind] 57 | if (not is_last) and (ind < (len(context_dims) - 1)) 58 | else dim_in, 59 | dim_out, 60 | time_dim, 61 | True if ind == 0 else False 62 | ), 63 | ResnetBlock(dim_out, dim_out, time_dim), 64 | Residual(PreNorm(dim_out, LinearAttention(dim_out))), 65 | # nn.Identity(), 66 | Downsample(dim_out) if not is_last else nn.Identity(), 67 | ] 68 | ) 69 | ) 70 | 71 | mid_dim = dims[-1] 72 | self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_dim) 73 | self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention(mid_dim))) 74 | # self.mid_attn = nn.Identity() 75 | self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_dim) 76 | 77 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 78 | is_last = ind >= (num_resolutions - 1) 79 | 80 | self.ups.append( 81 | nn.ModuleList( 82 | [ 83 | ResnetBlock(dim_out * 2, dim_in, time_dim), 84 | ResnetBlock(dim_in, dim_in, time_dim), 85 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 86 | # nn.Identity(), 87 | Upsample(dim_in) if not is_last else nn.Identity(), 88 | ] 89 | ) 90 | ) 91 | 92 | out_dim = default(out_dim, channels) 93 | self.final_conv = nn.Sequential(LayerNorm(dim), nn.Conv2d(dim, out_dim, 7, padding=3)) 94 | 95 | def encode(self, x, t, context): 96 | h = [] 97 | for idx, (backbone, backbone2, attn, downsample) in enumerate(self.downs): 98 | x = torch.cat([x, context[idx]], dim=1) if idx < len(context) else x 99 | x = backbone(x, t) 100 | x = backbone2(x, t) 101 | x = attn(x) 102 | h.append(x) 103 | x = downsample(x) 104 | 105 | x = self.mid_block1(x, t) 106 | return x, h 107 | 108 | def decode(self, x, h, t): 109 | x = self.mid_attn(x) 110 | x = self.mid_block2(x, t) 111 | 112 | for backbone, backbone2, attn, upsample in self.ups: 113 | x = torch.cat((x, h.pop()), dim=1) 114 | x = backbone(x, t) 115 | x = backbone2(x, t) 116 | x = attn(x) 117 | x = upsample(x) 118 | return self.final_conv(x) 119 | 120 | def forward(self, x, time=None, context=None): 121 | t = self.time_mlp(time) if exists(self.time_mlp) else None 122 | 123 | x, h = self.encode(x, t, context) 124 | return self.decode(x, h, t) 125 | -------------------------------------------------------------------------------- /epsilonparam/modules/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from inspect import isfunction 3 | from torch.autograd import Function 4 | import numpy as np 5 | 6 | 7 | def exists(x): 8 | return x is not None 9 | 10 | 11 | def default(val, d): 12 | if exists(val): 13 | return val 14 | return d() if isfunction(d) else d 15 | 16 | 17 | def cycle(dl): 18 | while True: 19 | for data in dl: 20 | yield data 21 | 22 | 23 | def num_to_groups(num, divisor): 24 | groups = num // divisor 25 | remainder = num % divisor 26 | arr = [divisor] * groups 27 | if remainder > 0: 28 | arr.append(remainder) 29 | return arr 30 | 31 | 32 | def extract(a, t, x_shape): 33 | b, *_ = t.shape 34 | out = a.gather(-1, t) 35 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 36 | 37 | 38 | def extract_tensor(a, t, place_holder=None): 39 | return a[t, torch.arange(len(t))] 40 | 41 | 42 | def noise_like(shape, device, repeat=False): 43 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( 44 | shape[0], *((1,) * (len(shape) - 1)) 45 | ) 46 | noise = lambda: torch.randn(shape, device=device) 47 | return repeat_noise() if repeat else noise() 48 | 49 | 50 | def cosine_beta_schedule(timesteps, s=0.008): 51 | """ 52 | cosine schedule 53 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 54 | """ 55 | steps = timesteps + 1 56 | x = np.linspace(0, steps, steps) 57 | alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 58 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 59 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 60 | return np.clip(betas, a_min=0, a_max=0.999) 61 | 62 | def linear_beta_schedule(timesteps): 63 | scale = 1000 / timesteps 64 | beta_start = scale * 0.0001 65 | beta_end = scale * 0.02 66 | return np.linspace(beta_start, beta_end, timesteps) 67 | 68 | 69 | def noise(input, scale): 70 | return input + scale*(torch.rand_like(input) - 0.5) 71 | 72 | 73 | def round_w_offset(input, loc): 74 | diff = STERound.apply(input - loc) 75 | return diff + loc 76 | 77 | 78 | def quantize(x, mode='noise', offset=None): 79 | if mode == 'noise': 80 | return noise(x, 1) 81 | elif mode == 'round': 82 | return STERound.apply(x) 83 | elif mode == 'dequantize': 84 | return round_w_offset(x, offset) 85 | else: 86 | raise NotImplementedError 87 | 88 | 89 | class STERound(Function): 90 | @staticmethod 91 | def forward(ctx, x): 92 | return x.round() 93 | 94 | @staticmethod 95 | def backward(ctx, g): 96 | return g 97 | 98 | 99 | class LowerBound(Function): 100 | @staticmethod 101 | def forward(ctx, inputs, bound): 102 | b = torch.ones_like(inputs) * bound 103 | ctx.save_for_backward(inputs, b) 104 | return torch.max(inputs, b) 105 | 106 | @staticmethod 107 | def backward(ctx, grad_output): 108 | inputs, b = ctx.saved_tensors 109 | 110 | pass_through_1 = inputs >= b 111 | pass_through_2 = grad_output < 0 112 | 113 | pass_through = pass_through_1 | pass_through_2 114 | return pass_through.type(grad_output.dtype) * grad_output, None 115 | 116 | 117 | class UpperBound(Function): 118 | @staticmethod 119 | def forward(ctx, inputs, bound): 120 | b = torch.ones_like(inputs) * bound 121 | ctx.save_for_backward(inputs, b) 122 | return torch.min(inputs, b) 123 | 124 | @staticmethod 125 | def backward(ctx, grad_output): 126 | inputs, b = ctx.saved_tensors 127 | 128 | pass_through_1 = inputs <= b 129 | pass_through_2 = grad_output > 0 130 | 131 | pass_through = pass_through_1 | pass_through_2 132 | return pass_through.type(grad_output.dtype) * grad_output, None 133 | 134 | 135 | class NormalDistribution: 136 | ''' 137 | A normal distribution 138 | ''' 139 | def __init__(self, loc, scale): 140 | assert loc.shape == scale.shape 141 | self.loc = loc 142 | self.scale = scale 143 | 144 | @property 145 | def mean(self): 146 | return self.loc.detach() 147 | 148 | def std_cdf(self, inputs): 149 | half = 0.5 150 | const = -(2**-0.5) 151 | return half * torch.erfc(const * inputs) 152 | 153 | def sample(self): 154 | return self.scale * torch.randn_like(self.scale) + self.loc 155 | 156 | def likelihood(self, x, min=1e-9): 157 | x = torch.abs(x - self.loc) 158 | upper = self.std_cdf((.5 - x) / self.scale) 159 | lower = self.std_cdf((-.5 - x) / self.scale) 160 | return LowerBound.apply(upper - lower, min) 161 | 162 | def scaled_likelihood(self, x, s=1, min=1e-9): 163 | x = torch.abs(x - self.loc) 164 | s = s * .5 165 | upper = self.std_cdf((s - x) / self.scale) 166 | lower = self.std_cdf((-s - x) / self.scale) 167 | return LowerBound.apply(upper - lower, min) -------------------------------------------------------------------------------- /epsilonparam/test_epsilonparam.py: -------------------------------------------------------------------------------- 1 | from data import load_data 2 | import argparse 3 | import os 4 | import torch 5 | import torchvision 6 | import numpy as np 7 | import pathlib 8 | import config 9 | from modules.denoising_diffusion import GaussianDiffusion 10 | from modules.unet import Unet 11 | from modules.compress_modules import BigCompressor, SimpleCompressor 12 | 13 | 14 | parser = argparse.ArgumentParser(description="values from bash script") 15 | 16 | parser.add_argument("--ckpt", type=str, required=True) # ckpt path 17 | parser.add_argument("--gamma", type=float, default=0.8) # noise intensity for decoding 18 | parser.add_argument("--n_denoise_step", type=int, default=200) # number of denoising step 19 | parser.add_argument("--device", type=int, default=0) # gpu device index 20 | parser.add_argument("--img_dir", type=str, default='../imgs') 21 | parser.add_argument("--out_dir", type=str, default='../compressed_imgs') 22 | parser.add_argument("--lpips_weight", type=float, required=True) # either 0.9 or 0.0, note that this must match the ckpt you use, because with weight>0, the lpips-vggnet weights were also saved during training. Incorrect state_dict keys may lead to load_state_dict error when loading the ckpt. 23 | args = parser.parse_args() 24 | 25 | def main(rank): 26 | 27 | denoise_model = Unet( 28 | dim=64, 29 | channels=3, 30 | context_channels=3, 31 | dim_mults=(1, 2, 3, 4, 5, 6), 32 | context_dim_mults=(1, 2, 3, 4), 33 | ) 34 | 35 | context_model = BigCompressor( 36 | dim=64, 37 | dim_mults=(1, 2, 3, 4), 38 | hyper_dims_mults=(4, 4, 4), 39 | channels=3, 40 | out_channels=3, 41 | vbr=False, 42 | ) 43 | 44 | diffusion = GaussianDiffusion( 45 | denoise_fn=denoise_model, 46 | context_fn=context_model, 47 | num_timesteps=20000, 48 | loss_type="l1", 49 | clip_noise="none", 50 | vbr=False, 51 | lagrangian=0.9, 52 | pred_mode="noise", 53 | var_schedule="linear", 54 | aux_loss_weight=args.lpips_weight, 55 | aux_loss_type="lpips" 56 | ) 57 | 58 | loaded_param = torch.load( 59 | args.ckpt, 60 | map_location=lambda storage, loc: storage, 61 | ) 62 | 63 | diffusion.load_state_dict(loaded_param["model"]) 64 | diffusion.to(rank) 65 | diffusion.eval() 66 | 67 | for img in os.listdir(args.img_dir): 68 | if img.endswith(".png") or img.endswith(".jpg"): 69 | to_be_compressed = torchvision.io.read_image(os.path.join(args.img_dir, img)).unsqueeze(0).float().to(rank) / 255.0 70 | compressed, bpp = diffusion.compress( 71 | to_be_compressed.to(rank) * 2.0 - 1.0, 72 | sample_steps=args.n_denoise_step, 73 | sample_mode="ddim", 74 | bpp_return_mean=False, 75 | init=torch.randn_like(to_be_compressed) * args.gamma 76 | ) 77 | compressed = compressed.clamp(-1, 1) / 2.0 + 0.5 78 | pathlib.Path(args.out_dir).mkdir(parents=True, exist_ok=True) 79 | torchvision.utils.save_image(compressed.cpu(), os.path.join(args.out_dir, img)) 80 | print("bpp:", bpp) 81 | 82 | 83 | if __name__ == "__main__": 84 | main(args.device) 85 | -------------------------------------------------------------------------------- /epsilonparam/train.py: -------------------------------------------------------------------------------- 1 | from data import load_data 2 | import argparse 3 | import os 4 | import torch.distributed as dist 5 | import torch.multiprocessing as mp 6 | from modules.denoising_diffusion import GaussianDiffusion 7 | from modules.unet import Unet 8 | from modules.trainer import Trainer 9 | from modules.compress_modules import BigCompressor, SimpleCompressor 10 | import config 11 | 12 | 13 | parser = argparse.ArgumentParser(description="values from bash script") 14 | parser.add_argument("--device", type=int, required=True, help="cuda device number") 15 | parser.add_argument("--beta", type=float, required=True, help="beta") 16 | parser.add_argument("--alpha", type=float, required=True, help="alpha") 17 | args = parser.parse_args() 18 | 19 | model_name = ( 20 | f"{config.compressor}-{config.loss_type}-{config.data_config['dataset_name']}" 21 | f"-d{config.embed_dim}-t{config.iteration_step}-b{args.beta}-vbr{config.vbr}" 22 | f"-{config.pred_mode}-{config.var_schedule}-aux{args.alpha}{config.aux_loss_type if args.alpha>0 else ''}{config.additional_note}" 23 | ) 24 | 25 | 26 | def schedule_func(ep): 27 | return max(config.decay ** ep, config.minf) 28 | 29 | 30 | def main(): 31 | 32 | train_data, val_data = load_data( 33 | config.data_config, 34 | config.batch_size, 35 | pin_memory=False, 36 | num_workers=config.n_workers, 37 | ) 38 | 39 | denoise_model = Unet( 40 | dim=config.embed_dim, 41 | channels=config.data_config["img_channel"], 42 | context_channels=config.context_channels, 43 | dim_mults=config.dim_mults, 44 | context_dim_mults=config.context_dim_mults 45 | ) 46 | 47 | if config.compressor == 'big': 48 | context_model = BigCompressor( 49 | dim=config.embed_dim, 50 | dim_mults=config.context_dim_mults, 51 | hyper_dims_mults=config.hyper_dim_mults, 52 | channels=config.data_config["img_channel"], 53 | out_channels=config.context_channels, 54 | vbr=config.vbr 55 | ) 56 | elif config.compressor == 'simple': 57 | context_model = SimpleCompressor( 58 | dim=config.embed_dim, 59 | dim_mults=config.context_dim_mults, 60 | hyper_dims_mults=config.hyper_dim_mults, 61 | channels=config.data_config["img_channel"], 62 | out_channels=config.context_channels, 63 | vbr=config.vbr 64 | ) 65 | else: 66 | raise NotImplementedError 67 | 68 | diffusion = GaussianDiffusion( 69 | denoise_fn=denoise_model, 70 | context_fn=context_model, 71 | clip_noise=config.clip_noise, 72 | num_timesteps=config.iteration_step, 73 | loss_type=config.loss_type, 74 | vbr=config.vbr, 75 | lagrangian=args.beta, 76 | pred_mode=config.pred_mode, 77 | aux_loss_weight=args.alpha, 78 | aux_loss_type=config.aux_loss_type, 79 | var_schedule=config.var_schedule 80 | ).to(args.device) 81 | 82 | trainer = Trainer( 83 | rank=args.device, 84 | sample_steps=config.sample_steps, 85 | diffusion_model=diffusion, 86 | train_dl=train_data, 87 | val_dl=val_data, 88 | scheduler_function=schedule_func, 89 | scheduler_checkpoint_step=config.scheduler_checkpoint_step, 90 | train_lr=config.lr, 91 | train_num_steps=config.n_step, 92 | save_and_sample_every=config.log_checkpoint_step, 93 | results_folder=os.path.join(config.result_root, f"{model_name}/"), 94 | tensorboard_dir=os.path.join(config.tensorboard_root, f"{model_name}/"), 95 | model_name=model_name, 96 | val_num_of_batch=config.val_num_of_batch, 97 | optimizer=config.optimizer, 98 | sample_mode=config.sample_mode 99 | ) 100 | 101 | if config.load_model: 102 | trainer.load(load_step=config.load_step) 103 | 104 | trainer.train() 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /imgs/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/imgs/1.png -------------------------------------------------------------------------------- /imgs/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/imgs/2.png -------------------------------------------------------------------------------- /imgs/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/imgs/3.png -------------------------------------------------------------------------------- /xparam/config.py: -------------------------------------------------------------------------------- 1 | # training config 2 | n_step = 1000000 3 | scheduler_checkpoint_step = 100000 4 | log_checkpoint_step = 5000 5 | gradient_accumulate_every = 1 6 | lr = 4e-5 7 | decay = 0.9 8 | minf = 0.5 9 | ema_decay = 0.95 10 | optimizer = "adam" # adamw or adam 11 | ema_step = 5 12 | ema_start_step = 2000 13 | n_workers = 4 14 | 15 | # load 16 | load_model = True 17 | load_step = True 18 | 19 | # diffusion config 20 | pred_mode = 'noise' 21 | loss_type = "l1" 22 | iteration_step = 20000 23 | sample_steps = 1000 24 | embed_dim = 64 25 | dim_mults = (1, 2, 3, 4, 5, 6) 26 | hyper_dim_mults = (4, 4, 4) 27 | context_channels = 3 28 | clip_noise = False 29 | val_num_of_batch = 1 30 | additional_note = "" 31 | context_dim_mults = (1, 2, 3, 4) 32 | sample_mode = "ddim" 33 | var_schedule = "cosine" 34 | aux_loss_type = "lpips" 35 | 36 | # data config 37 | data_config = { 38 | "dataset_name": "vimeo", 39 | "data_path": "/extra/ucibdl0/shared/data", 40 | "sequence_length": 1, 41 | "img_size": 256, 42 | "img_channel": 3, 43 | "add_noise": False, 44 | "img_hz_flip": False, 45 | } 46 | 47 | batch_size = 4 48 | 49 | result_root = "/extra/ucibdl0/ruihan/params_compress_v7" 50 | tensorboard_root = "/extra/ucibdl0/ruihan/tblogs_compress_v7" 51 | -------------------------------------------------------------------------------- /xparam/config_ae.py: -------------------------------------------------------------------------------- 1 | # training config 2 | n_step = 1000000 3 | scheduler_checkpoint_step = 100000 4 | log_checkpoint_step = 5000 5 | lr = 4e-5 6 | decay = 0.9 7 | minf = 0.5 8 | optimizer = "adam" # adamw or adam 9 | n_workers = 4 10 | 11 | # load 12 | load_model = True 13 | load_step = True 14 | 15 | # data config 16 | data_config = { 17 | "dataset_name": "vimeo", 18 | "data_path": "/extra/ucibdl0/shared/data", 19 | "sequence_length": 1, 20 | "img_size": 256, 21 | "img_channel": 3, 22 | "add_noise": False, 23 | "img_hz_flip": False, 24 | } 25 | 26 | batch_size = 4 27 | result_root = "/extra/ucibdl0/ruihan/params_compress_v7" 28 | tensorboard_root = "/extra/ucibdl0/ruihan/tblogs_compress_v7" 29 | -------------------------------------------------------------------------------- /xparam/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .load_data import load_data 2 | -------------------------------------------------------------------------------- /xparam/data/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/data/__pycache__/load_data.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/__pycache__/load_data.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/data/__pycache__/load_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/__pycache__/load_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .bair_robot_pushing import BAIRRobotPushing 2 | # from .bair_robot import BAIRRobot 3 | from .kth_actions import KTHActions 4 | from .moving_mnist import MovingMNIST 5 | from .stochastic_moving_mnist import StochasticMovingMNIST 6 | from .bouncing_ball_creator import make_bouncing_ball_dataset 7 | from .bouncing_ball import BouncingBall 8 | from .big import BIG 9 | from .image import IMG 10 | from .vimeo import VIMEO 11 | from .youtube import Youtube 12 | from .uvg import UVG 13 | from .audi import AUDI 14 | from .climate import ClimateData 15 | from .city import CITY 16 | from .simu import Simulation -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/audi.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/audi.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/audi.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/audi.cpython-38.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/audi.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/audi.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/bair_robot_pushing.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/bair_robot_pushing.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/bair_robot_pushing.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/bair_robot_pushing.cpython-37.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/bair_robot_pushing.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/bair_robot_pushing.cpython-38.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/bair_robot_pushing.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/bair_robot_pushing.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/big.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/big.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/big.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/big.cpython-37.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/big.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/big.cpython-38.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/big.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/big.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/bouncing_ball.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/bouncing_ball.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/bouncing_ball.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/bouncing_ball.cpython-37.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/bouncing_ball.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/bouncing_ball.cpython-38.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/bouncing_ball.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/bouncing_ball.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/bouncing_ball_creator.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/bouncing_ball_creator.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/bouncing_ball_creator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/bouncing_ball_creator.cpython-37.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/bouncing_ball_creator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/bouncing_ball_creator.cpython-38.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/bouncing_ball_creator.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/bouncing_ball_creator.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/city.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/city.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/city.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/city.cpython-38.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/city.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/city.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/climate.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/climate.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/climate.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/climate.cpython-38.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/climate.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/climate.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/freedoc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/freedoc.cpython-37.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/image.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/image.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/image.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/image.cpython-37.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/image.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/image.cpython-38.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/image.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/image.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/kinetics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/kinetics.cpython-37.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/kth_actions.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/kth_actions.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/kth_actions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/kth_actions.cpython-37.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/kth_actions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/kth_actions.cpython-38.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/kth_actions.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/kth_actions.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/moving_mnist.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/moving_mnist.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/moving_mnist.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/moving_mnist.cpython-37.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/moving_mnist.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/moving_mnist.cpython-38.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/moving_mnist.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/moving_mnist.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/simu.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/simu.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/simu.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/simu.cpython-38.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/simu.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/simu.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/stochastic_moving_mnist.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/stochastic_moving_mnist.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/stochastic_moving_mnist.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/stochastic_moving_mnist.cpython-37.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/stochastic_moving_mnist.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/stochastic_moving_mnist.cpython-38.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/stochastic_moving_mnist.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/stochastic_moving_mnist.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/uvg.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/uvg.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/uvg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/uvg.cpython-37.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/uvg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/uvg.cpython-38.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/uvg.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/uvg.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/vimeo.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/vimeo.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/vimeo.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/vimeo.cpython-37.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/vimeo.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/vimeo.cpython-38.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/vimeo.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/vimeo.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/youtube.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/youtube.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/youtube.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/youtube.cpython-37.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/youtube.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/youtube.cpython-38.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/__pycache__/youtube.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/datasets/__pycache__/youtube.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/data/datasets/audi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | from pathlib import Path 6 | import torch 7 | 8 | 9 | class AUDI(Dataset): 10 | def __init__(self, path, num_of_frame, train=True, transform=None, add_noise=False): 11 | assert os.path.exists(path), "Invalid path to AUDI data set: " + path 12 | self.path = path 13 | self.transform = transform 14 | self.train = train 15 | if train: 16 | self.video_list = list( 17 | Path(os.path.join(path, "camera_lidar_semantic")).glob("*/camera/cam_front_center") 18 | )[:-1] 19 | else: 20 | self.video_list = list( 21 | Path(os.path.join(path, "camera_lidar_semantic")).glob("*/camera/cam_front_center") 22 | )[-1:] 23 | self.add_noise = add_noise 24 | self.num_of_frame = num_of_frame 25 | self.img_paths = [] 26 | for each in self.video_list: 27 | self.img_paths.append(sorted(list(each.glob("**/*small.png")))) 28 | 29 | def __getitem__(self, ind): 30 | # load the images from the ind directory to get list of PIL images 31 | if self.train: 32 | start_index = torch.randint(0, len(self.img_paths[ind]) - self.num_of_frame, (1,)).item() 33 | else: 34 | start_index = 525 35 | imgs = [Image.open(self.img_paths[ind][start_index + i]) for i in range(self.num_of_frame)] 36 | if self.transform is not None: 37 | imgs = self.transform(imgs) 38 | 39 | if self.add_noise: 40 | imgs = imgs + (torch.rand_like(imgs) - 0.5) / 256.0 41 | 42 | return imgs 43 | 44 | def __len__(self): 45 | # total number of videos 46 | return len(self.video_list) 47 | -------------------------------------------------------------------------------- /xparam/data/datasets/bair.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | from scipy.misc import imresize 4 | import numpy as np 5 | from PIL import Image 6 | from scipy.misc import imresize 7 | from scipy.misc import imread 8 | 9 | 10 | class RobotPush(object): 11 | 12 | """Data Handler that loads robot pushing data.""" 13 | 14 | def __init__(self, data_root, train=True, seq_len=20, image_size=64): 15 | self.root_dir = data_root 16 | if train: 17 | self.data_dir = '%s/processed_data/train' % self.root_dir 18 | self.ordered = False 19 | else: 20 | self.data_dir = '%s/processed_data/test' % self.root_dir 21 | self.ordered = True 22 | self.dirs = [] 23 | for d1 in os.listdir(self.data_dir): 24 | for d2 in os.listdir('%s/%s' % (self.data_dir, d1)): 25 | self.dirs.append('%s/%s/%s' % (self.data_dir, d1, d2)) 26 | self.seq_len = seq_len 27 | self.image_size = image_size 28 | self.seed_is_set = False # multi threaded loading 29 | self.d = 0 30 | 31 | def set_seed(self, seed): 32 | if not self.seed_is_set: 33 | self.seed_is_set = True 34 | np.random.seed(seed) 35 | 36 | def __len__(self): 37 | return 10000 38 | 39 | def get_seq(self): 40 | if self.ordered: 41 | d = self.dirs[self.d] 42 | if self.d == len(self.dirs) - 1: 43 | self.d = 0 44 | else: 45 | self.d+=1 46 | else: 47 | d = self.dirs[np.random.randint(len(self.dirs))] 48 | image_seq = [] 49 | for i in range(self.seq_len): 50 | fname = '%s/%d.png' % (d, i) 51 | im = imread(fname).reshape(1, 64, 64, 3) 52 | image_seq.append(im/255.) 53 | image_seq = np.concatenate(image_seq, axis=0) 54 | return image_seq 55 | 56 | 57 | def __getitem__(self, index): 58 | self.set_seed(index) 59 | return self.get_seq() 60 | 61 | 62 | -------------------------------------------------------------------------------- /xparam/data/datasets/bair_robot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | import os 5 | 6 | class RobotPushing(torch.utils.data.Dataset): 7 | """ 8 | dataset class for moving-mnist dataset 9 | """ 10 | def __init__(self, is_train, data_path=None): 11 | super(RobotPushing, self).__init__() 12 | if data_path is None: 13 | self.data_path = '/local-scratch/chenleic/Data/BAIR/robot_pushing/main_frames' 14 | # print() 15 | else: 16 | self.data_path = data_path 17 | 18 | 19 | all_vid_list = sorted(os.listdir(self.data_path)) 20 | 21 | self.is_train = is_train 22 | if self.is_train: 23 | self.vid_list = all_vid_list[:10000] 24 | else: 25 | self.vid_list = all_vid_list[-100:] 26 | 27 | 28 | def __len__(self): 29 | return len(self.vid_list) 30 | 31 | 32 | def __getitem__(self, item): 33 | frames = np.load(os.path.join(self.data_path, self.vid_list[item])) 34 | frames = torch.FloatTensor(frames).permute(3, 0, 2, 1).contiguous() 35 | 36 | frames = frames/255 37 | 38 | return frames 39 | 40 | 41 | if __name__ == '__main__': 42 | dataset = RobotPushing(is_train=True) 43 | dataloader = torch.utils.data.DataLoader(dataset) 44 | 45 | import matplotlib.pyplot as plt 46 | 47 | for batch in dataloader: 48 | frame = batch[0,:,0,:,:].permute(1,2,0).contiguous() 49 | plt.imshow(frame) 50 | plt.draw() 51 | plt.savefig('/local-scratch/chenleic/Projects/seq_flow/seq_flow_robot_results/check_frame.jpg') 52 | break 53 | 54 | print(batch.size()) 55 | 56 | 57 | -------------------------------------------------------------------------------- /xparam/data/datasets/bair_robot_pushing.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | import torch 5 | 6 | 7 | class BAIRRobotPushing(Dataset): 8 | """ 9 | Dataset object for BAIR robot pushing dataset. The dataset must be stored 10 | with each video in a separate directory: 11 | /path 12 | /0 13 | /0.png 14 | /1.png 15 | /... 16 | /1 17 | /... 18 | """ 19 | 20 | def __init__(self, path, transform=None, add_noise=False): 21 | assert os.path.exists(path), 'Invalid path to BAIR data set: ' + path 22 | self.path = path 23 | self.transform = transform 24 | self.video_list = os.listdir(self.path) 25 | 26 | self.add_noise = add_noise 27 | 28 | def __getitem__(self, ind): 29 | # load the images from the ind directory to get list of PIL images 30 | img_names = os.listdir(os.path.join(self.path, self.video_list[ind])) 31 | img_names = [img_name.split('.')[0] for img_name in img_names] 32 | img_names.sort(key=float) 33 | imgs = [Image.open(os.path.join(self.path, self.video_list[ind], i + '.png')) for i in img_names] 34 | if self.transform is not None: 35 | # apply the image/video transforms 36 | imgs = self.transform(imgs) 37 | 38 | # imgs = imgs.unsqueeze(1) 39 | 40 | if self.add_noise: 41 | imgs = imgs + (torch.rand_like(imgs)-0.5) / 256. 42 | 43 | return imgs 44 | 45 | def __len__(self): 46 | # total number of videos 47 | return len(self.video_list) 48 | -------------------------------------------------------------------------------- /xparam/data/datasets/big.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | from torchvision.io import read_video 5 | import torch 6 | 7 | 8 | class BIG(Dataset): 9 | """ 10 | Dataset object for BAIR robot pushing dataset. The dataset must be stored 11 | with each video in a separate directory: 12 | /path 13 | /0 14 | /0.png 15 | /1.png 16 | /... 17 | /1 18 | /... 19 | """ 20 | 21 | def __init__(self, path, transform=None, add_noise=False, img_mode=False): 22 | assert os.path.exists( 23 | path), 'Invalid path to UCF+HMDB data set: ' + path 24 | self.path = path 25 | self.transform = transform 26 | self.video_list = os.listdir(self.path) 27 | self.img_mode = img_mode 28 | self.add_noise = add_noise 29 | 30 | def __getitem__(self, ind): 31 | # load the images from the ind directory to get list of PIL images 32 | img_names = os.listdir(os.path.join( 33 | self.path, self.video_list[ind])) 34 | img_names = [img_name.split('.')[0] for img_name in img_names] 35 | img_names.sort(key=float) 36 | if not self.img_mode: 37 | imgs = [Image.open(os.path.join( 38 | self.path, self.video_list[ind], i + '.png')) for i in img_names] 39 | else: 40 | select = torch.randint(0, len(img_names), (1,)) 41 | imgs = [Image.open(os.path.join( 42 | self.path, self.video_list[ind], img_names[select] + '.png'))] 43 | if self.transform is not None: 44 | # apply the image/video transforms 45 | imgs = self.transform(imgs) 46 | 47 | # imgs = imgs.unsqueeze(1) 48 | 49 | if self.add_noise: 50 | imgs = imgs + (torch.rand_like(imgs)-0.5) / 256. 51 | 52 | return imgs 53 | 54 | def __len__(self): 55 | # total number of videos 56 | return len(self.video_list) 57 | -------------------------------------------------------------------------------- /xparam/data/datasets/bouncing_ball.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class BouncingBall(Dataset): 8 | """ 9 | Dataset class for moving MNIST dataset. 10 | 11 | Args: 12 | path (str): path to the .mat dataset 13 | transform (torchvision.transforms): image/video transforms 14 | """ 15 | def __init__(self, path, sequence_lengh): 16 | assert os.path.exists(path), 'Invalid path to Bouncing Ball data set: ' + path 17 | self.sequence_length = sequence_lengh 18 | self.data = np.load(path) 19 | 20 | def __getitem__(self, ind): 21 | imgs = self.data[ind,:,:,:].astype('float32') 22 | s, h, w = imgs.shape 23 | imgs = imgs.reshape(s, 1, h, w) 24 | 25 | imgs = imgs[:self.sequence_length, :, :, :] 26 | 27 | return torch.FloatTensor(imgs).contiguous() 28 | 29 | def __len__(self): 30 | return self.data.shape[0] 31 | -------------------------------------------------------------------------------- /xparam/data/datasets/bouncing_ball_creator.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script comes from the RTRBM code by Ilya Sutskever from 3 | http://www.cs.utoronto.ca/~ilya/code/2008/RTRBM.tar 4 | """ 5 | 6 | from numpy import * 7 | from scipy import * 8 | import pdb 9 | import pickle 10 | import scipy.io 11 | import sys, os 12 | 13 | import matplotlib 14 | 15 | matplotlib.use('Agg') 16 | import matplotlib.pyplot as plt 17 | 18 | shape_std = shape 19 | 20 | 21 | def shape(A): 22 | if isinstance(A, ndarray): 23 | return shape_std(A) 24 | else: 25 | return A.shape() 26 | 27 | 28 | size_std = size 29 | 30 | 31 | def size(A): 32 | if isinstance(A, ndarray): 33 | return size_std(A) 34 | else: 35 | return A.size() 36 | 37 | 38 | det = linalg.det 39 | 40 | 41 | def new_speeds(m1, m2, v1, v2): 42 | new_v2 = (2 * m1 * v1 + v2 * (m2 - m1)) / (m1 + m2) 43 | new_v1 = new_v2 + (v2 - v1) 44 | return new_v1, new_v2 45 | 46 | 47 | def norm(x): return sqrt((x ** 2).sum()) 48 | 49 | 50 | def sigmoid(x): return 1. / (1. + exp(-x)) 51 | 52 | 53 | SIZE = 10 54 | 55 | 56 | # size of bounding box: SIZE X SIZE. 57 | 58 | def bounce_n(T=128, n=2, r=None, m=None): 59 | if r is None: r = array([1.2] * n) 60 | if m is None: m = array([1] * n) 61 | # r is to be rather small. 62 | X = zeros((T, n, 2), dtype='float') 63 | v = randn(n, 2) 64 | v = v / norm(v) * .5 65 | good_config = False 66 | while not good_config: 67 | x = 2 + rand(n, 2) * 8 68 | good_config = True 69 | for i in range(n): 70 | for z in range(2): 71 | if x[i][z] - r[i] < 0: good_config = False 72 | if x[i][z] + r[i] > SIZE: good_config = False 73 | 74 | # that's the main part. 75 | for i in range(n): 76 | for j in range(i): 77 | if norm(x[i] - x[j]) < r[i] + r[j]: 78 | good_config = False 79 | 80 | eps = .5 81 | for t in range(T): 82 | # for how long do we show small simulation 83 | 84 | for i in range(n): 85 | X[t, i] = x[i] 86 | 87 | for mu in range(int(1 / eps)): 88 | 89 | for i in range(n): 90 | x[i] += eps * v[i] 91 | 92 | for i in range(n): 93 | for z in range(2): 94 | if x[i][z] - r[i] < 0: v[i][z] = abs(v[i][z]) # want positive 95 | if x[i][z] + r[i] > SIZE: v[i][z] = -abs(v[i][z]) # want negative 96 | 97 | for i in range(n): 98 | for j in range(i): 99 | if norm(x[i] - x[j]) < r[i] + r[j]: 100 | # the bouncing off part: 101 | w = x[i] - x[j] 102 | w = w / norm(w) 103 | 104 | v_i = dot(w.transpose(), v[i]) 105 | v_j = dot(w.transpose(), v[j]) 106 | 107 | new_v_i, new_v_j = new_speeds(m[i], m[j], v_i, v_j) 108 | 109 | v[i] += w * (new_v_i - v_i) 110 | v[j] += w * (new_v_j - v_j) 111 | 112 | return X 113 | 114 | 115 | def ar(x, y, z): 116 | return z / 2 + arange(x, y, z, dtype='float') 117 | 118 | 119 | def matricize(X, res, r=None): 120 | T, n = shape(X)[0:2] 121 | if r is None: r = array([1.2] * n) 122 | 123 | A = zeros((T, res, res), dtype='float') 124 | 125 | [I, J] = meshgrid(ar(0, 1, 1. / res) * SIZE, ar(0, 1, 1. / res) * SIZE) 126 | 127 | for t in range(T): 128 | for i in range(n): 129 | A[t] += exp(-(((I - X[t, i, 0]) ** 2 + (J - X[t, i, 1]) ** 2) / (r[i] ** 2)) ** 4) 130 | 131 | A[t][A[t] > 1] = 1 132 | return A 133 | 134 | 135 | def bounce_mat(res, n=2, T=128, r=None): 136 | if r == None: r = array([1.2] * n) 137 | x = bounce_n(T, n, r); 138 | A = matricize(x, res, r) 139 | return A 140 | 141 | 142 | def bounce_vec(res, n=2, T=128, r=None, m=None): 143 | if r == None: r = array([1.2] * n) 144 | x = bounce_n(T, n, r, m); 145 | V = matricize(x, res, r) 146 | return V.reshape(T, res ** 2) 147 | 148 | 149 | # make sure you have this folder 150 | # logdir = './sample' 151 | 152 | 153 | # def show_sample(V): 154 | # T = len(V) 155 | # res = int(sqrt(shape(V)[1])) 156 | # for t in range(T): 157 | # plt.imshow(V[t].reshape(res, res), cmap=matplotlib.cm.Greys_r) 158 | # # Save it 159 | # fname = logdir + '/' + str(t) + '.png' 160 | # plt.savefig(fname) 161 | 162 | 163 | def make_bouncing_ball_dataset(data_path, res, n_ball, T, N_train, N_val): 164 | train_data = zeros((N_train, T, res, res)) 165 | for i in range(N_train): 166 | train_data[i] = bounce_vec(res=res, n=n_ball, T=T).reshape((T, res, res)) 167 | sys.stdout.write('\rcreating bouncing ball train clip {}/{}'.format(i, N_train)) 168 | sys.stdout.flush() 169 | 170 | train_data = train_data.reshape((N_train, T, res, res)) 171 | save(os.path.join(data_path, 'bouncing_balls_train_data.npy'), train_data) 172 | print() 173 | 174 | val_data = zeros((N_val, T, res, res)) 175 | for i in range(N_val): 176 | val_data[i] = bounce_vec(res=res, n=n_ball, T=T).reshape((T, res, res)) 177 | sys.stdout.write('\rcreating bouncing ball val clip {}/{}'.format(i, N_val)) 178 | sys.stdout.flush() 179 | 180 | val_data = val_data.reshape((N_val, T, res, res)) 181 | save(os.path.join(data_path, 'bouncing_balls_val_data.npy'), val_data) 182 | 183 | 184 | 185 | # if __name__ == "__main__": 186 | # res = 30 187 | # T = 100 188 | # N = 4000 189 | # dat = empty((N), dtype=object) 190 | # for i in range(N): 191 | # dat[i] = bounce_vec(res=res, n=3, T=100) 192 | # data = {} 193 | # data['Data'] = dat 194 | # scipy.io.savemat('bouncing_balls_training_data.mat', data) 195 | # 196 | # N = 200 197 | # dat = empty((N), dtype=object) 198 | # for i in range(N): 199 | # dat[i] = bounce_vec(res=res, n=3, T=100) 200 | # data = {} 201 | # data['Data'] = dat 202 | # scipy.io.savemat('bouncing_balls_testing_data.mat', data) 203 | 204 | # show one video 205 | # show_sample(dat[1]) 206 | # ffmpeg -start_number 0 -i %d.png -c:v libx264 -pix_fmt yuv420p -r 30 sample.mp4 -------------------------------------------------------------------------------- /xparam/data/datasets/city.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | from pathlib import Path 6 | import torch 7 | 8 | 9 | class CITY(Dataset): 10 | def __init__(self, path, num_of_frame, train=True, transform=None, add_noise=False): 11 | assert os.path.exists(path), "Invalid path to CITY data set: " + path 12 | self.path = path 13 | self.transform = transform 14 | self.train = train 15 | if train: 16 | self.frame_list = Path(os.path.join(path, "leftImg8bit_sequence/train")).glob("*/*.png") 17 | else: 18 | self.frame_list = Path(os.path.join(path, "leftImg8bit_sequence/val")).glob("*/*.png") 19 | self.add_noise = add_noise 20 | self.num_of_frame = num_of_frame 21 | self.frame_list = sorted(self.frame_list) 22 | 23 | def __getitem__(self, ind): 24 | # load the images from the ind directory to get list of PIL images 25 | first_frame_ind = ind * 30 26 | last_frame_ind = (ind+1) * 30 27 | if self.train: 28 | start_ind = torch.randint(first_frame_ind, last_frame_ind - self.num_of_frame, (1,)).item() 29 | else: 30 | start_ind = first_frame_ind 31 | imgs = [Image.open(self.frame_list[start_ind + i]) for i in range(self.num_of_frame)] 32 | if self.transform is not None: 33 | imgs = self.transform(imgs) 34 | 35 | if self.add_noise: 36 | imgs = imgs + (torch.rand_like(imgs) - 0.5) / 256.0 37 | 38 | return imgs 39 | 40 | def __len__(self): 41 | # total number of videos 42 | return len(self.frame_list) // 30 43 | -------------------------------------------------------------------------------- /xparam/data/datasets/climate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import os 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class ClimateData(Dataset): 9 | def __init__(self, path, mode): 10 | data = np.load(os.path.join(path, "climate_timestep/W_fields.npy"), mmap_mode="r") 11 | data = np.reshape(data, (-1, 192, 30, 128), order="F") 12 | data = np.reshape(data, (-1, 24, 8, 30, 128)) 13 | self.mean = data.mean() 14 | self.std = np.std(data) 15 | data = (data - self.mean) / self.std 16 | 17 | if mode == "train": 18 | 19 | self.t = 20 20 | train = data[:, :20, :, :, :] 21 | del data 22 | train = np.reshape(train, (-1, 8, 30, 128)) 23 | train = np.reshape(train, (-1, 30, 128)) 24 | train = np.pad(train, ((0, 0), (1, 1), (0, 0)), "symmetric") 25 | self.data = torch.from_numpy(train).float() 26 | 27 | else: 28 | 29 | self.t = 4 30 | test = data[:, 20:, :, :, :] 31 | del data 32 | test = np.reshape(test, (-1, 8, 30, 128)) 33 | test = np.reshape(test, (-1, 30, 128)) 34 | test = np.pad(test, ((0, 0), (1, 1), (0, 0)), "symmetric") 35 | self.data = torch.from_numpy(test).float() 36 | 37 | def __len__(self): 38 | 39 | return self.data.size()[0] 40 | 41 | def __getitem__(self, idx): 42 | 43 | width = self.t * 8 44 | start = int(idx / (width)) 45 | p = idx % width 46 | if p > width - 8: 47 | p = width - 8 48 | begin = start * width + p 49 | return self.data[begin : begin + 8, :, :].unsqueeze(1) 50 | 51 | -------------------------------------------------------------------------------- /xparam/data/datasets/image.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class IMG(Dataset): 7 | def __init__(self, path, transform=None): 8 | assert os.path.exists(path), 'Invalid path to IMAGE data set: ' + path 9 | self.path = path 10 | self.transform = transform 11 | self.img_list = os.listdir(self.path) 12 | 13 | def __getitem__(self, ind): 14 | # load the images from the ind directory to get list of PIL images 15 | img = [Image.open(os.path.join(self.path, self.img_list[ind]))] 16 | if self.transform is not None: 17 | img = self.transform(img) 18 | if img.shape[1] == 1: 19 | img = img.expand(-1, 3, -1, -1) 20 | return img 21 | 22 | def __len__(self): 23 | # total number of videos 24 | return len(self.img_list) 25 | -------------------------------------------------------------------------------- /xparam/data/datasets/kth_actions.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | import torch 5 | 6 | 7 | class KTHActions(Dataset): 8 | """ 9 | Dataset object for KTH actions dataset. The dataset must be stored 10 | with each video (action sequence) in a separate directory: 11 | /path 12 | /person01_walking_d1_0 13 | /0.png 14 | /1.png 15 | /... 16 | /person01_walking_d1_1 17 | /... 18 | """ 19 | def __init__(self, path, transform=None, add_noise=False): 20 | assert os.path.exists(path), 'Invalid path to KTH actions data set: ' + path 21 | self.path = path 22 | self.transform = transform 23 | self.video_list = os.listdir(self.path) 24 | self.add_noise = add_noise 25 | 26 | def __getitem__(self, ind): 27 | # load the images from the ind directory to get list of PIL images 28 | img_names = os.listdir(os.path.join(self.path, self.video_list[ind])) 29 | img_names = [img_name.split('.')[0] for img_name in img_names] 30 | img_names.sort(key=float) 31 | imgs = [Image.open(os.path.join(self.path, self.video_list[ind], i + '.png')).convert('L') for i in img_names] 32 | if self.transform is not None: 33 | # apply the image/video transforms 34 | imgs = self.transform(imgs) 35 | 36 | if self.add_noise: 37 | imgs += torch.randn_like(imgs)/256 38 | 39 | return imgs 40 | 41 | def __len__(self): 42 | # returns the total number of videos 43 | return len(os.listdir(self.path)) 44 | -------------------------------------------------------------------------------- /xparam/data/datasets/moving_mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class MovingMNIST(Dataset): 8 | """ 9 | Dataset class for moving MNIST dataset. 10 | 11 | Args: 12 | path (str): path to the .npy dataset 13 | transform (torchvision.transforms): image/video transforms 14 | """ 15 | 16 | def __init__(self, path, transform=None, add_noise=False): 17 | assert os.path.exists(path), 'Invalid path to Moving MNIST data set: ' + path 18 | self.transform = transform 19 | self.data = np.load(path) 20 | self.add_noise = add_noise 21 | 22 | def __getitem__(self, ind): 23 | imgs = self.data[:, ind, :, :].astype('float32') 24 | s, h, w = imgs.shape 25 | imgs = imgs.reshape(s, 1, h, w) 26 | if self.transform is not None: 27 | # apply the image/video transforms 28 | imgs = self.transform(imgs) 29 | 30 | if self.add_noise: 31 | imgs += torch.randn_like(imgs)/256 32 | return imgs 33 | 34 | def __len__(self): 35 | return self.data.shape[1] 36 | -------------------------------------------------------------------------------- /xparam/data/datasets/simu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision.transforms.functional as G 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class Simulation(Dataset): 8 | 9 | def __init__(self, path, number_of_frame, train, size, transform=None): 10 | 11 | data = np.load(path).astype(np.single) 12 | mmin = data.min() 13 | mmax = data.max() 14 | self.number_of_frame = number_of_frame 15 | self.transform = transform 16 | 17 | if train: 18 | self.t = 1000 19 | train = data[:8000, :, :] 20 | train = (train - mmin) / (mmax - mmin) 21 | self.data = torch.from_numpy(train) 22 | self.data = self.data.unsqueeze(1) 23 | self.data = G.resize(self.data, size) 24 | 25 | else: 26 | 27 | self.t = 250 28 | test = data[8000:, :, :] 29 | test = (test - mmin) / (mmax - mmin) 30 | self.data = torch.from_numpy(test) 31 | self.data = self.data.unsqueeze(1) 32 | self.data = G.resize(self.data, size) 33 | 34 | def __len__(self): 35 | 36 | return self.data.size()[0] 37 | 38 | def __getitem__(self, idx): 39 | 40 | width = self.t 41 | start = int(idx/(width)) 42 | p = idx % width 43 | if p > width - self.number_of_frame: 44 | p = width - self.number_of_frame 45 | begin = start * width + p 46 | frames = self.data[begin:begin + self.number_of_frame, :, :, :] 47 | return frames -------------------------------------------------------------------------------- /xparam/data/datasets/stochastic_moving_mnist.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | from torchvision import datasets, transforms 5 | import torch 6 | 7 | 8 | class StochasticMovingMNIST(Dataset): 9 | """Data Handler that creates Bouncing MNIST dataset on the fly.""" 10 | def __init__(self, 11 | train, 12 | data_root, 13 | seq_len=20, 14 | num_digits=2, 15 | image_size=64, 16 | deterministic=True, 17 | add_noise=False, 18 | epoch_size=0): 19 | path = data_root 20 | self.seq_len = seq_len 21 | self.num_digits = num_digits 22 | self.image_size = image_size 23 | self.step_length = 0.1 24 | self.digit_size = 32 25 | self.deterministic = deterministic 26 | self.seed_is_set = False # multi threaded loading 27 | self.channels = 1 28 | self.add_noise = add_noise 29 | self.epoch_size = epoch_size 30 | 31 | self.data = datasets.MNIST(path, 32 | train=train, 33 | download=True, 34 | transform=transforms.Compose( 35 | [transforms.Scale(self.digit_size), 36 | transforms.ToTensor()])) 37 | 38 | self.N = len(self.data) 39 | 40 | def set_seed(self, seed): 41 | if not self.seed_is_set: 42 | self.seed_is_set = True 43 | np.random.seed(seed) 44 | 45 | def __len__(self): 46 | if self.epoch_size > 0: 47 | return self.epoch_size 48 | else: 49 | return self.N 50 | 51 | def __getitem__(self, index): 52 | self.set_seed(index) 53 | image_size = self.image_size 54 | digit_size = self.digit_size 55 | x = np.zeros((self.seq_len, image_size, image_size, self.channels), dtype=np.float32) 56 | for n in range(self.num_digits): 57 | idx = np.random.randint(self.N) 58 | digit, _ = self.data[idx] 59 | 60 | sx = np.random.randint(image_size - digit_size) 61 | sy = np.random.randint(image_size - digit_size) 62 | dx = np.random.randint(-4, 5) 63 | dy = np.random.randint(-4, 5) 64 | for t in range(self.seq_len): 65 | if sy < 0: 66 | sy = 0 67 | if self.deterministic: 68 | dy = -dy 69 | else: 70 | dy = np.random.randint(1, 5) 71 | dx = np.random.randint(-4, 5) 72 | elif sy >= image_size - 32: 73 | sy = image_size - 32 - 1 74 | if self.deterministic: 75 | dy = -dy 76 | else: 77 | dy = np.random.randint(-4, 0) 78 | dx = np.random.randint(-4, 5) 79 | 80 | if sx < 0: 81 | sx = 0 82 | if self.deterministic: 83 | dx = -dx 84 | else: 85 | dx = np.random.randint(1, 5) 86 | dy = np.random.randint(-4, 5) 87 | elif sx >= image_size - 32: 88 | sx = image_size - 32 - 1 89 | if self.deterministic: 90 | dx = -dx 91 | else: 92 | dx = np.random.randint(-4, 0) 93 | dy = np.random.randint(-4, 5) 94 | 95 | x[t, sy:sy + 32, sx:sx + 32, 0] += digit.numpy().squeeze() 96 | sy += dy 97 | sx += dx 98 | 99 | x = torch.FloatTensor(x).permute(0, 3, 1, 2).contiguous() 100 | if self.add_noise: 101 | x += torch.randn_like(x) / 256 102 | 103 | x[x < 0] = 0. 104 | x[x > 1] = 1. 105 | return x 106 | -------------------------------------------------------------------------------- /xparam/data/datasets/uvg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | import torch 6 | import random 7 | 8 | 9 | class UVG(Dataset): 10 | def __init__(self, path, nframe=3, transform=None, seed=1212): 11 | assert os.path.exists(path), 'Invalid path to uvg data set: ' + path 12 | random.seed(seed) 13 | ldir = os.listdir(path) 14 | random.shuffle(ldir) 15 | self.transform = transform 16 | video_list = np.core.defchararray.add(f'{path}/', ldir) 17 | self.video_list = video_list 18 | self.nframe = nframe 19 | 20 | def __getitem__(self, ind): 21 | tot_nframe = len(os.listdir(self.video_list[ind])) 22 | assert tot_nframe >= self.nframe 23 | start_ind = torch.randint(1, 1 + tot_nframe - self.nframe, (1, )).item() 24 | imgs = [ 25 | Image.open(os.path.join(self.video_list[ind], 26 | str(img_name) + '.png')) for img_name in range(start_ind, start_ind + self.nframe) 27 | ] 28 | if self.transform is not None: 29 | imgs = self.transform(imgs) 30 | 31 | return imgs 32 | 33 | def __len__(self): 34 | # total number of videos 35 | return len(self.video_list) 36 | -------------------------------------------------------------------------------- /xparam/data/datasets/vimeo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | import torch 6 | 7 | 8 | class VIMEO(Dataset): 9 | 10 | def __init__(self, path, train=True, transform=None, add_noise=False): 11 | assert os.path.exists( 12 | path), 'Invalid path to VIMEO data set: ' + path 13 | self.path = path 14 | self.transform = transform 15 | if train: 16 | self.video_list = os.path.join(path, 'sep_trainlist.txt') 17 | else: 18 | self.video_list = os.path.join(path, 'sep_testlist.txt') 19 | self.video_list = np.loadtxt(self.video_list, dtype=str) 20 | self.video_list = np.core.defchararray.add(f'{os.path.join(path, "sequences")}/', self.video_list) 21 | 22 | self.add_noise = add_noise 23 | 24 | def __getitem__(self, ind): 25 | # load the images from the ind directory to get list of PIL images 26 | img_names = os.listdir(str(self.video_list[ind])) 27 | imgs = [Image.open(os.path.join(self.video_list[ind], str(img_name))) 28 | for img_name in img_names] 29 | if self.transform is not None: 30 | imgs = self.transform(imgs) 31 | 32 | if self.add_noise: 33 | imgs = imgs + (torch.rand_like(imgs)-0.5) / 256. 34 | 35 | return imgs 36 | 37 | def __len__(self): 38 | # total number of videos 39 | return len(self.video_list) 40 | -------------------------------------------------------------------------------- /xparam/data/datasets/youtube.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image, ImageChops 4 | from torch.utils.data import Dataset 5 | import torch 6 | import random 7 | 8 | 9 | class Youtube(Dataset): 10 | def __init__(self, path, nframe=3, train=True, transform=None, seed=1212): 11 | assert os.path.exists(path), 'Invalid path to youtube data set: ' + path 12 | random.seed(seed) 13 | ldir = os.listdir(path) 14 | random.shuffle(ldir) 15 | self.transform = transform 16 | video_list = np.core.defchararray.add(f'{path}/', ldir) 17 | if train: 18 | self.video_list = video_list[:-32] 19 | else: 20 | self.video_list = video_list[-32:] 21 | self.nframe = nframe 22 | 23 | def __getitem__(self, ind): 24 | tot_nframe = len(os.listdir(self.video_list[ind])) 25 | assert tot_nframe >= self.nframe 26 | start_ind = torch.randint(1, 1 + tot_nframe - self.nframe, (1, )).item() 27 | imgs = [ 28 | Image.open(os.path.join(self.video_list[ind], 29 | str(img_name) + '.png')) for img_name in range(start_ind, start_ind + self.nframe) 30 | ] 31 | if self.transform is not None: 32 | imgs = self.transform(imgs) 33 | 34 | return imgs 35 | 36 | def __len__(self): 37 | # total number of videos 38 | return len(self.video_list) 39 | -------------------------------------------------------------------------------- /xparam/data/load_data.py: -------------------------------------------------------------------------------- 1 | from .load_dataset import load_dataset 2 | from .transposed_collate import train_transposed_collate, test_transposed_collate 3 | from torch.utils.data.dataloader import default_collate 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data.distributed import DistributedSampler 6 | 7 | 8 | def load_data(data_config, batch_size, num_workers=4, pin_memory=True, distributed=False): 9 | """ 10 | Wrapper around load_dataset. Gets the dataset, then places it in a DataLoader. 11 | 12 | Args: 13 | data_config (dict): data configuration dictionary 14 | batch_size (dict): run configuration dictionary 15 | num_workers (int): number of threads of multi-processed data Loading 16 | pin_memory (bool): whether or not to pin memory in cpu 17 | sequence (bool): whether data examples are sequences, in which case the 18 | data loader returns transposed batches with the sequence 19 | step as the first dimension and batch index as the 20 | second dimension 21 | """ 22 | train, val = load_dataset(data_config) 23 | train_spl = DistributedSampler(train) if distributed else None 24 | val_spl = DistributedSampler(val, shuffle=False) if distributed else None 25 | 26 | if train is not None: 27 | train = DataLoader( 28 | train, 29 | batch_size=batch_size, 30 | shuffle=False if distributed else True, 31 | collate_fn=train_transposed_collate, 32 | num_workers=num_workers, 33 | pin_memory=pin_memory, 34 | sampler=train_spl 35 | ) 36 | 37 | if val is not None: 38 | val = DataLoader( 39 | val, 40 | batch_size=batch_size, 41 | shuffle=False, 42 | collate_fn=test_transposed_collate, 43 | num_workers=num_workers, 44 | pin_memory=pin_memory, 45 | sampler=val_spl 46 | ) 47 | return train, val 48 | -------------------------------------------------------------------------------- /xparam/data/misc_data_util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/misc_data_util/__init__.py -------------------------------------------------------------------------------- /xparam/data/misc_data_util/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/misc_data_util/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/data/misc_data_util/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/misc_data_util/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /xparam/data/misc_data_util/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/misc_data_util/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /xparam/data/misc_data_util/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/misc_data_util/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/data/misc_data_util/__pycache__/convert_bair.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/misc_data_util/__pycache__/convert_bair.cpython-37.pyc -------------------------------------------------------------------------------- /xparam/data/misc_data_util/__pycache__/convert_bair.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/misc_data_util/__pycache__/convert_bair.cpython-38.pyc -------------------------------------------------------------------------------- /xparam/data/misc_data_util/__pycache__/convert_kth_actions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/misc_data_util/__pycache__/convert_kth_actions.cpython-37.pyc -------------------------------------------------------------------------------- /xparam/data/misc_data_util/__pycache__/convert_kth_actions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/misc_data_util/__pycache__/convert_kth_actions.cpython-38.pyc -------------------------------------------------------------------------------- /xparam/data/misc_data_util/__pycache__/kth_actions_frames.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/misc_data_util/__pycache__/kth_actions_frames.cpython-37.pyc -------------------------------------------------------------------------------- /xparam/data/misc_data_util/__pycache__/kth_actions_frames.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/misc_data_util/__pycache__/kth_actions_frames.cpython-38.pyc -------------------------------------------------------------------------------- /xparam/data/misc_data_util/__pycache__/transforms.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/misc_data_util/__pycache__/transforms.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/data/misc_data_util/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/misc_data_util/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /xparam/data/misc_data_util/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/misc_data_util/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /xparam/data/misc_data_util/__pycache__/transforms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/misc_data_util/__pycache__/transforms.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/data/misc_data_util/__pycache__/url_save.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/misc_data_util/__pycache__/url_save.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/data/misc_data_util/__pycache__/url_save.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/misc_data_util/__pycache__/url_save.cpython-37.pyc -------------------------------------------------------------------------------- /xparam/data/misc_data_util/__pycache__/url_save.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/misc_data_util/__pycache__/url_save.cpython-38.pyc -------------------------------------------------------------------------------- /xparam/data/misc_data_util/__pycache__/url_save.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/data/misc_data_util/__pycache__/url_save.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/data/misc_data_util/convert_bair.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import tensorflow.compat.v1 as tf 5 | from tensorflow.python.platform import gfile 6 | from imageio import imwrite as imsave 7 | 8 | # Convert BAIR robot pushing data to numpy to use with PyTorch 9 | # Based on Emily Denton's script: https://github.com/edenton/svg/blob/master/data/convert_bair.py 10 | 11 | 12 | def convert(data_path): 13 | # iterate through the data splits 14 | for data_split in ['train', 'test']: 15 | os.makedirs(os.path.join(data_path, data_split)) 16 | data_split_path = os.path.join(data_path, 'softmotion30_44k', data_split) 17 | data_split_files = gfile.Glob(os.path.join(data_split_path, '*')) 18 | # iterate through the TF records 19 | for f in data_split_files: 20 | print('Current file: ' + f) 21 | ind = int(f.split('/')[-1].split('_')[1]) # starting video index 22 | # iterate through the sequences in this TF record 23 | for serialized_example in tf.python_io.tf_record_iterator(f): 24 | os.makedirs(os.path.join(data_path, data_split, str(ind))) 25 | example = tf.train.Example() 26 | example.ParseFromString(serialized_example) 27 | # iterate through the sequence 28 | for i in range(30): 29 | image_name = str(i) + '/image_aux1/encoded' 30 | byte_str = example.features.feature[image_name].bytes_list.value[0] 31 | img = Image.frombytes('RGB', (64, 64), byte_str) 32 | img = np.array(img.getdata()).reshape(img.size[1], img.size[0], 3) / 255. 33 | imsave(os.path.join(data_path, data_split, str(ind), str(i) + '.png'), img) 34 | print(' Finished processing sequence ' + str(ind)) 35 | ind += 1 36 | -------------------------------------------------------------------------------- /xparam/data/misc_data_util/convert_kth_actions.py: -------------------------------------------------------------------------------- 1 | import os 2 | from imageio import imwrite as imsave 3 | from moviepy.editor import VideoFileClip 4 | from .kth_actions_frames import kth_actions_dict 5 | 6 | settings = ['d1', 'd2', 'd3', 'd4'] 7 | actions = ['walking', 'jogging', 'running', 'boxing', 'handwaving', 'handclapping'] 8 | person_ids = {'train': ['11', '12', '13', '14', '15', '16', '17', '18'], 9 | 'val': ['19', '20', '21', '23', '24', '25', '01', '04'], 10 | 'test': ['22', '02', '03', '05', '06', '07', '08', '09', '10']} 11 | 12 | 13 | def convert(data_path): 14 | # iterate through the data splits 15 | for data_split in ['train', 'val', 'test']: 16 | print('Converting ' + data_split) 17 | os.makedirs(os.path.join(data_path, data_split)) 18 | split_person_ids = person_ids[data_split] 19 | # iterate through the ids, actions, and settings for this split 20 | for person_id in split_person_ids: 21 | print(' Converting person' + person_id) 22 | for action in kth_actions_dict['person'+person_id]: 23 | for setting in kth_actions_dict['person'+person_id][action]: 24 | frame_nums = kth_actions_dict['person'+person_id][action][setting] 25 | if len(frame_nums) > 0: 26 | start_frames = [frame_pair[0] for frame_pair in frame_nums] 27 | end_frames = [frame_pair[1] for frame_pair in frame_nums] 28 | # load the video 29 | file_name = 'person' + person_id + '_' + action + '_' + setting + '_uncomp.avi' 30 | print(file_name) 31 | video = VideoFileClip(os.path.join(data_path, action, file_name)) 32 | # write each sequence to a directory 33 | sequence_frame_index = 0 34 | sequence_index = 0 35 | sequence_name = '' 36 | in_sequence = False 37 | for frame_index, frame in enumerate(video.iter_frames()): 38 | if frame_index + 1 in start_frames: 39 | # start a new sequence 40 | in_sequence = True 41 | sequence_frame_index = 0 42 | sequence_name = 'person' + person_id + '_' + action + '_' + setting + '_' + str(sequence_index) 43 | os.makedirs(os.path.join(data_path, data_split, sequence_name)) 44 | if frame_index + 1 in end_frames: 45 | # end the current sequence 46 | in_sequence = False 47 | sequence_index += 1 48 | if frame_index + 1 == max(end_frames): 49 | break 50 | if in_sequence: 51 | # write frame to the current sequence 52 | frame = frame.astype('float32') / 255. 53 | imsave(os.path.join(data_path, data_split, sequence_name, str(sequence_frame_index) + '.png'), frame) 54 | sequence_frame_index += 1 55 | del video.reader 56 | del video 57 | -------------------------------------------------------------------------------- /xparam/data/misc_data_util/url_save.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import urllib 5 | 6 | """ 7 | Utility functions for downloading files. 8 | """ 9 | 10 | def report_hook(count, block_size, total_size): 11 | # to display download progress 12 | # see https://blog.shichao.io/2012/10/04/progress_speed_indicator_for_urlretrieve_in_python.html 13 | global start_time 14 | if count == 0: 15 | start_time = time.time() 16 | return 17 | duration = time.time() - start_time 18 | progress_size = int(count * block_size) 19 | speed = int(progress_size / (1024 * duration)) 20 | percent = min(int(count*block_size*100/total_size),100) 21 | sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" % 22 | (percent, progress_size / (1024 * 1024), speed, duration)) 23 | sys.stdout.flush() 24 | 25 | def save(url, file_name): 26 | urllib.request.urlretrieve(url, file_name, report_hook) 27 | -------------------------------------------------------------------------------- /xparam/data/transposed_collate.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataloader import default_collate 2 | import torch 3 | 4 | 5 | def train_transposed_collate(batch): 6 | """ 7 | Wrapper around the default collate function to return sequences of PyTorch 8 | tensors with sequence step as the first dimension and batch index as the 9 | second dimension. 10 | 11 | Args: 12 | batch (list): data examples 13 | """ 14 | batch = filter(lambda img: img is not None, batch) 15 | collated_batch = default_collate(list(batch)) 16 | transposed_batch = collated_batch.transpose_(0, 1) 17 | # assert transposed_batch.shape[0] >= 4 18 | # idx = torch.randint(4, transposed_batch.shape[0] + 1, size=(1,)).item() 19 | # return transposed_batch[:idx] 20 | return transposed_batch 21 | 22 | 23 | def test_transposed_collate(batch): 24 | """ 25 | Wrapper around the default collate function to return sequences of PyTorch 26 | tensors with sequence step as the first dimension and batch index as the 27 | second dimension. 28 | 29 | Args: 30 | batch (list): data examples 31 | """ 32 | batch = filter(lambda img: img is not None, batch) 33 | collated_batch = default_collate(list(batch)) 34 | transposed_batch = collated_batch.transpose_(0, 1) 35 | return transposed_batch 36 | -------------------------------------------------------------------------------- /xparam/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/modules/__init__.py -------------------------------------------------------------------------------- /xparam/modules/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/modules/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/modules/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/modules/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/modules/__pycache__/compress_modules.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/modules/__pycache__/compress_modules.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/modules/__pycache__/compress_modules.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/modules/__pycache__/compress_modules.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/modules/__pycache__/denoising_diffusion.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/modules/__pycache__/denoising_diffusion.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/modules/__pycache__/denoising_diffusion.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/modules/__pycache__/denoising_diffusion.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/modules/__pycache__/network_components.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/modules/__pycache__/network_components.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/modules/__pycache__/network_components.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/modules/__pycache__/network_components.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/modules/__pycache__/unet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/modules/__pycache__/unet.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/modules/__pycache__/unet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/modules/__pycache__/unet.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/modules/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/modules/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /xparam/modules/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buggyyang/CDC_compression/538f0f4fc0fa2c41757dc54547b01d5ae26e9ada/xparam/modules/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /xparam/modules/compress_modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .network_components import ResnetBlock, FlexiblePrior, Downsample, Upsample 3 | from .utils import quantize, NormalDistribution 4 | 5 | 6 | class Compressor(nn.Module): 7 | def __init__( 8 | self, 9 | dim=64, 10 | dim_mults=(1, 2, 3, 4), 11 | reverse_dim_mults=(4, 3, 2, 1), 12 | hyper_dims_mults=(4, 4, 4), 13 | channels=3, 14 | out_channels=3, 15 | ): 16 | super().__init__() 17 | self.channels = channels 18 | self.out_channels = out_channels 19 | self.dims = [channels, *map(lambda m: dim * m, dim_mults)] 20 | self.in_out = list(zip(self.dims[:-1], self.dims[1:])) 21 | self.reversed_dims = [*map(lambda m: dim * m, reverse_dim_mults), out_channels] 22 | self.reversed_in_out = list(zip(self.reversed_dims[:-1], self.reversed_dims[1:])) 23 | assert self.dims[-1] == self.reversed_dims[0] 24 | self.hyper_dims = [self.dims[-1], *map(lambda m: dim * m, hyper_dims_mults)] 25 | self.hyper_in_out = list(zip(self.hyper_dims[:-1], self.hyper_dims[1:])) 26 | self.reversed_hyper_dims = list( 27 | reversed([self.dims[-1] * 2, *map(lambda m: dim * m, hyper_dims_mults)]) 28 | ) 29 | self.reversed_hyper_in_out = list( 30 | zip(self.reversed_hyper_dims[:-1], self.reversed_hyper_dims[1:]) 31 | ) 32 | self.prior = FlexiblePrior(self.hyper_dims[-1]) 33 | 34 | def get_extra_loss(self): 35 | return self.prior.get_extraloss() 36 | 37 | def build_network(self): 38 | self.enc = nn.ModuleList([]) 39 | self.dec = nn.ModuleList([]) 40 | self.hyper_enc = nn.ModuleList([]) 41 | self.hyper_dec = nn.ModuleList([]) 42 | 43 | def encode(self, input): 44 | for i, (resnet, down) in enumerate(self.enc): 45 | input = resnet(input) 46 | input = down(input) 47 | latent = input 48 | for i, (conv, act) in enumerate(self.hyper_enc): 49 | input = conv(input) 50 | input = act(input) 51 | hyper_latent = input 52 | q_hyper_latent = quantize(hyper_latent, "dequantize", self.prior.medians) 53 | input = q_hyper_latent 54 | for i, (deconv, act) in enumerate(self.hyper_dec): 55 | input = deconv(input) 56 | input = act(input) 57 | 58 | mean, scale = input.chunk(2, 1) 59 | latent_distribution = NormalDistribution(mean, scale.clamp(min=0.1)) 60 | q_latent = quantize(latent, "dequantize", latent_distribution.mean) 61 | state4bpp = { 62 | "latent": latent, 63 | "hyper_latent": hyper_latent, 64 | "latent_distribution": latent_distribution, 65 | } 66 | return q_latent, q_hyper_latent, state4bpp 67 | 68 | def decode(self, input): 69 | output = [] 70 | for i, (resnet, up) in enumerate(self.dec): 71 | input = resnet(input) 72 | input = up(input) 73 | output.append(input) 74 | return output[::-1] 75 | 76 | def bpp(self, shape, state4bpp): 77 | B, _, H, W = shape 78 | latent = state4bpp["latent"] 79 | hyper_latent = state4bpp["hyper_latent"] 80 | latent_distribution = state4bpp["latent_distribution"] 81 | if self.training: 82 | q_hyper_latent = quantize(hyper_latent, "noise") 83 | q_latent = quantize(latent, "noise") 84 | else: 85 | q_hyper_latent = quantize(hyper_latent, "dequantize", self.prior.medians) 86 | q_latent = quantize(latent, "dequantize", latent_distribution.mean) 87 | hyper_rate = -self.prior.likelihood(q_hyper_latent).log2() 88 | cond_rate = -latent_distribution.likelihood(q_latent).log2() 89 | bpp = (hyper_rate.sum(dim=(1, 2, 3)) + cond_rate.sum(dim=(1, 2, 3))) / (H * W) 90 | return bpp 91 | 92 | def forward(self, input): 93 | q_latent, q_hyper_latent, state4bpp = self.encode(input) 94 | bpp = self.bpp(input.shape, state4bpp) 95 | output = self.decode(q_latent) 96 | return { 97 | "output": output, 98 | "bpp": bpp, 99 | "q_latent": q_latent, 100 | "q_hyper_latent": q_hyper_latent, 101 | } 102 | 103 | 104 | class ResnetCompressor(Compressor): 105 | def __init__( 106 | self, 107 | dim=64, 108 | dim_mults=(1, 2, 3, 4), 109 | reverse_dim_mults=(4, 3, 2, 1), 110 | hyper_dims_mults=(4, 4, 4), 111 | channels=3, 112 | out_channels=3, 113 | ): 114 | super().__init__( 115 | dim, 116 | dim_mults, 117 | reverse_dim_mults, 118 | hyper_dims_mults, 119 | channels, 120 | out_channels 121 | ) 122 | self.build_network() 123 | 124 | def build_network(self): 125 | 126 | self.enc = nn.ModuleList([]) 127 | self.dec = nn.ModuleList([]) 128 | self.hyper_enc = nn.ModuleList([]) 129 | self.hyper_dec = nn.ModuleList([]) 130 | 131 | for ind, (dim_in, dim_out) in enumerate(self.in_out): 132 | is_last = ind >= (len(self.in_out) - 1) 133 | self.enc.append( 134 | nn.ModuleList( 135 | [ 136 | ResnetBlock(dim_in, dim_out, None, True if ind == 0 else False), 137 | Downsample(dim_out), 138 | ] 139 | ) 140 | ) 141 | 142 | for ind, (dim_in, dim_out) in enumerate(self.reversed_in_out): 143 | is_last = ind >= (len(self.reversed_in_out) - 1) 144 | self.dec.append( 145 | nn.ModuleList( 146 | [ 147 | ResnetBlock(dim_in, dim_out if not is_last else dim_in), 148 | Upsample(dim_out if not is_last else dim_in, dim_out), 149 | ] 150 | ) 151 | ) 152 | 153 | for ind, (dim_in, dim_out) in enumerate(self.hyper_in_out): 154 | is_last = ind >= (len(self.hyper_in_out) - 1) 155 | self.hyper_enc.append( 156 | nn.ModuleList( 157 | [ 158 | nn.Conv2d(dim_in, dim_out, 3, 1, 1) 159 | if ind == 0 160 | else nn.Conv2d(dim_in, dim_out, 5, 2, 2), 161 | nn.LeakyReLU(0.2) if not is_last else nn.Identity(), 162 | ] 163 | ) 164 | ) 165 | 166 | for ind, (dim_in, dim_out) in enumerate(self.reversed_hyper_in_out): 167 | is_last = ind >= (len(self.reversed_hyper_in_out) - 1) 168 | self.hyper_dec.append( 169 | nn.ModuleList( 170 | [ 171 | nn.Conv2d(dim_in, dim_out, 3, 1, 1) 172 | if is_last 173 | else nn.ConvTranspose2d(dim_in, dim_out, 5, 2, 2, 1), 174 | nn.LeakyReLU(0.2) if not is_last else nn.Identity(), 175 | ] 176 | ) 177 | ) 178 | -------------------------------------------------------------------------------- /xparam/modules/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .utils import exists, default 4 | from .network_components import ( 5 | LayerNorm, 6 | Residual, 7 | # SinusoidalPosEmb, 8 | Upsample, 9 | Downsample, 10 | PreNorm, 11 | LinearAttention, 12 | # Block, 13 | ResnetBlock, 14 | ImprovedSinusoidalPosEmb 15 | ) 16 | 17 | 18 | class Unet(nn.Module): 19 | def __init__( 20 | self, 21 | dim, 22 | out_dim=None, 23 | dim_mults=(1, 2, 4, 8), 24 | context_dim_mults=(1, 2, 3, 3), 25 | channels=3, 26 | context_channels=3, 27 | with_time_emb=True, 28 | embd_type="01" 29 | ): 30 | super().__init__() 31 | self.channels = channels 32 | 33 | dims = [channels, *map(lambda m: dim * m, dim_mults)] 34 | context_dims = [context_channels, *map(lambda m: dim * m, context_dim_mults)] 35 | in_out = list(zip(dims[:-1], dims[1:])) 36 | self.embd_type = embd_type 37 | 38 | if with_time_emb: 39 | if embd_type == "01": 40 | time_dim = dim 41 | self.time_mlp = nn.Sequential(nn.Linear(1, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)) 42 | elif embd_type == "index": 43 | time_dim = dim 44 | self.time_mlp = nn.Sequential( 45 | ImprovedSinusoidalPosEmb(time_dim // 2), 46 | nn.Linear(time_dim // 2 + 1, time_dim * 4), 47 | nn.GELU(), 48 | nn.Linear(time_dim * 4, time_dim) 49 | ) 50 | else: 51 | raise NotImplementedError 52 | else: 53 | time_dim = None 54 | self.time_mlp = None 55 | 56 | self.downs = nn.ModuleList([]) 57 | self.ups = nn.ModuleList([]) 58 | num_resolutions = len(in_out) 59 | 60 | for ind, (dim_in, dim_out) in enumerate(in_out): 61 | is_last = ind >= (num_resolutions - 1) 62 | 63 | self.downs.append( 64 | nn.ModuleList( 65 | [ 66 | ResnetBlock( 67 | dim_in + context_dims[ind] 68 | if (not is_last) and (ind < (len(context_dims) - 1)) 69 | else dim_in, 70 | dim_out, 71 | time_dim, 72 | True if ind == 0 else False 73 | ), 74 | ResnetBlock(dim_out, dim_out, time_dim), 75 | Residual(PreNorm(dim_out, LinearAttention(dim_out))), 76 | # nn.Identity(), 77 | Downsample(dim_out) if not is_last else nn.Identity(), 78 | ] 79 | ) 80 | ) 81 | 82 | mid_dim = dims[-1] 83 | self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_dim) 84 | self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention(mid_dim))) 85 | # self.mid_attn = nn.Identity() 86 | self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_dim) 87 | 88 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 89 | is_last = ind >= (num_resolutions - 1) 90 | 91 | self.ups.append( 92 | nn.ModuleList( 93 | [ 94 | ResnetBlock(dim_out * 2, dim_in, time_dim), 95 | ResnetBlock(dim_in, dim_in, time_dim), 96 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 97 | # nn.Identity(), 98 | Upsample(dim_in) if not is_last else nn.Identity(), 99 | ] 100 | ) 101 | ) 102 | 103 | out_dim = default(out_dim, channels) 104 | self.final_conv = nn.Sequential(LayerNorm(dim), nn.Conv2d(dim, out_dim, 7, padding=3)) 105 | 106 | def encode(self, x, t, context): 107 | h = [] 108 | for idx, (backbone, backbone2, attn, downsample) in enumerate(self.downs): 109 | x = torch.cat([x, context[idx]], dim=1) if idx < len(context) else x 110 | x = backbone(x, t) 111 | x = backbone2(x, t) 112 | x = attn(x) 113 | h.append(x) 114 | x = downsample(x) 115 | 116 | x = self.mid_block1(x, t) 117 | return x, h 118 | 119 | def decode(self, x, h, t): 120 | x = self.mid_attn(x) 121 | x = self.mid_block2(x, t) 122 | 123 | for backbone, backbone2, attn, upsample in self.ups: 124 | x = torch.cat((x, h.pop()), dim=1) 125 | x = backbone(x, t) 126 | x = backbone2(x, t) 127 | x = attn(x) 128 | x = upsample(x) 129 | return self.final_conv(x) 130 | 131 | def forward(self, x, time=None, context=None): 132 | t = self.time_mlp(time) if exists(self.time_mlp) else None 133 | 134 | x, h = self.encode(x, t, context) 135 | return self.decode(x, h, t) 136 | -------------------------------------------------------------------------------- /xparam/modules/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from inspect import isfunction 3 | from torch.autograd import Function 4 | import numpy as np 5 | 6 | 7 | def exists(x): 8 | return x is not None 9 | 10 | 11 | def default(val, d): 12 | if exists(val): 13 | return val 14 | return d() if isfunction(d) else d 15 | 16 | 17 | def cycle(dl): 18 | while True: 19 | for data in dl: 20 | yield data 21 | 22 | 23 | def num_to_groups(num, divisor): 24 | groups = num // divisor 25 | remainder = num % divisor 26 | arr = [divisor] * groups 27 | if remainder > 0: 28 | arr.append(remainder) 29 | return arr 30 | 31 | 32 | def extract(a, t, x_shape): 33 | b, *_ = t.shape 34 | out = a.gather(-1, t) 35 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 36 | 37 | 38 | def extract_tensor(a, t, place_holder=None): 39 | return a[t, torch.arange(len(t))] 40 | 41 | 42 | def noise_like(shape, device, repeat=False): 43 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( 44 | shape[0], *((1,) * (len(shape) - 1)) 45 | ) 46 | noise = lambda: torch.randn(shape, device=device) 47 | return repeat_noise() if repeat else noise() 48 | 49 | 50 | def cosine_beta_schedule(timesteps, s=0.008): 51 | """ 52 | cosine schedule 53 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 54 | """ 55 | steps = timesteps + 1 56 | x = np.linspace(0, steps, steps) 57 | alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 58 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 59 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 60 | return np.clip(betas, a_min=0, a_max=0.999) 61 | 62 | def linear_beta_schedule(timesteps): 63 | scale = 1000 / timesteps 64 | beta_start = scale * 0.0001 65 | beta_end = scale * 0.02 66 | return np.linspace(beta_start, beta_end, timesteps) 67 | 68 | def noise(input, scale): 69 | return input + scale*(torch.rand_like(input) - 0.5) 70 | 71 | 72 | def round_w_offset(input, loc): 73 | diff = STERound.apply(input - loc) 74 | return diff + loc 75 | 76 | 77 | def quantize(x, mode='noise', offset=None): 78 | if mode == 'noise': 79 | return noise(x, 1) 80 | elif mode == 'round': 81 | return STERound.apply(x) 82 | elif mode == 'dequantize': 83 | return round_w_offset(x, offset) 84 | else: 85 | raise NotImplementedError 86 | 87 | 88 | class STERound(Function): 89 | @staticmethod 90 | def forward(ctx, x): 91 | return x.round() 92 | 93 | @staticmethod 94 | def backward(ctx, g): 95 | return g 96 | 97 | 98 | class LowerBound(Function): 99 | @staticmethod 100 | def forward(ctx, inputs, bound): 101 | b = torch.ones_like(inputs) * bound 102 | ctx.save_for_backward(inputs, b) 103 | return torch.max(inputs, b) 104 | 105 | @staticmethod 106 | def backward(ctx, grad_output): 107 | inputs, b = ctx.saved_tensors 108 | 109 | pass_through_1 = inputs >= b 110 | pass_through_2 = grad_output < 0 111 | 112 | pass_through = pass_through_1 | pass_through_2 113 | return pass_through.type(grad_output.dtype) * grad_output, None 114 | 115 | 116 | class UpperBound(Function): 117 | @staticmethod 118 | def forward(ctx, inputs, bound): 119 | b = torch.ones_like(inputs) * bound 120 | ctx.save_for_backward(inputs, b) 121 | return torch.min(inputs, b) 122 | 123 | @staticmethod 124 | def backward(ctx, grad_output): 125 | inputs, b = ctx.saved_tensors 126 | 127 | pass_through_1 = inputs <= b 128 | pass_through_2 = grad_output > 0 129 | 130 | pass_through = pass_through_1 | pass_through_2 131 | return pass_through.type(grad_output.dtype) * grad_output, None 132 | 133 | 134 | class NormalDistribution: 135 | ''' 136 | A normal distribution 137 | ''' 138 | def __init__(self, loc, scale): 139 | assert loc.shape == scale.shape 140 | self.loc = loc 141 | self.scale = scale 142 | 143 | @property 144 | def mean(self): 145 | return self.loc.detach() 146 | 147 | def std_cdf(self, inputs): 148 | half = 0.5 149 | const = -(2**-0.5) 150 | return half * torch.erfc(const * inputs) 151 | 152 | def sample(self): 153 | return self.scale * torch.randn_like(self.scale) + self.loc 154 | 155 | def likelihood(self, x, min=1e-9): 156 | x = torch.abs(x - self.loc) 157 | upper = self.std_cdf((.5 - x) / self.scale) 158 | lower = self.std_cdf((-.5 - x) / self.scale) 159 | return LowerBound.apply(upper - lower, min) 160 | 161 | def scaled_likelihood(self, x, s=1, min=1e-9): 162 | x = torch.abs(x - self.loc) 163 | s = s * .5 164 | upper = self.std_cdf((s - x) / self.scale) 165 | lower = self.std_cdf((-s - x) / self.scale) 166 | return LowerBound.apply(upper - lower, min) -------------------------------------------------------------------------------- /xparam/test_xparam.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import torchvision 5 | import numpy as np 6 | import pathlib 7 | from modules.denoising_diffusion import GaussianDiffusion 8 | from modules.unet import Unet 9 | from modules.compress_modules import ResnetCompressor 10 | from ema_pytorch import EMA 11 | 12 | parser = argparse.ArgumentParser(description="values from bash script") 13 | 14 | parser.add_argument("--ckpt", type=str, required=True) # ckpt path 15 | parser.add_argument("--gamma", type=float, default=0.8) # noise intensity for decoding 16 | parser.add_argument("--n_denoise_step", type=int, default=65) # number of denoising step 17 | parser.add_argument("--device", type=int, default=0) # gpu device index 18 | parser.add_argument("--img_dir", type=str, default='../imgs') 19 | parser.add_argument("--out_dir", type=str, default='../compressed_imgs') 20 | parser.add_argument("--lpips_weight", type=float, required=True) # either 0.9 or 0.0, note that this must match the ckpt you use, because with weight>0, the lpips-vggnet weights were also saved during training. Incorrect state_dict keys may lead to load_state_dict error when loading the ckpt. 21 | 22 | config = parser.parse_args() 23 | 24 | 25 | def main(rank): 26 | 27 | 28 | 29 | denoise_model = Unet( 30 | dim=64, 31 | channels=3, 32 | context_channels=64, 33 | dim_mults=[1,2,3,4,5,6], 34 | context_dim_mults=[1,2,3,4], 35 | embd_type="01", 36 | ) 37 | 38 | context_model = ResnetCompressor( 39 | dim=64, 40 | dim_mults=[1,2,3,4], 41 | reverse_dim_mults=[4,3,2,1], 42 | hyper_dims_mults=[4,4,4], 43 | channels=3, 44 | out_channels=64, 45 | ) 46 | 47 | diffusion = GaussianDiffusion( 48 | denoise_fn=denoise_model, 49 | context_fn=context_model, 50 | ae_fn=None, 51 | num_timesteps=8193, 52 | loss_type="l2", 53 | lagrangian=0.0032, 54 | pred_mode="x", 55 | aux_loss_weight=config.lpips_weight, 56 | aux_loss_type="lpips", 57 | var_schedule="cosine", 58 | use_loss_weight=True, 59 | loss_weight_min=5, 60 | use_aux_loss_weight_schedule=False, 61 | ) 62 | loaded_param = torch.load( 63 | config.ckpt, 64 | map_location=lambda storage, loc: storage, 65 | ) 66 | ema = EMA(diffusion, beta=0.999, update_every=10, power=0.75, update_after_step=100) 67 | ema.load_state_dict(loaded_param["ema"]) 68 | diffusion = ema.ema_model 69 | diffusion.to(rank) 70 | diffusion.eval() 71 | 72 | for img in os.listdir(config.img_dir): 73 | if img.endswith(".png") or img.endswith(".jpg"): 74 | to_be_compressed = torchvision.io.read_image(os.path.join(config.img_dir, img)).unsqueeze(0).float().to(rank) / 255.0 75 | compressed, bpp = diffusion.compress( 76 | to_be_compressed * 2.0 - 1.0, 77 | sample_steps=config.n_denoise_step, 78 | bpp_return_mean=True, 79 | init=torch.randn_like(to_be_compressed) * config.gamma 80 | ) 81 | compressed = compressed.clamp(-1, 1) / 2.0 + 0.5 82 | pathlib.Path(config.out_dir).mkdir(parents=True, exist_ok=True) 83 | torchvision.utils.save_image(compressed.cpu(), os.path.join(config.out_dir, img)) 84 | print("bpp:", bpp) 85 | 86 | 87 | if __name__ == "__main__": 88 | main(config.device) 89 | -------------------------------------------------------------------------------- /xparam/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=d_eori # Job name 4 | #SBATCH --gres=gpu:1 # how many gpus would you like to use (here I use 1) 5 | #SBATCH --mail-type=END,FAIL # Mail events (NONE, BEGIN, END, FAIL, ALL) 6 | #SBATCH --mail-user=ruihan.yang@uci.edu # Where to send mail (for notification only) 7 | #SBATCH --nodes=1 # Run all processes on a single node 8 | #SBATCH --ntasks=1 # Run a single task 9 | #SBATCH --cpus-per-task=5 # Number of CPU cores per task 10 | #SBATCH --mem=16G # Job memory request 11 | #SBATCH --time=7-00:00:00 # Time limit hrs:min:sec 12 | #SBATCH --partition=ava_m.p # partition name 13 | #SBATCH --exclude=ava-m4 # select your node (or not) 14 | #SBATCH --output=logs/job_%j.log # output log 15 | #SBATCH -a 0 16 | 17 | 18 | pairs=("l2 cosine") 19 | item=${pairs[$SLURM_ARRAY_TASK_ID]} 20 | lt="${item% *}" 21 | vs="${item#* }" 22 | /home/ruihay1/miniconda3/envs/exp_pytorch/bin/python train.py --iteration_step 8193 \ 23 | --device 0 --loss_type $lt --var_schedule $vs --pred_mode "noise" --beta 0.000001 \ 24 | --aux_weight 0 --reverse_context_dim_mults 4 3 2 1 --ae_path "" --embd_type "01" --load_model --load_step --------------------------------------------------------------------------------