├── make_a_video
├── __init__.py
└── components
│ ├── __init__.py
│ ├── backbone
│ ├── t2i.py
│ └── __init__.py
│ └── layers
│ ├── __init__.py
│ ├── frame_interpolation.py
│ ├── p3d_attention.py
│ └── p3d_conv.py
├── pseudo3d.png
├── t2v_architecture.png
├── README.md
├── LICENSE
└── .gitignore
/make_a_video/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/make_a_video/components/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/make_a_video/components/backbone/t2i.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/make_a_video/components/layers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/make_a_video/components/backbone/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/make_a_video/components/layers/frame_interpolation.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/pseudo3d.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soran-ghaderi/make-a-video/HEAD/pseudo3d.png
--------------------------------------------------------------------------------
/t2v_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soran-ghaderi/make-a-video/HEAD/t2v_architecture.png
--------------------------------------------------------------------------------
/make_a_video/components/layers/p3d_attention.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import tensorflow as tf
3 |
4 |
5 | def flatten(input_tensor):
6 | return einops.rearrange(input_tensor, 'b c f h w -> b c f (h w)')
7 |
8 |
9 | def unflatten(input_tensor, h, w):
10 | return einops.rearrange(input_tensor, '... (h w) -> ... h w', h=h, w=w)
11 |
12 |
13 | class Attention2D:
14 | pass
15 |
16 |
17 | class Attention1D:
18 | pass
19 |
20 |
21 | if __name__ == '__main__':
22 | input = tf.random.normal((10, 3, 7, 5, 5))
23 | print(input.shape)
24 | flatten_t = flatten((input))
25 | print(flatten_t.shape)
26 | print(unflatten(flatten_t, 5, 5).shape)
27 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # MAKE-A-VIDEO
4 | ### TEXT-TO-VIDEO GENERATION WITHOUT TEXT-VIDEO DATA
5 |
6 |
7 |

8 |
9 |
10 | ### Main components
11 | It consists of three main components:
12 | 1. A base T2I model trained on text-image pairs
13 | 2. spatiotemporal convolution and attention layers that extend the networks’ building
14 | blocks to the temporal dimension
15 | 3. spatiotemporal networks that consist of both spatiotemporal layers, as well as another crucial element needed for
16 | T2V generation - a frame interpolation network for high frame rate generation
17 |
18 | ### Spaciotemporal layers
19 | #### 1. Pseudo-3D convolutional layers
20 | It stacks a 1D convolution following each 2D convolutional (conv) layer to facilitate information sharing between spacial and temporal axes with less computational power compared with 3D conv layers.
21 |
22 |
23 |

24 |
25 |
26 | #### 2. Pseudo-3D attention layers
27 | #### 3 Frame interpolation network
28 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Soran Ghaderi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/make_a_video/components/layers/p3d_conv.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional, Tuple
3 |
4 | import tensorflow as tf
5 | from einops import rearrange
6 |
7 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
8 |
9 | __all__ = ['P3D']
10 |
11 |
12 | class P3D:
13 | def __init__(self, output_channels: int = None, stride: Optional[Tuple[int, Tuple]] = 1,
14 | padding: Optional[str] = 'valid'):
15 | """
16 |
17 | Parameters
18 | ----------
19 | output_channels :
20 | Number of channels produced by the convolution
21 | stride :
22 | Stride of the convolution. Default: 1
23 | padding :
24 | Padding added to all six sides of the input
25 | """
26 | self.s_padding = tf.constant([[0, 0], [0, 0], [0, 0], [1, 1], [1, 1], [0, 0]])
27 | self.t_padding = tf.constant([[0, 0], [0, 0], [1, 1], [0, 0], [0, 0], [0, 0]])
28 | self.output_channels = output_channels
29 | self.stride = stride
30 | self.padding = padding
31 |
32 | def conv_S(self):
33 | """Applies a 3D convolution over an input signal composed of several input planes.
34 |
35 | A (1, 3, 3) convolution layer as described in [1]_
36 |
37 | References
38 | ----------
39 | [1] https://openaccess.thecvf.com/content_ICCV_2017/papers/Qiu_Learning_Spatio-Temporal_Representation_ICCV_2017_paper.pdf
40 |
41 | Returns
42 | -------
43 | out:
44 | A tensor of rank 4+ representing
45 | """
46 | return tf.keras.layers.Conv3D(filters=self.output_channels,
47 | kernel_size=[1, 3, 3],
48 | strides=self.stride,
49 | padding=self.padding,
50 | use_bias=False)
51 |
52 | def conv_T(self):
53 | """Apply a 3D convolution over an input signal composed of several input planes.
54 |
55 | A (1, 3, 3) convolution layer as described in [1]_
56 |
57 | References
58 | ----------
59 | [1] https://openaccess.thecvf.com/content_ICCV_2017/papers/Qiu_Learning_Spatio-Temporal_Representation_ICCV_2017_paper.pdf
60 |
61 | Returns
62 | -------
63 | out:
64 | A tensor of rank 4+ representing
65 | """
66 |
67 | return tf.keras.layers.Conv3D(filters=self.output_channels,
68 | kernel_size=[3, 1, 1],
69 | strides=self.stride,
70 | padding=self.padding,
71 | use_bias=False)
72 |
73 | def p3d_a(self, inputs, convolve_across_time=True):
74 | """Return P3D-A as described in [1]_
75 |
76 | References
77 | ----------
78 | [1] https://openaccess.thecvf.com/content_ICCV_2017/papers/Qiu_Learning_Spatio-Temporal_Representation_ICCV_2017_paper.pdf
79 |
80 | Parameters
81 | ----------
82 | inputs :
83 | 4+d input tensor
84 | convolve_across_time :
85 | Boolean indicate the convolution across time
86 | Returns
87 | -------
88 | out:
89 | A tensor of rank 4+ representing
90 | """
91 | spacial = self.conv_S()
92 | temporal = self.conv_T()
93 | b, c, *_, h, w = inputs.shape
94 |
95 | is_video = len(inputs.shape) == 5
96 |
97 | if is_video:
98 | # convolve spatially
99 | rearrange(inputs, "b c f h w -> (b f) c h w")
100 |
101 | # spacial_padded_inputs = tf.pad(inputs, self.s_padding)
102 | x = spacial(inputs)
103 | x = tf.keras.layers.BatchNormalization(self.output_channels)(x)
104 | x = tf.nn.relu(x)
105 |
106 | if is_video:
107 | rearrange(x, "(b f) c h w -> b c f h w", b=b)
108 |
109 | if convolve_across_time:
110 | rearrange(x, "b c f h w -> (b h w) c f")
111 |
112 | # temporal_padded_inputs = tf.pad(x, self.t_padding)
113 | x = temporal(x)
114 | x = tf.keras.layers.BatchNormalization(self.output_channels)(x)
115 | x = tf.nn.relu(x)
116 |
117 | rearrange(x, "(b h w) c f -> b c f h w", h=h, w=w)
118 | return x
119 |
120 | def p3d_b(self, output_channels, inputs):
121 | """Return P3D-B as described in [1]_
122 |
123 | References
124 | ----------
125 | [1] https://openaccess.thecvf.com/content_ICCV_2017/papers/Qiu_Learning_Spatio-Temporal_Representation_ICCV_2017_paper.pdf
126 |
127 | Parameters
128 | ----------
129 | output_channels :
130 | Number of channels produced by the convolution
131 | inputs :
132 | 4+d input tensor
133 | Returns
134 | -------
135 | out:
136 | A tensor of rank 4+ representing
137 | """
138 | spacial = self.conv_S()
139 | temporal = self.conv_T()
140 |
141 | spacial_padded_inputs = tf.pad(inputs, self.s_padding)
142 | spacial_p3d = spacial(spacial_padded_inputs)
143 | spacial_p3d = tf.keras.layers.BatchNormalization(output_channels)(spacial_p3d)
144 | spacial_p3d = tf.nn.relu(spacial_p3d)
145 |
146 | temporal_padded_inputs = tf.pad(inputs, self.t_padding) # raw 'inputs' is fed
147 | temporal_p3d = temporal(temporal_padded_inputs)
148 | temporal_p3d = tf.keras.layers.BatchNormalization(output_channels)(temporal_p3d)
149 | temporal_p3d = tf.nn.relu(temporal_p3d)
150 |
151 | return temporal_p3d + spacial_p3d
152 |
153 | def p3d_c(self, output_channels, inputs):
154 | """Return P3D-C as described in [1]_
155 |
156 | References
157 | ----------
158 | [1] https://openaccess.thecvf.com/content_ICCV_2017/papers/Qiu_Learning_Spatio-Temporal_Representation_ICCV_2017_paper.pdf
159 |
160 | Parameters
161 | ----------
162 | output_channels :
163 | Number of channels produced by the convolution
164 | inputs :
165 | 4+d input tensor
166 | Returns
167 | -------
168 | out:
169 | A tensor of rank 4+ representing
170 | """
171 | spacial = self.conv_S()
172 | temporal = self.conv_T()
173 |
174 | spacial_padded_inputs = tf.pad(inputs, self.s_padding)
175 | spacial_p3d = spacial(spacial_padded_inputs)
176 | spacial_p3d = tf.keras.layers.BatchNormalization(output_channels)(spacial_p3d)
177 | spacial_p3d = tf.nn.relu(spacial_p3d)
178 |
179 | temporal_padded_inputs = tf.pad(spacial_p3d, self.t_padding) # spacial_p3d is fed
180 | temporal_p3d = temporal(temporal_padded_inputs)
181 | temporal_p3d = tf.keras.layers.BatchNormalization(output_channels)(temporal_p3d)
182 | temporal_p3d = tf.nn.relu(temporal_p3d)
183 |
184 | return temporal_p3d + spacial_p3d
185 |
186 | def call(self, inputs: tf.Tensor, convolve_across_time: bool = True):
187 | """
188 |
189 | Parameters
190 | ----------
191 | inputs :
192 | 4+d input tensor
193 | convolve_across_time :
194 | Boolean indicate the convolution across time
195 |
196 | Returns
197 | -------
198 | out:
199 | A tensor of rank 4+ representing
200 | """
201 |
202 | return self.p3d_a(inputs, convolve_across_time)
203 |
204 |
205 | if __name__ == '__main__':
206 | p3d = P3D(5)
207 | # cs = p3d.conv_S(16)
208 | input = tf.random.normal((20, 7, 20, 10, 50, 3))
209 |
210 | print("input shape: ", input.shape)
211 | print("p3d_a shape: ", p3d.p3d_a(input).shape)
212 | print("p3d_b shape: ", p3d.p3d_b(5, input).shape)
213 | print("p3d_c shape: ", p3d.p3d_c(5, input).shape)
214 |
--------------------------------------------------------------------------------