├── make_a_video ├── __init__.py └── components │ ├── __init__.py │ ├── backbone │ ├── t2i.py │ └── __init__.py │ └── layers │ ├── __init__.py │ ├── frame_interpolation.py │ ├── p3d_attention.py │ └── p3d_conv.py ├── pseudo3d.png ├── t2v_architecture.png ├── README.md ├── LICENSE └── .gitignore /make_a_video/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /make_a_video/components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /make_a_video/components/backbone/t2i.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /make_a_video/components/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /make_a_video/components/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /make_a_video/components/layers/frame_interpolation.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pseudo3d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soran-ghaderi/make-a-video/HEAD/pseudo3d.png -------------------------------------------------------------------------------- /t2v_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soran-ghaderi/make-a-video/HEAD/t2v_architecture.png -------------------------------------------------------------------------------- /make_a_video/components/layers/p3d_attention.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import tensorflow as tf 3 | 4 | 5 | def flatten(input_tensor): 6 | return einops.rearrange(input_tensor, 'b c f h w -> b c f (h w)') 7 | 8 | 9 | def unflatten(input_tensor, h, w): 10 | return einops.rearrange(input_tensor, '... (h w) -> ... h w', h=h, w=w) 11 | 12 | 13 | class Attention2D: 14 | pass 15 | 16 | 17 | class Attention1D: 18 | pass 19 | 20 | 21 | if __name__ == '__main__': 22 | input = tf.random.normal((10, 3, 7, 5, 5)) 23 | print(input.shape) 24 | flatten_t = flatten((input)) 25 | print(flatten_t.shape) 26 | print(unflatten(flatten_t, 5, 5).shape) 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # MAKE-A-VIDEO 4 | ### TEXT-TO-VIDEO GENERATION WITHOUT TEXT-VIDEO DATA 5 | 6 | 7 | 8 |
9 | 10 | ### Main components 11 | It consists of three main components: 12 | 1. A base T2I model trained on text-image pairs 13 | 2. spatiotemporal convolution and attention layers that extend the networks’ building 14 | blocks to the temporal dimension 15 | 3. spatiotemporal networks that consist of both spatiotemporal layers, as well as another crucial element needed for 16 | T2V generation - a frame interpolation network for high frame rate generation 17 | 18 | ### Spaciotemporal layers 19 | #### 1. Pseudo-3D convolutional layers 20 | It stacks a 1D convolution following each 2D convolutional (conv) layer to facilitate information sharing between spacial and temporal axes with less computational power compared with 3D conv layers. 21 | 22 |
23 | 24 |
25 | 26 | #### 2. Pseudo-3D attention layers 27 | #### 3 Frame interpolation network 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Soran Ghaderi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /make_a_video/components/layers/p3d_conv.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, Tuple 3 | 4 | import tensorflow as tf 5 | from einops import rearrange 6 | 7 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 8 | 9 | __all__ = ['P3D'] 10 | 11 | 12 | class P3D: 13 | def __init__(self, output_channels: int = None, stride: Optional[Tuple[int, Tuple]] = 1, 14 | padding: Optional[str] = 'valid'): 15 | """ 16 | 17 | Parameters 18 | ---------- 19 | output_channels : 20 | Number of channels produced by the convolution 21 | stride : 22 | Stride of the convolution. Default: 1 23 | padding : 24 | Padding added to all six sides of the input 25 | """ 26 | self.s_padding = tf.constant([[0, 0], [0, 0], [0, 0], [1, 1], [1, 1], [0, 0]]) 27 | self.t_padding = tf.constant([[0, 0], [0, 0], [1, 1], [0, 0], [0, 0], [0, 0]]) 28 | self.output_channels = output_channels 29 | self.stride = stride 30 | self.padding = padding 31 | 32 | def conv_S(self): 33 | """Applies a 3D convolution over an input signal composed of several input planes. 34 | 35 | A (1, 3, 3) convolution layer as described in [1]_ 36 | 37 | References 38 | ---------- 39 | [1] https://openaccess.thecvf.com/content_ICCV_2017/papers/Qiu_Learning_Spatio-Temporal_Representation_ICCV_2017_paper.pdf 40 | 41 | Returns 42 | ------- 43 | out: 44 | A tensor of rank 4+ representing 45 | """ 46 | return tf.keras.layers.Conv3D(filters=self.output_channels, 47 | kernel_size=[1, 3, 3], 48 | strides=self.stride, 49 | padding=self.padding, 50 | use_bias=False) 51 | 52 | def conv_T(self): 53 | """Apply a 3D convolution over an input signal composed of several input planes. 54 | 55 | A (1, 3, 3) convolution layer as described in [1]_ 56 | 57 | References 58 | ---------- 59 | [1] https://openaccess.thecvf.com/content_ICCV_2017/papers/Qiu_Learning_Spatio-Temporal_Representation_ICCV_2017_paper.pdf 60 | 61 | Returns 62 | ------- 63 | out: 64 | A tensor of rank 4+ representing 65 | """ 66 | 67 | return tf.keras.layers.Conv3D(filters=self.output_channels, 68 | kernel_size=[3, 1, 1], 69 | strides=self.stride, 70 | padding=self.padding, 71 | use_bias=False) 72 | 73 | def p3d_a(self, inputs, convolve_across_time=True): 74 | """Return P3D-A as described in [1]_ 75 | 76 | References 77 | ---------- 78 | [1] https://openaccess.thecvf.com/content_ICCV_2017/papers/Qiu_Learning_Spatio-Temporal_Representation_ICCV_2017_paper.pdf 79 | 80 | Parameters 81 | ---------- 82 | inputs : 83 | 4+d input tensor 84 | convolve_across_time : 85 | Boolean indicate the convolution across time 86 | Returns 87 | ------- 88 | out: 89 | A tensor of rank 4+ representing 90 | """ 91 | spacial = self.conv_S() 92 | temporal = self.conv_T() 93 | b, c, *_, h, w = inputs.shape 94 | 95 | is_video = len(inputs.shape) == 5 96 | 97 | if is_video: 98 | # convolve spatially 99 | rearrange(inputs, "b c f h w -> (b f) c h w") 100 | 101 | # spacial_padded_inputs = tf.pad(inputs, self.s_padding) 102 | x = spacial(inputs) 103 | x = tf.keras.layers.BatchNormalization(self.output_channels)(x) 104 | x = tf.nn.relu(x) 105 | 106 | if is_video: 107 | rearrange(x, "(b f) c h w -> b c f h w", b=b) 108 | 109 | if convolve_across_time: 110 | rearrange(x, "b c f h w -> (b h w) c f") 111 | 112 | # temporal_padded_inputs = tf.pad(x, self.t_padding) 113 | x = temporal(x) 114 | x = tf.keras.layers.BatchNormalization(self.output_channels)(x) 115 | x = tf.nn.relu(x) 116 | 117 | rearrange(x, "(b h w) c f -> b c f h w", h=h, w=w) 118 | return x 119 | 120 | def p3d_b(self, output_channels, inputs): 121 | """Return P3D-B as described in [1]_ 122 | 123 | References 124 | ---------- 125 | [1] https://openaccess.thecvf.com/content_ICCV_2017/papers/Qiu_Learning_Spatio-Temporal_Representation_ICCV_2017_paper.pdf 126 | 127 | Parameters 128 | ---------- 129 | output_channels : 130 | Number of channels produced by the convolution 131 | inputs : 132 | 4+d input tensor 133 | Returns 134 | ------- 135 | out: 136 | A tensor of rank 4+ representing 137 | """ 138 | spacial = self.conv_S() 139 | temporal = self.conv_T() 140 | 141 | spacial_padded_inputs = tf.pad(inputs, self.s_padding) 142 | spacial_p3d = spacial(spacial_padded_inputs) 143 | spacial_p3d = tf.keras.layers.BatchNormalization(output_channels)(spacial_p3d) 144 | spacial_p3d = tf.nn.relu(spacial_p3d) 145 | 146 | temporal_padded_inputs = tf.pad(inputs, self.t_padding) # raw 'inputs' is fed 147 | temporal_p3d = temporal(temporal_padded_inputs) 148 | temporal_p3d = tf.keras.layers.BatchNormalization(output_channels)(temporal_p3d) 149 | temporal_p3d = tf.nn.relu(temporal_p3d) 150 | 151 | return temporal_p3d + spacial_p3d 152 | 153 | def p3d_c(self, output_channels, inputs): 154 | """Return P3D-C as described in [1]_ 155 | 156 | References 157 | ---------- 158 | [1] https://openaccess.thecvf.com/content_ICCV_2017/papers/Qiu_Learning_Spatio-Temporal_Representation_ICCV_2017_paper.pdf 159 | 160 | Parameters 161 | ---------- 162 | output_channels : 163 | Number of channels produced by the convolution 164 | inputs : 165 | 4+d input tensor 166 | Returns 167 | ------- 168 | out: 169 | A tensor of rank 4+ representing 170 | """ 171 | spacial = self.conv_S() 172 | temporal = self.conv_T() 173 | 174 | spacial_padded_inputs = tf.pad(inputs, self.s_padding) 175 | spacial_p3d = spacial(spacial_padded_inputs) 176 | spacial_p3d = tf.keras.layers.BatchNormalization(output_channels)(spacial_p3d) 177 | spacial_p3d = tf.nn.relu(spacial_p3d) 178 | 179 | temporal_padded_inputs = tf.pad(spacial_p3d, self.t_padding) # spacial_p3d is fed 180 | temporal_p3d = temporal(temporal_padded_inputs) 181 | temporal_p3d = tf.keras.layers.BatchNormalization(output_channels)(temporal_p3d) 182 | temporal_p3d = tf.nn.relu(temporal_p3d) 183 | 184 | return temporal_p3d + spacial_p3d 185 | 186 | def call(self, inputs: tf.Tensor, convolve_across_time: bool = True): 187 | """ 188 | 189 | Parameters 190 | ---------- 191 | inputs : 192 | 4+d input tensor 193 | convolve_across_time : 194 | Boolean indicate the convolution across time 195 | 196 | Returns 197 | ------- 198 | out: 199 | A tensor of rank 4+ representing 200 | """ 201 | 202 | return self.p3d_a(inputs, convolve_across_time) 203 | 204 | 205 | if __name__ == '__main__': 206 | p3d = P3D(5) 207 | # cs = p3d.conv_S(16) 208 | input = tf.random.normal((20, 7, 20, 10, 50, 3)) 209 | 210 | print("input shape: ", input.shape) 211 | print("p3d_a shape: ", p3d.p3d_a(input).shape) 212 | print("p3d_b shape: ", p3d.p3d_b(5, input).shape) 213 | print("p3d_c shape: ", p3d.p3d_c(5, input).shape) 214 | --------------------------------------------------------------------------------