├── docs ├── static │ ├── .nojekyll │ └── img │ │ ├── logo.png │ │ ├── favicon.ico │ │ ├── logo_text.png │ │ ├── undraw_expanded_functionality.svg │ │ ├── undraw_fast_loading.svg │ │ └── undraw_compatibility.svg ├── babel.config.js ├── sidebars.js ├── .gitignore ├── blog │ └── 2020-09-15-helloworld.md ├── src │ ├── pages │ │ ├── styles.module.css │ │ └── index.js │ └── css │ │ └── custom.css ├── README.md ├── package.json ├── docusaurus.config.js └── docs │ ├── files.md │ └── gettingstarted.md ├── examples ├── sample_subset.json ├── regex_index.json ├── regex_dataset │ ├── a1.jpg │ ├── a2.jpg │ ├── a3.jpg │ ├── b1.jpg │ ├── b2.jpg │ └── b3.jpg ├── sample_dataset │ ├── image0.jpg │ ├── image1.jpg │ ├── image2.jpg │ ├── image3.jpg │ ├── image4.jpg │ ├── image5.jpg │ ├── image6.jpg │ └── image7.jpg ├── sample_index_unsupervised.json ├── sample_index.json ├── unsupervised.py └── supervised.py ├── .flake8 ├── setup.cfg ├── pyproject.toml ├── .pre-commit-config.yaml ├── Makefile ├── betterloader ├── standard_transforms.py ├── ImageFolderCustom.py ├── DatasetFolder.py ├── defaults.py └── __init__.py ├── .github ├── ISSUE_TEMPLATE │ ├── Feature_request.md │ └── Bug_report.md ├── PULL_REQUEST_TEMPLATE │ ├── Bugfix.md │ ├── Documentation.md │ └── Feature.md └── workflows │ ├── pr.yaml │ └── push.yaml ├── LICENSE.txt ├── requirements.txt ├── setup.py ├── CONTRIBUTING.md ├── .gitignore ├── README.md └── tests └── tests.py /docs/static/.nojekyll: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/sample_subset.json: -------------------------------------------------------------------------------- 1 | ["image0.jpg","image3.jpg"] 2 | -------------------------------------------------------------------------------- /examples/regex_index.json: -------------------------------------------------------------------------------- 1 | { 2 | "class1": "^a", 3 | "class2": "^b" 4 | } -------------------------------------------------------------------------------- /docs/static/img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinItAI/BetterLoader/HEAD/docs/static/img/logo.png -------------------------------------------------------------------------------- /docs/static/img/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinItAI/BetterLoader/HEAD/docs/static/img/favicon.ico -------------------------------------------------------------------------------- /docs/static/img/logo_text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinItAI/BetterLoader/HEAD/docs/static/img/logo_text.png -------------------------------------------------------------------------------- /examples/regex_dataset/a1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinItAI/BetterLoader/HEAD/examples/regex_dataset/a1.jpg -------------------------------------------------------------------------------- /examples/regex_dataset/a2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinItAI/BetterLoader/HEAD/examples/regex_dataset/a2.jpg -------------------------------------------------------------------------------- /examples/regex_dataset/a3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinItAI/BetterLoader/HEAD/examples/regex_dataset/a3.jpg -------------------------------------------------------------------------------- /examples/regex_dataset/b1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinItAI/BetterLoader/HEAD/examples/regex_dataset/b1.jpg -------------------------------------------------------------------------------- /examples/regex_dataset/b2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinItAI/BetterLoader/HEAD/examples/regex_dataset/b2.jpg -------------------------------------------------------------------------------- /examples/regex_dataset/b3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinItAI/BetterLoader/HEAD/examples/regex_dataset/b3.jpg -------------------------------------------------------------------------------- /docs/babel.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | presets: [require.resolve('@docusaurus/core/lib/babel/preset')], 3 | }; 4 | -------------------------------------------------------------------------------- /docs/sidebars.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | someSidebar: { 3 | BetterLoader: ['gettingstarted', 'files'], 4 | }, 5 | }; 6 | -------------------------------------------------------------------------------- /examples/sample_dataset/image0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinItAI/BetterLoader/HEAD/examples/sample_dataset/image0.jpg -------------------------------------------------------------------------------- /examples/sample_dataset/image1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinItAI/BetterLoader/HEAD/examples/sample_dataset/image1.jpg -------------------------------------------------------------------------------- /examples/sample_dataset/image2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinItAI/BetterLoader/HEAD/examples/sample_dataset/image2.jpg -------------------------------------------------------------------------------- /examples/sample_dataset/image3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinItAI/BetterLoader/HEAD/examples/sample_dataset/image3.jpg -------------------------------------------------------------------------------- /examples/sample_dataset/image4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinItAI/BetterLoader/HEAD/examples/sample_dataset/image4.jpg -------------------------------------------------------------------------------- /examples/sample_dataset/image5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinItAI/BetterLoader/HEAD/examples/sample_dataset/image5.jpg -------------------------------------------------------------------------------- /examples/sample_dataset/image6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinItAI/BetterLoader/HEAD/examples/sample_dataset/image6.jpg -------------------------------------------------------------------------------- /examples/sample_dataset/image7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinItAI/BetterLoader/HEAD/examples/sample_dataset/image7.jpg -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | max-complexity = 18 4 | select = B,C,E,F,W,T4,B9 5 | ignore = E203, E266, E501, W503, F403, F401, E402 6 | -------------------------------------------------------------------------------- /examples/sample_index_unsupervised.json: -------------------------------------------------------------------------------- 1 | { 2 | "class0":["image0.jpg","image1.jpg","image2.jpg","image3.jpg", "image4.jpg","image5.jpg","image6.jpg","image7.jpg"] 3 | } -------------------------------------------------------------------------------- /examples/sample_index.json: -------------------------------------------------------------------------------- 1 | { 2 | "class1":["image0.jpg","image1.jpg","image2.jpg","image3.jpg"], 3 | "class2":["image4.jpg","image5.jpg","image6.jpg","image7.jpg"] 4 | } -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | # This includes the license file(s) in the wheel. 3 | # https://wheel.readthedocs.io/en/stable/user_guide.html#including-license-files-in-the-generated-wheel-file 4 | license_files = LICENSE.txt 5 | description-file = README.md -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 88 3 | include = '\.pyi?$' 4 | exclude = ''' 5 | /( 6 | \.git 7 | | \.hg 8 | | \.mypy_cache 9 | | \.tox 10 | | \.venv 11 | | _build 12 | | buck-out 13 | | build 14 | | dist 15 | | docs 16 | )/ 17 | ''' 18 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: stable 4 | hooks: 5 | - id: black 6 | language_version: python3 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v2.3.0 9 | hooks: 10 | - id: flake8 11 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | # Dependencies 2 | /node_modules 3 | 4 | # Production 5 | /build 6 | 7 | # Generated files 8 | .docusaurus 9 | .cache-loader 10 | 11 | # Misc 12 | .DS_Store 13 | .env.local 14 | .env.development.local 15 | .env.test.local 16 | .env.production.local 17 | 18 | npm-debug.log* 19 | yarn-debug.log* 20 | yarn-error.log* 21 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | install: 2 | python -m pip install --upgrade pip 3 | pip install -r requirements.txt 4 | 5 | sdist: 6 | python setup.py sdist 7 | 8 | upload: 9 | twine upload dist/* 10 | 11 | sample: 12 | python examples/example.py 13 | 14 | clean: 15 | rm -rf *.egg-info 16 | rm -rf dist/ 17 | 18 | test: 19 | python tests/tests.py 20 | 21 | deploy: clean sdist upload 22 | -------------------------------------------------------------------------------- /betterloader/standard_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import transforms 3 | np.random.seed(1) 4 | 5 | class TransformWhileSampling(object): 6 | 7 | def __init__(self, transform): 8 | self.transform = transform 9 | 10 | 11 | def __call__(self, sample): 12 | 13 | x1 = self.transform(sample) 14 | x2 = self.transform(sample) 15 | 16 | return x1, x2 -------------------------------------------------------------------------------- /docs/blog/2020-09-15-helloworld.md: -------------------------------------------------------------------------------- 1 | --- 2 | slug: helloworld 3 | title: Hello World 4 | author: Raghav Mecheri 5 | author_title: Engineering @ BinIt 6 | author_url: https://github.com/raghavmecheri 7 | author_image_url: https://avatars3.githubusercontent.com/u/37787004?s=460&u=d48281f3aee13537b28d33ebe23ddee5af16d65f&v=4 8 | tags: [binit, docusaurus] 9 | --- 10 | 11 | Welcome to BetterLoader! After months of painstalkingly maintaining numerous image directories for different PyTorch experiments, we decided that the PyTorch dataloader needed a facelift. BetterLoader is still most definitely a work in progress, but we'd love to hear what you think! 12 | 13 | This blog is powered by Docusaurus :) -------------------------------------------------------------------------------- /examples/unsupervised.py: -------------------------------------------------------------------------------- 1 | from betterloader import UnsupervisedBetterLoader 2 | from betterloader.defaults import collate_metadata 3 | from PIL import Image 4 | 5 | index_json = "./sample_index_unsupervised.json" 6 | basepath = "./sample_dataset/" 7 | batch_size = 2 8 | metadata = collate_metadata() 9 | better_loader = UnsupervisedBetterLoader( 10 | basepath=basepath, 11 | base_experiment_details=["simclr", 1, (150, 150)], 12 | index_json_path=index_json, 13 | dataset_metadata=metadata, 14 | ) 15 | dataloaders, sizes = better_loader.fetch_segmented_dataloaders(batch_size=batch_size) 16 | 17 | for i, ((xp1, xp2), _) in enumerate(dataloaders["train"]): 18 | print(i) 19 | print(xp1.shape) 20 | -------------------------------------------------------------------------------- /examples/supervised.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple betterloader.BetterLoader usage example. We should really add to these. 3 | Notes: 4 | 1. This example uses the default parameters. 5 | 2. The structure of the index file can be any json as long as you change the other parameters to work with that 6 | """ 7 | 8 | from betterloader import BetterLoader 9 | 10 | index_json = "./examples/sample_index.json" 11 | basepath = "./examples/sample_dataset/" 12 | batch_size = 2 13 | 14 | loader = BetterLoader(basepath=basepath, index_json_path=index_json) 15 | dataloaders, sizes = loader.fetch_segmented_dataloaders( 16 | batch_size=batch_size, transform=None 17 | ) 18 | 19 | print("Dataloader sizes: {}".format(str(sizes))) 20 | -------------------------------------------------------------------------------- /docs/src/pages/styles.module.css: -------------------------------------------------------------------------------- 1 | /* stylelint-disable docusaurus/copyright-header */ 2 | 3 | /** 4 | * CSS files with the .module.css suffix will be treated as CSS modules 5 | * and scoped locally. 6 | */ 7 | 8 | .heroBanner { 9 | padding: 4rem 0; 10 | text-align: center; 11 | position: relative; 12 | overflow: hidden; 13 | } 14 | 15 | @media screen and (max-width: 966px) { 16 | .heroBanner { 17 | padding: 2rem; 18 | } 19 | } 20 | 21 | .buttons { 22 | display: flex; 23 | align-items: center; 24 | justify-content: center; 25 | } 26 | 27 | .features { 28 | display: flex; 29 | align-items: center; 30 | padding: 2rem 0; 31 | width: 100%; 32 | } 33 | 34 | .featureImage { 35 | height: 200px; 36 | width: 200px; 37 | } 38 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/Feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🚀 Feature Request 3 | about: I have a suggestion (and may want to implement it 🙂)! 4 | 5 | --- 6 | 7 | ## Feature Request 8 | 9 | ### Description of Problem: 10 | ...what *problem* are you trying to solve that this project doesn't currently solve? 11 | 12 | ...please resist the temptation to describe your request in terms of a solution. Job Story form ("When [triggering condition], I want to [motivation/goal], so I can [outcome].") can help ensure you're expressing a problem statement. 13 | 14 | ### Potential Solutions: 15 | ...clearly and concisely describe what you want to happen. Add any considered drawbacks. 16 | 17 | ... if you've considered alternatives, clearly and concisely describe those too 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/Bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: If something isn't working as expected 🤔. 4 | 5 | --- 6 | 7 | ## Bug Report 8 | 9 | ### Steps to Reproduce: 10 | 1. Step 1 description 11 | 2. Step 2 description 12 | 3. Step 3 description 13 | 14 | ### Expected Result: 15 | ...description of what you expected to see... 16 | 17 | ### Actual Result: 18 | ...what actually happened, including full exceptions (please include the entire stack trace, including "caused by" entries), log entries, screen shots etc. where appropriate. 19 | 20 | ### Environment: 21 | ...version and build of the project, OS and runtime versions, virtualised environment (if any), etc. 22 | 23 | ### Additional Context: 24 | ...add any other context about the problem here. If applicable, add screenshots to help explain. 25 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Website 2 | 3 | This website is built using [Docusaurus 2](https://v2.docusaurus.io/), a modern static website generator. 4 | 5 | ### Installation 6 | 7 | ``` 8 | $ yarn 9 | ``` 10 | 11 | ### Local Development 12 | 13 | ``` 14 | $ yarn start 15 | ``` 16 | 17 | This command starts a local development server and open up a browser window. Most changes are reflected live without having to restart the server. 18 | 19 | ### Build 20 | 21 | ``` 22 | $ yarn build 23 | ``` 24 | 25 | This command generates static content into the `build` directory and can be served using any static contents hosting service. 26 | 27 | ### Deployment 28 | 29 | ``` 30 | $ GIT_USER= USE_SSH=true yarn deploy 31 | ``` 32 | 33 | If you are using GitHub pages for hosting, this command is a convenient way to build the website and push to the `gh-pages` branch. 34 | -------------------------------------------------------------------------------- /docs/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "docs", 3 | "version": "0.0.0", 4 | "private": true, 5 | "scripts": { 6 | "docusaurus": "docusaurus", 7 | "start": "docusaurus start", 8 | "build": "docusaurus build", 9 | "swizzle": "docusaurus swizzle", 10 | "deploy": "docusaurus deploy", 11 | "serve": "docusaurus serve" 12 | }, 13 | "dependencies": { 14 | "@docusaurus/core": "2.0.0-alpha.64", 15 | "@docusaurus/preset-classic": "2.0.0-alpha.64", 16 | "@mdx-js/react": "^1.5.8", 17 | "clsx": "^1.1.1", 18 | "react": "^16.8.4", 19 | "react-dom": "^16.8.4" 20 | }, 21 | "browserslist": { 22 | "production": [ 23 | ">0.2%", 24 | "not dead", 25 | "not op_mini all" 26 | ], 27 | "development": [ 28 | "last 1 chrome version", 29 | "last 1 firefox version", 30 | "last 1 safari version" 31 | ] 32 | } 33 | } -------------------------------------------------------------------------------- /docs/src/css/custom.css: -------------------------------------------------------------------------------- 1 | /* stylelint-disable docusaurus/copyright-header */ 2 | /** 3 | * Any CSS included here will be global. The classic template 4 | * bundles Infima by default. Infima is a CSS framework designed to 5 | * work well for content-centric websites. 6 | */ 7 | 8 | /* You can override the default Infima variables here. */ 9 | :root { 10 | --ifm-color-primary: rgb(222, 29, 27); 11 | --ifm-color-primary-dark: rgb(33, 175, 144); 12 | --ifm-color-primary-darker: rgb(31, 165, 136); 13 | --ifm-color-primary-darkest: rgb(26, 136, 112); 14 | --ifm-color-primary-light: rgb(70, 203, 174); 15 | --ifm-color-primary-lighter: rgb(102, 212, 189); 16 | --ifm-color-primary-lightest: rgb(146, 224, 208); 17 | --ifm-code-font-size: 95%; 18 | } 19 | 20 | .docusaurus-highlight-code-line { 21 | background-color: rgb(72, 77, 91); 22 | display: block; 23 | margin: 0 calc(-1 * var(--ifm-pre-padding)); 24 | padding: 0 var(--ifm-pre-padding); 25 | } 26 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE/Bugfix.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🐛 Bug Fix 3 | about: Fixes a bug reported in our issues list 📕 4 | 5 | --- 6 | 7 | # Bugfix 8 | 9 | 11 | 12 | ## Testing 13 | 14 | 16 | 17 | ## Screenshots (if appropriate) 18 | 19 | 21 | 22 | ## Checklist 23 | 24 | 25 | 26 | - [ ] I have read the [`CONTRIBUTING.md`](https://github.com/BinItAI/BetterLoader/blob/master/CONTRIBUTING.md) and followed its [Guidelines](https://github.com/BinItAI/BetterLoader/blob/master/CONTRIBUTING.md#guidelines) 27 | - [ ] I have linted my code locally, following the project's code style 28 | - [ ] I have tested my changes locally. 29 | -------------------------------------------------------------------------------- /.github/workflows/pr.yaml: -------------------------------------------------------------------------------- 1 | name: Pull Build 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - master 7 | 8 | jobs: 9 | Test: 10 | runs-on: ubuntu-latest 11 | 12 | strategy: 13 | matrix: 14 | node-version: [10.x, 12.x] 15 | python-version: [3.6, 3.7, 3.8] 16 | 17 | steps: 18 | - name: Git checkout 19 | uses: actions/checkout@v2 20 | 21 | - name: Use Node.js ${{ matrix.node-version }} 22 | uses: actions/setup-node@v1 23 | with: 24 | node-version: ${{ matrix.node-version }} 25 | 26 | - uses: borales/actions-yarn@v2.0.0 27 | 28 | - name: Use Python ${{ matrix.python-version }} 29 | uses: actions/setup-python@v2 30 | with: 31 | python-version: ${{ matrix.python-version }} 32 | 33 | - name: Install Python Dependancies 34 | run: make install 35 | 36 | - name: Run Python test suite 37 | run: | 38 | export PYTHONPATH=$PYTHONPATH:$(pwd) 39 | make test 40 | 41 | - name: Install documentation dependancies & build docs 42 | run: | 43 | cd docs 44 | yarn 45 | yarn build 46 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | ----------- 3 | 4 | Copyright (c) 2020 BinIt Inc (https://binit.in) 5 | Permission is hereby granted, free of charge, to any person 6 | obtaining a copy of this software and associated documentation 7 | files (the "Software"), to deal in the Software without 8 | restriction, including without limitation the rights to use, 9 | copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the 11 | Software is furnished to do so, subject to the following 12 | conditions: 13 | 14 | The above copyright notice and this permission notice shall be 15 | included in all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 18 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 19 | OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 20 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 21 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 22 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 23 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 24 | OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE/Documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 📄 Documentation 3 | about: Improves documentation for the API or on the docs website. 4 | 5 | --- 6 | 7 | # Documentation 8 | 9 | 10 | 11 | ## Validation 12 | 13 | 18 | 19 | - [ ] I have linted, spell-checked, and grammar-checked my documentation additions. 20 | - [ ] Within `/docs`, I have validated the changes to the documentation after running `yarn start` within the directory 21 | - [ ] My additions are styled and structured correctly on documentation site 22 | 23 | ## Checklist 24 | 25 | 26 | 27 | - [ ] I have read the [`CONTRIBUTING.md`](https://github.com/BinItAI/BetterLoader/blob/master/CONTRIBUTING.md) and followed its [Guidelines](https://github.com/BinItAI/BetterLoader/blob/master/CONTRIBUTING.md#guidelines) 28 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | alabaster==0.7.12 2 | appdirs==1.4.4 3 | astroid==2.4.2 4 | autopep8==1.5.4 5 | Babel==2.8.0 6 | bleach==3.1.5 7 | certifi==2020.6.20 8 | cfgv==3.2.0 9 | chardet==3.0.4 10 | colorama==0.4.3 11 | distlib==0.3.1 12 | docutils==0.16 13 | filelock==3.0.12 14 | future==0.18.2 15 | identify==1.5.2 16 | idna==2.10 17 | imagesize==1.2.0 18 | importlib-metadata==1.7.0 19 | isort==5.5.2 20 | Jinja2==2.11.2 21 | keyring==21.3.0 22 | lazy-object-proxy==1.4.3 23 | MarkupSafe==1.1.1 24 | mccabe==0.6.1 25 | nodeenv==1.5.0 26 | numpy==1.19.1 27 | packaging==20.4 28 | Pillow==7.2.0 29 | pkginfo==1.5.0.1 30 | pre-commit==2.9.2 31 | pycodestyle==2.6.0 32 | Pygments==2.6.1 33 | pylint==2.6.0 34 | pyparsing==2.4.7 35 | pytz==2020.1 36 | PyYAML==5.3.1 37 | readme-renderer==26.0 38 | requests==2.24.0 39 | requests-toolbelt==0.9.1 40 | rfc3986==1.4.0 41 | six==1.15.0 42 | snowballstemmer==2.0.0 43 | Sphinx==3.2.0 44 | sphinxcontrib-applehelp==1.0.2 45 | sphinxcontrib-devhelp==1.0.2 46 | sphinxcontrib-htmlhelp==1.0.3 47 | sphinxcontrib-jsmath==1.0.1 48 | sphinxcontrib-qthelp==1.0.3 49 | sphinxcontrib-serializinghtml==1.1.4 50 | toml==0.10.1 51 | torch==1.7.1 52 | torchvision==0.8.2 53 | tqdm==4.48.2 54 | twine==3.2.0 55 | typed-ast==1.4.1 56 | urllib3==1.25.10 57 | virtualenv==20.0.31 58 | webencodings==0.5.1 59 | wrapt==1.12.1 60 | zipp==3.1.0 61 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """setup file for BetterLoader deployments 2 | """ 3 | 4 | from os import path 5 | from setuptools import setup 6 | 7 | this_directory = path.abspath(path.dirname(__file__)) 8 | with open(path.join(this_directory, "README.md"), encoding="utf-8") as f: 9 | long_description = f.read() 10 | 11 | setup( 12 | name="BetterLoader", 13 | version="0.2.2", 14 | description="A better PyTorch dataloader", 15 | long_description=long_description, 16 | long_description_content_type="text/markdown", 17 | url="https://github.com/BinItAI/BetterLoader", 18 | author="BinIt Inc", 19 | author_email="support@binit.in", 20 | license="MIT", 21 | download_url="https://github.com/BinItAI/BetterLoader/archive/0.1.5.zip", 22 | packages=["betterloader"], 23 | install_requires=[ 24 | "future==0.18.2", 25 | "numpy==1.19.1", 26 | "Pillow==7.2.0", 27 | "torch==1.7.1", 28 | "torchvision==0.8.2", 29 | ], 30 | classifiers=[ 31 | "Development Status :: 3 - Alpha", 32 | "Intended Audience :: Science/Research", 33 | "License :: OSI Approved :: MIT License", 34 | "Operating System :: POSIX :: Linux", 35 | "Operating System :: MacOS :: MacOS X", 36 | "Operating System :: Microsoft :: Windows :: Windows 10", 37 | "Programming Language :: Python :: 3.7", 38 | ], 39 | ) 40 | -------------------------------------------------------------------------------- /.github/workflows/push.yaml: -------------------------------------------------------------------------------- 1 | name: Build 2 | on: 3 | push: 4 | branches: 5 | - master 6 | 7 | jobs: 8 | Test: 9 | runs-on: ubuntu-latest 10 | 11 | strategy: 12 | matrix: 13 | node-version: [10.x, 12.x] 14 | python-version: [3.6, 3.7, 3.8] 15 | 16 | steps: 17 | - name: Git checkout 18 | uses: actions/checkout@v2 19 | 20 | - name: Use Node.js ${{ matrix.node-version }} 21 | uses: actions/setup-node@v1 22 | with: 23 | node-version: ${{ matrix.node-version }} 24 | 25 | - uses: borales/actions-yarn@v2.0.0 26 | 27 | - name: Use Python ${{ matrix.python-version }} 28 | uses: actions/setup-python@v2 29 | with: 30 | python-version: ${{ matrix.python-version }} 31 | 32 | - name: Install Python Dependancies 33 | run: make install 34 | 35 | - name: Run Python test suite 36 | run: | 37 | export PYTHONPATH=$PYTHONPATH:$(pwd) 38 | make test 39 | 40 | - name: Install documentation dependancies & build docs 41 | run: | 42 | cd docs 43 | yarn 44 | yarn build 45 | 46 | Deploy: 47 | runs-on: ubuntu-latest 48 | 49 | steps: 50 | - name: Git checkout 51 | uses: actions/checkout@v2 52 | 53 | - name: Use Node 10 54 | uses: actions/setup-node@v1 55 | 56 | - uses: borales/actions-yarn@v2.0.0 57 | 58 | - name: Build docs and deploy 59 | run: | 60 | git config --global user.name "${{ secrets.GH_NAME }}" 61 | git config --global user.email "${{ secrets.GH_EMAIL }}" 62 | echo "machine github.com login ${{ secrets.GH_NAME }} password ${{ secrets.GH_TOKEN }}" > ~/.netrc 63 | cd docs && yarn 64 | yarn deploy 65 | env: 66 | GIT_USER: ${{ secrets.GH_NAME }} 67 | 68 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE/Feature.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: New Feature 3 | about: Adds a new feature to BetterLoader 🙂! 4 | 5 | --- 6 | 7 | # Feature 8 | 9 | 18 | 19 | 22 | 23 | - [ ] Breaking feature 24 | - [ ] Non-breaking feature 25 | 26 | 27 | ## Changelog 28 | 29 | - {A bullet-pointed changelog that outlines the contents of your PR.} 30 | 31 | ## Testing 32 | 33 | 42 | 43 | ## Screenshots (if appropriate) 44 | 45 | 47 | 48 | ## Checklist 49 | 50 | 51 | 52 | - [ ] I have read the [`CONTRIBUTING.md`](https://github.com/BinItAI/BetterLoader/blob/master/CONTRIBUTING.md) and followed its [Guidelines](https://github.com/BinItAI/BetterLoader/blob/master/CONTRIBUTING.md#guidelines) 53 | - [ ] I have linted my code locally, following the project's code style 54 | - [ ] I have tested my changes locally. 55 | -------------------------------------------------------------------------------- /docs/docusaurus.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | title: 'BetterLoader', 3 | tagline: 'The augmented PyTorch dataloader', 4 | url: 'https://binitai.github.io', 5 | baseUrl: '/BetterLoader/', 6 | onBrokenLinks: 'throw', 7 | favicon: 'img/favicon.ico', 8 | organizationName: 'BinItAI', 9 | projectName: 'BetterLoader', 10 | themeConfig: { 11 | navbar: { 12 | title: 'BetterLoader', 13 | logo: { 14 | alt: 'My Site Logo', 15 | src: 'img/logo.png', 16 | }, 17 | items: [ 18 | { 19 | to: 'docs/', 20 | activeBasePath: 'docs', 21 | label: 'Docs', 22 | position: 'left', 23 | }, 24 | {to: 'blog', label: 'Blog', position: 'left'}, 25 | { 26 | href: 'https://github.com/binitai/BetterLoader', 27 | label: 'GitHub', 28 | position: 'right', 29 | }, 30 | ], 31 | }, 32 | footer: { 33 | style: 'dark', 34 | links: [ 35 | { 36 | title: 'Docs', 37 | items: [ 38 | { 39 | label: 'Getting Started', 40 | to: 'docs/', 41 | } 42 | ], 43 | }, 44 | { 45 | title: 'Community', 46 | items: [ 47 | { 48 | label: 'Stack Overflow', 49 | href: 'https://stackoverflow.com/questions/tagged/BetterLoader', 50 | }, 51 | { 52 | label: 'Discord', 53 | href: 'https://discord.gg/T4Hxcq6', 54 | } 55 | ], 56 | } 57 | ], 58 | copyright: `Copyright © ${new Date().getFullYear()} BinIt, Inc. Built with Docusaurus.`, 59 | }, 60 | }, 61 | presets: [ 62 | [ 63 | '@docusaurus/preset-classic', 64 | { 65 | docs: { 66 | sidebarPath: require.resolve('./sidebars.js'), 67 | // Please change this to your repo. 68 | editUrl: 69 | 'https://github.com/binitai/BetterLoader', 70 | }, 71 | theme: { 72 | customCss: require.resolve('./src/css/custom.css'), 73 | }, 74 | }, 75 | ], 76 | ], 77 | }; 78 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Thank you for your interest in contributing to BetterLoader! 4 | 5 | BetterLoader is built on open source and is maintained by the team at [BinIt](https://github.com/binitai/). We invite you to participate in our community by adding and commenting on [issues](https://github.com/BinItAI/BetterLoader/issues) (e.g., bug reports; new feature suggestions) or contributing code enhancements through a pull request. 6 | 7 | If you have any general questions about contributing to BetterLoader, please feel free to email either [Raghav](mailto:raghav.mecheri@columbia.edu) or [James](mailto:jbb2170@columbia.edu), or just open an issue on [Github](https://github.com/BinItAI/BetterLoader/issues/new). 8 | ## Guidelines 9 | 10 | When submitting PRs to BetterLoader, please respect the following general 11 | coding guidelines: 12 | 13 | * All PRs should be accompanied by an appropriate label as per [lerna-changelog](https://github.com/lerna/lerna-changelog), and reference any issue they resolve. 14 | * Please try to keep PRs small and focused. If you find your PR touches multiple loosely related changes, it may be best to break up into multiple PRs. 15 | * Individual commits should preferably do One Thing (tm), and have descriptive commit messages. Do not make "WIP" or other mystery commit messages. 16 | * ... that being said, one-liners or other commits should typically be grouped. Please try to keep 'cleanup', 'formatting' or other non-functional changes to a single commit at most in your PR. 17 | * PRs that involve moving files around the repository tree should be organized in a stand-alone commit from actual code changes. 18 | * Please do not submit incomplete PRs or partially implemented features. Feature additions should be implemented completely. 19 | * Please do not submit PRs disabled by feature or build flag - experimental features should be kept on a branch until they are ready to be merged. 20 | * For feature additions, make sure you have added complete docstrings to any new APIs, as well as additions to the [Usage Guide]() if applicable. 21 | * All PRs should be accompanied by tests asserting their behavior in any packages they modify. 22 | * Do not commit with `--no-verify` or otherwise bypass commit hooks, and please respect the formatting and linting guidelines they enforce. 23 | * Do not `merge master` upstream changes into your PR. If your change has conflicts with the `master` branch, please pull master into your fork's master, then rebase. 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # Ignore vscode files 141 | .vscode/ 142 | 143 | # Ignore website build directory 144 | website/ 145 | -------------------------------------------------------------------------------- /docs/src/pages/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import clsx from 'clsx'; 3 | import Layout from '@theme/Layout'; 4 | import Link from '@docusaurus/Link'; 5 | import useDocusaurusContext from '@docusaurus/useDocusaurusContext'; 6 | import useBaseUrl from '@docusaurus/useBaseUrl'; 7 | import styles from './styles.module.css'; 8 | 9 | const features = [ 10 | { 11 | title: 'Easy to Use', 12 | imageUrl: 'img/undraw_fast_loading.svg', 13 | description: ( 14 | <> 15 | BetterLoader was designed from the ground up to be easily installed and 16 | used to get your deep learning processes up and running quickly. 17 | 18 | ), 19 | }, 20 | { 21 | title: 'Expanded Functionality', 22 | imageUrl: 'img/undraw_expanded_functionality.svg', 23 | description: ( 24 | <> 25 | BetterLoader lets you focus on your deep learning models, while we handle conditional data loading, or multiple copies of files. 26 | 27 | ), 28 | }, 29 | { 30 | title: 'Compatible with PyTorch', 31 | imageUrl: 'img/undraw_compatibility.svg', 32 | description: ( 33 | <> 34 | BetterLoader just extends PyTorch classes. Plug BetterLoader calls right into your workflow, with minimal changes. 35 | 36 | ), 37 | }, 38 | ]; 39 | 40 | function Feature({imageUrl, title, description}) { 41 | const imgUrl = useBaseUrl(imageUrl); 42 | return ( 43 |
44 | {imgUrl && ( 45 |
46 | {title} 47 |
48 | )} 49 |

{title}

50 |

{description}

51 |
52 | ); 53 | } 54 | 55 | function Home() { 56 | const context = useDocusaurusContext(); 57 | const {siteConfig = {}} = context; 58 | return ( 59 | 62 |
63 |
64 |

{siteConfig.title}

65 |

{siteConfig.tagline}

66 |
67 | 73 | Get Started 74 | 75 |
76 |
77 |
78 |
79 | {features && features.length > 0 && ( 80 |
81 |
82 |
83 | {features.map((props, idx) => ( 84 | 85 | ))} 86 |
87 |
88 |
89 | )} 90 |
91 |
92 | ); 93 | } 94 | 95 | export default Home; 96 | -------------------------------------------------------------------------------- /docs/docs/files.md: -------------------------------------------------------------------------------- 1 | --- 2 | id: files 3 | title: Index & Subset Files 4 | sidebar_label: Index & Subset Files 5 | slug: /files 6 | --- 7 | 8 | ## Overview 9 | BetterLoader uses two types of files to do some really interesting stuff. These are index files and subset files. 10 | Index files allow you to specify labelled groupings for your image dataset, which allows you to maintain your actual data within a single flat folder. Subset files, on the other hand, allow you to specify a list of image paths to load, which consequently are labelled via the index file. This allows you to load subsets of your dataset, and run multiple experiments all with minimal file management. 11 | 12 | ### Index Files 13 | Index JSON files are default used to create a mapping from label, to filenames. Index files are by default, expected to be formatted as key-value pairs, where the values are lists of filenames. However, this format can be overriden by passing a value to the `train_test_val_instances` key of the `dataset_metadata` parameter of the BetterLoader constructor.
Since the format is so flexible there are many things you can do, for example index files can use regex as long as the train_test_val_instances function is setup to parse the regex correctly. There's nothing hardcoded about the index file except that it has to be a json. A sample index file would look something like: 14 | 15 | ```json 16 | { 17 | "class1":["image0.jpg","image1.jpg","image2.jpg","image3.jpg"], 18 | "class2":["image4.jpg","image5.jpg","image6.jpg","image7.jpg"] 19 | } 20 | ``` 21 | 22 | An index file can also look like the below. The only catch here is that if you do this, you will have to pass a custom function as the `train_test_val_instances` key of the `dataset_metadata` parameter of the BetterLoader constructor. An example of such a function to group filenames based on regex can be found here. 23 | ```json 24 | { 25 | "class1": "^a", 26 | "class2": "^b" 27 | } 28 | ``` 29 | 30 | ### Subset Files 31 | As their names suggest, subset JSON files are used to instruct the BetterLoader to limit itself to only a subset of the dataset present at the root of the directory being loaded from. Currently, subset files just consist of a list of allowed files (as we've been auto-generating them as a part of our workflow), but this is definitely something we would be open to refining as well. A sample subset file would look something like this: 32 | ```json 33 | ["image0.jpg","image1.jpg", "image2.jpg", "image3.jpg", "image4.jpg"] 34 | ``` 35 | 36 | ## Usage 37 | An index is required to use the BetterLoader, either a path to an index json must be supplied, or an index object. The index object may be any python object, and replaces what would be the result of loading the index json. An index is required because it replaces the traditional approach that the PyTorch dataloader uses involving using folder names to infer class label. Since we've done away with this mechanism entirely, an index is essential to loading data for supervised learning tasks.
38 | Subset files, are an optional parameter. If a subset file is not specified, then the BetterLoader will just load your entire dataset :) 39 | You may also use a subset object, which is entirely analogous to the way index objects work. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 | 5 | 6 | **Making it harder to do easy things, but easier to do harder things with the Pytorch Dataloader** 7 | 8 | --- 9 | 10 |

11 | About • 12 | Installation • 13 | Usage • 14 | Development • 15 | License 16 |

17 | 18 | [![PyPi Badge](https://img.shields.io/pypi/dm/betterloader?style=for-the-badge)](https://pypi.org/project/BetterLoader/) 19 | [![PyPi Version](https://img.shields.io/pypi/v/betterloader?style=for-the-badge)](https://pypi.org/project/BetterLoader/) 20 | [![Github Actions Build Status](https://img.shields.io/github/workflow/status/BinItAI/BetterLoader/Build?style=for-the-badge)](https://img.shields.io/github/workflow/status/BinItAI/BetterLoader/Build?style=for-the-badge) 21 | [![Issues](https://img.shields.io/github/issues/binitai/betterloader?style=for-the-badge)](https://github.com/BinItAI/BetterLoader/issues) 22 | [![license](https://img.shields.io/github/license/binitai/betterloader?style=for-the-badge)](https://github.com/BinItAI/BetterLoader/blob/master/LICENSE.txt) 23 | 24 |
25 | 26 | --- 27 | 28 | ## About BetterLoader 29 | BetterLoader is a hyper-customizable extension of the default PyTorch dataloader class, that allows for custom transformations pre-load and image subset definitions. Use the power of custom index files to maintain only a single copy of a dataset with a fixed, flat file structure, and allow BetterLoader to do all the heavy lifting. 30 | 31 | ## Installation 32 | ```sh 33 | pip install betterloader 34 | ``` 35 | 36 | ## Usage 37 | BetterLoader allows you to dynamically assign images to labels, load subsets of images conditionally, perform custom pretransforms before loading an image, and much more. 38 | 39 | ### Basic Usage 40 | A few points worth noting are that: 41 | - BetterLoader does not expect a nested folder structure. In its current iteration, files are expected to all be present in the root directory. 42 | - Every instance of BetterLoader requires an index file to function. Sample index files may be found here. 43 | 44 | ```python 45 | from betterloader import BetterLoader 46 | 47 | index_json = './examples/sample_index.json' 48 | basepath = "./examples/sample_dataset/" 49 | batch_size = 2 50 | 51 | loader = BetterLoader(basepath=basepath, index_json_path=index_json) 52 | dataloaders, sizes = loader.fetch_segmented_dataloaders(batch_size=batch_size, transform=None) 53 | 54 | print("Dataloader sizes: {}".format(str(sizes))) 55 | ``` 56 | For more information and more detailed examples, please check out the BetterLoader docs! 57 | 58 | ## Development 59 | 60 | We use Makefile to make our lives a little easier :) 61 | ### Install Dependancies 62 | ```sh 63 | make install 64 | ``` 65 | ### Run Sample 66 | ```sh 67 | make sample 68 | ``` 69 | ### Run Unit Tests 70 | ```sh 71 | make test 72 | ``` 73 | 74 | ## Meta 75 | Distributed under the MIT license. See ``LICENSE`` for more information. 76 | 77 | ## Documentation & Usage 78 | - [Usage docs](https://binitai.github.io/BetterLoader/) 79 | - [Example implementation](./examples) 80 | - [Contributing](./CONTRIBUTING.md) 81 | -------------------------------------------------------------------------------- /betterloader/ImageFolderCustom.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified version of the PyTorch ImageFolder class to make custom dataloading possible 3 | 4 | """ 5 | 6 | from PIL import Image 7 | 8 | from .DatasetFolder import DatasetFolder 9 | 10 | IMG_EXTENSIONS = ( 11 | ".jpg", 12 | ".jpeg", 13 | ".png", 14 | ".ppm", 15 | ".bmp", 16 | ".pgm", 17 | ".tif", 18 | ".tiff", 19 | ".webp", 20 | ) 21 | 22 | 23 | def pil_loader(path): 24 | """open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 25 | Args: 26 | path: Load image at path 27 | Returns: 28 | Pil.Image: A PIL image object in RGB format 29 | """ 30 | with open(path, "rb") as f: 31 | img = Image.open(f) 32 | return img.convert("RGB") 33 | 34 | 35 | def accimage_loader(path): 36 | """Helper to try and load image as an accimage, if supported 37 | Args: 38 | path: Path to target image 39 | Returns: 40 | var: Image object, either as an accimage, or a PIL image 41 | """ 42 | import accimage # pylint: disable=import-error 43 | 44 | try: 45 | return accimage.Image(path) 46 | except IOError: 47 | # Potentially a decoding problem, fall back to PIL.Image 48 | return pil_loader(path) 49 | 50 | 51 | def default_loader(path): 52 | """Load images given a path, either via accimage or PIL 53 | Args: 54 | path: Path to target image 55 | Returns: 56 | var: Image object, either as an accimage, or a PIL image 57 | """ 58 | from torchvision import get_image_backend 59 | 60 | if get_image_backend() == "accimage": 61 | return accimage_loader(path) 62 | 63 | return pil_loader(path) 64 | 65 | 66 | def default_classdata(_, index): 67 | """Load default classdata if no class data is passed 68 | Args: 69 | _: Ignored path value 70 | index: Index file dictionary 71 | Returns: 72 | classes: A list of image classes 73 | class_to_idx: Mapping from classes to indexes of those classes 74 | """ 75 | classes = list(index.keys()) 76 | classes.sort() 77 | class_to_idx = {classes[i]: i for i in range(len(classes))} 78 | return classes, class_to_idx 79 | 80 | 81 | class ImageFolderCustom(DatasetFolder): # pylint: disable=too-few-public-methods 82 | """A generic data loader for images :: 83 | 84 | Args: 85 | root (string): Root directory path. 86 | transform (callable, optional): A function/transform that takes in an PIL image 87 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 88 | target_transform (callable, optional): A function/transform that takes in the 89 | target and transforms it. 90 | loader (callable, optional): A function to load an image given its path. 91 | is_valid_file (callable, optional): A function that takes path of an Image file 92 | and check if the file is a valid file (used to check of corrupt files) 93 | instance (sting, optional): Either 'train' 'test' or 'val' whether or not you want the train test or val split 94 | index (dict[string:list[string]], optional): A dictionary that maps each class to a list of the image paths for that class along with whetever other data you need to make your dataset 95 | this can really be whatever you want because it is only handled by train_test_val_instances. 96 | train_test_val_instances (callable, optional): A function that takes: 97 | a root directory, 98 | a mapping of class names to indeces, 99 | the index, 100 | and is_valid_file 101 | and returns a tuple of lists containing the instance data for each of train test and val, 102 | the instance data in the list is a tuple and can have whatever structure you want as long as the image path is the first element 103 | each of these tuples is processed by the pretransform 104 | class_data (tuple, optional): the first element is a list of the classes, the second is a mapping of the classes to their indeces 105 | pretransform (callable, optional): A function that takes the loaded image and any other relevant data for that image and returns a transformed version of that image 106 | 107 | Attributes: 108 | classes (list): List of the class names sorted alphabetically. 109 | class_to_idx (dict): Dict with items (class_name, class_index). 110 | imgs (list): List of (image path, class_index) tuples 111 | """ 112 | 113 | def __init__( 114 | self, 115 | root, 116 | transform=None, 117 | target_transform=None, 118 | loader=default_loader, 119 | is_valid_file=None, 120 | instance="train", 121 | index=None, 122 | train_test_val_instances=None, 123 | class_data=None, 124 | pretransform=None, 125 | ): 126 | 127 | class_data = default_classdata if class_data is None else class_data 128 | 129 | super(ImageFolderCustom, self).__init__( 130 | root, 131 | loader, 132 | IMG_EXTENSIONS if is_valid_file is None else None, 133 | transform=transform, 134 | target_transform=target_transform, 135 | is_valid_file=is_valid_file, 136 | instance=instance, 137 | index=index, 138 | train_test_val_instances=train_test_val_instances, 139 | class_data=class_data, 140 | pretransform=pretransform, 141 | ) 142 | self.imgs = self.samples 143 | -------------------------------------------------------------------------------- /docs/static/img/undraw_expanded_functionality.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /betterloader/DatasetFolder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified version of the PyTorch DatasetFolder class to make custom dataloading possible 3 | 4 | """ 5 | 6 | import os 7 | import os.path 8 | 9 | from torchvision.datasets import VisionDataset 10 | 11 | 12 | def has_file_allowed_extension(filename, extensions): 13 | """Checks if a file is an allowed extension. 14 | 15 | Args: 16 | filename (string): path to a file 17 | extensions (tuple of strings): extensions to consider (lowercase) 18 | 19 | Returns: 20 | bool: True if the filename ends with one of given extensions 21 | """ 22 | return filename.lower().endswith(extensions) 23 | 24 | 25 | def default_pretransform(sample, values): 26 | """Returns the image sample without transforming it at all 27 | Args: 28 | sample: Loaded image data 29 | values: Tuple such that the 1th arguement is the target (defined by default) 30 | 31 | Returns: 32 | var: The loaded sample image 33 | int: Value representing the image class (label for data) 34 | """ 35 | target = values[1] 36 | return sample, target 37 | 38 | 39 | def make_dataset( 40 | directory, 41 | class_to_idx, 42 | extensions=None, 43 | is_valid_file=None, 44 | instance="train", 45 | index=None, 46 | train_test_val_instances=None, 47 | ): 48 | """Makes the actual dataset 49 | Args: 50 | directory (string): Root directory path. 51 | class_to_idx (dict): Dict which maps classes to index values 52 | extensions (tuple[string]): A list of allowed extensions. both extensions and is_valid_file should not be passed. 53 | is_valid_file (callable, optional): A function that takes path of a file and check if the file is a valid file (used to check of corrupt files) both extensions and is_valid_file should not be passed. 54 | instance (str): String signifying data segment (train, test, val) 55 | index (dict): Index file dict data 56 | train_test_val_instances (callable, optional): Returns custom breakup for train, test, val data 57 | 58 | Returns: 59 | list: List of PyTorch instance data to be loaded 60 | """ 61 | directory = os.path.expanduser(directory) 62 | both_none = extensions is None and is_valid_file is None 63 | both_something = extensions is not None and is_valid_file is not None 64 | if both_none or both_something: 65 | raise ValueError( 66 | "Both extensions and is_valid_file cannot be None or not None at the same time" 67 | ) 68 | if extensions is not None: 69 | 70 | def is_valid_file(x): 71 | return has_file_allowed_extension(x, extensions) 72 | 73 | train, test, val = train_test_val_instances( 74 | directory, class_to_idx, index, is_valid_file 75 | ) 76 | 77 | return train if instance == "train" else test if instance == "test" else val 78 | 79 | 80 | class DatasetFolder(VisionDataset): 81 | """A generic data loader :: 82 | 83 | Args: 84 | root (string): Root directory path. 85 | loader (callable): A function to load a sample given its path. 86 | extensions (tuple[string]): A list of allowed extensions. 87 | both extensions and is_valid_file should not be passed. 88 | transform (callable, optional): A function/transform that takes in 89 | a sample and returns a transformed version. 90 | E.g, ``transforms.RandomCrop`` for images. 91 | target_transform (callable, optional): A function/transform that takes 92 | in the target and transforms it. 93 | is_valid_file (callable, optional): A function that takes path of a file 94 | and check if the file is a valid file (used to check of corrupt files) 95 | both extensions and is_valid_file should not be passed. 96 | instance (sting, optional): Either 'train' 'test' or 'val' whether or not you want the train test or val split 97 | index (dict[string:list[string]], optional): A dictionary that maps each class to a list of the image paths for that class along with whetever other data you need to make your dataset 98 | this can really be whatever you want because it is only handled by train_test_val_instances. 99 | train_test_val_instances (callable, optional): A function that takes: 100 | a root directory, 101 | a mapping of class names to indeces, 102 | the index, 103 | and is_valid_file 104 | and returns a tuple of lists containing the instance data for each of train test and val, 105 | the instance data in the list is a tuple and can have whatever structure you want as long as the image path is the first element 106 | each of these tuples is processed by the pretransform 107 | class_data (tuple, optional): the first element is a list of the classes, the second is a mapping of the classes to their indeces 108 | pretransform (callable, optional): A function that takes the loaded image and any other relevant data for that image and returns a transformed version of that image 109 | 110 | Attributes: 111 | classes (list): List of the class names sorted alphabetically. 112 | class_to_idx (dict): Dict with items (class_name, class_index). 113 | pretransform (callable): returns a transformed image using data in the sample 114 | class_data (tuple): (classes, class_to_idx) 115 | samples (tuple): tuple of three (train test val) lists of (sample path, class_index, whatever else, ...) tuples 116 | Unused: targets (list): The class_index value for each image in the dataset 117 | """ 118 | 119 | def __init__( 120 | self, 121 | root, 122 | loader, 123 | extensions, 124 | transform, 125 | target_transform, 126 | is_valid_file, 127 | instance, 128 | index, 129 | train_test_val_instances, 130 | class_data, 131 | pretransform, 132 | ): 133 | 134 | super(DatasetFolder, self).__init__( 135 | root, transform=transform, target_transform=target_transform 136 | ) 137 | self.index = index 138 | self.class_data = class_data 139 | self.pretransform = ( 140 | default_pretransform if pretransform is None else pretransform 141 | ) 142 | classes, class_to_idx = self._find_classes(self.root) 143 | samples = make_dataset( 144 | self.root, 145 | class_to_idx, 146 | extensions, 147 | is_valid_file, 148 | instance, 149 | index, 150 | train_test_val_instances, 151 | ) 152 | 153 | self.loader = loader 154 | self.extensions = extensions 155 | 156 | self.classes = classes 157 | self.class_to_idx = class_to_idx 158 | self.samples = samples 159 | self.targets = [s[1] for s in samples] 160 | 161 | def _find_classes(self, root_dir): 162 | """ 163 | Finds the class folders in a dataset. 164 | 165 | Args: 166 | root_dir (string): Root directory path. 167 | 168 | Returns: 169 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. 170 | 171 | Ensures: 172 | No class is a subdirectory of another. 173 | """ 174 | 175 | classes, class_to_idx = self.class_data(root_dir, self.index) 176 | 177 | return classes, class_to_idx 178 | 179 | def __getitem__(self, index): 180 | """ 181 | Args: 182 | index (int): Index 183 | 184 | Returns: 185 | tuple: (sample, target) where target is class_index of the target class. 186 | """ 187 | # path, target, pt = self.samples[index] 188 | 189 | values = self.samples[index] 190 | path = values[0] 191 | sample = self.loader(path) 192 | sample, target = self.pretransform(sample, values) 193 | 194 | if self.transform is not None: 195 | sample = self.transform(sample) 196 | if self.target_transform is not None: 197 | target = self.target_transform(target) 198 | 199 | return sample, target 200 | 201 | def __len__(self): 202 | return len(self.samples) 203 | -------------------------------------------------------------------------------- /docs/static/img/undraw_fast_loading.svg: -------------------------------------------------------------------------------- 1 | fast loading -------------------------------------------------------------------------------- /tests/tests.py: -------------------------------------------------------------------------------- 1 | """Test Suite for BetterLoader 2 | """ 3 | import unittest 4 | 5 | import sys 6 | import os 7 | 8 | sys.path.append(os.getcwd()) 9 | 10 | from betterloader import BetterLoader, UnsupervisedBetterLoader 11 | from betterloader.defaults import simple_metadata, regex_metadata, collate_metadata 12 | 13 | from torchvision import transforms 14 | 15 | # pylint: disable=no-self-use 16 | 17 | basic_transform = transforms.Compose( 18 | [ 19 | transforms.CenterCrop(10), 20 | ] 21 | ) 22 | 23 | dictionary_transform = { 24 | "train": basic_transform, 25 | "test": basic_transform, 26 | "val": basic_transform, 27 | } 28 | 29 | 30 | class Integration(unittest.TestCase): 31 | """Suite of Integration tests for BetterLoader""" 32 | 33 | def test_defaults(self): 34 | """Test the BetterLoader call using the default parameters""" 35 | index_json = "./examples/sample_index.json" 36 | basepath = "./examples/sample_dataset/" 37 | batch_size = 2 38 | 39 | loader = BetterLoader(basepath=basepath, index_json_path=index_json) 40 | dataloaders, sizes = loader.fetch_segmented_dataloaders( 41 | batch_size=batch_size, transform=None 42 | ) 43 | 44 | assert dataloaders is not None 45 | 46 | assert sizes["train"] == 4 47 | assert sizes["test"] == 2 48 | assert sizes["val"] == 2 49 | 50 | def test_transform(self): 51 | """Test the BetterLoader call using the default parameters""" 52 | index_json = "./examples/sample_index.json" 53 | basepath = "./examples/sample_dataset/" 54 | batch_size = 2 55 | 56 | loader = BetterLoader(basepath=basepath, index_json_path=index_json) 57 | dataloaders, sizes = loader.fetch_segmented_dataloaders( 58 | batch_size=batch_size, transform=basic_transform 59 | ) 60 | 61 | assert dataloaders is not None 62 | 63 | assert sizes["train"] == 4 64 | assert sizes["test"] == 2 65 | assert sizes["val"] == 2 66 | 67 | def test_transformdict(self): 68 | """Test the BetterLoader call using the default parameters""" 69 | index_json = "./examples/sample_index.json" 70 | basepath = "./examples/sample_dataset/" 71 | batch_size = 2 72 | 73 | loader = BetterLoader(basepath=basepath, index_json_path=index_json) 74 | dataloaders, sizes = loader.fetch_segmented_dataloaders( 75 | batch_size=batch_size, transform=dictionary_transform 76 | ) 77 | 78 | assert dataloaders is not None 79 | 80 | assert sizes["train"] == 4 81 | assert sizes["test"] == 2 82 | assert sizes["val"] == 2 83 | 84 | def test_unsupervised(self): 85 | """Test the BetterLoader Unsupervised call using the default parameters""" 86 | index_json = "./examples/sample_index_unsupervised.json" 87 | basepath = "./examples/sample_dataset/" 88 | batch_size = 2 89 | metadata = collate_metadata() 90 | better_loader = UnsupervisedBetterLoader( 91 | basepath=basepath, 92 | base_experiment_details=["simclr", 1, (150, 150)], 93 | index_json_path=index_json, 94 | dataset_metadata=metadata, 95 | ) 96 | dataloaders, sizes = better_loader.fetch_segmented_dataloaders( 97 | batch_size=batch_size 98 | ) 99 | 100 | assert dataloaders is not None 101 | 102 | assert sizes["train"] == 4 103 | assert sizes["test"] == 2 104 | assert sizes["val"] == 2 105 | 106 | def test_defaults_with_object(self): 107 | """Test the BetterLoader call using the default parameters but using index_object not index_json_path""" 108 | index = { 109 | "class1": ["image0.jpg", "image1.jpg", "image2.jpg", "image3.jpg"], 110 | "class2": ["image4.jpg", "image5.jpg", "image6.jpg", "image7.jpg"], 111 | } 112 | basepath = "./examples/sample_dataset/" 113 | batch_size = 2 114 | 115 | loader = BetterLoader(basepath=basepath, index_object=index) 116 | dataloaders, sizes = loader.fetch_segmented_dataloaders( 117 | batch_size=batch_size, transform=None 118 | ) 119 | 120 | assert dataloaders is not None 121 | 122 | assert sizes["train"] == 4 123 | assert sizes["test"] == 2 124 | assert sizes["val"] == 2 125 | 126 | def test_simple_metadata(self): 127 | """Test the BetterLoader call using the same default functions, but passed in this time""" 128 | index_json = "./examples/sample_index.json" 129 | basepath = "./examples/sample_dataset/" 130 | batch_size = 2 131 | 132 | dataset_metadata = simple_metadata() 133 | 134 | loader = BetterLoader( 135 | basepath=basepath, 136 | index_json_path=index_json, 137 | dataset_metadata=dataset_metadata, 138 | ) 139 | dataloaders, sizes = loader.fetch_segmented_dataloaders( 140 | batch_size=batch_size, transform=None 141 | ) 142 | 143 | assert dataloaders is not None 144 | 145 | assert sizes["train"] == 4 146 | assert sizes["test"] == 2 147 | assert sizes["val"] == 2 148 | 149 | def test_complex_metadata(self): 150 | 151 | index_json = "./examples/sample_index.json" 152 | basepath = "./examples/sample_dataset/" 153 | batch_size = 2 154 | 155 | dataset_metadata = collate_metadata() 156 | loader = BetterLoader( 157 | basepath=basepath, 158 | index_json_path=index_json, 159 | dataset_metadata=dataset_metadata, 160 | ) 161 | dataloaders, sizes = loader.fetch_segmented_dataloaders( 162 | batch_size=batch_size, transform=None 163 | ) 164 | 165 | assert dataloaders is not None 166 | 167 | assert sizes["train"] == 4 168 | assert sizes["test"] == 2 169 | assert sizes["val"] == 2 170 | 171 | def test_regex_metadata(self): 172 | """Test the BetterLoader call using regex based functions, but passed in this time""" 173 | index_json = "./examples/regex_index.json" 174 | basepath = "./examples/regex_dataset/" 175 | batch_size = 2 176 | 177 | dataset_metadata = regex_metadata() 178 | 179 | loader = BetterLoader( 180 | basepath=basepath, 181 | index_json_path=index_json, 182 | dataset_metadata=dataset_metadata, 183 | ) 184 | dataloaders, sizes = loader.fetch_segmented_dataloaders( 185 | batch_size=batch_size, transform=None 186 | ) 187 | 188 | assert dataloaders is not None 189 | 190 | assert sizes["train"] == 2 191 | assert sizes["test"] == 2 192 | assert sizes["val"] == 2 193 | 194 | def test_bad_paths(self): 195 | """Test the BetterLoader call using two bad paths - the basepath should be the first exception thrown""" 196 | index_json = "./badpath/" 197 | basepath = "./badpath/" 198 | 199 | dataset_metadata = simple_metadata() 200 | self.assertRaisesRegex( 201 | FileNotFoundError, 202 | "Please supply a valid path to your base folder!", 203 | BetterLoader, 204 | basepath, 205 | index_json, 206 | dataset_metadata, 207 | ) 208 | 209 | def test_bad_basepath(self): 210 | """Test the BetterLoader call using a bad basepath""" 211 | index_json = "./examples/sample_index.json" 212 | basepath = "./badpath/" 213 | 214 | dataset_metadata = simple_metadata() 215 | self.assertRaisesRegex( 216 | FileNotFoundError, 217 | "Please supply a valid path to your base folder!", 218 | BetterLoader, 219 | basepath, 220 | index_json, 221 | dataset_metadata, 222 | ) 223 | 224 | def test_bad_index(self): 225 | """Test the BetterLoader call using a bad index path""" 226 | index_json = "./badpath" 227 | basepath = "./examples/sample_dataset/" 228 | 229 | dataset_metadata = simple_metadata() 230 | self.assertRaisesRegex( 231 | FileNotFoundError, 232 | "Please supply a valid path to a dataset index file or valid index object!", 233 | BetterLoader, 234 | basepath, 235 | index_json, 236 | dataset_metadata, 237 | ) 238 | 239 | 240 | if __name__ == "__main__": 241 | unittest.main() 242 | -------------------------------------------------------------------------------- /betterloader/defaults.py: -------------------------------------------------------------------------------- 1 | """A collection of aggregated default methods and metadata accessors, used for both testing and default values 2 | """ 3 | 4 | import os 5 | import re 6 | import torch 7 | from torch._six import container_abcs, string_classes, int_classes 8 | 9 | 10 | def _simple(): 11 | def train_test_val_instances( 12 | split, directory, class_to_idx, index, is_valid_file 13 | ): # pylint: disable=too-many-locals 14 | """Function to perform default train/test/val instance creation 15 | Args: 16 | split (tuple): Tuple of ratios (from 0 to 1) for train, test, val values 17 | directory (str): Parent directory to read images from 18 | class_to_idx (dict): Dictionary to map values from class strings to index values 19 | index (dict): Index file dict object 20 | is_valid_file (callable): Function to verify if a file should be loaded 21 | Returns: 22 | (tuple): Tuple of length 3 containing train, test, val instances 23 | """ 24 | 25 | 26 | train, test, val = [], [], [] 27 | i = 0 28 | for target_class in sorted(class_to_idx.keys()): 29 | i += 1 30 | if not os.path.isdir(directory): 31 | continue 32 | instances = [] 33 | for file in index[target_class]: 34 | if is_valid_file(file): 35 | path = os.path.join(directory, file) 36 | instances.append((path, class_to_idx[target_class])) 37 | 38 | trainp, _, valp = split 39 | 40 | train += instances[: int(len(instances) * trainp)] 41 | test += instances[ 42 | int(len(instances) * trainp) : int(len(instances) * (1 - valp)) 43 | ] 44 | val += instances[int(len(instances) * (1 - valp)) :] 45 | 46 | return train, test, val 47 | 48 | def classdata(_, index): 49 | 50 | """Given class data, just create the default classes list and class_to_idx dict""" 51 | classes = list(index.keys()) 52 | classes.sort() 53 | class_to_idx = {classes[i]: i for i in range(len(classes))} 54 | 55 | return classes, class_to_idx 56 | 57 | def pretransform(sample, values): 58 | 59 | """Given a sample and a values list as specified in the docs, just return the path""" 60 | target = values[1] 61 | return sample, target 62 | 63 | return train_test_val_instances, classdata, pretransform 64 | 65 | 66 | def _regex(): 67 | def train_test_val_instances( 68 | split, directory, class_to_idx, index, is_valid_file 69 | ): # pylint: disable=too-many-locals 70 | """Function to perform default train/test/val instance creation 71 | Args: 72 | split (tuple): Tuple of ratios (from 0 to 1) for train, test, val values 73 | directory (str): Parent directory to read images from 74 | class_to_idx (dict): Dictionary to map values from class strings to index values 75 | index (dict): Index file dict object 76 | is_valid_file (callable): Function to verify if a file should be loaded 77 | Returns: 78 | (tuple): Tuple of length 3 containing train, test, val instances 79 | """ 80 | train, test, val = [], [], [] 81 | i = 0 82 | 83 | def _fetch_regex_names(regex): 84 | files = [] 85 | for filename in os.listdir(directory): 86 | if re.compile(regex).match(filename): 87 | files.append(filename) 88 | return files 89 | 90 | for target_class in sorted(class_to_idx.keys()): 91 | i += 1 92 | regex = index[target_class] 93 | if not os.path.isdir(directory): 94 | continue 95 | instances = [] 96 | files = _fetch_regex_names(regex) 97 | for file in files: 98 | if is_valid_file(file): 99 | path = os.path.join(directory, file) 100 | instances.append((path, class_to_idx[target_class])) 101 | 102 | trainp, _, valp = split 103 | 104 | train += instances[: int(len(instances) * trainp)] 105 | test += instances[ 106 | int(len(instances) * trainp) : int(len(instances) * (1 - valp)) 107 | ] 108 | val += instances[int(len(instances) * (1 - valp)) :] 109 | return train, test, val 110 | 111 | def classdata(_, index): 112 | """Given class data, just create the default classes list and class_to_idx dict""" 113 | classes = list(index.keys()) 114 | classes.sort() 115 | class_to_idx = {classes[i]: i for i in range(len(classes))} 116 | return classes, class_to_idx 117 | 118 | def pretransform(sample, values): 119 | """Given a sample and a values list as specified in the docs, just return the path""" 120 | target = values[1] 121 | return sample, target 122 | 123 | return train_test_val_instances, classdata, pretransform 124 | 125 | 126 | def _collate(): 127 | np_str_obj_array_pattern = re.compile(r"[SaUO]") 128 | default_collate_err_msg_format = ( 129 | "default_collate: batch must contain tensors, numpy arrays, numbers, " 130 | "dicts or lists; found {}" 131 | ) 132 | 133 | def basic_collate_fn(batch): 134 | """Puts each data field into a tensor with outer dimension batch size""" 135 | elem = batch[0] 136 | elem_type = type(elem) 137 | if isinstance(elem, torch.Tensor): 138 | out = None 139 | if torch.utils.data.get_worker_info() is not None: 140 | # If we're in a background process, concatenate directly into a 141 | # shared memory tensor to avoid an extra copy 142 | numel = sum([x.numel() for x in batch]) 143 | storage = elem.storage()._new_shared(numel) 144 | out = elem.new(storage) 145 | return torch.stack(batch, 0, out=out) 146 | elif ( 147 | elem_type.__module__ == "numpy" 148 | and elem_type.__name__ != "str_" 149 | and elem_type.__name__ != "string_" 150 | ): 151 | if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap": 152 | # array of string classes and object 153 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None: 154 | raise TypeError(default_collate_err_msg_format.format(elem.dtype)) 155 | 156 | return basic_collate_fn([torch.as_tensor(b) for b in batch]) 157 | elif elem.shape == (): # scalars 158 | return torch.as_tensor(batch) 159 | elif isinstance(elem, float): 160 | return torch.tensor(batch, dtype=torch.float64) 161 | elif isinstance(elem, int_classes): 162 | return torch.tensor(batch) 163 | elif isinstance(elem, string_classes): 164 | return batch 165 | elif isinstance(elem, container_abcs.Mapping): 166 | return {key: basic_collate_fn([d[key] for d in batch]) for key in elem} 167 | elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple 168 | return elem_type(*(basic_collate_fn(samples) for samples in zip(*batch))) 169 | elif isinstance(elem, container_abcs.Sequence): 170 | # check to make sure that the elements in batch have consistent size 171 | it = iter(batch) 172 | elem_size = len(next(it)) 173 | if not all(len(elem) == elem_size for elem in it): 174 | raise RuntimeError( 175 | "each element in list of batch should be of equal size" 176 | ) 177 | transposed = zip(*batch) 178 | return [basic_collate_fn(samples) for samples in transposed] 179 | 180 | raise TypeError(default_collate_err_msg_format.format(elem_type)) 181 | 182 | return basic_collate_fn 183 | 184 | 185 | def simple_metadata(): 186 | """Create a very simple metadata object to test with""" 187 | train_test_val_instances, classdata, pretransform = _simple() 188 | metadata = {} 189 | metadata["pretransform"] = pretransform 190 | metadata["classdata"] = classdata 191 | metadata["train_test_val_instances"] = train_test_val_instances 192 | return metadata 193 | 194 | 195 | def regex_metadata(): 196 | """Create a regex based metadata object""" 197 | train_test_val_instances, classdata, pretransform = _regex() 198 | metadata = {} 199 | metadata["pretransform"] = pretransform 200 | metadata["classdata"] = classdata 201 | metadata["train_test_val_instances"] = train_test_val_instances 202 | return metadata 203 | 204 | 205 | def collate_metadata(): 206 | """Create a collation based metadata object""" 207 | train_test_val_instances, classdata, pretransform = _simple() 208 | basic_collate_fn = _collate() 209 | metadata = {} 210 | metadata["pretransform"] = pretransform 211 | metadata["classdata"] = classdata 212 | metadata["train_test_val_instances"] = train_test_val_instances 213 | metadata["supervised"] = True 214 | metadata["custom_collate"] = basic_collate_fn 215 | metadata["drop_last"] = True 216 | metadata["eccentric_object"] = False 217 | metadata["sample_type"] = None 218 | 219 | return metadata 220 | -------------------------------------------------------------------------------- /docs/docs/gettingstarted.md: -------------------------------------------------------------------------------- 1 | --- 2 | id: gettingstarted 3 | title: Getting Started 4 | sidebar_label: Getting Started 5 | slug: / 6 | --- 7 | 8 | ## Installation 9 | 10 | ### Python 11 | The BetterLoader library is hosted on [pypi](https://pypi.org/) and can be installed via [pip](https://pip.pypa.io/en/stable/). 12 | ```bash 13 | pip install betterloader 14 | ``` 15 | 16 | ### From Source 17 | For developers, BetterLoader's source may also be found at our [Github repository](https://github.com/BinItAI/BetterLoader). You can also install BetterLoader from source, but if you're just trying to use the package, pip is probably a far better bet. 18 | 19 | ## Why BetterLoader? 20 | BetterLoader really shines when you're working with a dataset, and you want to load subsets of image classes conditionally. Say you have 3 folders of images, and you only want to load those images that conform to a specific condition, or those that are present in a pre-defined subset file. What if you want to load a specific set of crops per source image, given a set of source images? BetterLoader can do all this, and more.
21 | Note: BetterLoader currently only supports supervised deep learning tasks. Unsupervised learning support coming soon! 22 | 23 | ### Creating a BetterLoader 24 | Using BetterLoader with its default parameters lets it function just like the regular Python dataloader. A few points worth noting are that: 25 | - BetterLoader does not expect a nested folder structure. In its current iteration, files are expected to all be present in the root directory. This lets us use index files to define classes and labels dynamically, and vary them from experiment to experiment. 26 | - Every instance of BetterLoader requires an index file to function. The default index file format maps class names to a list of image paths, but the index file can be any json file as long as you modify train_test_val_instances to parse it correctly; for example you could instead map class names to regex for the file paths and pass a train_test_val_instances that reads the files based on that regex. Sample index files may be found here. 27 | 28 | A sample use-case for BetterLoader may be found below. It's worth noting that at this point in time, the BetterLoader class has only one callable function. 29 | ```python 30 | from betterloader import BetterLoader 31 | 32 | index_json = './examples/sample_index.json' 33 | # or index_object = {"class1":["image0.jpg","image1.jpg","image2.jpg","image3.jpg"],"class2":["image4.jpg","image5.jpg","image6.jpg","image7.jpg"]} 34 | basepath = "./examples/sample_dataset/" 35 | batch_size = 2 36 | 37 | loader = BetterLoader(basepath=basepath, index_json_path=index_json) 38 | # or loader = BetterLoader(basepath=basepath, index_object=index_object) 39 | dataloaders, sizes = loader.fetch_segmented_dataloaders(batch_size=batch_size, transform=None) 40 | 41 | print("Dataloader sizes: {}".format(str(sizes))) 42 | ``` 43 | 44 | ### Constructor Parameters 45 | | field | type | description | optional (datatype) | 46 | | ------------- | :-----------: | -----: | -----------: | 47 | | basepath | str | path to image directory | no | 48 | | index_json_path | str | path to index file | yes (None) | 49 | | index_object | dict| An object representation of an index file | yes (None) | 50 | | num_workers | int | number of workers | yes (1) | 51 | | subset_json_path | str | path to subset json file | yes (None) | 52 | | subset_object | dict| An object representation of the subset file | yes (None) | 53 | | dataset_metadata | metadata object for dataset | list of optional metadata attributes to customise the BetterLoader (more below) | yes (None) | 54 | 55 | ### Usage 56 | The BetterLoader class' `fetch_segmented_dataloaders` function allows for a user to obtain a tuple of dictionaries, which are most commonly referenced as `(dataloaders, sizes)`. Each dictionary consequently contains `train`, `test`, and `val` keys, allowing for easy access to the dataloaders, as well as their sizes. The function header for the same may be found below: 57 | 58 | ```python 59 | def fetch_segmented_dataloaders(self, batch_size, transform=None) 60 | """Fetch custom dataloaders, which may be used with any PyTorch model 61 | Args: 62 | batch_size (string): Image batch size. 63 | transform (callable or dict, optional): PyTorch transform object. This parameter may also be a dict with keys of 'train', 'test', and 'val', in order to enable separate transforms for each split. 64 | Returns: 65 | dict: A dictionary of dataloaders for train test split 66 | dict: A dictionary of dataset sizes for train test split 67 | """ 68 | ``` 69 | 70 | ### Metadata Parameters 71 | BetterLoader accepts certain key value pairs under the `dataset_metadata` parameter, in order to enable some custom functionality. 72 | 1. pretransform (callable, optional): This allows us to load a custom pretransform before images are loaded into the dataloader and transformed. 73 | For basic usage a pretransform that does not do anything (the default) is usually sufficient. An example use case for the customizability is listed below. 74 | 2. classdata (callable, optional): Defines a custom mapping for a custom format index file to read data from the DatasetFolder class. 75 | Since the index file may have any structure we need to ensure that the classes and a mapping from the classes to the index are always available. 76 | Returns a tuple (list of classes, dictionary mapping of class to index) 77 | 3. split (tuple, optional): Defines a tuple for train, test, val values which must add to one. 78 | 4. train_test_val_instances (callable, optional): Defines a custom function to read values from the index file. 79 | The default expects an index that is a dict mapping classes to a list of file paths, will need to be written custom for different index formats. 80 | Always must return train test and val splits, which each need to be a list of tuples, each tuple corresponding to one datapoint. 81 | The first element of this tuple must also be the filepath of the image for that datapoint. 82 | The default also has the target class index as the second element of this tuple, this is probably good for most use cases. 83 | Each of these datapoint tuples is passed as the `values` argument in the pretransform, any additional data necessary for transforming the datapoint before it is loaded can go in the datapoint tuple. 84 | 5. supervised (bool, optional): Defines whether or not the experiment is supervised 85 | 6. custom_collator (callable, optional): Custom function that merges a list of samples to form a mini-batch of Tensors 86 | 7. drop_last (bool, optional): Defines whether to drop the last incomplete batch if the dataset is not divisible by batch size to avoid sizing errors 87 | 8. pin_mem (bool, optional): Sets the data load to copy tensors into CUDA pinned memory before returning them, providing your data elements are not custom type 88 | 9. sampler (torch.utils.data.Sampler or `iterable`, optional): Can be used to define a custom strategy to draw data from the dataset 89 | 90 | --- 91 | 92 | Here is an example of a `pretransform` and a `train_test_val_instances` designed to allow for a specified crop to be taken of each image. 93 | Notes: 94 | 95 | - The internals of the loader dictate that the elements of the `instances` variables generated from train_test_val_instances will become the `values` argument for a pretransform call, and the `sample` argument for pretransform is the image data loaded directly from the filepath in `values[0]` (or `instances[i][0]`). 96 | - Since the index file here has a similar structure to the default we can get away with using the default classdata function, but index files that don't have the classes as keys of a dictionary will need a custom way of determining the classes. 97 | 98 | ```python 99 | def pretransform(sample, values): 100 | """Example pretransform - takes an image and crops it based on the parameters defined in values 101 | Args: 102 | values (tuple): Tuple of values relevant to a given image - created by the train_test_val_instances function 103 | 104 | Returns: 105 | tuple: Actual modified image, and the target class index for that image 106 | """ 107 | image_path, target, crop_params = values 108 | 109 | # pretransform should always return a tuple of this structure (some image data, some target class index) 110 | return (_crop(sample, crop_params), target) 111 | 112 | ``` 113 | 114 | ```python 115 | def train_test_val_instances(split, directory, class_to_idx, index, is_valid_file): 116 | """Function to perform default train/test/val instance creation 117 | Args: 118 | split (tuple): Tuple of ratios (from 0 to 1) for train, test, val values 119 | directory (str): Parent directory to read images from 120 | class_to_idx (dict): Dictionary to map values from class strings to index values 121 | index (dict): Index file dict object 122 | is_valid_file (callable): Function to verify if a file should be loaded 123 | Returns: 124 | (tuple): Tuple of length 3 containing train, test, val instances 125 | """ 126 | train, test, val = [], [], [] 127 | i = 0 128 | for target_class in sorted(class_to_idx.keys()): 129 | i += 1 130 | if not os.path.isdir(directory): 131 | continue 132 | instances = [] 133 | for filename in index[target_class]: 134 | if is_valid_file(filename): 135 | path = os.path.join(directory, filename) 136 | instances.append((path, class_to_idx[target_class])) 137 | 138 | trainp, _, valp = split 139 | 140 | train += instances[:int(len(instances)*trainp)] 141 | test += instances[int(len(instances)*trainp):int(len(instances)*(1-valp))] 142 | val += instances[int(len(instances)*(1-valp)):] 143 | return train, test, val 144 | ``` 145 | -------------------------------------------------------------------------------- /docs/static/img/undraw_compatibility.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /betterloader/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adding Hypercustomizability to the Pytorch ImageFolder Dataloader. 3 | id est: making it harder to do easy things, but easier to do harder things :) 4 | 5 | """ 6 | 7 | __version__ = "0.1.3" 8 | __author__ = "BinIt Inc" 9 | __credits__ = "N/A" 10 | 11 | import json 12 | import os 13 | from collections import defaultdict 14 | import torchvision.transforms as transforms 15 | from .standard_transforms import TransformWhileSampling 16 | from torch.utils.data.sampler import SubsetRandomSampler 17 | 18 | import torch 19 | import torchvision 20 | 21 | from .ImageFolderCustom import ImageFolderCustom 22 | from .defaults import simple_metadata 23 | 24 | 25 | def fetch_json_from_path(path): 26 | """Helper method to fetch json dict from file 27 | Args: 28 | path: Path to fetch json dict 29 | Returns: 30 | dict: JSON object stored in file at path 31 | """ 32 | if path is not None: 33 | with open(path, "r") as file: 34 | return json.load(file) 35 | else: 36 | return None 37 | 38 | 39 | def check_valid(subset_json): 40 | """Helper to check if a file is valid, given a subset json instance 41 | Args: 42 | subset_json: Defined subset json file data 43 | Returns: 44 | bool: True/false value for validity 45 | """ 46 | if subset_json is None: 47 | return lambda path: True 48 | 49 | def curry(image_path): 50 | if image_path in subset_json: 51 | return True 52 | return False 53 | 54 | return curry 55 | 56 | 57 | class BetterLoader: # pylint: disable=too-few-public-methods 58 | """A hypercustomisable Python dataloader 59 | 60 | Args: 61 | basepath (string): Root directory path. 62 | index_json_path (string): Path to index file 63 | num_workers (int, optional): Number of workers 64 | subset_json_path (string, optional): Path to subset json 65 | dataset_metadata (dict, optional): Optional metadata parameters: 66 | 67 | - pretransform (callable, optional): This allows us to load a custom pretransform before images are loaded into the dataloader and transformed. 68 | - classdata (callable, optional): Defines a custom mapping for a custom format index file to read data from the DatasetFolder class 69 | - split (tuple, optional): Defines a tuple for train, test, val values which must add to one 70 | - train_test_val_instances (callable, optional): Defines a custom function to read values from the index file 71 | 72 | Attributes: 73 | basepath (string): Root directory path. 74 | num_workers (int, optional): Number of workers 75 | subset_json_path (string, optional): Path to subset json 76 | index_json_path (string): Path to index file 77 | classes (list): List of the class names sorted alphabetically 78 | class_to_idx (dict): Dict with items (class_name, class_index). 79 | dataset_metadata (dict, optional): Optional metadata parameters. This dictionary must atleast 80 | contain a bool indicating if the experiment is 'supervised', and addtionally might contain 'custom_collate' 81 | (a custom collator), an additional 'sample_type' for arbitrary data sampling, 'eccentric_object' to indicate 82 | if it has to be pinned in memory and a indicator to 'drop_last' for non-integer sample_size/batch_size value 83 | split (tuple): Tuple of train test val float values 84 | 85 | """ 86 | 87 | def __init__( 88 | self, 89 | basepath, 90 | index_json_path=None, 91 | num_workers=1, 92 | index_object=None, 93 | subset_json_path=None, 94 | subset_object=None, 95 | dataset_metadata=None, 96 | ): 97 | if not os.path.exists(basepath): 98 | raise FileNotFoundError("Please supply a valid path to your base folder!") 99 | 100 | if index_json_path and not os.path.exists(index_json_path): 101 | if not index_object: 102 | raise FileNotFoundError( 103 | "Please supply a valid path to a dataset index file or valid index object!" 104 | ) 105 | if index_object and index_json_path: 106 | raise ValueError( 107 | "you must only define either the index_json_path or index object, not both!" 108 | ) 109 | if subset_object and subset_json_path: 110 | raise ValueError( 111 | "you must only define either the subset_json_path or subset object, not both!" 112 | ) 113 | self.basepath = basepath 114 | self.num_workers = num_workers 115 | self.subset_json_path = subset_json_path 116 | self.index_json_path = index_json_path 117 | self.subset_object = subset_object 118 | self.index_object = index_object 119 | self.classes = [] 120 | self.class_to_idx = {} 121 | self.dataset_metadata = ( 122 | {} 123 | if dataset_metadata is None 124 | else {i: dataset_metadata[i] for i in dataset_metadata if i != "split"} 125 | ) 126 | self.split = ( 127 | self.dataset_metadata["split"] 128 | if "split" in self.dataset_metadata 129 | else (0.6, 0.2, 0.2) 130 | ) 131 | self.supervised = ( 132 | self.dataset_metadata["supervised"] 133 | if "supervised" in self.dataset_metadata 134 | else True 135 | ) 136 | self.custom_collator = ( 137 | self.dataset_metadata["custom_collate"] 138 | if "custom_collate" in self.dataset_metadata 139 | else None 140 | ) 141 | self.drop_last = ( 142 | self.dataset_metadata["drop_last"] 143 | if "drop_last" in self.dataset_metadata 144 | else False 145 | ) 146 | self.pin_mem = ( 147 | self.dataset_metadata["eccentric_object"] 148 | if "eccentric_object" in self.dataset_metadata 149 | else False 150 | ) 151 | self.sampler = ( 152 | self.dataset_metadata["sample_type"] 153 | if "sample_type" in self.dataset_metadata 154 | else None 155 | ) 156 | 157 | # self.dataloader_params = self.dataset_metadata['dataloader_params'] if 'dataloader_params' in self.dataset_metadata else {} 158 | 159 | def _set_class_data(self, datasets): 160 | """Wrapper to set class data values upon processing datasets 161 | Args: 162 | list: datasets that have been processed 163 | """ 164 | 165 | if not ( 166 | all(x.classes == datasets[0].classes for x in datasets) 167 | and all(x.class_to_idx == datasets[0].class_to_idx for x in datasets) 168 | ): 169 | print( 170 | "Class mismatch between the train/test/val data. This is usually caused by an uneven split, or a lack of the presence of identical classes in train/test/val. Assigning train data class names and class_to_idx map." 171 | ) 172 | 173 | self.classes = datasets[0].classes 174 | self.class_to_idx = datasets[0].class_to_idx 175 | 176 | def _fetch_metadata(self, key): 177 | """Wrapper to fetch a value from the dataset metadata 178 | Args: 179 | key (string): Key that we're trying to fetch 180 | 181 | Returns: 182 | var: Value for that key - if such a key does not exist, return None 183 | """ 184 | 185 | if key in self.dataset_metadata: 186 | return self.dataset_metadata[key] 187 | 188 | return None 189 | 190 | def fetch_segmented_dataloaders(self, batch_size, transform=None): 191 | """Fetch custom dataloaders, which may be used with any PyTorch model 192 | 193 | Args: 194 | batch_size (string): Image batch size. 195 | transform (callable or dict, optional): PyTorch transform object. This parameter may also be a dict with keys of 'train', 'test', and 'val', in order to enable separate transforms for each split. 196 | 197 | Returns: 198 | dict: A dictionary of dataloaders for train test split 199 | dict: A dictionary of dataset sizes for train test split 200 | """ 201 | 202 | train_test_val_instances, class_data, pretransform = ( 203 | self._fetch_metadata("train_test_val_instances"), 204 | self._fetch_metadata("classdata"), 205 | self._fetch_metadata("pretransform"), 206 | ) 207 | 208 | if train_test_val_instances is None: 209 | train_test_val_instances = simple_metadata()["train_test_val_instances"] 210 | 211 | index, subset_json = ( 212 | fetch_json_from_path(self.index_json_path) 213 | if self.index_json_path 214 | else self.index_object, 215 | fetch_json_from_path(self.subset_json_path) 216 | if self.subset_json_path 217 | else self.subset_object, 218 | ) 219 | 220 | datasets = None 221 | 222 | def train_test_val_instances_wrap( 223 | directory, class_to_idx, index, is_valid_file 224 | ): 225 | return train_test_val_instances( 226 | self.split, directory, class_to_idx, index, is_valid_file 227 | ) 228 | 229 | if transform is None: 230 | datasets = [ 231 | ImageFolderCustom( 232 | root=self.basepath, 233 | is_valid_file=check_valid(subset_json), 234 | instance=x, 235 | index=index, 236 | train_test_val_instances=train_test_val_instances_wrap, 237 | class_data=class_data, 238 | pretransform=pretransform, 239 | ) 240 | for x in ("train", "test", "val") 241 | ] 242 | elif isinstance(transform, dict): 243 | # if transform is a dict treat like {'train':transform1, 'test':transform2, 'val':transform3} 244 | datasets = [ 245 | ImageFolderCustom( 246 | root=self.basepath, 247 | transform=transform[x], 248 | is_valid_file=check_valid(subset_json), 249 | instance=x, 250 | index=index, 251 | train_test_val_instances=train_test_val_instances_wrap, 252 | class_data=class_data, 253 | pretransform=pretransform, 254 | ) 255 | for x in ("train", "test", "val") 256 | ] 257 | else: 258 | datasets = [ 259 | ImageFolderCustom( 260 | root=self.basepath, 261 | transform=transform, 262 | is_valid_file=check_valid(subset_json), 263 | instance=x, 264 | index=index, 265 | train_test_val_instances=train_test_val_instances_wrap, 266 | class_data=class_data, 267 | pretransform=pretransform, 268 | ) 269 | for x in ("train", "test", "val") 270 | ] 271 | 272 | self._set_class_data(datasets) 273 | 274 | custom_collator = self.custom_collator 275 | drop_last = self.drop_last 276 | pin_mem = self.pin_mem 277 | sampler = self.sampler 278 | 279 | if sampler is not None: 280 | dataloaders = [ 281 | torch.utils.data.DataLoader( 282 | x, 283 | batch_size=batch_size, 284 | num_workers=self.num_workers, 285 | collate_fn=custom_collator, 286 | sampler=SubsetRandomSampler(list(range(len(x)))) 287 | if sampler == "subset_sampling" 288 | else None, 289 | pin_memory=pin_mem, 290 | drop_last=drop_last, 291 | shuffle=False, 292 | ) 293 | for x in [datasets[0]] 294 | ] + [ 295 | torch.utils.data.DataLoader( 296 | x, 297 | batch_size=batch_size, 298 | num_workers=self.num_workers, 299 | collate_fn=custom_collator, 300 | sampler=SubsetRandomSampler(list(range(len(x)))) 301 | if sampler == "subset_sampling" 302 | else None, 303 | pin_memory=pin_mem, 304 | drop_last=drop_last, 305 | ) 306 | for x in datasets[1:] 307 | ] 308 | 309 | else: 310 | dataloaders = [ 311 | torch.utils.data.DataLoader( 312 | x, 313 | batch_size=batch_size, 314 | num_workers=self.num_workers, 315 | collate_fn=custom_collator, 316 | shuffle=True, 317 | pin_memory=pin_mem, 318 | drop_last=drop_last, 319 | ) 320 | for x in [datasets[0]] 321 | ] + [ 322 | torch.utils.data.DataLoader( 323 | x, 324 | batch_size=batch_size, 325 | num_workers=self.num_workers, 326 | collate_fn=custom_collator, 327 | shuffle=False, 328 | pin_memory=pin_mem, 329 | drop_last=drop_last, 330 | ) 331 | for x in datasets[1:] 332 | ] 333 | 334 | loaders = { 335 | "train": dataloaders[0], 336 | "test": dataloaders[1], 337 | "val": dataloaders[2], 338 | } 339 | 340 | sizes = { 341 | "train": len(datasets[0]), 342 | "test": len(datasets[1]), 343 | "val": len(datasets[2]), 344 | } 345 | 346 | return loaders, sizes 347 | 348 | 349 | class UnsupervisedBetterLoader(BetterLoader): 350 | def __init__( 351 | self, 352 | basepath, 353 | base_experiment_details, 354 | index_json_path=None, 355 | num_workers=1, 356 | index_object=None, 357 | subset_json_path=None, 358 | subset_object=None, 359 | dataset_metadata=None, 360 | ): 361 | 362 | super(UnsupervisedBetterLoader, self).__init__( 363 | basepath, 364 | index_json_path, 365 | num_workers, 366 | index_object, 367 | subset_json_path, 368 | subset_object, 369 | dataset_metadata, 370 | ) 371 | 372 | self.base_experiment_name = ( 373 | base_experiment_details[0] 374 | if base_experiment_details is not None 375 | else "simclr" 376 | ) 377 | self.experiment_transform_params = base_experiment_details[1:] 378 | 379 | self.setup_sampling() 380 | self.transforms = self.setup_transform() 381 | 382 | def setup_sampling(self): 383 | 384 | if self.base_experiment_name == "simclr": 385 | 386 | if self.dataset_metadata["sample_type"] is None: 387 | self.dataset_metadata["sample_type"] = "subset_sampling" 388 | 389 | else: 390 | raise Exception( 391 | "Iteration of experiment (name: {}) is not currently supported".format( 392 | self.base_experiment_name 393 | ) 394 | ) 395 | 396 | def setup_transform(self): 397 | 398 | if self.base_experiment_name == "simclr": 399 | 400 | if not len(self.experiment_transform_params) == 2: 401 | raise Exception( 402 | "For SimClR, experiment details should be of the form [experiment_name, side_jitter, input_shape]" 403 | ) 404 | 405 | side_jitter = self.experiment_transform_params[0] 406 | input_shape = self.experiment_transform_params[1] 407 | 408 | color_jitter = transforms.ColorJitter( 409 | 0.8 * side_jitter, 410 | 0.8 * side_jitter, 411 | 0.8 * side_jitter, 412 | 0.2 * side_jitter, 413 | ) 414 | 415 | all_transforms = transforms.Compose( 416 | [ 417 | transforms.RandomResizedCrop(size=input_shape[0]), 418 | transforms.RandomHorizontalFlip(), 419 | transforms.RandomApply([color_jitter], p=0.8), 420 | transforms.RandomGrayscale(p=0.2), 421 | transforms.GaussianBlur( 422 | kernel_size=[ 423 | int(0.1 * input_shape[0]), 424 | int(0.1 * input_shape[0]), 425 | ] 426 | ), 427 | transforms.ToTensor(), 428 | ] 429 | ) 430 | 431 | return TransformWhileSampling(all_transforms) 432 | 433 | else: 434 | 435 | raise Exception( 436 | "Iteration of experiment (name: {}) is not currently supported".format( 437 | self.base_experiment_name 438 | ) 439 | ) 440 | 441 | def fetch_segmented_dataloaders(self, batch_size): 442 | 443 | dataloaders, sizes = super( 444 | UnsupervisedBetterLoader, self 445 | ).fetch_segmented_dataloaders(batch_size, self.transforms) 446 | return dataloaders, sizes 447 | --------------------------------------------------------------------------------