├── .editorconfig ├── .github ├── ISSUE_TEMPLATE.md └── workflows │ ├── pre-commit.yml │ ├── pytest-datamodules.yml │ ├── pytest-models.yml │ ├── python-build.yaml │ ├── python-publish.yml │ └── semgrep.yml ├── .gitignore ├── .pre-commit-config.yaml ├── AUTHORS.rst ├── CITATION.cff ├── CONTRIBUTING.rst ├── HISTORY.rst ├── LICENSE ├── Makefile ├── README.md ├── docs ├── Makefile ├── _static │ └── images │ │ ├── logo.png │ │ └── torchfl-github.png ├── authors.rst ├── conf.py ├── contributing.rst ├── examples.rst ├── history.rst ├── index.rst ├── installation.rst ├── make.bat ├── modules.rst ├── torchfl.datamodules.rst ├── torchfl.federated.agents.rst ├── torchfl.federated.aggregators.rst ├── torchfl.federated.rst ├── torchfl.federated.samplers.rst ├── torchfl.models.core.cifar.cifar10.rst ├── torchfl.models.core.cifar.cifar100.rst ├── torchfl.models.core.cifar.rst ├── torchfl.models.core.emnist.balanced.rst ├── torchfl.models.core.emnist.byclass.rst ├── torchfl.models.core.emnist.bymerge.rst ├── torchfl.models.core.emnist.digits.rst ├── torchfl.models.core.emnist.letters.rst ├── torchfl.models.core.emnist.mnist.rst ├── torchfl.models.core.emnist.rst ├── torchfl.models.core.fashionmnist.rst ├── torchfl.models.core.rst ├── torchfl.models.rst ├── torchfl.models.sota.rst ├── torchfl.models.wrapper.rst └── torchfl.rst ├── examples ├── datamodules │ └── cifar.py ├── federated │ ├── .gitkeep │ ├── aggregators │ │ └── fedavg_test.py │ ├── mnist_entrypoint_iid.py │ └── mnist_entrypoint_niid.py └── trainers │ ├── cifar10.py │ ├── cifar10_scratch.py │ └── mnist.py ├── poetry.lock ├── pyproject.toml ├── tests ├── __init__.py ├── datamodules │ ├── __init__.py │ ├── test_cifar.py │ ├── test_emnist.py │ └── test_fashionmnist.py ├── models │ ├── __init__.py │ ├── test_alexnet.py │ ├── test_densenet.py │ ├── test_lenet.py │ ├── test_mlp.py │ ├── test_mobilenet.py │ ├── test_resnet.py │ ├── test_shufflenetv2.py │ ├── test_squeezenet.py │ └── test_vgg.py └── test_torchfl.py └── torchfl ├── __init__.py ├── cli.py ├── compatibility.py ├── config_resolver.py ├── datamodules ├── __init__.py ├── base.py ├── cifar.py ├── emnist.py ├── fashionmnist.py └── types.py ├── federated ├── __init__.py ├── agents │ ├── __init__.py │ ├── base.py │ └── v1.py ├── aggregators │ ├── __init__.py │ ├── base.py │ └── fedavg.py ├── entrypoint.py ├── fl_params.py ├── samplers │ ├── __init__.py │ ├── base.py │ └── random.py └── types.py ├── models ├── __init__.py ├── core │ ├── __init__.py │ ├── cifar │ │ ├── __init__.py │ │ ├── cifar10 │ │ │ ├── __init__.py │ │ │ ├── alexnet.py │ │ │ ├── densenet.py │ │ │ ├── lenet.py │ │ │ ├── mobilenet.py │ │ │ ├── resnet.py │ │ │ ├── shufflenetv2.py │ │ │ ├── squeezenet.py │ │ │ └── vgg.py │ │ └── cifar100 │ │ │ ├── __init__.py │ │ │ ├── alexnet.py │ │ │ ├── densenet.py │ │ │ ├── lenet.py │ │ │ ├── mobilenet.py │ │ │ ├── resnet.py │ │ │ ├── shufflenetv2.py │ │ │ ├── squeezenet.py │ │ │ └── vgg.py │ ├── emnist │ │ ├── __init__.py │ │ ├── balanced │ │ │ ├── __init__.py │ │ │ ├── alexnet.py │ │ │ ├── densenet.py │ │ │ ├── lenet.py │ │ │ ├── mlp.py │ │ │ ├── mobilenet.py │ │ │ ├── resnet.py │ │ │ ├── shufflenetv2.py │ │ │ ├── squeezenet.py │ │ │ └── vgg.py │ │ ├── byclass │ │ │ ├── __init__.py │ │ │ ├── alexnet.py │ │ │ ├── densenet.py │ │ │ ├── lenet.py │ │ │ ├── mlp.py │ │ │ ├── mobilenet.py │ │ │ ├── resnet.py │ │ │ ├── shufflenetv2.py │ │ │ ├── squeezenet.py │ │ │ └── vgg.py │ │ ├── bymerge │ │ │ ├── __init__.py │ │ │ ├── alexnet.py │ │ │ ├── densenet.py │ │ │ ├── lenet.py │ │ │ ├── mlp.py │ │ │ ├── mobilenet.py │ │ │ ├── resnet.py │ │ │ ├── shufflenetv2.py │ │ │ ├── squeezenet.py │ │ │ └── vgg.py │ │ ├── digits │ │ │ ├── __init__.py │ │ │ ├── alexnet.py │ │ │ ├── densenet.py │ │ │ ├── lenet.py │ │ │ ├── mlp.py │ │ │ ├── mobilenet.py │ │ │ ├── resnet.py │ │ │ ├── shufflenetv2.py │ │ │ ├── squeezenet.py │ │ │ └── vgg.py │ │ ├── letters │ │ │ ├── __init__.py │ │ │ ├── alexnet.py │ │ │ ├── densenet.py │ │ │ ├── lenet.py │ │ │ ├── mlp.py │ │ │ ├── mobilenet.py │ │ │ ├── resnet.py │ │ │ ├── shufflenetv2.py │ │ │ ├── squeezenet.py │ │ │ └── vgg.py │ │ └── mnist │ │ │ ├── __init__.py │ │ │ ├── alexnet.py │ │ │ ├── densenet.py │ │ │ ├── lenet.py │ │ │ ├── mlp.py │ │ │ ├── mobilenet.py │ │ │ ├── resnet.py │ │ │ ├── shufflenetv2.py │ │ │ ├── squeezenet.py │ │ │ └── vgg.py │ └── fashionmnist │ │ ├── __init__.py │ │ ├── alexnet.py │ │ ├── densenet.py │ │ ├── lenet.py │ │ ├── mlp.py │ │ ├── mobilenet.py │ │ ├── resnet.py │ │ ├── shufflenetv2.py │ │ ├── squeezenet.py │ │ └── vgg.py ├── sota │ ├── __init__.py │ ├── alexnet.py │ ├── densenet.py │ ├── lenet.py │ ├── mlp.py │ ├── mobilenet.py │ ├── resnet.py │ ├── shufflenetv2.py │ ├── squeezenet.py │ └── vgg.py └── wrapper │ ├── __init__.py │ ├── cifar.py │ ├── emnist.py │ └── fashionmnist.py └── utils.py /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = tab 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.bat] 14 | indent_style = tab 15 | end_of_line = crlf 16 | 17 | [LICENSE] 18 | insert_final_newline = false 19 | 20 | [Makefile] 21 | indent_style = tab 22 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | * torchfl version: 2 | * Python version: 3 | * Operating System: 4 | 5 | ### Description 6 | 7 | Describe what you were trying to get done. 8 | Tell us what happened, what went wrong, and what you expected to happen. 9 | 10 | ### What I Did 11 | 12 | ``` 13 | Paste the command(s) you ran and the output. 14 | If there was a crash, please include the traceback here. 15 | ``` 16 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [master] 7 | 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - name: Install Poetry 14 | run: pipx install poetry 15 | - uses: actions/setup-python@v4 16 | with: 17 | python-version: '3.10' 18 | cache: 'poetry' 19 | - uses: pre-commit/action@v3.0.0 20 | -------------------------------------------------------------------------------- /.github/workflows/pytest-datamodules.yml: -------------------------------------------------------------------------------- 1 | name: pytest-datamodules 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - master 7 | push: 8 | branches: 9 | - master 10 | 11 | jobs: 12 | changedfiles: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Checkout repository 16 | uses: actions/checkout@v3 17 | - name: Changed Datamodules 18 | uses: actions/cache@v3 19 | id: changed-datamodules 20 | with: 21 | path: tests/datamodules 22 | key: hashFiles('torchfl/datamodules/*') 23 | outputs: 24 | datamodules-cache-hit: ${{ steps.changed-datamodules.outputs.cache-hit }} 25 | 26 | test-datamodules: 27 | name: "Test changed datamodules." 28 | runs-on: ubuntu-latest 29 | needs: changedfiles 30 | strategy: 31 | matrix: 32 | python-version: ["3.10"] 33 | if: ${{needs.changedfiles.outputs.datamodules-cache-hit != 'true'}} 34 | steps: 35 | - uses: actions/checkout@v3 36 | - name: Install poetry 37 | run: pipx install poetry 38 | - name: Set up Python 39 | uses: actions/setup-python@v4 40 | with: 41 | python-version: ${{ matrix.python-version }} 42 | cache: 'poetry' 43 | - name: Install dependencies 44 | run: | 45 | poetry install 46 | - name: PyTest for datamodules 47 | run: | 48 | poetry run pytest tests/datamodules/ 49 | -------------------------------------------------------------------------------- /.github/workflows/pytest-models.yml: -------------------------------------------------------------------------------- 1 | name: pytest-models 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - master 7 | push: 8 | branches: 9 | - master 10 | 11 | jobs: 12 | changedfiles: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Checkout repository 16 | uses: actions/checkout@v3 17 | - name: Changed Models 18 | uses: actions/cache@v3 19 | id: changed-models 20 | with: 21 | path: tests/models 22 | key: hashFiles('torchfl/models/sota/*') 23 | outputs: 24 | models-cache-hit: ${{ steps.changed-models.outputs.cache-hit }} 25 | 26 | test-models: 27 | name: "Test changed model files." 28 | runs-on: ubuntu-latest 29 | needs: changedfiles 30 | strategy: 31 | matrix: 32 | python-version: ["3.10"] 33 | if: ${{needs.changedfiles.outputs.models-cache-hit != 'true'}} 34 | steps: 35 | - uses: actions/checkout@v3 36 | - name: Install poetry 37 | run: pipx install poetry 38 | - name: Set up Python 39 | uses: actions/setup-python@v4 40 | with: 41 | python-version: ${{ matrix.python-version }} 42 | cache: 'poetry' 43 | - name: Install dependencies 44 | run: | 45 | poetry install 46 | - name: PyTest for models 47 | run: | 48 | poetry run pytest tests/models/ 49 | -------------------------------------------------------------------------------- /.github/workflows/python-build.yaml: -------------------------------------------------------------------------------- 1 | name: Build a Python package with different versions and operating systems. 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | 11 | jobs: 12 | build: 13 | strategy: 14 | fail-fast: true 15 | matrix: 16 | python-version: ["3.10"] 17 | os: ["ubuntu-latest", "macos-latest"] 18 | runs-on: ${{ matrix.os }} 19 | steps: 20 | - uses: actions/checkout@v3 21 | - name: Install poetry 22 | run: | 23 | pipx install poetry 24 | - name: Set up Python 25 | uses: actions/setup-python@v4 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | cache: 'poetry' 29 | - name: Install dependencies 30 | run: | 31 | poetry install 32 | - name: Build package 33 | run: poetry build 34 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package to PyPI 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | strategy: 20 | fail-fast: true 21 | matrix: 22 | python-version: ["3.10"] 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Install poetry 26 | run: | 27 | pipx install poetry 28 | - name: Set up Python 29 | uses: actions/setup-python@v4 30 | with: 31 | python-version: ${{ matrix.python-version }} 32 | cache: 'poetry' 33 | - name: Install dependencies 34 | run: | 35 | poetry install 36 | - name: Build package 37 | run: poetry build 38 | - name: Publish package 39 | run: | 40 | poetry config pypi-token.pypi ${{ secrets.PYPI_API_TOKEN }} 41 | poetry publish 42 | -------------------------------------------------------------------------------- /.github/workflows/semgrep.yml: -------------------------------------------------------------------------------- 1 | name: Semgrep 2 | on: 3 | pull_request: {} 4 | merge_group: 5 | push: 6 | branches: 7 | - main 8 | - master 9 | paths: 10 | - .github/workflows/semgrep.yml 11 | schedule: 12 | - cron: '0 0 * * 0' 13 | jobs: 14 | semgrep: 15 | name: Scan 16 | runs-on: ubuntu-20.04 17 | env: 18 | SEMGREP_APP_TOKEN: ${{ secrets.SEMGREP_APP_TOKEN }} 19 | container: 20 | image: returntocorp/semgrep 21 | if: (github.actor != 'dependabot[bot]') 22 | steps: 23 | - uses: actions/checkout@v3 24 | - run: semgrep ci 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | 104 | # IDE settings 105 | .vscode/ 106 | .idea/ 107 | 108 | # PyTorch dataset files 109 | data/ 110 | 111 | # Log files or run files 112 | runs/ 113 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-merge-conflict 6 | - id: check-added-large-files 7 | - id: check-yaml 8 | - id: end-of-file-fixer 9 | - id: trailing-whitespace 10 | - id: check-ast 11 | - id: check-case-conflict 12 | - id: check-symlinks 13 | - id: debug-statements 14 | - id: detect-private-key 15 | - repo: https://github.com/python-poetry/poetry 16 | rev: 1.3.2 17 | hooks: 18 | - id: poetry-lock 19 | - id: poetry-check 20 | - repo: https://github.com/psf/black 21 | rev: 23.1.0 22 | hooks: 23 | - id: black 24 | args: [--config=pyproject.toml] 25 | - repo: https://github.com/hadialqattan/pycln 26 | rev: v2.1.3 27 | hooks: 28 | - id: pycln 29 | args: [--config=pyproject.toml] 30 | - repo: https://github.com/PyCQA/isort 31 | rev: 5.12.0 32 | hooks: 33 | - id: isort 34 | files: "\\.(py)$" 35 | args: [--settings-path=pyproject.toml] 36 | - repo: https://github.com/pre-commit/mirrors-mypy 37 | rev: v1.0.0 38 | hooks: 39 | - id: mypy 40 | language_version: python3.10 41 | files: "^torchfl/.+" 42 | exclude: ^torchfl/(__pycache__/.+|data/.+|tests/.+|docs/.+)$ 43 | args: [--config-file=pyproject.toml] 44 | additional_dependencies: 45 | - "numpy" 46 | - "pytorch_lightning" 47 | - "torch" 48 | - "torchvision" 49 | - repo: https://github.com/PyCQA/doc8 50 | rev: v1.1.1 51 | hooks: 52 | - id: doc8 53 | name: doc8 54 | description: This hook runs doc8 for linting docs 55 | entry: python -m doc8 56 | language: python 57 | files: "\\.(rst)$" 58 | require_serial: true 59 | args: [--config=pyproject.toml] 60 | - repo: https://github.com/charliermarsh/ruff-pre-commit 61 | rev: "v0.0.246" 62 | hooks: 63 | - id: ruff 64 | language_version: python3.10 65 | args: [--fix, --exit-non-zero-on-fix] 66 | exclude: ^torchfl/(__pycache__/.+|data/.+|tests/.+|docs/.+)$ 67 | -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | Credits 3 | ======= 4 | 5 | We truly appreciate everyone contributing to this community! 6 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Khimani" 5 | given-names: "Vivek" 6 | orcid: "https://orcid.org/0000-0002-7395-9875" 7 | - family-names: "Jabbari" 8 | given-names: "Shahin" 9 | title: "torchfl" 10 | version: 1.0.0 11 | doi: 10.5281/zenodo.1234 12 | date-released: 2023-01-27 13 | url: "https://github.com/vivekkhimani/torchfl" 14 | preferred-citation: 15 | type: article 16 | authors: 17 | - family-names: "Khimani" 18 | given-names: "Vivek" 19 | orcid: "https://orcid.org/0000-0002-7395-9875" 20 | - family-names: "Jabbari" 21 | given-names: "Shahin" 22 | doi: "10.48550/ARXIV.2211.00735" 23 | start: 1 # First page number 24 | end: 20 # Last page number 25 | title: "TorchFL: A Performant Library for Bootstrapping Federated Learning Experiments" 26 | year: 2022 27 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | .. highlight:: shell 2 | 3 | ============ 4 | Contributing 5 | ============ 6 | 7 | Contributions are welcome, and they are greatly appreciated! Every little bit 8 | helps, and credit will always be given. 9 | 10 | You can contribute in many ways: 11 | 12 | Types of Contributions 13 | ---------------------- 14 | 15 | Report Bugs 16 | ~~~~~~~~~~~ 17 | 18 | Report bugs at https://github.com/vivekkhimani/torchfl/issues. 19 | 20 | If you are reporting a bug, please include: 21 | 22 | * Your operating system name and version. 23 | * Any details about your local setup that might be helpful in troubleshooting. 24 | * Detailed steps to reproduce the bug. 25 | 26 | Fix Bugs 27 | ~~~~~~~~ 28 | 29 | Look through the GitHub issues for bugs. Anything tagged with "bug" and "help 30 | wanted" is open to whoever wants to implement it. 31 | 32 | Implement Features 33 | ~~~~~~~~~~~~~~~~~~ 34 | 35 | Look through the GitHub issues for features. 36 | Anything tagged with "enhancement", "help wanted", 37 | and "feature" is open to whoever wants to implement it. 38 | 39 | Write Documentation 40 | ~~~~~~~~~~~~~~~~~~~ 41 | 42 | torchfl could always use more documentation, whether as part of the 43 | official torchfl docs, in docstrings, or even on the web in blog posts, 44 | articles, and such. 45 | 46 | Submit Feedback 47 | ~~~~~~~~~~~~~~~ 48 | 49 | The best way to send feedback is to file an issue at https://github.com/vivekkhimani/torchfl/issues. 50 | 51 | If you are proposing a feature: 52 | 53 | * Explain in detail how it would work. 54 | * Keep the scope as narrow as possible, to make it easier to implement. 55 | * Remember that this is a volunteer-driven project, and that contributions 56 | are welcome :) 57 | 58 | Get Started! 59 | ------------ 60 | 61 | Ready to contribute? Here's how to set up `torchfl` for local development. 62 | 63 | 1. Fork the `torchfl` repo on GitHub. 64 | 2. Clone your fork locally:: 65 | 66 | $ git clone git@github.com:/torchfl.git 67 | 68 | 3. Install Poetry to manage dependencies and virtual environments from https://python-poetry.org/docs/. 69 | 4. Install the project dependencies using:: 70 | 71 | $ poetry install 72 | 73 | 5. To add a new dependency to the project, use:: 74 | 75 | $ poetry add 76 | 77 | 6. Create a branch for local development:: 78 | 79 | $ git checkout -b name-of-your-bugfix-or-feature 80 | 81 | Now you can make your changes locally and maintain them on your own branch. 82 | 83 | 7. When you're done making changes, check that your changes pass the tests:: 84 | 85 | $ poetry run pytest tests 86 | 87 | If you want to run a specific test file, use:: 88 | 89 | $ poetry pytest 90 | 91 | If your changes are not covered by the tests, please add tests. 92 | 93 | 8. The pre-commit hooks will be run before every commit. 94 | If you want to run them manually, use:: 95 | 96 | $ pre-commit run --all 97 | 98 | 9. Commit your changes and push your branch to GitHub:: 99 | 100 | $ git add --all 101 | $ git commit -m "Your detailed description of your changes." 102 | $ git push origin 103 | 104 | 10. Submit a pull request through the GitHub website. 105 | 11. Once the pull request has been submitted, 106 | the CI pipelines will be triggered on GitHub Actions, 107 | All of them must pass before one of the maintainers 108 | can review the request and perform the merge. 109 | 110 | Pull Request Guidelines 111 | ---------------------------- 112 | 113 | 1. The pull request should include tests. 114 | 115 | 2. If the pull request adds functionality, the docs should be updated. Put 116 | your new functionality into a function with a docstring, and add the 117 | feature to the list in README.rst. 118 | 119 | 3. The pull request should work for Python3, and for PyPy. Check 120 | https://travis-ci.com/vivekkhimani/torchfl/pull_requests 121 | and make sure that the tests pass for all supported Python versions. 122 | -------------------------------------------------------------------------------- /HISTORY.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | History 3 | ======= 4 | 5 | 0.1.0 (2021-12-16) 6 | ------------------ 7 | 8 | * Started development. 9 | 10 | 0.1.1 (2022-09-15) 11 | ------------------ 12 | 13 | * First Release on PyPI. 14 | 15 | 0.1.2 (2022-09-16) 16 | ------------------ 17 | 18 | * Bug Fixes. 19 | 20 | 0.1.3 (2022-09-17) 21 | ------------------ 22 | 23 | * Bug Fixes. 24 | * Documentation Update and Dependency conflict resolution. 25 | 26 | 0.1.4 (2022-09-17) 27 | ------------------ 28 | 29 | * Bug fixes. 30 | 31 | 0.1.5 (2022-10-06) 32 | ------------------ 33 | 34 | * Major changes. 35 | * Added workflows to build the torchfl package before merging it to master. 36 | * Updated the existing workflows to include the checks for all python versions. 37 | * Removed the usage of Literal type to enable py36, and py37 support. 38 | * Defined a common cache in home dir for torchfl. 39 | 40 | 0.1.6 (2022-10-28) 41 | ------------------ 42 | 43 | * Officially launched the federated learning modules. 44 | * Added extensive examples for federated and non-federated settings. 45 | * Updated the documentation. 46 | 47 | 0.1.7 (2023-02-01) 48 | ------------------ 49 | 50 | * Code cleanup using pre-commit hooks. 51 | * Bug fixes. 52 | 53 | 0.1.8 (2023-02-01) 54 | ------------------ 55 | 56 | * Documentation fix. 57 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | A Python library for implementing concurrent (multi-threaded and multi-processed) federated learning using PyTorch API. 5 | Copyright (C) 2021 Vivek Khimani 6 | 7 | This program is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | This program is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with this program. If not, see . 19 | 20 | Also add information on how to contact you by electronic and paper mail. 21 | 22 | You should also get your employer (if you work as a programmer) or school, 23 | if any, to sign a "copyright disclaimer" for the program, if necessary. 24 | For more information on this, and how to apply and follow the GNU GPL, see 25 | . 26 | 27 | The GNU General Public License does not permit incorporating your program 28 | into proprietary programs. If your program is a subroutine library, you 29 | may consider it more useful to permit linking proprietary applications with 30 | the library. If this is what you want to do, use the GNU Lesser General 31 | Public License instead of this License. But first, please read 32 | . 33 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: clean clean-test clean-pyc clean-build docs help 2 | .DEFAULT_GOAL := help 3 | 4 | define BROWSER_PYSCRIPT 5 | import os, webbrowser, sys 6 | 7 | from urllib.request import pathname2url 8 | 9 | webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) 10 | endef 11 | export BROWSER_PYSCRIPT 12 | 13 | define PRINT_HELP_PYSCRIPT 14 | import re, sys 15 | 16 | for line in sys.stdin: 17 | match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) 18 | if match: 19 | target, help = match.groups() 20 | print("%-20s %s" % (target, help)) 21 | endef 22 | export PRINT_HELP_PYSCRIPT 23 | 24 | BROWSER := python -c "$$BROWSER_PYSCRIPT" 25 | 26 | help: 27 | @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) 28 | 29 | clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts 30 | 31 | clean-build: ## remove build artifacts 32 | rm -fr build/ 33 | rm -fr dist/ 34 | rm -fr .eggs/ 35 | find . -name '*.egg-info' -exec rm -fr {} + 36 | find . -name '*.egg' -exec rm -f {} + 37 | 38 | clean-pyc: ## remove Python file artifacts 39 | find . -name '*.pyc' -exec rm -f {} + 40 | find . -name '*.pyo' -exec rm -f {} + 41 | find . -name '*~' -exec rm -f {} + 42 | find . -name '__pycache__' -exec rm -fr {} + 43 | 44 | clean-test: ## remove test and coverage artifacts 45 | rm -fr .tox/ 46 | rm -f .coverage 47 | rm -fr htmlcov/ 48 | rm -fr .pytest_cache 49 | 50 | test: ## run tests quickly with the default Python 51 | pytest 52 | 53 | test-all: ## run tests on every Python version with tox 54 | tox 55 | 56 | coverage: ## check code coverage quickly with the default Python 57 | coverage run --source torchfl -m pytest 58 | coverage report -m 59 | coverage html 60 | $(BROWSER) htmlcov/index.html 61 | 62 | release: dist ## package and upload a release 63 | twine upload dist/* 64 | 65 | dist: clean ## builds source and wheel package 66 | python setup.py sdist 67 | python setup.py bdist_wheel 68 | ls -l dist 69 | 70 | install: clean ## install the package to the active Python's site-packages 71 | python setup.py install 72 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = torchfl 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/torchfl-org/torchfl/170eb81c9e1254307e2b52b6c51b9186c6895f0c/docs/_static/images/logo.png -------------------------------------------------------------------------------- /docs/_static/images/torchfl-github.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/torchfl-org/torchfl/170eb81c9e1254307e2b52b6c51b9186c6895f0c/docs/_static/images/torchfl-github.png -------------------------------------------------------------------------------- /docs/authors.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../AUTHORS.rst 2 | -------------------------------------------------------------------------------- /docs/contributing.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../CONTRIBUTING.rst 2 | -------------------------------------------------------------------------------- /docs/history.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../HISTORY.rst 2 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Documentation 2 | ============= 3 | 4 | .. image:: ./_static/images/torchfl-github.png 5 | :width: 700 6 | :height: 350 7 | :alt: Logo 8 | 9 | .. toctree:: 10 | :maxdepth: 4 11 | :caption: Table of Contents: 12 | 13 | installation 14 | modules 15 | examples 16 | contributing 17 | authors 18 | history 19 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | .. highlight:: shell 2 | 3 | ============ 4 | Installation 5 | ============ 6 | 7 | 8 | Stable release 9 | -------------- 10 | 11 | To install torchfl, run this command in your terminal: 12 | 13 | .. code-block:: console 14 | 15 | $ pip install torchfl 16 | 17 | This is the preferred method to install torchfl with most stable release. 18 | 19 | If you don't have `pip`_ installed, this `Python installation guide`_ can guide 20 | you through the process. 21 | 22 | .. _pip: https://pip.pypa.io 23 | .. _Python installation guide: http://docs.python-guide.org/en/latest/starting/installation/ 24 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=torchfl 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 20 | echo.then set the SPHINXBUILD environment variable to point to the full 21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 22 | echo.Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/modules.rst: -------------------------------------------------------------------------------- 1 | torchfl 2 | ======= 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | torchfl 8 | -------------------------------------------------------------------------------- /docs/torchfl.datamodules.rst: -------------------------------------------------------------------------------- 1 | .. _torchfl.datamodules: 2 | 3 | torchfl.datamodules package 4 | =========================== 5 | 6 | Submodules 7 | ---------- 8 | 9 | torchfl.datamodules.cifar module 10 | -------------------------------- 11 | 12 | .. automodule:: torchfl.datamodules.cifar 13 | :members: 14 | :undoc-members: 15 | :show-inheritance: 16 | 17 | torchfl.datamodules.emnist module 18 | --------------------------------- 19 | 20 | .. automodule:: torchfl.datamodules.emnist 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | 25 | torchfl.datamodules.fashionmnist module 26 | --------------------------------------- 27 | 28 | .. automodule:: torchfl.datamodules.fashionmnist 29 | :members: 30 | :undoc-members: 31 | :show-inheritance: 32 | 33 | Module contents 34 | --------------- 35 | 36 | .. automodule:: torchfl.datamodules 37 | :members: 38 | :undoc-members: 39 | :show-inheritance: 40 | -------------------------------------------------------------------------------- /docs/torchfl.federated.agents.rst: -------------------------------------------------------------------------------- 1 | torchfl.federated.agents package 2 | ===================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchfl.federated.agents.base module 8 | ----------------------------------------------- 9 | 10 | .. automodule:: torchfl.federated.agents.base 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchfl.federated.agents.v1 module 16 | ----------------------------------------------- 17 | 18 | .. automodule:: torchfl.federated.agents.v1 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | 24 | Module contents 25 | --------------- 26 | 27 | .. automodule:: torchfl.federated.agents 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | -------------------------------------------------------------------------------- /docs/torchfl.federated.aggregators.rst: -------------------------------------------------------------------------------- 1 | torchfl.federated.aggregators package 2 | ===================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchfl.federated.aggregators.base module 8 | ----------------------------------------------- 9 | 10 | .. automodule:: torchfl.federated.aggregators.base 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchfl.federated.aggregators.fedavg module 16 | ----------------------------------------------- 17 | 18 | .. automodule:: torchfl.federated.aggregators.fedavg 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | 24 | Module contents 25 | --------------- 26 | 27 | .. automodule:: torchfl.federated.aggregators 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | -------------------------------------------------------------------------------- /docs/torchfl.federated.rst: -------------------------------------------------------------------------------- 1 | torchfl.federated package 2 | ========================= 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | torchfl.federated.agents 11 | torchfl.federated.aggregators 12 | torchfl.federated.samplers 13 | 14 | Submodules 15 | ---------- 16 | 17 | torchfl.federated.entrypoint module 18 | ----------------------------------------------- 19 | 20 | .. automodule:: torchfl.federated.entrypoint 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | 25 | torchfl.federated.fl_params module 26 | ----------------------------------------------- 27 | 28 | .. automodule:: torchfl.federated.fl_params 29 | :members: 30 | :undoc-members: 31 | :show-inheritance: 32 | 33 | torchfl.federated.types module 34 | ----------------------------------------------- 35 | 36 | .. automodule:: torchfl.federated.types 37 | :members: 38 | :undoc-members: 39 | :show-inheritance: 40 | 41 | 42 | Module contents 43 | --------------- 44 | 45 | .. automodule:: torchfl.federated 46 | :members: 47 | :undoc-members: 48 | :show-inheritance: 49 | -------------------------------------------------------------------------------- /docs/torchfl.federated.samplers.rst: -------------------------------------------------------------------------------- 1 | torchfl.federated.samplers package 2 | ================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchfl.federated.samplers.base module 8 | ----------------------------------------------- 9 | 10 | .. automodule:: torchfl.federated.samplers.base 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchfl.federated.samplers.random module 16 | ----------------------------------------------- 17 | 18 | .. automodule:: torchfl.federated.samplers.random 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | 24 | Module contents 25 | --------------- 26 | 27 | .. automodule:: torchfl.federated.samplers 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | -------------------------------------------------------------------------------- /docs/torchfl.models.core.cifar.cifar10.rst: -------------------------------------------------------------------------------- 1 | torchfl.models.core.cifar.cifar10 package 2 | ========================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchfl.models.core.cifar.cifar10.alexnet module 8 | ------------------------------------------------ 9 | 10 | .. automodule:: torchfl.models.core.cifar.cifar10.alexnet 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchfl.models.core.cifar.cifar10.densenet module 16 | ------------------------------------------------- 17 | 18 | .. automodule:: torchfl.models.core.cifar.cifar10.densenet 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchfl.models.core.cifar.cifar10.lenet module 24 | ---------------------------------------------- 25 | 26 | .. automodule:: torchfl.models.core.cifar.cifar10.lenet 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | torchfl.models.core.cifar.cifar10.mobilenet module 32 | -------------------------------------------------- 33 | 34 | .. automodule:: torchfl.models.core.cifar.cifar10.mobilenet 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | torchfl.models.core.cifar.cifar10.resnet module 40 | ----------------------------------------------- 41 | 42 | .. automodule:: torchfl.models.core.cifar.cifar10.resnet 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | torchfl.models.core.cifar.cifar10.shufflenetv2 module 48 | ----------------------------------------------------- 49 | 50 | .. automodule:: torchfl.models.core.cifar.cifar10.shufflenetv2 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | torchfl.models.core.cifar.cifar10.squeezenet module 56 | --------------------------------------------------- 57 | 58 | .. automodule:: torchfl.models.core.cifar.cifar10.squeezenet 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | torchfl.models.core.cifar.cifar10.vgg module 64 | -------------------------------------------- 65 | 66 | .. automodule:: torchfl.models.core.cifar.cifar10.vgg 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | Module contents 72 | --------------- 73 | 74 | .. automodule:: torchfl.models.core.cifar.cifar10 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | -------------------------------------------------------------------------------- /docs/torchfl.models.core.cifar.cifar100.rst: -------------------------------------------------------------------------------- 1 | torchfl.models.core.cifar.cifar100 package 2 | ========================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchfl.models.core.cifar.cifar100.alexnet module 8 | ------------------------------------------------- 9 | 10 | .. automodule:: torchfl.models.core.cifar.cifar100.alexnet 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchfl.models.core.cifar.cifar100.densenet module 16 | -------------------------------------------------- 17 | 18 | .. automodule:: torchfl.models.core.cifar.cifar100.densenet 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchfl.models.core.cifar.cifar100.lenet module 24 | ----------------------------------------------- 25 | 26 | .. automodule:: torchfl.models.core.cifar.cifar100.lenet 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | torchfl.models.core.cifar.cifar100.mobilenet module 32 | --------------------------------------------------- 33 | 34 | .. automodule:: torchfl.models.core.cifar.cifar100.mobilenet 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | torchfl.models.core.cifar.cifar100.resnet module 40 | ------------------------------------------------ 41 | 42 | .. automodule:: torchfl.models.core.cifar.cifar100.resnet 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | torchfl.models.core.cifar.cifar100.shufflenetv2 module 48 | ------------------------------------------------------ 49 | 50 | .. automodule:: torchfl.models.core.cifar.cifar100.shufflenetv2 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | torchfl.models.core.cifar.cifar100.squeezenet module 56 | ---------------------------------------------------- 57 | 58 | .. automodule:: torchfl.models.core.cifar.cifar100.squeezenet 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | torchfl.models.core.cifar.cifar100.vgg module 64 | --------------------------------------------- 65 | 66 | .. automodule:: torchfl.models.core.cifar.cifar100.vgg 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | Module contents 72 | --------------- 73 | 74 | .. automodule:: torchfl.models.core.cifar.cifar100 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | -------------------------------------------------------------------------------- /docs/torchfl.models.core.cifar.rst: -------------------------------------------------------------------------------- 1 | torchfl.models.core.cifar package 2 | ================================= 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | torchfl.models.core.cifar.cifar10 11 | torchfl.models.core.cifar.cifar100 12 | 13 | Module contents 14 | --------------- 15 | 16 | .. automodule:: torchfl.models.core.cifar 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | -------------------------------------------------------------------------------- /docs/torchfl.models.core.emnist.balanced.rst: -------------------------------------------------------------------------------- 1 | torchfl.models.core.emnist.balanced package 2 | =========================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchfl.models.core.emnist.balanced.alexnet module 8 | -------------------------------------------------- 9 | 10 | .. automodule:: torchfl.models.core.emnist.balanced.alexnet 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchfl.models.core.emnist.balanced.densenet module 16 | --------------------------------------------------- 17 | 18 | .. automodule:: torchfl.models.core.emnist.balanced.densenet 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchfl.models.core.emnist.balanced.lenet module 24 | ------------------------------------------------ 25 | 26 | .. automodule:: torchfl.models.core.emnist.balanced.lenet 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | torchfl.models.core.emnist.balanced.mlp module 32 | ---------------------------------------------- 33 | 34 | .. automodule:: torchfl.models.core.emnist.balanced.mlp 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | torchfl.models.core.emnist.balanced.mobilenet module 40 | ---------------------------------------------------- 41 | 42 | .. automodule:: torchfl.models.core.emnist.balanced.mobilenet 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | torchfl.models.core.emnist.balanced.resnet module 48 | ------------------------------------------------- 49 | 50 | .. automodule:: torchfl.models.core.emnist.balanced.resnet 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | torchfl.models.core.emnist.balanced.shufflenetv2 module 56 | ------------------------------------------------------- 57 | 58 | .. automodule:: torchfl.models.core.emnist.balanced.shufflenetv2 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | torchfl.models.core.emnist.balanced.squeezenet module 64 | ----------------------------------------------------- 65 | 66 | .. automodule:: torchfl.models.core.emnist.balanced.squeezenet 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | torchfl.models.core.emnist.balanced.vgg module 72 | ---------------------------------------------- 73 | 74 | .. automodule:: torchfl.models.core.emnist.balanced.vgg 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | Module contents 80 | --------------- 81 | 82 | .. automodule:: torchfl.models.core.emnist.balanced 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | -------------------------------------------------------------------------------- /docs/torchfl.models.core.emnist.byclass.rst: -------------------------------------------------------------------------------- 1 | torchfl.models.core.emnist.byclass package 2 | ========================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchfl.models.core.emnist.byclass.alexnet module 8 | ------------------------------------------------- 9 | 10 | .. automodule:: torchfl.models.core.emnist.byclass.alexnet 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchfl.models.core.emnist.byclass.densenet module 16 | -------------------------------------------------- 17 | 18 | .. automodule:: torchfl.models.core.emnist.byclass.densenet 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchfl.models.core.emnist.byclass.lenet module 24 | ----------------------------------------------- 25 | 26 | .. automodule:: torchfl.models.core.emnist.byclass.lenet 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | torchfl.models.core.emnist.byclass.mlp module 32 | --------------------------------------------- 33 | 34 | .. automodule:: torchfl.models.core.emnist.byclass.mlp 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | torchfl.models.core.emnist.byclass.mobilenet module 40 | --------------------------------------------------- 41 | 42 | .. automodule:: torchfl.models.core.emnist.byclass.mobilenet 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | torchfl.models.core.emnist.byclass.resnet module 48 | ------------------------------------------------ 49 | 50 | .. automodule:: torchfl.models.core.emnist.byclass.resnet 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | torchfl.models.core.emnist.byclass.shufflenetv2 module 56 | ------------------------------------------------------ 57 | 58 | .. automodule:: torchfl.models.core.emnist.byclass.shufflenetv2 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | torchfl.models.core.emnist.byclass.squeezenet module 64 | ---------------------------------------------------- 65 | 66 | .. automodule:: torchfl.models.core.emnist.byclass.squeezenet 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | torchfl.models.core.emnist.byclass.vgg module 72 | --------------------------------------------- 73 | 74 | .. automodule:: torchfl.models.core.emnist.byclass.vgg 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | Module contents 80 | --------------- 81 | 82 | .. automodule:: torchfl.models.core.emnist.byclass 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | -------------------------------------------------------------------------------- /docs/torchfl.models.core.emnist.bymerge.rst: -------------------------------------------------------------------------------- 1 | torchfl.models.core.emnist.bymerge package 2 | ========================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchfl.models.core.emnist.bymerge.alexnet module 8 | ------------------------------------------------- 9 | 10 | .. automodule:: torchfl.models.core.emnist.bymerge.alexnet 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchfl.models.core.emnist.bymerge.densenet module 16 | -------------------------------------------------- 17 | 18 | .. automodule:: torchfl.models.core.emnist.bymerge.densenet 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchfl.models.core.emnist.bymerge.lenet module 24 | ----------------------------------------------- 25 | 26 | .. automodule:: torchfl.models.core.emnist.bymerge.lenet 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | torchfl.models.core.emnist.bymerge.mlp module 32 | --------------------------------------------- 33 | 34 | .. automodule:: torchfl.models.core.emnist.bymerge.mlp 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | torchfl.models.core.emnist.bymerge.mobilenet module 40 | --------------------------------------------------- 41 | 42 | .. automodule:: torchfl.models.core.emnist.bymerge.mobilenet 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | torchfl.models.core.emnist.bymerge.resnet module 48 | ------------------------------------------------ 49 | 50 | .. automodule:: torchfl.models.core.emnist.bymerge.resnet 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | torchfl.models.core.emnist.bymerge.shufflenetv2 module 56 | ------------------------------------------------------ 57 | 58 | .. automodule:: torchfl.models.core.emnist.bymerge.shufflenetv2 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | torchfl.models.core.emnist.bymerge.squeezenet module 64 | ---------------------------------------------------- 65 | 66 | .. automodule:: torchfl.models.core.emnist.bymerge.squeezenet 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | torchfl.models.core.emnist.bymerge.vgg module 72 | --------------------------------------------- 73 | 74 | .. automodule:: torchfl.models.core.emnist.bymerge.vgg 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | Module contents 80 | --------------- 81 | 82 | .. automodule:: torchfl.models.core.emnist.bymerge 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | -------------------------------------------------------------------------------- /docs/torchfl.models.core.emnist.digits.rst: -------------------------------------------------------------------------------- 1 | torchfl.models.core.emnist.digits package 2 | ========================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchfl.models.core.emnist.digits.alexnet module 8 | ------------------------------------------------ 9 | 10 | .. automodule:: torchfl.models.core.emnist.digits.alexnet 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchfl.models.core.emnist.digits.densenet module 16 | ------------------------------------------------- 17 | 18 | .. automodule:: torchfl.models.core.emnist.digits.densenet 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchfl.models.core.emnist.digits.lenet module 24 | ---------------------------------------------- 25 | 26 | .. automodule:: torchfl.models.core.emnist.digits.lenet 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | torchfl.models.core.emnist.digits.mlp module 32 | -------------------------------------------- 33 | 34 | .. automodule:: torchfl.models.core.emnist.digits.mlp 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | torchfl.models.core.emnist.digits.mobilenet module 40 | -------------------------------------------------- 41 | 42 | .. automodule:: torchfl.models.core.emnist.digits.mobilenet 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | torchfl.models.core.emnist.digits.resnet module 48 | ----------------------------------------------- 49 | 50 | .. automodule:: torchfl.models.core.emnist.digits.resnet 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | torchfl.models.core.emnist.digits.shufflenetv2 module 56 | ----------------------------------------------------- 57 | 58 | .. automodule:: torchfl.models.core.emnist.digits.shufflenetv2 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | torchfl.models.core.emnist.digits.squeezenet module 64 | --------------------------------------------------- 65 | 66 | .. automodule:: torchfl.models.core.emnist.digits.squeezenet 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | torchfl.models.core.emnist.digits.vgg module 72 | -------------------------------------------- 73 | 74 | .. automodule:: torchfl.models.core.emnist.digits.vgg 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | Module contents 80 | --------------- 81 | 82 | .. automodule:: torchfl.models.core.emnist.digits 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | -------------------------------------------------------------------------------- /docs/torchfl.models.core.emnist.letters.rst: -------------------------------------------------------------------------------- 1 | torchfl.models.core.emnist.letters package 2 | ========================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchfl.models.core.emnist.letters.alexnet module 8 | ------------------------------------------------- 9 | 10 | .. automodule:: torchfl.models.core.emnist.letters.alexnet 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchfl.models.core.emnist.letters.densenet module 16 | -------------------------------------------------- 17 | 18 | .. automodule:: torchfl.models.core.emnist.letters.densenet 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchfl.models.core.emnist.letters.lenet module 24 | ----------------------------------------------- 25 | 26 | .. automodule:: torchfl.models.core.emnist.letters.lenet 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | torchfl.models.core.emnist.letters.mlp module 32 | --------------------------------------------- 33 | 34 | .. automodule:: torchfl.models.core.emnist.letters.mlp 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | torchfl.models.core.emnist.letters.mobilenet module 40 | --------------------------------------------------- 41 | 42 | .. automodule:: torchfl.models.core.emnist.letters.mobilenet 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | torchfl.models.core.emnist.letters.resnet module 48 | ------------------------------------------------ 49 | 50 | .. automodule:: torchfl.models.core.emnist.letters.resnet 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | torchfl.models.core.emnist.letters.shufflenetv2 module 56 | ------------------------------------------------------ 57 | 58 | .. automodule:: torchfl.models.core.emnist.letters.shufflenetv2 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | torchfl.models.core.emnist.letters.squeezenet module 64 | ---------------------------------------------------- 65 | 66 | .. automodule:: torchfl.models.core.emnist.letters.squeezenet 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | torchfl.models.core.emnist.letters.vgg module 72 | --------------------------------------------- 73 | 74 | .. automodule:: torchfl.models.core.emnist.letters.vgg 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | Module contents 80 | --------------- 81 | 82 | .. automodule:: torchfl.models.core.emnist.letters 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | -------------------------------------------------------------------------------- /docs/torchfl.models.core.emnist.mnist.rst: -------------------------------------------------------------------------------- 1 | torchfl.models.core.emnist.mnist package 2 | ======================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchfl.models.core.emnist.mnist.alexnet module 8 | ----------------------------------------------- 9 | 10 | .. automodule:: torchfl.models.core.emnist.mnist.alexnet 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchfl.models.core.emnist.mnist.densenet module 16 | ------------------------------------------------ 17 | 18 | .. automodule:: torchfl.models.core.emnist.mnist.densenet 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchfl.models.core.emnist.mnist.lenet module 24 | --------------------------------------------- 25 | 26 | .. automodule:: torchfl.models.core.emnist.mnist.lenet 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | torchfl.models.core.emnist.mnist.mlp module 32 | ------------------------------------------- 33 | 34 | .. automodule:: torchfl.models.core.emnist.mnist.mlp 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | torchfl.models.core.emnist.mnist.mobilenet module 40 | ------------------------------------------------- 41 | 42 | .. automodule:: torchfl.models.core.emnist.mnist.mobilenet 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | torchfl.models.core.emnist.mnist.resnet module 48 | ---------------------------------------------- 49 | 50 | .. automodule:: torchfl.models.core.emnist.mnist.resnet 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | torchfl.models.core.emnist.mnist.shufflenetv2 module 56 | ---------------------------------------------------- 57 | 58 | .. automodule:: torchfl.models.core.emnist.mnist.shufflenetv2 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | torchfl.models.core.emnist.mnist.squeezenet module 64 | -------------------------------------------------- 65 | 66 | .. automodule:: torchfl.models.core.emnist.mnist.squeezenet 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | torchfl.models.core.emnist.mnist.vgg module 72 | ------------------------------------------- 73 | 74 | .. automodule:: torchfl.models.core.emnist.mnist.vgg 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | Module contents 80 | --------------- 81 | 82 | .. automodule:: torchfl.models.core.emnist.mnist 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | -------------------------------------------------------------------------------- /docs/torchfl.models.core.emnist.rst: -------------------------------------------------------------------------------- 1 | torchfl.models.core.emnist package 2 | ================================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | torchfl.models.core.emnist.balanced 11 | torchfl.models.core.emnist.byclass 12 | torchfl.models.core.emnist.bymerge 13 | torchfl.models.core.emnist.digits 14 | torchfl.models.core.emnist.letters 15 | torchfl.models.core.emnist.mnist 16 | 17 | Module contents 18 | --------------- 19 | 20 | .. automodule:: torchfl.models.core.emnist 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | -------------------------------------------------------------------------------- /docs/torchfl.models.core.fashionmnist.rst: -------------------------------------------------------------------------------- 1 | torchfl.models.core.fashionmnist package 2 | ======================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchfl.models.core.fashionmnist.alexnet module 8 | ----------------------------------------------- 9 | 10 | .. automodule:: torchfl.models.core.fashionmnist.alexnet 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchfl.models.core.fashionmnist.densenet module 16 | ------------------------------------------------ 17 | 18 | .. automodule:: torchfl.models.core.fashionmnist.densenet 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchfl.models.core.fashionmnist.lenet module 24 | --------------------------------------------- 25 | 26 | .. automodule:: torchfl.models.core.fashionmnist.lenet 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | torchfl.models.core.fashionmnist.mlp module 32 | ------------------------------------------- 33 | 34 | .. automodule:: torchfl.models.core.fashionmnist.mlp 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | torchfl.models.core.fashionmnist.mobilenet module 40 | ------------------------------------------------- 41 | 42 | .. automodule:: torchfl.models.core.fashionmnist.mobilenet 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | torchfl.models.core.fashionmnist.resnet module 48 | ---------------------------------------------- 49 | 50 | .. automodule:: torchfl.models.core.fashionmnist.resnet 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | torchfl.models.core.fashionmnist.shufflenetv2 module 56 | ---------------------------------------------------- 57 | 58 | .. automodule:: torchfl.models.core.fashionmnist.shufflenetv2 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | torchfl.models.core.fashionmnist.squeezenet module 64 | -------------------------------------------------- 65 | 66 | .. automodule:: torchfl.models.core.fashionmnist.squeezenet 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | torchfl.models.core.fashionmnist.vgg module 72 | ------------------------------------------- 73 | 74 | .. automodule:: torchfl.models.core.fashionmnist.vgg 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | Module contents 80 | --------------- 81 | 82 | .. automodule:: torchfl.models.core.fashionmnist 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | -------------------------------------------------------------------------------- /docs/torchfl.models.core.rst: -------------------------------------------------------------------------------- 1 | torchfl.models.core package 2 | =========================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | torchfl.models.core.cifar 11 | torchfl.models.core.emnist 12 | torchfl.models.core.fashionmnist 13 | 14 | Module contents 15 | --------------- 16 | 17 | .. automodule:: torchfl.models.core 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | -------------------------------------------------------------------------------- /docs/torchfl.models.rst: -------------------------------------------------------------------------------- 1 | torchfl.models package 2 | ====================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | torchfl.models.core 11 | torchfl.models.sota 12 | torchfl.models.wrapper 13 | 14 | Module contents 15 | --------------- 16 | 17 | .. automodule:: torchfl.models 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | -------------------------------------------------------------------------------- /docs/torchfl.models.sota.rst: -------------------------------------------------------------------------------- 1 | torchfl.models.sota package 2 | =========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchfl.models.sota.alexnet module 8 | ---------------------------------- 9 | 10 | .. automodule:: torchfl.models.sota.alexnet 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchfl.models.sota.densenet module 16 | ----------------------------------- 17 | 18 | .. automodule:: torchfl.models.sota.densenet 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchfl.models.sota.lenet module 24 | -------------------------------- 25 | 26 | .. automodule:: torchfl.models.sota.lenet 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | torchfl.models.sota.mlp module 32 | ------------------------------ 33 | 34 | .. automodule:: torchfl.models.sota.mlp 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | torchfl.models.sota.mobilenet module 40 | ------------------------------------ 41 | 42 | .. automodule:: torchfl.models.sota.mobilenet 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | torchfl.models.sota.resnet module 48 | --------------------------------- 49 | 50 | .. automodule:: torchfl.models.sota.resnet 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | torchfl.models.sota.shufflenetv2 module 56 | --------------------------------------- 57 | 58 | .. automodule:: torchfl.models.sota.shufflenetv2 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | torchfl.models.sota.squeezenet module 64 | ------------------------------------- 65 | 66 | .. automodule:: torchfl.models.sota.squeezenet 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | torchfl.models.sota.vgg module 72 | ------------------------------ 73 | 74 | .. automodule:: torchfl.models.sota.vgg 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | Module contents 80 | --------------- 81 | 82 | .. automodule:: torchfl.models.sota 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | -------------------------------------------------------------------------------- /docs/torchfl.models.wrapper.rst: -------------------------------------------------------------------------------- 1 | .. _torchfl.wrapper: 2 | 3 | torchfl.models.wrapper package 4 | ============================== 5 | 6 | Submodules 7 | ---------- 8 | 9 | torchfl.models.wrapper.cifar module 10 | ----------------------------------- 11 | 12 | .. automodule:: torchfl.models.wrapper.cifar 13 | :members: 14 | :undoc-members: 15 | :show-inheritance: 16 | 17 | torchfl.models.wrapper.emnist module 18 | ------------------------------------ 19 | 20 | .. automodule:: torchfl.models.wrapper.emnist 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | 25 | torchfl.models.wrapper.fashionmnist module 26 | ------------------------------------------ 27 | 28 | .. automodule:: torchfl.models.wrapper.fashionmnist 29 | :members: 30 | :undoc-members: 31 | :show-inheritance: 32 | 33 | Module contents 34 | --------------- 35 | 36 | .. automodule:: torchfl.models.wrapper 37 | :members: 38 | :undoc-members: 39 | :show-inheritance: 40 | -------------------------------------------------------------------------------- /docs/torchfl.rst: -------------------------------------------------------------------------------- 1 | torchfl package 2 | =============== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | torchfl.datamodules 11 | torchfl.federated 12 | torchfl.models 13 | 14 | Submodules 15 | ---------- 16 | 17 | torchfl.cli module 18 | ------------------ 19 | 20 | .. automodule:: torchfl.cli 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | 25 | torchfl.compatibility module 26 | ---------------------------- 27 | 28 | .. automodule:: torchfl.compatibility 29 | :members: 30 | :undoc-members: 31 | :show-inheritance: 32 | 33 | Module contents 34 | --------------- 35 | 36 | .. automodule:: torchfl 37 | :members: 38 | :undoc-members: 39 | :show-inheritance: 40 | -------------------------------------------------------------------------------- /examples/federated/.gitkeep: -------------------------------------------------------------------------------- 1 | keep 2 | -------------------------------------------------------------------------------- /examples/federated/aggregators/fedavg_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """An example script to test the FedAvg aggregation.""" 4 | from torchfl.compatibility import OPTIMIZERS_TYPE 5 | from torchfl.federated.aggregators.fedavg import FedAvgAggregator 6 | from torchfl.models.wrapper.emnist import EMNIST_MODELS_ENUM, MNISTEMNIST 7 | 8 | if __name__ == "__main__": 9 | model = MNISTEMNIST( 10 | EMNIST_MODELS_ENUM.LENET, 11 | OPTIMIZERS_TYPE.ADAM, 12 | {"lr": 0.001}, 13 | {}, 14 | ) 15 | a_map = {0: model, 1: model} 16 | agg = FedAvgAggregator([]) 17 | out = agg.aggregate(model, a_map) 18 | model.load_state_dict(out) 19 | -------------------------------------------------------------------------------- /examples/federated/mnist_entrypoint_iid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """An example script for using FL entrypoint to setup a FL experiment on MNIST.""" 4 | 5 | from typing import Any 6 | 7 | from torch.utils.data import DataLoader 8 | 9 | from torchfl.compatibility import OPTIMIZERS_TYPE 10 | from torchfl.datamodules.emnist import ( 11 | SUPPORTED_DATASETS_TYPE, 12 | EMNISTDataModule, 13 | ) 14 | from torchfl.federated.agents.v1 import V1Agent 15 | from torchfl.federated.aggregators.fedavg import FedAvgAggregator 16 | from torchfl.federated.entrypoint import Entrypoint 17 | from torchfl.federated.fl_params import FLParams 18 | from torchfl.federated.samplers.random import RandomSampler 19 | from torchfl.models.wrapper.emnist import EMNIST_MODELS_ENUM, MNISTEMNIST 20 | 21 | 22 | def initialize_agents( 23 | fl_params: FLParams, agent_data_shard_map: dict[int, DataLoader] 24 | ) -> list[V1Agent]: 25 | """Initialize agents.""" 26 | agents = [] 27 | for agent_id in range(fl_params.num_agents): 28 | agent = V1Agent( 29 | id=agent_id, 30 | model=MNISTEMNIST( 31 | model_name=EMNIST_MODELS_ENUM.MOBILENETV3SMALL, 32 | optimizer_name=OPTIMIZERS_TYPE.ADAM, 33 | optimizer_hparams={"lr": 0.001}, 34 | model_hparams={"pre_trained": True, "feature_extract": True}, 35 | fl_hparams=fl_params, 36 | ), 37 | data_shard=agent_data_shard_map[agent_id], 38 | ) 39 | agents.append(agent) 40 | return agents 41 | 42 | 43 | def get_agent_data_shard_map() -> EMNISTDataModule: 44 | datamodule: EMNISTDataModule = EMNISTDataModule( 45 | dataset_name=SUPPORTED_DATASETS_TYPE.MNIST, train_batch_size=10 46 | ) 47 | datamodule.prepare_data() 48 | datamodule.setup() 49 | return datamodule 50 | 51 | 52 | def main() -> None: 53 | """Main function.""" 54 | fl_params = FLParams( 55 | experiment_name="iid_mnist_fedavg_10_agents_5_sampled_50_epochs_mobilenetv3small_latest", 56 | num_agents=10, 57 | global_epochs=10, 58 | local_epochs=2, 59 | sampling_ratio=0.5, 60 | ) 61 | global_model = MNISTEMNIST( 62 | model_name=EMNIST_MODELS_ENUM.MOBILENETV3SMALL, 63 | optimizer_name=OPTIMIZERS_TYPE.ADAM, 64 | optimizer_hparams={"lr": 0.001}, 65 | model_hparams={"pre_trained": True, "feature_extract": True}, 66 | fl_hparams=fl_params, 67 | ) 68 | agent_data_shard_map = get_agent_data_shard_map().federated_iid_dataloader( 69 | num_workers=fl_params.num_agents, 70 | workers_batch_size=fl_params.local_train_batch_size, 71 | ) 72 | all_agents: Any = initialize_agents(fl_params, agent_data_shard_map) 73 | entrypoint = Entrypoint( 74 | global_model=global_model, 75 | global_datamodule=get_agent_data_shard_map(), 76 | fl_hparams=fl_params, 77 | agents=all_agents, 78 | aggregator=FedAvgAggregator(all_agents=all_agents), 79 | sampler=RandomSampler(all_agents=all_agents), 80 | ) 81 | entrypoint.run() 82 | 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /examples/federated/mnist_entrypoint_niid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """An example script for using FL entrypoint to setup a FL experiment on MNIST.""" 4 | 5 | from typing import Any 6 | 7 | from torch.utils.data import DataLoader 8 | 9 | from torchfl.compatibility import OPTIMIZERS_TYPE 10 | from torchfl.datamodules.emnist import ( 11 | SUPPORTED_DATASETS_TYPE, 12 | EMNISTDataModule, 13 | ) 14 | from torchfl.federated.agents.v1 import V1Agent 15 | from torchfl.federated.aggregators.fedavg import FedAvgAggregator 16 | from torchfl.federated.entrypoint import Entrypoint 17 | from torchfl.federated.fl_params import FLParams 18 | from torchfl.federated.samplers.random import RandomSampler 19 | from torchfl.models.wrapper.emnist import EMNIST_MODELS_ENUM, MNISTEMNIST 20 | 21 | 22 | def initialize_agents( 23 | fl_params: FLParams, agent_data_shard_map: dict[int, DataLoader] 24 | ) -> list[V1Agent]: 25 | """Initialize agents.""" 26 | agents = [] 27 | for agent_id in range(fl_params.num_agents): 28 | agent = V1Agent( 29 | id=agent_id, 30 | model=MNISTEMNIST( 31 | model_name=EMNIST_MODELS_ENUM.MOBILENETV3SMALL, 32 | optimizer_name=OPTIMIZERS_TYPE.ADAM, 33 | optimizer_hparams={"lr": 0.001}, 34 | model_hparams={"pre_trained": True, "feature_extract": True}, 35 | fl_hparams=fl_params, 36 | ), 37 | data_shard=agent_data_shard_map[agent_id], 38 | ) 39 | agents.append(agent) 40 | return agents 41 | 42 | 43 | def get_agent_data_shard_map() -> EMNISTDataModule: 44 | datamodule: EMNISTDataModule = EMNISTDataModule( 45 | dataset_name=SUPPORTED_DATASETS_TYPE.MNIST, train_batch_size=10 46 | ) 47 | datamodule.prepare_data() 48 | datamodule.setup() 49 | return datamodule 50 | 51 | 52 | def main() -> None: 53 | """Main function.""" 54 | fl_params = FLParams( 55 | experiment_name="niid_1_mnist_fedavg_10_agents_5_sampled_50_epochs_mobilenetv3small", 56 | num_agents=10, 57 | global_epochs=10, 58 | local_epochs=2, 59 | sampling_ratio=0.5, 60 | local_test_batch_size=1, 61 | ) 62 | global_model = MNISTEMNIST( 63 | model_name=EMNIST_MODELS_ENUM.MOBILENETV3SMALL, 64 | optimizer_name=OPTIMIZERS_TYPE.ADAM, 65 | optimizer_hparams={"lr": 0.001}, 66 | model_hparams={"pre_trained": True, "feature_extract": True}, 67 | fl_hparams=fl_params, 68 | ) 69 | agent_data_shard_map = ( 70 | get_agent_data_shard_map().federated_non_iid_dataloader( 71 | num_workers=fl_params.num_agents, 72 | workers_batch_size=fl_params.local_train_batch_size, 73 | niid_factor=1, 74 | ) 75 | ) 76 | all_agents: Any = initialize_agents(fl_params, agent_data_shard_map) 77 | entrypoint = Entrypoint( 78 | global_model=global_model, 79 | global_datamodule=get_agent_data_shard_map(), 80 | fl_hparams=fl_params, 81 | agents=all_agents, 82 | aggregator=FedAvgAggregator(all_agents=all_agents), 83 | sampler=RandomSampler(all_agents=all_agents), 84 | ) 85 | entrypoint.run() 86 | 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Unit test package for torchfl.""" 2 | -------------------------------------------------------------------------------- /tests/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | """Unit test sub-package for datamodules in torchfl.""" 2 | -------------------------------------------------------------------------------- /tests/datamodules/test_fashionmnist.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Tests for FashionMNIST PyTorch LightningDataModule module in `torchfl` package.""" 4 | from collections import Counter 5 | 6 | import pytest 7 | 8 | from torchfl.datamodules.fashionmnist import FashionMNISTDataModule 9 | 10 | 11 | @pytest.fixture() 12 | def fashionmnist_data_module(): 13 | """Fixture for FashionMNIST PyTorch LightningDataModule 14 | 15 | Returns: 16 | FashionMNISTDataModule: PyTorch LightningDataModule for FashionMNIST. 17 | """ 18 | return FashionMNISTDataModule() 19 | 20 | 21 | @pytest.mark.datamodules_fashionmnist() 22 | def test_fashionmnist_train_val_split(fashionmnist_data_module): 23 | """Testing the fashionmnist dataset train and validation split with PyTorch Lightning wrapper. 24 | 25 | Args: 26 | fashionmnist_data_module (FashionMNISTDataModule): PyTorch LightningDataModule for fashionmnist. 27 | """ 28 | fashionmnist_data_module.prepare_data() 29 | fashionmnist_data_module.setup(stage="fit") 30 | train_dataloader = fashionmnist_data_module.train_dataloader() 31 | val_dataloader = fashionmnist_data_module.val_dataloader() 32 | assert len(train_dataloader.dataset) == 54000 33 | assert len(val_dataloader.dataset) == 6000 34 | 35 | 36 | @pytest.mark.datamodules_fashionmnist() 37 | def test_fashionmnist_test_split(fashionmnist_data_module): 38 | """Testing the fashionmnist dataset test split with PyTorch Lightning wrapper. 39 | 40 | Args: 41 | fashionmnist_data_module (FashionMNISTDataModule): PyTorch LightningDataModule for fashionmnist. 42 | """ 43 | fashionmnist_data_module.prepare_data() 44 | fashionmnist_data_module.setup(stage="test") 45 | test_dataloader = fashionmnist_data_module.test_dataloader() 46 | assert len(test_dataloader.dataset) == 10000 47 | 48 | 49 | @pytest.mark.datamodules_fashionmnist() 50 | def test_fashionmnist_prediction_split(fashionmnist_data_module): 51 | """Testing the fashionmnist dataset prediction split with PyTorch Lightning wrapper. 52 | 53 | Args: 54 | fashionmnist_data_module (FashionMNISTDataModule): PyTorch LightningDataModule for fashionmnist. 55 | """ 56 | fashionmnist_data_module.prepare_data() 57 | fashionmnist_data_module.setup(stage="predict") 58 | predict_dataloader = fashionmnist_data_module.predict_dataloader() 59 | assert len(predict_dataloader.dataset) == 10000 60 | 61 | 62 | @pytest.mark.datamodules_fashionmnist() 63 | def test_fashionmnist_federated_iid_split(fashionmnist_data_module): 64 | """Testing the fashionmnist dataset federated iid split with PyTorch Lightning wrapper. 65 | 66 | Args: 67 | fashionmnist_data_module (FashionMNISTDataModule): PyTorch LightningDataModule for fashionmnist. 68 | """ 69 | fashionmnist_data_module.prepare_data() 70 | fashionmnist_data_module.setup(stage="fit") 71 | dataloader = fashionmnist_data_module.federated_iid_dataloader() 72 | assert len(dataloader.keys()) == 10 73 | assert len(dataloader[0].dataset) == 6000 74 | frequency = Counter(list(dataloader[0].dataset.targets)) 75 | assert len(frequency.keys()) == 10 76 | 77 | 78 | @pytest.mark.datamodules_fashionmnist() 79 | def test_fashionmnist_federated_non_iid_split(fashionmnist_data_module): 80 | """Testing the fashionmnist dataset federated non iid split with PyTorch Lightning wrapper. 81 | 82 | Args: 83 | fashionmnist_data_module (FashionMNISTDataModule): PyTorch LightningDataModule for fashionmnist. 84 | """ 85 | fashionmnist_data_module.prepare_data() 86 | fashionmnist_data_module.setup(stage="fit") 87 | dataloader = fashionmnist_data_module.federated_non_iid_dataloader() 88 | assert len(dataloader.keys()) == 10 89 | all_freq = [] 90 | for i in range(10): 91 | assert len(dataloader[i].dataset) == 6000 92 | frequency = Counter(list(dataloader[i].dataset.targets)) 93 | all_freq.append(len(frequency.keys())) 94 | assert max(all_freq) == 2 95 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Unit test sub-package for models in torchfl.""" 2 | -------------------------------------------------------------------------------- /tests/models/test_alexnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Tests for AlexNet in `torchfl` package.""" 4 | import torch 5 | 6 | from torchfl.models.sota.alexnet import AlexNet 7 | 8 | 9 | def test_alexnet_single_channel_output_shape(): 10 | model = AlexNet(num_channels=1) 11 | model.zero_grad() 12 | out = model(torch.randn(1, 1, 224, 224)) 13 | assert out.size() == torch.Size([1, 10]) 14 | 15 | 16 | def test_alexnet_three_channel_output_shape(): 17 | model = AlexNet(num_channels=3) 18 | model.zero_grad() 19 | out = model(torch.randn(1, 3, 224, 224)) 20 | assert out.size() == torch.Size([1, 10]) 21 | -------------------------------------------------------------------------------- /tests/models/test_densenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Tests for DenseNet in `torchfl` package.""" 4 | 5 | import torch 6 | 7 | from torchfl.models.sota.densenet import ( 8 | DenseNet121, 9 | DenseNet161, 10 | DenseNet169, 11 | DenseNet201, 12 | ) 13 | 14 | 15 | def test_densenet121_single_channel_ouput_shape(): 16 | model = DenseNet121(num_channels=1) 17 | model.zero_grad() 18 | out = model(torch.randn(1, 1, 224, 224)) 19 | assert out.size() == torch.Size([1, 10]) 20 | 21 | 22 | def test_densenet121_three_channel_ouput_shape(): 23 | model = DenseNet121(num_channels=3) 24 | model.zero_grad() 25 | out = model(torch.randn(1, 3, 224, 224)) 26 | assert out.size() == torch.Size([1, 10]) 27 | 28 | 29 | def test_densenet161_single_channel_ouput_shape(): 30 | model = DenseNet161(num_channels=1) 31 | model.zero_grad() 32 | out = model(torch.randn(1, 1, 224, 224)) 33 | assert out.size() == torch.Size([1, 10]) 34 | 35 | 36 | def test_densenet161_three_channel_ouput_shape(): 37 | model = DenseNet161(num_channels=3) 38 | model.zero_grad() 39 | out = model(torch.randn(1, 3, 224, 224)) 40 | assert out.size() == torch.Size([1, 10]) 41 | 42 | 43 | def test_densenet169_single_channel_ouput_shape(): 44 | model = DenseNet169(num_channels=1) 45 | model.zero_grad() 46 | out = model(torch.randn(1, 1, 224, 224)) 47 | assert out.size() == torch.Size([1, 10]) 48 | 49 | 50 | def test_densenet169_three_channel_ouput_shape(): 51 | model = DenseNet169(num_channels=3) 52 | model.zero_grad() 53 | out = model(torch.randn(1, 3, 224, 224)) 54 | assert out.size() == torch.Size([1, 10]) 55 | 56 | 57 | def test_densenet201_single_channel_ouput_shape(): 58 | model = DenseNet201(num_channels=1) 59 | model.zero_grad() 60 | out = model(torch.randn(1, 1, 224, 224)) 61 | assert out.size() == torch.Size([1, 10]) 62 | 63 | 64 | def test_densenet201_three_channel_ouput_shape(): 65 | model = DenseNet201(num_channels=3) 66 | model.zero_grad() 67 | out = model(torch.randn(1, 3, 224, 224)) 68 | assert out.size() == torch.Size([1, 10]) 69 | -------------------------------------------------------------------------------- /tests/models/test_lenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Tests for LeNet in `torchfl` package.""" 4 | 5 | import torch 6 | 7 | from torchfl.models.sota.lenet import LeNet 8 | 9 | 10 | def test_lenet_single_channel_ouput_shape(): 11 | model = LeNet(num_channels=1) 12 | model.zero_grad() 13 | out = model(torch.randn(1, 1, 224, 224)) 14 | assert out.size() == torch.Size([1, 10]) 15 | 16 | 17 | def test_lenet_three_channel_ouput_shape(): 18 | model = LeNet(num_channels=3) 19 | model.zero_grad() 20 | out = model(torch.randn(1, 3, 224, 224)) 21 | assert out.size() == torch.Size([1, 10]) 22 | -------------------------------------------------------------------------------- /tests/models/test_mlp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Tests for MLP in `torchfl` package.""" 4 | 5 | import torch 6 | 7 | from torchfl.models.sota.mlp import MLP 8 | 9 | 10 | def test_mlp_single_channel_ouput_shape(): 11 | model = MLP(num_channels=1) 12 | model.zero_grad() 13 | out = model(torch.randn(1, 1, 28, 28)) 14 | assert out.size() == torch.Size([1, 10]) 15 | 16 | 17 | def test_mlp_three_channel_ouput_shape(): 18 | model = MLP(num_channels=3) 19 | model.zero_grad() 20 | out = model(torch.randn(1, 3, 28, 28)) 21 | assert out.size() == torch.Size([1, 10]) 22 | -------------------------------------------------------------------------------- /tests/models/test_mobilenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Tests for MobileNet in `torchfl` package.""" 4 | 5 | import torch 6 | 7 | from torchfl.models.sota.mobilenet import ( 8 | MobileNetV2, 9 | MobileNetV3Large, 10 | MobileNetV3Small, 11 | ) 12 | 13 | 14 | def test_mobilenetv2_single_channel_ouput_shape(): 15 | model = MobileNetV2(num_channels=1) 16 | model.zero_grad() 17 | out = model(torch.randn(1, 1, 224, 224)) 18 | assert out.size() == torch.Size([1, 10]) 19 | 20 | 21 | def test_mobilenetv2_three_channel_ouput_shape(): 22 | model = MobileNetV2(num_channels=3) 23 | model.zero_grad() 24 | out = model(torch.randn(1, 3, 224, 224)) 25 | assert out.size() == torch.Size([1, 10]) 26 | 27 | 28 | def test_mobilenetv3large_single_channel_ouput_shape(): 29 | model = MobileNetV3Large(num_channels=1) 30 | model.zero_grad() 31 | out = model(torch.randn(1, 1, 224, 224)) 32 | assert out.size() == torch.Size([1, 10]) 33 | 34 | 35 | def test_mobilenetv3large_three_channel_ouput_shape(): 36 | model = MobileNetV3Large(num_channels=3) 37 | model.zero_grad() 38 | out = model(torch.randn(1, 3, 224, 224)) 39 | assert out.size() == torch.Size([1, 10]) 40 | 41 | 42 | def test_mobilenetv3small_single_channel_ouput_shape(): 43 | model = MobileNetV3Small(num_channels=1) 44 | model.zero_grad() 45 | out = model(torch.randn(1, 1, 224, 224)) 46 | assert out.size() == torch.Size([1, 10]) 47 | 48 | 49 | def test_mobilenetv3small_three_channel_ouput_shape(): 50 | model = MobileNetV3Small(num_channels=3) 51 | model.zero_grad() 52 | out = model(torch.randn(1, 3, 224, 224)) 53 | assert out.size() == torch.Size([1, 10]) 54 | -------------------------------------------------------------------------------- /tests/models/test_shufflenetv2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Tests for ShuffleNetv2 in `torchfl` package.""" 4 | 5 | import torch 6 | 7 | from torchfl.models.sota.shufflenetv2 import ( 8 | ShuffleNetv2_x0_5, 9 | ShuffleNetv2_x1_0, 10 | ShuffleNetv2_x1_5, 11 | ShuffleNetv2_x2_0, 12 | ) 13 | 14 | 15 | def test_shufflenetv2_x0_5_single_channel_ouput_shape(): 16 | model = ShuffleNetv2_x0_5(num_channels=1) 17 | model.zero_grad() 18 | out = model(torch.randn(1, 1, 224, 224)) 19 | assert out.size() == torch.Size([1, 10]) 20 | 21 | 22 | def test_shufflenetv2_x0_5_three_channel_ouput_shape(): 23 | model = ShuffleNetv2_x0_5(num_channels=3) 24 | model.zero_grad() 25 | out = model(torch.randn(1, 3, 224, 224)) 26 | assert out.size() == torch.Size([1, 10]) 27 | 28 | 29 | def test_shufflenetv2_x1_0_single_channel_ouput_shape(): 30 | model = ShuffleNetv2_x1_0(num_channels=1) 31 | model.zero_grad() 32 | out = model(torch.randn(1, 1, 224, 224)) 33 | assert out.size() == torch.Size([1, 10]) 34 | 35 | 36 | def test_shufflenetv2_x1_0_three_channel_ouput_shape(): 37 | model = ShuffleNetv2_x1_0(num_channels=3) 38 | model.zero_grad() 39 | out = model(torch.randn(1, 3, 224, 224)) 40 | assert out.size() == torch.Size([1, 10]) 41 | 42 | 43 | def test_shufflenetv2_x1_5_single_channel_ouput_shape(): 44 | model = ShuffleNetv2_x1_5(num_channels=1, pre_trained=False) 45 | model.zero_grad() 46 | out = model(torch.randn(1, 1, 224, 224)) 47 | assert out.size() == torch.Size([1, 10]) 48 | 49 | 50 | def test_shufflenetv2_x1_5_three_channel_ouput_shape(): 51 | model = ShuffleNetv2_x1_5(num_channels=3, pre_trained=False) 52 | model.zero_grad() 53 | out = model(torch.randn(1, 3, 224, 224)) 54 | assert out.size() == torch.Size([1, 10]) 55 | 56 | 57 | def test_shufflenetv2_x2_0_single_channel_ouput_shape(): 58 | model = ShuffleNetv2_x2_0(num_channels=1, pre_trained=False) 59 | model.zero_grad() 60 | out = model(torch.randn(1, 1, 224, 224)) 61 | assert out.size() == torch.Size([1, 10]) 62 | 63 | 64 | def test_shufflenetv2_x2_0_three_channel_ouput_shape(): 65 | model = ShuffleNetv2_x2_0(num_channels=3, pre_trained=False) 66 | model.zero_grad() 67 | out = model(torch.randn(1, 3, 224, 224)) 68 | assert out.size() == torch.Size([1, 10]) 69 | -------------------------------------------------------------------------------- /tests/models/test_squeezenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Tests for SqueezeNet in `torchfl` package.""" 4 | 5 | import torch 6 | 7 | from torchfl.models.sota.squeezenet import SqueezeNet1_0, SqueezeNet1_1 8 | 9 | 10 | def test_squeezenet1_0_single_channel_ouput_shape(): 11 | model = SqueezeNet1_0(num_channels=1) 12 | model.zero_grad() 13 | out = model(torch.randn(1, 1, 224, 224)) 14 | assert out.size() == torch.Size([1, 10]) 15 | 16 | 17 | def test_squeezenet1_0_three_channel_ouput_shape(): 18 | model = SqueezeNet1_0(num_channels=3) 19 | model.zero_grad() 20 | out = model(torch.randn(1, 3, 224, 224)) 21 | assert out.size() == torch.Size([1, 10]) 22 | 23 | 24 | def test_squeezenet1_1_single_channel_ouput_shape(): 25 | model = SqueezeNet1_1(num_channels=1) 26 | model.zero_grad() 27 | out = model(torch.randn(1, 1, 224, 224)) 28 | assert out.size() == torch.Size([1, 10]) 29 | 30 | 31 | def test_squeezenet1_1_three_channel_ouput_shape(): 32 | model = SqueezeNet1_1(num_channels=3) 33 | model.zero_grad() 34 | out = model(torch.randn(1, 3, 224, 224)) 35 | assert out.size() == torch.Size([1, 10]) 36 | -------------------------------------------------------------------------------- /tests/models/test_vgg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Tests for VGG in `torchfl` package.""" 4 | 5 | import torch 6 | 7 | from torchfl.models.sota.vgg import ( 8 | VGG11, 9 | VGG11_BN, 10 | VGG13, 11 | VGG13_BN, 12 | VGG16, 13 | VGG16_BN, 14 | VGG19, 15 | VGG19_BN, 16 | ) 17 | 18 | 19 | def test_vgg11_single_channel_ouput_shape(): 20 | model = VGG11(num_channels=1) 21 | model.zero_grad() 22 | out = model(torch.randn(1, 1, 224, 224)) 23 | assert out.size() == torch.Size([1, 10]) 24 | 25 | 26 | def test_vgg11_three_channel_ouput_shape(): 27 | model = VGG11(num_channels=3) 28 | model.zero_grad() 29 | out = model(torch.randn(1, 3, 224, 224)) 30 | assert out.size() == torch.Size([1, 10]) 31 | 32 | 33 | def test_vgg11_bn_single_channel_ouput_shape(): 34 | model = VGG11_BN(num_channels=1) 35 | model.zero_grad() 36 | out = model(torch.randn(1, 1, 224, 224)) 37 | assert out.size() == torch.Size([1, 10]) 38 | 39 | 40 | def test_vgg11_bn_three_channel_ouput_shape(): 41 | model = VGG11_BN(num_channels=3) 42 | model.zero_grad() 43 | out = model(torch.randn(1, 3, 224, 224)) 44 | assert out.size() == torch.Size([1, 10]) 45 | 46 | 47 | def test_vgg13_single_channel_ouput_shape(): 48 | model = VGG13(num_channels=1) 49 | model.zero_grad() 50 | out = model(torch.randn(1, 1, 224, 224)) 51 | assert out.size() == torch.Size([1, 10]) 52 | 53 | 54 | def test_vgg13_three_channel_ouput_shape(): 55 | model = VGG13(num_channels=3) 56 | model.zero_grad() 57 | out = model(torch.randn(1, 3, 224, 224)) 58 | assert out.size() == torch.Size([1, 10]) 59 | 60 | 61 | def test_vgg13_bn_single_channel_ouput_shape(): 62 | model = VGG13_BN(num_channels=1) 63 | model.zero_grad() 64 | out = model(torch.randn(1, 1, 224, 224)) 65 | assert out.size() == torch.Size([1, 10]) 66 | 67 | 68 | def test_vgg13_bn_three_channel_ouput_shape(): 69 | model = VGG13_BN(num_channels=3) 70 | model.zero_grad() 71 | out = model(torch.randn(1, 3, 224, 224)) 72 | assert out.size() == torch.Size([1, 10]) 73 | 74 | 75 | def test_vgg16_single_channel_ouput_shape(): 76 | model = VGG16(num_channels=1) 77 | model.zero_grad() 78 | out = model(torch.randn(1, 1, 224, 224)) 79 | assert out.size() == torch.Size([1, 10]) 80 | 81 | 82 | def test_vgg16_three_channel_ouput_shape(): 83 | model = VGG16(num_channels=3) 84 | model.zero_grad() 85 | out = model(torch.randn(1, 3, 224, 224)) 86 | assert out.size() == torch.Size([1, 10]) 87 | 88 | 89 | def test_vgg16_bn_single_channel_ouput_shape(): 90 | model = VGG16_BN(num_channels=1) 91 | model.zero_grad() 92 | out = model(torch.randn(1, 1, 224, 224)) 93 | assert out.size() == torch.Size([1, 10]) 94 | 95 | 96 | def test_vgg16_bn_three_channel_ouput_shape(): 97 | model = VGG16_BN(num_channels=3) 98 | model.zero_grad() 99 | out = model(torch.randn(1, 3, 224, 224)) 100 | assert out.size() == torch.Size([1, 10]) 101 | 102 | 103 | def test_vgg19_single_channel_ouput_shape(): 104 | model = VGG19(num_channels=1) 105 | model.zero_grad() 106 | out = model(torch.randn(1, 1, 224, 224)) 107 | assert out.size() == torch.Size([1, 10]) 108 | 109 | 110 | def test_vgg19_three_channel_ouput_shape(): 111 | model = VGG19(num_channels=3) 112 | model.zero_grad() 113 | out = model(torch.randn(1, 3, 224, 224)) 114 | assert out.size() == torch.Size([1, 10]) 115 | 116 | 117 | def test_vgg19_bn_single_channel_ouput_shape(): 118 | model = VGG19_BN(num_channels=1) 119 | model.zero_grad() 120 | out = model(torch.randn(1, 1, 224, 224)) 121 | assert out.size() == torch.Size([1, 10]) 122 | 123 | 124 | def test_vgg19_bn_three_channel_ouput_shape(): 125 | model = VGG19_BN(num_channels=3) 126 | model.zero_grad() 127 | out = model(torch.randn(1, 3, 224, 224)) 128 | assert out.size() == torch.Size([1, 10]) 129 | -------------------------------------------------------------------------------- /tests/test_torchfl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Test for `torchfl` package.""" 4 | 5 | import pytest 6 | 7 | 8 | @pytest.mark.skip() 9 | def test_skip(): 10 | raise Exception("This test should be skipped") 11 | -------------------------------------------------------------------------------- /torchfl/__init__.py: -------------------------------------------------------------------------------- 1 | """Top-level package for torchfl.""" 2 | 3 | __author__ = """Vivek Khimani""" 4 | __email__ = "vivekkhimani07@gmail.com" 5 | __version__ = "0.1.9" 6 | -------------------------------------------------------------------------------- /torchfl/compatibility.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # mypy: ignore-errors 3 | 4 | """Defines the constants to ensure the consistency and compatibility between the files.""" 5 | import enum 6 | import os 7 | from pathlib import Path 8 | from typing import Any 9 | 10 | from torch.nn import GELU, LeakyReLU, ReLU, Tanh 11 | from torch.optim import ( 12 | ASGD, 13 | LBFGS, 14 | SGD, 15 | Adadelta, 16 | Adagrad, 17 | Adam, 18 | Adamax, 19 | AdamW, 20 | NAdam, 21 | RAdam, 22 | RMSprop, 23 | Rprop, 24 | SparseAdam, 25 | ) 26 | 27 | TORCHFL_DIR: str = os.path.join(Path.home(), ".torchfl") 28 | OPTIMIZERS = [ 29 | "adadelta", 30 | "adagrad", 31 | "adam", 32 | "adamw", 33 | "sparseadam", 34 | "adamax", 35 | "asgd", 36 | "lbfgs", 37 | "nadam", 38 | "radam", 39 | "rmsprop", 40 | "rprop", 41 | "sgd", 42 | ] 43 | ACTIVATION_FUNCTIONS = ["tanh", "relu", "leakyrelu", "gelu"] 44 | 45 | 46 | class OPTIMIZERS_TYPE(enum.Enum): 47 | """Enum class for the supported optimizers.""" 48 | 49 | ADAM = "adam" 50 | ADAMW = "adamw" 51 | ADAMAX = "adamax" 52 | ADAGRAD = "adagrad" 53 | ADADALTA = "adadelta" 54 | ASGD = "asgd" 55 | LBFGS = "lbfgs" 56 | NADAM = "nadam" 57 | RADAM = "radam" 58 | RMSPROP = "rmsprop" 59 | RPROP = "rprop" 60 | SGD = "sgd" 61 | SPARSEADAM = "sparseadam" 62 | 63 | 64 | class ACTIVATION_FUNCTIONS_TYPE(enum.Enum): 65 | TANH = "tanh" 66 | RELU = "relu" 67 | LEAKYRELU = "leakyrelu" 68 | GELU = "gelu" 69 | 70 | 71 | # mappings 72 | OPTIMIZERS_BY_NAME: dict[str, Any] = { 73 | "adadelta": Adadelta, 74 | "adagrad": Adagrad, 75 | "adam": Adam, 76 | "adamw": AdamW, 77 | "sparseadam": SparseAdam, 78 | "adamax": Adamax, 79 | "asgd": ASGD, 80 | "lbfgs": LBFGS, 81 | "nadam": NAdam, 82 | "radam": RAdam, 83 | "rmsprop": RMSprop, 84 | "rprop": Rprop, 85 | "sgd": SGD, 86 | } 87 | ACTIVATION_FUNCTIONS_BY_NAME: dict[str, Any] = { 88 | "tanh": Tanh, 89 | "relu": ReLU, 90 | "leakyrelu": LeakyReLU, 91 | "gelu": GELU, 92 | } 93 | -------------------------------------------------------------------------------- /torchfl/config_resolver.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the implementation of the ConfigResolver class.""" 4 | 5 | from argparse import Namespace 6 | 7 | 8 | class ConfigResolver: 9 | def __init__(self, config: Namespace) -> None: 10 | """Constructor 11 | 12 | Args: 13 | - config (Namespace): Namespace object from argparse 14 | """ 15 | self.config = config 16 | -------------------------------------------------------------------------------- /torchfl/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | """Sub-package for datamodules used by torchfl.""" 2 | 3 | __author__ = """Vivek Khimani""" 4 | __email__ = "vivekkhimani07@gmail.com" 5 | __version__ = "0.9.0" 6 | __all__ = ["emnist", "cifar", "fashionmnist"] 7 | -------------------------------------------------------------------------------- /torchfl/datamodules/types.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Types used within the datamodules utilities.""" 4 | 5 | import enum 6 | 7 | from torchfl.datamodules.cifar import ( 8 | SUPPORTED_DATASETS_TYPE as CIFAR_DATASETS_TYPE, 9 | ) 10 | from torchfl.datamodules.cifar import CIFARDataModule 11 | from torchfl.datamodules.emnist import ( 12 | SUPPORTED_DATASETS_TYPE as EMNIST_DATASETS_TYPE, 13 | ) 14 | from torchfl.datamodules.emnist import EMNISTDataModule 15 | from torchfl.datamodules.fashionmnist import ( 16 | SUPPORTED_DATASETS_TYPE as FASHIONMNIST_DATASETS_TYPE, 17 | ) 18 | from torchfl.datamodules.fashionmnist import FashionMNISTDataModule 19 | from torchfl.utils import _get_enum_values 20 | 21 | EMNIST_DATASETS: list[str] = _get_enum_values(EMNIST_DATASETS_TYPE) 22 | CIFAR_DATASETS: list[str] = _get_enum_values(CIFAR_DATASETS_TYPE) 23 | FASHIONMNIST_DATASETS: list[str] = _get_enum_values(FASHIONMNIST_DATASETS_TYPE) 24 | 25 | DATASET_GROUPS_MAP: dict[str, list[str]] = { 26 | "emnist": EMNIST_DATASETS, 27 | "cifar": CIFAR_DATASETS, 28 | "fashionmnist": FASHIONMNIST_DATASETS, 29 | } 30 | 31 | 32 | class DatasetGroupsEnum(enum.Enum): 33 | EMNIST = EMNISTDataModule 34 | CIFAR = CIFARDataModule 35 | FASHIONMNIST = FashionMNISTDataModule 36 | 37 | 38 | DatasetGroupsType = EMNISTDataModule | CIFARDataModule | FashionMNISTDataModule 39 | -------------------------------------------------------------------------------- /torchfl/federated/__init__.py: -------------------------------------------------------------------------------- 1 | """Sub-package for federated learning used by torchfl.""" 2 | 3 | __author__ = """Vivek Khimani""" 4 | __email__ = "vivekkhimani07@gmail.com" 5 | __version__ = "0.9.0" 6 | __all__ = ["entrypoint", "fl_params"] 7 | -------------------------------------------------------------------------------- /torchfl/federated/agents/__init__.py: -------------------------------------------------------------------------------- 1 | """Sub-package for federated learning agents used by torchfl.""" 2 | 3 | __author__ = """Vivek Khimani""" 4 | __email__ = "vivekkhimani07@gmail.com" 5 | __version__ = "0.9.0" 6 | __all__ = ["v1"] 7 | -------------------------------------------------------------------------------- /torchfl/federated/agents/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Base Agent class used in FL.""" 4 | 5 | from abc import ABCMeta, abstractmethod 6 | from typing import Any 7 | 8 | import pytorch_lightning as pl 9 | from torch.utils.data import DataLoader 10 | 11 | from torchfl.federated.fl_params import FLParams 12 | 13 | 14 | class BaseAgent(metaclass=ABCMeta): 15 | """BaseAgent class used in FL.""" 16 | 17 | def __init__( 18 | self, 19 | id: int, 20 | data_shard: DataLoader, 21 | model: Any, 22 | ) -> None: 23 | """Constructor.""" 24 | self.id: int = id 25 | self.data_shard: DataLoader = data_shard 26 | self.model: Any = model 27 | 28 | def assign_model(self, model: Any) -> None: 29 | """Assign a model to the agent.""" 30 | self.model.load_state_dict(model.state_dict()) 31 | 32 | def assign_data_shard(self, data_shard: DataLoader) -> None: 33 | """Assign a data shard to the agent.""" 34 | self.data_shard = data_shard 35 | 36 | @abstractmethod 37 | def train( 38 | self, 39 | trainer: pl.Trainer, 40 | fl_params: FLParams, 41 | ) -> None: 42 | """Train the agent.""" 43 | raise NotImplementedError 44 | -------------------------------------------------------------------------------- /torchfl/federated/agents/v1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """V1 Agent class used in FL.""" 4 | 5 | from typing import Any 6 | 7 | import pytorch_lightning as pl 8 | from torch.utils.data import DataLoader, random_split 9 | 10 | from torchfl.federated.agents.base import BaseAgent 11 | from torchfl.federated.fl_params import FLParams 12 | 13 | pl.seed_everything(42) 14 | 15 | 16 | class V1Agent(BaseAgent): 17 | """V1Agent class used in FL.""" 18 | 19 | def __init__( 20 | self, id: int, data_shard: DataLoader, model: Any | None = None 21 | ) -> None: 22 | """Constructor.""" 23 | super().__init__(id, data_shard, model) 24 | 25 | def train( 26 | self, 27 | trainer: pl.Trainer, 28 | fl_params: FLParams, 29 | ) -> Any: 30 | """ 31 | Train the agent. 32 | 33 | Args: 34 | trainer (pl.Trainer): Trainer object used to train the model. 35 | fl_params (FLParams): FLParams object containing the FL parameters. 36 | """ 37 | if self.model is None: 38 | raise ValueError( 39 | f"Model is not assigned to the agent with id={self.id}." 40 | ) 41 | 42 | train_data_shard_len = int( 43 | len(self.data_shard.dataset) # type:ignore 44 | * fl_params.local_train_split 45 | ) 46 | test_data_shard_len = ( 47 | len(self.data_shard.dataset) - train_data_shard_len # type:ignore 48 | ) 49 | train_data_shard, val_data_shard = random_split( 50 | self.data_shard.dataset, 51 | [train_data_shard_len, test_data_shard_len], 52 | ) 53 | train_dataloader = DataLoader( 54 | train_data_shard, 55 | batch_size=fl_params.local_train_batch_size, 56 | shuffle=True, 57 | ) 58 | val_dataloader = DataLoader( 59 | val_data_shard, 60 | batch_size=fl_params.local_test_batch_size, 61 | shuffle=False, 62 | ) 63 | 64 | trainer.fit(self.model, train_dataloader, val_dataloader) 65 | # test best model based on the validation and test set 66 | val_result: list[dict[str, float]] = trainer.test( 67 | self.model, dataloaders=val_dataloader, verbose=True 68 | ) 69 | test_result: list[dict[str, float]] = trainer.test( 70 | self.model, dataloaders=val_dataloader, verbose=True 71 | ) 72 | result: dict[str, float] = { # type:ignore 73 | "test_acc": test_result[0][ 74 | f"{fl_params.experiment_name}_test_acc" 75 | ], 76 | "val_acc": val_result[0][f"{fl_params.experiment_name}_test_acc"], 77 | } 78 | return self.model, result 79 | -------------------------------------------------------------------------------- /torchfl/federated/aggregators/__init__.py: -------------------------------------------------------------------------------- 1 | """Sub-package for federated learning aggregators used by torchfl.""" 2 | 3 | __author__ = """Vivek Khimani""" 4 | __email__ = "vivekkhimani07@gmail.com" 5 | __version__ = "0.9.0" 6 | __all__ = ["fedavg"] 7 | -------------------------------------------------------------------------------- /torchfl/federated/aggregators/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Base Aggregator class used in FL.""" 4 | from abc import ABCMeta, abstractmethod 5 | from typing import Any 6 | 7 | 8 | class BaseAggregator(metaclass=ABCMeta): 9 | """BaseAggregator class used in FL.""" 10 | 11 | def __init__(self, all_agents: list[Any]) -> None: 12 | """Constructor.""" 13 | super().__init__() 14 | self.agents: list[Any] = all_agents 15 | 16 | @abstractmethod 17 | def aggregate( 18 | self, global_model: Any, agent_models_map: dict[int, Any] 19 | ) -> Any: 20 | """ 21 | Aggregate the weights of the agents. Compute the new global model using agent_models_map and update the models of all the agents. 22 | 23 | Args: 24 | global_model: global model 25 | agent_models_map: map of agent id to agent model 26 | 27 | Returns: 28 | new global model 29 | """ 30 | raise NotImplementedError 31 | -------------------------------------------------------------------------------- /torchfl/federated/aggregators/fedavg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """FedAvg Aggregator class used in FL.""" 4 | 5 | from collections import OrderedDict 6 | from typing import Any 7 | 8 | import torch 9 | 10 | from torchfl.federated.aggregators.base import BaseAggregator 11 | 12 | 13 | class FedAvgAggregator(BaseAggregator): 14 | """FedAvgAggregator class used in FL.""" 15 | 16 | def __init__(self, all_agents: list[Any]) -> None: 17 | """Constructor.""" 18 | super().__init__(all_agents) 19 | 20 | def aggregate( 21 | self, global_model: Any, agent_models_map: dict[int, Any] 22 | ) -> Any: 23 | """ 24 | Aggregate the weights of the agents. Compute the new global model using agent_models_map and update the models of all the agents. 25 | 26 | Args: 27 | global_model (Any): Global model used in the FL experiment. 28 | agent_models_map (Dict[int, Any]): map of agent id to agent model 29 | 30 | Returns: 31 | new global model 32 | """ 33 | w_avg: dict[Any, Any] = OrderedDict() 34 | for _, models in agent_models_map.items(): 35 | for key in global_model.state_dict().keys(): 36 | if key not in w_avg.keys(): 37 | w_avg[key] = models[key].clone() 38 | else: 39 | w_avg[key] += models[key].clone() 40 | for key in w_avg.keys(): 41 | w_avg[key] = torch.divide(w_avg[key], len(agent_models_map)) 42 | return w_avg 43 | -------------------------------------------------------------------------------- /torchfl/federated/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | """Sub-package for federated learning samplers used by torchfl.""" 2 | 3 | __author__ = """Vivek Khimani""" 4 | __email__ = "vivekkhimani07@gmail.com" 5 | __version__ = "0.9.0" 6 | __all__ = ["random"] 7 | -------------------------------------------------------------------------------- /torchfl/federated/samplers/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Base Sampler class used in FL.""" 4 | from abc import ABCMeta, abstractmethod 5 | from typing import Any 6 | 7 | 8 | class BaseSampler(metaclass=ABCMeta): 9 | """BaseSampler class used in FL.""" 10 | 11 | def __init__(self, all_agents: list[Any]) -> None: 12 | """Constructor.""" 13 | super().__init__() 14 | self.agents: list[Any] = all_agents 15 | 16 | @abstractmethod 17 | def sample(self, num: int) -> list[Any]: 18 | """ 19 | Sample agents. 20 | 21 | Args: 22 | num: number of agents to sample 23 | 24 | Returns: 25 | List of sampled agents 26 | """ 27 | raise NotImplementedError 28 | -------------------------------------------------------------------------------- /torchfl/federated/samplers/random.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Random Sampler class used in FL.""" 4 | 5 | import random 6 | from typing import Any 7 | 8 | from torchfl.federated.samplers.base import BaseSampler 9 | 10 | random.seed(42) 11 | 12 | 13 | class RandomSampler(BaseSampler): 14 | """RandomSampler class used in FL.""" 15 | 16 | def __init__(self, all_agents: list[Any]) -> None: 17 | """Constructor.""" 18 | super().__init__(all_agents=all_agents) 19 | 20 | def sample(self, num: int) -> list[Any]: 21 | """ 22 | Sample agents. 23 | 24 | Args: 25 | num: number of agents to sample 26 | 27 | Returns: 28 | List of sampled agents 29 | """ 30 | return random.sample(self.agents, num) 31 | -------------------------------------------------------------------------------- /torchfl/federated/types.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Types used within the federated learning utilities.""" 4 | 5 | import enum 6 | 7 | from torchfl.federated.agents.base import BaseAgent 8 | from torchfl.federated.agents.v1 import V1Agent 9 | from torchfl.federated.aggregators.base import BaseAggregator 10 | from torchfl.federated.aggregators.fedavg import FedAvgAggregator 11 | from torchfl.federated.samplers.base import BaseSampler 12 | from torchfl.federated.samplers.random import RandomSampler 13 | 14 | 15 | # enums 16 | class AgentsEnum(enum.Enum): 17 | """Enum class for the supported agent types.""" 18 | 19 | BASE = BaseAgent 20 | V1 = V1Agent 21 | 22 | 23 | class AggregatorsEnum(enum.Enum): 24 | """Enum class for the supported aggregator types.""" 25 | 26 | BASE = BaseAggregator 27 | FEDAVG = FedAvgAggregator 28 | 29 | 30 | class SamplersEnum(enum.Enum): 31 | """Enum class for the supported sampler types.""" 32 | 33 | BASE = BaseSampler 34 | RANDOM = RandomSampler 35 | 36 | 37 | # type aliases 38 | AgentsType = BaseAgent | V1Agent 39 | AggregatorsType = BaseAggregator | FedAvgAggregator 40 | SamplersType = BaseSampler | RandomSampler 41 | 42 | AGENTS_TYPE = ["v1"] 43 | AGGREGATORS_TYPE = ["fedavg"] 44 | SAMPLERS_TYPE = ["random"] 45 | -------------------------------------------------------------------------------- /torchfl/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Sub-package for models used by torchfl.""" 2 | 3 | __author__ = """Vivek Khimani""" 4 | __email__ = "vivekkhimani07@gmail.com" 5 | __version__ = "0.1.0" 6 | -------------------------------------------------------------------------------- /torchfl/models/core/__init__.py: -------------------------------------------------------------------------------- 1 | """Sub-package for core model implementations used by torchfl.""" 2 | 3 | __author__ = """Vivek Khimani""" 4 | __email__ = "vivekkhimani07@gmail.com" 5 | __version__ = "0.1.0" 6 | -------------------------------------------------------------------------------- /torchfl/models/core/cifar/__init__.py: -------------------------------------------------------------------------------- 1 | """Sub-package for CIFAR model implementations provided by torchfl.""" 2 | 3 | __author__ = """Vivek Khimani""" 4 | __email__ = "vivekkhimani07@gmail.com" 5 | __version__ = "0.1.0" 6 | -------------------------------------------------------------------------------- /torchfl/models/core/cifar/cifar10/__init__.py: -------------------------------------------------------------------------------- 1 | """Sub-package for CIFAR10 model implementations provided by torchfl.""" 2 | 3 | __author__ = """Vivek Khimani""" 4 | __email__ = "vivekkhimani07@gmail.com" 5 | __version__ = "0.1.0" 6 | -------------------------------------------------------------------------------- /torchfl/models/core/cifar/cifar10/alexnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the AlexNet model implementations for CIFAR10 dataset.""" 4 | 5 | from torchfl.models.sota.alexnet import AlexNet as BaseAlexNet 6 | 7 | 8 | class AlexNet(BaseAlexNet): 9 | def __init__( 10 | self, pre_trained=True, feature_extract=False, num_channels=3 11 | ) -> None: 12 | """Constructor 13 | 14 | Args: 15 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 16 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 17 | - num_channels (int, optional): Number of incoming channels. Defaults to 3. 18 | """ 19 | super().__init__( 20 | pre_trained=pre_trained, 21 | feature_extract=feature_extract, 22 | num_channels=num_channels, 23 | num_classes=10, 24 | act_fn_name="relu", 25 | ) 26 | -------------------------------------------------------------------------------- /torchfl/models/core/cifar/cifar10/densenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the DenseNet model implementations for CIFAR10 dataset. 4 | 5 | Contains: 6 | - DenseNet121 7 | - DenseNet161 8 | - DenseNet169 9 | - DenseNet201 10 | """ 11 | 12 | from torchfl.models.sota.densenet import DenseNet121 as BaseDenseNet121 13 | from torchfl.models.sota.densenet import DenseNet161 as BaseDenseNet161 14 | from torchfl.models.sota.densenet import DenseNet169 as BaseDenseNet169 15 | from torchfl.models.sota.densenet import DenseNet201 as BaseDenseNet201 16 | 17 | 18 | class DenseNet121(BaseDenseNet121): 19 | def __init__( 20 | self, pre_trained=True, feature_extract=False, num_channels=3 21 | ) -> None: 22 | """Constructor 23 | 24 | Args: 25 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 26 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 27 | - num_channels (int, optional): Number of incoming channels. Defaults to 3. 28 | """ 29 | super().__init__( 30 | pre_trained=pre_trained, 31 | feature_extract=feature_extract, 32 | num_channels=num_channels, 33 | num_classes=10, 34 | act_fn_name="relu", 35 | ) 36 | 37 | 38 | class DenseNet161(BaseDenseNet161): 39 | def __init__( 40 | self, pre_trained=True, feature_extract=False, num_channels=3 41 | ) -> None: 42 | """Constructor 43 | 44 | Args: 45 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 46 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 47 | - num_channels (int, optional): Number of incoming channels. Defaults to 3. 48 | """ 49 | super().__init__( 50 | pre_trained=pre_trained, 51 | feature_extract=feature_extract, 52 | num_channels=num_channels, 53 | num_classes=10, 54 | act_fn_name="relu", 55 | ) 56 | 57 | 58 | class DenseNet169(BaseDenseNet169): 59 | def __init__( 60 | self, pre_trained=True, feature_extract=False, num_channels=3 61 | ) -> None: 62 | """Constructor 63 | 64 | Args: 65 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 66 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 67 | - num_channels (int, optional): Number of incoming channels. Defaults to 3. 68 | """ 69 | super().__init__( 70 | pre_trained=pre_trained, 71 | feature_extract=feature_extract, 72 | num_channels=num_channels, 73 | num_classes=10, 74 | act_fn_name="relu", 75 | ) 76 | 77 | 78 | class DenseNet201(BaseDenseNet201): 79 | def __init__( 80 | self, pre_trained=True, feature_extract=False, num_channels=3 81 | ) -> None: 82 | """Constructor 83 | 84 | Args: 85 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 86 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 87 | - num_channels (int, optional): Number of incoming channels. Defaults to 3. 88 | """ 89 | super().__init__( 90 | pre_trained=pre_trained, 91 | feature_extract=feature_extract, 92 | num_channels=num_channels, 93 | num_classes=10, 94 | act_fn_name="relu", 95 | ) 96 | -------------------------------------------------------------------------------- /torchfl/models/core/cifar/cifar10/lenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the LeNet model implementations for CIFAR10 dataset.""" 4 | 5 | from torchfl.models.sota.lenet import LeNet as BaseLeNet 6 | 7 | 8 | class LeNet(BaseLeNet): 9 | def __init__(self, num_channels=3) -> None: 10 | """Constructor 11 | 12 | Args: 13 | - num_channels (int, optional): Number of incoming channels. Defaults to 3. 14 | """ 15 | super().__init__( 16 | num_classes=10, num_channels=num_channels, act_fn_name="tanh" 17 | ) 18 | -------------------------------------------------------------------------------- /torchfl/models/core/cifar/cifar10/mobilenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the MobileNet model implementations for FashionMNIST dataset. 4 | 5 | Contains: 6 | - MobileNetV2 7 | - MobileNetV3Small 8 | - MobileNetV3Large 9 | """ 10 | 11 | from torchfl.models.sota.mobilenet import MobileNetV2 as BaseMobileNetV2 12 | from torchfl.models.sota.mobilenet import ( 13 | MobileNetV3Large as BaseMobileNetV3Large, 14 | ) 15 | from torchfl.models.sota.mobilenet import ( 16 | MobileNetV3Small as BaseMobileNetV3Small, 17 | ) 18 | 19 | 20 | class MobileNetV2(BaseMobileNetV2): 21 | def __init__( 22 | self, pre_trained=True, feature_extract=False, num_channels=3 23 | ) -> None: 24 | """Constructor 25 | 26 | Args: 27 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 28 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 29 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 30 | """ 31 | super().__init__( 32 | pre_trained=pre_trained, 33 | feature_extract=feature_extract, 34 | num_classes=10, 35 | num_channels=num_channels, 36 | act_fn_name="relu", 37 | ) 38 | 39 | 40 | class MobileNetV3Small(BaseMobileNetV3Small): 41 | def __init__( 42 | self, pre_trained=True, feature_extract=False, num_channels=3 43 | ) -> None: 44 | """Constructor 45 | 46 | Args: 47 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 48 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 49 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 50 | """ 51 | super().__init__( 52 | pre_trained=pre_trained, 53 | feature_extract=feature_extract, 54 | num_classes=10, 55 | num_channels=num_channels, 56 | act_fn_name="relu", 57 | ) 58 | 59 | 60 | class MobileNetV3Large(BaseMobileNetV3Large): 61 | def __init__( 62 | self, pre_trained=True, feature_extract=False, num_channels=3 63 | ) -> None: 64 | """Constructor 65 | 66 | Args: 67 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 68 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 69 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 70 | """ 71 | super().__init__( 72 | pre_trained=pre_trained, 73 | feature_extract=feature_extract, 74 | num_classes=10, 75 | num_channels=num_channels, 76 | act_fn_name="relu", 77 | ) 78 | -------------------------------------------------------------------------------- /torchfl/models/core/cifar/cifar10/squeezenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the SqueezeNet model implementations for CIFAR10 dataset. 4 | 5 | Contains: 6 | - SqueezeNet1_0 7 | - SqueezeNet1_1 8 | """ 9 | 10 | from torchfl.models.sota.squeezenet import SqueezeNet1_0 as BaseSqueezeNet1_0 11 | from torchfl.models.sota.squeezenet import SqueezeNet1_1 as BaseSqueezeNet1_1 12 | 13 | 14 | class SqueezeNet1_0(BaseSqueezeNet1_0): 15 | def __init__( 16 | self, pre_trained=True, feature_extract=False, num_channels=3 17 | ) -> None: 18 | """Constructor 19 | 20 | Args: 21 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 22 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 23 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 24 | """ 25 | super().__init__( 26 | pre_trained=pre_trained, 27 | feature_extract=feature_extract, 28 | num_classes=10, 29 | num_channels=num_channels, 30 | act_fn_name="relu", 31 | ) 32 | 33 | 34 | class SqueezeNet1_1(BaseSqueezeNet1_1): 35 | def __init__( 36 | self, pre_trained=True, feature_extract=False, num_channels=3 37 | ) -> None: 38 | """Constructor 39 | 40 | Args: 41 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 42 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 43 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 44 | """ 45 | super().__init__( 46 | pre_trained=pre_trained, 47 | feature_extract=feature_extract, 48 | num_classes=10, 49 | num_channels=num_channels, 50 | act_fn_name="relu", 51 | ) 52 | -------------------------------------------------------------------------------- /torchfl/models/core/cifar/cifar100/__init__.py: -------------------------------------------------------------------------------- 1 | """Sub-package for CIFAR100 model implementations provided by torchfl.""" 2 | 3 | __author__ = """Vivek Khimani""" 4 | __email__ = "vivekkhimani07@gmail.com" 5 | __version__ = "0.1.0" 6 | -------------------------------------------------------------------------------- /torchfl/models/core/cifar/cifar100/alexnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the AlexNet model implementations for CIFAR100 dataset.""" 4 | 5 | from torchfl.models.sota.alexnet import AlexNet as BaseAlexNet 6 | 7 | 8 | class AlexNet(BaseAlexNet): 9 | def __init__( 10 | self, pre_trained=True, feature_extract=False, num_channels=3 11 | ) -> None: 12 | """Constructor 13 | 14 | Args: 15 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 16 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 17 | - num_channels (int, optional): Number of incoming channels. Defaults to 3. 18 | """ 19 | super().__init__( 20 | pre_trained=pre_trained, 21 | feature_extract=feature_extract, 22 | num_channels=num_channels, 23 | num_classes=100, 24 | act_fn_name="relu", 25 | ) 26 | -------------------------------------------------------------------------------- /torchfl/models/core/cifar/cifar100/densenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the DenseNet model implementations for CIFAR100 dataset. 4 | 5 | Contains: 6 | - DenseNet121 7 | - DenseNet161 8 | - DenseNet169 9 | - DenseNet201 10 | """ 11 | 12 | from torchfl.models.sota.densenet import DenseNet121 as BaseDenseNet121 13 | from torchfl.models.sota.densenet import DenseNet161 as BaseDenseNet161 14 | from torchfl.models.sota.densenet import DenseNet169 as BaseDenseNet169 15 | from torchfl.models.sota.densenet import DenseNet201 as BaseDenseNet201 16 | 17 | 18 | class DenseNet121(BaseDenseNet121): 19 | def __init__( 20 | self, pre_trained=True, feature_extract=False, num_channels=3 21 | ) -> None: 22 | """Constructor 23 | 24 | Args: 25 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 26 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 27 | - num_channels (int, optional): Number of incoming channels. Defaults to 3. 28 | """ 29 | super().__init__( 30 | pre_trained=pre_trained, 31 | feature_extract=feature_extract, 32 | num_channels=num_channels, 33 | num_classes=100, 34 | act_fn_name="relu", 35 | ) 36 | 37 | 38 | class DenseNet161(BaseDenseNet161): 39 | def __init__( 40 | self, pre_trained=True, feature_extract=False, num_channels=3 41 | ) -> None: 42 | """Constructor 43 | 44 | Args: 45 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 46 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 47 | - num_channels (int, optional): Number of incoming channels. Defaults to 3. 48 | """ 49 | super().__init__( 50 | pre_trained=pre_trained, 51 | feature_extract=feature_extract, 52 | num_channels=num_channels, 53 | num_classes=100, 54 | act_fn_name="relu", 55 | ) 56 | 57 | 58 | class DenseNet169(BaseDenseNet169): 59 | def __init__( 60 | self, pre_trained=True, feature_extract=False, num_channels=3 61 | ) -> None: 62 | """Constructor 63 | 64 | Args: 65 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 66 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 67 | - num_channels (int, optional): Number of incoming channels. Defaults to 3. 68 | """ 69 | super().__init__( 70 | pre_trained=pre_trained, 71 | feature_extract=feature_extract, 72 | num_channels=num_channels, 73 | num_classes=100, 74 | act_fn_name="relu", 75 | ) 76 | 77 | 78 | class DenseNet201(BaseDenseNet201): 79 | def __init__( 80 | self, pre_trained=True, feature_extract=False, num_channels=3 81 | ) -> None: 82 | """Constructor 83 | 84 | Args: 85 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 86 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 87 | - num_channels (int, optional): Number of incoming channels. Defaults to 3. 88 | """ 89 | super().__init__( 90 | pre_trained=pre_trained, 91 | feature_extract=feature_extract, 92 | num_channels=num_channels, 93 | num_classes=100, 94 | act_fn_name="relu", 95 | ) 96 | -------------------------------------------------------------------------------- /torchfl/models/core/cifar/cifar100/lenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the LeNet model implementations for CIFAR100 dataset.""" 4 | 5 | from torchfl.models.sota.lenet import LeNet as BaseLeNet 6 | 7 | 8 | class LeNet(BaseLeNet): 9 | def __init__(self, num_channels=3) -> None: 10 | """Constructor 11 | 12 | Args: 13 | - num_channels (int, optional): Number of incoming channels. Defaults to 3. 14 | """ 15 | super().__init__( 16 | num_classes=100, num_channels=num_channels, act_fn_name="tanh" 17 | ) 18 | -------------------------------------------------------------------------------- /torchfl/models/core/cifar/cifar100/mobilenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the MobileNet model implementations for FashionMNIST dataset. 4 | 5 | Contains: 6 | - MobileNetV2 7 | - MobileNetV3Small 8 | - MobileNetV3Large 9 | """ 10 | 11 | from torchfl.models.sota.mobilenet import MobileNetV2 as BaseMobileNetV2 12 | from torchfl.models.sota.mobilenet import ( 13 | MobileNetV3Large as BaseMobileNetV3Large, 14 | ) 15 | from torchfl.models.sota.mobilenet import ( 16 | MobileNetV3Small as BaseMobileNetV3Small, 17 | ) 18 | 19 | 20 | class MobileNetV2(BaseMobileNetV2): 21 | def __init__( 22 | self, pre_trained=True, feature_extract=False, num_channels=3 23 | ) -> None: 24 | """Constructor 25 | 26 | Args: 27 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 28 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 29 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 30 | """ 31 | super().__init__( 32 | pre_trained=pre_trained, 33 | feature_extract=feature_extract, 34 | num_classes=100, 35 | num_channels=num_channels, 36 | act_fn_name="relu", 37 | ) 38 | 39 | 40 | class MobileNetV3Small(BaseMobileNetV3Small): 41 | def __init__( 42 | self, pre_trained=True, feature_extract=False, num_channels=3 43 | ) -> None: 44 | """Constructor 45 | 46 | Args: 47 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 48 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 49 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 50 | """ 51 | super().__init__( 52 | pre_trained=pre_trained, 53 | feature_extract=feature_extract, 54 | num_classes=100, 55 | num_channels=num_channels, 56 | act_fn_name="relu", 57 | ) 58 | 59 | 60 | class MobileNetV3Large(BaseMobileNetV3Large): 61 | def __init__( 62 | self, pre_trained=True, feature_extract=False, num_channels=3 63 | ) -> None: 64 | """Constructor 65 | 66 | Args: 67 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 68 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 69 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 70 | """ 71 | super().__init__( 72 | pre_trained=pre_trained, 73 | feature_extract=feature_extract, 74 | num_classes=100, 75 | num_channels=num_channels, 76 | act_fn_name="relu", 77 | ) 78 | -------------------------------------------------------------------------------- /torchfl/models/core/cifar/cifar100/squeezenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the SqueezeNet model implementations for CIFAR100 dataset. 4 | 5 | Contains: 6 | - SqueezeNet1_0 7 | - SqueezeNet1_1 8 | """ 9 | 10 | from torchfl.models.sota.squeezenet import SqueezeNet1_0 as BaseSqueezeNet1_0 11 | from torchfl.models.sota.squeezenet import SqueezeNet1_1 as BaseSqueezeNet1_1 12 | 13 | 14 | class SqueezeNet1_0(BaseSqueezeNet1_0): 15 | def __init__( 16 | self, pre_trained=True, feature_extract=False, num_channels=3 17 | ) -> None: 18 | """Constructor 19 | 20 | Args: 21 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 22 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 23 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 24 | """ 25 | super().__init__( 26 | pre_trained=pre_trained, 27 | feature_extract=feature_extract, 28 | num_classes=100, 29 | num_channels=num_channels, 30 | act_fn_name="relu", 31 | ) 32 | 33 | 34 | class SqueezeNet1_1(BaseSqueezeNet1_1): 35 | def __init__( 36 | self, pre_trained=True, feature_extract=False, num_channels=3 37 | ) -> None: 38 | """Constructor 39 | 40 | Args: 41 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 42 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 43 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 44 | """ 45 | super().__init__( 46 | pre_trained=pre_trained, 47 | feature_extract=feature_extract, 48 | num_classes=100, 49 | num_channels=num_channels, 50 | act_fn_name="relu", 51 | ) 52 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/__init__.py: -------------------------------------------------------------------------------- 1 | """Sub-package for core EMNIST model implementations used by torchfl.""" 2 | 3 | __author__ = """Vivek Khimani""" 4 | __email__ = "vivekkhimani07@gmail.com" 5 | __version__ = "0.1.0" 6 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/balanced/__init__.py: -------------------------------------------------------------------------------- 1 | """Sub-package for EMNIST (balanced) model implementations provided by torchfl.""" 2 | 3 | __author__ = """Vivek Khimani""" 4 | __email__ = "vivekkhimani07@gmail.com" 5 | __version__ = "0.1.0" 6 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/balanced/alexnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the AlexNet model implementations for EMNIST (balanced) dataset.""" 4 | 5 | from torchfl.models.sota.alexnet import AlexNet as BaseAlexNet 6 | 7 | 8 | class AlexNet(BaseAlexNet): 9 | def __init__( 10 | self, pre_trained=True, feature_extract=False, num_channels=1 11 | ) -> None: 12 | """Constructor 13 | 14 | Args: 15 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 16 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 17 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 18 | """ 19 | super().__init__( 20 | pre_trained=pre_trained, 21 | feature_extract=feature_extract, 22 | num_channels=num_channels, 23 | num_classes=47, 24 | act_fn_name="relu", 25 | ) 26 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/balanced/densenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the DenseNet model implementations for EMNIST (balanced) dataset. 4 | 5 | Contains: 6 | - DenseNet121 7 | - DenseNet161 8 | - DenseNet169 9 | - DenseNet201 10 | """ 11 | 12 | from torchfl.models.sota.densenet import DenseNet121 as BaseDenseNet121 13 | from torchfl.models.sota.densenet import DenseNet161 as BaseDenseNet161 14 | from torchfl.models.sota.densenet import DenseNet169 as BaseDenseNet169 15 | from torchfl.models.sota.densenet import DenseNet201 as BaseDenseNet201 16 | 17 | 18 | class DenseNet121(BaseDenseNet121): 19 | def __init__( 20 | self, pre_trained=True, feature_extract=False, num_channels=1 21 | ) -> None: 22 | """Constructor 23 | 24 | Args: 25 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 26 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 27 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 28 | """ 29 | super().__init__( 30 | pre_trained=pre_trained, 31 | feature_extract=feature_extract, 32 | num_channels=num_channels, 33 | num_classes=47, 34 | act_fn_name="relu", 35 | ) 36 | 37 | 38 | class DenseNet161(BaseDenseNet161): 39 | def __init__( 40 | self, pre_trained=True, feature_extract=False, num_channels=1 41 | ) -> None: 42 | """Constructor 43 | 44 | Args: 45 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 46 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 47 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 48 | """ 49 | super().__init__( 50 | pre_trained=pre_trained, 51 | feature_extract=feature_extract, 52 | num_channels=num_channels, 53 | num_classes=47, 54 | act_fn_name="relu", 55 | ) 56 | 57 | 58 | class DenseNet169(BaseDenseNet169): 59 | def __init__( 60 | self, pre_trained=True, feature_extract=False, num_channels=1 61 | ) -> None: 62 | """Constructor 63 | 64 | Args: 65 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 66 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 67 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 68 | """ 69 | super().__init__( 70 | pre_trained=pre_trained, 71 | feature_extract=feature_extract, 72 | num_channels=num_channels, 73 | num_classes=47, 74 | act_fn_name="relu", 75 | ) 76 | 77 | 78 | class DenseNet201(BaseDenseNet201): 79 | def __init__( 80 | self, pre_trained=True, feature_extract=False, num_channels=1 81 | ) -> None: 82 | """Constructor 83 | 84 | Args: 85 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 86 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 87 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 88 | """ 89 | super().__init__( 90 | pre_trained=pre_trained, 91 | feature_extract=feature_extract, 92 | num_channels=num_channels, 93 | num_classes=47, 94 | act_fn_name="relu", 95 | ) 96 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/balanced/lenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the LeNet model implementations for EMNIST (balanced) dataset.""" 4 | 5 | from torchfl.models.sota.lenet import LeNet as BaseLeNet 6 | 7 | 8 | class LeNet(BaseLeNet): 9 | def __init__(self, num_channels=1) -> None: 10 | """Constructor 11 | 12 | Args: 13 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 14 | """ 15 | super().__init__( 16 | num_classes=47, num_channels=num_channels, act_fn_name="tanh" 17 | ) 18 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/balanced/mlp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the MLP model implementations for EMNIST (balanced) dataset.""" 4 | 5 | from torchfl.models.sota.mlp import MLP as BaseMLP 6 | 7 | 8 | class MLP(BaseMLP): 9 | def __init__(self, num_channels=1, img_w=28, img_h=28) -> None: 10 | """Constructor 11 | 12 | Args: 13 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 14 | - img_w (int, optional): Width of the input image. Defaults to 28. 15 | - img_h (int, optional): Height of the input image. Defaults to 28. 16 | """ 17 | super().__init__( 18 | num_classes=47, 19 | num_channels=num_channels, 20 | img_w=img_w, 21 | img_h=img_h, 22 | hidden_dims=[256, 128], 23 | ) 24 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/balanced/mobilenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the MobileNet model implementations for EMNIST (balanced) dataset. 4 | 5 | Contains: 6 | - MobileNetV2 7 | - MobileNetV3Small 8 | - MobileNetV3Large 9 | """ 10 | 11 | from torchfl.models.sota.mobilenet import MobileNetV2 as BaseMobileNetV2 12 | from torchfl.models.sota.mobilenet import ( 13 | MobileNetV3Large as BaseMobileNetV3Large, 14 | ) 15 | from torchfl.models.sota.mobilenet import ( 16 | MobileNetV3Small as BaseMobileNetV3Small, 17 | ) 18 | 19 | 20 | class MobileNetV2(BaseMobileNetV2): 21 | def __init__( 22 | self, pre_trained=True, feature_extract=False, num_channels=1 23 | ) -> None: 24 | """Constructor 25 | 26 | Args: 27 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 28 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 29 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 30 | """ 31 | super().__init__( 32 | pre_trained=pre_trained, 33 | feature_extract=feature_extract, 34 | num_classes=47, 35 | num_channels=num_channels, 36 | act_fn_name="relu", 37 | ) 38 | 39 | 40 | class MobileNetV3Small(BaseMobileNetV3Small): 41 | def __init__( 42 | self, pre_trained=True, feature_extract=False, num_channels=1 43 | ) -> None: 44 | """Constructor 45 | 46 | Args: 47 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 48 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 49 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 50 | """ 51 | super().__init__( 52 | pre_trained=pre_trained, 53 | feature_extract=feature_extract, 54 | num_classes=47, 55 | num_channels=num_channels, 56 | act_fn_name="relu", 57 | ) 58 | 59 | 60 | class MobileNetV3Large(BaseMobileNetV3Large): 61 | def __init__( 62 | self, pre_trained=True, feature_extract=False, num_channels=1 63 | ) -> None: 64 | """Constructor 65 | 66 | Args: 67 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 68 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 69 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 70 | """ 71 | super().__init__( 72 | pre_trained=pre_trained, 73 | feature_extract=feature_extract, 74 | num_classes=47, 75 | num_channels=num_channels, 76 | act_fn_name="relu", 77 | ) 78 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/balanced/squeezenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the SqueezeNet model implementations for EMNIST (balanced) dataset. 4 | 5 | Contains: 6 | - SqueezeNet1_0 7 | - SqueezeNet1_1 8 | """ 9 | 10 | from torchfl.models.sota.squeezenet import SqueezeNet1_0 as BaseSqueezeNet1_0 11 | from torchfl.models.sota.squeezenet import SqueezeNet1_1 as BaseSqueezeNet1_1 12 | 13 | 14 | class SqueezeNet1_0(BaseSqueezeNet1_0): 15 | def __init__( 16 | self, pre_trained=True, feature_extract=False, num_channels=1 17 | ) -> None: 18 | """Constructor 19 | 20 | Args: 21 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 22 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 23 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 24 | """ 25 | super().__init__( 26 | pre_trained=pre_trained, 27 | feature_extract=feature_extract, 28 | num_classes=47, 29 | num_channels=num_channels, 30 | act_fn_name="relu", 31 | ) 32 | 33 | 34 | class SqueezeNet1_1(BaseSqueezeNet1_1): 35 | def __init__( 36 | self, pre_trained=True, feature_extract=False, num_channels=1 37 | ) -> None: 38 | """Constructor 39 | 40 | Args: 41 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 42 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 43 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 44 | """ 45 | super().__init__( 46 | pre_trained=pre_trained, 47 | feature_extract=feature_extract, 48 | num_classes=47, 49 | num_channels=num_channels, 50 | act_fn_name="relu", 51 | ) 52 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/byclass/__init__.py: -------------------------------------------------------------------------------- 1 | """Sub-package for EMNIST (by class) model implementations provided by torchfl.""" 2 | 3 | __author__ = """Vivek Khimani""" 4 | __email__ = "vivekkhimani07@gmail.com" 5 | __version__ = "0.1.0" 6 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/byclass/alexnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the AlexNet model implementations for EMNIST (by class) dataset.""" 4 | 5 | from torchfl.models.sota.alexnet import AlexNet as BaseAlexNet 6 | 7 | 8 | class AlexNet(BaseAlexNet): 9 | def __init__( 10 | self, pre_trained=True, feature_extract=False, num_channels=1 11 | ) -> None: 12 | """Constructor 13 | 14 | Args: 15 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 16 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 17 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 18 | """ 19 | super().__init__( 20 | pre_trained=pre_trained, 21 | feature_extract=feature_extract, 22 | num_channels=num_channels, 23 | num_classes=62, 24 | act_fn_name="relu", 25 | ) 26 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/byclass/densenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the DenseNet model implementations for EMNIST (by class) dataset. 4 | 5 | Contains: 6 | - DenseNet121 7 | - DenseNet161 8 | - DenseNet169 9 | - DenseNet201 10 | """ 11 | 12 | from torchfl.models.sota.densenet import DenseNet121 as BaseDenseNet121 13 | from torchfl.models.sota.densenet import DenseNet161 as BaseDenseNet161 14 | from torchfl.models.sota.densenet import DenseNet169 as BaseDenseNet169 15 | from torchfl.models.sota.densenet import DenseNet201 as BaseDenseNet201 16 | 17 | 18 | class DenseNet121(BaseDenseNet121): 19 | def __init__( 20 | self, pre_trained=True, feature_extract=False, num_channels=1 21 | ) -> None: 22 | """Constructor 23 | 24 | Args: 25 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 26 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 27 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 28 | """ 29 | super().__init__( 30 | pre_trained=pre_trained, 31 | feature_extract=feature_extract, 32 | num_channels=num_channels, 33 | num_classes=62, 34 | act_fn_name="relu", 35 | ) 36 | 37 | 38 | class DenseNet161(BaseDenseNet161): 39 | def __init__( 40 | self, pre_trained=True, feature_extract=False, num_channels=1 41 | ) -> None: 42 | """Constructor 43 | 44 | Args: 45 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 46 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 47 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 48 | """ 49 | super().__init__( 50 | pre_trained=pre_trained, 51 | feature_extract=feature_extract, 52 | num_channels=num_channels, 53 | num_classes=62, 54 | act_fn_name="relu", 55 | ) 56 | 57 | 58 | class DenseNet169(BaseDenseNet169): 59 | def __init__( 60 | self, pre_trained=True, feature_extract=False, num_channels=1 61 | ) -> None: 62 | """Constructor 63 | 64 | Args: 65 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 66 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 67 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 68 | """ 69 | super().__init__( 70 | pre_trained=pre_trained, 71 | feature_extract=feature_extract, 72 | num_channels=num_channels, 73 | num_classes=62, 74 | act_fn_name="relu", 75 | ) 76 | 77 | 78 | class DenseNet201(BaseDenseNet201): 79 | def __init__( 80 | self, pre_trained=True, feature_extract=False, num_channels=1 81 | ) -> None: 82 | """Constructor 83 | 84 | Args: 85 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 86 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 87 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 88 | """ 89 | super().__init__( 90 | pre_trained=pre_trained, 91 | feature_extract=feature_extract, 92 | num_channels=num_channels, 93 | num_classes=62, 94 | act_fn_name="relu", 95 | ) 96 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/byclass/lenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the LeNet model implementations for EMNIST (by class) dataset.""" 4 | 5 | from torchfl.models.sota.lenet import LeNet as BaseLeNet 6 | 7 | 8 | class LeNet(BaseLeNet): 9 | def __init__(self, num_channels=1) -> None: 10 | """Constructor 11 | 12 | Args: 13 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 14 | """ 15 | super().__init__( 16 | num_classes=62, num_channels=num_channels, act_fn_name="tanh" 17 | ) 18 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/byclass/mlp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the MLP model implementations for EMNIST (by class) dataset.""" 4 | 5 | from torchfl.models.sota.mlp import MLP as BaseMLP 6 | 7 | 8 | class MLP(BaseMLP): 9 | def __init__(self, num_channels=1, img_w=28, img_h=28) -> None: 10 | """Constructor 11 | 12 | Args: 13 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 14 | - img_w (int, optional): Width of the input image. Defaults to 28. 15 | - img_h (int, optional): Height of the input image. Defaults to 28. 16 | """ 17 | super().__init__( 18 | num_classes=62, 19 | num_channels=num_channels, 20 | img_w=img_w, 21 | img_h=img_h, 22 | hidden_dims=[256, 128], 23 | ) 24 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/byclass/mobilenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the MobileNet model implementations for EMNIST (balanced) dataset. 4 | 5 | Contains: 6 | - MobileNetV2 7 | - MobileNetV3Small 8 | - MobileNetV3Large 9 | """ 10 | 11 | from torchfl.models.sota.mobilenet import MobileNetV2 as BaseMobileNetV2 12 | from torchfl.models.sota.mobilenet import ( 13 | MobileNetV3Large as BaseMobileNetV3Large, 14 | ) 15 | from torchfl.models.sota.mobilenet import ( 16 | MobileNetV3Small as BaseMobileNetV3Small, 17 | ) 18 | 19 | 20 | class MobileNetV2(BaseMobileNetV2): 21 | def __init__( 22 | self, pre_trained=True, feature_extract=False, num_channels=1 23 | ) -> None: 24 | """Constructor 25 | 26 | Args: 27 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 28 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 29 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 30 | """ 31 | super().__init__( 32 | pre_trained=pre_trained, 33 | feature_extract=feature_extract, 34 | num_classes=62, 35 | num_channels=num_channels, 36 | act_fn_name="relu", 37 | ) 38 | 39 | 40 | class MobileNetV3Small(BaseMobileNetV3Small): 41 | def __init__( 42 | self, pre_trained=True, feature_extract=False, num_channels=1 43 | ) -> None: 44 | """Constructor 45 | 46 | Args: 47 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 48 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 49 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 50 | """ 51 | super().__init__( 52 | pre_trained=pre_trained, 53 | feature_extract=feature_extract, 54 | num_classes=62, 55 | num_channels=num_channels, 56 | act_fn_name="relu", 57 | ) 58 | 59 | 60 | class MobileNetV3Large(BaseMobileNetV3Large): 61 | def __init__( 62 | self, pre_trained=True, feature_extract=False, num_channels=1 63 | ) -> None: 64 | """Constructor 65 | 66 | Args: 67 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 68 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 69 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 70 | """ 71 | super().__init__( 72 | pre_trained=pre_trained, 73 | feature_extract=feature_extract, 74 | num_classes=62, 75 | num_channels=num_channels, 76 | act_fn_name="relu", 77 | ) 78 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/byclass/squeezenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the SqueezeNet model implementations for EMNIST (by class) dataset. 4 | 5 | Contains: 6 | - SqueezeNet1_0 7 | - SqueezeNet1_1 8 | """ 9 | 10 | from torchfl.models.sota.squeezenet import SqueezeNet1_0 as BaseSqueezeNet1_0 11 | from torchfl.models.sota.squeezenet import SqueezeNet1_1 as BaseSqueezeNet1_1 12 | 13 | 14 | class SqueezeNet1_0(BaseSqueezeNet1_0): 15 | def __init__( 16 | self, pre_trained=True, feature_extract=False, num_channels=1 17 | ) -> None: 18 | """Constructor 19 | 20 | Args: 21 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 22 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 23 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 24 | """ 25 | super().__init__( 26 | pre_trained=pre_trained, 27 | feature_extract=feature_extract, 28 | num_classes=62, 29 | num_channels=num_channels, 30 | act_fn_name="relu", 31 | ) 32 | 33 | 34 | class SqueezeNet1_1(BaseSqueezeNet1_1): 35 | def __init__( 36 | self, pre_trained=True, feature_extract=False, num_channels=1 37 | ) -> None: 38 | """Constructor 39 | 40 | Args: 41 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 42 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 43 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 44 | """ 45 | super().__init__( 46 | pre_trained=pre_trained, 47 | feature_extract=feature_extract, 48 | num_classes=62, 49 | num_channels=num_channels, 50 | act_fn_name="relu", 51 | ) 52 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/bymerge/__init__.py: -------------------------------------------------------------------------------- 1 | """Sub-package for EMNIST (by merge) model implementations provided by torchfl.""" 2 | 3 | __author__ = """Vivek Khimani""" 4 | __email__ = "vivekkhimani07@gmail.com" 5 | __version__ = "0.1.0" 6 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/bymerge/alexnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the AlexNet model implementations for EMNIST (by merge) dataset.""" 4 | 5 | from torchfl.models.sota.alexnet import AlexNet as BaseAlexNet 6 | 7 | 8 | class AlexNet(BaseAlexNet): 9 | def __init__( 10 | self, pre_trained=True, feature_extract=False, num_channels=1 11 | ) -> None: 12 | """Constructor 13 | 14 | Args: 15 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 16 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 17 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 18 | """ 19 | super().__init__( 20 | pre_trained=pre_trained, 21 | feature_extract=feature_extract, 22 | num_channels=num_channels, 23 | num_classes=47, 24 | act_fn_name="relu", 25 | ) 26 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/bymerge/densenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the DenseNet model implementations for EMNIST (by merge) dataset. 4 | 5 | Contains: 6 | - DenseNet121 7 | - DenseNet161 8 | - DenseNet169 9 | - DenseNet201 10 | """ 11 | 12 | from torchfl.models.sota.densenet import DenseNet121 as BaseDenseNet121 13 | from torchfl.models.sota.densenet import DenseNet161 as BaseDenseNet161 14 | from torchfl.models.sota.densenet import DenseNet169 as BaseDenseNet169 15 | from torchfl.models.sota.densenet import DenseNet201 as BaseDenseNet201 16 | 17 | 18 | class DenseNet121(BaseDenseNet121): 19 | def __init__( 20 | self, pre_trained=True, feature_extract=False, num_channels=1 21 | ) -> None: 22 | """Constructor 23 | 24 | Args: 25 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 26 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 27 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 28 | """ 29 | super().__init__( 30 | pre_trained=pre_trained, 31 | feature_extract=feature_extract, 32 | num_channels=num_channels, 33 | num_classes=47, 34 | act_fn_name="relu", 35 | ) 36 | 37 | 38 | class DenseNet161(BaseDenseNet161): 39 | def __init__( 40 | self, pre_trained=True, feature_extract=False, num_channels=1 41 | ) -> None: 42 | """Constructor 43 | 44 | Args: 45 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 46 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 47 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 48 | """ 49 | super().__init__( 50 | pre_trained=pre_trained, 51 | feature_extract=feature_extract, 52 | num_channels=num_channels, 53 | num_classes=47, 54 | act_fn_name="relu", 55 | ) 56 | 57 | 58 | class DenseNet169(BaseDenseNet169): 59 | def __init__( 60 | self, pre_trained=True, feature_extract=False, num_channels=1 61 | ) -> None: 62 | """Constructor 63 | 64 | Args: 65 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 66 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 67 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 68 | """ 69 | super().__init__( 70 | pre_trained=pre_trained, 71 | feature_extract=feature_extract, 72 | num_channels=num_channels, 73 | num_classes=47, 74 | act_fn_name="relu", 75 | ) 76 | 77 | 78 | class DenseNet201(BaseDenseNet201): 79 | def __init__( 80 | self, pre_trained=True, feature_extract=False, num_channels=1 81 | ) -> None: 82 | """Constructor 83 | 84 | Args: 85 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 86 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 87 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 88 | """ 89 | super().__init__( 90 | pre_trained=pre_trained, 91 | feature_extract=feature_extract, 92 | num_channels=num_channels, 93 | num_classes=47, 94 | act_fn_name="relu", 95 | ) 96 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/bymerge/lenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the LeNet model implementations for EMNIST (by merge) dataset.""" 4 | 5 | from torchfl.models.sota.lenet import LeNet as BaseLeNet 6 | 7 | 8 | class LeNet(BaseLeNet): 9 | def __init__(self, num_channels=1) -> None: 10 | """Constructor 11 | 12 | Args: 13 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 14 | """ 15 | super().__init__( 16 | num_classes=47, num_channels=num_channels, act_fn_name="tanh" 17 | ) 18 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/bymerge/mlp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the MLP model implementations for EMNIST (by merge) dataset.""" 4 | 5 | from torchfl.models.sota.mlp import MLP as BaseMLP 6 | 7 | 8 | class MLP(BaseMLP): 9 | def __init__(self, num_channels=1, img_w=28, img_h=28) -> None: 10 | """Constructor 11 | 12 | Args: 13 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 14 | - img_w (int, optional): Width of the input image. Defaults to 28. 15 | - img_h (int, optional): Height of the input image. Defaults to 28. 16 | """ 17 | super().__init__( 18 | num_classes=47, 19 | num_channels=num_channels, 20 | img_w=img_w, 21 | img_h=img_h, 22 | hidden_dims=[256, 128], 23 | ) 24 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/bymerge/mobilenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the MobileNet model implementations for EMNIST (balanced) dataset. 4 | 5 | Contains: 6 | - MobileNetV2 7 | - MobileNetV3Small 8 | - MobileNetV3Large 9 | """ 10 | 11 | from torchfl.models.sota.mobilenet import MobileNetV2 as BaseMobileNetV2 12 | from torchfl.models.sota.mobilenet import ( 13 | MobileNetV3Large as BaseMobileNetV3Large, 14 | ) 15 | from torchfl.models.sota.mobilenet import ( 16 | MobileNetV3Small as BaseMobileNetV3Small, 17 | ) 18 | 19 | 20 | class MobileNetV2(BaseMobileNetV2): 21 | def __init__( 22 | self, pre_trained=True, feature_extract=False, num_channels=1 23 | ) -> None: 24 | """Constructor 25 | 26 | Args: 27 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 28 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 29 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 30 | """ 31 | super().__init__( 32 | pre_trained=pre_trained, 33 | feature_extract=feature_extract, 34 | num_classes=47, 35 | num_channels=num_channels, 36 | act_fn_name="relu", 37 | ) 38 | 39 | 40 | class MobileNetV3Small(BaseMobileNetV3Small): 41 | def __init__( 42 | self, pre_trained=True, feature_extract=False, num_channels=1 43 | ) -> None: 44 | """Constructor 45 | 46 | Args: 47 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 48 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 49 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 50 | """ 51 | super().__init__( 52 | pre_trained=pre_trained, 53 | feature_extract=feature_extract, 54 | num_classes=47, 55 | num_channels=num_channels, 56 | act_fn_name="relu", 57 | ) 58 | 59 | 60 | class MobileNetV3Large(BaseMobileNetV3Large): 61 | def __init__( 62 | self, pre_trained=True, feature_extract=False, num_channels=1 63 | ) -> None: 64 | """Constructor 65 | 66 | Args: 67 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 68 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 69 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 70 | """ 71 | super().__init__( 72 | pre_trained=pre_trained, 73 | feature_extract=feature_extract, 74 | num_classes=47, 75 | num_channels=num_channels, 76 | act_fn_name="relu", 77 | ) 78 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/bymerge/squeezenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the SqueezeNet model implementations for EMNIST (by merge) dataset. 4 | 5 | Contains: 6 | - SqueezeNet1_0 7 | - SqueezeNet1_1 8 | """ 9 | 10 | from torchfl.models.sota.squeezenet import SqueezeNet1_0 as BaseSqueezeNet1_0 11 | from torchfl.models.sota.squeezenet import SqueezeNet1_1 as BaseSqueezeNet1_1 12 | 13 | 14 | class SqueezeNet1_0(BaseSqueezeNet1_0): 15 | def __init__( 16 | self, pre_trained=True, feature_extract=False, num_channels=1 17 | ) -> None: 18 | """Constructor 19 | 20 | Args: 21 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 22 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 23 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 24 | """ 25 | super().__init__( 26 | pre_trained=pre_trained, 27 | feature_extract=feature_extract, 28 | num_classes=47, 29 | num_channels=num_channels, 30 | act_fn_name="relu", 31 | ) 32 | 33 | 34 | class SqueezeNet1_1(BaseSqueezeNet1_1): 35 | def __init__( 36 | self, pre_trained=True, feature_extract=False, num_channels=1 37 | ) -> None: 38 | """Constructor 39 | 40 | Args: 41 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 42 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 43 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 44 | """ 45 | super().__init__( 46 | pre_trained=pre_trained, 47 | feature_extract=feature_extract, 48 | num_classes=47, 49 | num_channels=num_channels, 50 | act_fn_name="relu", 51 | ) 52 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/digits/__init__.py: -------------------------------------------------------------------------------- 1 | """Sub-package for EMNIST (digits) model implementations provided by torchfl.""" 2 | 3 | __author__ = """Vivek Khimani""" 4 | __email__ = "vivekkhimani07@gmail.com" 5 | __version__ = "0.1.0" 6 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/digits/alexnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the AlexNet model implementations for EMNIST (digits) dataset.""" 4 | 5 | from torchfl.models.sota.alexnet import AlexNet as BaseAlexNet 6 | 7 | 8 | class AlexNet(BaseAlexNet): 9 | def __init__( 10 | self, pre_trained=True, feature_extract=False, num_channels=1 11 | ) -> None: 12 | """Constructor 13 | 14 | Args: 15 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 16 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 17 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 18 | """ 19 | super().__init__( 20 | pre_trained=pre_trained, 21 | feature_extract=feature_extract, 22 | num_channels=num_channels, 23 | num_classes=10, 24 | act_fn_name="relu", 25 | ) 26 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/digits/densenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the DenseNet model implementations for EMNIST (digits) dataset. 4 | 5 | Contains: 6 | - DenseNet121 7 | - DenseNet161 8 | - DenseNet169 9 | - DenseNet201 10 | """ 11 | 12 | from torchfl.models.sota.densenet import DenseNet121 as BaseDenseNet121 13 | from torchfl.models.sota.densenet import DenseNet161 as BaseDenseNet161 14 | from torchfl.models.sota.densenet import DenseNet169 as BaseDenseNet169 15 | from torchfl.models.sota.densenet import DenseNet201 as BaseDenseNet201 16 | 17 | 18 | class DenseNet121(BaseDenseNet121): 19 | def __init__( 20 | self, pre_trained=True, feature_extract=False, num_channels=1 21 | ) -> None: 22 | """Constructor 23 | 24 | Args: 25 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 26 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 27 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 28 | """ 29 | super().__init__( 30 | pre_trained=pre_trained, 31 | feature_extract=feature_extract, 32 | num_channels=num_channels, 33 | num_classes=10, 34 | act_fn_name="relu", 35 | ) 36 | 37 | 38 | class DenseNet161(BaseDenseNet161): 39 | def __init__( 40 | self, pre_trained=True, feature_extract=False, num_channels=1 41 | ) -> None: 42 | """Constructor 43 | 44 | Args: 45 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 46 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 47 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 48 | """ 49 | super().__init__( 50 | pre_trained=pre_trained, 51 | feature_extract=feature_extract, 52 | num_channels=num_channels, 53 | num_classes=10, 54 | act_fn_name="relu", 55 | ) 56 | 57 | 58 | class DenseNet169(BaseDenseNet169): 59 | def __init__( 60 | self, pre_trained=True, feature_extract=False, num_channels=1 61 | ) -> None: 62 | """Constructor 63 | 64 | Args: 65 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 66 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 67 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 68 | """ 69 | super().__init__( 70 | pre_trained=pre_trained, 71 | feature_extract=feature_extract, 72 | num_channels=num_channels, 73 | num_classes=10, 74 | act_fn_name="relu", 75 | ) 76 | 77 | 78 | class DenseNet201(BaseDenseNet201): 79 | def __init__( 80 | self, pre_trained=True, feature_extract=False, num_channels=1 81 | ) -> None: 82 | """Constructor 83 | 84 | Args: 85 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 86 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 87 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 88 | """ 89 | super().__init__( 90 | pre_trained=pre_trained, 91 | feature_extract=feature_extract, 92 | num_channels=num_channels, 93 | num_classes=10, 94 | act_fn_name="relu", 95 | ) 96 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/digits/lenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the LeNet model implementations for EMNIST (digits) dataset.""" 4 | 5 | from torchfl.models.sota.lenet import LeNet as BaseLeNet 6 | 7 | 8 | class LeNet(BaseLeNet): 9 | def __init__(self, num_channels=1) -> None: 10 | """Constructor 11 | 12 | Args: 13 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 14 | """ 15 | super().__init__( 16 | num_classes=10, num_channels=num_channels, act_fn_name="tanh" 17 | ) 18 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/digits/mlp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the MLP model implementations for EMNIST (digits) dataset.""" 4 | 5 | from torchfl.models.sota.mlp import MLP as BaseMLP 6 | 7 | 8 | class MLP(BaseMLP): 9 | def __init__(self, num_channels=1, img_w=28, img_h=28) -> None: 10 | """Constructor 11 | 12 | Args: 13 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 14 | - img_w (int, optional): Width of the input image. Defaults to 28. 15 | - img_h (int, optional): Height of the input image. Defaults to 28. 16 | """ 17 | super().__init__( 18 | num_classes=10, 19 | num_channels=num_channels, 20 | img_w=img_w, 21 | img_h=img_h, 22 | hidden_dims=[256, 128], 23 | ) 24 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/digits/mobilenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the MobileNet model implementations for EMNIST (balanced) dataset. 4 | 5 | Contains: 6 | - MobileNetV2 7 | - MobileNetV3Small 8 | - MobileNetV3Large 9 | """ 10 | 11 | from torchfl.models.sota.mobilenet import MobileNetV2 as BaseMobileNetV2 12 | from torchfl.models.sota.mobilenet import ( 13 | MobileNetV3Large as BaseMobileNetV3Large, 14 | ) 15 | from torchfl.models.sota.mobilenet import ( 16 | MobileNetV3Small as BaseMobileNetV3Small, 17 | ) 18 | 19 | 20 | class MobileNetV2(BaseMobileNetV2): 21 | def __init__( 22 | self, pre_trained=True, feature_extract=False, num_channels=1 23 | ) -> None: 24 | """Constructor 25 | 26 | Args: 27 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 28 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 29 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 30 | """ 31 | super().__init__( 32 | pre_trained=pre_trained, 33 | feature_extract=feature_extract, 34 | num_classes=10, 35 | num_channels=num_channels, 36 | act_fn_name="relu", 37 | ) 38 | 39 | 40 | class MobileNetV3Small(BaseMobileNetV3Small): 41 | def __init__( 42 | self, pre_trained=True, feature_extract=False, num_channels=1 43 | ) -> None: 44 | """Constructor 45 | 46 | Args: 47 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 48 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 49 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 50 | """ 51 | super().__init__( 52 | pre_trained=pre_trained, 53 | feature_extract=feature_extract, 54 | num_classes=10, 55 | num_channels=num_channels, 56 | act_fn_name="relu", 57 | ) 58 | 59 | 60 | class MobileNetV3Large(BaseMobileNetV3Large): 61 | def __init__( 62 | self, pre_trained=True, feature_extract=False, num_channels=1 63 | ) -> None: 64 | """Constructor 65 | 66 | Args: 67 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 68 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 69 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 70 | """ 71 | super().__init__( 72 | pre_trained=pre_trained, 73 | feature_extract=feature_extract, 74 | num_classes=10, 75 | num_channels=num_channels, 76 | act_fn_name="relu", 77 | ) 78 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/digits/squeezenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the SqueezeNet model implementations for EMNIST (digits) dataset. 4 | 5 | Contains: 6 | - SqueezeNet1_0 7 | - SqueezeNet1_1 8 | """ 9 | 10 | from torchfl.models.sota.squeezenet import SqueezeNet1_0 as BaseSqueezeNet1_0 11 | from torchfl.models.sota.squeezenet import SqueezeNet1_1 as BaseSqueezeNet1_1 12 | 13 | 14 | class SqueezeNet1_0(BaseSqueezeNet1_0): 15 | def __init__( 16 | self, pre_trained=True, feature_extract=False, num_channels=1 17 | ) -> None: 18 | """Constructor 19 | 20 | Args: 21 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 22 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 23 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 24 | """ 25 | super().__init__( 26 | pre_trained=pre_trained, 27 | feature_extract=feature_extract, 28 | num_classes=10, 29 | num_channels=num_channels, 30 | act_fn_name="relu", 31 | ) 32 | 33 | 34 | class SqueezeNet1_1(BaseSqueezeNet1_1): 35 | def __init__( 36 | self, pre_trained=True, feature_extract=False, num_channels=1 37 | ) -> None: 38 | """Constructor 39 | 40 | Args: 41 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 42 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 43 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 44 | """ 45 | super().__init__( 46 | pre_trained=pre_trained, 47 | feature_extract=feature_extract, 48 | num_classes=10, 49 | num_channels=num_channels, 50 | act_fn_name="relu", 51 | ) 52 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/letters/__init__.py: -------------------------------------------------------------------------------- 1 | """Sub-package for EMNIST (letters) model implementations provided by torchfl.""" 2 | 3 | __author__ = """Vivek Khimani""" 4 | __email__ = "vivekkhimani07@gmail.com" 5 | __version__ = "0.1.0" 6 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/letters/alexnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the AlexNet model implementations for EMNIST (letters) dataset.""" 4 | 5 | from torchfl.models.sota.alexnet import AlexNet as BaseAlexNet 6 | 7 | 8 | class AlexNet(BaseAlexNet): 9 | def __init__( 10 | self, pre_trained=True, feature_extract=False, num_channels=1 11 | ) -> None: 12 | """Constructor 13 | 14 | Args: 15 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 16 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 17 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 18 | """ 19 | super().__init__( 20 | pre_trained=pre_trained, 21 | feature_extract=feature_extract, 22 | num_channels=num_channels, 23 | num_classes=26, 24 | act_fn_name="relu", 25 | ) 26 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/letters/densenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the DenseNet model implementations for EMNIST (letters) dataset. 4 | 5 | Contains: 6 | - DenseNet121 7 | - DenseNet161 8 | - DenseNet169 9 | - DenseNet201 10 | """ 11 | 12 | from torchfl.models.sota.densenet import DenseNet121 as BaseDenseNet121 13 | from torchfl.models.sota.densenet import DenseNet161 as BaseDenseNet161 14 | from torchfl.models.sota.densenet import DenseNet169 as BaseDenseNet169 15 | from torchfl.models.sota.densenet import DenseNet201 as BaseDenseNet201 16 | 17 | 18 | class DenseNet121(BaseDenseNet121): 19 | def __init__( 20 | self, pre_trained=True, feature_extract=False, num_channels=1 21 | ) -> None: 22 | """Constructor 23 | 24 | Args: 25 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 26 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 27 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 28 | """ 29 | super().__init__( 30 | pre_trained=pre_trained, 31 | feature_extract=feature_extract, 32 | num_channels=num_channels, 33 | num_classes=26, 34 | act_fn_name="relu", 35 | ) 36 | 37 | 38 | class DenseNet161(BaseDenseNet161): 39 | def __init__( 40 | self, pre_trained=True, feature_extract=False, num_channels=1 41 | ) -> None: 42 | """Constructor 43 | 44 | Args: 45 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 46 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 47 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 48 | """ 49 | super().__init__( 50 | pre_trained=pre_trained, 51 | feature_extract=feature_extract, 52 | num_channels=num_channels, 53 | num_classes=26, 54 | act_fn_name="relu", 55 | ) 56 | 57 | 58 | class DenseNet169(BaseDenseNet169): 59 | def __init__( 60 | self, pre_trained=True, feature_extract=False, num_channels=1 61 | ) -> None: 62 | """Constructor 63 | 64 | Args: 65 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 66 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 67 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 68 | """ 69 | super().__init__( 70 | pre_trained=pre_trained, 71 | feature_extract=feature_extract, 72 | num_channels=num_channels, 73 | num_classes=26, 74 | act_fn_name="relu", 75 | ) 76 | 77 | 78 | class DenseNet201(BaseDenseNet201): 79 | def __init__( 80 | self, pre_trained=True, feature_extract=False, num_channels=1 81 | ) -> None: 82 | """Constructor 83 | 84 | Args: 85 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 86 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 87 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 88 | """ 89 | super().__init__( 90 | pre_trained=pre_trained, 91 | feature_extract=feature_extract, 92 | num_channels=num_channels, 93 | num_classes=26, 94 | act_fn_name="relu", 95 | ) 96 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/letters/lenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the LeNet model implementations for EMNIST (letters) dataset.""" 4 | 5 | from torchfl.models.sota.lenet import LeNet as BaseLeNet 6 | 7 | 8 | class LeNet(BaseLeNet): 9 | def __init__(self, num_channels=1) -> None: 10 | """Constructor 11 | 12 | Args: 13 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 14 | """ 15 | super().__init__( 16 | num_classes=26, num_channels=num_channels, act_fn_name="tanh" 17 | ) 18 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/letters/mlp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the MLP model implementations for EMNIST (letters) dataset.""" 4 | 5 | from torchfl.models.sota.mlp import MLP as BaseMLP 6 | 7 | 8 | class MLP(BaseMLP): 9 | def __init__(self, num_channels=1, img_w=28, img_h=28) -> None: 10 | """Constructor 11 | 12 | Args: 13 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 14 | - img_w (int, optional): Width of the input image. Defaults to 28. 15 | - img_h (int, optional): Height of the input image. Defaults to 28. 16 | """ 17 | super().__init__( 18 | num_classes=26, 19 | num_channels=num_channels, 20 | img_w=img_w, 21 | img_h=img_h, 22 | hidden_dims=[256, 128], 23 | ) 24 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/letters/mobilenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the MobileNet model implementations for EMNIST (balanced) dataset. 4 | 5 | Contains: 6 | - MobileNetV2 7 | - MobileNetV3Small 8 | - MobileNetV3Large 9 | """ 10 | 11 | from torchfl.models.sota.mobilenet import MobileNetV2 as BaseMobileNetV2 12 | from torchfl.models.sota.mobilenet import ( 13 | MobileNetV3Large as BaseMobileNetV3Large, 14 | ) 15 | from torchfl.models.sota.mobilenet import ( 16 | MobileNetV3Small as BaseMobileNetV3Small, 17 | ) 18 | 19 | 20 | class MobileNetV2(BaseMobileNetV2): 21 | def __init__( 22 | self, pre_trained=True, feature_extract=False, num_channels=1 23 | ) -> None: 24 | """Constructor 25 | 26 | Args: 27 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 28 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 29 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 30 | """ 31 | super().__init__( 32 | pre_trained=pre_trained, 33 | feature_extract=feature_extract, 34 | num_classes=26, 35 | num_channels=num_channels, 36 | act_fn_name="relu", 37 | ) 38 | 39 | 40 | class MobileNetV3Small(BaseMobileNetV3Small): 41 | def __init__( 42 | self, pre_trained=True, feature_extract=False, num_channels=1 43 | ) -> None: 44 | """Constructor 45 | 46 | Args: 47 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 48 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 49 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 50 | """ 51 | super().__init__( 52 | pre_trained=pre_trained, 53 | feature_extract=feature_extract, 54 | num_classes=26, 55 | num_channels=num_channels, 56 | act_fn_name="relu", 57 | ) 58 | 59 | 60 | class MobileNetV3Large(BaseMobileNetV3Large): 61 | def __init__( 62 | self, pre_trained=True, feature_extract=False, num_channels=1 63 | ) -> None: 64 | """Constructor 65 | 66 | Args: 67 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 68 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 69 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 70 | """ 71 | super().__init__( 72 | pre_trained=pre_trained, 73 | feature_extract=feature_extract, 74 | num_classes=26, 75 | num_channels=num_channels, 76 | act_fn_name="relu", 77 | ) 78 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/letters/squeezenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the SqueezeNet model implementations for EMNIST (letters) dataset. 4 | 5 | Contains: 6 | - SqueezeNet1_0 7 | - SqueezeNet1_1 8 | """ 9 | 10 | from torchfl.models.sota.squeezenet import SqueezeNet1_0 as BaseSqueezeNet1_0 11 | from torchfl.models.sota.squeezenet import SqueezeNet1_1 as BaseSqueezeNet1_1 12 | 13 | 14 | class SqueezeNet1_0(BaseSqueezeNet1_0): 15 | def __init__( 16 | self, pre_trained=True, feature_extract=False, num_channels=1 17 | ) -> None: 18 | """Constructor 19 | 20 | Args: 21 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 22 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 23 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 24 | """ 25 | super().__init__( 26 | pre_trained=pre_trained, 27 | feature_extract=feature_extract, 28 | num_classes=26, 29 | num_channels=num_channels, 30 | act_fn_name="relu", 31 | ) 32 | 33 | 34 | class SqueezeNet1_1(BaseSqueezeNet1_1): 35 | def __init__( 36 | self, pre_trained=True, feature_extract=False, num_channels=1 37 | ) -> None: 38 | """Constructor 39 | 40 | Args: 41 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 42 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 43 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 44 | """ 45 | super().__init__( 46 | pre_trained=pre_trained, 47 | feature_extract=feature_extract, 48 | num_classes=26, 49 | num_channels=num_channels, 50 | act_fn_name="relu", 51 | ) 52 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/mnist/__init__.py: -------------------------------------------------------------------------------- 1 | """Sub-package for EMNIST (MNIST) model implementations provided by torchfl.""" 2 | 3 | __author__ = """Vivek Khimani""" 4 | __email__ = "vivekkhimani07@gmail.com" 5 | __version__ = "0.1.0" 6 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/mnist/alexnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the AlexNet model implementations for MNIST dataset.""" 4 | 5 | from torchfl.models.sota.alexnet import AlexNet as BaseAlexNet 6 | 7 | 8 | class AlexNet(BaseAlexNet): 9 | def __init__( 10 | self, pre_trained=True, feature_extract=False, num_channels=1 11 | ) -> None: 12 | """Constructor 13 | 14 | Args: 15 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 16 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 17 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 18 | """ 19 | super().__init__( 20 | pre_trained=pre_trained, 21 | feature_extract=feature_extract, 22 | num_channels=num_channels, 23 | num_classes=10, 24 | act_fn_name="relu", 25 | ) 26 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/mnist/densenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the DenseNet model implementations for MNIST dataset. 4 | 5 | Contains: 6 | - DenseNet121 7 | - DenseNet161 8 | - DenseNet169 9 | - DenseNet201 10 | """ 11 | 12 | from torchfl.models.sota.densenet import DenseNet121 as BaseDenseNet121 13 | from torchfl.models.sota.densenet import DenseNet161 as BaseDenseNet161 14 | from torchfl.models.sota.densenet import DenseNet169 as BaseDenseNet169 15 | from torchfl.models.sota.densenet import DenseNet201 as BaseDenseNet201 16 | 17 | 18 | class DenseNet121(BaseDenseNet121): 19 | def __init__( 20 | self, pre_trained=True, feature_extract=False, num_channels=1 21 | ) -> None: 22 | """Constructor 23 | 24 | Args: 25 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 26 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 27 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 28 | """ 29 | super().__init__( 30 | pre_trained=pre_trained, 31 | feature_extract=feature_extract, 32 | num_channels=num_channels, 33 | num_classes=10, 34 | act_fn_name="relu", 35 | ) 36 | 37 | 38 | class DenseNet161(BaseDenseNet161): 39 | def __init__( 40 | self, pre_trained=True, feature_extract=False, num_channels=1 41 | ) -> None: 42 | """Constructor 43 | 44 | Args: 45 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 46 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 47 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 48 | """ 49 | super().__init__( 50 | pre_trained=pre_trained, 51 | feature_extract=feature_extract, 52 | num_channels=num_channels, 53 | num_classes=10, 54 | act_fn_name="relu", 55 | ) 56 | 57 | 58 | class DenseNet169(BaseDenseNet169): 59 | def __init__( 60 | self, pre_trained=True, feature_extract=False, num_channels=1 61 | ) -> None: 62 | """Constructor 63 | 64 | Args: 65 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 66 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 67 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 68 | """ 69 | super().__init__( 70 | pre_trained=pre_trained, 71 | feature_extract=feature_extract, 72 | num_channels=num_channels, 73 | num_classes=10, 74 | act_fn_name="relu", 75 | ) 76 | 77 | 78 | class DenseNet201(BaseDenseNet201): 79 | def __init__( 80 | self, pre_trained=True, feature_extract=False, num_channels=1 81 | ) -> None: 82 | """Constructor 83 | 84 | Args: 85 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 86 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 87 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 88 | """ 89 | super().__init__( 90 | pre_trained=pre_trained, 91 | feature_extract=feature_extract, 92 | num_channels=num_channels, 93 | num_classes=10, 94 | act_fn_name="relu", 95 | ) 96 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/mnist/lenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the LeNet model implementations for MNIST dataset.""" 4 | 5 | from torchfl.models.sota.lenet import LeNet as BaseLeNet 6 | 7 | 8 | class LeNet(BaseLeNet): 9 | def __init__(self, num_channels=1) -> None: 10 | """Constructor 11 | 12 | Args: 13 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 14 | """ 15 | super().__init__( 16 | num_classes=10, num_channels=num_channels, act_fn_name="tanh" 17 | ) 18 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/mnist/mlp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the MLP model implementations for MNIST dataset.""" 4 | 5 | from torchfl.models.sota.mlp import MLP as BaseMLP 6 | 7 | 8 | class MLP(BaseMLP): 9 | def __init__(self, num_channels=1, img_w=28, img_h=28) -> None: 10 | """Constructor 11 | 12 | Args: 13 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 14 | - img_w (int, optional): Width of the input image. Defaults to 28. 15 | - img_h (int, optional): Height of the input image. Defaults to 28. 16 | """ 17 | super().__init__( 18 | num_classes=10, 19 | num_channels=num_channels, 20 | img_w=img_w, 21 | img_h=img_h, 22 | hidden_dims=[256, 128], 23 | ) 24 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/mnist/mobilenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the MobileNet model implementations for EMNIST (balanced) dataset. 4 | 5 | Contains: 6 | - MobileNetV2 7 | - MobileNetV3Small 8 | - MobileNetV3Large 9 | """ 10 | 11 | from torchfl.models.sota.mobilenet import MobileNetV2 as BaseMobileNetV2 12 | from torchfl.models.sota.mobilenet import ( 13 | MobileNetV3Large as BaseMobileNetV3Large, 14 | ) 15 | from torchfl.models.sota.mobilenet import ( 16 | MobileNetV3Small as BaseMobileNetV3Small, 17 | ) 18 | 19 | 20 | class MobileNetV2(BaseMobileNetV2): 21 | def __init__( 22 | self, pre_trained=True, feature_extract=False, num_channels=1 23 | ) -> None: 24 | """Constructor 25 | 26 | Args: 27 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 28 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 29 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 30 | """ 31 | super().__init__( 32 | pre_trained=pre_trained, 33 | feature_extract=feature_extract, 34 | num_classes=10, 35 | num_channels=num_channels, 36 | act_fn_name="relu", 37 | ) 38 | 39 | 40 | class MobileNetV3Small(BaseMobileNetV3Small): 41 | def __init__( 42 | self, pre_trained=True, feature_extract=False, num_channels=1 43 | ) -> None: 44 | """Constructor 45 | 46 | Args: 47 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 48 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 49 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 50 | """ 51 | super().__init__( 52 | pre_trained=pre_trained, 53 | feature_extract=feature_extract, 54 | num_classes=10, 55 | num_channels=num_channels, 56 | act_fn_name="relu", 57 | ) 58 | 59 | 60 | class MobileNetV3Large(BaseMobileNetV3Large): 61 | def __init__( 62 | self, pre_trained=True, feature_extract=False, num_channels=1 63 | ) -> None: 64 | """Constructor 65 | 66 | Args: 67 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 68 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 69 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 70 | """ 71 | super().__init__( 72 | pre_trained=pre_trained, 73 | feature_extract=feature_extract, 74 | num_classes=10, 75 | num_channels=num_channels, 76 | act_fn_name="relu", 77 | ) 78 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/mnist/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the ShuffleNetv2 model implementations for MNIST dataset. 4 | 5 | Contains: 6 | - ShuffleNetv2_x0_5 7 | - ShuffleNetv2_x1_0 8 | - ShuffleNetv2_x1_5 9 | - ShuffleNetv2_x2_0 10 | """ 11 | 12 | from torchfl.models.sota.shufflenetv2 import ( 13 | ShuffleNetv2_x0_5 as BaseShuffleNetv2_x0_5, 14 | ) 15 | from torchfl.models.sota.shufflenetv2 import ( 16 | ShuffleNetv2_x1_0 as BaseShuffleNetv2_x1_0, 17 | ) 18 | from torchfl.models.sota.shufflenetv2 import ( 19 | ShuffleNetv2_x1_5 as BaseShuffleNetv2_x1_5, 20 | ) 21 | from torchfl.models.sota.shufflenetv2 import ( 22 | ShuffleNetv2_x2_0 as BaseShuffleNetv2_x2_0, 23 | ) 24 | 25 | 26 | class ShuffleNetv2_x0_5(BaseShuffleNetv2_x0_5): 27 | def __init__( 28 | self, pre_trained=True, feature_extract=False, num_channels=1 29 | ) -> None: 30 | """Constructor 31 | 32 | Args: 33 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 34 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 35 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 36 | """ 37 | super().__init__( 38 | pre_trained=pre_trained, 39 | feature_extract=feature_extract, 40 | num_classes=10, 41 | num_channels=num_channels, 42 | act_fn_name="relu", 43 | ) 44 | 45 | 46 | class ShuffleNetv2_x1_0(BaseShuffleNetv2_x1_0): 47 | def __init__( 48 | self, pre_trained=True, feature_extract=False, num_channels=1 49 | ) -> None: 50 | """Constructor 51 | 52 | Args: 53 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 54 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 55 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 56 | """ 57 | super().__init__( 58 | pre_trained=pre_trained, 59 | feature_extract=feature_extract, 60 | num_classes=10, 61 | num_channels=num_channels, 62 | act_fn_name="relu", 63 | ) 64 | 65 | 66 | class ShuffleNetv2_x1_5(BaseShuffleNetv2_x1_5): 67 | def __init__( 68 | self, pre_trained=False, feature_extract=False, num_channels=1 69 | ) -> None: 70 | """Constructor 71 | 72 | Args: 73 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to False. 74 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 75 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 76 | """ 77 | super().__init__( 78 | pre_trained=pre_trained, 79 | feature_extract=feature_extract, 80 | num_classes=10, 81 | num_channels=num_channels, 82 | act_fn_name="relu", 83 | ) 84 | 85 | 86 | class ShuffleNetv2_x2_0(BaseShuffleNetv2_x2_0): 87 | def __init__( 88 | self, pre_trained=False, feature_extract=False, num_channels=1 89 | ) -> None: 90 | """Constructor 91 | 92 | Args: 93 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to False. 94 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 95 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 96 | """ 97 | super().__init__( 98 | pre_trained=pre_trained, 99 | feature_extract=feature_extract, 100 | num_classes=10, 101 | num_channels=num_channels, 102 | act_fn_name="relu", 103 | ) 104 | -------------------------------------------------------------------------------- /torchfl/models/core/emnist/mnist/squeezenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the SqueezeNet model implementations for MNIST dataset. 4 | 5 | Contains: 6 | - SqueezeNet1_0 7 | - SqueezeNet1_1 8 | """ 9 | 10 | from torchfl.models.sota.squeezenet import SqueezeNet1_0 as BaseSqueezeNet1_0 11 | from torchfl.models.sota.squeezenet import SqueezeNet1_1 as BaseSqueezeNet1_1 12 | 13 | 14 | class SqueezeNet1_0(BaseSqueezeNet1_0): 15 | def __init__( 16 | self, pre_trained=True, feature_extract=False, num_channels=1 17 | ) -> None: 18 | """Constructor 19 | 20 | Args: 21 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 22 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 23 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 24 | """ 25 | super().__init__( 26 | pre_trained=pre_trained, 27 | feature_extract=feature_extract, 28 | num_classes=10, 29 | num_channels=num_channels, 30 | act_fn_name="relu", 31 | ) 32 | 33 | 34 | class SqueezeNet1_1(BaseSqueezeNet1_1): 35 | def __init__( 36 | self, pre_trained=True, feature_extract=False, num_channels=1 37 | ) -> None: 38 | """Constructor 39 | 40 | Args: 41 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 42 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 43 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 44 | """ 45 | super().__init__( 46 | pre_trained=pre_trained, 47 | feature_extract=feature_extract, 48 | num_classes=10, 49 | num_channels=num_channels, 50 | act_fn_name="relu", 51 | ) 52 | -------------------------------------------------------------------------------- /torchfl/models/core/fashionmnist/__init__.py: -------------------------------------------------------------------------------- 1 | """Sub-package for FashionMNIST model implementations provided by torchfl.""" 2 | 3 | __author__ = """Vivek Khimani""" 4 | __email__ = "vivekkhimani07@gmail.com" 5 | __version__ = "0.1.0" 6 | -------------------------------------------------------------------------------- /torchfl/models/core/fashionmnist/alexnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the AlexNet model implementations for FashionMNIST dataset.""" 4 | 5 | from torchfl.models.sota.alexnet import AlexNet as BaseAlexNet 6 | 7 | 8 | class AlexNet(BaseAlexNet): 9 | def __init__( 10 | self, pre_trained=True, feature_extract=False, num_channels=1 11 | ) -> None: 12 | """Constructor 13 | 14 | Args: 15 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 16 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 17 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 18 | """ 19 | super().__init__( 20 | pre_trained=pre_trained, 21 | feature_extract=feature_extract, 22 | num_channels=num_channels, 23 | num_classes=10, 24 | act_fn_name="relu", 25 | ) 26 | -------------------------------------------------------------------------------- /torchfl/models/core/fashionmnist/densenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the DenseNet model implementations for FashionMNIST dataset. 4 | 5 | Contains: 6 | - DenseNet121 7 | - DenseNet161 8 | - DenseNet169 9 | - DenseNet201 10 | """ 11 | 12 | from torchfl.models.sota.densenet import DenseNet121 as BaseDenseNet121 13 | from torchfl.models.sota.densenet import DenseNet161 as BaseDenseNet161 14 | from torchfl.models.sota.densenet import DenseNet169 as BaseDenseNet169 15 | from torchfl.models.sota.densenet import DenseNet201 as BaseDenseNet201 16 | 17 | 18 | class DenseNet121(BaseDenseNet121): 19 | def __init__( 20 | self, pre_trained=True, feature_extract=False, num_channels=1 21 | ) -> None: 22 | """Constructor 23 | 24 | Args: 25 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 26 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 27 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 28 | """ 29 | super().__init__( 30 | pre_trained=pre_trained, 31 | feature_extract=feature_extract, 32 | num_channels=num_channels, 33 | num_classes=10, 34 | act_fn_name="relu", 35 | ) 36 | 37 | 38 | class DenseNet161(BaseDenseNet161): 39 | def __init__( 40 | self, pre_trained=True, feature_extract=False, num_channels=1 41 | ) -> None: 42 | """Constructor 43 | 44 | Args: 45 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 46 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 47 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 48 | """ 49 | super().__init__( 50 | pre_trained=pre_trained, 51 | feature_extract=feature_extract, 52 | num_channels=num_channels, 53 | num_classes=10, 54 | act_fn_name="relu", 55 | ) 56 | 57 | 58 | class DenseNet169(BaseDenseNet169): 59 | def __init__( 60 | self, pre_trained=True, feature_extract=False, num_channels=1 61 | ) -> None: 62 | """Constructor 63 | 64 | Args: 65 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 66 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 67 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 68 | """ 69 | super().__init__( 70 | pre_trained=pre_trained, 71 | feature_extract=feature_extract, 72 | num_channels=num_channels, 73 | num_classes=10, 74 | act_fn_name="relu", 75 | ) 76 | 77 | 78 | class DenseNet201(BaseDenseNet201): 79 | def __init__( 80 | self, pre_trained=True, feature_extract=False, num_channels=1 81 | ) -> None: 82 | """Constructor 83 | 84 | Args: 85 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 86 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 87 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 88 | """ 89 | super().__init__( 90 | pre_trained=pre_trained, 91 | feature_extract=feature_extract, 92 | num_channels=num_channels, 93 | num_classes=10, 94 | act_fn_name="relu", 95 | ) 96 | -------------------------------------------------------------------------------- /torchfl/models/core/fashionmnist/lenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the LeNet model implementations for FashionMNIST dataset.""" 4 | 5 | from torchfl.models.sota.lenet import LeNet as BaseLeNet 6 | 7 | 8 | class LeNet(BaseLeNet): 9 | def __init__(self, num_channels=1) -> None: 10 | """Constructor 11 | 12 | Args: 13 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 14 | """ 15 | super().__init__( 16 | num_classes=10, num_channels=num_channels, act_fn_name="tanh" 17 | ) 18 | -------------------------------------------------------------------------------- /torchfl/models/core/fashionmnist/mlp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the MLP model implementations for FashionMNIST dataset.""" 4 | 5 | from torchfl.models.sota.mlp import MLP as BaseMLP 6 | 7 | 8 | class MLP(BaseMLP): 9 | def __init__(self, num_channels=1, img_w=28, img_h=28) -> None: 10 | """Constructor 11 | 12 | Args: 13 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 14 | - img_w (int, optional): Width of the input image. Defaults to 28. 15 | - img_h (int, optional): Height of the input image. Defaults to 28. 16 | """ 17 | super().__init__( 18 | num_classes=10, 19 | num_channels=num_channels, 20 | img_w=img_w, 21 | img_h=img_h, 22 | hidden_dims=[256, 128], 23 | ) 24 | -------------------------------------------------------------------------------- /torchfl/models/core/fashionmnist/mobilenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the MobileNet model implementations for FashionMNIST dataset. 4 | 5 | Contains: 6 | - MobileNetV2 7 | - MobileNetV3Small 8 | - MobileNetV3Large 9 | """ 10 | 11 | from torchfl.models.sota.mobilenet import MobileNetV2 as BaseMobileNetV2 12 | from torchfl.models.sota.mobilenet import ( 13 | MobileNetV3Large as BaseMobileNetV3Large, 14 | ) 15 | from torchfl.models.sota.mobilenet import ( 16 | MobileNetV3Small as BaseMobileNetV3Small, 17 | ) 18 | 19 | 20 | class MobileNetV2(BaseMobileNetV2): 21 | def __init__( 22 | self, pre_trained=True, feature_extract=False, num_channels=1 23 | ) -> None: 24 | """Constructor 25 | 26 | Args: 27 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 28 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 29 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 30 | """ 31 | super().__init__( 32 | pre_trained=pre_trained, 33 | feature_extract=feature_extract, 34 | num_classes=10, 35 | num_channels=num_channels, 36 | act_fn_name="relu", 37 | ) 38 | 39 | 40 | class MobileNetV3Small(BaseMobileNetV3Small): 41 | def __init__( 42 | self, pre_trained=True, feature_extract=False, num_channels=1 43 | ) -> None: 44 | """Constructor 45 | 46 | Args: 47 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 48 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 49 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 50 | """ 51 | super().__init__( 52 | pre_trained=pre_trained, 53 | feature_extract=feature_extract, 54 | num_classes=10, 55 | num_channels=num_channels, 56 | act_fn_name="relu", 57 | ) 58 | 59 | 60 | class MobileNetV3Large(BaseMobileNetV3Large): 61 | def __init__( 62 | self, pre_trained=True, feature_extract=False, num_channels=1 63 | ) -> None: 64 | """Constructor 65 | 66 | Args: 67 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 68 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 69 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 70 | """ 71 | super().__init__( 72 | pre_trained=pre_trained, 73 | feature_extract=feature_extract, 74 | num_classes=10, 75 | num_channels=num_channels, 76 | act_fn_name="relu", 77 | ) 78 | -------------------------------------------------------------------------------- /torchfl/models/core/fashionmnist/squeezenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Contains the SqueezeNet model implementations for FashionMNIST dataset. 4 | 5 | Contains: 6 | - SqueezeNet1_0 7 | - SqueezeNet1_1 8 | """ 9 | 10 | from torchfl.models.sota.squeezenet import SqueezeNet1_0 as BaseSqueezeNet1_0 11 | from torchfl.models.sota.squeezenet import SqueezeNet1_1 as BaseSqueezeNet1_1 12 | 13 | 14 | class SqueezeNet1_0(BaseSqueezeNet1_0): 15 | def __init__( 16 | self, pre_trained=True, feature_extract=False, num_channels=1 17 | ) -> None: 18 | """Constructor 19 | 20 | Args: 21 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 22 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 23 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 24 | """ 25 | super().__init__( 26 | pre_trained=pre_trained, 27 | feature_extract=feature_extract, 28 | num_classes=10, 29 | num_channels=num_channels, 30 | act_fn_name="relu", 31 | ) 32 | 33 | 34 | class SqueezeNet1_1(BaseSqueezeNet1_1): 35 | def __init__( 36 | self, pre_trained=True, feature_extract=False, num_channels=1 37 | ) -> None: 38 | """Constructor 39 | 40 | Args: 41 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 42 | - feature_extract (bool, optional): Use transfer learning and only train the classifier. Otherwise, finetune the whole model. Defaults to False. 43 | - num_channels (int, optional): Number of incoming channels. Defaults to 1. 44 | """ 45 | super().__init__( 46 | pre_trained=pre_trained, 47 | feature_extract=feature_extract, 48 | num_classes=10, 49 | num_channels=num_channels, 50 | act_fn_name="relu", 51 | ) 52 | -------------------------------------------------------------------------------- /torchfl/models/sota/__init__.py: -------------------------------------------------------------------------------- 1 | """Sub-package for state-of-the-art model implementations pre-trained by torchvision and used by the core models provided by torchfl. 2 | 3 | Available modules: 4 | - alexnet 5 | - resnet 6 | - vgg 7 | - squeezenet 8 | - densenet 9 | - shufflenetv2 10 | - mobilenet 11 | """ 12 | 13 | __author__ = """Vivek Khimani""" 14 | __email__ = "vivekkhimani07@gmail.com" 15 | __version__ = "0.1.0" 16 | -------------------------------------------------------------------------------- /torchfl/models/sota/alexnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # type: ignore 3 | 4 | """Implementation of the pre-trained AlexNet using PyTorch and torchvision.""" 5 | 6 | from types import SimpleNamespace 7 | 8 | import torch.nn as nn 9 | from torchvision import models 10 | 11 | from torchfl.compatibility import ACTIVATION_FUNCTIONS_BY_NAME 12 | 13 | 14 | class AlexNet(models.AlexNet): 15 | """AlexNet base definition""" 16 | 17 | def __init__( 18 | self, 19 | pre_trained=True, 20 | feature_extract=True, 21 | num_classes=10, 22 | num_channels=3, 23 | act_fn_name="relu", 24 | **kwargs 25 | ) -> None: 26 | """Constructor 27 | 28 | Args: 29 | - pre_trained (bool, optional): Use the model pre-trained on the ImageNet dataset. Defaults to True. 30 | - feature_extract (bool, optional): Only trains the sequential layers of the pre-trained model. If False, the entire model is finetuned. Defaults to True. 31 | - num_classes (int, optional): Number of classification outputs. Defaults to 10. 32 | - num_channels (int, optional): Number of incoming channels. Defaults to 3. 33 | - act_fn_name (str, optional): Activation function to be used. Defaults to "relu". Accepted: ["tanh", "relu", "leakyrelu", "gelu"]. 34 | """ 35 | super().__init__() 36 | self.hparams = SimpleNamespace( 37 | model_name="alexnet", 38 | pre_trained=pre_trained, 39 | feature_extract=bool(pre_trained and feature_extract), 40 | finetune=bool(not feature_extract), 41 | quantized=False, 42 | num_classes=num_classes, 43 | num_channels=num_channels, 44 | act_fn_name=act_fn_name, 45 | act_fn=ACTIVATION_FUNCTIONS_BY_NAME[act_fn_name], 46 | ) 47 | if pre_trained: 48 | pretrained_model = models.alexnet(pretrained=True, progress=True) 49 | self.load_state_dict(pretrained_model.state_dict()) 50 | 51 | if feature_extract: 52 | for param in self.parameters(): 53 | param.requires_grad = False 54 | 55 | if num_channels != 3: 56 | self.features[0] = nn.Conv2d( 57 | num_channels, 64, kernel_size=11, stride=4, padding=2 58 | ) 59 | 60 | in_features = self.classifier[6].in_features 61 | self.classifier[6] = nn.Linear(in_features, self.hparams.num_classes) 62 | -------------------------------------------------------------------------------- /torchfl/models/sota/lenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # type: ignore 3 | 4 | """Implementation of the general LeNet architecture using PyTorch.""" 5 | 6 | from types import SimpleNamespace 7 | 8 | import torch.nn as nn 9 | 10 | from torchfl.compatibility import ACTIVATION_FUNCTIONS_BY_NAME 11 | from torchfl.models.sota.mlp import LinearBlock 12 | 13 | 14 | class LeNet(nn.Module): 15 | """LeNet base definition""" 16 | 17 | def __init__( 18 | self, num_classes=10, num_channels=1, act_fn_name="relu", **kwargs 19 | ) -> None: 20 | """Constructor 21 | 22 | Args: 23 | - num_classes (int, optional): Number of classification outputs. Defaults to 10. 24 | - num_channels (int, optional): Number of channels for the images in the dataset. Defaults to 3. 25 | - act_fn_name (str, optional): Activation function to be used. Defaults to "relu". Accepted: ["tanh", "relu", "leakyrelu", "gelu"]. 26 | """ 27 | super().__init__() 28 | self.hparams = SimpleNamespace( 29 | model_name="lenet", 30 | num_classes=num_classes, 31 | num_channels=num_channels, 32 | act_fn_name=act_fn_name, 33 | act_fn=ACTIVATION_FUNCTIONS_BY_NAME[act_fn_name], 34 | pre_trained=False, 35 | feature_extract=False, 36 | finetune=False, 37 | quantized=False, 38 | ) 39 | self._create_network() 40 | 41 | def _create_network(self): 42 | self.input_net = nn.Sequential( 43 | nn.Conv2d( 44 | self.hparams.num_channels, 45 | 6, 46 | kernel_size=5, 47 | stride=1, 48 | padding=2, 49 | ), 50 | self.hparams.act_fn(), 51 | nn.AvgPool2d(kernel_size=2, stride=2), 52 | ) 53 | self.conv_net = nn.Sequential( 54 | nn.Conv2d(6, 16, kernel_size=5, stride=1), 55 | nn.AvgPool2d(kernel_size=2, stride=2), 56 | nn.Conv2d(16, 120, kernel_size=5, stride=1), 57 | self.hparams.act_fn(), 58 | nn.Flatten(start_dim=1), 59 | LinearBlock(300000, 84, self.hparams.act_fn, False), 60 | ) 61 | self.output_net = nn.Sequential( 62 | nn.Linear(84, self.hparams.num_classes) 63 | ) 64 | 65 | def forward(self, x): 66 | """Forward propagation 67 | 68 | Args: 69 | - x (torch.Tensor): Input Tensor 70 | 71 | Returns: 72 | - torch.Tensor: Returns the tensor after forward propagation 73 | """ 74 | return self.output_net(self.conv_net(self.input_net(x))) 75 | -------------------------------------------------------------------------------- /torchfl/models/wrapper/__init__.py: -------------------------------------------------------------------------------- 1 | """Sub-package for wrapper model implementations used by torchfl.""" 2 | 3 | __author__ = """Vivek Khimani""" 4 | __email__ = "vivekkhimani07@gmail.com" 5 | __version__ = "0.1.0" 6 | -------------------------------------------------------------------------------- /torchfl/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Utility functions used within the torchfl package.""" 4 | 5 | from typing import Any 6 | 7 | 8 | def _get_enum_values(enum_class: Any) -> list[str]: 9 | """Return a list of values of an enum class.""" 10 | return [x.value for x in enum_class._member_map_.values()] 11 | --------------------------------------------------------------------------------