├── movie_gen ├── __init__.py └── tae.py ├── requirements.txt ├── scripts ├── tests.sh ├── test_name.sh ├── code_quality.sh └── merge_all_prs.sh ├── agorabanner.png ├── .github ├── workflows │ ├── ruff.yml │ ├── pull-request-links.yml │ ├── docs.yml │ ├── lints.yml │ ├── testing.yml │ ├── quality.yml │ ├── pr_request_checks.yml │ ├── label.yml │ ├── run_test.yml │ ├── welcome.yml │ ├── pylint.yml │ ├── docs_test.yml │ ├── unit-test.yml │ ├── python-publish.yml │ ├── code_quality_control.yml │ ├── stale.yml │ ├── cos_integration.yml │ └── test.yml ├── dependabot.yml ├── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md ├── FUNDING.yml ├── PULL_REQUEST_TEMPLATE.yml └── labeler.yml ├── .pre-commit-config.yaml ├── LICENSE ├── example.py ├── pyproject.toml ├── .gitignore └── README.md /movie_gen/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | zetascale 3 | swarms 4 | -------------------------------------------------------------------------------- /scripts/tests.sh: -------------------------------------------------------------------------------- 1 | find ./tests -name '*.py' -exec pytest {} \; -------------------------------------------------------------------------------- /agorabanner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyegomez/movie-gen/HEAD/agorabanner.png -------------------------------------------------------------------------------- /.github/workflows/ruff.yml: -------------------------------------------------------------------------------- 1 | name: Ruff 2 | on: [ push, pull_request ] 3 | jobs: 4 | ruff: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@v4 8 | - uses: chartboost/ruff-action@v1 9 | -------------------------------------------------------------------------------- /scripts/test_name.sh: -------------------------------------------------------------------------------- 1 | find ./tests -name "*.py" -type f | while read file 2 | do 3 | filename=$(basename "$file") 4 | dir=$(dirname "$file") 5 | if [[ $filename != test_* ]]; then 6 | mv "$file" "$dir/test_$filename" 7 | fi 8 | done -------------------------------------------------------------------------------- /.github/workflows/pull-request-links.yml: -------------------------------------------------------------------------------- 1 | name: readthedocs/actions 2 | on: 3 | pull_request_target: 4 | types: 5 | - opened 6 | paths: 7 | - "docs/**" 8 | 9 | permissions: 10 | pull-requests: write 11 | 12 | jobs: 13 | pull-request-links: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: readthedocs/actions/preview@v1 17 | with: 18 | project-slug: swarms_torch -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # https://docs.github.com/en/code-security/supply-chain-security/keeping-your-dependencies-updated-automatically/configuration-options-for-dependency-updates 2 | 3 | version: 2 4 | updates: 5 | - package-ecosystem: "github-actions" 6 | directory: "/" 7 | schedule: 8 | interval: "weekly" 9 | 10 | - package-ecosystem: "pip" 11 | directory: "/" 12 | schedule: 13 | interval: "weekly" 14 | 15 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Docs WorkFlow 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | - main 8 | - develop 9 | jobs: 10 | deploy: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - uses: actions/setup-python@v5 15 | with: 16 | python-version: '3.10' 17 | - run: pip install mkdocs-material 18 | - run: pip install "mkdocstrings[python]" 19 | - run: mkdocs gh-deploy --force -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: 22.3.0 4 | hooks: 5 | - id: black 6 | - repo: https://github.com/charliermarsh/ruff-pre-commit 7 | rev: 'v0.0.255' 8 | hooks: 9 | - id: ruff 10 | args: [--fix] 11 | - repo: https://github.com/nbQA-dev/nbQA 12 | rev: 1.6.3 13 | hooks: 14 | - id: nbqa-black 15 | additional_dependencies: [ipython==8.12, black] 16 | - id: nbqa-ruff 17 | args: ["--ignore=I001"] 18 | additional_dependencies: [ipython==8.12, ruff] -------------------------------------------------------------------------------- /.github/workflows/lints.yml: -------------------------------------------------------------------------------- 1 | name: Linting 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | 8 | jobs: 9 | lint: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Checkout code 14 | uses: actions/checkout@v4 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: '3.10' 20 | 21 | - name: Install dependencies 22 | run: pip install --no-cache-dir -r requirements.txt 23 | 24 | - name: Run linters 25 | run: pylint swarms_torch -------------------------------------------------------------------------------- /.github/workflows/testing.yml: -------------------------------------------------------------------------------- 1 | name: Unit Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | 8 | jobs: 9 | test: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Checkout code 14 | uses: actions/checkout@v4 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: '3.10' 20 | 21 | - name: Install dependencies 22 | run: pip install --no-cache-dir -r requirements.txt 23 | 24 | - name: Run unit tests 25 | run: pytest tests/ -------------------------------------------------------------------------------- /.github/workflows/quality.yml: -------------------------------------------------------------------------------- 1 | name: Quality 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | jobs: 10 | lint: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | fail-fast: false 14 | steps: 15 | - name: Checkout actions 16 | uses: actions/checkout@v4 17 | with: 18 | fetch-depth: 0 19 | - name: Init environment 20 | uses: ./.github/actions/init-environment 21 | - name: Run linter 22 | run: | 23 | pylint `git diff --name-only --diff-filter=d origin/main HEAD | grep -E '\.py$' | tr '\n' ' '` -------------------------------------------------------------------------------- /.github/workflows/pr_request_checks.yml: -------------------------------------------------------------------------------- 1 | name: Pull Request Checks 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - master 7 | 8 | jobs: 9 | test: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Checkout code 14 | uses: actions/checkout@v4 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: '3.10' 20 | 21 | - name: Install dependencies 22 | run: pip install --no-cache-dir -r requirements.txt 23 | 24 | - name: Run tests and checks 25 | run: | 26 | pytest tests/ 27 | pylint swarms_torch -------------------------------------------------------------------------------- /.github/workflows/label.yml: -------------------------------------------------------------------------------- 1 | # This workflow will triage pull requests and apply a label based on the 2 | # paths that are modified in the pull request. 3 | # 4 | # To use this workflow, you will need to set up a .github/labeler.yml 5 | # file with configuration. For more information, see: 6 | # https://github.com/actions/labeler 7 | 8 | name: Labeler 9 | on: [pull_request_target] 10 | 11 | jobs: 12 | label: 13 | 14 | runs-on: ubuntu-latest 15 | permissions: 16 | contents: read 17 | pull-requests: write 18 | 19 | steps: 20 | - uses: actions/labeler@v5.0.0 21 | with: 22 | repo-token: "${{ secrets.GITHUB_TOKEN }}" 23 | -------------------------------------------------------------------------------- /.github/workflows/run_test.yml: -------------------------------------------------------------------------------- 1 | name: Python application test 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Set up Python 3.10 13 | uses: actions/setup-python@v5 14 | with: 15 | python-version: '3.10' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --no-cache-dir --upgrade pip 19 | pip install pytest 20 | if [ -f requirements.txt ]; then pip install --no-cache-dir -r requirements.txt; fi 21 | - name: Run tests with pytest 22 | run: | 23 | pytest tests/ 24 | -------------------------------------------------------------------------------- /.github/workflows/welcome.yml: -------------------------------------------------------------------------------- 1 | name: Welcome WorkFlow 2 | 3 | on: 4 | issues: 5 | types: [opened] 6 | pull_request_target: 7 | types: [opened] 8 | 9 | jobs: 10 | build: 11 | name: 👋 Welcome 12 | permissions: write-all 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/first-interaction@v1.3.0 16 | with: 17 | repo-token: ${{ secrets.GITHUB_TOKEN }} 18 | issue-message: "Hello there, thank you for opening an Issue ! 🙏🏻 The team was notified and they will get back to you asap." 19 | pr-message: "Hello there, thank you for opening an PR ! 🙏🏻 The team was notified and they will get back to you asap." -------------------------------------------------------------------------------- /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: Pylint 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ["3.9", "3.10"] 11 | steps: 12 | - uses: actions/checkout@v4 13 | - name: Set up Python ${{ matrix.python-version }} 14 | uses: actions/setup-python@v5 15 | with: 16 | python-version: ${{ matrix.python-version }} 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --no-cache-dir --upgrade pip 20 | pip install pylint 21 | - name: Analysing the code with pylint 22 | run: | 23 | pylint $(git ls-files '*.py') 24 | -------------------------------------------------------------------------------- /.github/workflows/docs_test.yml: -------------------------------------------------------------------------------- 1 | name: Documentation Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | 8 | jobs: 9 | test: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Checkout code 14 | uses: actions/checkout@v4 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: '3.10' 20 | 21 | - name: Install dependencies 22 | run: pip install --no-cache-dir -r requirements.txt 23 | 24 | - name: Build documentation 25 | run: make docs 26 | 27 | - name: Validate documentation 28 | run: sphinx-build -b linkcheck docs build/docs -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: 'kyegomez' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [kyegomez] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: #Nothing 14 | -------------------------------------------------------------------------------- /.github/workflows/unit-test.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | 11 | build: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | 18 | - name: Setup Python 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: '3.10' 22 | 23 | - name: Install dependencies 24 | run: pip install --no-cache-dir -r requirements.txt 25 | 26 | - name: Run Python unit tests 27 | run: python3 -m unittest tests/ 28 | 29 | - name: Verify that the Docker image for the action builds 30 | run: docker build . --file Dockerfile 31 | 32 | - name: Verify integration test results 33 | run: python3 -m unittest tests/ 34 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a detailed report on the bug and it's root cause. Conduct root cause error analysis 4 | title: "[BUG] " 5 | labels: bug 6 | assignees: kyegomez 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is and what the main root cause error is. Test very thoroughly before submitting. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Additional context** 27 | Add any other context about the problem here. 28 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | name: Upload Python Package 3 | 4 | on: 5 | release: 6 | types: [published] 7 | 8 | permissions: 9 | contents: read 10 | 11 | jobs: 12 | deploy: 13 | 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: '3.10' 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --no-cache-dir --upgrade pip 25 | pip install build 26 | - name: Build package 27 | run: python -m build 28 | - name: Publish package 29 | uses: pypa/gh-action-pypi-publish@ec4db0b4ddc65acdf4bff5fa45ac92d78b56bdf0 30 | with: 31 | user: __token__ 32 | password: ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /.github/workflows/code_quality_control.yml: -------------------------------------------------------------------------------- 1 | name: Linting and Formatting 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | lint_and_format: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Checkout code 14 | uses: actions/checkout@v4 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: '3.10' 20 | 21 | - name: Install dependencies 22 | run: pip install --no-cache-dir -r requirements.txt 23 | 24 | - name: Find Python files 25 | run: find swarms_torch -name "*.py" -type f -exec autopep8 --in-place --aggressive --aggressive {} + 26 | 27 | - name: Push changes 28 | uses: ad-m/github-push-action@master 29 | with: 30 | github_token: ${{ secrets.GITHUB_TOKEN }} -------------------------------------------------------------------------------- /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | # This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time. 2 | # 3 | # You can adjust the behavior by modifying this file. 4 | # For more information, see: 5 | # https://github.com/actions/stale 6 | name: Mark stale issues and pull requests 7 | 8 | on: 9 | schedule: 10 | - cron: '26 12 * * *' 11 | 12 | jobs: 13 | stale: 14 | 15 | runs-on: ubuntu-latest 16 | permissions: 17 | issues: write 18 | pull-requests: write 19 | 20 | steps: 21 | - uses: actions/stale@v9 22 | with: 23 | repo-token: ${{ secrets.GITHUB_TOKEN }} 24 | stale-issue-message: 'Stale issue message' 25 | stale-pr-message: 'Stale pull request message' 26 | stale-issue-label: 'no-issue-activity' 27 | stale-pr-label: 'no-pr-activity' -------------------------------------------------------------------------------- /scripts/code_quality.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Navigate to the directory containing the 'package' folder 4 | # cd /path/to/your/code/directory 5 | 6 | # Run autopep8 with max aggressiveness (-aaa) and in-place modification (-i) 7 | # on all Python files (*.py) under the 'package' directory. 8 | autopep8 --in-place --aggressive --aggressive --recursive --experimental --list-fixes package/ 9 | 10 | # Run black with default settings, since black does not have an aggressiveness level. 11 | # Black will format all Python files it finds in the 'package' directory. 12 | black --experimental-string-processing package/ 13 | 14 | # Run ruff on the 'package' directory. 15 | # Add any additional flags if needed according to your version of ruff. 16 | ruff --unsafe_fix 17 | 18 | # YAPF 19 | yapf --recursive --in-place --verbose --style=google --parallel package 20 | -------------------------------------------------------------------------------- /scripts/merge_all_prs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if we are inside a Git repository 4 | if ! git rev-parse --git-dir > /dev/null 2>&1; then 5 | echo "Error: Must be run inside a Git repository." 6 | exit 1 7 | fi 8 | 9 | # Fetch all open pull requests 10 | echo "Fetching open PRs..." 11 | prs=$(gh pr list --state open --json number --jq '.[].number') 12 | 13 | # Check if there are PRs to merge 14 | if [ -z "$prs" ]; then 15 | echo "No open PRs to merge." 16 | exit 0 17 | fi 18 | 19 | echo "Found PRs: $prs" 20 | 21 | # Loop through each pull request number and merge it 22 | for pr in $prs; do 23 | echo "Attempting to merge PR #$pr" 24 | merge_output=$(gh pr merge $pr --auto --merge) 25 | merge_status=$? 26 | if [ $merge_status -ne 0 ]; then 27 | echo "Failed to merge PR #$pr. Error: $merge_output" 28 | else 29 | echo "Successfully merged PR #$pr" 30 | fi 31 | done 32 | 33 | echo "Processing complete." 34 | -------------------------------------------------------------------------------- /.github/workflows/cos_integration.yml: -------------------------------------------------------------------------------- 1 | name: Continuous Integration 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | test: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout code 13 | uses: actions/checkout@v4 14 | 15 | - name: Set up Python 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: '3.10' 19 | 20 | - name: Install dependencies 21 | run: pip install --no-cache-dir -r requirements.txt 22 | 23 | - name: Run unit tests 24 | run: pytest tests/unit 25 | 26 | - name: Run integration tests 27 | run: pytest tests/integration 28 | 29 | - name: Run code coverage 30 | run: pytest --cov=swarms tests/ 31 | 32 | - name: Run linters 33 | run: pylint swarms 34 | 35 | - name: Build documentation 36 | run: make docs 37 | 38 | - name: Validate documentation 39 | run: sphinx-build -b linkcheck docs build/docs 40 | 41 | - name: Run performance tests 42 | run: pytest tests/performance -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.yml: -------------------------------------------------------------------------------- 1 | 130 | 131 | 146 | 147 | 148 | ## Usage 149 | 150 | 151 | ### `TemporalAutoencoder` or `TAE` 152 | 153 | ```python 154 | import torch 155 | from loguru import logger 156 | from movie_gen.tae import TemporalAutoencoder 157 | 158 | def test_temporal_autoencoder(): 159 | """ 160 | Test the TemporalAutoencoder model with a dummy input tensor. 161 | This function creates a random input tensor representing a batch of videos, 162 | passes it through the model, and prints out the input and output shapes. 163 | """ 164 | # Set the logger to display debug messages 165 | logger.add(lambda msg: print(msg, end='')) 166 | 167 | # Instantiate the model 168 | model = TemporalAutoencoder(in_channels=3, latent_channels=16) 169 | 170 | # Create a dummy input tensor representing a batch of videos 171 | # Batch size B=2, T0=16 frames, 3 channels (RGB), H0=64, W0=64 172 | B, T0, C_in, H0, W0 = 1, 16, 3, 64, 64 173 | x = torch.randn(B, T0, C_in, H0, W0) 174 | 175 | # Forward pass through the model 176 | recon = model(x) 177 | 178 | # Print the shapes 179 | print(f"Input shape: {x.shape}") 180 | print(f"Reconstructed output shape: {recon.shape}") 181 | 182 | if __name__ == "__main__": 183 | test_temporal_autoencoder() 184 | 185 | 186 | ``` 187 | 188 | ## Evaluation 189 | 190 | Their models have been rigorously evaluated on multiple tasks, including: 191 | 192 | - Text-to-video generation benchmarks. 193 | - Video personalization accuracy. 194 | - Instruction-based video editing precision. 195 | - Audio generation quality. 196 | 197 | ### Reproducing Their Results 198 | 199 | To reproduce the evaluation metrics in Their paper, use the following command: 200 | 201 | ```bash 202 | python evaluate.py --model movie-gen-max --task text-to-video 203 | ``` 204 | 205 | ## Contributing 206 | 207 | They welcome contributions! Please follow the standard GitHub flow: 208 | 209 | 1. Fork the repository 210 | 2. Create a new feature branch (`git checkout -b feature-branch`) 211 | 3. Make your changes 212 | 4. Submit a pull request 213 | 214 | For a list of core contributors, please refer to the appendix of the [Movie Gen Paper](link_to_paper). 215 | 216 | ## License 217 | 218 | Movie Gen is licensed under the MIT License. See `LICENSE` for more information. 219 | 220 | ## Contact 221 | 222 | For any questions or collaboration opportunities, please reach out to the **Movie Gen** team at: 223 | 224 | - **Email**: http://agoralab.ai 225 | - **Website**: [http://agoralab.ai](http://agoralab.ai) -------------------------------------------------------------------------------- /movie_gen/tae.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from loguru import logger 5 | 6 | 7 | def get_num_heads(in_channels: int) -> int: 8 | """ 9 | Calculate the appropriate number of attention heads. 10 | Args: 11 | in_channels (int): The number of input channels. 12 | Returns: 13 | int: The number of attention heads. 14 | """ 15 | num_heads = min(in_channels, 8) 16 | while in_channels % num_heads != 0 and num_heads > 1: 17 | num_heads -= 1 18 | if in_channels % num_heads != 0: 19 | num_heads = 1 # Fallback to 1 if no divisor is found 20 | return num_heads 21 | 22 | 23 | class ResNetBlock(nn.Module): 24 | """ 25 | Residual block with spatial and temporal convolutions. 26 | This block implements the ResNet block described in Section 3.1.1 of the paper, 27 | where after each 2D spatial convolution, a 1D temporal convolution is added to 'inflate' the model. 28 | """ 29 | 30 | def __init__(self, in_channels: int, out_channels: int): 31 | super(ResNetBlock, self).__init__() 32 | # 2D spatial convolution 33 | self.conv_spatial = nn.Conv3d( 34 | in_channels, 35 | out_channels, 36 | kernel_size=(1, 3, 3), 37 | padding=(0, 1, 1), 38 | ) 39 | # 1D temporal convolution with symmetrical replicate padding 40 | self.conv_temporal = nn.Conv3d( 41 | out_channels, 42 | out_channels, 43 | kernel_size=(3, 1, 1), 44 | padding=(1, 0, 0), 45 | padding_mode="replicate", 46 | ) 47 | # Shortcut connection 48 | if in_channels != out_channels: 49 | self.shortcut = nn.Conv3d( 50 | in_channels, out_channels, kernel_size=1 51 | ) 52 | else: 53 | self.shortcut = nn.Identity() 54 | self.relu = nn.ReLU(inplace=True) 55 | 56 | def forward(self, x: torch.Tensor) -> torch.Tensor: 57 | """ 58 | Forward pass of the ResNet block. 59 | Args: 60 | x (torch.Tensor): Input tensor of shape (B, C_in, T, H, W) 61 | Returns: 62 | torch.Tensor: Output tensor of shape (B, C_out, T, H, W) 63 | """ 64 | identity = self.shortcut(x) 65 | out = self.relu(self.conv_spatial(x)) # Spatial convolution 66 | out = self.relu( 67 | self.conv_temporal(out) 68 | ) # Temporal convolution with symmetrical replicate padding 69 | out += identity # Residual connection 70 | out = self.relu(out) 71 | return out 72 | 73 | 74 | class SpatialSelfAttention(nn.Module): 75 | """ 76 | Spatial Self-Attention Block. 77 | Implements spatial attention as described in Section 3.1.1, where spatial attention is applied after spatial convolutions. 78 | """ 79 | 80 | def __init__(self, in_channels: int): 81 | super(SpatialSelfAttention, self).__init__() 82 | self.in_channels = in_channels 83 | num_heads = self.get_num_heads(in_channels) 84 | self.attention = nn.MultiheadAttention( 85 | embed_dim=in_channels, num_heads=num_heads 86 | ) 87 | 88 | def get_num_heads(self, in_channels: int) -> int: 89 | """ 90 | Calculate the appropriate number of attention heads. 91 | Args: 92 | in_channels (int): The number of input channels. 93 | Returns: 94 | int: The number of attention heads. 95 | """ 96 | num_heads = min(in_channels, 8) 97 | while in_channels % num_heads != 0 and num_heads > 1: 98 | num_heads -= 1 99 | if in_channels % num_heads != 0: 100 | num_heads = 1 # Fallback to 1 if no divisor is found 101 | return num_heads 102 | 103 | def forward(self, x: torch.Tensor) -> torch.Tensor: 104 | """ 105 | Forward pass of the spatial self-attention block. 106 | Args: 107 | x (torch.Tensor): Input tensor of shape (B, C, T, H, W) 108 | Returns: 109 | torch.Tensor: Output tensor of shape (B, C, T, H, W) 110 | """ 111 | B, C, T, H, W = x.shape 112 | x_reshaped = x.permute(0, 2, 3, 4, 1).reshape( 113 | B * T, H * W, C 114 | ) # (B*T, H*W, C) 115 | attn_output, _ = self.attention( 116 | x_reshaped, x_reshaped, x_reshaped 117 | ) 118 | attn_output = attn_output.view(B, T, H, W, C).permute( 119 | 0, 4, 1, 2, 3 120 | ) 121 | return x + attn_output # Residual connection 122 | 123 | 124 | class TemporalSelfAttention(nn.Module): 125 | """ 126 | Temporal Self-Attention Block. 127 | Implements temporal attention as described in Section 3.1.1, where temporal attention is applied after temporal convolutions. 128 | """ 129 | 130 | def __init__(self, in_channels: int): 131 | super(TemporalSelfAttention, self).__init__() 132 | self.in_channels = in_channels 133 | heads = get_num_heads(in_channels) 134 | self.attention = nn.MultiheadAttention( 135 | embed_dim=in_channels, num_heads=heads 136 | ) 137 | 138 | def forward(self, x: torch.Tensor) -> torch.Tensor: 139 | """ 140 | Forward pass of the temporal self-attention block. 141 | Args: 142 | x (torch.Tensor): Input tensor of shape (B, C, T, H, W) 143 | Returns: 144 | torch.Tensor: Output tensor of shape (B, C, T, H, W) 145 | """ 146 | B, C, T, H, W = x.shape 147 | x_reshaped = x.permute(0, 3, 4, 2, 1).reshape( 148 | B * H * W, T, C 149 | ) # (B*H*W, T, C) 150 | attn_output, _ = self.attention( 151 | x_reshaped, x_reshaped, x_reshaped 152 | ) 153 | attn_output = attn_output.view(B, H, W, T, C).permute( 154 | 0, 4, 3, 1, 2 155 | ) 156 | return x + attn_output # Residual connection 157 | 158 | 159 | class TAEEncoder(nn.Module): 160 | """ 161 | Temporal Autoencoder Encoder. 162 | Compresses input pixel space video V of shape (B, T0, 3, H0, W0) to latent X of shape (B, C, T, H, W). 163 | As described in Section 3.1.1, the encoder compresses the input 8x across each spatio-temporal dimension. 164 | """ 165 | 166 | def __init__( 167 | self, in_channels: int = 3, latent_channels: int = 16 168 | ): 169 | super(TAEEncoder, self).__init__() 170 | self.latent_channels = latent_channels 171 | # Initial convolution to increase channel dimension 172 | self.initial_conv = nn.Conv3d( 173 | in_channels, 64, kernel_size=3, padding=1 174 | ) 175 | current_channels = 64 176 | # Downsampling layers to achieve 8x compression 177 | self.downsampling_layers = nn.ModuleList() 178 | for _ in range(3): # 8x compression over T, H, W 179 | self.downsampling_layers.append( 180 | nn.Sequential( 181 | ResNetBlock(current_channels, current_channels), 182 | SpatialSelfAttention(current_channels), 183 | TemporalSelfAttention(current_channels), 184 | # Temporal downsampling via strided convolution with stride of 2 185 | nn.Conv3d( 186 | current_channels, 187 | current_channels * 2, 188 | kernel_size=3, 189 | stride=2, 190 | padding=1, 191 | ), 192 | ) 193 | ) 194 | current_channels *= 2 195 | # Final block to get latent representation 196 | self.final_block = nn.Sequential( 197 | ResNetBlock(current_channels, latent_channels), 198 | SpatialSelfAttention(latent_channels), 199 | TemporalSelfAttention(latent_channels), 200 | ) 201 | 202 | def forward(self, x: torch.Tensor) -> torch.Tensor: 203 | """ 204 | Forward pass of the TAE encoder. 205 | Args: 206 | x (torch.Tensor): Input tensor of shape (B, T0, 3, H0, W0) 207 | Returns: 208 | torch.Tensor: Output tensor of shape (B, C, T, H, W) 209 | """ 210 | logger.debug("Encoding input of shape {}", x.shape) 211 | B, T0, C_in, H0, W0 = x.shape 212 | x = x.permute(0, 2, 1, 3, 4) # (B, C_in, T0, H0, W0) 213 | x = self.initial_conv(x) 214 | for layer in self.downsampling_layers: 215 | x = layer(x) 216 | logger.debug( 217 | "After downsampling layer, shape: {}", x.shape 218 | ) 219 | x = self.final_block(x) 220 | logger.debug("Final latent shape: {}", x.shape) 221 | return x # (B, latent_channels, T, H, W) 222 | 223 | 224 | class TAEDecoder(nn.Module): 225 | """ 226 | Temporal Autoencoder Decoder. 227 | Decodes latent X of shape (B, C, T, H, W) to reconstructed video V_hat of shape (B, T0, 3, H0, W0). 228 | Upsampling is performed via nearest-neighbor interpolation followed by convolution as described in Section 3.1.1. 229 | """ 230 | 231 | def __init__( 232 | self, out_channels: int = 3, latent_channels: int = 16 233 | ): 234 | super(TAEDecoder, self).__init__() 235 | self.latent_channels = latent_channels 236 | current_channels = latent_channels 237 | # Initial block 238 | self.initial_block = nn.Sequential( 239 | ResNetBlock(current_channels, current_channels), 240 | SpatialSelfAttention(current_channels), 241 | TemporalSelfAttention(current_channels), 242 | ) 243 | # Upsampling layers to reverse the compression 244 | self.upsampling_layers = nn.ModuleList() 245 | for _ in range(3): # Reverse of encoder 246 | self.upsampling_layers.append( 247 | nn.Sequential( 248 | # Upsampling via nearest-neighbor interpolation 249 | nn.Upsample(scale_factor=2, mode="nearest"), 250 | ResNetBlock( 251 | current_channels, current_channels // 2 252 | ), 253 | SpatialSelfAttention(current_channels // 2), 254 | TemporalSelfAttention(current_channels // 2), 255 | ) 256 | ) 257 | current_channels = current_channels // 2 258 | # Final convolution to get output image 259 | self.final_conv = nn.Conv3d( 260 | current_channels, out_channels, kernel_size=3, padding=1 261 | ) 262 | self.sigmoid = ( 263 | nn.Sigmoid() 264 | ) # Assuming input images are normalized between 0 and 1 265 | 266 | def forward(self, x: torch.Tensor) -> torch.Tensor: 267 | """ 268 | Forward pass of the TAE decoder. 269 | Args: 270 | x (torch.Tensor): Input tensor of shape (B, C, T, H, W) 271 | Returns: 272 | torch.Tensor: Output tensor of shape (B, T0, 3, H0, W0) 273 | """ 274 | logger.debug("Decoding latent of shape {}", x.shape) 275 | x = self.initial_block(x) 276 | for layer in self.upsampling_layers: 277 | x = layer(x) 278 | logger.debug("After upsampling layer, shape: {}", x.shape) 279 | x = self.final_conv(x) 280 | x = self.sigmoid(x) 281 | x = x.permute(0, 2, 1, 3, 4) # (B, T0, C_out, H0, W0) 282 | logger.debug( 283 | "Reconstructed output shape before trimming: {}", x.shape 284 | ) 285 | return x 286 | 287 | 288 | class TemporalAutoencoder(nn.Module): 289 | """ 290 | Temporal Autoencoder (TAE) model. 291 | This model combines the encoder and decoder, and handles variable-length videos as described in Section 3.1.1. 292 | """ 293 | 294 | def __init__( 295 | self, in_channels: int = 3, latent_channels: int = 16 296 | ): 297 | super(TemporalAutoencoder, self).__init__() 298 | self.encoder = TAEEncoder(in_channels, latent_channels) 299 | self.decoder = TAEDecoder(in_channels, latent_channels) 300 | 301 | def forward(self, x: torch.Tensor) -> torch.Tensor: 302 | """ 303 | Forward pass of the TAE. 304 | Args: 305 | x (torch.Tensor): Input tensor of shape (B, T0, 3, H0, W0) 306 | Returns: 307 | torch.Tensor: Reconstructed tensor of shape (B, T0, 3, H0, W0) 308 | """ 309 | logger.debug("Starting encoding") 310 | latent = self.encoder(x) 311 | logger.debug("Encoding complete") 312 | logger.debug("Starting decoding") 313 | recon = self.decoder(latent) 314 | logger.debug("Decoding complete") 315 | # Discard spurious frames as shown in Figure 4 of the paper 316 | T0 = x.shape[1] 317 | T_decoded = recon.shape[1] 318 | if T_decoded > T0: 319 | recon = recon[:, :T0] 320 | logger.debug( 321 | "Discarded spurious frames, final output shape: {}", 322 | recon.shape, 323 | ) 324 | else: 325 | logger.debug( 326 | "No spurious frames to discard, final output shape: {}", 327 | recon.shape, 328 | ) 329 | return recon 330 | 331 | 332 | # import torch 333 | # from loguru import logger 334 | 335 | # # Assuming the TemporalAutoencoder and its dependencies have been defined/imported as provided above 336 | 337 | # def test_temporal_autoencoder(): 338 | # """ 339 | # Test the TemporalAutoencoder model with a dummy input tensor. 340 | # This function creates a random input tensor representing a batch of videos, 341 | # passes it through the model, and prints out the input and output shapes. 342 | # """ 343 | # # Set the logger to display debug messages 344 | # logger.add(lambda msg: print(msg, end='')) 345 | 346 | # # Instantiate the model 347 | # model = TemporalAutoencoder(in_channels=3, latent_channels=16) 348 | 349 | # # Create a dummy input tensor representing a batch of videos 350 | # # Batch size B=2, T0=16 frames, 3 channels (RGB), H0=64, W0=64 351 | # B, T0, C_in, H0, W0 = 1, 16, 3, 64, 64 352 | # x = torch.randn(B, T0, C_in, H0, W0) 353 | 354 | # # Forward pass through the model 355 | # recon = model(x) 356 | 357 | # # Print the shapes 358 | # print(f"Input shape: {x.shape}") 359 | # print(f"Reconstructed output shape: {recon.shape}") 360 | 361 | # if __name__ == "__main__": 362 | # test_temporal_autoencoder() 363 | --------------------------------------------------------------------------------